In [None]:
import torch
from tqdm import tqdm
from argparse import Namespace


config = Namespace(
    data_folder='./wm_bench_data', 
    max_seq_len=20, 
    rs_img_size=32, 
    batch_size=10, 
    num_workers=4, 
    use_cnn=1, 
    model_path='./model.pt',
)

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

In [None]:
from src.utils.data_utils import get_test_multitask_dataloader

test_loader = get_test_multitask_dataloader(config)

In [None]:
from src.model import WM_Model

model_data = torch.load(config.model_path)
model = WM_Model(Namespace(**model_data['config']), device).to(device)
model.load_state_dict(model_data['model_state_dict'])
model.eval()

In [None]:
model_dict = {
    'model_1': model
}

In [None]:
import numpy as np

epoch_acc = {}

dataloader = zip(*test_loader.values())

with torch.no_grad():
    for batch_index, multitask_batch in tqdm(enumerate(dataloader)):
        stim_batch, resp_batch, seq_len, num_distractor, variation = multitask_batch[13]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        for model_name, model in model_dict.items():
            out, _, _, _, _ = model(stim_batch, 'Complex_WM_Task', seq_len)

            pred = torch.argmax(out, dim=-1)

            for index, length in enumerate(seq_len):
                trial_acc = []

                if (model_name+'_Variation_'+str(variation[index].item())+'_Processing_'+
                    str(num_distractor[index].item())+'_SP') not in epoch_acc:
                    epoch_acc[model_name+'_Variation_'+str(variation[index].item())+'_Processing_'+
                              str(num_distractor[index].item())+'_SP'] = []

                if pred[index, 16] == resp_batch[index, 16]:
                    trial_acc.append(1)
                else:
                    trial_acc.append(0)

                if pred[index, 17] == resp_batch[index, 17]:
                    trial_acc.append(1)
                else:
                    trial_acc.append(0)

                epoch_acc[model_name+'_Variation_'+str(variation[index].item())+'_Processing_'+
                            str(num_distractor[index].item())+'_SP'].append(np.mean(trial_acc))

In [None]:
import json

with open('cs_accs.json', 'w') as fp:
    json.dump(epoch_acc, fp)

In [None]:
import numpy as np
from scipy import stats

plot_data = {}

for key in list(epoch_acc.keys()):
    model = key.split('_')[0] + '_' + key.split('_')[1] + '_' + key.split('_')[2]
    variation = key.split('_')[4]
    processing = key.split('_')[6]

    if model not in plot_data:
        plot_data[model] = {}
    if variation not in plot_data[model]:
        plot_data[model][variation] = [[], [], []]

    plot_data[model][variation][0].append(processing)
    plot_data[model][variation][1].append(np.mean(epoch_acc[key]))
    plot_data[model][variation][2].append(stats.sem(epoch_acc[key]))

In [None]:
import seaborn as sns
from matplotlib import cm
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "Serif"
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1})

plt.figure(figsize=(10, 5))

ax1 = plt.subplot(2, 4, 7)
ax2 = plt.subplot(1, 2, 1)


model = 'trf_1024_2'
cl = ("1", "3", "5")

data = {
    'Spatial-Spatial': [plot_data[model]['0'][1][1:], plot_data[model]['0'][2][1:]], 
    'Spatial-Visual': [plot_data[model]['1'][1][1:], plot_data[model]['1'][2][1:]], 
    'Visual-Spatial': [plot_data[model]['2'][1][1:], plot_data[model]['2'][2][1:]], 
    'Visual-Visual': [plot_data[model]['3'][1][1:], plot_data[model]['3'][2][1:]]
}


for attribute, measurement in data.items():
    ax2.errorbar(cl, measurement[0], yerr=measurement[1], fmt='o-', label=attribute, 
                 linewidth=2.5, markersize=3, capsize=4)


ax2.set_ylabel('Top-1 Accuracy', fontsize=25)

ax2.set_xticks([0, 1, 2])
ax2.set_xticklabels([1, 3, 5], fontsize=20)

ax2.set_ylim([0.0, 1.01])
ax2.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])


model = 'human'

data = {
    'Spatial-Spatial': [[0.33, 0.41, 0.49], [0.72, 0.66, 0.60]], 
    'Spatial-Visual': [[0.27, 0.345, 0.405], [0.775, 0.715, 0.60]], 
    'Visual-Spatial': [[0.33, 0.425, 0.53], [0.68, 0.57, 0.43]], 
    'Visual-Visual': [[0.275, 0.35, 0.42], [0.695, 0.55, 0.525]]
}


for attribute, measurement in data.items():
    ax1.errorbar(measurement[0], measurement[1], yerr=0, fmt='o-', label=attribute, 
                 linewidth=2, markersize=7)

sns.despine(left=False, bottom=False, right=True, top=True)

ax1.set_xticklabels([])
ax1.set_ylim(0.3, 1.0)
ax1.set_yticks([0.4, 0.6, 0.8, 1.0])
ax1.set_yticklabels([0.4, 0.6, 0.8, 1.0], fontsize=17)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.legend(frameon=False, loc='upper center', bbox_to_anchor=(1.39, 0.95), ncol=1, prop={'size': 16})

plt.text(1.5, -0.17, 'Cognitive Load', fontsize=25)
plt.show()