In [2]:
import neurogym as nygym
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import copy

import torch
import torch.nn.functional as F

import sys
import networks as nets
import net_utils as net_utils
import int_data as syn
import analysis as analysis
import context_data as context


c_vals = ['firebrick', 'darkgreen', 'blue', 'darkorange', 'm', 'deeppink', 'r', 'gray', 'g', 'navy', 
                'y', 'purple', 'cyan', 'olive', 'skyblue', 'pink', 'tan']
c_vals_l = ['salmon', 'limegreen', 'cornflowerblue', 'bisque', 'plum', 'pink', 'tomato', 'lightgray', 'g', 'b', 
                      'y', 'purple', 'cyan', 'olive', 'skyblue', 'pink', 'tan']

def participation_ratio_vector(C):
    """Computes the participation ratio of a vector of variances."""
    return np.sum(C) ** 2 / np.sum(C*C)

In [3]:
# Reload modules if changes have been made to them
from importlib import reload

reload(net_utils)
reload(nets)
reload(syn)
reload(analysis)
reload(context)

<module 'context_data' from '/home/eulerfrog/research/mpn/context_data.py'>

In [None]:
from net_utils import xe_classifier_accuracy

def init_net(net_params, verbose=True):
    
    # initialize net with default values
    input_dims = [net_params['n_inputs'], net_params['n_hidden'], net_params['n_outputs']]

    if net_params['netType'] in ('MPN',):
        # netClass = nets.HebbNet
        netClass = MultiPlasticNet
    elif net_params['netType'] in ('MPN2',):
        netClass = MultiPlasticNetTwo        
        input_dims.insert(1, net_params['n_hidden'])
    elif net_params['netType'] in ('MPN_rec',):
        netClass = MultiPlasticNetRec   
    else:
        raise ValueError('netType not recognized.')

    net = netClass(input_dims, verbose=verbose, MAct=net_params['MAct'])

    return net

def train_network(net_params, toy_params, current_net=None, save=False, save_root='', 
                  set_seed=True, verbose=True):
    """ 
    Code to train a single network. 
    
    OUTPUTS:
    net: the trained network
    toy_params: these are updated when data is generated

    """

    # Sets the random seed for reproducibility (this affects both data generation and network)
    if 'seed' in net_params and set_seed:
        if net_params['seed'] is not None: 
            np.random.seed(seed=net_params['seed'])
            torch.manual_seed(net_params['seed'])
    
    # Intitializes network and puts it on device
    if net_params['cuda']:
        if verbose: print('Using CUDA...')
        device = torch.device('cuda')
    else:
        if verbose: print('Using CPU...')
        device = torch.device('cpu')

    net = init_net(net_params, verbose=verbose)
    net.to(device)

    # Continually generates new data to train on
    # This will iterate in loop so that it only sees each type of data a set amount of times
    net_params['epochs'] = 0 if current_net is None else net.hist['epoch']
    validData, validOutputMask, _ = syn.generate_data(
        net_params['valid_set_size'], toy_params, net_params['n_outputs'], 
        verbose=False, auto_balance=False, device=device)
    early_stop = False
    new_thresh = True # Triggers threshold setting for first call of .fit, but turns off after first call

    while not early_stop:
        net_params['epochs'] += 10 # Number of times each example is passed to the network
        trainData, trainOutputMask, toy_params = syn.generate_data(
            net_params['train_set_size'], toy_params, net_params['n_outputs'], 
            verbose=False, auto_balance=False, device=device)
        
        early_stop = net.fit('sequence', epochs=net_params['epochs'], 
                             trainData=trainData, batchSize=net_params['batch_size'],
                             validBatch=validData[:,:,:], learningRate=net_params['learning_rate'],
                             newThresh=new_thresh, monitorFreq=50, 
                             trainOutputMask=trainOutputMask, validOutputMask=validOutputMask,
                             validStopThres=net_params['accEarlyStop'], weightReg=net_params['weight_reg'], 
                             regLambda=net_params['reg_lambda'], gradientClip=net_params['gradient_clip'],
                             earlyStopValid=net_params['validEarlyStop'], minMaxIter=net_params['minMaxIter']) 
        new_thresh = False   

    return net, toy_params, net_params

