In [None]:
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn as nn
from tqdm.auto import tqdm

import multitask.dataset as dataset
from multitask.models.task_switching import get_task_model
import multitask.models.task_switching.hooks as hooks
from multitask.utils.training import get_device
from multitask.utils.argparse import check_runs

In [None]:
sns.set_theme(style='ticks', palette='pastel')
mpl.rcParams['font.family'] = 'Liberation Sans'
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

In [None]:
path_data = os.path.join('..', '..', 'data')
path_model_task_switching = os.path.join('..', '..', 'results', 'task_switching')

In [None]:
num_runs = 10
initial_seed = 6789
max_seed = 10e5
num_epochs = 50
num_hidden = 10 * [100]
batch_size = 100
num_train = 50000
num_test = 10000
tasks_names = ['parity', 'value']
num_tasks = len(tasks_names)

In [None]:
num_layers = len(num_hidden)
list_results = []

for max_contexts in range(1, num_layers+1):

    idxs_contexts = list(range(max_contexts))
    num_hidden_contexts = len(idxs_contexts) * [num_hidden[0]]
    print(idxs_contexts, num_hidden_contexts)

    parameters = {
        'num_runs': num_runs,
        'initial_seed': initial_seed,
        'max_seed': max_seed,
        'num_epochs': num_epochs,
        'num_hidden': num_hidden_contexts,
        'batch_size': batch_size,
        'num_train': num_train,
        'num_test': num_test,
        'tasks': tasks_names,
        'idxs_contexts': idxs_contexts
    }

    data_folder = check_runs(path_model_task_switching, parameters)

    pickle_data = os.path.join(data_folder, 'data.pickle')
    with open(pickle_data, 'rb') as handle:
        results_task_switching = pickle.load(handle)
    list_results.append(results_task_switching)


In [None]:
if num_tasks > 2:
    raise NotImplementedError

In [None]:
tasks_datasets = dataset.get_tasks_dict(tasks_names, root=path_data)

task_switching_tasks = {}
num_tasks = len(tasks_names)

for i_context, task_name in enumerate(tasks_names):
    task_switching_tasks[task_name] = {}
    task_switching_tasks[task_name]['data'] = tasks_datasets[task_name]
    task_switching_tasks[task_name]['activations'] = num_tasks * [0]
    task_switching_tasks[task_name]['activations'][i_context] = 1

for key, value in task_switching_tasks.items():
    print(f'{key}: {value["activations"]}')

In [None]:
criterion = nn.CrossEntropyLoss()
device = get_device()

acc_test_parity = np.zeros((num_runs, num_layers))
acc_test_value = np.zeros((num_runs, num_layers))
acc_test_joint = np.zeros((num_runs, num_layers))


for i_results, results in enumerate(list_results):
    seeds = list(results.keys())
    idxs_contexts = list(range(i_results+1))
    num_hidden_contexts = len(idxs_contexts) * [num_hidden[0]]
    print(idxs_contexts, num_hidden_contexts)
    for j_seed, seed in enumerate(seeds):
        state_dict = results[seed]['model']
        model = get_task_model(task_switching_tasks,
                               num_hidden_contexts,
                               idxs_contexts,
                               device)
        model.load_state_dict(state_dict)
        indices = results[seed]['indices']

        test_sampler = dataset.SequentialSampler(indices['test'])
        _, test_dataloaders = dataset.create_dict_dataloaders(task_switching_tasks,
                                                            indices,
                                                            batch_size=batch_size)
        tasks_testloader = dataset.SequentialTaskDataloader(test_dataloaders)

        numbers = test_dataloaders[tasks_names[0]].dataset.numbers.numpy()
        numbers = numbers[indices['test']]

        acc_test_increase, _ = hooks.get_layer_activations(model,
                                                           tasks_testloader,
                                                           criterion,
                                                           device=device,
                                                           disable=True)
        
        acc_test_parity[j_seed, i_results] = acc_test_increase['parity'].mean()
        acc_test_value[j_seed, i_results] = acc_test_increase['value'].mean()
        acc_test_joint[j_seed, i_results] = (acc_test_increase['parity'] * acc_test_increase['value']).mean()

In [None]:
accuracies_df = pd.DataFrame(columns=['Acc', 'Idx', 'Task'])

# for i_acc_parity, acc_parity in enumerate(acc_test_parity.T):
#     parity_df = pd.DataFrame({'Acc': acc_parity, 'Idx': i_acc_parity, 'Task': 'Parity'})
#     accuracies_df = pd.concat((accuracies_df, parity_df))

# for i_acc_value, acc_value in enumerate(acc_test_value.T):
#     value_df = pd.DataFrame({'Acc': acc_value, 'Idx': i_acc_value, 'Task': 'Value'})
#     accuracies_df = pd.concat((accuracies_df, value_df))

for i_acc_joint, acc_joint in enumerate(acc_test_joint.T):
     joint_df = pd.DataFrame({'Acc': acc_joint, 'Idx': i_acc_joint, 'Task': 'Joint'})
     accuracies_df = pd.concat((accuracies_df, joint_df))

In [None]:
fig = plt.figure()
sns.barplot(x='Acc', y='Idx', hue='Task', data=accuracies_df, errorbar=('se'), errwidth=1.5, capsize=0.15, orient='horizontal')
plt.xlabel('Accuracy', fontsize=16)
plt.ylabel('Task', fontsize=16)
plt.xlim(0.96, 0.985)
plt.legend(loc='lower left', prop={'size':12})

fig.savefig('figures/figure05/fig05c_acc_weights_contexts_layers.svg')
plt.show()