# Society of Mind

In [None]:
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import yaml


In [None]:
from community.data.datasets import get_datasets, Custom_EMNIST
from community.common.init import init_community, init_optimizers
from community.common.training import train_community

In [None]:
transform=transforms.Compose([
        lambda img : transforms.functional.rotate(img, -90),
        lambda img : transforms.functional.hflip(img),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

truncate = truncate=np.arange(10, 21)
truncate = truncate[truncate != 18]

emnist = Custom_EMNIST('../data/', train=False, data_type='byclass', truncate=truncate, download=True, transform=transform)

In [None]:
emnist.targets

In [None]:
from community.common.utils import plot_grid
data = [[emnist[i*10 + j][0] for i in range(10)] for j in range(10)]
label = [[emnist[i*10 + j][1] for i in range(10)] for j in range(10)]

plot_grid(data, label, figsize=(10, 10))

In [None]:
[[d[0] for d in data] for _ in range(2)]

In [None]:
import warnings
#warnings.filterwarnings('ignore')

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 256
multi_loaders, double_loaders, single_loaders = get_datasets('../data', batch_size, use_cuda)

## Standard Community

### Community Initialization

In [None]:
agents_params_dict = {'n_agents' : 2,
                         'n_in' : 784,
                         'n_ins' : None,
                         'n_hid' : 100,
                         'n_layer' : 1,
                         'n_out' : 10,
                         'train_in_out': (True, False),
                         'use_readout': True,
                         'cell_type': str(nn.RNN),
                         'use_bottleneck': False,
                         'dropout': 0}

p_con = 1e-3

community = init_community(agents_params_dict, p_con, device=device)
community.nb_connections

In [None]:
params = lr, gamma = 1e-3, 0.95
params_dict = {'lr' : lr, 'gamma' : gamma}

deepR_params = l1, gdnoise, lr, gamma, cooling = 1e-5, 1e-3, 1e-3, 0.95, 0.95
deepR_params_dict = {'l1' : l1, 'gdnoise' : gdnoise, 'lr' : lr, 'gamma' : gamma, 'cooling' : cooling}


optimizers, schedulers = init_optimizers(community, params_dict, deepR_params_dict)

In [None]:
n

### Training

In [None]:
training_dict = {
    'n_epochs' : 2, 
    'task' : 'parity_digits',
    'global_rewire' : True, 
    'check_gradients' : False, 
    'reg_factor' : 0.,
    'train_connections' : True,
    'global_rewire' : True,
    'decision_params' : ('last', 'max'),
    'early_stop' : True ,
    'deepR_params_dict' : deepR_params_dict,
}

#pyaml.save(training_dict, '../community/common/default_train_dict.yml')

train_out = train_community(community, *double_loaders, optimizers, 
                            schedulers=schedulers, config=training_dict, device=device)
                            
results = train_out

In [None]:
train_losses, train_accs, test_losses, test_accs, deciding_agents, best_state = list(results.values())

### Results

In [None]:
plt.plot(train_accs)
m = 1
epochs_slices = np.arange(0, len(train_accs)+1, m*len(double_loaders[0]))
max_per_epoch = [np.max(train_accs[epochs_slices[e]:epochs_slices[e+1]]) for e in range(len(epochs_slices)-1)]
#max_per_epoch.append(np.max(train_accs[epochs_slices[-1]:]))
plt.plot(epochs_slices[1:], max_per_epoch)


In [None]:
plt.plot(train_losses)

In [None]:
import pyaml
import yaml

In [None]:
with open('../config.yml', 'r') as datafile : 
    config = yaml.safe_load(datafile)

    

In [None]:
filter_config = {k : v for (k, v) in config.items() if k in ['datasets']}
filter_config

In [None]:
from community.common.utils import get_wandb_artifact

In [None]:
filter_config = {'datasets' : {
                    'data_type' : 'multi'
                },
                'task' : 'parity_digits'
}
get_wandb_artifact(filter_config, 'funcspec', 'state_dicts', process_config=True, run_id='195cgoaq')

In [None]:
! wandb login