# Implementing, training, and analyzing recurrent neural networks, Notebook 1

### Brandon McMahan & Jonathan Kao

### For the CalTech DataSAI Summer School, July 15, 2022.

These notebooks will walk you through implementing, training, and analyzing a recurrent neural network for performing the context-dependent integration (CDI) task in _Mante V, Sussillo D, Shenoy KV, Newsome WT. Context-dependent computation by recurrent dynamics in prefrontal cortex. Nature. 2013;503: 78–84._

This is notebook 1, where you will implement the RNN architecture, initialization, and forward pass.

# Overview of

In [37]:
# import the necesary modules
import numpy as np
from random import randrange
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# 1. Brief Introduction To Working in PyTorch


## Automatic Differentiation 

Short lecture on how PyTorch Enables the automatic computation of derivatives followed by an exercise in computing the derivatives for a 'perceptron-like' function.


In [12]:
# define a function and backward propogate it to get the derivatives
# we will compute the function f(x) = 3x^2 + 6x - 5
# and then evaluate this function at f'(0) and f'(-1)

x = torch.tensor([0., -1.], requires_grad=True)
f_x = 3*x**2 + 6*x - 5
f_x.backward(gradient=torch.tensor([1.,1.]))

# now check that the gradients are correct
print("f'(0) ={0}\nf'(-1)={1}".format(x.grad[0], x.grad[1]))


f'(0) =6.0
f'(-1)=0.0


### Now its your turn
Try implementing the function below and using automati differentation to compute the derivaties.

In [68]:
## You will implement a function and use the PyTorch autograd library to compute
## the derivatives.
## 
## First, define the following function using PyTorch Tensors:
##
## h(x, y) = g(f(x, y)) where,
## g(z) = tanh(z)
## f(x,y) = xy 
## where x, and y are matrices and xy represents the matrix multiplication
##
## HINT: to multiply two matrices in PyTorch you may use X@Y or toch.matmul(x,y)
x = torch.tensor([[.7,-.3,-.4],[-.9,.2,.3]], requires_grad=True)  # x is a 2x3 matrix
y = torch.tensor([[1.],[0.],[1.]], requires_grad=True)      # y is a 3x1 matrix

## TODO: implement the function to compute h(x,y)

## BEGIN SOLUTION
f_x_y = torch.matmul(x, y)
h_x_y = torch.tanh(f_x_y)

print(h_x_y)

## TODO: use the autograd library to get the gradients of dh/dx and dh/dy

## BEGIN SOLUTION
h_x_y.backward(gradient=torch.tensor([[1.],[1.]]))
print(x.grad)
print(y.grad)


tensor([[ 0.2913],
        [-0.5370]], grad_fn=<TanhBackward0>)
tensor([[0.9151, 0.0000, 0.9151],
        [0.7116, 0.0000, 0.7116]])
tensor([[ 0.0002],
        [-0.1322],
        [-0.1526]])


## How to use subclassing to Create Neural Network models in PyTorch

brief introduction on how to create models with sublcasssing in PyTorch. 

In short, all we use the following boilerplate code: 


1.   Our neural network module inherits from the PyTorch nn.Module class
2.   We need to initialize our model with all of its weights in __init__
3.   we need to define the forward pass

After this we have a neural network model that we can use for training and inference



In [57]:
class ExampleNeuralNetwork(nn.Module):
  def __init__(self):
    # weight matrix
    super(ExampleNeuralNetwork, self).__init__()
    self.W = Variable(torch.randn((100, 28)))
    # bias vector
    self.b = Variable(torch.randn((1,100)))

  def forward(self, x):
    z = torch.matmul(x, torch.transpose(self.W, 1, 0))  # batch_size, 100
    z = z + self.b
    return F.relu(z)



### Now its your turn

Now try building a simple linear network with a relu activation and 10 output classes.

You will use two layers. Each layer will have its own weight and bias. The first layer should map an input from 28 to 100 dimensions. The outputs of this layer should use a ReLU activation.

The second layer should map an input from 100 down to 10 dimensions. This layer should use a SoftMax activation.

