In [None]:
import torch
from tqdm import tqdm
from argparse import Namespace


config = Namespace(
    data_folder='./wm_bench_data', 
    max_seq_len=20, 
    rs_img_size=32, 
    batch_size=10, 
    num_workers=4, 
    use_cnn=1, 
    model_path='./model.pt',
)

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

In [None]:
from src.utils.data_utils import get_test_multitask_dataloader

test_loader = get_test_multitask_dataloader(config)

In [None]:
from src.model import WM_Model

model_data = torch.load(config.model_path)
model = WM_Model(Namespace(**model_data['config']), device).to(device)
model.load_state_dict(model_data['model_state_dict'])
model.eval()

In [None]:
model_dict = {
    'model_1': model
}

In [None]:
import numpy as np

epoch_acc = {}

dataloader = zip(*test_loader.values())

with torch.no_grad():
    for batch_index, multitask_batch in tqdm(enumerate(dataloader)):
        stim_batch, resp_batch, seq_len, list_length = multitask_batch[6]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        for model_name, model in model_dict.items():
            out, _, _, _, _ = model(stim_batch, 'VSR_Task', seq_len)
            pred = torch.argmax(out, dim=-1)

            for index, _ in enumerate(seq_len):
                curr_gt = resp_batch[index, list_length[index]:list_length[index]*2]
                curr_pred = pred[index, list_length[index]:list_length[index]*2]

                for count, item in enumerate(curr_gt):
                    if (model_name+'_List_Length_'+str(list_length[index].item())+
                        '_SP_'+str(count+1)) not in epoch_acc:
                        epoch_acc[model_name+'_List_Length_'+
                                  str(list_length[index].item())+'_SP_'+str(count+1)] = []

                    if item == curr_pred[count]:
                        epoch_acc[model_name+'_List_Length_'+
                                  str(list_length[index].item())+'_SP_'+str(count+1)].append(1)
                    else:
                        epoch_acc[model_name+'_List_Length_'+
                                  str(list_length[index].item())+'_SP_'+str(count+1)].append(0)

In [None]:
import json

with open('vsr_accs.json', 'w') as fp:
    json.dump(epoch_acc, fp)

In [None]:
import numpy as np
from scipy import stats

epoch_acc_std_err = {}

for key in epoch_acc:
    epoch_acc_std_err[key] = [np.mean(epoch_acc[key]), 
                              stats.sem(epoch_acc[key])]

In [None]:
plot_data = {}

for key in list(epoch_acc_std_err.keys()):
    model = key.split('_')[0] + '_' + key.split('_')[1]
    list_length = key.split('_')[4]
    sp = key.split('_')[6]

    if model not in plot_data:
        plot_data[model] = {}
    
    if list_length not in plot_data[model]:
        plot_data[model][list_length] = [[], [], []]

    plot_data[model][list_length][0].append(sp)
    plot_data[model][list_length][1].append(epoch_acc_std_err[key][0])
    plot_data[model][list_length][2].append(epoch_acc_std_err[key][1])

In [None]:
human_plot_data = {
    '3': [['1', '2', '3'], [7.16/8, 6.8/8, 6.9/8], []], 
    '4': [['1', '2', '3', '4'], [6.35/8, 5.25/8, 4.8/8, 6.05/8], []],
    '5': [['1', '2', '3', '4', '5'], [5.5/8, 4.3/8, 4.2/8, 3.6/8, 5/8], []], 
    '6': [['1', '2', '3', '4', '5', '6'], [5.1/8, 4.25/8, 3.55/8, 3.5/8, 2.9/8, 3.95/8], []]
}

In [None]:
import seaborn as sns
from matplotlib import cm
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "Serif"
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1})

plt.figure(figsize=(10, 5))
# Create twp subplots
ax1 = plt.subplot(2, 4, 7)
ax2 = plt.subplot(1, 2, 1)


for model in list(plot_data.keys()):
    if model == 'gru_1024':
        for list_length in list(plot_data[model].keys()):
            if list_length in ['3', '4', '5', '6', '7', '9']:
                ax2.errorbar(plot_data[model][list_length][0], 
                            plot_data[model][list_length][1], 
                            yerr=plot_data[model][list_length][2],
                            label='List Length '+list_length, fmt='o-',
                            linewidth=2, markersize=3, capsize=4)
                
                
ax2.set_ylabel('Top-1 Accuracy', fontsize=25)

ax2.set_xticks([0, 1, 2, 3, 4, 5, 6, 7, 8])
ax2.set_xticklabels([1, 2, 3, 4, 5, 6, 7, 8, 9])

ax2.set_ylim([0.5, 1.00])
ax2.set_yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0])


for list_length in list(human_plot_data.keys()):        
    ax1.errorbar(human_plot_data[list_length][0], 
                human_plot_data[list_length][1], 
                yerr=0, 
                label='List Length '+list_length, fmt='o-',
                linewidth=2, markersize=7, color='C'+str(int(list_length)-2))
            
sns.despine(left=False, bottom=False, right=True, top=True)


ax1.set_xticks([0, 1, 2, 3, 4, 5])
ax1.set_xticklabels([1, 2, 3, 4, 5, 6], fontsize=18)

ax1.set_yticks([0.4, 0.6, 0.8, 1.0])
ax1.set_yticklabels([0.4, 0.6, 0.8, 1.0], fontsize=18)

ax1.set_ylim([0.3, 1.01])


plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.legend(frameon=False, loc='upper center', 
           bbox_to_anchor=(1.33, 1.05), ncol=1, prop={'size': 16})

plt.text(5, 0.405, 'Serial Position', fontsize=25)

plt.show()