In [None]:
import torch
from tqdm import tqdm
from argparse import Namespace


config = Namespace(
    data_folder='./wm_bench_data', 
    max_seq_len=20, 
    rs_img_size=32, 
    batch_size=10, 
    num_workers=4, 
    use_cnn=1, 
    model_path='./model.pt')

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

In [None]:
from src.utils.data_utils import get_test_multitask_dataloader

test_loader = get_test_multitask_dataloader(config)

In [None]:
from src.model import WM_Model

model_data = torch.load(config.model_path)
model = WM_Model(Namespace(**model_data['config']), device).to(device)
model.load_state_dict(model_data['model_state_dict'])
model.eval()

In [None]:
model_dict = {
    'model_1': model,
}

In [None]:
import numpy as np

epoch_acc = {}

dataloader = zip(*test_loader.values())

with torch.no_grad():
    for batch_index, multitask_batch in tqdm(enumerate(dataloader)):
        stim_batch, resp_batch, seq_len, set_size = multitask_batch[3]

        stim_batch = stim_batch.to(device)
        resp_batch = resp_batch.to(device)

        for model_name, model in model_dict.items():
            out, _, _, _, _ = model(stim_batch, 'SMU_Task', seq_len)
            pred = torch.argmax(out, dim=-1)

            for index, length in enumerate(seq_len):
                trial = []
                for count, val in enumerate(pred[index][length-set_size[index]:length]):
                    if val.item() == resp_batch[index][length-set_size[index]+count].item():
                        trial.append(1)
                    else:
                        trial.append(0)
                
                if model_name+'_Set_Size_'+str(set_size[index].item()) not in epoch_acc:
                    epoch_acc[model_name+'_Set_Size_'+str(set_size[index].item())] = []
                
                epoch_acc[model_name+'_Set_Size_'+str(set_size[index].item())].append(np.mean(trial))

In [None]:
import json

with open('smu_accs.json', 'w') as fp:
    json.dump(epoch_acc, fp)

In [None]:
import numpy as np
from scipy import stats

epoch_acc_std_err = {}

for key in epoch_acc:
    epoch_acc_std_err[key] = [np.mean(epoch_acc[key]), 
                              stats.sem(epoch_acc[key])]

In [None]:
plot_data = {}

for key in list(epoch_acc_std_err.keys()):
    model = key.split('_')[0] + '_' + key.split('_')[1]
    set_size = int(key.split('_')[-1])

    if model not in plot_data:
        plot_data[model] = [[], [], []]
    
    plot_data[model][0].append(set_size)
    plot_data[model][1].append(epoch_acc_std_err[key][0])
    plot_data[model][2].append(epoch_acc_std_err[key][1])


In [None]:
# Human data
plot_data['human'] = [[1, 2, 3, 4], [1.0, 0.94, 0.825, 0.68]]

In [None]:
import seaborn as sns
from matplotlib import cm
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "Serif"
sns.set_context("paper", font_scale=1.5, rc={"lines.linewidth": 1})

fig, ax = plt.subplots(figsize=(5, 5))

cmap = cm.get_cmap('Reds')
colors = np.linspace(0.4, 0.9, 3)

for index, model in enumerate(plot_data):
    if model.split('_')[0] == 'lstm' and model.split('_')[1] in ['96']:
        ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=plot_data[model][2], 
                    fmt='-^', label='LSTM-96', color=cmap(colors[0]), linewidth=2, markersize=3, capsize=4)
        
for index, model in enumerate(plot_data):
    if model.split('_')[0] == 'lstm' and model.split('_')[1] in ['256']:
        ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=plot_data[model][2], 
                    fmt='-^', label='LSTM-256', color=cmap(colors[1]), linewidth=2, markersize=3, capsize=4)

for index, model in enumerate(plot_data):
    if model.split('_')[0] == 'lstm' and model.split('_')[1] in ['1024']:
        ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=plot_data[model][2], 
                    fmt='-^', label='LSTM-1024', color=cmap(colors[2]), linewidth=2, markersize=3, capsize=4)


cmap = cm.get_cmap('Purples')
colors = np.linspace(0.5, 0.9, 3)

for index, model in enumerate(plot_data):
    if model.split('_')[0] == 'gru' and model.split('_')[1] in ['96']:
        ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=plot_data[model][2], 
                    fmt='-D', label='GRU-96', color=cmap(colors[0]), linewidth=2, markersize=3, capsize=4)
        
for index, model in enumerate(plot_data):
    if model.split('_')[0] == 'gru' and model.split('_')[1] in ['256']:
        ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=plot_data[model][2], 
                    fmt='-D', label='GRU-256', color=cmap(colors[1]), linewidth=2, markersize=3, capsize=4)
        
for index, model in enumerate(plot_data):
    if model.split('_')[0] == 'gru' and model.split('_')[1] in ['1024']:
        ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=plot_data[model][2], 
                    fmt='-D', label='GRU-1024', color=cmap(colors[2]), linewidth=2, markersize=3, capsize=4)


cmap = cm.get_cmap('Blues')
colors = np.linspace(0.4, 0.9, 3)

for index, model in enumerate(plot_data):
    if model.split('_')[0] == 'rnn' and model.split('_')[1] in ['96']:
        if model.split('_')[1] == '96':
            ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=plot_data[model][2], 
                        fmt='-X', label='RNN-96', color=cmap(colors[0]), linewidth=2, markersize=3, capsize=4)
            
for index, model in enumerate(plot_data):
    if model.split('_')[0] == 'rnn' and model.split('_')[1] in ['256']:
        if model.split('_')[1] == '256':
            ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=plot_data[model][2], 
                    fmt='-X', label='RNN-256', color=cmap(colors[1]), linewidth=2, markersize=3, capsize=4)
            
for index, model in enumerate(plot_data):
    if model.split('_')[0] == 'rnn' and model.split('_')[1] in ['1024']:
        if model.split('_')[1] == '1024':
            ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=plot_data[model][2], 
                    fmt='-X', label='RNN-1024', color=cmap(colors[2]), linewidth=2, markersize=3, capsize=4)


cmap = cm.get_cmap('Greens')
colors = np.linspace(0.4, 0.9, 3)

for index, model in enumerate(plot_data):
    if model.split('_')[0] == 'trf' and model.split('_')[1] == '64':
        ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=plot_data[model][2], 
                    fmt='-s', label='TRF-64', color=cmap(colors[1]), linewidth=2, markersize=3, capsize=4)
        

for index, model in enumerate(plot_data):
    if model.split('_')[0] == 'trf' and model.split('_')[1] == '128':
        ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=plot_data[model][2], 
                    fmt='-s', label='TRF-128', color=cmap(colors[2]), linewidth=2, markersize=3, capsize=4)
        

for index, model in enumerate(plot_data):
    if model == 'human':
        ax.errorbar(plot_data[model][0], plot_data[model][1], yerr=0,
                    fmt='--o', color='black', label='Human', linewidth=2, markersize=6)    
        

sns.despine(left=False, bottom=False, right=True, top=True)

ax.set_xlabel('Set Size', fontsize=25)
ax.set_ylabel('Top-1 Accuracy', fontsize=25)

ax.set_xticks([1, 2, 3, 4, 5, 6, 7, 8])
ax.set_xticklabels([1, 2, 3, 4, 5, 6, 7, 8])

ax.set_ylim([0.0, 1.01])

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.legend(frameon=False, loc='upper center', bbox_to_anchor=(1.35, 1.05), ncol=1, prop={'size': 16})

plt.show()