In [1]:
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch.nn as nn
import torch.optim
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

import multitask.dataset as dataset
from multitask.models.task_switching import get_task_model
import multitask.models.task_switching.hooks as hooks
from multitask.utils.training import get_device
from multitask.utils.argparse import check_runs

In [2]:
path_data = os.path.join('..', '..', 'data')
path_pickle = os.path.join('pickle', 'results_linear_decoder_task_switching.pickle')
path_model_task_switching = os.path.join('..', '..', 'results', 'task_switching')

In [3]:
num_runs = 10
initial_seed = 1234
max_seed = 10e5
num_epochs = 50
num_hidden = 10 * [100]
batch_size = 100
num_train = 41080
num_test = 8216
tasks_names = ['vowel', 'position']
idxs_contexts = list(range(len(num_hidden)))

In [4]:
parameters = {
    'num_runs': num_runs,
    'initial_seed': initial_seed,
    'max_seed': max_seed,
    'num_epochs': num_epochs,
    'num_hidden': num_hidden,
    'batch_size': batch_size,
    'num_train': num_train,
    'num_test': num_test,
    'tasks': tasks_names,
    'idxs_contexts': idxs_contexts
}

data_folder = check_runs(path_model_task_switching, parameters)

Found simulation in ../../results/task_switching with the same parameters (2024-01-12_18-56-09)


In [5]:
pickle_data = os.path.join(data_folder, 'data.pickle')
with open(pickle_data, 'rb') as handle:
    results_task_switching = pickle.load(handle)

In [6]:
seeds = sorted(list(results_task_switching.keys()))
num_seeds = len(seeds)
num_tasks = len(tasks_names)

print(seeds)
print(tasks_names)

[165158, 220532, 318129, 451283, 486191, 514041, 818831, 869016, 908341, 978124]
['vowel', 'position']


In [7]:
tasks_datasets = dataset.get_tasks_dict(tasks_names, root=path_data)

task_switching_tasks = {}
num_tasks = len(tasks_names)

for i_context, task_name in enumerate(tasks_names):
    task_switching_tasks[task_name] = {}
    task_switching_tasks[task_name]['data'] = tasks_datasets[task_name]
    task_switching_tasks[task_name]['activations'] = num_tasks * [0]
    task_switching_tasks[task_name]['activations'][i_context] = 1  # Set to 0 for Removed

for key, value in task_switching_tasks.items():
    print(f'{key}: {value["activations"]}')

vowel: [1, 0]
position: [0, 1]


In [8]:
device = get_device()
criterion = nn.CrossEntropyLoss()

seeds_task_swithing  = sorted(list(results_task_switching.keys()))
list_activations = []
list_letters = []

for i_seed, seed in tqdm(enumerate(seeds_task_swithing), total=num_runs):
    state_dict = results_task_switching[seed]['model']
    model = get_task_model(task_switching_tasks,
                           num_hidden,
                           idxs_contexts,
                           device)
    model.load_state_dict(state_dict)
    
    indices = results_task_switching[seed]['indices']

    test_sampler = dataset.SequentialSampler(indices['test'])
    _, test_dataloaders = dataset.create_dict_dataloaders(task_switching_tasks,
                                                          indices,
                                                          batch_size=batch_size)
    tasks_testloader = dataset.SequentialTaskDataloader(test_dataloaders)

    letters = test_dataloaders[tasks_names[0]].dataset.letters.numpy()
    letters = letters[indices['test']]

    _, activations = hooks.get_layer_activations(model,
                                                tasks_testloader,
                                                criterion,
                                                device=device,
                                                disable=True)
    
    list_activations.append(activations)
    list_letters.append(letters)

Running on GPU.


  0%|          | 0/10 [00:00<?, ?it/s]

In [9]:
num_layers = len(num_hidden)
max_iter = 8000

acc_letters_all = np.zeros((num_seeds, num_layers))
acc_tasks_all = np.zeros((num_seeds, num_layers))
acc_congruency_all = np.zeros((num_seeds, num_layers))
acc_output_all = np.zeros((num_seeds, num_layers))


