<a href="https://colab.research.google.com/github/BrisGeorge24044/Cognitive_Artificial_Intelligence/blob/main/Leaky_RNN_ContextDecisionMaking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install neurogym
!pip install gym==0.25.1

Collecting neurogym
  Downloading neurogym-0.0.2.tar.gz (79 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/79.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━[0m [32m71.7/79.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.1/79.1 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gym<0.25,>=0.20.0 (from neurogym)
  Downloading gym-0.24.1.tar.gz (696 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m696.4/696.4 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: neurogym, gym
  Building wheel for neurogym (setup.py) ... [?25l[?25hdone
  C

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import time

LeakyRNN Class defines the Leaky RNN and sets parameters.

In [4]:
class LeakyRNN(nn.Module):
  def __init__(self, input_size, hidden_size, dt = None):
    super().__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.tau = 100
    # Sets alpha as a function of the time step, dt, and the time constant, tau
    if dt is None:
      alpha = 1
    else:
      alpha = dt/self.tau
    self.alpha = alpha
    # Defines linear layers for the input to hidden and hidden to hidden
    # connections
    self.input2hidden = nn.Linear(input_size, hidden_size)
    self.hidden2hidden = nn.Linear(hidden_size, hidden_size)
  # init_hidden function initialises the hidden state as a matrix of zeros
  def init_hidden(self, input_shape):
    batch_size = input_shape[1]
    return torch.zeros(batch_size, self.hidden_size)

  # Recurrence function updates the hidden state
  def recurrence(self, input, hidden):
    # ReLU non-linear activation function is used to calculate the new hidden
    # state using the input and recurrent weight matrices
    h_new = torch.relu(self.input2hidden(input) + self.hidden2hidden(hidden))
     # Leaky integration is applied with alpha deciding how much of the last
    # state is retained
    h_new = hidden * (1 - self.alpha) + h_new * self.alpha
    return h_new

  # Forward function - process the sequence for every time step
  def forward(self, input, hidden = None):
    # if there is no hidden state, then it is initialised here
    if hidden is None:
      hidden = self.init_hidden(input.shape).to(input.device)

      output = []
      # Creates a loop for every time step
      steps = range(input.size(0))
      for i in steps:
        # Updates the hidden state using the recurrence function
        hidden = self.recurrence(input[i], hidden)
        output.append(hidden)

      # Stacks all outputs in one dimension
      output = torch.stack(output, dim = 0)
      return output, hidden

RNNNet Class creates a fully connected layer which converts the hidden states into final predictions.

In [5]:
class RNNNet(nn.Module):
  def __init__(self, input_size, hidden_size, output_size, **kwargs):
    super().__init__()
    # The LeakyRNN class is instantiated
    self.rnn = LeakyRNN(input_size, hidden_size, **kwargs)
    # A fully connected layer is created and applied as the last step of the
    # program to obtain predictions
    self.fc = nn.Linear(hidden_size, output_size)

  # Forward function processes the input through the Leaky RNN and the FC layer
  def forward(self, x):
    rnn_output, _ = self.rnn(x)
    # FC layer is applied to obtain the predictions for every time step
    out = self.fc(rnn_output)
    return out, rnn_output

In [6]:
import neurogym as ngym

# ContextDecisionMaking-v0 task is selected from neurogym
task_name = 'ContextDecisionMaking-v0'
# time step and stimulus length are defined
kwargs = {'dt': 20, 'timing': {'stimulus': 1500}}

  logger.warn(