In [29]:
class MyNeuralNetwork(nn.Module):
  def __init__(self):
    super(MyNeuralNetwork, self).__init__()
    ## TODO: define and initialize all the weights and biases you will need
    pass
  def forward(self, x):
    ## TODO: define the forward pass through the network
    pass

## Training and Inferencing from trained Neural Networks
Now we will demonstrate how to write a simple training *loop*

In [58]:
# create an instance of our neural network
my_net = ExampleNeuralNetwork()

# forward pass some random data through our network
rand_output = my_net(torch.randn((10, 28)))

print(rand_output.shape)

torch.Size([10, 100])


# 2. Implementation: RNN architecture and forward pass.

This part of the notebook will guide you through implementing the architecture and forward pass of the recurrent neural network. As we discussed in the Methods introduction, we do not have to calculate the backwards pass due to PyTorch's autodifferentiation capabilities. (Note that TensorFlow also has this.)


## Now we are ready to create an RNN

We will implement each of the functions seperately as code blocks and then stitch everything together at the end into a single RNN class

### RNN initialization and forward pass

While we will later make an RNN class with all these functions, we wanted to separate out the implementation of these functions at first so you can debug each one on its own.

Please first complete this notebook, which implements the architecture, initialization, and forward pass of the RNN. After this, we will provide our RNN class with these functions.

Also note, that while in the above example we had a single function to forward pass the inputs, here we will define three functions. We will define a ```update_hidden``` function that direclty updates the hidden state. We will write a ```forward``` function whihc passes a single input forward through the RNN and lastly we will implement a ```feed``` function which passes an entire mini-batch of inputs forward through the RNN.

#### Network initialization

In the following cell, we initialize all variables that will be used to compute the RNN forward pass.



In [59]:
## Input hyperparameters

inputSize = 2       # number of inputs
hiddenSize = 5      # number of artificial neurons
outputSize = 1      # of outputs

## Later on, self.device in the RNN class will be set to either "cpu" (to do all computations on the cpu) 
## or "cuda:0" (to do computations on GPU).

device = "cpu"

## Initialization hyperparameter for W_rec, usually g = 1.

g = 1

## Hyperparameters for the variance of the weight initialization for the input and output weight matrices.

inputSigma = 0.5      # input weights standard deviation
outputSigma = 0.1     # output weights standard deviation
biasScale = 0         # value of the biases


Next, you will implement the initialization of all the weights and biases.

In [60]:
## Initialize W["in"], W["rec"], W["out"], and biases. These will be stored in the dictionary self_W.

W = []
f = torch.tanh

## TODO:
##
##   Initialize all the weights, which are stored in the dictionary self_W.
##
##   The RNN equation is:
## 
##     x_{t+1} = (1-alpha) * x_t + (1-alpha) * (W_{rec} * f(x_t) + W_in * u_t + bias)
##
##   Here, the matrices are:
##     W_{in} is W["in"]
##     W_{rec} is W["rec"]
##     W_{out} is W["out"]
##     bias is self_W["bias"]
##     The activation function is f
##
##   The initializations should be:
##     W_{in} ~ N(0, (self.inputSigma)^2)
##     W_{rec} ~ N(0, (self.g**2 / self.hiddenSize)^2)
##     W_{out} ~ N(0, (self.outputSigma)^2)
##     bias ~ N(0, self.biasScale) 
##
##   Note 1: these should all be Torch tensors (not numpy arrays!)
##   Note 2: for this implementation, bias should be (self_hiddenSize,).
##
##   To make all of these CUDA compatible (i.e., able to run on GPU), append ".to(self_device)" after.
##   While these will run on the CPU, later on, this will enable us to run on the GPU if self.device in the RNN class is set to the GPU.
##    
##   EXAMPLE:
##
##     W['in'] = 0.5 * torch.randn(self_hiddenSize, self.inputSize)
##       becomes
##     W['in'] = (0.5 * torch.randn(self_hiddenSize, self.inputSize)).to(self_device)
##
##   Lastly, as we will do a check to make sure you initialized correctly, you should only make 
##    three calls to torch.randn; one for W_{in}, W_{rec}, and W_{out}