for i_seed, seed in enumerate(seeds):
    activations = list_activations[i_seed]
    letters = list_letters[i_seed]

    labels_letters = np.hstack((letters, letters))
    labels_task = np.concatenate((np.zeros_like(letters), np.ones_like(letters)))
    labels_congruency = np.array([1 if letter in [0, 4, 8, 14, 20, 13, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25] else 0 for letter in labels_letters])

    labels_output_vowel = np.array([1 if letter in [0, 4, 8, 14, 20] else 0 for letter in letters])
    labels_output_position = np.array([1 if letter < 13 else 0 for letter in letters])
    labels_output = np.concatenate((labels_output_vowel, labels_output_position))

    for j_layer in tqdm(range(num_layers), desc=f'{i_seed} [{seed}]'):
        activations_decoder = None
        for task in tasks_names:
            activations_task = activations[task][f'layer{j_layer+1}']
            if activations_decoder is None:
                activations_decoder = activations_task
            else:
                activations_decoder = np.vstack((activations_decoder, 
                                                activations_task))
        assert activations_decoder.shape[0] == labels_letters.shape[0]

        activations_decoder = (activations_decoder - activations_decoder.mean()) / activations_decoder.std()

        # Letters task
        X_train, X_test, y_train, y_test = train_test_split(activations_decoder,
                                                            labels_letters,
                                                            test_size=0.1,
                                                            random_state=seed)
        clf = LogisticRegression(random_state=seed,
                                max_iter=max_iter,
                                tol=1e-3).fit(X_train, y_train)
        acc_letters_all[i_seed, j_layer] = clf.score(X_test, y_test)

        # Labels task
        X_train, X_test, y_train, y_test = train_test_split(activations_decoder,
                                                            labels_task,
                                                            test_size=0.1,
                                                            random_state=seed)
        clf = LogisticRegression(random_state=seed,
                                max_iter=max_iter,
                                tol=1e-3).fit(X_train, y_train)
        acc_tasks_all[i_seed, j_layer] = clf.score(X_test, y_test)

        # Congruency task
        X_train, X_test, y_train, y_test = train_test_split(activations_decoder,
                                                            labels_congruency,
                                                            test_size=0.1,
                                                            random_state=seed)
        clf = LogisticRegression(random_state=seed,
                                max_iter=max_iter,
                                tol=1e-3).fit(X_train, y_train)
        acc_congruency_all[i_seed, j_layer] = clf.score(X_test, y_test)

        # Output task
        X_train, X_test, y_train, y_test = train_test_split(activations_decoder,
                                                            labels_output,
                                                            test_size=0.1,
                                                            random_state=seed)
        clf = LogisticRegression(random_state=seed,
                                max_iter=max_iter,
                                tol=1e-3).fit(X_train, y_train)
        acc_output_all[i_seed, j_layer] = clf.score(X_test, y_test)

0 [165158]:   0%|          | 0/10 [00:00<?, ?it/s]

1 [220532]:   0%|          | 0/10 [00:00<?, ?it/s]

2 [318129]:   0%|          | 0/10 [00:00<?, ?it/s]

3 [451283]:   0%|          | 0/10 [00:00<?, ?it/s]

4 [486191]:   0%|          | 0/10 [00:00<?, ?it/s]

5 [514041]:   0%|          | 0/10 [00:00<?, ?it/s]

6 [818831]:   0%|          | 0/10 [00:00<?, ?it/s]

7 [869016]:   0%|          | 0/10 [00:00<?, ?it/s]

8 [908341]:   0%|          | 0/10 [00:00<?, ?it/s]

9 [978124]:   0%|          | 0/10 [00:00<?, ?it/s]

In [10]:
results = {}
results['letters'] = acc_letters_all
results['tasks'] = acc_tasks_all
results['congruency'] = acc_congruency_all
results['output'] = acc_output_all

In [11]:
with open(path_pickle, 'wb') as f:
        pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)