In [None]:
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import torch.nn as nn
import torch.optim
from tqdm.auto import tqdm

import multitask.dataset as dataset
from multitask.models.individual import get_individual_model
from multitask.models.individual import train as train_individual
from multitask.models.individual import hooks as hooks_individual
from multitask.models.parallel import get_parallel_model
from multitask.models.parallel import train as train_parallel
from multitask.models.parallel import hooks as hooks_parallel
from multitask.models.task_switching import get_task_model
from multitask.models.task_switching import train as train_task_switching
from multitask.models.task_switching import hooks as hooks_task_switching

from train.utils.argparse import check_runs
from train.utils.training import get_device

sns.set_theme(style='ticks', palette='pastel')
mpl.rcParams['font.family'] = 'Liberation Sans'
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

model_path_individual = os.path.join('out', 'individual')
model_path_parallel = os.path.join('out', 'parallel')
model_path_task_switching = os.path.join('out', 'task_switching')

In [None]:
num_runs = 10
initial_seed = 6789
max_seed = 10e5
num_epochs = 100
num_hidden = 5 * [100]
batch_size = 100
num_train = 50000
num_test = 10000
tasks_names = ['parity', 'value']
# tasks_names = ["parity", "small", "prime", "fibonacci", "multiples_3"]
# tasks_names = ["parity", "imparity", "small", "large", "prime", "not_prime", "fibonacci", "not_fibonacci", "multiples_3", "not_multiples_3"]
idxs_contexts = [0, 1, 2, 3, 4]

In [None]:
parameters = {
    'num_runs': num_runs,
    'initial_seed': initial_seed,
    'max_seed': max_seed,
    'num_epochs': num_epochs,
    'num_hidden': num_hidden,
    'batch_size': batch_size,
    'num_train': num_train,
    'num_test': num_test,
    'tasks': tasks_names,
    'idxs_contexts': idxs_contexts
}
data_folder_task_switching = check_runs(model_path_task_switching, parameters)

parameters['idxs_contexts'] = None
data_folder_individual = check_runs(model_path_individual, parameters)
data_folder_parallel = check_runs(model_path_parallel, parameters)

In [None]:
pickle_data_individual = os.path.join(data_folder_individual, 'data.pickle')
with open(pickle_data_individual, 'rb') as handle:
    results_individual = pickle.load(handle)

pickle_data_parallel = os.path.join(data_folder_parallel, 'data.pickle')
with open(pickle_data_parallel, 'rb') as handle:
    results_parallel = pickle.load(handle)

pickle_data_task_switching = os.path.join(data_folder_task_switching, 'data.pickle')
with open(pickle_data_task_switching, 'rb') as handle:
    results_task_switching = pickle.load(handle)

In [None]:
seeds_individual = sorted(list(results_individual.keys()))
seeds_parallel = sorted(list(results_parallel.keys()))
seeds_task_switching = sorted(list(results_task_switching.keys()))
assert seeds_individual == seeds_parallel == seeds_task_switching

In [None]:
device = get_device()
criterion = nn.CrossEntropyLoss()
tasks = dataset.get_tasks_dict(tasks_names, root='data')
num_tasks = len(tasks)
num_layers = len(num_hidden)

In [None]:
sparsity_individual = np.zeros((num_tasks, num_runs, num_layers))
dead_individual = np.zeros((num_tasks, num_runs, num_layers))

for i_seed, seed in tqdm(enumerate(seeds_individual), total=len(seeds_individual)):
    indices = results_individual[seed]['indices']
    test_model = get_individual_model(num_hidden, device)
    test_sampler = dataset.SequentialSampler(indices['test'])

    for i_task, (task_name, task_dataset) in enumerate(tasks.items()):
        saved_model = results_individual[seed][task_name]['model']
        test_model.load_state_dict(saved_model)
        test_model = test_model.to(device)
        
        test_sampler = dataset.SequentialSampler(indices['test'])
        testloader = torch.utils.data.DataLoader(task_dataset,
                                                 sampler=test_sampler,
                                                 batch_size=100)

        _, activations_individuals = hooks_individual.get_layer_activations(test_model,
                                                             testloader,
                                                             criterion,
                                                             device=device,
                                                             disable=True)
                                                
        for j_layer in range(num_layers):
            layer = f'layer{j_layer+1}'
            sparsity_individual[i_task, i_seed, j_layer] = 100 * (np.sum(activations_individuals[layer] == 0, axis=1).mean() / num_hidden[j_layer])
            dead_individual[i_task, i_seed, j_layer] = 100 * (np.sum(activations_individuals[layer].sum(axis=0) == 0) / num_hidden[j_layer])

In [None]:
parallel_datasets = {}
for task_name in tasks_names:
    parallel_datasets[task_name] = tasks[task_name]

parallel_tasks = dataset.MultilabelTasks(parallel_datasets)

In [None]:
sparsity_parallel = np.zeros((num_runs, num_layers))
dead_parallel = np.zeros((num_runs, num_layers))

