<a href="https://colab.research.google.com/github/BrisGeorge24044/Cognitive_Artificial_Intelligence/blob/main/LeakyRNN_ContextDecisionMaking_sparse_matrices.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
! pip install neurogym
!pip install gym==0.25.1

Collecting neurogym
  Downloading neurogym-0.0.2.tar.gz (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.1/79.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gym<0.25,>=0.20.0 (from neurogym)
  Downloading gym-0.24.1.tar.gz (696 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m696.4/696.4 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: neurogym, gym
  Building wheel for neurogym (setup.py) ... [?25l[?25hdone
  Created wheel for neurogym: filename=neurogym-0.0.2-py3-none-any.whl size=118573 sha256=b58ffc7678955aa38639ec6123603c1aec29b4e7881630e0a5cda17dba16d600
  Stored in directory: /root/.cache/pip/wheels/f4/57/a7/66ed4eccf946052534253e4279438b97133b64facca5

In [7]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import time

LeakyRNN Class defines the Leaky RNN and sets parameters.

In [8]:
class LeakyRNN(nn.Module):

  # Function creates sparse matrices
  def sparse_matrices(self, in_feat, out_feat, sparsity = 0.8):
    linear_layer = nn.Linear(in_feat, out_feat)
    # Creates a binary mask to apply the sparsity
    mask = (torch.rand(out_feat, in_feat) > sparsity).float()
    linear_layer.weight.data = linear_layer.weight.data * mask
    return linear_layer


  def __init__(self, input_size, hidden_size, dt = None):
    super().__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.tau = 100
    # Sets alpha as a function of the time step, dt, and the time constant, tau
    if dt is None:
      alpha = 1
    else:
      alpha = dt/self.tau
    self.alpha = alpha
    # Defines sparse layers for the input to hidden and hidden to hidden
    # connections
    self.input2hidden = self.sparse_matrices(input_size, hidden_size)
    self.hidden2hidden = self.sparse_matrices(hidden_size, hidden_size)

  # init_hidden function initialises the hidden state as a matrix of zeros
  def init_hidden(self, input_shape):
    batch_size = input_shape[1]
    return torch.zeros(batch_size, self.hidden_size)

  # Recurrence function updates the hidden state
  def recurrence(self, input, hidden):
    # ReLU non-linear activation function is used to calculate the new hidden
    # state using the input and recurrent weight matrices
    h_new = torch.relu(self.input2hidden(input) + self.hidden2hidden(hidden))
    # Leaky integration is applied with alpha deciding how much of the last
    # state is retained
    h_new = hidden * (1 - self.alpha) + h_new * self.alpha
    return h_new

  # Forward function - process the sequence for every time step
  def forward(self, input, hidden = None):
    # If there is no hidden state, then it is initialised here
    if hidden is None:
      hidden = self.init_hidden(input.shape).to(input.device)

      output = []
      # Creates a loop for every time step
      steps = range(input.size(0))
      for i in steps:
        # Updates the hidden state using the recurrence function
        hidden = self.recurrence(input[i], hidden)
        output.append(hidden)

      # Stacks all outputs in one dimension
      output = torch.stack(output, dim = 0)
      return output, hidden

In [9]:
class RNNNet(nn.Module):
  def __init__(self, input_size, hidden_size, output_size, **kwargs):
    super().__init__()

    self.rnn = LeakyRNN(input_size, hidden_size, **kwargs)

    self.fc = nn.Linear(hidden_size, output_size)

  def forward(self, x):
    rnn_output, _ = self.rnn(x)
    out = self.fc(rnn_output)
    return out, rnn_output

In [10]:
import neurogym as ngym


task_name = 'ContextDecisionMaking-v0'
kwargs = {'dt': 20, 'timing': {'stimulus': 1500}}

  logger.warn(


In [11]:
seq_len = 150
batch_size = 24

dataset = ngym.Dataset(task_name, env_kwargs=kwargs, batch_size=batch_size, seq_len=seq_len)
env = dataset.env

inputs, target = dataset()
inputs = torch.from_numpy(inputs).type(torch.float)

input_size = env.observation_space.shape[0]
output_size = env.action_space.n

print(inputs.shape)
print(target.shape)

torch.Size([150, 24, 7])
(150, 24)


  deprecation(
  deprecation(


In [12]:
hidden_size = 192

net = RNNNet(input_size=input_size, hidden_size=hidden_size, output_size=output_size, dt=env.dt)
print(net)

def train_model(net, dataset):
    optimizer = optim.Adam(net.parameters(), lr = 0.01)
    criterion = nn.CrossEntropyLoss()

    running_loss = 0
    running_acc = 0
    start_time = time.time()

    for i in range(2000):
        inputs, labels = dataset()
        inputs = torch.from_numpy(inputs).type(torch.float)
        labels = torch.from_numpy(labels.flatten()).type(torch.long)

        optimizer.zero_grad()
        output, _ = net(inputs)
        output = output.view(-1, output_size)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            running_loss /= 100
            print('Step {}, Loss {:0.4f}, Time {:0.1f}s'.format(
                i+1, running_loss, time.time() - start_time))
            running_loss = 0
    return net

net = train_model(net, dataset)

RNNNet(
  (rnn): LeakyRNN(
    (input2hidden): Linear(in_features=7, out_features=192, bias=True)
    (hidden2hidden): Linear(in_features=192, out_features=192, bias=True)
  )
  (fc): Linear(in_features=192, out_features=3, bias=True)
)
Step 100, Loss 26.2008, Time 10.5s
Step 200, Loss 0.0777, Time 18.6s
Step 300, Loss 0.0392, Time 27.8s
Step 400, Loss 0.0302, Time 36.6s
Step 500, Loss 0.0266, Time 44.8s
Step 600, Loss 0.0249, Time 53.4s
Step 700, Loss 0.0247, Time 62.2s
Step 800, Loss 0.0234, Time 70.3s
Step 900, Loss 0.0229, Time 79.6s
Step 1000, Loss 0.0229, Time 88.4s
Step 1100, Loss 0.0225, Time 96.8s
Step 1200, Loss 0.0223, Time 108.1s
Step 1300, Loss 0.0228, Time 117.0s
Step 1400, Loss 0.0221, Time 126.0s
Step 1500, Loss 0.0220, Time 133.8s
Step 1600, Loss 0.0220, Time 142.7s
Step 1700, Loss 0.0223, Time 151.6s
Step 1800, Loss 0.0220, Time 159.6s
Step 1900, Loss 0.0220, Time 169.0s
Step 2000, Loss 0.0219, Time 178.1s


In [13]:
env = dataset.env
env.reset(no_step=True)

perf = 0
activity_dict = {}
trial_infos = {}

num_trial = 200
for i in range(num_trial):
  trial_info = env.new_trial()
  ob, gt = env.ob, env.gt
  inputs = torch.from_numpy(ob[:,np.newaxis, :]).type(torch.float)

  action_pred, rnn_activity = net(inputs)

  action_pred = action_pred.detach().numpy()[:, 0, :]
  choice = np.argmax(action_pred[-1, :])
  correct = choice == gt[-1]

  rnn_activity = rnn_activity[:, 0, :].detach().numpy()
  activity_dict[i] = rnn_activity
  trial_infos[i] = trial_info
  trial_infos[i].update({'choice': choice, 'correct': correct})

  if correct:
    perf += 1

for i in range(10):
  print('Trial', i, trial_infos[i])

print(f"Average performance = {np.mean([val['correct'] for val in trial_infos.values()])}")
print(f"Performance: {perf}/{num_trial}")


Trial 0 {'ground_truth': 2, 'other_choice': 1, 'context': 0, 'coh_0': 15, 'coh_1': 15, 'choice': 2, 'correct': True}
Trial 1 {'ground_truth': 2, 'other_choice': 1, 'context': 0, 'coh_0': 5, 'coh_1': 5, 'choice': 2, 'correct': True}
Trial 2 {'ground_truth': 1, 'other_choice': 1, 'context': 1, 'coh_0': 50, 'coh_1': 15, 'choice': 2, 'correct': False}
Trial 3 {'ground_truth': 1, 'other_choice': 1, 'context': 0, 'coh_0': 50, 'coh_1': 5, 'choice': 2, 'correct': False}
Trial 4 {'ground_truth': 2, 'other_choice': 1, 'context': 1, 'coh_0': 5, 'coh_1': 5, 'choice': 2, 'correct': True}
Trial 5 {'ground_truth': 1, 'other_choice': 1, 'context': 0, 'coh_0': 15, 'coh_1': 50, 'choice': 2, 'correct': False}
Trial 6 {'ground_truth': 1, 'other_choice': 2, 'context': 1, 'coh_0': 15, 'coh_1': 15, 'choice': 2, 'correct': False}
Trial 7 {'ground_truth': 2, 'other_choice': 2, 'context': 0, 'coh_0': 15, 'coh_1': 50, 'choice': 2, 'correct': True}
Trial 8 {'ground_truth': 2, 'other_choice': 2, 'context': 0, 'coh