# Functional Specialization metrics

### Imports

In [None]:
import torch.nn as nn
import torch
import numpy as np
#from tqdm import tqdm, trange
from tqdm.notebook import tqdm as tqdm_n

In [None]:
from community.data.datasets import get_datasets
from community.data.process import temporal_data
from community.common.init import init_community, init_optimizers
from community.common.utils import plot_grid
from community.common.training import train_community


In [None]:
import warnings
#warnings.filterwarnings('ignore')

In [None]:
%load_ext autoreload
%aimport community.funcspec.masks
%autoreload 2



# Datasets

In [None]:
use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 256
multi_loaders, double_loaders_dig, double_loaders_letters, single_loaders, letters = get_datasets('../data', batch_size, use_cuda, fix_asym=True)
loaders = double_loaders_letters

In [None]:
targets = torch.cat([t for _, t in loaders[1]])
uniques, counts = targets.unique(dim=0, return_counts=True)
counts, counts.shape, targets.shape

### Community Initialization

In [62]:
agents_params_dict = {'n_agents' : 2,
                         'n_in' : 784,
                         'n_ins' : None,
                         'n_hid' : 50,
                         'n_layer' : 1,
                         'n_out' : 10,
                         'train_in_out': (True, True),
                         'use_readout': True,
                         'cell_type': str(nn.RNN),
                         'use_bottleneck': True,
                         'ag_dropout': 0.05}

com_dropout = 0.2

p_con = 500 * 1 / agents_params_dict['n_hid']**2

community = init_community(agents_params_dict, p_con, device=device, use_deepR=False, com_dropout=com_dropout)
print(community.nb_connections)
params = lr, gamma = 1e-3, 0.95
params_dict = {'lr' : lr, 'gamma' : gamma}

deepR_params = l1, gdnoise, lr, gamma, cooling = 1e-5, 1e-3, 1e-3, 0.95, 0.95
deepR_params_dict = {'l1' : l1, 'gdnoise' : gdnoise, 'lr' : lr, 'gamma' : gamma, 'cooling' : cooling}

optimizers, schedulers = init_optimizers(community, params_dict, deepR_params_dict)

{'01': tensor(500), '10': tensor(500)}


In [47]:
p_con

0.2

## Training

In [63]:
training_dict = {
    'n_epochs' : 10 + int(p_con*10), 
    'task' : 'parity_digits',
    'global_rewire' : True, 
    'check_gradients' : False, 
    'reg_factor' : 0.,
    'train_connections' : True,
    'decision_params' : ('last', 'max'),
    'stopping_acc' : .9 ,
    'early_stop' : True,
    'deepR_params_dict' : deepR_params_dict,
}

#pyaml.save(training_dict, '../community/common/default_train_dict.yml')

train_out = train_community(community, *loaders, optimizers, 
                            schedulers=schedulers, config=training_dict,
                            trials=(True, True), device=device)

results = train_out
community.best_acc = results['test_accs'].max()

Train Epoch::   0%|          | 0/12 [00:00<?, ?it/s]

In [None]:
results['deciding_agents'].mean(-1).mean(-1)

In [64]:
results['test_accs']

array([0.6553125 , 0.79179687, 0.83515625, 0.84828125, 0.86070313,
       0.86765625, 0.87625   , 0.88070313, 0.87984375, 0.88617188,
       0.891875  , 0.8928125 ])

# Metrics

### Correlation 

In [None]:
from community.funcspec.correlation import fixed_information_data, get_pearson_metrics, compute_correlation_metric, plot_correlations, v_pearsonr, get_correlation

In [None]:
datas, label = next(iter(loaders[1]))
datas = temporal_data(datas).to(device)
fixed_data = fixed_information_data(datas, label, fixed=0, fixed_mode='label')
fixed_datas = [[d.reshape(1, 56, 28) for d in data[0, :, :10, :].transpose(0, 1).cpu()] for data in fixed_data]
labels = [[]]
plot_grid(fixed_datas, figsize=(10, 1*len(fixed_datas)))

In [None]:
corrs = get_pearson_metrics(community, loaders, use_tqdm=True, device=device)

In [None]:
corrs.mean(-1).mean(-1)

In [None]:
diff = lambda n : (corrs.mean(-1).mean(-1)[n, n] - corrs.mean(-1).mean(-1)[1-n, n]) / (corrs.mean(-1).mean(-1)[n, n] + corrs.mean(-1).mean(-1)[1-n, n])
[diff(n) for n in range(2)] 

In [None]:
from community.common.utils import get_wandb_artifact
community_states, *_ = get_wandb_artifact(name='state_dicts', project='funcspec')

In [None]:
correlations = {}
for p_con, states in tqdm_n(community_states.items()) : 
    correlations[p_con] = []
    for state in states[:] : 
        community.load_state_dict(state) 
        correlations[p_con].append(get_pearson_metrics(community, multi_loaders, use_tqdm=False, device=device))
    correlations[p_con] = np.array(correlations[p_con])

In [None]:
pearson_correlations = {}
pearson_correlations['Pearson_Label'] = correlations

In [None]:
plot_correlations(pearson_correlations)

### Bottleneck

In [None]:
from community.funcspec.bottleneck import readout_retrain, compute_bottleneck_metrics


In [None]:
bottleneck_metric = readout_retrain(community, loaders, device=device, use_tqdm=True, n_epochs=10, n_tests=1, train_all_param=False)

In [None]:
community.agents[1].readout.weight

In [None]:
bottleneck_metric['accs']

### Weight Masks

In [None]:
from community.funcspec.masks import train_and_get_mask_metric, compute_mask_metric, Mask_Community, get_proportions, get_proportions_per_agent, train_mask, find_optimal_sparsity

