In [None]:
import pandas as pd
import ast
import re

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', False)

ambiguous_types = {"comp_neg_dis", "comp_pos_dis", "simple_pos_dis", "simple_neg_dis", "simple_neg_conj", "comp_neg_conj"}

In [None]:
# For letter and number tasks

def normalize_response(response):
    if pd.isna(response) or response.strip() == "":
        return set()

    # Handle lists with brackets
    if (response.startswith("[") and response.endswith("]")):

        cleaned_response = response.strip()[1:-1]
        cleaned_response = ','.join([f"'{item.strip()}'" if not item.strip().startswith("'") else item.strip() for item in cleaned_response.split(',')])

        try:
            parsed = ast.literal_eval(f"[{cleaned_response}]")
            if isinstance(parsed, list):
                return set(map(str, parsed))
        except (ValueError, SyntaxError):
            pass

    # Handle non bracket lists
    if ',' in response:
        return set(map(lambda x: x.strip(), response.strip().split(',')))

    # Handle plain string and split by spaces
    return set(map(str, response.strip().split()))

def check_responses(df):
    def is_correct(row):
        model_resp = normalize_response(row["Model Response"])

        # If type is comp_neg_dis or comp_pos_dis, there are multiple correct answers
        if row["Type"] in ambiguous_types:
            try:
                correct_responses = [set(map(str, lst)) for lst in ast.literal_eval(row["Correct Response"])]
            except (ValueError, SyntaxError):
                correct_responses = []

            for index, cr in enumerate(correct_responses):
                if model_resp == cr:
                    return 1, index

            return 0, -1

        else:
            correct_resp = normalize_response(row["Correct Response"])
            if model_resp == correct_resp:
                return 1, 0
            else:
                return 0, -1
    df[["Correct", "Matched Response"]] = df.apply(lambda row: pd.Series(is_correct(row)), axis=1)


df = pd.read_csv('digit_task_stimuli_L3.1_8B.csv')

check_responses(df)
df.to_csv('digit_task_L3.1_8B_ch.csv', index=False)


In [None]:
# For emoji task

def normalize_emoji_response(response):
    """
    Converts model responses into a set of frozensets to properly handle sets of emoji responses.
    """
    if pd.isna(response) or response.strip() == "":
        return set()

    # Clean up spaces for consistency
    response = response.replace(", ", ",")  # Standardize the comma format
    response = response.replace(" ", "")  # Optionally strip all spaces for precise comparison
    if response.startswith("{") and response.endswith("}"):
        response = re.sub(r"([^\w\s,{}'])", r"'\1'", response)  # Wrap unquoted emojis without merging


    try:
        # Attempt to parse the response as a Python literal (list or set)
        parsed = ast.literal_eval(response)

        # If the response is a list of sets, we convert them to frozensets and wrap them in a set
        if isinstance(parsed, list):
            return {frozenset(map(str, s)) if isinstance(s, set) else frozenset([str(s)]) for s in parsed}

        # If it's a set, we wrap it inside a frozenset and return
        if isinstance(parsed, set):
            return {frozenset(map(str, parsed))}

    except (ValueError, SyntaxError):
        pass

    # Handle non-bracketed comma-separated responses like {'🐇', '🐖'}, where response is not enclosed in brackets
    if ',' in response:
        frozenset_items = set()
        items = response.split('},{')  # Split around sets

        for item in items:
            item = item.strip().strip('{}')  # Remove any surrounding braces
            item = item.replace("'", "")
            # Split the emojis within a set by comma and create a frozenset
            if item:
                frozenset_items.add(frozenset(item.split(',')))

        return frozenset_items

    # Handle space-separated emoji strings (like '🐇 🦨')
    return {frozenset(response.strip().split())}

