In [None]:
import torch
from tqdm import tqdm
from argparse import Namespace


config = Namespace(
    data_folder='./wm_bench_data', 
    max_seq_len=20, 
    rs_img_size=32, 
    batch_size=10, 
    num_workers=4, 
    use_cnn=1, 
    model_path='./model.pt'
)

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

In [None]:
from src.utils.data_utils import get_test_multitask_dataloader

test_loader = get_test_multitask_dataloader(config)

In [None]:
from src.model import WM_Model

model_data = torch.load(config.model_path)
model = WM_Model(Namespace(**model_data['config']), device).to(device)
model.load_state_dict(model_data['model_state_dict'])
model.eval()

In [None]:
model_dict = {
    'model_1': model,
}

In [None]:
import numpy as np

epoch_acc = {}

dataloader = zip(*test_loader.values())

with torch.no_grad():
    for batch_index, multitask_batch in tqdm(enumerate(dataloader)):
        stim_batch, resp_batch, seq_len, ri, gt_index = multitask_batch[5]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        for model_name, model in model_dict.items():
            out, _, _, _, _ = model(stim_batch, 'VIRec_2C_Task', seq_len)
            pred = torch.round(torch.sigmoid(out))

            for index, length in enumerate(seq_len):
                ll = length.item() - ri[index].item() - 1
                if (model_name+'_RI_'+str(ri[index].item())+'_LL_'+str(ll)+'_SP_'+
                    str(gt_index[index].item())) not in epoch_acc:
                    epoch_acc[model_name+'_RI_'+str(ri[index].item())+'_LL_'+str(ll)+'_SP_'+str(gt_index[index].item())] = []

                if pred[index, length-1] == resp_batch[index, length-1]:
                    epoch_acc[model_name+'_RI_'+str(ri[index].item())+'_LL_'+str(ll)+'_SP_'+str(gt_index[index].item())].append(1)
                else:
                    epoch_acc[model_name+'_RI_'+str(ri[index].item())+'_LL_'+str(ll)+'_SP_'+str(gt_index[index].item())].append(0)

In [None]:
import json

with open('vir_accs.json', 'w') as fp:
    json.dump(epoch_acc, fp)

In [None]:
import numpy as np
from scipy import stats

epoch_acc_std_err = {}

for key in epoch_acc:
    epoch_acc_std_err[key] = [np.mean(epoch_acc[key]), 
                              stats.sem(epoch_acc[key])]

In [None]:
plot_data = {}

for key in list(epoch_acc_std_err.keys()):
    model = key.split('_')[0] + '_' + key.split('_')[1] + '_' + key.split('_')[2]
    ri = key.split('_')[4]
    ll = key.split('_')[6]
    sp = int(key.split('_')[8]) + 1

    if model not in plot_data:
        plot_data[model] = {}

    if ll not in plot_data[model]:
        plot_data[model][ll] = {}

    if ri not in plot_data[model][ll]:
        plot_data[model][ll][ri] = [[], [], []]

    plot_data[model][ll][ri][0].append(sp)
    plot_data[model][ll][ri][1].append(epoch_acc_std_err[key][0])
    plot_data[model][ll][ri][2].append(epoch_acc_std_err[key][1])

In [None]:
sorted_plot_data = {}

for model in plot_data:
    for ll in plot_data[model]:
        for ri in plot_data[model][ll]:
            sp = plot_data[model][ll][ri][0]
            acc = plot_data[model][ll][ri][1]
            err = plot_data[model][ll][ri][2]

            sp, acc, err = zip(*sorted(zip(sp, acc, err)))

            if model not in sorted_plot_data:
                sorted_plot_data[model] = {}

            if ll not in sorted_plot_data[model]:
                sorted_plot_data[model][ll] = {}

            if ri not in sorted_plot_data[model][ll]:
                sorted_plot_data[model][ll][ri] = [[], [], []]

            sorted_plot_data[model][ll][ri][0] = sp
            sorted_plot_data[model][ll][ri][1] = acc
            sorted_plot_data[model][ll][ri][2] = err

In [None]:
import seaborn as sns
from matplotlib import cm
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "Serif"
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1})

fig, ax = plt.subplots(figsize=(5, 5))

for model in sorted_plot_data:
    if model == 'trf_96_1':
        for ll in sorted_plot_data[model]:
            if ll in ['4']:
                for ri in sorted_plot_data[model][ll]:
                    if ri in ['0', '5']:
                        sp = sorted_plot_data[model][ll][ri][0]
                        acc = sorted_plot_data[model][ll][ri][1]
                        err = sorted_plot_data[model][ll][ri][2]

                        ax.errorbar(sp, acc, yerr=err, label='TRF-256 RI = '+ri, linewidth=2, 
                                marker='s', markersize=3, capsize=4)

ax.plot([1, 2, 3, 4], [0.655, 0.66, 0.645, 0.81], linestyle='--', color='C0', 
        linewidth=2, marker='s', markersize=7, label='Human RI = 0')
ax.plot([1, 2, 3, 4], [0.655, 0.705, 0.68, 0.71], linestyle='--', color='C1', 
        linewidth=2, marker='s', markersize=7, label='Human RI = 5')

sns.despine(left=False, bottom=False, right=True, top=True)

ax.set_xlabel('Serial Position', fontsize=25)
ax.set_ylabel('Top-1 Accuracy', fontsize=25)

ax.set_xticks([1, 2, 3, 4])
ax.set_xticklabels([1, 2, 3, 4])

ax.set_ylim([0.5, 1.01])

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.legend(frameon=False, loc='upper center', ncol=1, bbox_to_anchor=(1.39, 1.0), 
           prop={'size': 16})

plt.show()