In [None]:
import torch
from tqdm import tqdm
from argparse import Namespace


config = Namespace(
    data_folder='./wm_bench_data', 
    max_seq_len=20, 
    rs_img_size=32, 
    batch_size=10, 
    num_workers=4, 
    use_cnn=1, 
    model_path='./model.pt'
)

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

In [None]:
from src.utils.data_utils import get_test_multitask_dataloader

test_loader = get_test_multitask_dataloader(config)

In [None]:
from src.model import WM_Model

model_data = torch.load(config.model_path)
model = WM_Model(Namespace(**model_data['config']), device).to(device)
model.load_state_dict(model_data['model_state_dict'])
model.eval()

In [None]:
model_dict = {
    'model_1': model
}

In [None]:
import numpy as np

epoch_acc = {}

dataloader = zip(*test_loader.values())

with torch.no_grad():
    for batch_index, multitask_batch in tqdm(enumerate(dataloader)):
        stim_batch, resp_batch, seq_len, list_length, distractor_diff = multitask_batch[7]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        for model_name, model in model_dict.items():
            out, _, _, _, _ = model(stim_batch, 'VSRec_Task', seq_len)
            pred = torch.round(torch.sigmoid(out))

            for index, length in enumerate(seq_len):
                curr_gt = resp_batch[index, list_length[index]:list_length[index]*2]
                curr_pred = pred[index, list_length[index]:list_length[index]*2]

                for count, item in enumerate(curr_gt):
                    if (model_name+'_LL_'+str(list_length[index].item())+'_DD_'+str(distractor_diff[index].item())+
                        '_SP_'+str(count+1)) not in epoch_acc:
                        epoch_acc[model_name+'_LL_'+str(list_length[index].item())+'_DD_'+str(distractor_diff[index].item())+
                                  '_SP_'+str(count+1)] = []

                    if item == curr_pred[count]:
                        epoch_acc[model_name+'_LL_'+str(list_length[index].item())+'_DD_'+str(distractor_diff[index].item())+
                                  '_SP_'+str(count+1)].append(1)
                    else:
                        epoch_acc[model_name+'_LL_'+str(list_length[index].item())+'_DD_'+str(distractor_diff[index].item())+
                                  '_SP_'+str(count+1)].append(0) 

In [None]:
import json

with open('vsrec_accs.json', 'w') as fp:
    json.dump(epoch_acc, fp)

In [None]:
import numpy as np
from scipy import stats

epoch_acc_std_err = {}

for key in epoch_acc:
    epoch_acc_std_err[key] = [np.mean(epoch_acc[key]), 
                              stats.sem(epoch_acc[key])]

In [None]:
plot_data = {}

for key in list(epoch_acc_std_err.keys()):
    model = key.split('_')[0] + '_' + key.split('_')[1] + '_' + key.split('_')[2]
    list_length = key.split('_')[4]
    dd = key.split('_')[6]
    sp = key.split('_')[8]

    if model not in plot_data:
        plot_data[model] = {}
    
    if list_length not in plot_data[model]:
        plot_data[model][list_length] = {}

    if dd not in plot_data[model][list_length]:
        plot_data[model][list_length][dd] = [[], [], []]

    plot_data[model][list_length][dd][0].append(sp)
    plot_data[model][list_length][dd][1].append(epoch_acc_std_err[key][0])
    plot_data[model][list_length][dd][2].append(epoch_acc_std_err[key][1])

In [None]:
human_plot_data = {
    '2': [[1, 2, 3, 4, 5], [13.65/20, 14.1/20, 13.5/20, 13.65/20, 13.7/20]], 
    '4': [[1, 2, 3, 4, 5], [14.85/20, 14.8/20, 15.4/20, 15.1/20, 14/20]], 
    '6': [[1, 2, 3, 4, 5], [15.55/20, 16/20, 15.8/20, 15.6/20, 15/20]],
}

In [None]:
import seaborn as sns
from matplotlib import cm
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "Serif"
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1})

fig, ax = plt.subplots(figsize=(10, 5))

ax1 = plt.subplot(2, 4, 7)
ax2 = plt.subplot(1, 2, 1)

