In [1]:
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def collect_data(root_dir):
    data_list = []
    for setting_name in os.listdir(root_dir):
        setting_path = os.path.join(root_dir, setting_name)
        if os.path.isdir(setting_path):
            for model_name in os.listdir(setting_path):
                model_path = os.path.join(setting_path, model_name)
                if os.path.isdir(model_path):
                    answers_file = os.path.join(model_path, 'answers.json')
                    if os.path.exists(answers_file):
                        with open(answers_file, 'r', encoding='utf-8') as f:
                            data = json.load(f)
                            for item in data:
                                item['setting_name'] = setting_name
                                item['model_name'] = model_name
                                data_list.append(item)
    return pd.DataFrame(data_list)

def process_data(df):
    # Normalize text data
    df['tool_annotation'] = df.get('tool_annotation', '').str.strip().str.lower()
    df['info_annotation'] = df.get('info_annotation', '').str.strip().str.lower()

    # Replace 'idk' with 'IDK' for consistency
    df['tool_annotation'] = df['tool_annotation'].replace({'idk': 'IDK', 'yes': 'Yes', 'no': 'No'})
    df['info_annotation'] = df['info_annotation'].replace({'idk': 'IDK', 'yes': 'Yes', 'no': 'No'})

    # Ensure pass_rate is numeric
    df['pass_rate'] = pd.to_numeric(df['pass_rate'], errors='coerce').fillna(0)

    # Ensure validity scores are numeric
    df['tool_aware_score'] = pd.to_numeric(df.get('tool_aware_score', 0), errors='coerce').fillna(0)
    df['info_aware_score'] = pd.to_numeric(df.get('info_aware_score', 0), errors='coerce').fillna(0)

    return df

def compute_metrics(df):
    # Aggregate counts for tool awareness
    tool_awareness_counts = df.pivot_table(index=['setting_name', 'model_name'], columns='tool_annotation', aggfunc='size', fill_value=0)
    tool_awareness_counts = tool_awareness_counts[['Yes', 'IDK', 'No']]

    # Aggregate counts for info awareness
    info_awareness_counts = df.pivot_table(index=['setting_name', 'model_name'], columns='info_annotation', aggfunc='size', fill_value=0)
    info_awareness_counts = info_awareness_counts[['Yes', 'IDK', 'No']]

    # Compute pass rates
    pass_rates = df.groupby(['setting_name', 'model_name'])['pass_rate'].mean().reset_index()

    # Compute tool and info validity scores
    tool_validity = df.groupby(['setting_name', 'model_name'])['tool_aware_score'].mean().reset_index()
    info_validity = df.groupby(['setting_name', 'model_name'])['info_aware_score'].mean().reset_index()

    return tool_awareness_counts, info_awareness_counts, pass_rates, tool_validity, info_validity

def plot_awareness_counts(awareness_counts, awareness_type):
    awareness_counts = awareness_counts.reset_index()
    for (setting_name, model_name), group_df in awareness_counts.groupby(['setting_name', 'model_name']):
        data = group_df[['Yes', 'IDK', 'No']].iloc[0]
        data.plot(kind='bar', stacked=True, color=['green', 'orange', 'red'])
        plt.title(f'{awareness_type} Responses for {model_name} in {setting_name} Setting')
        plt.xlabel('Response Type')
        plt.ylabel('Count')
        plt.tight_layout()
        plt.savefig(f'{awareness_type}_{setting_name}_{model_name}.png')
        plt.close()

def plot_pass_rates(pass_rates):
    plt.figure(figsize=(10, 6))
    sns.barplot(data=pass_rates, x='setting_name', y='pass_rate', hue='model_name')
    plt.title('Pass Rate by Setting and Model')
    plt.xlabel('Setting')
    plt.ylabel('Pass Rate')
    plt.tight_layout()
    plt.savefig('pass_rates.png')
    plt.close()

def plot_validity_scores(validity_scores, validity_type):
    plt.figure(figsize=(10, 6))
    sns.barplot(data=validity_scores, x='setting_name', y=f'{validity_type}_aware_score', hue='model_name')
    plt.title(f'{validity_type.capitalize()} Awareness Score by Setting and Model')
    plt.xlabel('Setting')
    plt.ylabel(f'{validity_type.capitalize()} Awareness Score')
    plt.tight_layout()
    plt.savefig(f'{validity_type}_awareness_scores.png')
    plt.close()

def main():
    root_dir = '.'  # Replace with your root directory
    df = collect_data(root_dir)
    df = process_data(df)

    tool_awareness_counts, info_awareness_counts, pass_rates, tool_validity, info_validity = compute_metrics(df)

    # Plotting
    plot_awareness_counts(tool_awareness_counts, 'Tool Awareness')
    plot_awareness_counts(info_awareness_counts, 'Information Awareness')
    plot_pass_rates(pass_rates)
    plot_validity_scores(tool_validity, 'tool')
    plot_validity_scores(info_validity, 'info')

    # Optionally, print out the computed metrics
    print("Pass Rates:")
    print(pass_rates)
    print("\nTool Awareness Counts:")
    print(tool_awareness_counts)
    print("\nInformation Awareness Counts:")
    print(info_awareness_counts)
    print("\nTool Validity Scores:")
    print(tool_validity)
    print("\nInformation Validity Scores:")
    print(info_validity)

if __name__ == '__main__':
    main()


Pass Rates:
       setting_name                      model_name  pass_rate
0          No-tools      claude3.5_sonnet_auto_eval   0.202020
1          No-tools                gpt_4o_auto_eval   0.575758
2          No-tools            gpt_4o_auto_eval_qaq   0.320000
3          No-tools            llama_405B_auto_eval   0.434343
4          No-tools             llama_70B_auto_eval   0.525253
5   Non-replaceable      claude3.5_sonnet_auto_eval   0.085106
6   Non-replaceable  claude3.5_sonnet_auto_eval_qaq   0.074468
7   Non-replaceable                gpt_4o_auto_eval   0.105263
8   Non-replaceable            gpt_4o_auto_eval_qaq   0.101266
9   Non-replaceable            llama_405B_auto_eval   0.297872
10  Non-replaceable        llama_405B_auto_eval_qaq   0.161290
11  Non-replaceable             llama_70B_auto_eval   0.043011
12  Non-replaceable         llama_70B_auto_eval_qaq   0.053191
13         Original      claude3.5_sonnet_auto_eval   0.670213
14         Original                gpt_4o_a