In [None]:
import torch
from tqdm import tqdm
from argparse import Namespace


config = Namespace(
    data_folder='./wm_bench_data', 
    max_seq_len=20, 
    rs_img_size=32, 
    batch_size=10, 
    num_workers=4, 
    use_cnn=1, 
    model_path='./trained_models/LSTM/256/model.pt'
)

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

In [None]:
from src.utils.data_utils import get_test_multitask_dataloader

test_loader = get_test_multitask_dataloader(config)

In [None]:
from src.model import WM_Model

model_data = torch.load(config.model_path)
model = WM_Model(Namespace(**model_data['config']), device).to(device)
model.load_state_dict(model_data['model_state_dict'])
model.eval()

In [None]:
model_dict = {
    'model_1': model
}

In [None]:
import numpy as np

epoch_acc = {}

dataloader = zip(*test_loader.values())

with torch.no_grad():
    for batch_index, multitask_batch in tqdm(enumerate(dataloader)):
        stim_batch, resp_batch, seq_len, ri, set_size = multitask_batch[8]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        for model_name, model in model_dict.items():
            out, _, _, _, _ = model(stim_batch, 'CD_Color_Task', seq_len)
            pred = torch.round(torch.sigmoid(out))

            for index, length in enumerate(seq_len):
                if model_name+'_Color_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item()) not in epoch_acc:
                    epoch_acc[model_name+'_Color_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())] = []

                if (pred[index, length-1] == resp_batch[index, length-1]).item():
                    epoch_acc[model_name+'_Color_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())].append(1)
                else:
                    epoch_acc[model_name+'_Color_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())].append(0)

        

        stim_batch, resp_batch, seq_len, ri, set_size = multitask_batch[9]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        for model_name, model in model_dict.items():
            out, _, _, _, _ = model(stim_batch, 'CD_Orientation_Task', seq_len)
            pred = torch.round(torch.sigmoid(out))

            for index, length in enumerate(seq_len):
                if model_name+'_Orientation_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item()) not in epoch_acc:
                    epoch_acc[model_name+'_Orientation_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())] = []

                if (pred[index, length-1] == resp_batch[index, length-1]).item():
                    epoch_acc[model_name+'_Orientation_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())].append(1)
                else:
                    epoch_acc[model_name+'_Orientation_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())].append(0)



        stim_batch, resp_batch, seq_len, ri, set_size = multitask_batch[10]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        for model_name, model in model_dict.items():
            out, _, _, _, _ = model(stim_batch, 'CD_Size_Task', seq_len)
            pred = torch.round(torch.sigmoid(out))

            for index, length in enumerate(seq_len):
                if model_name+'_Size_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item()) not in epoch_acc:
                    epoch_acc[model_name+'_Size_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())] = []

                if (pred[index, length-1] == resp_batch[index, length-1]).item():
                    epoch_acc[model_name+'_Size_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())].append(1)
                else:
                    epoch_acc[model_name+'_Size_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())].append(0)



        stim_batch, resp_batch, seq_len, ri, set_size = multitask_batch[11]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        for model_name, model in model_dict.items():
            out, _, _, _, _ = model(stim_batch, 'CD_Gap_Task', seq_len)
            pred = torch.round(torch.sigmoid(out))

            for index, length in enumerate(seq_len):
                if model_name+'_Gap_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item()) not in epoch_acc:
                    epoch_acc[model_name+'_Gap_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())] = []

                if (pred[index, length-1] == resp_batch[index, length-1]).item():
                    epoch_acc[model_name+'_Gap_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())].append(1)
                else:
                    epoch_acc[model_name+'_Gap_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())].append(0)


        
        stim_batch, resp_batch, seq_len, ri, set_size, conj_gt = multitask_batch[12]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        for model_name, model in model_dict.items():
            out, _, _, _, _ = model(stim_batch, 'CD_Conj_Task', seq_len)
            pred = torch.round(torch.sigmoid(out))

            for index, length in enumerate(seq_len):
                if model_name+'_Conj_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item()) not in epoch_acc:
                    epoch_acc[model_name+'_Conj_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())] = []

                if (pred[index, length-1] == resp_batch[index, length-1]).item():
                    epoch_acc[model_name+'_Conj_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())].append(1)
                else:
                    epoch_acc[model_name+'_Conj_Set_Size_'+str(set_size[index].item())+'_RI_'+str(ri[index].item())].append(0)

In [None]:
import json

with open('cd_accs.json', 'w') as fp:
    json.dump(epoch_acc, fp)

In [None]:
import numpy as np
epoch_acc_std_err = {}

for key in epoch_acc:
    epoch_acc_std_err[key] = [np.mean(epoch_acc[key]), 
                              np.std(epoch_acc[key])/np.sqrt(len(epoch_acc[key]))]

In [None]:
plot_data = {}

for key in list(epoch_acc_std_err.keys()):
    model = key.split('_')[0] + '_' + key.split('_')[1]
    feature = key.split('_')[2]
    set_size = key.split('_')[5]
    ri = key.split('_')[7]

    if model not in plot_data:
        plot_data[model] = {}

    if feature not in plot_data[model]:
        plot_data[model][feature] = {}

    if ri not in plot_data[model][feature]:
        plot_data[model][feature][ri] = [[], [], []]

    plot_data[model][feature][ri][0].append(int(set_size))
    plot_data[model][feature][ri][1].append(epoch_acc_std_err[key][0])
    plot_data[model][feature][ri][2].append(epoch_acc_std_err[key][1])

In [None]:
human_plot_data = {
    'Color': [[2, 4, 6], [0.985, 0.945, 0.84]], 
    'Orientation': [[2, 4, 6], [0.955, 0.90, 0.81]], 
    'Size': [[2, 4, 6], [0.975, 0.93, 0.82]], 
    'Gap': [[2, 4, 6], [0.945, 0.89, 0.80]], 
    'Conjunction': [[2, 4, 6], [0.965, 0.92, 0.83]]
}

In [None]:
import seaborn as sns
from matplotlib import cm
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "Serif"
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1})

fig, ax = plt.subplots(figsize=(5, 5))

for model in plot_data:
    if model == 'rnn_96':
        for feature in plot_data[model]:
            for ri in plot_data[model][feature]:
                if ri in ['18']:
                    if feature == 'Conj':
                        ax.errorbar(plot_data[model][feature][ri][0], 
                                    plot_data[model][feature][ri][1], 
                                    yerr=0, 
                                    label='RNN-96'+' '+'Conjunction', 
                                    linewidth=1.5, markersize=6, fmt='-s')
                    else:
                        ax.errorbar(plot_data[model][feature][ri][0], 
                                    plot_data[model][feature][ri][1], 
                                    yerr=0, 
                                    label='RNN-96'+' '+feature, 
                                    linewidth=1.5, markersize=6, fmt='-s')
                        

for feature in human_plot_data:
    if feature == 'Conj':
        ax.errorbar(human_plot_data[feature][0], 
                    human_plot_data[feature][1], 
                    yerr=0, 
                    label='Human'+' '+'Conjunction', 
                    linewidth=1.5, markersize=6, fmt='--o')
    else:
        ax.errorbar(human_plot_data[feature][0], 
                    human_plot_data[feature][1], 
                    yerr=0, 
                    label='Human'+' '+feature, 
                    linewidth=1.5, markersize=6, fmt='--o')

                
sns.despine(left=False, bottom=False, right=True, top=True)

ax.set_xlabel('Set Size', fontsize=20)
ax.set_ylabel('Top-1 Accuracy', fontsize=20)

ax.set_yticks([0.7, 0.8, 0.9, 1.0])

ax.set_ylim([0.7, 1.01])

plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

plt.legend(frameon=False, loc='upper center', bbox_to_anchor=(1.15, 0.87), ncol=1, prop={'size': 14})

plt.savefig('cd_accs_fin.png', bbox_inches='tight', dpi=300)
plt.show()