# Set the random seed so our results are comparable.

torch.manual_seed(0)

## BEGIN SOLUTION
W = {
    'in' : torch.randn(self_hiddenSize, self_inputSize).to(self_device)*self_inputSigma,
    'rec' : ((self_g**2)/self_hiddenSize)*torch.randn(self_hiddenSize, self_hiddenSize).to(self_device),
    'out' : (self_outputSigma*torch.randn(self_outputSize, self_hiddenSize).to(self_device)),
    'bias' : torch.zeros(self_hiddenSize,).to(self_device)*self_biasScale,
    }

## END SOLUTION

NameError: ignored

#### Initialization implementation check

Run the following code to check your initialization dimensions and values. Because we set the random seed, you should get the same results. 

All the assertions should pass if your initializations are correct.

In [62]:
# Check on dimensions and types

assert W["in"].numpy().shape == (self_hiddenSize, self_inputSize), "Dimensions of W_in are incorrect."
assert W["rec"].numpy().shape == (self_hiddenSize, self_hiddenSize), "Dimensions of W_rec are incorrect."
assert W["out"].numpy().shape == (self_outputSize, self_hiddenSize), "Dimensions of W_out are incorrect."
assert W["bias"].numpy().shape == (self_hiddenSize,), "Dimensions of bias are incorrect."
assert torch.is_tensor(W["in"]), "W_in is not a torch.FloatTensor"
assert torch.is_tensor(W["rec"]), "W_rec is not a torch.FloatTensor"
assert torch.is_tensor(W["out"]), "W_out is not a torch.FloatTensor"
assert torch.is_tensor(W["bias"]), "bias is not a torch.FloatTensor"

TypeError: ignored

In [63]:
# Check on values

assert torch.sum(torch.abs(self_W['in'] - torch.tensor((( 0.7705, -0.1467), (-1.0894,  0.2842), (-0.5423, -0.6993), ( 0.2017,  0.4190), (-0.3596, -0.2017))))) <= 1e-2, "W_in initialization is not correct"
assert torch.sum(torch.abs(self_W['rec'] - torch.tensor((( 0.1198, -0.3110, -0.0683,  0.3706,  0.0936), (-0.0315,  0.2887,  0.0532,  0.2779, -0.1357), ( 0.1877,  0.0978, -0.1346,  0.1746,  0.2111), ( 0.0356, -0.0461, -0.0613, -0.3162,  0.3413), (-0.0892,  0.1488,  0.3042,  0.6821, -0.3062))))) <= 1e-2, 'W_rec initialization is not correct'
assert torch.sum(torch.abs(self_W['out'] - torch.tensor(((-0.0531, -0.0431, -0.2286,  0.0070,  0.0667))))) < 1e-2, 'W_out initialization is not correct'

NameError: ignored

### Implement hidden state update for one time step

In this section, you will now write a function, "out = updateHidden(inpt)" that takes in one set of RNN inputs and computes the next hidden state. When we say hidden_state here, we are referring to the pre-activation hidden state, x, and not f(x) (where f is the nonlinearity). To calculate the next hidden_state requires the previous hidden_state, which we will assume is stored as "self.hidden". This update is just for ONE time step and one input. We will later use these functions to calculate all the hidden states and outputs across all time.

In [64]:
## You will implement the updateHidden function.
##
##     x_{t+1} = (1-alpha) * x_t + (1-alpha) * (W_{rec} * f(x_t) + W_in * u_t + bias)
##
## Assume that alpha = 1/self.dt
##
## 

