Import libraries

In [24]:
# imports
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import init
from torch.nn import functional as F
import math
import time
import neurogym as ngym

## NeuroGym Task 1

In [2]:
task_name = 'PerceptualDecisionMaking-v0'
kwargs = {'dt': 20, 'timing': {'stimulus': 1000}}

In [3]:
# Make supervised dataset
seq_len = 100
batch_size = 16
#Create the dataset (Hover over ngym.Dataset to see input arguments)
dataset = ngym.Dataset(task_name, env_kwargs=kwargs, seq_len=seq_len, batch_size=batch_size)
env = dataset.env

# Generate one batch of data when called
inputs, target = dataset()
inputs = torch.from_numpy(inputs).type(torch.float)

input_size = env.observation_space.shape[0]
output_size = env.action_space.n

print('Input has shape (SeqLen, Batch, Dim) =', inputs.shape)
print('Target has shape (SeqLen, Batch) =', target.shape)

Input has shape (SeqLen, Batch, Dim) = torch.Size([100, 16, 3])
Target has shape (SeqLen, Batch) = (100, 16)


In [None]:
def train_model(net, dataset, criterion, model = None):
    """Simple helper function to train the model.

    Args:
        net: a pytorch nn.Module module
        dataset: a dataset object that when called produce a (input, target output) pair
        criterion: a pytorch loss function
        model: a model object that is used in the loss function

    Returns:
        net: network object after training
    """
    # Use Adam optimizer
    optimizer = optim.Adam(net.parameters(), lr=0.01)
    
    running_loss = 0
    running_acc = 0
    start_time = time.time()
    # Loop over training batches
    print('Training network...')
    for i in range(2000):
        # Generate input and target, convert to pytorch tensor
        inputs, labels = dataset()
        inputs = torch.from_numpy(inputs).type(torch.float)
        labels = torch.from_numpy(labels.flatten()).type(torch.long)

        # boiler plate pytorch training:
        optimizer.zero_grad()   # zero the gradient buffers
        output, _ = net(inputs)
        # Reshape to (SeqLen x Batch, OutputSize)
        output = output.view(-1, output_size)
        if model == None:
            loss = criterion(output, labels)
        else:
            loss = criterion(output, labels, model)
        loss.backward()
        optimizer.step()    # Does the update

        # Compute the running loss every 100 steps
        running_loss += loss.item()
        if i % 100 == 99:
            running_loss /= 100
            print('Step {}, Loss {:0.4f}, Time {:0.1f}s'.format(
                i+1, running_loss, time.time() - start_time))
            running_loss = 0
    return net

## Model 1: Leaky RNN (from notes)

In [11]:
class LeakyRNN(nn.Module):
    """Leaky RNN.

    Parameters:
        input_size: Number of input neurons
        hidden_size: Number of hidden neurons
        dt: discretization time step in ms.
            If None, dt equals time constant tau

    Inputs:
        input: tensor of shape (seq_len, batch, input_size)
        hidden: tensor of shape (batch, hidden_size), initial hidden activity
            if None, hidden is initialized through self.init_hidden()

    Outputs:
        output: tensor of shape (seq_len, batch, hidden_size)
        hidden: tensor of shape (batch, hidden_size), final hidden activity
    """

    def __init__(self, input_size, hidden_size, dt=None, **kwargs):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.tau = 100
        if dt is None:
            alpha = 1
        else:
            alpha = dt / self.tau
        self.alpha = alpha

        self.input2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)

    def init_hidden(self, input_shape):
        batch_size = input_shape[1]
        return torch.zeros(batch_size, self.hidden_size)

    def recurrence(self, input, hidden):
        """Run network for one time step.

        Inputs:
            input: tensor of shape (batch, input_size)
            hidden: tensor of shape (batch, hidden_size)

        Outputs:
            h_new: tensor of shape (batch, hidden_size),
                network activity at the next time step
        """
        h_new = torch.relu(self.input2h(input) + self.h2h(hidden))

        #implement how much the previous hidden layer activity should be maintained in the new activity
        h_new = hidden * (1 - self.alpha) + h_new * self.alpha
        return h_new

    def forward(self, input, hidden=None):
        """Propogate input through the network."""

        # If hidden activity is not provided, initialize it
        if hidden is None:
            hidden = self.init_hidden(input.shape).to(input.device)

        # Loop through time
        output = []
        steps = range(input.size(0))
        for i in steps:
            hidden = self.recurrence(input[i], hidden)
            output.append(hidden)

        # Stack together output from all time steps
        output = torch.stack(output, dim=0)  # (seq_len, batch, hidden_size)
        return output, hidden


