# RNNs
Training RNNs on the Mante-Susillo context-dependent integration (CDI) task.

# Overview of

In [None]:
# import the necesary modules
import numpy as np
from random import randrange
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.autograd import Variable

# Overview of how to build NN in PyTorch

In [None]:
# create a simple RNN class
class RNN(nn.Module):
  def __init__(self):
    super(RNN,self).__init__()
    # initialize RNN
  def forward(x):
    # forward propogate inputs


Add a section on how to build different RNN architectures in python

Add some intuition of BPTT

Add some inutition on the pros/cons of these methods for neuroscience


# Overview of how Torch handles backpropogation and writing a simple training loop

In [1]:
# simple traiing loop

## Build an RNN

In [None]:
class RNN(nn.Module):
    # recently changed var by normalizing it by N, before was 0.045 w/o normilazation
    def __init__(self, hyperParams, task, device='cuda:0'):
        #self._device = torch.device("cpu")
        #if torch.cuda.is_available():
        self._device = torch.device(device)
        super(RNN, self).__init__()                                            # initialize parent class
        self._inputSize = int(hyperParams["inputSize"])
        self._hiddenSize = int(hyperParams["hiddenSize"])
        self._outputSize = int(hyperParams["outputSize"])
        self._g = hyperParams["g"]
        self._hiddenInitScale = hyperParams["initScale"]                       # sets tolerance for determining validation accuracy# initializes hidden layer                                                   # sets noise for hidden initialization
        self._dt= hyperParams["dt"]
        self._batchSize = int(hyperParams["batchSize"])
        self._hParams = hyperParams                                             # used for saving training conditions
        self._init_hidden()   
        self._totalTrainTime = 0                                               # accumulates training time
        self._timerStarted = False
        self._useForce = False            # if set to true this slightly changes the forward pass 
        self._useHeb = False
        self._fixedPoints = []
        self._tol = 1
        
        self._J = {
        'in' : torch.randn(self._hiddenSize, self._inputSize).to(self._device)*(1/2),
        'rec' : ((self._g**2)/50)*torch.randn(self._hiddenSize, self._hiddenSize).to(self._device),
        'out' : (0.1*torch.randn(self._outputSize, self._hiddenSize).to(self._device)),
        'bias' : torch.zeros(self._hiddenSize, 1).to(self._device)*(1/2)
        }


        try:
            self._use_ReLU = int(hyperParams["ReLU"])             # determines the activation function to use
        except:
            self._use_ReLU = 0
            

        self._task = task

              

        #create an activity tensor that will hold a history of all hidden states
        self._activityTensor = np.zeros((50))
        self._neuronIX = None    #will hold neuron sorting from TCA
        self._targets=[]
        self._losses = 0     #trainer should update this with a list of training losses
        self._MODEL_NAME = 'models/UNSPECIFIED MODEL'   #trainer should update this
        self._pca = []
        self._recMagnitude = []     #will hold the magnitude of reccurent connections at each step of training
        # will hold the previous weight values
        self._w_hist = []    # will hold the previous weight values

        self._valHist = []        # empty list to hold history of validation accuracies

        #self.fractions = []
        #weight matrices for the RNN are initialized with random normal

    def _startTimer(self):
        '''
        starts timer for training purposes

        Returns
        -------
        None.

        '''
        self._tStart = time.time()
        self._timerStarted = True
    
    def _endTimer(self):
        '''
        stops training timer

        Returns
        -------
        None.

        '''
        if self._timerStarted == False:
            return
        else:
            self._totalTrainTime += time.time() - self._tStart
            self._timerStarted = False
    
    def setName(self, name):
        '''
        sets name of model

        Parameters
        ----------
        name : string
            model will be saved as a .pt file with this name in the /models/ directory.

        Returns
        -------
        None.

        '''
        self._MODEL_NAME = "models/" + name
        
    def createValidationSet(self, test_iters=2000):
        '''
        DESCRIPTION:
        Creates the validation dataset which will be used to 
        decide when to terminate RNN training. The validation 
        dataset consists of means sampled uniformly from -0.1875 
        to +0.1875. The variance of all instances in the validation 
        dataset is equal to the variance of the trainign dataset 
        as specified by the task object.
        
        validation dataset has shape (test_iters, numInputs, tSteps)
        
        PARAMETERS:
        **valSize: specifies the size of the validation dataset to use
        **task: task object that is used to create the validation data 
        set
        '''
        # initialize tensors to hold validation data
        breakpoint()
        self.validationData = torch.zeros(test_iters, self._inputSize,  self._task.N).to(self._device)
        self.validationTargets = torch.zeros(test_iters,1).to(self._device)
        # means for validation data
        meanValues = np.linspace(0, 0.1875, 20)
        for trial in range(test_iters):
            # to get genetic and bptt different I divided by 30
            mean_overide = meanValues[trial %20]
            inpt_tmp, condition_tmp = self._task.GetInput(mean_overide=mean_overide)
            self.validationData[trial,:, :] = inpt_tmp.t()
            self.validationTargets[trial] = condition_tmp
        print('Validation dataset created!\n')

    def GetValidationAccuracy(self, test_iters=2000):
        '''
        Will get validation accuracy of the model on the specified task to 
        determine if training should be terminated
        '''
        accuracy = 0.0
        tol = self._tol

        inpt_data = self.validationData#.t()
        condition = self.validationTargets
        condition = torch.squeeze(condition)
        output = self.feed(inpt_data)
        output_final = torch.sign(output[-1,:])
        # scale ouput in the case of multisensory network
        # compute the difference magnitudes between model output and target output
        #if isinstance(self._task, DMC):
        #    differences = self.validationTargets-output_final.view(-1,1)
        #    num_errors = torch.sum(torch.abs(differences) > 0.15).item()
        #else:
        error = torch.abs(condition-output_final)
        # threshold this so that errors greater than tol(=0.1) are counted
        num_errors = torch.where(error>tol)[0].shape[0]
        accuracy = (test_iters - num_errors) / test_iters

        
        self._valHist.append(accuracy)                 # appends current validation accuracy to history
    
        
        return accuracy

    def _UpdateHidden(self, inpt):
        '''Updates the hidden state of the RNN. NOTE: The way the hidden state is updated 
        will vary depending on the task and learning rule the RNN was trained with. 


        Parameters
        ----------
        inpt : PyTorch cuda tensor
            inputs to the network at current time step. Has shape (inputSize, batchSize)
        use_relu : BOOL, optional
            when true, the network will use ReLU activations. The default is False.

        Returns
        -------
        hidden_next : PyTorch cuda Tensor
            Tensor of the updated neuron activations. Has shape (hiddenSize, batchSize)

        '''
        dt = self._dt
        Jin = self._J["in"]

        # set activation function
        if self._use_ReLU:
            act_func = nn.functional.relu
        else:
            act_func = lambda x : 1+torch.tanh(x)

        # forward pass through the network
        noiseTerm=0
        hidden_next = dt*torch.matmul(Jin, inpt) + \
        dt*torch.matmul(self._J['rec'], (1+torch.tanh(self._hidden))) + \
        (1-dt)*self._hidden + dt*self._J['bias'] + 0*noiseTerm

        self._hidden = hidden_next        # updates hidden layer
        return hidden_next

    def _forward(self, inpt):
        '''
        Computes the RNNs forward pass activations for a single timestep

        Parameters
        ----------
        inpt : PyTorch cuda Tensor 
            inputs to the network for the current timestep. Shape (inputSize, batchSize)

        Returns
        -------
        output : PyTorch cuda Tensor
            output of the network after this timestep. Has shape (outputSize, batchSize)
        PyTorch cuda Tensor
            copy of the hidden state activations after the forward pass. Has shape
            (hiddenSize, batchSize)

        '''

        #ensure the input is a torch tensor
        if torch.is_tensor(inpt) == False:
            inpt = torch.from_numpy(inpt).float()                              # inpt must have shape (1,1)
        inpt = inpt.reshape(self._inputSize, -1)
        
        # compute the forward pass
        self._UpdateHidden(inpt)
        if self._useHeb:
            output = torch.tanh(self._hidden[0])
        elif self._useForce:
            output = torch.tanh(torch.matmul(self._J['out'], torch.tanh(self._hidden)))
        else:
            output = torch.matmul(self._J['out'], self._hidden)

        return output, self._hidden.clone()

    def feed(self, inpt_data, return_hidden=False, return_states=False):
        '''
        Feeds an input data sequence into an RNN

        Parameters
        ----------
        inpt_data : PyTorch CUDA Tensor
            Inputs sequence to be fed into RNN. Has shape (batchSize, inputSize, Time)
        return_hidden : BOOL, optional
            When True, the hidden states of the RNN are returned as list of length
            Time where each element is a NumPy array of shape (hiddenSize, batchSize)
            containing the hidden state activations through the course of the input
            sequence. The default is False.

        Returns
        -------
        output_trace : PyTorch CUDA Tensor
            output_trace: output of the network over all timesteps. Will have shape 
            (batchSize, inputSize) i.e. 40x1 for single sample inputs
        hidden_states : PyTorch CUDA Tensor
            hidden_states: hidden states of the network through a trial, has shape
            (batch_size, time_steps, hidden_size)

        '''

        #num_inputs = len(inpt_data[0])
        batch_size = inpt_data.shape[0]
        inpt_seq_len = inpt_data.shape[-1]
        assert inpt_data.shape[1] == self._inputSize, "Size of inputs:{} does not match network's input size:{}".format(inpt_data.shape[1], self._inputSize)
        num_t_steps = inpt_data.shape[2]
        
        output_trace = torch.zeros(num_t_steps, batch_size).to(self._device)
        if return_hidden:
            hidden_trace = []
        if return_states:
            hidden_states = torch.zeros((batch_size, inpt_seq_len, self._hiddenSize), requires_grad=True)

        self._init_hidden(numInputs=batch_size)  # initializes hidden state
        inpt_data = inpt_data.permute(2,1,0)     # now has shape TxMxB
        for t_step in range(len(inpt_data)):
            output, hidden = self._forward(inpt_data[t_step])
            if return_hidden:
                hidden_trace.append(hidden.cpu().detach().numpy())      # unsure if there are any dependencies on hidden_trace

            if return_states:
                hidden_states[:,t_step,:] = hidden.T                    # (batch_size, hidden_size)
                
            if self._useHeb:
                output_trace[t_step,:] = hidden.detach()[0]
            else:
                output_trace[t_step,:] = output
        if return_hidden:
            hh = np.array(hidden_trace)
            return output_trace, hh
            
        if return_states:
            return output_trace, hidden_states
        #print('shape of output trace', len(output_trace[0]))
        return output_trace
        

    def save(self, N="", tElapsed=0, *kwargs):
        '''
        saves RNN parameters and attributes. User may define additional attributes
        to be saved through kwargs
        '''
        print('valdiation history', self._valHist)
        if N=="":     # no timestamp
            model_name = self._MODEL_NAME+'.pt'
            print("model name: ", model_name)
        else:         # timestamp
            model_name = self._MODEL_NAME + '_' + str(N) + '.pt'
        if tElapsed==0:
            torch.save({'weights': self._J, \
                        'weight_hist':self._w_hist, \
                        'activities': self._activityTensor, \
                        'targets': self._targets, \
                        'pca': self._pca, \
                        'losses': self._losses, \
                        'rec_magnitude' : self._recMagnitude, \
                        'neuron_idx': self._neuronIX,\
                        'validation_history' : self._valHist,
                        'fixed_points': self._fixedPoints}, model_name)
                
            # save model hyper-parameters to text file
            f = open(self._MODEL_NAME+".txt","w")
            for key in self._hParams:
                f.write( str(key)+" : "+str(self._hParams[key]) + '\n')
            f.write( "total training time: " + str(self._totalTrainTime) + '\n')
            f.close()
            
        else:
            torch.save({'weights': self.J, \
                        'weight_hist':self.w_hist, \
                        'activities': self.activity_tensor, \
                        'targets': self.targets, \
                        'pca': self.pca, \
                        'losses': self.losses, \
                        'rec_magnitude' : self.rec_magnitude, \
                        'neuron_idx': self.neuron_idx, \
                        'fractions' : self.fractions, \
                        'validation_history' : self.valHist, \
                        'tElapsed' : tElapsed,
                        'fixed_points' : self._fixedPoints}, model_name)
        
        #torch.save({'weights': self.J, 'targets': self.targets,  'losses': self.losses,'validation_history' : self.valHist}, model_name)

    def load(self, model_name, *kwargs):
        '''
        Loads in parameters and attributers from a previously instantiated model.
        User may define additional model attributes to load through kwargs
        '''
        # add file suffix to model_name
        fname = model_name+'.pt'
        model_dict = torch.load(fname)
        # load attributes in model dictionary
        if 'weights' in model_dict:
            self._J = model_dict['weights']
            for layer in self._J:  # move weights to correct device
                self._J[layer] = self._J[layer].to(self._device)
        else:
            print('WARNING!! NO WEIGHTS FOUND\n\n')
        if 'activities' in model_dict:
            self._activityTensor = model_dict['activities']
        else:
            print('WARNING!! NO ACTIVITIES FOUND\n\n')
        if 'targets' in model_dict:
            self._targets = model_dict['targets']
        else:
            print('WARNING!! NO TARGETS FOUND\n\n')
        if 'pca' in model_dict:
            self._pca = model_dict['pca']
        else:
            print('WARNING!! NO PCA DATA FOUND\n\n')
        if 'losses' in model_dict:
            self._losses = model_dict['losses']
        else:
            print('WARNING!! NO LOSS HISTORY FOUND\n\n')
        if 'validation_history' in model_dict:
            self._valHist = model_dict['validation_history']
        else:
            print('WARNING!! NO VALIDATION HISTORY FOUND\n\n')
            
        if 'fixed_points' in model_dict:
            self._fixedPoints = model_dict['fixed_points']
            
        # try to load additional attributes specified for kwargs
        for key in kwargs:
            print('loading of', key, 'has not yet been implemented!')

        
        
        
        
        if 'rec_magnitude' in model_dict:
            self.rec_magnitude = model_dict['rec_magnitude']
        else:
            print('WARNING!! NO WEIGHT HISTORY FOUND\n\n')
        if 'neuron_idx' in model_dict:
            self.neuron_idx = model_dict['neuron_idx']
        else:
            print('WARNING!! NO NEURON INDEX FOUND\n\n')

        print('\n\n')
        print('-'*50)
        print('-'*50)
        print('RNN model succesfully loaded ...\n\n')


    # maybe I should consider learning the initial state?
    def _init_hidden(self, numInputs=1):
        self._hidden = self._hiddenInitScale*(torch.randn(self._hiddenSize, numInputs).to(self._device))
    
    def GetF(self):
        W_rec = self._J['rec'].data.cpu().detach().numpy()
        W_in = self._J['in'].data.cpu().detach().numpy()
        b = self._J['bias'].data.cpu().detach().numpy()
        ReLU_flag = self._use_ReLU

        def master_function(inpt, relu=ReLU_flag):
            dt = 0.1
            sizeOfInput = len(inpt)
            inpt = inpt.reshape(sizeOfInput,1)
            if relu:
                return lambda x: np.squeeze( dt*np.matmul(W_in, inpt) + dt*np.matmul(W_rec, (np.maximum( np.zeros((self._hiddenSize,1)), x.reshape(self._hiddenSize,1)) )) - dt*x.reshape(self._hiddenSize,1) + b*dt)
            else:
                if self._useHeb: #TODO: update this to incorporate bias
                    def update_fcn(x):
                        x[1] = 1       # Bias from Miconi 2017
                        x[10] = 1
                        x[11] = -1
                        x = np.squeeze( dt*np.matmul(W_in, inpt) + dt*np.matmul(W_rec, (np.tanh(x.reshape(self._hiddenSize,1)))) - dt*x.reshape(self._hiddenSize,1) + b*dt)
                        x[1] = 1       # Bias from Miconi 2017
                        x[10] = 1
                        x[11] = -1
                        return x
                    #return lambda x: np.squeeze( dt*np.matmul(W_in, inpt) + dt*np.matmul(W_rec, (np.tanh(x.reshape(self._hiddenSize,1)))) - dt*x.reshape(self._hiddenSize,1) + b*dt)
                    return update_fcn
                elif self._useForce:
                    return lambda x: np.squeeze( dt*np.matmul(W_in, inpt) + dt*np.matmul(W_rec, (np.tanh(x.reshape(self._hiddenSize,1)))) - dt*x.reshape(self._hiddenSize,1) + b*dt)
                else:  # BPTT and GA RNNs
                    return lambda x: np.squeeze( dt*np.matmul(W_in, inpt) + dt*np.matmul(W_rec, (1+np.tanh(x.reshape(self._hiddenSize,1)))) - dt*x.reshape(self._hiddenSize,1) + b*dt)

        return master_function
        
    def plotLosses(self):
        plt.plot(self._losses)
        plt.ylabel('Loss')
        plt.xlabel('Trial')
        plt.title(self._MODEL_NAME)

