In [31]:
from src.experiments_utils import *
from matplotlib import pyplot as plt

In [34]:
def do_comparison(combined_path_list, last_path_list, key_param_list, key_param, exp_name, params, prior_metrics, lp_list, used_regimes=None):

    combined_path = os.path.join('supervised', exp_name, 'c')
    os.makedirs(combined_path, exist_ok=True)

    print('loading IFRs')
    folder_paths = set()
    for sublist in lp_list:
        for file_path in sublist:
            folder_path = os.path.dirname(file_path)
            folder_paths.add(folder_path)

    dataframes = []
    for folder in folder_paths:
        csv_path = os.path.join(folder, 'ifr_df.csv')
        if os.path.exists(csv_path):
            try:
                df = pd.read_csv(csv_path)
                base_name = os.path.basename(folder)
                base_name = base_name.replace('-retrain', '')
                df = df.dropna(subset=['epoch'])
                #print('uniques', df['epoch'].unique())
                df['prior'] = df['epoch'] == 0
                df['model'] = base_name
                df['retrain'] = 'retrain' in folder
                dataframes.append(df)
            except Exception as e:
                print(f"Error reading {csv_path}: {e}")

    if len(dataframes) > 0:
        combined_ifr_df = pd.concat(dataframes, ignore_index=True)
        combined_ifr_df.to_csv(os.path.join(combined_path, 'combined_ifr_df.csv'), index=False)
    columns_to_ignore = ['loss']
    columns_to_consider = [col for col in combined_ifr_df.columns if col not in columns_to_ignore]
    current_length = len(combined_ifr_df)
    combined_ifr_df = combined_ifr_df.drop_duplicates(subset=columns_to_consider)
    combined_ifr_df = adjust_epochs(combined_ifr_df)
    print('duplicates removed', current_length, len(combined_ifr_df))

    if False:
        print('loading baseline h5s!!!!!')
        all_baseline_data = []
        base_path = os.path.dirname(list(folder_paths)[0])
        for regime in ['s1', 's2', 's3']:
            individual_data = load_h5_data(os.path.join(base_path, f'indy_ifr_base-{regime}-0.h5'), regime=regime, retrain=False, prior=False)
            all_baseline_data.append(individual_data)
        all_baseline_df = pd.concat(all_baseline_data, ignore_index=True)


        ### Get individual data
        print('loading model h5s')
        all_indy_data = []
        for folder in folder_paths:
            regime = os.path.basename(folder).replace('-retrain', '')
            retrain = 'retrain' in folder
            prior = 'prior' in folder
            for rep in [0, 1, 2]:
                individual_data = load_h5_data(os.path.join(folder, f'indy_ifr{rep}.h5'), regime=regime, retrain=retrain, prior=prior)
                all_indy_data.append(individual_data)
        all_indy_df = pd.concat(all_indy_data, ignore_index=True)



        if False:
            all_indy_all = all_indy_df[all_indy_df['act_key'] == 'all-activations']
            all_indy_final = all_indy_df[all_indy_df['act_key'] == 'final-layer-activations']
            make_splom_aux(all_indy_all, 'all', os.path.join(combined_path, 'sploms'))
            make_splom_aux(all_indy_final, 'final', os.path.join(combined_path, 'sploms'))
            make_scatters2(all_indy_all, all_indy_final, other_keys, combined_path)

        print('loading evaluation results')


        all_epochs_df = load_dataframes(combined_path_list, key_param_list, key_param)
        all_epochs_df['is_novel_task'] = all_epochs_df.apply(is_novel_task, axis=1)
        print('columns and epochs', all_epochs_df.columns, all_epochs_df['epoch'].unique())

        # run experiment 3:
        if True:
            print('running experiment 3')
            print(all_epochs_df.columns)
            print(all_indy_df.columns)
            for retrain in [True, False]:
                merged_df = pd.merge(all_indy_df[(all_indy_df['model_type'] == 'mlp1') & (all_indy_df['retrain'] == retrain)], all_epochs_df[['id', 'regime', 'is_novel_task', 'accuracy', 'repetition']], on=['id', 'regime', 'repetition'], how='left')
                merged_df = merged_df.drop_duplicates()
                print(len(merged_df))
                print(merged_df.head())
                print(merged_df.dtypes)
                print(merged_df['accuracy'].describe())
                print(merged_df['other_key'].value_counts())
                do_prediction_dependency(merged_df, 'accuracy', combined_path, retrain)


        all_epochs_df['regime_short'] = all_epochs_df['regime'].str[-2:]
        merged_baselines = pd.merge(all_baseline_df, all_epochs_df[['id', 'regime_short', 'is_novel_task', 'retrain']], left_on=['id', 'regime', 'retrain'], right_on=['id', 'regime_short', 'retrain'], how='left')
        all_baselines_df = merged_baselines
        all_baselines_df = all_baselines_df.drop_duplicates()

        # figure out whether each item in all_indy_df is novel:
        merged_df = pd.merge(all_indy_df, all_epochs_df[['id', 'regime', 'is_novel_task', 'retrain']], on=['id', 'regime', 'retrain'], how='left')
        all_indy_df = merged_df
        all_indy_df = all_indy_df.drop_duplicates()

        print('generating accuracy tables')
        if False:
            for retrain in [True, False]:
                baseline_tables = generate_accuracy_tables(all_baselines_df[(all_baselines_df['retrain'] == retrain) & (all_baselines_df['prior'] == False)], combined_path, is_baseline=True, retrain=retrain)
                result_tables = generate_accuracy_tables(all_indy_df[(all_indy_df['retrain'] == retrain) & (all_indy_df['prior'] == False)], combined_path, retrain=retrain)

    # make pred dep heatmaps
    if False:
        print('doing dependency heatmaps')
        for strat in ['normal', 'box']:
            if strat == "normal":
                strategies = {
                    'No-Mindreading': ['opponents', 'big-loc', 'small-loc'],
                    'Low-Mindreading': ['vision', 'fb-exist'],
                    'High-Mindreading': ['fb-loc', 'b-loc', 'target-loc']
                }
            else:
                strategies = {
                    'No-Mindreading': ['big-loc', 'big-box', 'small-loc', 'small-box'],
                    'High-Mindreading': ['fb-loc', 'fb-box', 'b-loc', 'b-box', 'target-loc', 'target-box',]
                }


            for retrain in [False, True]:
                dep_df = pd.read_csv(os.path.join(combined_path, f'accuracy_dependencies_retrain_{retrain}.csv'))

                for layer in ['all-activations', 'final-layer-activations']:
                        plot_dependency_bar_graphs_new(dep_df, combined_path, strategies, True, retrain=retrain, strat=strat, layer=layer)
            #create_faceted_heatmap(dep_df, True, 'final-layer-activations', os.path.join(combined_path, 'test.png'), strategies)

    return combined_path
    strategies_short = {
        'Opponents': ['opponents'],
        'Location Beliefs': ['b-loc']
    }
    strategies_long = {
        'No-Mindreading': ['pred', 'opponents', 'big-loc', 'small-loc'],
        'Low-Mindreading': ['vision', 'fb-exist'],
        'High-Mindreading': ['fb-loc', 'b-loc', 'target-loc', 'labels']
    }
    strategies_both = {
        'No-Mindreading': ['big-loc', 'big-box', 'small-loc', 'small-box'],
        'High-Mindreading': ['fb-loc', 'fb-box', 'b-loc', 'b-box', 'target-loc', 'target-box', ]
    }
    for retrain in [True, False]:
        baseline_tables = pd.read_csv(os.path.join(combined_path, f'base_all_table_retrain_False.csv'))
        result_tables = pd.read_csv(os.path.join(combined_path, f'all_table_retrain_{retrain}.csv'))

        print('made acc tables', len(baseline_tables), len(result_tables))
        print(baseline_tables.columns, result_tables.columns)

        this_path = os.path.join(combined_path, f'strats-rt-{retrain}')

        for result_type in ['mlp1', 'linear']:
            for layer in ['all', 'final-layer']:
                plot_bar_graphs_new2(baseline_tables[(baseline_tables['Model_Type'] == result_type)], result_tables[(result_tables['Model_Type'] == result_type)], this_path, strategies_short, layer=layer, r_type=f'spec-{result_type}')
                plot_bar_graphs_new2(baseline_tables[(baseline_tables['Model_Type'] == result_type)], result_tables[(result_tables['Model_Type'] == result_type)], this_path, strategies_long, layer=layer, r_type=result_type)
                #plot_bar_graphs_new(baseline_tables[(baseline_tables['Model_Type'] == result_type)], result_tables[(result_tables['Model_Type'] == result_type)], this_path, strategies_both, layer=layer, r_type=f'both-{result_type}')

        #plot_bar_graphs_special(baseline_tables, result_tables, os.path.join(combined_path, 'strats'), strategies)
        #plot_bar_graphs(baseline_tables[(baseline_tables['Model_Type'] == 'mlp1')], result_tables[(result_tables['Model_Type'] == 'mlp1')], os.path.join(combined_path, f'strats-rt-{retrain}'), strategies)

    print('Original columns and epochs:', all_epochs_df.columns, all_epochs_df['epoch'].unique())

    all_epochs_df = adjust_epochs(all_epochs_df)

    grouped_df = group_eval_df(all_epochs_df)

    print('merging dfs')
    merged_df = pd.merge(combined_ifr_df, grouped_df, on=['rep', 'model', 'epoch', 'retrain', 'prior'])
    print('after merge', merged_df['epoch'].unique(), merged_df['retrain'].unique(), merged_df['prior'].unique())

    # SPLOM
    for act in ['all_activations', 'final_layer_activations', 'input_activations']:
        try:
            make_splom(merged_df[(merged_df['act'] == act)], combined_path, act, False, True)
            make_scatter(merged_df[merged_df['act'] == act], combined_path, act)
        except BaseException as e:
            print('failed a splom', e)

    make_corr_things(merged_df, combined_path)

    #mean_correlation_df = correlation_df.groupby('feature')['correlation'].mean().reset_index()

    last_epoch_df = load_dataframes(last_path_list, key_param_list, key_param)

    print('last epoch df columns', last_epoch_df.columns)

    if all_epochs_df is not None and last_epoch_df is not None:
        create_combined_histogram(last_epoch_df, all_epochs_df, key_param, os.path.join('supervised', exp_name))

        avg_loss, variances, ranges_1, ranges_2, range_dict, range_dict3, stats, key_param_stats, oracle_stats, delta_sum, delta_x = calculate_statistics(
            all_epochs_df, last_epoch_df, list(set(params + prior_metrics + [key_param])),
            skip_3x=True, skip_1x=True, key_param=key_param, used_regimes=used_regimes, savepath=os.path.join('supervised', exp_name), last_timestep=True)



    write_metrics_to_file(os.path.join(combined_path, 'metrics.txt'), last_epoch_df, ranges_1, params, stats,
                          key_param=key_param, d_s=delta_sum, d_x=delta_x)
    save_figures(combined_path, combined_ifr_df, avg_loss, ranges_2, range_dict, range_dict3,
                 params, last_epoch_df, num=12, key_param_stats=key_param_stats, oracle_stats=oracle_stats,
                 key_param=key_param, delta_sum=delta_sum, delta_x=delta_x)