class RNNNet(nn.Module):
    """Recurrent network model.

    Parameters:
        input_size: int, input size
        hidden_size: int, hidden size
        output_size: int, output size

    Inputs:
        x: tensor of shape (Seq Len, Batch, Input size)

    Outputs:
        out: tensor of shape (Seq Len, Batch, Output size)
        rnn_output: tensor of shape (Seq Len, Batch, Hidden size)
    """
    def __init__(self, input_size, hidden_size, output_size, **kwargs):
        super().__init__()

        # Leaky RNN
        self.rnn = LeakyRNN(input_size, hidden_size, **kwargs)

        # Add a Linear output layer
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        rnn_output, _ = self.rnn(x)
        out = self.fc(rnn_output)
        return out, rnn_output

In [12]:
# Instantiate the network and print information
hidden_size = 128
criterion = nn.CrossEntropyLoss()

# Create an instance of the Class RNNNet
net = RNNNet(input_size, hidden_size, output_size, dt=env.dt)
print(net)

net = train_model(net, dataset, criterion)

RNNNet(
  (rnn): LeakyRNN(
    (input2h): Linear(in_features=3, out_features=128, bias=True)
    (h2h): Linear(in_features=128, out_features=128, bias=True)
  )
  (fc): Linear(in_features=128, out_features=3, bias=True)
)
Training network...
Step 100, Loss 0.1700, Time 3.0s
Step 200, Loss 0.0829, Time 5.9s
Step 300, Loss 0.0568, Time 9.0s
Step 400, Loss 0.0431, Time 12.0s
Step 500, Loss 0.0378, Time 15.0s
Step 600, Loss 0.0337, Time 18.1s
Step 700, Loss 0.0319, Time 21.2s
Step 800, Loss 0.0296, Time 24.4s
Step 900, Loss 0.0284, Time 27.5s
Step 1000, Loss 0.0285, Time 30.6s
Step 1100, Loss 0.0337, Time 33.8s
Step 1200, Loss 0.0286, Time 37.0s
Step 1300, Loss 0.0274, Time 40.2s
Step 1400, Loss 0.0257, Time 43.6s
Step 1500, Loss 0.0295, Time 46.8s
Step 1600, Loss 0.0262, Time 50.2s
Step 1700, Loss 0.0262, Time 53.4s
Step 1800, Loss 0.0250, Time 56.7s
Step 1900, Loss 0.0255, Time 60.0s
Step 2000, Loss 0.0262, Time 63.2s


## Model 2: Brain Inspired - penalties for non brain like cost functions

In [None]:
# define penalties
def sparsity_penalty(model, beta_sparsity=0.01):
    """Compute sparsity penalty based on the model's weights."""
    return beta_sparsity * torch.sum(torch.abs(model.fc.weight))

def firing_rate_penalty(output, beta_firing_rate=0.1):
    """Compute penalty for high firing rates based on the model's output."""
    return beta_firing_rate * torch.sum(torch.square(output))