def updateHidden(hidden_current, inpt):
    '''Updates the hidden state of the RNN for one time step.
        
        Parameters
        ----------
        hidden_curent: PyTorch Tensor
            current hidden state
        inpt: PyTorch Tensor
            inputs to the network at current time step. Has shape (inputSize, batchSize)

        Returns
        -------
        hidden_next: PyTorch cuda Tensor
            Tensor of the updated neuron activations. Has shape (hiddenSize, batchSize)

    '''
    
    self_dt = 10
    alpha = 1/self_dt
    
    ## TODO: Implement one time step to update the hidden activation.
    
    ## BEGIN SOLUTION

    # forward pass through the network
    hidden_next = alpha*torch.matmul(W["in"], inpt) + \
        alpha*torch.matmul(W['rec'], self_f(hidden_current)) + \
        (1-alpha)*hidden_current + alpha*W['bias']
    
    ## END SOLUTION
    
    return hidden_next
    

#### updateHidden implementation check

The following checks your implementation of updateHidden.

Be sure that self_W is one that passed all the prior implementation checks.

In [65]:
## code to check updateHidden

torch.manual_seed(0)

hidden_current = torch.randn((hiddenSize,))
inpt = torch.randn((inputSize,))
hidden_next = updateHidden(hidden_current, inpt)

TypeError: ignored

### Compute output for one time step

Next, we will compute the output of the RNN for an input. You should use the updateHidden() function here. Be sure your updateHidden() passed the implementation check.

Computing the output given the inputs is typically called the forward pass, and so this function is called "forward".


In [66]:
def forward(hidden_current, inpt):
    '''
    Computes the RNNs forward pass output for a single timestep

    Parameters
    ----------
    hidden_curent: PyTorch Tensor
        current hidden state
    inpt: PyTorch Tensor 
        inputs to the network for the current timestep. Shape (inputSize, batchSize)

    Returns
    -------
    output: PyTorch Tensor
        output of the network after this timestep. Has shape (outputSize, batchSize)
    hidden_next: PyTorch Tensor
        copy of the hidden state activations after the forward pass. Has shape
        (hiddenSize, batchSize)

    '''
    
    # assert that inputs are torch tensors 
    assert torch.is_tensor(hidden_current), "Current hidden state is not a torch tensor."
    assert torch.is_tensor(inpt), "Input is not a torch tensor."


    ## BEGIN SOLUTION
    hidden_next = UpdateHidden(inpt)
    output = torch.matmul(self_W['out'], hidden_next)

    ## END SOLUTION

    return output, hidden_next.clone()

    

### Now compute all hidden states and outputs for all time steps

We will now use the forward() function you implemented to compute the forward pass for all time steps. We will call this function "feed()".

In [67]:
T = 10

torch.manual_seed(0)
inpt_data = torch.randn((inputSize, T)) # of dimensions (inputSize, Time)

## Note that in the class, our feed function will also be able to take batches of data
## but in this case, we'll just do it for one batch.

def feed(self, inpt_data):
    '''
    Feeds an input data sequence into an RNN

    Parameters
    ----------
    inpt_data : PyTorch Tensor
        Inputs sequence to be fed into RNN. Has shape (inputSize, Time)
        
    Returns
    -------
    output_trace : PyTorch Tensor
        output_trace: output of the network over all timesteps. Will have shape 
        (T,) i.e. 40x1 for single sample inputs
    hidden_states : PyTorch Tensor
        hidden_states: pre-activation hidden states of the network through a trial, has shape
        (hidden_size, T)

    '''

    #num_inputs = len(inpt_data[0])
    batchSize = inpt_data.shape[0]
    T = inpt_data.shape[1]
    assert inpt_data.shape[0] == inputSize, "Size of inputs:{} does not match network's input size:{}".format(inpt_data.shape[1], self_inputSize)

    # Normally the hidden size and output size are stored as class variables. For this function, we will 
    # simply define output_trace for a 1D output, and assume the hiddenSize is 5, continuing our example above.
    # These are hard-coded in, but in the class, they will be stored as e.g. self.hiddenSize.
    
    output_trace = torch.zeros(num_t_steps)
    hidden_states = torch.zeros((5,T), requires_grad=True)
    hidden_current = torch.zeros((5,))
    
    ## BEGIN SOLUTION # still in progress
    inpt_data = inpt_data.permute(1,0)     # now has shape TxMxB
    for t_step in range(len(inpt_data)):
        output, hidden = self._forward(inpt_data[t_step])
        output_trace[t_step,:] = output

    ## END SOLUTION
    
    return output_trace, hidden_states

