In [None]:
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch.nn as nn
import torch.optim
from tqdm.auto import tqdm

from train.utils.argparse import check_runs

sns.set_theme(style='ticks', palette='pastel')
mpl.rcParams['font.family'] = 'Liberation Sans'
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

model_path_individual = os.path.join('out', 'individual')
model_path_parallel = os.path.join('out', 'parallel')
model_path_task_switching = os.path.join('out', 'task_switching')

In [None]:
num_runs = 10
initial_seed = 6789
max_seed = 10e5
num_epochs = 100
num_hidden = 5 * [100]
batch_size = 100
num_train = 50000
num_test = 10000
tasks_names = ['parity', 'value']
idxs_contexts = list(range(5))

In [None]:
parameters = {
    'num_runs': num_runs,
    'initial_seed': initial_seed,
    'max_seed': max_seed,
    'num_epochs': num_epochs,
    'num_hidden': num_hidden,
    'batch_size': batch_size,
    'num_train': num_train,
    'num_test': num_test,
    'tasks': tasks_names,
    'idxs_contexts': idxs_contexts
}
data_folder_task_switching = check_runs(model_path_task_switching, parameters)

parameters['idxs_contexts'] = None
data_folder_individual = check_runs(model_path_individual, parameters)
data_folder_parallel = check_runs(model_path_parallel, parameters)

In [None]:
pickle_data_individual = os.path.join(data_folder_individual, 'data.pickle')
with open(pickle_data_individual, 'rb') as handle:
    results_individual = pickle.load(handle)

pickle_data_parallel = os.path.join(data_folder_parallel, 'data.pickle')
with open(pickle_data_parallel, 'rb') as handle:
    results_parallel = pickle.load(handle)

pickle_data_task_switching = os.path.join(data_folder_task_switching, 'data.pickle')
with open(pickle_data_task_switching, 'rb') as handle:
    results_task_switching = pickle.load(handle)

In [None]:
seeds_individual = sorted(list(results_individual.keys()))
seeds_parallel = sorted(list(results_parallel.keys()))
seeds_task_switching = sorted(list(results_task_switching.keys()))
assert seeds_individual == seeds_parallel == seeds_task_switching

In [None]:
results_models = {}
results_models['individual'] = {}
results_models['parallel'] = {}
results_models['task_switching'] = {}
list_models = ['individual', 'parallel', 'task_switching']

for model in list_models:
    for task_name in tasks_names:
        results_models[model][task_name] = {}
        results_models[model][task_name]['train_loss'] = np.zeros((num_runs, num_epochs))
        results_models[model][task_name]['train_acc'] = np.zeros((num_runs, num_epochs))
        results_models[model][task_name]['valid_loss'] = np.zeros((num_runs, num_epochs))
        results_models[model][task_name]['valid_acc'] = np.zeros((num_runs, num_epochs))

num_tasks = len(tasks_names)

In [None]:
for i_seed, seed in enumerate(seeds_individual):
    for task_name in tasks_names:
        results_models['individual'][task_name]['train_loss'][i_seed, :] = results_individual[seed][task_name]['results']['train_loss']
        results_models['individual'][task_name]['train_acc'][i_seed, :] = results_individual[seed][task_name]['results']['train_acc']
        results_models['individual'][task_name]['valid_loss'][i_seed, :] = results_individual[seed][task_name]['results']['valid_loss']
        results_models['individual'][task_name]['valid_acc'][i_seed, :] = results_individual[seed][task_name]['results']['valid_acc']

        results_models['parallel'][task_name]['train_loss'][i_seed, :] = results_parallel[seed]['results']['train_loss'][task_name]
        results_models['parallel'][task_name]['train_acc'][i_seed, :] = results_parallel[seed]['results']['train_acc'][task_name]
        results_models['parallel'][task_name]['valid_loss'][i_seed, :] = results_parallel[seed]['results']['valid_loss'][task_name]
        results_models['parallel'][task_name]['valid_acc'][i_seed, :] = results_parallel[seed]['results']['valid_acc'][task_name]

        results_models['task_switching'][task_name]['train_loss'][i_seed, :] = results_task_switching[seed]['results']['train_loss'][task_name]
        results_models['task_switching'][task_name]['train_acc'][i_seed, :] = results_task_switching[seed]['results']['train_acc'][task_name]
        results_models['task_switching'][task_name]['valid_loss'][i_seed, :] = results_task_switching[seed]['results']['valid_loss'][task_name]
        results_models['task_switching'][task_name]['valid_acc'][i_seed, :] = results_task_switching[seed]['results']['valid_acc'][task_name]

In [None]:
from tkinter.tix import Y_REGION


epochs = range(num_epochs)
metrics = ['train_loss', 'train_acc', 'valid_loss', 'valid_acc']

fig, ax = plt.subplots(num_tasks, 4, figsize=(12, 4))

for i_task, task_name in enumerate(tasks_names):
    for j_model, model in enumerate(list_models):
        for k_metric, metric in enumerate(metrics):
            mean_model = results_models[model][task_name][metric].mean(axis=0)
            std_model = results_models[model][task_name][metric].std(axis=0)

            ax[i_task, k_metric].plot(epochs, mean_model)
            ax[i_task, k_metric].fill_between(epochs,
                                              mean_model-std_model,
                                              mean_model+std_model,
                                              alpha=0.5)
            if i_task == num_tasks - 1:
                ax[i_task, k_metric].set_xlabel('Epochs')
            
            ylabel = ' '.join(metric.split('_')).capitalize()
            ax[i_task, k_metric].set_ylabel(ylabel)

            # if metric == 'train_loss' or metric == 'valid_loss':
            #     ax[i_task, k_metric].set_ylim(0, 0.2)

            # if metric == 'train_acc' or metric == 'valid_acc':
            #     ax[i_task, k_metric].set_ylim(0.9, 1)


fig.tight_layout()
plt.show()