# define loss function
def loss_brain_penalties(output, target, model, 
                                beta_sparsity=0.05, beta_firing_rate=0.5):
    """
    Compute the total loss with added penalties for sparsity, firing rate, and long-distance connections.
    """
    # Base task-specific loss - same as before
    loss = nn.CrossEntropyLoss()(output, target)

    # Add penalties
    total_loss = loss
    total_loss += sparsity_penalty(model, beta_sparsity)
    total_loss += firing_rate_penalty(output, beta_firing_rate)
    
    return total_loss

### Sparsity penalty

In [21]:
# Create an instance of the Class RNNNet
net = RNNNet(input_size, hidden_size, output_size, dt=env.dt) # same as before

net = train_model(
    net, 
    dataset, 
    lambda output, target, model: loss_brain_penalties(output, target, model, beta_sparsity=0.5, beta_firing_rate=0),
    model=net  # pass for penalty calc
)


Training network...
Step 100, Loss 1.4102, Time 3.0s
Step 200, Loss 0.6318, Time 6.1s
Step 300, Loss 0.6044, Time 9.2s
Step 400, Loss 0.6031, Time 12.3s
Step 500, Loss 0.5947, Time 15.3s
Step 600, Loss 0.6326, Time 18.5s
Step 700, Loss 0.5851, Time 21.8s
Step 800, Loss 0.5684, Time 25.0s
Step 900, Loss 0.5649, Time 28.2s
Step 1000, Loss 0.5655, Time 31.5s
Step 1100, Loss 0.5685, Time 34.8s
Step 1200, Loss 0.5654, Time 38.0s
Step 1300, Loss 0.5682, Time 41.4s
Step 1400, Loss 0.5684, Time 44.8s
Step 1500, Loss 0.5655, Time 48.1s
Step 1600, Loss 0.5694, Time 51.5s
Step 1700, Loss 0.5683, Time 55.0s
Step 1800, Loss 0.5645, Time 58.3s
Step 1900, Loss 0.5658, Time 61.7s
Step 2000, Loss 0.5659, Time 65.1s


### Low Firing Rate Penalty

In [22]:
# Create an instance of the Class RNNNet
net = RNNNet(input_size, hidden_size, output_size, dt=env.dt) # same as before

net = train_model(
    net, 
    dataset, 
    lambda output, target, model: loss_brain_penalties(output, target, model, beta_sparsity=0, beta_firing_rate=0.5),
    model=net  # pass for penalty calc
)

Training network...
Step 100, Loss 14.9122, Time 2.9s
Step 200, Loss 1.1213, Time 5.7s
Step 300, Loss 1.1071, Time 8.6s
Step 400, Loss 1.1025, Time 11.5s
Step 500, Loss 1.1014, Time 14.5s
Step 600, Loss 1.1008, Time 18.7s
Step 700, Loss 1.1007, Time 22.1s
Step 800, Loss 1.1006, Time 25.5s
Step 900, Loss 1.1007, Time 29.1s
Step 1000, Loss 1.1002, Time 32.6s
Step 1100, Loss 1.1007, Time 36.1s
Step 1200, Loss 1.1000, Time 39.5s
Step 1300, Loss 1.1004, Time 42.9s
Step 1400, Loss 1.1000, Time 46.2s
Step 1500, Loss 1.0998, Time 49.6s
Step 1600, Loss 1.1000, Time 53.2s
Step 1700, Loss 1.0998, Time 56.6s
Step 1800, Loss 1.1001, Time 60.0s
Step 1900, Loss 1.0993, Time 63.3s
Step 2000, Loss 1.0998, Time 66.7s


# Model EI 

