<a href="https://colab.research.google.com/github/adampotton/Cognitive_AI_CW/blob/main/Q2A_LSTM_BI_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Adpating the LSTM model architecture to tackle sparsity

In [1]:
! git clone https://github.com/neurogym/neurogym.git
%cd neurogym/
! pip install -e .

Cloning into 'neurogym'...
remote: Enumerating objects: 11100, done.[K
remote: Counting objects: 100% (1002/1002), done.[K
remote: Compressing objects: 100% (106/106), done.[K
remote: Total 11100 (delta 928), reused 896 (delta 896), pack-reused 10098 (from 1)[K
Receiving objects: 100% (11100/11100), 8.17 MiB | 7.04 MiB/s, done.
Resolving deltas: 100% (8333/8333), done.
/content/neurogym
Obtaining file:///content/neurogym
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gym<0.25,>=0.20.0 (from neurogym==0.0.2)
  Downloading gym-0.24.1.tar.gz (696 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m696.4/696.4 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: gym
  Building wheel for gym (pyproject.toml) ... [?25l[?25hdone
  Created wheel for

In [2]:
import neurogym as ngym
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time

  logger.warn(


### Define a simple LSTM model

In [229]:
class Sparse_LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, sparsity_masks):
        super(Sparse_LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers)  # Standard LSTM with n layers
        self.sparsity_masks = sparsity_masks  # Binary masks for sparsity

    def forward(self, x): # Applying sparsity on each forward pass
        with torch.no_grad():
            for name, param in self.lstm.named_parameters():
                if 'weight' in name:  # Only applying sparsity mask to weight matrices
                    layer_idx = int(name.split('_l')[1].split('.')[0])  # Extract layer index from name
                    if 'ih' in name:  # Input-to-hidden weights
                        mask = self.sparsity_masks.get(f'weight_ih_l{layer_idx}')
                        param.data *= mask
                    elif 'hh' in name:  # Hidden-to-hidden weights
                        mask = self.sparsity_masks.get(f'weight_hh_l{layer_idx}')
                        param.data *= mask

        output, (hn, cn) = self.lstm(x)
        return output, (hn, cn)

    def apply_sparsity_masks_after_optimiser(self):  # Applying sparsity masks after the optimiser step
        with torch.no_grad():
            for name, param in self.lstm.named_parameters():
                if 'weight' in name:
                    layer_idx = int(name.split('_l')[1].split('.')[0])  # Extract layer index from name
                    if 'ih' in name:  # Input-to-hidden weights
                        mask = self.sparsity_masks.get(f'weight_ih_l{layer_idx}')
                        if mask is not None:
                            param.data *= mask
                    elif 'hh' in name:  # Hidden-to-hidden weights
                        mask = self.sparsity_masks.get(f'weight_hh_l{layer_idx}')
                        if mask is not None:
                            param.data *= mask  # Apply sparsity in-place


class LSTMNet(nn.Module): # Define the main model with the added linear layer
    def __init__(self, input_size, hidden_size, output_size, num_lstm_layers=2, sparsity_masks=None):
        super(LSTMNet, self).__init__()
        self.lstm = Sparse_LSTM(input_size, hidden_size, num_layers=num_lstm_layers, sparsity_masks=sparsity_masks)  # Sparse LSTM layers
        self.fc = nn.Linear(hidden_size, output_size)  # Fully connected layer

    def forward(self, x):
        lstm_output, _ = self.lstm(x)
        out = self.fc(lstm_output)
        return out, lstm_output


### Creating dataset and adjusting parameters

In [230]:
config = {
    'dt': 200, # Timestep parameter
    'hidden_size': 32, # Hidden size for your LSTM
    'batch_size': 16, # Batch size for training
    'seq_len': 100, # Sequence length for input data
    'envid': 'ReadySetGo-v0', # Task name
    'gain': 2, # Custom gain
    'prod_margin': 10, # Custom production margin
}

env_kwargs = {
    'dt': config['dt'], # Assing timestep parameter
    'gain': config['gain'],  # Controls the measure that the agent has to produce
    'prod_margin': config['prod_margin'], # Controls the interval around the ground truth production time within which the agent receives proportional reward
}
config['env_kwargs'] = env_kwargs

dataset = ngym.Dataset(config['envid'], env_kwargs=config['env_kwargs'], batch_size=config['batch_size'], seq_len=config['seq_len']) # Generate dataset
env = dataset.env

inputs, target = dataset() # Assing inputs and targets
inputs = torch.from_numpy(inputs).type(torch.float)

input_size = env.observation_space.shape[0] # Find dimensions for data
output_size = env.action_space.n

print('Input has shape (SeqLen, Batch, Dim) =', inputs.shape)
print('Target has shape (SeqLen, Batch) =', target.shape)

Input has shape (SeqLen, Batch, Dim) = torch.Size([100, 16, 3])
Target has shape (SeqLen, Batch) = (100, 16)


  and should_run_async(code)
  logger.warn(
  logger.warn(
  logger.warn(


### Generate sparisty matricies

In [231]:
def generate_sparsity_masks(input_size, hidden_size, num_layers, sparsity):
    sparsity_masks = {}
    for layer in range(num_layers):
        ih_shape_1_gate = (hidden_size, input_size if layer == 0 else hidden_size)  # Input-to-hidden mask
        hh_shape_1_gate = (hidden_size, hidden_size)  # Hidden-to-hidden mask
        ih_mask_1 = (torch.rand(ih_shape_1_gate) > sparsity).float()
        hh_mask_1 = (torch.rand(hh_shape_1_gate) > sparsity).float()
        sparsity_masks[f'weight_ih_l{layer}'] = ih_mask_1.repeat(4, 1)
        sparsity_masks[f'weight_hh_l{layer}'] = hh_mask_1.repeat(4, 1)

    return sparsity_masks

  and should_run_async(code)


### Training the model

In [232]:
iter_steps = 1000 # Training loops
report_freq = 100 # How often a report on is returned
num_lstm_layers = 2 # Number of LSTM layers
sparsity = 0.2 # Proportion of 0s in binary masks

binary_masks = generate_sparsity_masks(input_size, config['hidden_size'], num_lstm_layers, sparsity)

net = LSTMNet(input_size, config['hidden_size'], output_size, num_lstm_layers , binary_masks) # Create an instance of the sparse LSTM

def train_model(net, dataset, iter_steps, report_freq):

    optimizer = optim.AdamW(net.parameters(), lr=0.01) # AdamW optimiser
    criterion = nn.CrossEntropyLoss() # Loss funciton

    running_loss = 0
    running_acc = 0
    start_time = time.time() # Start training timer

    for i in range(iter_steps): # Loop over training batches
        inputs, labels = dataset() # Generate a set of data
        inputs = torch.from_numpy(inputs).type(torch.float)
        labels = torch.from_numpy(labels.flatten()).type(torch.long)

        optimizer.zero_grad() # Reset gradients
        output, _ = net(inputs)
        output = output.view(-1, output_size)

        loss = criterion(output, labels) # Loss function
        loss.backward()
        optimizer.step()  # Update

        net.lstm.apply_sparsity_masks_after_optimiser() # Apply sparsity masks after the optimiser step
        batch_acc = (torch.argmax(output, dim=1) == labels).sum().item() / labels.shape[0] # Current batch accuracy
        running_loss += loss.item()
        running_acc += batch_acc

        if i % report_freq == report_freq - 1:
            running_loss /= report_freq
            running_acc /= report_freq  # average accuracy over the last 100 batches
            print('Step {}, Loss {:0.4f}, Accuracy {:0.4f}, Time {:0.1f}s'.format(
                i+1, running_loss, running_acc, time.time() - start_time))
            running_loss = 0 # Reset metrics for next report
            running_acc = 0
    return net


net = train_model(net, dataset, iter_steps, report_freq) # Call the training function

Step 100, Loss 0.1802, Accuracy 0.9622, Time 1.9s
Step 200, Loss 0.1264, Accuracy 0.9622, Time 3.9s
Step 300, Loss 0.0735, Accuracy 0.9697, Time 6.4s
Step 400, Loss 0.0228, Accuracy 0.9949, Time 8.7s
Step 500, Loss 0.0183, Accuracy 0.9950, Time 11.6s
Step 600, Loss 0.0158, Accuracy 0.9958, Time 13.5s
Step 700, Loss 0.0158, Accuracy 0.9957, Time 15.6s
Step 800, Loss 0.0156, Accuracy 0.9958, Time 17.7s
Step 900, Loss 0.0150, Accuracy 0.9958, Time 19.7s
Step 1000, Loss 0.0157, Accuracy 0.9956, Time 22.2s


### Function to check the sparsity of each layer

In [218]:
def check_sparsity_of_weights(net):
    with torch.no_grad():  # Ensure no gradients are computed during the check
        for name, param in net.lstm.named_parameters():
            if 'weight' in name:
                zeros = torch.sum(param == 0).item()
                total = param.numel()
                sparsity = zeros / total
                print(f"Sparsity of {name}: {zeros} zeros out of {total} total weights. Sparsity: {sparsity:.2f}")

check_sparsity_of_weights(net)

Sparsity of lstm.weight_ih_l0: 99 zeros out of 384 total weights. Sparsity: 0.26
Sparsity of lstm.weight_hh_l0: 827 zeros out of 4096 total weights. Sparsity: 0.20
Sparsity of lstm.weight_ih_l1: 824 zeros out of 4096 total weights. Sparsity: 0.20
Sparsity of lstm.weight_hh_l1: 831 zeros out of 4096 total weights. Sparsity: 0.20


  and should_run_async(code)


In [220]:
def print_weight_dimensions(net):
    for name, param in net.named_parameters():
        if 'weight' in name:
            print(f"{name}: {param.shape}")

print_weight_dimensions(net)


lstm.lstm.weight_ih_l0: torch.Size([128, 3])
lstm.lstm.weight_hh_l0: torch.Size([128, 32])
lstm.lstm.weight_ih_l1: torch.Size([128, 32])
lstm.lstm.weight_hh_l1: torch.Size([128, 32])
fc.weight: torch.Size([2, 32])


  and should_run_async(code)


In [205]:
env = dataset.env # Reset environment
env.reset(no_step=True)

perf = 0 # Initialize loggin vars
activity_dict = {}
trial_infos = {}

num_trial = 200
for i in range(num_trial):

    trial_info = env.new_trial() # New trial
    ob, gt = env.ob, env.gt # Observation and groud-truth of this trial
    inputs = torch.from_numpy(ob[:, np.newaxis, :]).type(torch.float)

    action_pred, rnn_activity = net(inputs) # Run network for one trial

    action_pred = action_pred.detach().numpy()[:, 0, :] # Compute performance
    choice = np.argmax(action_pred[-1, :]) # Final choice at final time step
    correct = choice == gt[-1]

    rnn_activity = rnn_activity[:, 0, :].detach().numpy() # Record activity
    activity_dict[i] = rnn_activity
    trial_infos[i] = trial_info  # Record trial infos
    trial_infos[i].update({'correct': correct})

for i in range(20):
    print('Trial ', i, trial_infos[i])

print('Average performance', np.mean([val['correct'] for val in trial_infos.values()]))

Trial  0 {'measure': 800.0, 'gain': 2, 'production': 1600.0, 'correct': True}
Trial  1 {'measure': 1400.0, 'gain': 2, 'production': 2800.0, 'correct': True}
Trial  2 {'measure': 1400.0, 'gain': 2, 'production': 2800.0, 'correct': True}
Trial  3 {'measure': 800.0, 'gain': 2, 'production': 1600.0, 'correct': True}
Trial  4 {'measure': 800.0, 'gain': 2, 'production': 1600.0, 'correct': True}
Trial  5 {'measure': 800.0, 'gain': 2, 'production': 1600.0, 'correct': True}
Trial  6 {'measure': 800.0, 'gain': 2, 'production': 1600.0, 'correct': True}
Trial  7 {'measure': 1000.0, 'gain': 2, 'production': 2000.0, 'correct': True}
Trial  8 {'measure': 1200.0, 'gain': 2, 'production': 2400.0, 'correct': True}
Trial  9 {'measure': 800.0, 'gain': 2, 'production': 1600.0, 'correct': True}
Trial  10 {'measure': 1000.0, 'gain': 2, 'production': 2000.0, 'correct': True}
Trial  11 {'measure': 800.0, 'gain': 2, 'production': 1600.0, 'correct': True}
Trial  12 {'measure': 800.0, 'gain': 2, 'production': 160

  and should_run_async(code)
