# Functional Specialization metrics

### Imports

In [None]:
import torch.nn as nn
import torch
#from tqdm import tqdm, trange
from tqdm.notebook import tqdm as tqdm_n

In [None]:
from community.data.datasets import get_datasets
from community.data.process import temporal_data
from community.common.init import init_community, init_optimizers
from community.common.utils import plot_grid
from community.common.training import train_community
from community.funcspec.correlation import fixed_information_data, get_pearson_metrics, compute_correlation_metric, plot_correlations
from community.funcspec.bottleneck import readout_retrain, compute_bottleneck_metrics
from community.funcspec.masks import train_and_get_mask_metric, compute_mask_metric

In [None]:
import warnings
#warnings.filterwarnings('ignore')

In [None]:
%load_ext autoreload
%aimport community.funcspec.masks
%autoreload 2



# Datasets

In [None]:
use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 256
multi_loaders, double_loaders, single_loaders = get_datasets('../data', batch_size, use_cuda)

### Community Initialization

In [None]:
agents_params_dict = {'n_agents' : 2,
                         'n_in' : 784,
                         'n_ins' : None,
                         'n_hid' : 100,
                         'n_layer' : 1,
                         'n_out' : 10,
                         'train_in_out': (True, False),
                         'use_readout': True,
                         'cell_type': str(nn.RNN),
                         'use_bottleneck': False,
                         'dropout': 0}

p_con = 1e-3

community = init_community(agents_params_dict, p_con, device=device, use_deepR=False)
community.nb_connections

In [None]:
params = lr, gamma = 1e-3, 0.95
params_dict = {'lr' : lr, 'gamma' : gamma}

deepR_params = l1, gdnoise, lr, gamma, cooling = 1e-5, 1e-3, 1e-3, 0.95, 0.95
deepR_params_dict = {'l1' : l1, 'gdnoise' : gdnoise, 'lr' : lr, 'gamma' : gamma, 'cooling' : cooling}

optimizers, schedulers = init_optimizers(community, params_dict, deepR_params_dict)

In [None]:
n

In [None]:
data, target = next(iter(multi_loaders[0]))
d = 0
plot_grid([data[i, :, ...].reshape(1, 56, 28) for i in range(10)], [target[i] for i in range(10)], figsize=(20, 2))
#data_t, target_t = varying_temporal_data(data, target, 5, False, True)
#plot([[dt[:, i, ...].reshape(1, 56, 28) for dt in data_t] for i in range(2)], [[t[i, 1] for t in target_t] for i in range(2)], figsize=(10, 12))

## Training

In [None]:
training_dict = {
    'n_epochs' : 5, 
    'task' : 'parity_digits',
    'global_rewire' : True, 
    'check_gradients' : False, 
    'reg_factor' : 0.,
    'train_connections' : True,
    'decision_params' : ('last', 'max'),
    'early_stop' : True ,
    'deepR_params_dict' : deepR_params_dict,
}

#pyaml.save(training_dict, '../community/common/default_train_dict.yml')

train_out = train_community(community, *multi_loaders, optimizers, 
                            schedulers=schedulers, config=training_dict, device=device)
                            
results = train_out

# Metrics

### Correlation 

In [None]:
datas, label = next(iter(multi_loaders[0]))
datas = temporal_data(datas).to(device)
fixed_datas = fixed_information_data(datas, label, 'label', i=1)
fixed_datas = [[d.reshape(1, 56, 28) for d in data[0, :, :10, :].transpose(0, 1).cpu()] for data in fixed_datas]
labels = [[]]
plot_grid(fixed_datas, figsize=(10, 1*len(fixed_datas)))

In [None]:
corrs = get_pearson_metrics(community, multi_loaders, use_tqdm=False, device=device)

In [None]:
corrs.shape

In [None]:
from community.common.utils import get_wandb_artifact
community_states, *_ = get_wandb_artifact(name='state_dicts', project='funcspec')