## Import the RNN class

Congratulations on implementing the functions needed to build an RNN! We will now import the RNN class which just contains all the functions you have implemented. You can check out the RNN class in rnn.py if you want to see more details. For now, just understand that we have just copied the functions implemented above into the RNN class.

In [69]:
from rnn import RNN

ModuleNotFoundError: ignored

# 3. Building the Training Data Pipeline

Now that the RNN weights and forward pass have been implemented, next we are going to create our training data. We will be training these RNNs to perform the Context Dependent Integration (CDI) task as desribed in Mante and Susillo 2013. 

## First model the task

In [None]:
def _random_generate_input(self, mean, trial_prob=0.5):
        '''
        Generates a 1D noisy signal for a context with the provided mean. 
        
        Parameters
        ----------
        mean : float
          the mean value of the noisy context signal
        trial_prob : float
          the probability that this trial shoudl be positive
        
        Returns
        -------
        context_input : torch tensor
          Noisy context signal with dimensions (750,)
        
        The 
        chance of a positive trial is given by trial_prob.
        '''
        ## TODO: implement this
        context_input = None

        ## BEGIN 
        
        # with probability trial_prob generate a positive trial
        if torch.rand(1).item() < trial_prob:
          context_input = mean*torch.ones(self.N)
        else:
          context_input = -mean*torch.ones(self.N)  
           
        # add some noise to the signal
        context_input += torch.randn(self.N)  
        ## END SOLUTION

        return context_input

In [None]:
def GetInput(self, mean_overide=None, var_overide=None):
        '''
        Generates a single trial of data for the CDI-task. Returns the
        perceptual data (inpts) and the labels (target)


        Parameters
        ----------
        mean_overide : TYPE, optional
            DESCRIPTION. The default is 1.

        Returns
        -------
        inpts : PyTorch CUDA Tensor
            DESCRIPTION.
        target : TYPE
            DESCRIPTION.
        '''
        inpts = torch.zeros((self.N, 2*self.N_INPUT_CHANNELS)).to(self._device)
        
        # sample trial mean from ~U[-max_mean, max_mean]
        mean = torch.rand(1).item() * self._max_mean

        # allows caller to use a deterministic mean for this trial
        if mean_overide is not None:
            mean = mean_overide

        # allows caller to override the default noise
        if var_overide is not None:
            var = var_overide
        else:
            var = self._var  

        # randomly sets one of the channels to be on
        go_channel = randrange(self.N_INPUT_CHANNELS)
        for channel_num in range(self.N_INPUT_CHANNELS):
            inpts[:,channel_num] = self._random_generate_input(mean)
            # generate the GO signals
            if go_channel == channel_num:
                inpts[:, self.N_INPUT_CHANNELS + channel_num] = 1
                target = torch.sign(torch.mean(inpts[:, channel_num]))
            else:
                inpts[:, self.N_INPUT_CHANNELS + channel_num] = 0

        
        # adds noise to inputs
        inpts[:,:self.N_INPUT_CHANNELS] += var*torch.randn(self.N, 
            self.N_INPUT_CHANNELS).to(self._device)
        return inpts, target
      

# import the task class 
Now that you have implemented the code to model the trainin task, we can import the ContextDpendentIntegration task which uses this code to generate training data for the RNN class.



# 4. Writing a training loop
Now we can use the RNN class and the torch class to write a training loop to train our RNN. We will implement the following functions:

## get_loss
This function will compute the loss between the RNN outputs and the target outputs for the trial

## train one trial
This function will train the RNN on a single (mini-batch) of trials. 

## train
This function will pull everything together and train the RNN by looping over (mini-batches) of trials to train the RNN until we attain a validation accuracy of some target threshold (95% accuracy by default)

In [None]:
# implement the loss fucntion
def loss_fn(target, output):

        # extract start and end of output
        y_end = output[-1]
        y_strt = output[0]

        # use loss from Mante 2013
        squareLoss = (y_end-torch.sign(target.T))**2 + (y_strt - 0)**2
        SquareLoss = torch.sum( squareLoss, axis=0 )
        return SquareLoss  

