<a href="https://colab.research.google.com/github/BrisGeorge24044/Cognitive_Artificial_Intelligence/blob/main/LeakyRNN_with_Hebbian_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Hebbian Learning Rule:
- Neurons that fire together, wire together

    Δw<sub>ij</sub> = η ⋅ x<sub>i</sub> ⋅ y<sub>j</sub>

- w<sub>ij</sub> = Weight which connects the pre- and post-synaptic neurons together
- x<sub>i</sub> = pre-synaptic neuron activity
- y<sub>j</sub> = post-synaptic neuron activity
- η = Learning Rate


In [None]:
!pip install neurogym
!pip install gym==0.25.1

Collecting gym<0.25,>=0.20.0 (from neurogym)
  Using cached gym-0.24.1-py3-none-any.whl
Installing collected packages: gym
  Attempting uninstall: gym
    Found existing installation: gym 0.25.1
    Uninstalling gym-0.25.1:
      Successfully uninstalled gym-0.25.1
Successfully installed gym-0.24.1
Collecting gym==0.25.1
  Using cached gym-0.25.1-py3-none-any.whl
Installing collected packages: gym
  Attempting uninstall: gym
    Found existing installation: gym 0.24.1
    Uninstalling gym-0.24.1:
      Successfully uninstalled gym-0.24.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