In [None]:
masked_community, test_loss, test_accs, best_state = train_mask(community, 0.1, 0, loaders, use_tqdm=True)

In [None]:
masks_metric = train_and_get_mask_metric(community, .5, loaders, device=device, n_tests=1, n_epochs=2, use_tqdm=True, use_optimal_sparsity=True)

In [None]:
diff = lambda n : (masks_metric['proportions'][0][n, n] - masks_metric['proportions'][0][1-n, n]) / (masks_metric['proportions'][0][n, n] + masks_metric['proportions'][0][1-n, n])
[diff(n) for n in range(2)] 

In [None]:
masks_metric['test_accs']

In [None]:
masks_metric['sparsities'].mean(0)

## WandB Loading

In [9]:
from community.common.wandb_utils import get_wandb_runs, get_wandb_artifact
import wandb, torch, numpy as np

In [21]:
#runs = get_wandb_runs(run_id=None)
artifacts = get_wandb_artifact(project='funcspec', name='all_results')

Found 16 runs, returning...


In [20]:
artifacts[0]['Correlation']

{0.0004: array([[0.68913522, 0.61625586],
        [0.38977652, 0.43961567]]),
 0.0028: array([[0.74053665, 0.61810281],
        [0.46858397, 0.49160229]]),
 0.02: array([[0.51784033, 0.37991701],
        [0.53182244, 0.5562143 ]]),
 0.1413: array([[       nan, 0.48647937],
        [       nan, 0.48874858]]),
 0.999: array([[0.66787579,        nan],
        [0.806576  ,        nan]])}

In [None]:
from community.funcspec.masks import plot_mask_metric, get_metrics_from_saved_masks

In [None]:
mask_metric = get_metrics_from_saved_masks(mask_metric, sparsities=[0.2])

In [None]:
plot_mask_metric(mask_metric)

# New Polygon Task

In [None]:
from math import radians, pi, cos, sin
import numpy as np
def draw_polygon(sides, x0, y0, r=1, rotate=0):
    """Draw an n-sided regular polygon.

    Args:
        sides (int): Number of polygon sides.
        x0, y0 (int): Coordinates of center point.
        r (int): Radius.
        color (int): RGB565 color value.
        rotate (Optional float): Rotation in degrees relative to origin.
    Note:
        The center point is the center of the x0,y0 pixel.
        Since pixels are not divisible, the radius is integer rounded
        up to complete on a full pixel.  Therefore diameter = 2 x r + 1.
    """
    coords = []
    theta = pi/2
    n = sides +1
    for s in range(n):
        t = 2.0 * pi * s / sides + theta
        coords.append([r * cos(t) + x0, r * sin(t) + y0])

    # Cast to python float first to fix rounding errors
    #self.draw_lines(coords, color=color) 
    return np.array(coords)


In [None]:
poly

In [None]:
poly = draw_polygon(15, 0, 0)
fig  = plt.figure(figsize=(.7, .7))
plt.plot(poly[:, 0], poly[:, 1], linewidth=.1)
plt.axis('off')
plt.tight_layout(pad=0.)
plt.savefig('poly.png')
#rs = Resize((28, 28), interpolation=InterpolationMode.HAMMING)
poly_tensor = TF.to_tensor(Image.open('poly.png'))[:1]
print(poly_tensor.shape)
pad = (poly_tensor.shape[1] - 28) // 2
digits = TF.normalize(TF.pad(next(iter(single_loaders[0]))[0][0], [pad, pad], fill=0), 1, 1)
print(digits.shape)
poly_tensor = TF.normalize(TF.resize(poly_tensor, digits.shape[1]), 1, 1)


In [None]:
final = (1 - poly_tensor.data.numpy() + digits.data.numpy())[0]
plt.imshow(final)

In [None]:
digits.max()

In [None]:
from PIL import Image
import torchvision.transforms.functional as TF
from torchvision.transforms import ToTensor, Resize, InterpolationMode, Pad

In [None]:
poly = pat.RegularPolygon((0, 0), 10)
plt.plot(poly)

In [None]:
plt.scatter(poly.get_path()

# Others


In [None]:
from community.common.wandb_utils import get_wandb_runs


In [None]:
metrics = torch.load('/home/gb21/Code/ANNs/community-of-agents/wandb/latest-run/files/single/metrics/metric_results')
training = torch.load('/home/gb21/Code/ANNs/community-of-agents/wandb/latest-run/files/single/training/training_results')

In [None]:
community.load_state_dict(training[0.1][True]['best_state'])

In [None]:
corrs = get_pearson_metrics(community, multi_loaders, device=device, use_tqdm=True, n_tests=128)

In [None]:
cor = corrs.mean(-1).mean(-1)
cor

In [None]:
cor = metrics['Correlation'][0.1]
bot = metrics['Bottleneck'][0.1]
mask = metrics['Masks'][0.1]

In [None]:
bottleneck = readout_retrain(community, multi_loaders, n_epochs=1, n_tests=1, use_tqdm=True, device=device)

In [None]:
bot = bottleneck['accs'].mean(0).max(-1)

In [None]:
i = 1
(cor[i, i] - cor[1-i, i])/ (cor[0, i] + cor[1, i])

In [None]:
ag_metric = lambda metric, ag : (metric[ag, ag], metric[1-ag, ag])
diff_metric = lambda metric, ag : ((ag_metric(metric, ag)[0]-ag_metric(metric, ag)[1])/(ag_metric(metric, ag)[0]+ag_metric(metric, ag)[1]))


In [None]:
mask