In [None]:
class MultiPlasticNet(StatefulBase): """ Two-layer feedforward setup, with single multi-plastic layer followed by a readout layer.

    Same architecture used in "Neural Population Dynamics of Computing with Synaptic Modulations"
    """        
    def __init__(self, init, verbose=True, **mpnArgs):        
        super(MultiPlasticNet, self).__init__()        
        
        Nx,Nh,Ny = init
        # For readouts
        W,b = random_weight_init([Nh,Ny], bias=True)
        
        self.n_inputs = Nx
        self.n_hidden = Nh
        self.n_outputs = Ny 
        
        self.loss_fn = F.cross_entropy # Reductions is mean by default
        self.acc_fn = xe_classifier_accuracy 
    
        # Creates the input MP layer
        self.mp_layer = MultiPlasticLayer((self.n_inputs, self.n_hidden), verbose=verbose, **mpnArgs)
    
        # Input layer activation
        self.f = torch.tanh
        # Readout layer
        self.w2 = nn.Parameter(torch.tensor(W[0], dtype=torch.float))
        # Readout bias is not used (easier interpretting readouts in the latter)'
        self.register_buffer('b2', torch.zeros_like(torch.tensor(b[0])))
    
    def reset_state(self, batchSize=1):
        """
        Resets states of all internal layer SM matrices
        """
        self.mp_layer.reset_state(batchSize=batchSize)   
    
    def forward(self, x, debug=False):
        """
        This modifies the internal state of the model (self.M). 
        Don't call twice in a row unless you want to update self.M twice!
    
        x.shape: [B, Nx]
        b1.shape: [Nh]
        w1.shape=[Nx,Nh], 
        M.shape=[B,Nh,Nx], 
    
        """
        # Apply input multi-plastic layer, returns pre-activation
        h_tilde = self.mp_layer(x, debug=False)
        h = self.f(h_tilde)
        
        # M updated internally when this is called
        self.mp_layer.update_sm_matrix(x, h)                  
    
        # (1, Ny) + [(B, Nh,) x (Nh, Ny) = (B, Ny)] = (B, Ny)
        y_tilde = self.b2.unsqueeze(0) + h.squeeze(dim=2) @ torch.transpose(self.w2, 0, 1)
        y = y_tilde
                           
        if debug:
            h_tilde = h_tilde.squeeze(dim=2)
            h = h.squeeze(dim=2)
            return h_tilde, h, y_tilde, y, self.mp_layer.M 
        else:
            return y   
     
    def evaluate(self, batch):
        """
        Runs a full sequence of the given back size through the network.
        """
        # Begin by resetting the state
        self.reset_state(batchSize=batch[0].shape[0])
    
        out_size = torch.Size([batch[1].shape[0], batch[1].shape[1], self.n_outputs]) # [B, T, Ny]
        out = torch.empty(out_size, dtype=torch.float, layout=batch[1].layout, device=batch[1].device)
    
        for time_idx in range(batch[0].shape[1]):
            x = batch[0][:, time_idx, :] # [B, Nx]
            out[:, time_idx] = self(x)
    
        return out
        
    @torch.no_grad()    
    def evaluate_debug(self, batch, batchMask=None, acc=True, reset=True):
        """ 
        Runs a full sequence of the given back size through the network, but now keeps track of all sorts of parameters
        """
        B = batch[0].shape[0]
    
        if reset:
            self.reset_state(batchSize=B)
    
        Nx = self.n_inputs
        Nh = self.n_hidden
        Ny = self.n_outputs
        T = batch[1].shape[1]
        db = {'x' : torch.empty(B,T,Nx),
              'h_tilde' : torch.empty(B,T,Nh),
              'h' : torch.empty(B,T,Nh),
              'Wxb' : torch.empty(B,T,Nh),
              'M': torch.empty(B,T,Nh,Nx),
              'Mx' : torch.empty(B,T,Nh),
              'y_tilde' : torch.empty(B,T,Ny),
              'out' : torch.empty(B,T,Ny),
              }
        for time_idx in range(batch[0].shape[1]):
            x = batch[0][:, time_idx, :] # [B, Nx]
            db['x'][:,time_idx,:] = x
    
            (db['h_tilde'][:,time_idx], db['h'][:,time_idx,:], db['y_tilde'][:,time_idx,:], 
                db['out'][:,time_idx,:], db['M'][:,time_idx,:]) = self(x, debug=True)      
            db['Mx'][:,time_idx,:] = torch.bmm(self.mp_layer.M, x.unsqueeze(2)).squeeze(2) 
    
            db['Wxb'][:,time_idx] = self.mp_layer.b1.unsqueeze(0) + torch.mm(x, torch.transpose(self.mp_layer.w1, 0, 1))
        
        if acc:
            db['acc'] = self.accuracy(batch, out=db['out'].to(self.w2.device), outputMask=batchMask).item()  
                             
        return db

