# Processing of the 'interactions' dataset to compute BERTScores

In this notebook, we process the 'interactions' dataset to determine the candidate answer for each datapoint. Since almost of datapoints contain a multi-turn interaction, we determine our candidate answer by considering the assistant answer that has the highest semantic similarity (determined computing the BERTScore) with the golden answer.

In [None]:
!pip install bert-score # https://pypi.org/project/bert-score/

In [None]:
from bert_score import score

In [None]:
import json

# Load the JSON data
with open('interactions_v1.json', 'r') as file:
    data = json.load(file)

In [None]:
# Load the JSON solutions
with open('solutions_v1.json', 'r') as file:
    solutions = json.load(file)

In [None]:
def flatten_and_concatenate(nested_list):
    # If the input is a string, return it
    if isinstance(nested_list, str):
        return nested_list

    # If the input is a list, apply the function to each element and concatenate the results
    if isinstance(nested_list, list):
        return ' '.join(flatten_and_concatenate(element) for element in nested_list)

    # If the input is neither a string nor a list, return an empty string
    return ''

In [None]:
import time
import json

# Create an empty list to store the datapoints
datapoints = []
complete_data = []

try:
    # Load the previous list of datapoints (used to resume the process if it was interrupted)
    with open('datapoints.json', 'r') as f:
        datapoints = json.load(f)
    with open('complete.json', 'r') as f:
        complete_data = json.load(f)

    start_index = len(datapoints)
except Exception as e:
    print(f"Failed to load previous data: {e}")
    
data_len = len(data)
start_time = time.time()
elapsed_times = []

# Iterate over the data
for count, datapoint in enumerate(data[start_index:], start_index + 1):
    try:
        print("########################################################")
        print("Processing datapoint", count, "of", data_len, "(", round(count/data_len*100, 2), "%)")
        print("########################################################")

        iteration_start_time = time.time()
        max_score = -1
        best_content = ''
        gold_answer = ''
        explanation = None
        is_mcq = 0

        # Iterate over the entries in the data
        for entry in solutions:
            # If the sol_id of the current entry matches the target sol_id
            if entry.get('sol_id') == datapoint.get("sol_id"):
                gold_answer = entry.get('answer', '')
                if 'choices' in entry:
                    is_mcq = 1
                if 'explanation' in entry and entry['explanation'] is not None:
                    explanation = entry['explanation']
                break

        interactions = []
        # Iterate over the interactions in the data
        for interaction in datapoint.get('interaction', []):
            # Check if the role is 'assistant'
            if interaction.get('role') == 'assistant':
                # Compute the BERTScore for the content of the interaction
                if isinstance(interaction.get('content'), str):
                    score_ = score([interaction['content']], [flatten_and_concatenate(gold_answer)], model_type="bert-base-multilingual-cased")[2]

                    # Convert tensor to a single value
                    score_ = score_.item()

                # If the computed score is higher than the current max score,
                # update max_score and best_content
                if score_ > max_score:
                    max_score = score_
                    best_content = interaction['content']

                interaction["BERTScore"] = float(score_)

            interactions.append(interaction)

        # Create a dictionary to store the datapoint
        if is_mcq and explanation is not None:
            datapoint_dict = {
                "candidate_answer": best_content,
                "gold_answer": gold_answer,
                "max_score": max_score,
                "MCQ": is_mcq,
                "explanation": explanation
            }
        else:
            datapoint_dict = {
                "candidate_answer": best_content,
                "gold_answer": gold_answer,
                "max_score": max_score,
                "MCQ": is_mcq
            }

        # Add the datapoint to the list
        datapoints.append(datapoint_dict)

        complete_data.append({
            "confidence": datapoint.get("confidence", None),
            "interaction": interactions
        })

        iteration_end_time = time.time()
        elapsed_time = iteration_end_time - iteration_start_time
        elapsed_times.append(elapsed_time)

        average_time_per_datapoint = sum(elapsed_times) / len(elapsed_times)
        remaining_datapoints = data_len - count
        estimated_time_remaining = remaining_datapoints * average_time_per_datapoint

        print("Estimated time remaining: ", round(estimated_time_remaining/60, 2), "minutes")

        # Save the list of datapoints as a JSON file
        with open('datapoints.json', 'w') as f:
            json.dump(datapoints, f)

        # Save the complete data with BERTScore for each interaction
        with open('complete.json', 'w') as f:
            json.dump(complete_data, f)
    except Exception as e:
        print(f"Failed on datapoint {count}: {e}")
        continue