# Functional Specialization metrics

### Imports

In [1]:
import torch.nn as nn
import torch
#from tqdm import tqdm, trange
from tqdm.notebook import tqdm as tqdm_n

In [2]:
from community.data.datasets import get_datasets
from community.data.process import temporal_data
from community.common.init import init_community, init_optimizers
from community.common.utils import plot_grid
from community.common.training import train_community
from community.funcspec.metrics.correlation import fixed_information_data, get_pearson_metrics, compute_correlation_metric
from community.funcspec.metrics.bottleneck import readout_retrain, compute_bottleneck_metrics
from community.funcspec.metrics.masks import train_and_get_mask_metric, compute_mask_metric

In [3]:
import warnings
#warnings.filterwarnings('ignore')

In [4]:
%load_ext autoreload
%autoreload 2
%aimport community.funcspec.metrics.masks



# Datasets

In [5]:
use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 32
multi_loaders, double_loaders, single_loaders = get_datasets('../data', batch_size, use_cuda)

### Community Initialization

In [6]:
agents_params_dict = {'n_agents' : 2,
                         'n_in' : 784,
                         'n_ins' : None,
                         'n_hid' : 100,
                         'n_layer' : 1,
                         'n_out' : 10,
                         'train_in_out': (True, False),
                         'use_readout': True,
                         'cell_type': str(nn.RNN),
                         'use_bottleneck': False,
                         'dropout': 0}

p_con = 1e-3

community = init_community(agents_params_dict, p_con, device=device)
community.nb_connections

{'01': tensor(10), '10': tensor(10)}

In [7]:
params = lr, gamma = 1e-3, 0.95
params_dict = {'lr' : lr, 'gamma' : gamma}

deepR_params = l1, gdnoise, lr, gamma, cooling = 1e-5, 1e-3, 1e-3, 0.95, 0.95
deepR_params_dict = {'l1' : l1, 'gdnoise' : gdnoise, 'lr' : lr, 'gamma' : gamma, 'cooling' : cooling}

optimizers, schedulers = init_optimizers(community, params_dict, deepR_params_dict)

In [8]:
n

NameError: name 'n' is not defined

In [None]:
data, target = next(iter(multi_loaders[1]))
d = 0
plot_grid([data[i, :, ...].reshape(1, 56, 28) for i in range(10)], [target[i] for i in range(10)], figsize=(20, 2))
#data_t, target_t = varying_temporal_data(data, target, 5, False, True)
#plot([[dt[:, i, ...].reshape(1, 56, 28) for dt in data_t] for i in range(2)], [[t[i, 1] for t in target_t] for i in range(2)], figsize=(10, 12))

## Training

In [None]:
training_dict = {
    'n_epochs' : 5, 
    'task' : 'parity_digits',
    'global_rewire' : True, 
    'check_gradients' : False, 
    'reg_factor' : 0.,
    'train_connections' : True,
    'decision_params' : ('last', 'max'),
    'early_stop' : True ,
    'deepR_params_dict' : deepR_params_dict,
}

#pyaml.save(training_dict, '../community/common/default_train_dict.yml')

train_out = train_community(community, *multi_loaders, optimizers, 
                            schedulers=schedulers, config=training_dict, device=device)
                            
results = train_out

Train Epoch::   0%|          | 0/5 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Metrics

### Correlation 

In [None]:
datas, label = next(iter(multi_loaders[0]))
datas = temporal_data(datas).to(device)
fixed_datas = fixed_information_data(datas, label, 'label', i=1)
fixed_datas = [[d.reshape(1, 56, 28) for d in data[0, :, :3, :].transpose(0, 1).cpu()] for data in fixed_datas]
labels = [[]]
plot_grid(fixed_datas, figsize=(3, 1*len(fixed_datas)))

In [None]:
correlations = get_pearson_metrics(community, multi_loaders, use_tqdm=True, device=device)

In [None]:
correlations.mean(-1).mean(-1)

### Bottleneck

In [None]:
bottleneck_metric = readout_retrain(community, multi_loaders, device=device, use_tqdm=True, n_epochs=1)

In [None]:
bottleneck_metric['accs'].mean(0).max(-1)

### Weight Masks

In [9]:
masks_metric = train_and_get_mask_metric(community, 0.1, multi_loaders, device=device, notebook=True, n_tests=1)

Metric Trials :   0%|          | 0/1 [00:00<?, ?it/s]

Train Epoch::   0%|          | 0/1 [00:00<?, ?it/s]

Train Epoch::   0%|          | 0/1 [00:00<?, ?it/s]

In [11]:
masks_metric[0]

array([[[0.90414041, 0.09585959],
        [0.        , 1.        ]]])