## Build the Data Loader

In [None]:
# class for generating data 
class ContextDependentIntegration():
    '''
    This class simulates the Mante and Sussillo 2013 context-dependent 
    integration task described here:

    https://www.nature.com/articles/nature12742
    '''

    N_INPUT_CHANNELS = 2

    def __init__(self, N=750, mean=0.1857, var=1, device="cuda:0"):
        # determines if code should run on CPU or GPU
        if torch.cuda.is_available() and device=="cuda:0":
            self._device = torch.device(device)
        else:
            self._device = torch.device("cpu")

        self.N = N                 # number of time steps in a trial
        self._max_mean = mean      # maximum mean value of signal
        self._var = var            # variance of signal
        #self._version = ""        # is this deprecated ??
        #self._name = "Ncontext"   # is this deprecated ??

    def _random_generate_input(self, mean, trial_prob=0.5):
        '''
        Generates a 1D noisy signal for a context with the provided mean. The 
        chance of a positive trial is given by trial_prob.
        '''
        if torch.rand(1).item() < trial_prob:
            return mean*torch.ones(self.N)
        return -mean*torch.ones(self.N)
        
    def GetInput(self, mean_overide=None, var_overide=None):
        '''
        Generates a single trial of data for the CDI-task. Returns the
        perceptual data (inpts) and the labels (target)


        Parameters
        ----------
        mean_overide : TYPE, optional
            DESCRIPTION. The default is 1.

        Returns
        -------
        inpts : PyTorch CUDA Tensor
            DESCRIPTION.
        target : TYPE
            DESCRIPTION.
        '''
        inpts = torch.zeros((self.N, 2*self.N_INPUT_CHANNELS)).to(self._device)
        
        # sample trial mean from ~U[-max_mean, max_mean]
        mean = torch.rand(1).item() * self._max_mean

        # allows caller to use a deterministic mean for this trial
        if mean_overide is not None:
            mean = mean_overide

        # allows caller to override the default noise
        if var_overide is not None:
            var = var_overide
        else:
            var = self._var  

        # randomly sets one of the channels to be on
        go_channel = randrange(self.N_INPUT_CHANNELS)
        for channel_num in range(self.N_INPUT_CHANNELS):
            inpts[:,channel_num] = self._random_generate_input(mean)
            # generate the GO signals
            if go_channel == channel_num:
                inpts[:, self.N_INPUT_CHANNELS + channel_num] = 1
                target = torch.sign(torch.mean(inpts[:, channel_num]))
            else:
                inpts[:, self.N_INPUT_CHANNELS + channel_num] = 0

        
        # adds noise to inputs
        inpts[:,:self.N_INPUT_CHANNELS] += var*torch.randn(self.N, 
            self.N_INPUT_CHANNELS).to(self._device)
        return inpts, target
      