for i_seed, seed in tqdm(enumerate(seeds_parallel), total=num_runs):
    saved_model = results_parallel[seed]['model']
    test_model = get_parallel_model(num_tasks,
                               num_hidden,
                               device)
    test_model.load_state_dict(saved_model)
    test_model = test_model.to(device)
    
    indices = results_parallel[seed]['indices']

    test_sampler = dataset.SequentialSampler(indices['test'])
    parallel_testloader = torch.utils.data.DataLoader(parallel_tasks,
                                                      sampler=test_sampler,
                                                      batch_size=batch_size)

    numbers = parallel_datasets[tasks_names[0]].numbers
    numbers = numbers[indices['test']]
    
    _, activations_parallel = hooks_parallel.get_layer_activations(test_model,
                                                       parallel_testloader,
                                                       criterion=criterion,
                                                       device=device,
                                                       disable=True)
    
    for j_layer in range(num_layers):
        layer = f'layer{j_layer+1}'
        sparsity_parallel[i_seed, j_layer] = 100 * (np.sum(activations_parallel[layer] == 0, axis=1).mean() / num_hidden[j_layer])
        dead_parallel[i_seed, j_layer] = 100 * (np.sum(activations_parallel[layer].sum(axis=0) == 0) / num_hidden[j_layer])

In [None]:
tasks_datasets = dataset.get_tasks_dict(tasks_names, root='data')

task_switching_tasks = {}
num_tasks = len(tasks_names)

for i_context, task_name in enumerate(tasks_names):
    task_switching_tasks[task_name] = {}
    task_switching_tasks[task_name]['data'] = tasks_datasets[task_name]
    task_switching_tasks[task_name]['activations'] = num_tasks * [0]
    task_switching_tasks[task_name]['activations'][i_context] = 1

for key, value in task_switching_tasks.items():
    print(f'{key}: {value["activations"]}')

In [None]:
sparsity_task_switching = np.zeros((num_runs, num_layers))
dead_task_switching = np.zeros((num_runs, num_layers))


for i_seed, seed in tqdm(enumerate(seeds_task_switching), total=num_runs):
    state_dict = results_task_switching[seed]['model']
    model = get_task_model(task_switching_tasks,
                           num_hidden,
                           idxs_contexts,
                           device)
    model.load_state_dict(state_dict)
    
    indices = results_task_switching[seed]['indices']

    test_sampler = dataset.SequentialSampler(indices['test'])
    _, test_dataloaders = dataset.create_dict_dataloaders(task_switching_tasks,
                                                          indices,
                                                          batch_size=batch_size)
    tasks_testloader = dataset.SequentialTaskDataloader(test_dataloaders)

    numbers = test_dataloaders[tasks_names[0]].dataset.numbers.numpy()
    numbers = numbers[indices['test']]

    _, activations_task_switching = hooks_task_switching.get_layer_activations(model,
                                                             tasks_testloader,
                                                             criterion,
                                                             device=device,
                                                             disable=True)

    for j_layer in range(num_layers):
        layer = f'layer{j_layer+1}'
        for i_task, task in enumerate(tasks_names):
            if i_task == 0:
                total_activations_layer = activations_task_switching[task][layer]
            else:
                total_activations_layer = np.vstack((total_activations_layer,
                                                     activations_task_switching[task][layer]))
        
        sparsity_task_switching[i_seed, j_layer] = 100 * (np.sum(total_activations_layer == 0, axis=1).mean() / num_hidden[j_layer])
        dead_task_switching[i_seed, j_layer] = 100 * (np.sum(total_activations_layer.sum(axis=0) == 0) / num_hidden[j_layer])

In [None]:
mean_sparsity_individual = sparsity_individual.mean(axis=0).mean(axis=0)
mean_sparsity_parallel = sparsity_parallel.mean(axis=0)
mean_sparsity_task_switching = sparsity_task_switching.mean(axis=0)

std_sparsity_individual = sparsity_individual.mean(axis=0).std(axis=0)
std_sparsity_parallel = sparsity_parallel.std(axis=0)
std_sparsity_task_switching = sparsity_task_switching.std(axis=0)

In [None]:
layers = range(1, num_layers + 1)
fig = plt.figure()

plt.plot(layers, mean_sparsity_individual)
plt.plot(layers, mean_sparsity_parallel)
plt.plot(layers, mean_sparsity_task_switching,)

plt.fill_between(layers,
                 mean_sparsity_individual-std_sparsity_individual,
                 mean_sparsity_individual+std_sparsity_individual,
                 alpha=0.5)

plt.fill_between(layers,
                 mean_sparsity_parallel-std_sparsity_parallel,
                 mean_sparsity_parallel+std_sparsity_parallel,
                 alpha=0.5)


plt.fill_between(layers,
                 mean_sparsity_task_switching-std_sparsity_task_switching,
                 mean_sparsity_task_switching+std_sparsity_task_switching,
                 alpha=0.5)