neurogym 0.0.2 requires gym<0.25,>=0.20.0, but you have gym 0.25.1 which is incompatible.[0m[31m
[0mSuccessfully installed gym-0.25.1


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import time

LeakyRNN Class defines the Leaky RNN and sets parameters.

In [None]:
class LeakyRNN(nn.Module):
  # Function applies the hebbian learning rule
  def hebbian_learning_rule(self, input, hidden, lr = 0.0001):
    batch_size = input.size(0)

    # Add dimension for input and hidden
    input = input.view(batch_size, 1, -1)
    hidden = hidden.view(batch_size, -1, 1)

    # Computes weight change for input to hidden state
    # torch.bmm performs batch multiplication
    # Creates (batch_size, hidden_size, input_size)
    delta_w_i = lr * torch.bmm(hidden, input)
    # Computes weight change for hidden to hidden state
    # Creates (batch_size, hidden_size, hidden_size)
    delta_w_j = lr * torch.bmm(hidden, hidden.transpose(1, 2))

    # Weight changes are added together to update the weight matrices
    delta_w_i_sum = delta_w_i.sum(dim=0)
    delta_w_j_sum = delta_w_j.sum(dim=0)

    # scale factor controls the impact of the learning rule and ensures that
    # weights do not increase too quickly as this causes instability
    hebbian_scale_factor = 0.001

    with torch.no_grad():
        self.input2hidden.weight += hebbian_scale_factor * delta_w_i_sum
        self.hidden2hidden.weight += hebbian_scale_factor * delta_w_j_sum

        self.input2hidden.weight.data = self.input2hidden.weight.data / (torch.norm(self.input2hidden.weight.data) + 1e-6)
        self.hidden2hidden.weight.data = self.hidden2hidden.weight.data / (torch.norm(self.hidden2hidden.weight.data) + 1e-6)

        # Check for NaN in weights
    if torch.isnan(self.input2hidden.weight).any() or torch.isnan(self.hidden2hidden.weight).any():
        print("NaN detected in weights!")

  def __init__(self, input_size, hidden_size, dt = None):
    super().__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.tau = 100
    # Sets alpha as a function of the time step, dt, and the time constant, tau
    if dt is None:
      alpha = 1
    else:
      alpha = dt/self.tau
    self.alpha = alpha
    # Defines linear layers for the input to hidden and hidden to hidden
    # connections
    self.input2hidden = nn.Linear(input_size, hidden_size)
    self.hidden2hidden = nn.Linear(hidden_size, hidden_size)
  # init_hidden function initialises the hidden state as a matrix of zeros
  def init_hidden(self, input_shape):
    batch_size = input_shape[1]
    return torch.zeros(batch_size, self.hidden_size)

  # Recurrence function updates the hidden state
  def recurrence(self, input, hidden):
    self.hebbian_learning_rule(input, hidden)
    # ReLU non-linear activation function is used to calculate the new hidden
    # state using the input and recurrent weight matrices
    h_new = torch.relu(self.input2hidden(input) + self.hidden2hidden(hidden))
    # Leaky integration is applied with alpha deciding how much of the last
    # state is retained
    h_new = hidden * (1 - self.alpha) + h_new * self.alpha
    return h_new

   # Forward function - process the sequence for every time step
  def forward(self, input, hidden = None):
    # If there is no hidden state, then it is initialised here
    if hidden is None:
      hidden = self.init_hidden(input.shape).to(input.device)

      output = []
      # Creates a loop for every time step
      steps = range(input.size(0))
      for i in steps:
        # Updates the hidden state using the recurrence function
        hidden = self.recurrence(input[i], hidden)
        output.append(hidden)

      # Stacks all outputs in one dimension
      output = torch.stack(output, dim = 0)
      return output, hidden

RNNNet Class creates a fully connected layer which converts the hidden states into final predictions.

In [None]:
class RNNNet(nn.Module):
  def __init__(self, input_size, hidden_size, output_size, **kwargs):
    super().__init__()
    # The LeakyRNN class is instantiated
    self.rnn = LeakyRNN(input_size, hidden_size, **kwargs)
    # A fully connected layer is created and applied as the last step of the
    # program to obtain predictions
    self.fc = nn.Linear(hidden_size, output_size)

  # Forward function processes the input through the Leaky RNN and the FC layer
  def forward(self, x):
    rnn_output, _ = self.rnn(x)
    # FC layer is applied to obtain the predictions for every time step
    out = self.fc(rnn_output)
    return out, rnn_output

In [None]:
import neurogym as ngym

# ContextDecisionMaking-v0 task is selected from neurogym
task_name = 'ContextDecisionMaking-v0'
# Time step and stimulus length are defined
kwargs = {'dt': 20, 'timing': {'stimulus': 1500}}

In [None]:
# The sequence length (Num of time steps for each input) and batch size are defined
seq_len = 150
batch_size = 24

# Dataset is created with ngym.Dataset
dataset = ngym.Dataset(task_name, env_kwargs=kwargs, batch_size=batch_size, seq_len=seq_len)
env = dataset.env

# Inputs and targets are taken from the dataset
inputs, target = dataset()
inputs = torch.from_numpy(inputs).type(torch.float)

# Input and output size are defined by the neurogym environment
input_size = env.observation_space.shape[0]
output_size = env.action_space.n

print(inputs.shape)
print(target.shape)

torch.Size([150, 24, 7])
(150, 24)


In [None]:
hidden_size = 192

net = RNNNet(input_size=input_size, hidden_size=hidden_size, output_size=output_size, dt=env.dt)
print(net)

def train_model_hebbian(net, dataset, epoch_num = 10):
  criterion = nn.CrossEntropyLoss()
  for epoch in range(epoch_num):
    for i in range(epoch_num):
      inputs, labels = dataset()
      inputs = torch.from_numpy(inputs).type(torch.float)
      labels = torch.from_numpy(labels.flatten()).type(torch.long)

      hidden = None
      output, hidden = net.rnn(inputs, hidden)
      output = net.fc(output)

      loss = criterion(output.view(-1, output.size(-1)), labels)

      print(f"Epoch num: {epoch + 1}, Step num: {i + 1}, Loss: {loss.item()}")

    return net


net = train_model_hebbian(net, dataset)

RNNNet(
  (rnn): LeakyRNN(
    (input2hidden): Linear(in_features=7, out_features=192, bias=True)
    (hidden2hidden): Linear(in_features=192, out_features=192, bias=True)
  )
  (fc): Linear(in_features=192, out_features=3, bias=True)
)
Epoch num: 1, Step num: 1, Loss: 1.1536469459533691
Epoch num: 1, Step num: 2, Loss: 1.1537741422653198
Epoch num: 1, Step num: 3, Loss: 1.1537294387817383
Epoch num: 1, Step num: 4, Loss: 1.1532005071640015
Epoch num: 1, Step num: 5, Loss: 1.151999592781067
Epoch num: 1, Step num: 6, Loss: 1.1540336608886719
Epoch num: 1, Step num: 7, Loss: 1.1531153917312622
Epoch num: 1, Step num: 8, Loss: 1.1525181531906128
Epoch num: 1, Step num: 9, Loss: 1.1537060737609863
Epoch num: 1, Step num: 10, Loss: 1.15753173828125


Testing Phase:

In [None]:
# Environment is reset after training
env = dataset.env
env.reset(no_step=True)

# Performance is tracked
perf = 0
# Dictionaries to store activity and information for each trial are created
activity_dict = {}
trial_infos = {}

# Number of trials is set to 200
num_trial = 200

# For loop is created to run the model for 200 trials
for i in range(num_trial):
  # Starts a trial
  trial_info = env.new_trial()
  # defines the observations and ground truths for each trial
  ob, gt = env.ob, env.gt
  # Converts observations to batch dimensions
  inputs = torch.from_numpy(ob[:,np.newaxis, :]).type(torch.float)

  # Takes the predictions from the model
  action_pred, rnn_activity = net(inputs)

  # Converts the predictions to a numpy array and selects the action predicted
  # by the model (the largest value)
  action_pred = action_pred.detach().numpy()[:, 0, :]
  choice = np.argmax(action_pred[-1, :])
  correct = choice == gt[-1]

  # Stores the activity and trial info in dictionaries
  rnn_activity = rnn_activity[:, 0, :].detach().numpy()
  activity_dict[i] = rnn_activity
  trial_infos[i] = trial_info
  trial_infos[i].update({'choice': choice, 'correct': correct})

  # Tracks the performance
  if correct:
    perf += 1

for i in range(20):
  print('Trial', i + 1, trial_infos[i])

print(f"Average performance = {np.mean([val['correct'] for val in trial_infos.values()])}")
print(f"Performance: {perf}/{num_trial}")

Trial 1 {'ground_truth': 1, 'other_choice': 2, 'context': 1, 'coh_0': 50, 'coh_1': 5, 'choice': 2, 'correct': False}
Trial 2 {'ground_truth': 2, 'other_choice': 1, 'context': 1, 'coh_0': 15, 'coh_1': 5, 'choice': 2, 'correct': True}
Trial 3 {'ground_truth': 2, 'other_choice': 1, 'context': 0, 'coh_0': 15, 'coh_1': 50, 'choice': 2, 'correct': True}
Trial 4 {'ground_truth': 1, 'other_choice': 2, 'context': 1, 'coh_0': 15, 'coh_1': 15, 'choice': 2, 'correct': False}
Trial 5 {'ground_truth': 1, 'other_choice': 1, 'context': 0, 'coh_0': 15, 'coh_1': 15, 'choice': 2, 'correct': False}
Trial 6 {'ground_truth': 2, 'other_choice': 1, 'context': 1, 'coh_0': 50, 'coh_1': 50, 'choice': 2, 'correct': True}
Trial 7 {'ground_truth': 2, 'other_choice': 1, 'context': 1, 'coh_0': 15, 'coh_1': 5, 'choice': 2, 'correct': True}
Trial 8 {'ground_truth': 1, 'other_choice': 2, 'context': 0, 'coh_0': 5, 'coh_1': 15, 'choice': 2, 'correct': False}
Trial 9 {'ground_truth': 1, 'other_choice': 2, 'context': 0, 'co