In [25]:
class EIRecLinear(nn.Module):

    r"""Recurrent E-I Linear transformation.

    This module implements a linear transformation with recurrent E-I dynamics,
    where part of the units are excitatory and the rest are inhibitory.

    Args:
        hidden_size: int, the number of units in the layer.
        e_prop: float between 0 and 1, the proportion of excitatory units.
        bias: bool, if True, adds a learnable bias to the output.
    """

    __constants__ = ['bias', 'hidden_size', 'e_prop']

    def __init__(self, hidden_size, e_prop, bias=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.e_prop = e_prop
        self.e_size = int(e_prop * hidden_size) # Number of excitatory units
        self.i_size = hidden_size - self.e_size # Number of inhibitory units

        # Weight matrix for the recurrent connections
        self.weight = nn.Parameter(torch.Tensor(hidden_size, hidden_size))

        # Create a mask to define the E-I interactions
        # The mask has ones for E to E/I and negative ones for I to E/I, except the diagonal

        #First create a mask to remove the diagonal (matrix size hidden_size*hidden_size)
        mask_no_diag = torch.ones((hidden_size, hidden_size)) - torch.eye(hidden_size)
        # Define the excitatory an inhibitory units with columns of 1s and -1s (use e_size and i_size)
        E_I_unit_list = np.concatenate([np.ones(self.e_size), -np.ones(self.i_size)]).T

        mask = torch.tensor(E_I_unit_list, dtype=torch.float32) * mask_no_diag

        self.mask = torch.tensor(mask, dtype=torch.float32)

        # Optionally add a bias term
        if bias:
            self.bias = nn.Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        # Initialize weights and biases
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        # Scale the weights for the excitatory neurons
        self.weight.data[:, :self.e_size] /= (self.e_size/self.i_size)

        # Initialize biases
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def effective_weight(self):
        # Apply the mask you have already created to the weights after applying rectification to get the effective weight
        # This ensures that weights from excitatory neurons are positive,
        # and weights from inhibitory neurons are negative.
        eff_W = F.relu(self.weight)*self.mask
        return eff_W

    def forward(self, input):
        # Apply the linear transformation using the effective weights and biases
        # The weights used are non-negative due to the absolute value in effective_weight.
        return F.linear(input, self.effective_weight(), self.bias)


In [26]:
class EIRNN(nn.Module):
    """E-I RNN.

    Reference:
        Song, H.F., Yang, G.R. and Wang, X.J., 2016.
        Training excitatory-inhibitory recurrent neural networks
        for cognitive tasks: a simple and flexible framework.
        PLoS computational biology, 12(2).

    Args:
        input_size: Number of input neurons
        hidden_size: Number of hidden neurons

    Inputs:
        input: (seq_len, batch, input_size)
        hidden: (batch, hidden_size)
        e_prop: float between 0 and 1, proportion of excitatory neurons
    """

    def __init__(self, input_size, hidden_size, dt=None,
                 e_prop=0.8, sigma_rec=0, **kwargs):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.e_size = int(hidden_size * e_prop)
        self.i_size = hidden_size - self.e_size
        self.num_layers = 1
        self.tau = 100
        if dt is None:
            alpha = 1
        else:
            alpha = dt / self.tau
        self.alpha = alpha
        self.oneminusalpha = 1 - alpha
        # Recurrent noise parameter, scaled by the discretization (sqrt(2*alpha)) and noise level (sigma_rec)
        # This adds stochasticity to the recurrent dynamics, possibly simulating biological neural variability
        self._sigma_rec = np.sqrt(2*alpha) * sigma_rec

        self.input2h = nn.Linear(input_size, hidden_size)
        self.h2h = EIRecLinear(hidden_size, e_prop=0.8)

    def init_hidden(self, input):
        batch_size = input.shape[1]
        return (torch.zeros(batch_size, self.hidden_size).to(input.device),
                torch.zeros(batch_size, self.hidden_size).to(input.device))

    def recurrence(self, input, hidden):
        """Recurrence helper."""
        state, output = hidden
        total_input = self.input2h(input) + self.h2h(output)

        state = state * self.oneminusalpha + total_input * self.alpha
        state += self._sigma_rec * torch.randn_like(state)
        output = torch.relu(state)
        return state, output

    def forward(self, input, hidden=None):
        """Propogate input through the network."""
        if hidden is None:
            hidden = self.init_hidden(input)

        output = []
        steps = range(input.size(0))
        for i in steps:
            hidden = self.recurrence(input[i], hidden)
            output.append(hidden[1])
        output = torch.stack(output, dim=0)
        return output, hidden

In [27]:
class Net(nn.Module):
    """Recurrent network model.

    Args:
        input_size: int, input size
        hidden_size: int, hidden size
        output_size: int, output size
        rnn: str, type of RNN, lstm, rnn, ctrnn, or eirnn
    """
    def __init__(self, input_size, hidden_size, output_size, **kwargs):
        super().__init__()

        # Excitatory-inhibitory RNN
        self.rnn = EIRNN(input_size, hidden_size, **kwargs)
        self.fc = nn.Linear(self.rnn.e_size, output_size)

    def forward(self, x):
        rnn_activity, _ = self.rnn(x)
        rnn_e = rnn_activity[:, :, :self.rnn.e_size]
        out = self.fc(rnn_e)
        return out, rnn_activity

In [29]:
# Instantiate the network
hidden_size = 50
net = Net(input_size=input_size, hidden_size=hidden_size,
          output_size=output_size, dt=env.dt, sigma_rec=0.15)
print(net)

"""
# Use Adam optimizer
optimizer = optim.Adam(net.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

loss_values = []  # List to store loss values
running_loss = 0.0
print_step = 200
for i in range(5000):
    inputs, labels = dataset()
    inputs = torch.from_numpy(inputs).type(torch.float)
    labels = torch.from_numpy(labels.flatten()).type(torch.long)

    # Zero the gradient buffers
    optimizer.zero_grad()

    # Forward pass
    output, activity = net(inputs)
    output = output.view(-1, output_size)

    # Compute loss
    loss = criterion(output, labels)

    # Backward pass
    loss.backward()

    # Update weights
    optimizer.step()

    # Update running loss
    running_loss += loss.item()
    if i % print_step == (print_step - 1):
        average_loss = running_loss / print_step
        print('Step {}, Loss {:0.4f}'.format(i+1, average_loss))
        loss_values.append(average_loss)  # Append average loss here
        running_loss = 0.0

# Plotting the learning curve
plt.figure(figsize=(10,5))
plt.title("Learning Curve")
plt.plot(loss_values, label='Loss')
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.legend()
plt.show()
"""
criterion = nn.CrossEntropyLoss()
net = train_model(net, dataset, criterion)

  self.mask = torch.tensor(mask, dtype=torch.float32)


Net(
  (rnn): EIRNN(
    (input2h): Linear(in_features=3, out_features=50, bias=True)
    (h2h): EIRecLinear()
  )
  (fc): Linear(in_features=40, out_features=3, bias=True)
)
Training network...
Step 100, Loss 0.2854, Time 3.9s
Step 200, Loss 0.1045, Time 7.8s
Step 300, Loss 0.0595, Time 11.7s
Step 400, Loss 0.0480, Time 15.6s
Step 500, Loss 0.0405, Time 19.7s
Step 600, Loss 0.0359, Time 24.4s
Step 700, Loss 0.0337, Time 28.8s
Step 800, Loss 0.0349, Time 33.2s
Step 900, Loss 0.0308, Time 37.4s
Step 1000, Loss 0.0296, Time 41.6s
Step 1100, Loss 0.0286, Time 46.1s
Step 1200, Loss 0.0285, Time 50.3s
Step 1300, Loss 0.0295, Time 54.7s
Step 1400, Loss 0.0269, Time 58.9s
Step 1500, Loss 0.0278, Time 63.2s
Step 1600, Loss 0.0271, Time 67.7s
Step 1700, Loss 0.0286, Time 72.1s
Step 1800, Loss 0.0275, Time 76.6s
Step 1900, Loss 0.0252, Time 81.3s
Step 2000, Loss 0.0262, Time 85.9s
