In [1]:
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch.nn as nn
import torch.optim
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

import multitask.dataset as dataset
from multitask.models.parallel import get_parallel_model, calculate_rdm, plot_rdm
import multitask.models.parallel.hooks as hooks
from multitask.utils.training import get_device
from multitask.utils.argparse import check_runs


In [2]:
path_data = os.path.join('..', '..', 'data')
path_pickle = os.path.join('pickle', 'results_linear_decoder_parallel.pickle')
path_model_individual = os.path.join('..', '..', 'results', 'parallel')

In [3]:
num_runs = 10
initial_seed = 6789
max_seed = 10e5
num_epochs = 50
num_hidden = 10 * [100]
batch_size = 100
num_train = 50000
num_test = 10000
tasks_names = ['parity', 'value']
idxs_contexts = None

In [4]:
parameters = {
    'num_runs': num_runs,
    'initial_seed': initial_seed,
    'max_seed': max_seed,
    'num_epochs': num_epochs,
    'num_hidden': num_hidden,
    'batch_size': batch_size,
    'num_train': num_train,
    'num_test': num_test,
    'tasks': tasks_names,
    'idxs_contexts': idxs_contexts
}

data_folder = check_runs(path_model_individual, parameters)

Found simulation in ../../results/parallel with the same parameters (2022-09-28_01_56_10)


In [5]:
pickle_data = os.path.join(data_folder, 'data.pickle')
with open(pickle_data, 'rb') as handle:
    results_parallel = pickle.load(handle)

In [6]:
seeds = sorted(list(results_parallel.keys()))
num_seeds = len(seeds)
num_tasks = len(tasks_names)

print(seeds)
print(tasks_names)

[10612, 17350, 130146, 173249, 213794, 341996, 440064, 668870, 858781, 894813]
['parity', 'value']


In [7]:
tasks_datasets = dataset.get_tasks_dict(tasks_names, root=path_data)

In [8]:
parallel_datasets = {}
for task_name in tasks_names:
    parallel_datasets[task_name] = tasks_datasets[task_name]

parallel_tasks = dataset.MultilabelTasks(parallel_datasets)

In [9]:
device = get_device()
criterion = nn.CrossEntropyLoss()
num_layers = len(num_hidden)
num_tasks = len(tasks_names)

list_activations = []
list_numbers = []

for i_seed, seed in tqdm(enumerate(seeds), total=num_runs):
    state_dict = results_parallel[seed]['model']
    model = get_parallel_model(num_tasks,
                               num_hidden,
                               device)

    model.load_state_dict(state_dict)
    
    indices = results_parallel[seed]['indices']

    test_sampler = dataset.SequentialSampler(indices['test'])
    parallel_testloader = torch.utils.data.DataLoader(parallel_tasks,
                                                      sampler=test_sampler,
                                                      batch_size=batch_size)

    numbers = parallel_datasets[tasks_names[0]].numbers
    numbers = numbers[indices['test']]
    
    _, activations = hooks.get_layer_activations(model,
                                                 parallel_testloader,
                                                 criterion=criterion,
                                                 device=device,
                                                 disable=True)
    
    parallel_activations = {}
    for task_name in tasks_names:
        parallel_activations[task_name] = activations

    list_activations.append(parallel_activations)
    list_numbers.append(numbers)

Running on GPU.


  0%|          | 0/10 [00:00<?, ?it/s]

In [10]:
num_layers = len(num_hidden)
max_iter = 8000

acc_numbers_all = np.zeros((num_seeds, num_layers))
acc_tasks_all = np.zeros((num_seeds, num_layers))
acc_congruency_all = np.zeros((num_seeds, num_layers))
acc_output_all = np.zeros((num_seeds, num_layers))


for i_seed, seed in enumerate(seeds):
    activations = list_activations[i_seed]
    numbers = list_numbers[i_seed]

    labels_numbers = np.hstack((numbers, numbers))
    labels_task = np.concatenate((np.zeros_like(numbers), np.ones_like(numbers)))
    labels_congruency = np.array([1 if number in [0, 2, 4, 5, 7, 9] else 0 for number in labels_numbers])

    labels_output_parity = np.array([1 if number in [0, 2, 4, 6, 8] else 0 for number in numbers])
    labels_output_value = np.array([1 if number in [0, 1, 2, 3, 4] else 0 for number in numbers])
    labels_output = np.concatenate((labels_output_parity, labels_output_value))

    for j_layer in tqdm(range(num_layers), desc=f'{i_seed} [{seed}]'):
        activations_decoder = None
        for task in tasks_names:
            activations_task = activations[task][f'layer{j_layer+1}']
            if activations_decoder is None:
                activations_decoder = activations_task
            else:
                activations_decoder = np.vstack((activations_decoder, 
                                                activations_task))
        assert activations_decoder.shape[0] == labels_numbers.shape[0]

        activations_decoder = (activations_decoder - activations_decoder.mean()) / activations_decoder.std()

        # Numbers task
        X_train, X_test, y_train, y_test = train_test_split(activations_decoder,
                                                            labels_numbers,
                                                            test_size=0.1,
                                                            random_state=seed)
        clf = LogisticRegression(random_state=seed,
                                max_iter=max_iter,
                                tol=1e-3).fit(X_train, y_train)
        acc_numbers_all[i_seed, j_layer] = clf.score(X_test, y_test)

        # Labels task
        X_train, X_test, y_train, y_test = train_test_split(activations_decoder,
                                                            labels_task,
                                                            test_size=0.1,
                                                            random_state=seed)
        clf = LogisticRegression(random_state=seed,
                                max_iter=max_iter,
                                tol=1e-3).fit(X_train, y_train)
        acc_tasks_all[i_seed, j_layer] = clf.score(X_test, y_test)

        # Congruency task
        X_train, X_test, y_train, y_test = train_test_split(activations_decoder,
                                                            labels_congruency,
                                                            test_size=0.1,
                                                            random_state=seed)
        clf = LogisticRegression(random_state=seed,
                                max_iter=max_iter,
                                tol=1e-3).fit(X_train, y_train)
        acc_congruency_all[i_seed, j_layer] = clf.score(X_test, y_test)

        # Output task
        X_train, X_test, y_train, y_test = train_test_split(activations_decoder,
                                                            labels_output,
                                                            test_size=0.1,
                                                            random_state=seed)
        clf = LogisticRegression(random_state=seed,
                                max_iter=max_iter,
                                tol=1e-3).fit(X_train, y_train)
        acc_output_all[i_seed, j_layer] = clf.score(X_test, y_test)

0 [10612]:   0%|          | 0/10 [00:00<?, ?it/s]

1 [17350]:   0%|          | 0/10 [00:00<?, ?it/s]

2 [130146]:   0%|          | 0/10 [00:00<?, ?it/s]

3 [173249]:   0%|          | 0/10 [00:00<?, ?it/s]

4 [213794]:   0%|          | 0/10 [00:00<?, ?it/s]

5 [341996]:   0%|          | 0/10 [00:00<?, ?it/s]

6 [440064]:   0%|          | 0/10 [00:00<?, ?it/s]

7 [668870]:   0%|          | 0/10 [00:00<?, ?it/s]

8 [858781]:   0%|          | 0/10 [00:00<?, ?it/s]

9 [894813]:   0%|          | 0/10 [00:00<?, ?it/s]

In [11]:
results = {}
results['numbers'] = acc_numbers_all
results['tasks'] = acc_tasks_all
results['congruency'] = acc_congruency_all
results['output'] = acc_output_all

In [12]:
with open(path_pickle, 'wb') as f:
        pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)