### The model


In [59]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, IterableDataset

from lightning.pytorch import LightningModule, Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger
import neurogym as ngym

\begin{align}
    \mathbf{a}(t+\Delta t) = \mathbf{a}(t) + \Delta \mathbf{a} &= \mathbf{a}(t) + \frac{\Delta t}{\tau}[-\mathbf{a}(t) + f(W_{a\rightarrow a} \mathbf{a}(t) + W_{x\rightarrow a}  \mathbf{x}(t) + \mathbf{b}_r)] \\
    &= (1 - \frac{\Delta t}{\tau})\mathbf{a}(t) + \frac{\Delta t}{\tau}f(W_{a\rightarrow a} \mathbf{a}(t) + W_{x\rightarrow a}  \mathbf{x}(t) + \mathbf{b}_r)
\end{align}

In [60]:
class LeakyRNN(nn.Module):
    def __init__(self, num_input, num_hidden, delta_t = 0.05, tau = 100):

        super().__init__()

        # dont need num_out for the layer probably
        self.input_size = num_input
        self.hidden_size = num_hidden

        self.delta_t = delta_t
        self.tau = tau
        self.alpha = self.delta_t / self.tau

        # linear or parameter?
        self.inh = nn.Linear(self.input_size, self.hidden_size)
        self.hh = nn.Linear(self.hidden_size, self.hidden_size)

        self.nonlinearity = nn.LeakyReLU()


    def _init_hidden(self, input_shape):
        batch_size = input_shape[1]
        return torch.zeros(batch_size, self.hidden_size)

    def recurrent(self, input, hidden):
        # input = x(t)
        # hidden = a(t)
        # hidden_new = a(t + delta_t)

        # W_{a,a} & W_{x,a} are singular with no bias
        hidden_new = self.nonlinearity(self.inh(input) + self.hh(hidden))

        # hidden is a(t), activity of current neuron at time t
        hidden_new = (1 - self.alpha) * hidden + self.alpha * hidden_new

        return hidden_new

    def forward(self, input, hidden=None):

        # dont know what this does really
        if hidden is None:
            hidden = self._init_hidden(input.shape).to(input.device)

        output = []
        for i in range(len(input)):
            hidden = self.recurrent(input[i], hidden)
            output.append(hidden)
        
        output = torch.stack(output, dim=0)     # seq_len, batch, hidden_size

        return output, hidden
    

class RNNNet(LightningModule):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        # leaky layer
        self.rnn = LeakyRNN(input_size, hidden_size)

        # linear output layer
        self.ho = nn.Linear(hidden_size, output_size)

        # loss func
        self.loss = torch.nn.CrossEntropyLoss()
        self.lr = 0.01


    def forward(self, input):
        leaky_out, _ = self.rnn(input)
        output = self.ho(leaky_out)

        return output, leaky_out


    def training_step(self, batch, x_idx):
        x, y = batch
        output, _ = self.forward(x)
        # print(f'input shape {x.shape}')
        # print(f'output shape {y.shape}')
        output = output.view(-1, self.output_size)
        # output_last = output[-1]
        loss = self.loss(output, y)

        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)

        return loss
    

    def test_step(self, batch, x_idx):
        x, y = batch
        output, _ = self.forward(x)
        output = output.view(-1, self.output_size)
        test_loss = self.loss(output, y)

        self.log('test_loss', test_loss, batch_size=x.shape[1])


    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1)

In [61]:
task = 'PerceptualDecisionMaking-v0'
kwargs = {'dt': 100}
seq_len = 100
batch_size = 64

# dont mind the warnings
ngym_dataset = ngym.Dataset(task, kwargs, batch_size, seq_len)
env = ngym_dataset.env
inputs, target = ngym_dataset()

print('Input has shape (SeqLen, Batch, Dim) =', inputs.shape)
print('Target has shape (SeqLen, Batch) =', target.shape)

inputs = torch.from_numpy(inputs).type(torch.float)
target = torch.from_numpy(target.flatten()).type(torch.long)

input_size = env.observation_space.shape[0]
output_size = env.action_space.n

print(input_size, output_size)


Input has shape (SeqLen, Batch, Dim) = (100, 64, 3)
Target has shape (SeqLen, Batch) = (100, 64)
3 3


In [62]:
class NgymWrapper(IterableDataset):
    """
    Wrapper for converting neurogym dataset into in a pytorch dataset with tensors not numpy arrays
    Kinda weird that ngym doesn't return tensors
    todo: allow multiple batches per epoch
        allow multiple datasets
        should sample randomly?
        can i handle batch size from here? maybe if i have __init__ build from ngym instead of passing the prebuilt dataset
    """
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def __iter__(self):
        inputs, targets = self.dataset()
        inputs = torch.from_numpy(inputs).type(torch.float)
        targets = torch.from_numpy(targets.flatten()).type(torch.long)
        
        # never seen yield before, returns the object(s) without exiting the function
        yield (inputs, targets)


In [63]:
seed_everything(42)

trainer_kwargs = {'max_epochs': 25,
                  'logger': TensorBoardLogger('logs/')
                  }
trainer = Trainer(**trainer_kwargs)

dataset = NgymWrapper(ngym_dataset)
hidden_size = 64


model = RNNNet(input_size, hidden_size, output_size)
print(model)

trainer.fit(model, train_dataloaders=dataset)

Seed set to 42
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type             | Params | Mode 
--------------------------------------------------
0 | rnn  | LeakyRNN         | 4.4 K  | train
1 | ho   | Linear           | 195    | train
2 | loss | CrossEntropyLoss | 0      | train
--------------------------------------------------
4.6 K     Trainable params
0         Non-trainable params
4.6 K     Total params
0.018     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode


RNNNet(
  (rnn): LeakyRNN(
    (inh): Linear(in_features=3, out_features=64, bias=True)
    (hh): Linear(in_features=64, out_features=64, bias=True)
    (nonlinearity): LeakyReLU(negative_slope=0.01)
  )
  (ho): Linear(in_features=64, out_features=3, bias=True)
  (loss): CrossEntropyLoss()
)
Epoch 24: |          | 1/? [00:00<00:00, 11.67it/s, v_num=4, train_loss=0.253]

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: |          | 1/? [00:00<00:00, 10.33it/s, v_num=4, train_loss=0.253]


### Testing
Testing the model with a new, unseen batch of ngym data

In [64]:
test_ngym_dataset = ngym.Dataset(task, kwargs, batch_size, seq_len)

test_dataset = NgymWrapper(test_ngym_dataset)

trainer.test(model, dataloaders=test_dataset)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: |          | 1/? [00:00<00:00, 25.25it/s]
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
       Test metric             DataLoader 0
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
        test_loss           0.20977191627025604
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”

[{'test_loss': 0.20977191627025604}]