## Pytorch supervised learning of perceptual decision making task

Pytorch-based example code for training a RNN on a perceptual decision-making task.

### Dataset

In [1]:
import numpy as np
import torch
import torch.nn as nn
import json
import neurogym as ngym
import gym
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# Environment
task = 'GoNogo-v0'  # 'PerceptualDecisionMaking-v0', 'DelayComparison-v0'
kwargs = {'dt': 100}
seq_len = 100

# Make supervised dataset
dataset = ngym.Dataset(task, env_kwargs=kwargs, batch_size=32,
                       seq_len=seq_len)
env = dataset.env
ob_size = env.observation_space.shape[0]
act_size = env.action_space.n

### Network and Training

In [3]:
class Net(nn.Module):
    def __init__(self, num_h):
        super(Net, self).__init__()
        self.rnn = nn.RNN(ob_size, num_h)
        self.linear = nn.Linear(num_h, act_size)

    def forward(self, x):
        out, hidden = self.rnn(x)
        x = self.linear(out)
        return x, out

device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = Net(num_h=8).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)

running_loss = 0.0
for i in range(2000):
    inputs, labels = dataset()
    inputs = torch.from_numpy(inputs).type(torch.float).to(device)
    labels = torch.from_numpy(labels.flatten()).type(torch.long).to(device)
    # print(inputs.shape, labels.shape)  # inputs: (seq_len, batch_size, 3), labels: seq_len * batch_size

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs, _ = net(inputs)

    loss = criterion(outputs.view(-1, act_size), labels)
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item()
    if i % 200 == 199:
        print('{:d} loss: {:0.5f}'.format(i + 1, running_loss / 200))
        running_loss = 0.0

print('Finished Training')

200 loss: 0.28027
400 loss: 0.22989
600 loss: 0.22902
800 loss: 0.22769
1000 loss: 0.06301
1200 loss: 0.02177
1400 loss: 0.02070
1600 loss: 0.02067
1800 loss: 0.02030
2000 loss: 0.02056
Finished Training


In [5]:
torch.save(net.state_dict(), "taskRNN_data/GoNogo/rnn8/rnn8_ep2000.pth")

In [6]:
config = {
    'dt': 100,
    'hidden_size': 8,
    'lr': 1e-2,
    'batch_size': 32,
    'seq_len': 100,
    'envid': task,
}

env_kwargs = {'dt': config['dt']}
config['env_kwargs'] = env_kwargs

# Save config
with open('taskRNN_data/GoNogo/rnn8/config_rnn8.json', 'w') as f:
    json.dump(config, f)

In [7]:
def infer_test_timing(env):
    """Infer timing of environment for testing."""
    timing = {}
    for period in env.timing.keys():
        period_times = [env.sample_time(period) for _ in range(100)]
        timing[period] = np.median(period_times)
    return timing

In [8]:
"""Run trained networks for analysis.

Args:
    envid: str, Environment ID

Returns:
    activity: a list of activity matrices, each matrix has shape (
    N_time, N_neuron)
    info: pandas dataframe, each row is information of a trial
    config: dict of network, training configurations
"""
with open('taskRNN_data/GoNogo/rnn8/config_rnn8.json') as f:
    config = json.load(f)

env_kwargs = config['env_kwargs']

# Run network to get activity and info
# Environment
env = gym.make(task, **env_kwargs)
env.timing = infer_test_timing(env)
print(env.timing)
env.reset(no_step=True)

# Instantiate the network and print information
with torch.no_grad():
    net = Net(
              num_h=config['hidden_size'],
              )
    net = net.to(device)
    net.load_state_dict(torch.load('taskRNN_data/GoNogo/rnn8/rnn8_ep2000.pth'))

    perf = 0
    num_trial = 1000

    activity_alltrials = list()
    inputs_alltrials = list()
    gt_alltrials = list()
    action_pred_alltrials = list()
    info = pd.DataFrame()

    for i in range(num_trial):
        env.new_trial()
        ob, gt = env.ob, env.gt
        inputs = torch.from_numpy(ob[:, np.newaxis, :]).type(torch.float)
        action_pred, hidden = net(inputs)

        # Compute performance
        action_pred = action_pred.detach().numpy()
        choice = np.argmax(action_pred[-1, 0, :])
        correct = choice == gt[-1]

        # Log trial info
        trial_info = env.trial
        trial_info.update({'correct': correct, 'choice': choice})
        # info = info.append(trial_info, ignore_index=True)
        info = pd.concat([info, pd.DataFrame([trial_info])], ignore_index=True)

        # Log stimulus period activity
        inputs_alltrials.append(inputs.numpy()[:, 0, :])
        gt_alltrials.append(gt)
        activity_alltrials.append(np.array(hidden)[:, 0, :])
        action_pred_alltrials.append(action_pred)

    print('Average performance', np.mean(info['correct']))

{'fixation': 0.0, 'stimulus': 500.0, 'delay': 500.0, 'decision': 500.0}
Average performance 1.0


In [18]:
info

Unnamed: 0,ground_truth,correct,choice
0,1,True,1
1,0,True,0
2,0,True,0
3,1,True,1
4,1,True,1
...,...,...,...
995,0,True,0
996,1,True,1
997,0,True,0
998,0,True,0


In [9]:
activity_alltrials = np.array(activity_alltrials)
inputs_alltrials = np.array(inputs_alltrials)
gt_alltrials = np.array(gt_alltrials)
action_pred_alltrials = np.array(action_pred_alltrials)

In [10]:
activity_alltrials.shape, inputs_alltrials.shape, gt_alltrials.shape, action_pred_alltrials.shape

((1000, 15, 8), (1000, 15, 3), (1000, 15), (1000, 15, 1, 2))

In [11]:
action_pred_alltrials = action_pred_alltrials[:, :, 0, :]

In [12]:
action_pred_alltrials.shape

(1000, 15, 2)

In [14]:
np.save('taskRNN_data/GoNogo/rnn8/hidden_activity_alltrials.npy', activity_alltrials)
np.save('taskRNN_data/GoNogo/rnn8/inputs_alltrials.npy', inputs_alltrials)
np.save('taskRNN_data/GoNogo/rnn8/gt_alltrials.npy', gt_alltrials)
np.save('taskRNN_data/GoNogo/rnn8/action_pred_alltrial.npy', action_pred_alltrials)

In [19]:
HH_mat = net.rnn.weight_hh_l0.detach()

In [20]:
HH_mat.shape

torch.Size([8, 8])

In [21]:
np.save('taskRNN_data/GoNogo/rnn8/W_hidden_GT.npy', HH_mat)