In [None]:
# visualize a few trials
training_task = ContextDependentIntegration()

'''for sample_ix in range(10):
    sample_trial, sample_target = training_task.GetInput(mean_overide=1, 
                                                         var_overide=0.25)


    plt.figure(sample_ix)
    plt.plot(sample_trial.cpu()[:,0])
    plt.plot(sample_trial.cpu()[:,1])
    plt.plot(sample_trial.cpu()[:,2])
    plt.plot(sample_trial.cpu()[:,3])
    plt.legend(["context 1", "context 2", "GO 1", "GO 2"])
    plt.title("Goal: " + str(sample_target.item()))'''

'for sample_ix in range(10):\n    sample_trial, sample_target = training_task.GetInput(mean_overide=1, \n                                                         var_overide=0.25)\n\n\n    plt.figure(sample_ix)\n    plt.plot(sample_trial.cpu()[:,0])\n    plt.plot(sample_trial.cpu()[:,1])\n    plt.plot(sample_trial.cpu()[:,2])\n    plt.plot(sample_trial.cpu()[:,3])\n    plt.legend(["context 1", "context 2", "GO 1", "GO 2"])\n    plt.title("Goal: " + str(sample_target.item()))'

## Write a training loop

In [None]:
PRINT_EVERY = 10

class BPTT(RNN):
    '''create a trainer object that will be used to trian an RNN object'''
    def __init__(self, hyperParams, task, lr=5e-4):
        super(BPTT, self).__init__(hyperParams, task)
        '''
        description of parameters:

        '''

        self._num_epochs = 2_000
        self._learning_rate = lr   
        self._hParams["learning_rate"] = self._learning_rate
        
        # cast as PyTorch variables
        self._J['in'] = Variable(self._J['in'], requires_grad=True)
        self._J['rec'] = Variable(self._J['rec'], requires_grad=True)
        self._J['out'] = Variable(self._J['out'], requires_grad=True)
        
        #self._params = [self._J['in'], self._J['rec'], self._J['out']]
        self._params = [self._J['rec']]
        self._optimizer = torch.optim.Adam(self._params, lr=self._learning_rate)


        
        self.all_losses = []
        
        self._hidden = Variable(self._hidden, requires_grad=True)


    def loss_fn(self, target, output, hidden_states=0):

        # extract start and end of output
        y_end = output[-1]
        y_strt = output[0]

        # use loss from Mante 2013
        squareLoss = (y_end-torch.sign(target.T))**2 + (y_strt - 0)**2
        SquareLoss = torch.sum( squareLoss, axis=0 )
        return SquareLoss  


    def train_one_batch(self, input, trial, condition):
        '''
        trains a model on a single batch of trials

        Parameters
        ----------
        input : PyTorch A Tensor
            Inputs sequence to be fed into RNN. Has shape (batchSize, sequence_len, inputSize)
        
        Returns
        -------
        output : PyTorch CUDA Tensor
        '''

        batch_size, inpt_seq_len, input_size = input.shape      # input dimensions

        #create an activities tensor for the rnn_model
        self._optimizer.zero_grad()
        #self.StoreRecMag()

        output = torch.zeros((self._task.N, batch_size*self._outputSize)) 
        output_temp = torch.Tensor([0])

        trial_length = self._task.N
        #hidden_states = torch.zeros((batch_size, inpt_seq_len, self._hiddenSize), requires_grad=True)
        for i in range(trial_length): 
            inputNow = input[:,i,:].t()
            output_temp, hidden = self._forward(inputNow)           #I need to generalize this line to work for context task
            #hidden_states[:,i,:] = hidden.T                         # (batch_size, hidden_size)
            #output_temp, hidden = self.rnn_model.forward(input[:,i], hidden, dt)             #this incridebly hacky must improve data formatting accross all modules to correctly implement a context task that doesn't clash with DM task
            output[i] = np.squeeze(output_temp)
            if (i %10 == 0):
                activityIX = int(i/10)
                self._activityTensor[self.trial_count, activityIX, :] = np.squeeze(torch.tanh(self._hidden).cpu().detach().numpy())[:,0]
        self.trial_count += 1
        # self.activity_tensor[trial, i, :] = hidden.detach().numpy()  # make sure calling detach does not mess with the backprop gradients (I think .data does)
        # https://pytorch.org/docs/stable/autograd.html
        #pdb.set_trace()
        loss = self.loss_fn(condition, output)
        loss.backward()
        self._optimizer.step()
        
        return output, loss.item()

    
    def getBatch(self):
        x_batch = torch.zeros((self._batchSize, self._task.N, self._inputSize))
        #x_batch = torch.zeros((750, self._inputSize, self._batchSize))
        y_batch = torch.zeros(self._batchSize)
        for dataPtIX in range(self._batchSize):
            inpt, condition = self._task.GetInput()
            x_batch[dataPtIX,:,:] = inpt
            y_batch[dataPtIX] = condition
        return x_batch, y_batch


    def train(self, termination_accuracy=0.9):
        self._startTimer()
        # pre-generate a set of validation trials that will be constant throughout training
        self.createValidationSet()
        # create activity tensor
        self._activityTensor = np.zeros((self._num_epochs, int(self._task.N/10), self._hiddenSize))

        
        # inps_save = np.zeros((num_epochs, trial_length))
        self.trial_count=0
        self.targets = []   #will hold target output for each trial
        
        validation_accuracy = 0.0
        validation_acc_hist = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        meanLossHist = 100*np.ones(20)
        # empty list that will hold history of validation accuracy
        loss_hist = []
        
        # for CUDA implementation
        inpt = Variable(torch.zeros(int(self._batchSize), self._task.N, self._inputSize).to(self._device))
        
        #trial = 0
        loss = np.inf
        
        # start main training loop
        while(validation_accuracy < termination_accuracy):
            # terminate training after maximum allowed epochs
            if self.trial_count >= self._num_epochs:
                break
            
            # periodically prints training status
            if self.trial_count %PRINT_EVERY == 0:
                print('trial #:', self.trial_count)
                print('validation accuracy', validation_accuracy)
                print("validation history", validation_acc_hist)
            #print('loss', loss)\
            
            
            # generate a batch of training trials
            # I should move the getBatch method to the data class
            inpt[:], condition = self.getBatch()    # inpt has shape 750x1
            self.targets.append(condition[-1].item())
            condition = Variable(condition)
            
            # train model on this data    
            self._init_hidden()   # resets the hidden state for new trials
            # train the network on this batch of trials
            output, loss = self.train_one_batch(inpt, self.trial_count, condition)
            
            # append current loss to history
            self.all_losses.append(loss)

            validation_accuracy_curr = self.GetValidationAccuracy()
            loss_hist.append(1-validation_accuracy_curr)
            #print('loss hist', np.mean(np.diff(meanLossHist)))
            validation_acc_hist[:9] = validation_acc_hist[1:]
            validation_acc_hist[-1] = validation_accuracy_curr
            meanLossHist[:19] = meanLossHist[1:]
            meanLossHist[-1] = np.mean(self.all_losses[-20:])
            validation_accuracy = np.min(validation_acc_hist)
                 
            
            # save the model every 100 trials
            # why am I saving the model twice?
            if self.trial_count %100 == 0:
                self.saveProgress()
            
        self._targets = np.array(self.targets)        #hacky
        self._losses = np.array(self.all_losses)#self.all_losses      #also hacky
        self._activityTensor = self._activityTensor[:self.trial_count,:,:]
        print('shape of activity tensor', self._activityTensor.shape)  
        print('trial count', self.trial_count)
        self._endTimer()


    def saveProgress(self):
        self._targets = self.targets        #hacky
        self._losses = np.array(self.all_losses)#self.all_losses      #also hacky
        #self.rnn_model.activity_tensor = self.rnn_model.activity_tensor[:self.trial_count,:,:]
        self.save()
        print('model back-ed up')