In [None]:
# implement the train one epoch function

def train_one_batch(rnn, input, trial, condition):
    '''
    trains a model on a single batch of trials

    Parameters
    ----------
    input : PyTorch A Tensor
        Inputs sequence to be fed into RNN. Has shape (batchSize, sequence_len, inputSize)
    
    Returns
    -------
    output : PyTorch CUDA Tensor
    '''

    batch_size, inpt_seq_len, input_size = input.shape      # input dimensions

    #create an activities tensor for the rnn_model
    
    #self.StoreRecMag()

    output = torch.zeros((self._task.N, batch_size*self._outputSize)) 
    output_temp = torch.Tensor([0])

    trial_length = self._task.N
    #hidden_states = torch.zeros((batch_size, inpt_seq_len, self._hiddenSize), requires_grad=True)
    for i in range(trial_length): 
        inputNow = input[:,i,:].t()
        output_temp, hidden = self._forward(inputNow)           #I need to generalize this line to work for context task
        #hidden_states[:,i,:] = hidden.T                         # (batch_size, hidden_size)
        #output_temp, hidden = self.rnn_model.forward(input[:,i], hidden, dt)             #this incridebly hacky must improve data formatting accross all modules to correctly implement a context task that doesn't clash with DM task
        output[i] = np.squeeze(output_temp)
        if (i %10 == 0):
            activityIX = int(i/10)
            self._activityTensor[self.trial_count, activityIX, :] = np.squeeze(torch.tanh(self._hidden).cpu().detach().numpy())[:,0]
    self.trial_count += 1
    # self.activity_tensor[trial, i, :] = hidden.detach().numpy()  # make sure calling detach does not mess with the backprop gradients (I think .data does)
    # https://pytorch.org/docs/stable/autograd.html
    #pdb.set_trace()
    loss = self.loss_fn(condition, output)
    loss.backward()
    self._optimizer.step()
    
    return output, loss.item()

In [None]:
# implement the train function
def train(rnn, termination_accuracy=0.9):
    '''
    Trains RNN to termination accuracy

    Parameters
    ----------
    '''

    optimizer = torch.optim.Adam(rnn._params, lr=5e-4)
    # pre-generate a set of validation trials that will be constant throughout training
    rnn.createValidationSet()
    # create activity tensor
    #self._activityTensor = np.zeros((self._num_epochs, int(self._task.N/10), self._hiddenSize))

    
    # inps_save = np.zeros((num_epochs, trial_length))
    trial_count=0
    targets = []   #will hold target output for each trial
    
    validation_accuracy = 0.0
    validation_acc_hist = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    meanLossHist = 100*np.ones(20)
    # empty list that will hold history of validation accuracy
    loss_hist = []
    
    # for CUDA implementation
    inpt = Variable(torch.zeros(int(self._batchSize), self._task.N, self._inputSize).to(self._device))
    
    #trial = 0
    loss = np.inf
    
    # start main training loop
    while(validation_accuracy < termination_accuracy):
        # terminate training after maximum allowed epochs
        if self.trial_count >= self._num_epochs:
            break
        
        # periodically prints training status
        if self.trial_count %PRINT_EVERY == 0:
            print('trial #:', self.trial_count)
            print('validation accuracy', validation_accuracy)
            print("validation history", validation_acc_hist)
        #print('loss', loss)\
        
        
        # generate a batch of training trials
        # I should move the getBatch method to the data class
        inpt[:], condition = self.getBatch()    # inpt has shape 750x1
        self.targets.append(condition[-1].item())
        condition = Variable(condition)
        
        # train model on this data    
        rnn._init_hidden()   # resets the hidden state for new trials
        # train the network on this batch of trials
        optimizer.zero_grad()
        output, loss = train_one_batch(inpt, self.trial_count, condition)
        
        # append current loss to history
        self.all_losses.append(loss)

        validation_accuracy_curr = self.GetValidationAccuracy()
        loss_hist.append(1-validation_accuracy_curr)
        #print('loss hist', np.mean(np.diff(meanLossHist)))
        validation_acc_hist[:9] = validation_acc_hist[1:]
        validation_acc_hist[-1] = validation_accuracy_curr
        meanLossHist[:19] = meanLossHist[1:]
        meanLossHist[-1] = np.mean(self.all_losses[-20:])
        validation_accuracy = np.min(validation_acc_hist)
              
        
        # save the model every 100 trials
        # why am I saving the model twice?
        if self.trial_count %100 == 0:
            self.saveProgress()
        
    self._targets = np.array(self.targets)        #hacky
    self._losses = np.array(self.all_losses)#self.all_losses      #also hacky
    self._activityTensor = self._activityTensor[:self.trial_count,:,:]
    print('shape of activity tensor', self._activityTensor.shape)  
    print('trial count', self.trial_count)
    self._endTimer()