plt.xlabel('Layer', fontsize=16)
# plt.ylabel('Mean Squared Error', fontsize=16)
plt.ylabel('Sparsity (%)', fontsize=16)
plt.xticks(layers, fontsize=14, fontname='Liberation Sans')
plt.yticks(fontsize=12,  fontname='Liberation Sans')
plt.legend(['Individual', 'Parallel', 'Task Switching'], prop={'size':12})
plt.show()

In [None]:
df_sparsity_all = pd.DataFrame({}, columns=['Sparsity', 'Model', 'Layer'])
for i_layer in range(num_layers):
    layer = f'layer{i_layer+1}'
    df_sparsity_individual = pd.DataFrame({'Sparsity': sparsity_individual.mean(axis=0)[:, i_layer], 'Model': 'Individual', 'Layer': layer})
    df_sparsity_parallel = pd.DataFrame({'Sparsity': sparsity_parallel[:, i_layer], 'Model': 'Parallel', 'Layer': layer})
    df_sparsity_task_switching = pd.DataFrame({'Sparsity': sparsity_task_switching[:, i_layer], 'Model': 'Task Switching', 'Layer': layer})
    df_sparsity_all = pd.concat([df_sparsity_all, df_sparsity_individual, df_sparsity_parallel, df_sparsity_task_switching])

In [None]:
mean_dead_individual = dead_individual.mean(axis=0).mean(axis=0)
mean_dead_parallel = dead_parallel.mean(axis=0)
mean_dead_task_switching = dead_task_switching.mean(axis=0)

std_dead_individual = dead_individual.mean(axis=0).std(axis=0)
std_dead_parallel = dead_parallel.std(axis=0)
std_dead_task_switching = dead_task_switching.std(axis=0)

In [None]:
layers = range(1, num_layers + 1)
fig = plt.figure()

plt.plot(layers, mean_dead_individual)
plt.plot(layers, mean_dead_parallel)
plt.plot(layers, mean_dead_task_switching,)

plt.fill_between(layers,
                 mean_dead_individual-std_dead_individual,
                 mean_dead_individual+std_dead_individual,
                 alpha=0.5)

plt.fill_between(layers,
                 mean_dead_parallel-std_dead_parallel,
                 mean_dead_parallel+std_dead_parallel,
                 alpha=0.5)


plt.fill_between(layers,
                 mean_dead_task_switching-std_dead_task_switching,
                 mean_dead_task_switching+std_dead_task_switching,
                 alpha=0.5)
plt.xlabel('Layer', fontsize=16)
# plt.ylabel('Mean Squared Error', fontsize=16)
plt.ylabel('Dead Units (%)', fontsize=16)
plt.xticks(layers, fontsize=14, fontname='Liberation Sans')
plt.yticks(fontsize=12,  fontname='Liberation Sans')
plt.legend(['Individual', 'Parallel', 'Task Switching'], prop={'size':12})
plt.show()

In [None]:
df_sparsity_all = pd.DataFrame({}, columns=['Sparsity', 'Model', 'Layer'])
for i_layer in range(num_layers):
    layer = f'layer{i_layer+1}'
    df_sparsity_individual = pd.DataFrame({'Sparsity': sparsity_individual.mean(axis=0)[:, i_layer], 'Model': 'Individual', 'Layer': layer})
    df_sparsity_parallel = pd.DataFrame({'Sparsity': sparsity_parallel[:, i_layer], 'Model': 'Parallel', 'Layer': layer})
    df_sparsity_task_switching = pd.DataFrame({'Sparsity': sparsity_task_switching[:, i_layer], 'Model': 'Task Switching', 'Layer': layer})
    df_sparsity_all = pd.concat([df_sparsity_all, df_sparsity_individual, df_sparsity_parallel, df_sparsity_task_switching])

In [None]:
df_dead_all = pd.DataFrame({}, columns=['Dead', 'Model', 'Layer'])
for i_layer in range(num_layers):
    layer = f'layer{i_layer+1}'
    df_dead_individual = pd.DataFrame({'Dead': dead_individual.mean(axis=0)[:, i_layer], 'Model': 'Individual', 'Layer': layer})
    df_dead_parallel = pd.DataFrame({'Dead': dead_parallel[:, i_layer], 'Model': 'Parallel', 'Layer': layer})
    df_dead_task_switching = pd.DataFrame({'Dead': dead_task_switching[:, i_layer], 'Model': 'Task Switching', 'Layer': layer})
    df_dead_all = pd.concat([df_dead_all, df_dead_individual, df_dead_parallel, df_dead_task_switching])

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 4))

sns.barplot(data=df_sparsity_all, x="Layer", y="Sparsity", hue='Model', ci='sd', ax=ax[0])
sns.barplot(data=df_dead_all, x="Layer", y="Dead", hue='Model', ci='sd', ax=ax[1])
fig.suptitle(f'Num. Layers: {num_layers}   Num. Units: {num_hidden[0]}')

ax[0].set_ylim(40, 100)
ax[1].set_ylim(0, 100)

fig.tight_layout()
fig.savefig(f'figures/figS02_sparsity_{num_layers}_{num_hidden[0]}_{num_tasks}.pdf')
plt.show()