In [4]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import copy

import torch
import torch.nn.functional as F

import sys
import networks as nets
import net_utils as net_utils
import int_data as syn
import analysis as analysis
import context_data as context


c_vals = ['firebrick', 'darkgreen', 'blue', 'darkorange', 'm', 'deeppink', 'r', 'gray', 'g', 'navy', 
                'y', 'purple', 'cyan', 'olive', 'skyblue', 'pink', 'tan']
c_vals_l = ['salmon', 'limegreen', 'cornflowerblue', 'bisque', 'plum', 'pink', 'tomato', 'lightgray', 'g', 'b', 
                      'y', 'purple', 'cyan', 'olive', 'skyblue', 'pink', 'tan']

def participation_ratio_vector(C):
    """Computes the participation ratio of a vector of variances."""
    return np.sum(C) ** 2 / np.sum(C*C)

In [5]:
# Reload modules if changes have been made to them
from importlib import reload

reload(net_utils)
reload(nets)
reload(syn)
reload(analysis)
reload(context)

<module 'context_data' from '/home/ewertj2/KAM/joe-mpn/mpn/context_data.py'>

In [3]:
from net_utils import xe_classifier_accuracy

def init_net(net_params, verbose=True):
    
    # initialize net with default values
    input_dims = [net_params['n_inputs'], net_params['n_hidden'], net_params['n_outputs']]

    if net_params['netType'] in ('MPN',):
        # netClass = nets.HebbNet
        netClass = MultiPlasticNet
    elif net_params['netType'] in ('MPN2',):
        netClass = MultiPlasticNetTwo        
        input_dims.insert(1, net_params['n_hidden'])
    elif net_params['netType'] in ('MPN_rec',):
        netClass = MultiPlasticNetRec   
    else:
        raise ValueError('netType not recognized.')

    net = netClass(input_dims, verbose=verbose, MAct=net_params['MAct'])

    return net

def train_network(net_params, toy_params, current_net=None, save=False, save_root='', 
                  set_seed=True, verbose=True):
    """ 
    Code to train a single network. 
    
    OUTPUTS:
    net: the trained network
    toy_params: these are updated when data is generated

    """

    # Sets the random seed for reproducibility (this affects both data generation and network)
    if 'seed' in net_params and set_seed:
        if net_params['seed'] is not None: 
            np.random.seed(seed=net_params['seed'])
            torch.manual_seed(net_params['seed'])
    
    # Intitializes network and puts it on device
    if net_params['cuda']:
        if verbose: print('Using CUDA...')
        device = torch.device('cuda')
    else:
        if verbose: print('Using CPU...')
        device = torch.device('cpu')

    net = init_net(net_params, verbose=verbose)
    net.to(device)

    # Continually generates new data to train on
    # This will iterate in loop so that it only sees each type of data a set amount of times
    net_params['epochs'] = 0 if current_net is None else net.hist['epoch']
    validData, validOutputMask, _ = syn.generate_data(
        net_params['valid_set_size'], toy_params, net_params['n_outputs'], 
        verbose=False, auto_balance=False, device=device)
    early_stop = False
    new_thresh = True # Triggers threshold setting for first call of .fit, but turns off after first call

    while not early_stop:
        net_params['epochs'] += 10 # Number of times each example is passed to the network
        trainData, trainOutputMask, toy_params = syn.generate_data(
            net_params['train_set_size'], toy_params, net_params['n_outputs'], 
            verbose=False, auto_balance=False, device=device)
        
        early_stop = net.fit('sequence', epochs=net_params['epochs'], 
                             trainData=trainData, batchSize=net_params['batch_size'],
                             validBatch=validData[:,:,:], learningRate=net_params['learning_rate'],
                             newThresh=new_thresh, monitorFreq=50, 
                             trainOutputMask=trainOutputMask, validOutputMask=validOutputMask,
                             validStopThres=net_params['accEarlyStop'], weightReg=net_params['weight_reg'], 
                             regLambda=net_params['reg_lambda'], gradientClip=net_params['gradient_clip'],
                             earlyStopValid=net_params['validEarlyStop'], minMaxIter=net_params['minMaxIter']) 
        new_thresh = False   

    return net, toy_params, net_params