for model in list(plot_data.keys()):
    if model == 'trf_1024_2':
        for list_length in list(plot_data[model].keys()):
            if list_length in ['5']:
                for dd in list(plot_data[model][list_length].keys()):
                    if dd in ['2']:
                        ax2.errorbar(plot_data[model][list_length][dd][0], 
                                    plot_data[model][list_length][dd][1], 
                                    yerr=plot_data[model][list_length][dd][2], 
                                    label='Pattern-Distrac Diff. '+dd, fmt='o-',
                                    linewidth=2, markersize=3, capsize=4)


for model in list(plot_data.keys()):
    if model == 'trf_1024_2':
        for list_length in list(plot_data[model].keys()):
            if list_length in ['5']:
                for dd in list(plot_data[model][list_length].keys()):
                    if dd in ['4']:
                        ax2.errorbar(plot_data[model][list_length][dd][0], 
                                    plot_data[model][list_length][dd][1], 
                                    yerr=plot_data[model][list_length][dd][2], 
                                    label='Pattern-Distrac Diff. '+dd, fmt='o-',
                                    linewidth=2, markersize=3, capsize=4)
                        

for model in list(plot_data.keys()):
    if model == 'trf_1024_2':
        for list_length in list(plot_data[model].keys()):
            if list_length in ['5']:
                for dd in list(plot_data[model][list_length].keys()):
                    if dd in ['6']:
                        ax2.errorbar(plot_data[model][list_length][dd][0], 
                                    plot_data[model][list_length][dd][1], 
                                    yerr=plot_data[model][list_length][dd][2], 
                                    label='Pattern-Distrac Diff. '+dd, fmt='o-',
                                    linewidth=2, markersize=3, capsize=4)
                        

for model in list(plot_data.keys()):
    if model == 'trf_1024_2':
        for list_length in list(plot_data[model].keys()):
            if list_length in ['5']:
                for dd in list(plot_data[model][list_length].keys()):
                    if dd in ['8']:
                        ax2.errorbar(plot_data[model][list_length][dd][0], 
                                    plot_data[model][list_length][dd][1], 
                                    yerr=plot_data[model][list_length][dd][2], 
                                    label='Pattern-Distrac Diff. '+dd, fmt='o-',
                                    linewidth=2, markersize=3, capsize=4)


for model in list(plot_data.keys()):
    if model == 'trf_1024_2':
        for list_length in list(plot_data[model].keys()):
            if list_length in ['5']:
                for dd in list(plot_data[model][list_length].keys()):
                    if dd in ['10']:
                        ax2.errorbar(plot_data[model][list_length][dd][0], 
                                    plot_data[model][list_length][dd][1], 
                                    yerr=plot_data[model][list_length][dd][2], 
                                    label='Pattern-Distrac Diff. '+dd, fmt='o-',
                                    linewidth=2, markersize=3, capsize=4)


sns.despine(left=False, bottom=False, right=True, top=True)

ax2.set_ylabel('Top-1 Accuracy', fontsize=25)

ax2.set_xticks([0, 1, 2, 3, 4])
ax2.set_xticklabels([1, 2, 3, 4, 5], fontsize=20)


ax2.set_yticks([0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
ax2.set_ylim([0.3, 1.01])


for dd in list(human_plot_data.keys()):        
    ax1.errorbar(human_plot_data[dd][0], 
                human_plot_data[dd][1], 
                yerr=0, 
                label='Pattern-Distrac Diff. '+dd, fmt='o-',
                linewidth=2, markersize=7)

ax1.set_ylim([0.5, 1.01])
ax1.set_xticks([0.5, 1, 2, 3, 4, 5])
ax1.set_xticklabels(['', 1, 2, 3, 4, 5], fontsize=20)

ax1.set_yticks([0.6, 0.8, 1.0])
ax1.set_yticklabels([0.6, 0.8, 1.0], fontsize=18)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.legend(frameon=False, loc='upper center', bbox_to_anchor=(1.49, 1.0), ncol=1, prop={'size': 16})

plt.text(2.3, 0.17, 'Serial Position', fontsize=25)

plt.show()