In [None]:
from net_utils import random_weight_init

def convert_ngym_dataset(dataset_params, set_size=None, device='cpu', mask_type=None):
    """
    This converts a neroGym dataset into one that the code can use.

    Mostly just transposes the batch and sequence dimensions, then combines
    them into a TensorDataset. Also creates a mask of all trues.
    """

    # dataset = dataset_params['dataset']

    # Creates a brand new dataset each time to make sure we are starting from the
    # beginning of a new sequence. There is probably a better way to do this.
    kwargs = {'dt': dataset_params['dt']}
    dataset = ngym.Dataset(dataset_params['dataset_name'],
                           env_kwargs=kwargs, batch_size=set_size,
                           seq_len=dataset_params['seq_length'])

    if set_size is not None:
        dataset.batchsize = set_size # Just create as a single large batch for now

    act_size = dataset.env.action_space.n

    inputs, labels = dataset()

    # Default in our setup (batch, seq_idx, :) so need to swap dims
    inputs = np.transpose(inputs, axes=(1, 0, 2))
    labels = np.transpose(labels, axes=(1, 0,))[:, :, np.newaxis]

    if mask_type is None: # Mask is always just all time steps, so creates all True array
        masks = np.ones((inputs.shape[0], inputs.shape[1], act_size))
    elif mask_type == 'label': # Masks on when labels are nonzero
        masks_flat = (labels > 0.0).astype(np.int32) # (B, seq_len)
        masks = np.repeat(masks_flat, act_size, axis=-1) # (B, seq_len, act_size)
    elif mask_type == 'no_fix': # Masks on when fixation is zero, assumes fixation is zeroth input
        masks_flat = (inputs[:, :, 0:1] == 0.0).astype(np.int32) # (B, seq_len)
        if np.sum(masks_flat) == 0:
            raise ValueError('Mask is all zeros!')
        masks = np.repeat(masks_flat, act_size, axis=-1) # (B, seq_len, act_size)
    else:
        raise ValueError('mask type {} not recoginized'.format(mask_type))

    # Note this has to happen after mask assignment since sometimes raw inputs are used to determine mask
    if dataset_params['convert_inputs']:
        # If the conversion is not already generated, create it
        if 'convert_mat' not in dataset_params:
            W,b = random_weight_init([inputs.shape[-1], dataset_params['input_dim']], bias=True)
            dataset_params['convert_mat'] = W[0]
            dataset_params['convert_b'] = b[0][np.newaxis, np.newaxis, :]

        inputs = np.matmul(inputs, dataset_params['convert_mat'].T) + dataset_params['convert_b']

    inputs = torch.from_numpy(inputs).type(torch.float).to(device) # inputs.shape (16, 100, 3)
    labels = torch.from_numpy(labels).type(torch.long).to(device) # labels.shape = (16, 100, 1)
    masks = torch.tensor(masks, dtype=torch.bool, device=device)

    trainData = torch.utils.data.TensorDataset(inputs, labels)

    return trainData, masks, dataset_params