In [4]:
import argparse
import ast
import json
import os
import traceback
import multiprocessing
from itertools import product

import h5py

import pandas as pd
import tqdm

from src.pz_envs import ScenarioConfigs
from src.supervised_learning import gen_data
from src.utils.plotting import create_combined_histogram, plot_progression, save_key_param_figures, plot_learning_curves, make_splom, make_ifrscores, make_scatter, make_corr_things, make_splom_aux, plot_strategy_bar, \
    create_faceted_heatmap, plot_bar_graphs, plot_bar_graphs_special, plot_bar_graphs_new, plot_dependency_bar_graphs, plot_dependency_bar_graphs_new
from supervised_learning_main import run_supervised_session, calculate_statistics, write_metrics_to_file, save_figures, \
    train_model
import numpy as np
import random

import warnings

  from .autonotebook import tqdm as notebook_tqdm


In [19]:

def experiments(todo, repetitions, epochs=50, batches=5000, skip_train=False, skip_calc=False, batch_size=64, desired_evals=5,
                skip_eval=False, skip_activations=False, last_timestep=True, retrain=False, current_model_type=None, current_label=None, current_label_name=None,
                comparison=False):

    save_every = max(1, epochs // desired_evals)

    params = ['visible_baits', 'swaps', 'visible_swaps', 'first_swap_is_both',
              'second_swap_to_first_loc', 'delay_2nd_bait', 'first_bait_size',
              'uninformed_bait', 'uninformed_swap', 'first_swap', 'test_regime']
    prior_metrics = ['shouldAvoidSmall', 'correct-loc', 'incorrect-loc',
                     'shouldGetBig', 'informedness', 'p-b-0', 'p-b-1', 'p-s-0', 'p-s-1', 'delay', 'opponents']

    sub_regime_keys = [
        "Nn","Fn", "Nf","Tn", "Nt","Ff","Tf", "Ft","Tt"
    ]
    all_regimes = ['sl-' + x + '0' for x in sub_regime_keys] + ['sl-' + x + '1' for x in sub_regime_keys]
    mixed_regimes = {k: ['sl-' + x + '0' for x in sub_regime_keys] + ['sl-' + k + '1'] for k in sub_regime_keys}

    regimes = {}
    regimes['direct'] = ['sl-' + x + '1' for x in sub_regime_keys]
    regimes['noOpponent'] = ['sl-' + x + '0' for x in sub_regime_keys]
    regimes['everything'] = all_regimes
    hregime = {}
    hregime['homogeneous'] = ['sl-Tt0', 'sl-Ff0', 'sl-Nn0', 'sl-Tt1', 'sl-Ff1', 'sl-Nn1']
    #hregime['identity'] = ['sl-' + x + '0' for x in sub_regime_keys] + ['sl-Tt1', 'sl-Ff1', 'sl-Nn1']
    sregime = {}
    sregime['special'] = ['sl-Tt0', 'sl-Tt1', 'sl-Nt0', 'sl-Nt1', 'sl-Nf0', 'sl-Nf1', 'sl-Nn0', 'sl-Nn1']

    fregimes = {}
    fregimes['s1'] = regimes['noOpponent']
    fregimes['s3'] = regimes['everything']
    fregimes['s2'] = mixed_regimes['Tt']
    #fregimes['homogeneous'] = hregime['homogeneous']

    single_regimes = {k[3:]: [k] for k in all_regimes}
    leave_one_out_regimes = {}
    for i in range(len(sub_regime_keys)):
        regime_name = "lo_" + sub_regime_keys[i]
        leave_one_out_regimes[regime_name] = ['sl-' + x + '0' for x in sub_regime_keys]
        ones = ['sl-' + x + '1' for j, x in enumerate(sub_regime_keys) if j != i]
        leave_one_out_regimes[regime_name].extend(ones)

    pref_types = [
        ('same', ''), # ('different', 'd'), # ('varying', 'v'),
    ]
    role_types = [
        ('subordinate', ''), # ('dominant', 'D'), # ('varying', 'V'),
    ]

    # labels for ICLR, including size just in case

    model_type = "loc" # or box
    labels = [
        'id', 'i-informedness', # must have these or it will break
        'opponents',
        'big-loc',
        'small-loc',
        'target-loc',
        'b-loc',
        'fb-loc',
        'fb-exist',
        'vision',
        'big-box',
        'small-box',
        'target-box',
        'b-box',
        'fb-box',
        'box-locations'
              ]


    oracles = labels + [None]
    conf = ScenarioConfigs()
    exp_name = f'exp_{todo[0]}'
    if last_timestep:
        exp_name += "-L"

    session_params = {
        'repetitions': repetitions,
        'epochs': epochs,
        'batches': batches,
        'skip_train': skip_train,
        'skip_eval': skip_eval,
        'batch_size': batch_size,
        'prior_metrics': list(set(prior_metrics + labels)),
        'save_every': save_every,
        'skip_calc': skip_calc,
        'act_label_names': labels,
        'skip_activations': skip_activations,
        #'oracle_is_target': False,
        'last_timestep': last_timestep,
    }
    if 0 in todo:
        print('Generating datasets with labels', labels)
        os.makedirs('supervised', exist_ok=True)
        for pref_type, pref_suffix in pref_types:
            for role_type, role_suffix in role_types:
                gen_data(labels, path='supervised', pref_type=pref_suffix, role_type=role_suffix,
                         prior_metrics=prior_metrics, conf=conf)

    if 'h' in todo:
        print('Running hyperparameter search on all regimes, pref_types, role_types')
        run_hparam_search(trials=100, repetitions=3, log_file='hparam_file.txt', train_sets=regimes['direct'], epochs=20)

    if 2 in todo:
        print('Running experiment 1: base, different models and answers')

        combined_path_list = []
        last_path_list = []
        lp_list = []
        key_param = 'regime'
        key_param_list = []
        session_params['oracle_is_target'] = False

        if current_model_type != None:
            model_types = [current_model_type]
            label_tuples = [(current_label, current_label_name)]
        else:
            model_types = ['cnn', 'smlp', 'clstm']
            label_tuples = [('correct-loc', 'loc')]

        for label, label_name in label_tuples: #[('correct-loc', 'loc'), ('correct-box', 'box'), ('shouldGetBig', 'size')]:
            for model_type in model_types:#['smlp', 'cnn', 'clstm', ]:
                for regime in list(fregimes.keys()):
                    kpname = f'{model_type}-{label_name}-{regime}'
                    print(model_type + '-' + label_name, 'regime:', regime, 'train_sets:', fregimes[regime])
                    combined_paths, last_epoch_paths, lp = run_supervised_session(
                        save_path=os.path.join('supervised', exp_name, kpname),
                        train_sets=fregimes[regime],
                        eval_sets=fregimes['s3'],
                        oracle_labels=[None],
                        key_param=key_param,
                        key_param_value=kpname,
                        label=label,
                        model_type=model_type,
                        do_retrain_model=retrain,
                        **session_params
                    )
                    conditions = [
                        (lambda x: 'prior' not in x and 'retrain' not in x, ''),
                        #(lambda x: 'prior' in x and 'retrain' not in x, '-prior'),
                        #(lambda x: 'prior' not in x and 'retrain' in x, '-retrain')
                    ]

                    print('paths found', combined_paths, last_epoch_paths)

                    for condition, suffix in conditions:
                        last_path_list.append([x for x in last_epoch_paths if condition(x)])
                        combined_path_list.append([x for x in combined_paths if condition(x)])
                        key_param_list.append(kpname + suffix)
                    lp_list.append(lp) # has x, x-retrain currently

        if comparison:
            do_comparison(combined_path_list, last_path_list, key_param_list, key_param, exp_name, params, prior_metrics, lp_list)

In [39]:
experiments([2], repetitions=3,
                batches=10000,
                skip_train=True,
                skip_eval=True,
                skip_calc=True,
                skip_activations=True,
                retrain=False,
                batch_size=256,
                desired_evals=1,
                last_timestep=True,
                comparison=True)
print('finished')

{'Nn': 24, 'Tt': 104, 'Nt': 32, 'Tn': 32, 'Ff': 28, 'Tf': 16, 'Ft': 16, 'Nf': 22, 'Fn': 22}
list_events 296 total fillers 440 total permutations 23520
Running experiment 1: base, different models and answers
cnn-loc regime: s1 train_sets: ['sl-Nn0', 'sl-Fn0', 'sl-Nf0', 'sl-Tn0', 'sl-Nt0', 'sl-Ff0', 'sl-Tf0', 'sl-Ft0', 'sl-Tt0']
sup sess dfs paths ['supervised\\exp_2-L\\cnn-loc-s1\\param_losses_49_0.csv', 'supervised\\exp_2-L\\cnn-loc-s1\\param_losses_49_1.csv', 'supervised\\exp_2-L\\cnn-loc-s1\\param_losses_49_2.csv'] supervised\exp_2-L\cnn-loc-s1
sup sess dfs paths ['supervised\\exp_2-L\\cnn-loc-s1\\param_losses_49_0.csv', 'supervised\\exp_2-L\\cnn-loc-s1\\param_losses_49_1.csv', 'supervised\\exp_2-L\\cnn-loc-s1\\param_losses_49_2.csv', 'supervised\\exp_2-L\\cnn-loc-s1-retrain\\param_losses_49_0.csv', 'supervised\\exp_2-L\\cnn-loc-s1-retrain\\param_losses_49_1.csv', 'supervised\\exp_2-L\\cnn-loc-s1-retrain\\param_losses_49_2.csv'] supervised\exp_2-L\cnn-loc-s1
paths found ['supervised

In [28]:
def get_sort_key(model):
    if 'mlp' in model.lower():
        return 1
    elif 'cnn' in model.lower():
        return 2
    elif 'lstm' in model.lower():
        return 3
    else:
        return 4

In [71]:
def plot_bar_graphs_new(baselines_df, results_df, save_dir, strategies, one_dataset=True, one_model=False, layer='all', r_type='mlp1'):
    model_classes = sorted(set([model[:-7] for model in results_df['Model'].unique()]), key=get_sort_key)
    datasets = sorted(set([model[-2:] for model in results_df['Model'].unique()]))
    all_features = sorted(set([feature for features in strategies.values() for feature in features]))

    model_name_map = {
        'smlp': 'MLP Model',
        'cnn': 'CNN Model',
        'clstm': 'CLSTM Model'
    }
    
    model_colors = {
        'Raw Input': '#808080',
        'smlp': '#FF4444',
        'cnn': '#9944FF',
        'clstm': '#4444FF'
    }
    
    strategy_colors = {
        'No-Mindreading': '#FFE5E5',
        'Low-Mindreading': '#FFFAE5',
        'High-Mindreading': '#E5F0FF'
    }
    
    dataset_name_map = {
        's1': 'Stage 1 Training',
        's2': 'Stage 2 Training',
        's3': 'Stage 3 Training'
    }

    fig = plt.figure(figsize=(8 * len(datasets), 5.5 * (len(model_classes) + 1) + 2))
    
    gs = fig.add_gridspec(len(model_classes) + 2, len(datasets), 
                         height_ratios=[4]*(len(model_classes) + 1) + [1],
                         hspace=0.45,
                         wspace=0.05)
    axs = [[fig.add_subplot(gs[i, j]) for j in range(len(datasets))] for i in range(len(model_classes) + 1)]
    
    bar_width = 0.15
    feature_gap = 0.03
    strategy_gap = 0.25 if 'spe' not in r_type else 0.15
    error_offset = bar_width * 0.25

    # Calculate total width for each strategy group
    strategy_widths = {}
    for strategy, features in strategies.items():
        total_bars = len(features)
        strategy_widths[strategy] = total_bars * bar_width + (total_bars - 1) * feature_gap + strategy_gap

    # Add strategy backgrounds
    for col in range(len(datasets)):
        ax = axs[0][col]
        x_offset = -bar_width
        for strategy, features in strategies.items():
            width = strategy_widths[strategy]
            rect = plt.Rectangle((x_offset, 0), width, 1,
                               facecolor=strategy_colors[strategy],
                               transform=ax.get_xaxis_transform(),
                               zorder=-1,
                               clip_on=False)
            ax.add_patch(rect)
            
            for row in range(1, len(model_classes) + 1):
                rect = plt.Rectangle((x_offset, 0), width, 1,
                                   facecolor=strategy_colors[strategy],
                                   transform=axs[row][col].get_xaxis_transform(),
                                   zorder=-1,
                                   clip_on=False)
                axs[row][col].add_patch(rect)
            x_offset += width

    # Plot baselines row (row 0)
    for col, dataset in enumerate(datasets):
        ax = axs[0][col]
        x_offset = 0
        
        for strategy_idx, (strategy, features) in enumerate(strategies.items()):
            for feature_idx, feature in enumerate(features):
                baseline = baselines_df[(baselines_df['Feature'] == feature) & 
                                     (baselines_df['Model'] == dataset)]
                
                if len(baseline) > 0:
                    baseline_familiar = baseline['Familiar accuracy (input-activations)'].values[0]
                    baseline_novel = baseline['Novel accuracy (input-activations)'].values[0]
                    baseline_familiar_err = [baseline['Familiar q1 (input-activations)'].values[0],
                                          baseline['Familiar q3 (input-activations)'].values[0]]
                    baseline_novel_err = [baseline['Novel q1 (input-activations)'].values[0],
                                        baseline['Novel q3 (input-activations)'].values[0]]
                    
                    # Draw bars without error bars
                    ax.bar(x_offset, baseline_familiar, bar_width,
                          color='white', edgecolor=model_colors['Raw Input'])
                    ax.bar(x_offset, baseline_novel, bar_width,
                          color=model_colors['Raw Input'])
                    
                    # Add error bars separately with offset
                    ax.errorbar(x_offset + error_offset, baseline_familiar,
                              yerr=[[baseline_familiar - baseline_familiar_err[0]],
                                   [baseline_familiar_err[1] - baseline_familiar]],
                              color='black', capsize=3, fmt='none')
                    ax.errorbar(x_offset - error_offset, baseline_novel,
                              yerr=[[baseline_novel - baseline_novel_err[0]],
                                   [baseline_novel_err[1] - baseline_novel]],
                              color='black', capsize=3, fmt='none')
                    
                    ax.text(x_offset, -0.05, feature, ha='right', va='top',
                           rotation=45, rotation_mode='anchor', fontsize=12)
                
                x_offset += bar_width + feature_gap
            x_offset += strategy_gap

        ax.set_ylabel('Accuracy' if col == 0 else '')
        ax.set_ylim(0, 1)
        ax.set_xlim(-bar_width, x_offset - strategy_gap)
        ax.set_xticks([])
        for spine in ax.spines.values():
            spine.set_visible(False)
        
        if col == 0:
            ax.text(-0.25, 0.5, 'Raw Input', va='center', ha='right', 
                   color=model_colors['Raw Input'], fontsize=24, fontweight='bold')
        
        if col != 0:
            ax.yaxis.set_ticks([])
            
        ax.set_title(dataset_name_map[dataset], fontsize=16, fontweight='bold')

    # Plot model results (rows 1 onwards)
    for row, model_class in enumerate(model_classes, start=1):
        for col, dataset in enumerate(datasets):
            ax = axs[row][col]
            x_offset = 0
            
            for strategy_idx, (strategy, features) in enumerate(strategies.items()):
                for feature_idx, feature in enumerate(features):
                    model_name = f"{model_class}-loc-{dataset}"
                    results = results_df[(results_df['Feature'] == feature) & 
                                      (results_df['Model'] == model_name)]
                    
                    if len(results) > 0:
                        result_familiar = results[f'Familiar accuracy ({layer}-activations)'].values[0]
                        result_novel = results[f'Novel accuracy ({layer}-activations)'].values[0]
                        result_familiar_err = [results[f'Familiar q1 ({layer}-activations)'].values[0],
                                            results[f'Familiar q3 ({layer}-activations)'].values[0]]
                        result_novel_err = [results[f'Novel q1 ({layer}-activations)'].values[0],
                                          results[f'Novel q3 ({layer}-activations)'].values[0]]

                        # Draw bars without error bars
                        ax.bar(x_offset, result_familiar, bar_width,
                              color='white', edgecolor=model_colors[model_class])
                        ax.bar(x_offset, result_novel, bar_width,
                              color=model_colors[model_class])
                        
                        # Add error bars separately with offset
                        ax.errorbar(x_offset + error_offset, result_familiar,
                                  yerr=[[result_familiar - result_familiar_err[0]],
                                       [result_familiar_err[1] - result_familiar]],
                                  color=model_colors[model_class], capsize=3, fmt='none')
                        ax.errorbar(x_offset - error_offset, result_novel,
                                  yerr=[[result_novel - result_novel_err[0]],
                                       [result_novel_err[1] - result_novel]],
                                  color=model_colors[model_class], capsize=3, fmt='none')
                        
                        ax.text(x_offset, -0.05, feature, ha='right', va='top',
                               rotation=45, rotation_mode='anchor', fontsize=12)
                    
                    x_offset += bar_width + feature_gap
                x_offset += strategy_gap

            ax.set_ylabel('Accuracy' if col == 0 else '')
            ax.set_ylim(0, 1)
            ax.set_xlim(-bar_width, x_offset - strategy_gap)
            ax.set_xticks([])
            for spine in ax.spines.values():
                spine.set_visible(False)

            if col != 0:
                ax.yaxis.set_ticks([])

            if col == 0:
                ax.text(-0.25, 0.5, model_name_map[model_class], va='center', ha='right', 
                       color=model_colors[model_class], fontsize=24, fontweight='bold')

    # Add strategy labels at bottom
    for col in range(len(datasets)):
        ax = axs[-1][col]
        x_offset = -bar_width
        for strategy, features in strategies.items():
            width = strategy_widths[strategy]
            center = x_offset + width/2
            strategy_text = strategy.replace('-', '\n')
            ax.text(center, -0.3, strategy_text, ha='center', va='top', 
                   fontsize=16, fontweight='bold')
            x_offset += width
        ax.axis('off')

    plt.savefig(f'{save_dir}/{layer}_models_datasets_accuracy_comparison_{r_type}.png', 
                bbox_inches='tight', dpi=300)
    plt.close(fig)

In [72]:
last_timestep = True
exp_name = f'exp_{2}'
if last_timestep:
    exp_name += "-L"

combined_path = os.path.join('supervised', exp_name, 'c')

strategies_short = {
    'No-Mindreading': ['opponents'],
    'High-Mindreading': ['b-loc']
}
strategies_long = {
    'No-Mindreading': ['pred', 'opponents', 'big-loc', 'small-loc'],
    'Low-Mindreading': ['vision', 'fb-exist'],
    'High-Mindreading': ['fb-loc', 'b-loc', 'target-loc', 'labels']
}
strategies_both = {
    'No-Mindreading': ['big-loc', 'big-box', 'small-loc', 'small-box'],
    'High-Mindreading': ['fb-loc', 'fb-box', 'b-loc', 'b-box', 'target-loc', 'target-box', ]
}

#for retrain in [True, False]:
for retrain in [False]:
    baseline_tables = pd.read_csv(os.path.join(combined_path, f'base_all_table_retrain_False.csv'))
    result_tables = pd.read_csv(os.path.join(combined_path, f'all_table_retrain_{retrain}.csv'))

    print('made acc tables', len(baseline_tables), len(result_tables))
    print(baseline_tables.columns, result_tables.columns)

    this_path = os.path.join(combined_path, f'strats-rt-{retrain}')

    for result_type in ['mlp1', 'linear']:
        for layer in ['all', 'final-layer']:
            print('doing')
            plot_bar_graphs_new2(baseline_tables[(baseline_tables['Model_Type'] == result_type)], result_tables[(result_tables['Model_Type'] == result_type)], this_path, strategies_short, layer=layer, r_type=f'spec-{result_type}')
            plot_bar_graphs_new2(baseline_tables[(baseline_tables['Model_Type'] == result_type)], result_tables[(result_tables['Model_Type'] == result_type)], this_path, strategies_long, layer=layer, r_type=result_type)

made acc tables 54 288
Index(['Feature', 'Model', 'Model_Type',
       'Familiar accuracy (input-activations)',
       'Familiar between-model std (input-activations)',
       'Familiar within-model std (input-activations)',
       'Familiar q1 (input-activations)', 'Familiar q3 (input-activations)',
       'Novel accuracy (input-activations)',
       'Novel between-model std (input-activations)',
       'Novel within-model std (input-activations)',
       'Novel q1 (input-activations)', 'Novel q3 (input-activations)',
       'data_Type'],
      dtype='object') Index(['Feature', 'Model', 'Model_Type', 'Familiar accuracy (all-activations)',
       'Familiar between-model std (all-activations)',
       'Familiar within-model std (all-activations)',
       'Familiar q1 (all-activations)', 'Familiar q3 (all-activations)',
       'Novel accuracy (all-activations)',
       'Novel between-model std (all-activations)',
       'Novel within-model std (all-activations)',
       'Novel q1 (all-ac