def check_responses(df):
    def is_correct(row):
        model_resp = normalize_emoji_response(row["Model Response"])

        if row["Type"] in ambiguous_types:
            try:
                # Parse Correct Response as a tuple of two possible correct lists
                correct_responses_tuple = ast.literal_eval(row["Correct Response"])
                if not isinstance(correct_responses_tuple, tuple) or len(correct_responses_tuple) != 2:
                    raise ValueError("Incorrect format for correct responses")

                # Convert both possible correct responses into sets of frozensets
                correct_response_1 = {
                    frozenset(map(str, s)) if isinstance(s, set) else frozenset([str(s)])
                    for s in correct_responses_tuple[0]
                }
                correct_response_2 = {
                    frozenset(map(str, s)) if isinstance(s, set) else frozenset([str(s)])
                    for s in correct_responses_tuple[1]
                }

            except (ValueError, SyntaxError, TypeError) as e:
                print(f"Error parsing correct response for row {row.name}: {e}")
                correct_response_1 = set()
                correct_response_2 = set()

            # Model response should match either correct set of responses exactly
            if model_resp == correct_response_1:
                return 1, 0
            elif model_resp == correct_response_2:
                return 1, 1
            else:
                return 0, -1
        else:
            correct_resp = normalize_emoji_response(row["Correct Response"])
            if model_resp == correct_resp:
                return 1, 0
            else:
                return 0, -1

    df[["Correct", "Matched Response"]] = df.apply(lambda row: pd.Series(is_correct(row)), axis=1)


df = pd.read_csv('emoji_task_stimuli_L3.1_8B.csv')

check_responses(df)
df.to_csv('emoji_task_L3.1_8B_ch.csv', index=False)


In [None]:
# For word task

def normalize_single_response(response):

    if pd.isna(response) or response.strip() == "":
        return set()

    # If the response starts with a square bracket and ends with it, strip it
    if response.startswith("[") and response.endswith("]"):
        response = response[1:-1].strip()

    # Normalize by removing single or double quotes around sentences
    response = response.replace("'", "").replace('"', '')

    # Split by periods and commas with optional spaces between them
    sentences = re.split(r'\s*,\s*|\.\s*', response)

    # Normalize by stripping any extra spaces and removing any empty sentences
    return set(sentence.strip() for sentence in sentences if sentence.strip())

def normalize_multiple_response(response):
    if pd.isna(response) or response.strip() == "":
        return set()

    try:
        # Attempt to parse as a Python literal (list of lists or single list)
        parsed = ast.literal_eval(response)

        if isinstance(parsed, list) and all(isinstance(lst, list) for lst in parsed):
            # Preserve multiple response sets as distinct frozensets
            return {frozenset(sentence.strip("[] ") for sentence in lst) for lst in parsed}
        elif isinstance(parsed, list):
            # Single response list case: wrap it in a set
            return {frozenset(sentence.strip("[] ") for sentence in parsed)}

    except (ValueError, SyntaxError):
        pass  # Fall back to manual parsing if necessary

    # Fallback: Manually extract lists while keeping them separate
    response = response.replace("(", "").replace(")", "")  # Remove parentheses
    response = response.replace("'", "").replace('"', '')  # Remove quotes

    # Split into possible separate lists by detecting sets enclosed in brackets
    possible_sets = re.findall(r'\[.*?\]', response)  # Finds everything inside square brackets []

    # Process each detected list separately and convert into frozensets
    processed_sets = set()
    for item in possible_sets:
        cleaned = item.strip("[] ")  # Remove outer brackets
        sentences = re.split(r'\s*,\s*|\.\s*', cleaned)  # Split by commas or periods
        processed_sets.add(frozenset(sentence.strip() for sentence in sentences if sentence.strip()))

    return processed_sets


def check_responses(df):
    def is_correct(row):
        model_resp = normalize_single_response(row["Model Response"])

        if row["Type"] in ambiguous_types:
            try:
                correct_responses = normalize_multiple_response(row["Correct Response"])
            except Exception as e:
                print(f"Error parsing correct response for row {row.name}: {e}")
                correct_responses = set()
            model_resp = frozenset(model_resp)
            # Model response should match either correct set of sentences
            for index, cr in enumerate(correct_responses):
                if model_resp == cr:
                    return 1, index
            return 0, -1

        else:
            correct_resp = normalize_single_response(row["Correct Response"])
            if model_resp == correct_resp:
                return 1, 0
            else:
                return 0, -1

    df[["Correct", "Matched Response"]] = df.apply(lambda row: pd.Series(is_correct(row)), axis=1)
    # print(df[['Correct Response', 'Model Response', 'Correct', 'Matched Response']])





'''
df = pd.read_csv('checker_test.csv')
df = df[['Type', 'Correct Response', 'Model Response']]
check_responses(df)
df = df.iloc[18:]
'''

df = pd.read_csv('word_task_stimuli_L3.1_8B.csv')

check_responses(df)
df.to_csv('word_task_L3.1_8B_ch.csv', index=False)