# 5. Visualizing RNN Behavior

## Psychometric Curves
write a function to generate psychometric *curves*

## Visualize attractor States
Write a function to find the attractor states in the RNN

In [None]:
# class for generating data 
class ContextDependentIntegration():
    '''
    This class simulates the Mante and Sussillo 2013 context-dependent 
    integration task described here:

    https://www.nature.com/articles/nature12742
    '''

    N_INPUT_CHANNELS = 2

    def __init__(self, N=750, mean=0.1857, var=1, device="cuda:0"):
        # determines if code should run on CPU or GPU
        if torch.cuda.is_available() and device=="cuda:0":
            self._device = torch.device(device)
        else:
            self._device = torch.device("cpu")

        self.N = N                 # number of time steps in a trial
        self._max_mean = mean      # maximum mean value of signal
        self._var = var            # variance of signal
        #self._version = ""        # is this deprecated ??
        #self._name = "Ncontext"   # is this deprecated ??

    def _random_generate_input(self, mean, trial_prob=0.5):
        '''
        Generates a 1D noisy signal for a context with the provided mean. The 
        chance of a positive trial is given by trial_prob.
        '''
        if torch.rand(1).item() < trial_prob:
            return mean*torch.ones(self.N)
        return -mean*torch.ones(self.N)
        
    def GetInput(self, mean_overide=None, var_overide=None):
        '''
        Generates a single trial of data for the CDI-task. Returns the
        perceptual data (inpts) and the labels (target)


        Parameters
        ----------
        mean_overide : TYPE, optional
            DESCRIPTION. The default is 1.

        Returns
        -------
        inpts : PyTorch CUDA Tensor
            DESCRIPTION.
        target : TYPE
            DESCRIPTION.
        '''
        inpts = torch.zeros((self.N, 2*self.N_INPUT_CHANNELS)).to(self._device)
        
        # sample trial mean from ~U[-max_mean, max_mean]
        mean = torch.rand(1).item() * self._max_mean

        # allows caller to use a deterministic mean for this trial
        if mean_overide is not None:
            mean = mean_overide

        # allows caller to override the default noise
        if var_overide is not None:
            var = var_overide
        else:
            var = self._var  

        # randomly sets one of the channels to be on
        go_channel = randrange(self.N_INPUT_CHANNELS)
        for channel_num in range(self.N_INPUT_CHANNELS):
            inpts[:,channel_num] = self._random_generate_input(mean)
            # generate the GO signals
            if go_channel == channel_num:
                inpts[:, self.N_INPUT_CHANNELS + channel_num] = 1
                target = torch.sign(torch.mean(inpts[:, channel_num]))
            else:
                inpts[:, self.N_INPUT_CHANNELS + channel_num] = 0

        
        # adds noise to inputs
        inpts[:,:self.N_INPUT_CHANNELS] += var*torch.randn(self.N, 
            self.N_INPUT_CHANNELS).to(self._device)
        return inpts, target
      


Add a section on how to build different RNN architectures in python

Add some intuition of BPTT

Add some inutition on the pros/cons of these methods for neuroscience


## Build an RNN

## Build the Data Loader

In [7]:
# visualize a few trials
training_task = ContextDependentIntegration()