In [None]:
correlations = {}
for p_con, states in tqdm_n(community_states.items()) : 
    correlations[p_con] = []
    for state in states[:] : 
        community.load_state_dict(state) 
        correlations[p_con].append(get_pearson_metrics(community, multi_loaders, use_tqdm=False, device=device))
    correlations[p_con] = np.array(correlations[p_con])

In [None]:
pearson_correlations = {}
pearson_correlations['Pearson_Label'] = correlations

In [None]:
plot_correlations(pearson_correlations)

### Bottleneck

In [None]:
bottleneck_metric = readout_retrain(community, multi_loaders, device=device, use_tqdm=True, n_epochs=1)

In [None]:
bottleneck_metric['accs'].mean(0).max(-1)

### Weight Masks

In [None]:
masks_metric = train_and_get_mask_metric(community, 0.1, multi_loaders, device=device, n_tests=1)

In [None]:
masks_prop, masks_accs, _, masks_states = masks_metric

In [None]:
masks_accs.shape

In [None]:
from community.common.utils import get_wandb_runs
import wandb, torch, numpy as np

In [None]:
runs = get_wandb_runs(run_id=None)
metric_path = runs[0].config['saves']['metrics_save_path'] + 'Masks'
mask_metric = torch.load(metric_path)
wandb.run = runs[-1]

In [None]:
from community.funcspec.masks import plot_mask_metric, get_metrics_from_saved_masks

In [None]:
mask_metric = get_metrics_from_saved_masks(mask_metric, sparsities=[0.2])

In [None]:
plot_mask_metric(mask_metric)

# New Polygon Task

In [None]:
from math import radians, pi, cos, sin
import numpy as np
def draw_polygon(sides, x0, y0, r=1, rotate=0):
    """Draw an n-sided regular polygon.

    Args:
        sides (int): Number of polygon sides.
        x0, y0 (int): Coordinates of center point.
        r (int): Radius.
        color (int): RGB565 color value.
        rotate (Optional float): Rotation in degrees relative to origin.
    Note:
        The center point is the center of the x0,y0 pixel.
        Since pixels are not divisible, the radius is integer rounded
        up to complete on a full pixel.  Therefore diameter = 2 x r + 1.
    """
    coords = []
    theta = pi/2
    n = sides +1
    for s in range(n):
        t = 2.0 * pi * s / sides + theta
        coords.append([r * cos(t) + x0, r * sin(t) + y0])

    # Cast to python float first to fix rounding errors
    #self.draw_lines(coords, color=color) 
    return np.array(coords)


In [None]:
poly

In [None]:
poly = draw_polygon(15, 0, 0)
fig  = plt.figure(figsize=(.7, .7))
plt.plot(poly[:, 0], poly[:, 1], linewidth=.1)
plt.axis('off')
plt.tight_layout(pad=0.)
plt.savefig('poly.png')
#rs = Resize((28, 28), interpolation=InterpolationMode.HAMMING)
poly_tensor = TF.to_tensor(Image.open('poly.png'))[:1]
print(poly_tensor.shape)
pad = (poly_tensor.shape[1] - 28) // 2
digits = TF.normalize(TF.pad(next(iter(single_loaders[0]))[0][0], [pad, pad], fill=0), 1, 1)
print(digits.shape)
poly_tensor = TF.normalize(TF.resize(poly_tensor, digits.shape[1]), 1, 1)


In [None]:
final = (1 - poly_tensor.data.numpy() + digits.data.numpy())[0]
plt.imshow(final)

In [None]:
digits.max()

In [None]:
from PIL import Image
import torchvision.transforms.functional as TF
from torchvision.transforms import ToTensor, Resize, InterpolationMode, Pad

In [None]:
poly = pat.RegularPolygon((0, 0), 10)
plt.plot(poly)

In [None]:
plt.scatter(poly.get_path()

In [None]:
attrib