In [None]:
hyper_params = {                  # dictionary of all RNN hyper-parameters
   "inputSize" : 4,
   "hiddenSize" : 50,
   "outputSize" : 1,
   "g" : 1 ,
   "inputVariance" : 0.5,
   "outputVariance" : 0.5,
   "biasScale" : 0,
   "initScale" : 0.3,
   "dt" : 0.1,
   "batchSize" : 500,
   "taskMean" : 0.1857,
   "taskVar" : 1,
   "ReLU" : 0
   }

my_rnn = BPTT(hyper_params, training_task)

In [None]:
my_rnn.setName("tmp")
import time

In [None]:
my_rnn.train()


sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/lib/python3.7/bdb.py", line 332, in set_trace
    sys.settrace(self.trace_dispatch)



> <ipython-input-2-1925a9166785>(120)createValidationSet()
-> self.validationData = torch.zeros(test_iters, self._inputSize,  self._task.N).to(self._device)
(Pdb) c



sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/lib/python3.7/bdb.py", line 343, in set_continue
    sys.settrace(None)



Validation dataset created!

trial #: 0
validation accuracy 0.0
validation history [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]




trial #: 10
validation accuracy 0.504
validation history [0.511, 0.5125, 0.5095, 0.508, 0.504, 0.505, 0.507, 0.507, 0.5045, 0.5045]
trial #: 20
validation accuracy 0.4915
validation history [0.5035, 0.503, 0.4995, 0.503, 0.5005, 0.495, 0.4915, 0.493, 0.497, 0.4955]
trial #: 30
validation accuracy 0.498
validation history [0.498, 0.5045, 0.5, 0.5005, 0.504, 0.508, 0.5095, 0.521, 0.5155, 0.522]

Program interrupted. (Use 'cont' to resume).
> <ipython-input-2-1925a9166785>(193)_UpdateHidden()
-> self._hidden = hidden_next        # updates hidden layer
--KeyboardInterrupt--
--KeyboardInterrupt--
--KeyboardInterrupt--


## Visualize the Results

In [None]:
# view training curves

In [None]:
# sample some behavior

In [None]:
# view some trajectories in state-space

In [None]:
# find attractor states