'''for sample_ix in range(10):
    sample_trial, sample_target = training_task.GetInput(mean_overide=1, 
                                                         var_overide=0.25)


    plt.figure(sample_ix)
    plt.plot(sample_trial.cpu()[:,0])
    plt.plot(sample_trial.cpu()[:,1])
    plt.plot(sample_trial.cpu()[:,2])
    plt.plot(sample_trial.cpu()[:,3])
    plt.legend(["context 1", "context 2", "GO 1", "GO 2"])
    plt.title("Goal: " + str(sample_target.item()))'''

'for sample_ix in range(10):\n    sample_trial, sample_target = training_task.GetInput(mean_overide=1, \n                                                         var_overide=0.25)\n\n\n    plt.figure(sample_ix)\n    plt.plot(sample_trial.cpu()[:,0])\n    plt.plot(sample_trial.cpu()[:,1])\n    plt.plot(sample_trial.cpu()[:,2])\n    plt.plot(sample_trial.cpu()[:,3])\n    plt.legend(["context 1", "context 2", "GO 1", "GO 2"])\n    plt.title("Goal: " + str(sample_target.item()))'

## Write a training loop

In [8]:
hyper_params = {                  # dictionary of all RNN hyper-parameters
   "inputSize" : 4,
   "hiddenSize" : 50,
   "outputSize" : 1,
   "g" : 1 ,
   "inputVariance" : 0.5,
   "outputVariance" : 0.5,
   "biasScale" : 0,
   "initScale" : 0.3,
   "dt" : 0.1,
   "batchSize" : 500,
   "taskMean" : 0.1857,
   "taskVar" : 1,
   "ReLU" : 0
   }

my_rnn = BPTT(hyper_params, training_task)

In [10]:
my_rnn.setName("tmp")
import time

In [11]:
my_rnn.train()


sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/lib/python3.7/bdb.py", line 332, in set_trace
    sys.settrace(self.trace_dispatch)



> <ipython-input-3-1925a9166785>(120)createValidationSet()
-> self.validationData = torch.zeros(test_iters, self._inputSize,  self._task.N).to(self._device)
(Pdb) c



sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/lib/python3.7/bdb.py", line 343, in set_continue
    sys.settrace(None)



Validation dataset created!

trial #: 0
validation accuracy 0.0
validation history [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]




trial #: 10
validation accuracy 0.4695
validation history [0.54, 0.5465, 0.5385, 0.5245, 0.504, 0.491, 0.484, 0.4765, 0.4695, 0.47]
trial #: 20
validation accuracy 0.4755
validation history [0.4755, 0.4795, 0.486, 0.489, 0.5, 0.5065, 0.5085, 0.507, 0.509, 0.5085]
trial #: 30
validation accuracy 0.4775
validation history [0.5085, 0.5085, 0.501, 0.4985, 0.49, 0.4895, 0.486, 0.4785, 0.4775, 0.479]
trial #: 40
validation accuracy 0.4785
validation history [0.4785, 0.482, 0.488, 0.4865, 0.484, 0.489, 0.4945, 0.497, 0.5055, 0.507]
trial #: 50
validation accuracy 0.5055
validation history [0.5095, 0.507, 0.5055, 0.5095, 0.5135, 0.5145, 0.517, 0.519, 0.523, 0.529]
trial #: 60
validation accuracy 0.53
validation history [0.5305, 0.53, 0.536, 0.5375, 0.5405, 0.546, 0.552, 0.554, 0.5555, 0.56]
trial #: 70
validation accuracy 0.5605
validation history [0.561, 0.562, 0.5625, 0.563, 0.5605, 0.5615, 0.5685, 0.569, 0.5675, 0.572]
trial #: 80
validation accuracy 0.5745
validation history [0.5745, 0.583

FileNotFoundError: ignored

## Visualize the Results

In [None]:
# view training curves
# add code to visualize training accuracy 

In [None]:
# sample some behavior
# add code for psychometric curves

In [None]:
# find attractor states

In [None]:
# view some trajectories in state-space
# pass to our provided plot function