In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import tqdm
from torch.utils.data import Dataset
from scipy import integrate

In [2]:
class GRU(torch.nn.Module):
    def __init__(self, num_inputs, num_hiddens, num_outputs, per_timestep_readout=True, sigma=0.01):
        super().__init__()

        self.per_timestep_readout = per_timestep_readout
        # Gaussian random init with standard deviation *sigma*
        init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)

        # It is easier to initialize it this way since we always need to worry about
        # (1) projections from the inputs, (2) projections from the latent state, and (3) the bias
        # Note that unlike biological RNNs, we **do not** introduce stochasticity in the activities
 
        triple = lambda: (init_weight(num_inputs, num_hiddens),
                          init_weight(num_hiddens, num_hiddens),
                          nn.Parameter(torch.zeros(num_hiddens)))

        # create the parameters for the update gate
        self.W_xz, self.W_hz, self.b_z = triple()

        # create the parameters for the reset gate
        self.W_xr, self.W_hr, self.b_r = triple()

        # hidden state parameters
        self.W_xh, self.W_hh, self.b_h = triple()

        # readout layer parameters
        self.fc = nn.Linear(num_hiddens, num_outputs)
        self.relu = nn.ReLU()

    ''' Given that our parent class is nn.Module, what we are doing here is essentially *overloading*
    This is the function that will be called when we pass a batch of inputs to the GRU
    '''
    def forward(self, inputs, H=None):
        matmul_H = lambda A, B: torch.matmul(A, B) if H is not None else 0
        outputs = []
        readouts = []

        for X in inputs:
            Z = torch.sigmoid(torch.matmul(X, self.W_xz) + (
                torch.matmul(H, self.W_hz) if H is not None else 0) + self.b_z)
            if H is None: H = torch.zeros_like(Z)
            R = torch.sigmoid(torch.matmul(X, self.W_xr) +
                            torch.matmul(H, self.W_hr) + self.b_r)
            H_tilda = torch.tanh(torch.matmul(X, self.W_xh) +
                               torch.matmul(R * H, self.W_hh) + self.b_h)
            H = Z * H + (1 - Z) * H_tilda
            outputs.append(H)

            if self.per_timestep_readout:
                readouts.append(self.fc(self.relu(H)))

        if not self.per_timestep_readout:
            # final timestep readout layer
            readouts.append(self.fc(self.relu(H)))

        return outputs, readouts

    def single_step(self, X, H):
        matmul_H = lambda A, B: torch.matmul(A, B)

        Z = torch.sigmoid(torch.matmul(X, self.W_xz) + (
            torch.matmul(H, self.W_hz) if H is not None else 0) + self.b_z)
        
        R = torch.sigmoid(torch.matmul(X, self.W_xr) +
                        torch.matmul(H, self.W_hr) + self.b_r)
        
        H_tilda = torch.tanh(torch.matmul(X, self.W_xh) +
                           torch.matmul(R * H, self.W_hh) + self.b_h)
        
        H = Z * H + (1 - Z) * H_tilda

        return H, self.fc(self.relu(H))


In [3]:
class FitzhughNagumo(Dataset):
    def __init__(self, N, T, I=0.5, a=0.7, b=0.8):
        self.I = I
        self.a = a
        self.b = b
        self.N = N
        self.T = T

        data_x = []
        data_y = []
        for i in range(N):
            t = np.linspace(0,400,T+1)
            x0 = np.array([float(np.random.rand(1))*2.-1.,0.])
            sol = integrate.solve_ivp(self.FHN_rhs, [0,400], x0, t_eval=t)
            data_x.append(sol.y[0,:-1])
            data_y.append(sol.y[0,1:])

        self.data_x = np.array(data_x).reshape(N,T,1)
        self.data_y = np.array(data_y).reshape(N,T,1)

    def __len__(self):
        return self.data_x.shape[0]

    def __getitem__(self, idx):
        return torch.Tensor(self.data_x[idx]), torch.Tensor(self.data_y[idx])

    def FHN_rhs(self, t,x):
        I, a, b = self.I, self.a, self.b
        eps = 1./50.
        dim1 = x[0] - (x[0]**3)/3. - x[1] + I
        dim2 = eps*(x[0] + a - b*x[1])
        out = np.stack((dim1,dim2)).T

        return out

    def get_init(self):
        t = np.linspace(0,400,self.T+1)
        x0 = np.array([float(np.random.rand(1))*2.-1.,0.])
        sol = integrate.solve_ivp(self.FHN_rhs, [0,400], x0, t_eval=t)
        init_x = sol.y[0, :50]
        return init_x


In [4]:
def train_model(model, dataset, params, visualize_train=True):
    %matplotlib notebook
    
    # create the data generator to iterate over mini batches
    trainDataGenerator = torch.utils.data.DataLoader(dataset, **params['train_params'])

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=params['init_lr'])

    if visualize_train:
        fig = plt.figure()
        ax = fig.add_subplot(111)

    for epoch in range(params['num_epochs']):

        for data, label in trainDataGenerator:
            # The inputs need to be of the form T x B x N_in
            # where T is the total "time" duration of the signal, B is the batch size
            # and N_in is the feature dimensionality of an observation
            data = data.transpose(0, 1) #.to('cuda:0')

            # forward pass
            latent_activities, readout = model(data)

            readout = torch.stack(readout).permute(1,0,-1)
            # compute the loss
            loss = criterion(readout, label) #.to('cuda:0'))

            # backpropagate through time!
            loss.backward()

            # update model parameters
            optimizer.step()
            optimizer.zero_grad()
            
            if visualize_train:
                ax.clear()
                # Let's pick index 0, since batch is shuffled anyway!
                ax.plot(data[:,0,0].detach().numpy(), linewidth=2, color='tab:gray', label='groundtruth')
                ax.plot(readout[0,:,0].detach().numpy(), '--', linewidth=2, color='r', label='prediction')

                # Just formatting options. This is my pet peeve so you can safely ignore!
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.set_title('Training epoch: {}'.format(epoch))
                ax.set_xlabel('Time', fontsize=16, fontweight='bold')
                ax.set_ylabel('Firing rate (in a.u.)', fontsize=16, fontweight='bold')
                ax.legend(loc='upper right')
                ax.set_xticks([0., data.shape[0]])
                ax.set_xticklabels(['0ms', '{}ms'.format(data.shape[0])])
                ax.set_yticks([])
                ax.set_ylim([-2.5, 2.5]) 
                
                plt.pause(0.1)
                #plt.close()
                
        print('Epoch: {} | Training Loss: {}'.format(epoch, loss.item()))

    return model

In [5]:
# Create a dataset instance. What do the N and T signify?
fhDataset = FitzhughNagumo(N=128, T=1000)

In [6]:
params = {
        'n_inputs': 1,
        'n_hidden': 32,
        'num_epochs': 50,
        'init_lr': 1e-2,
        'n_outputs': 1,
        'train_params': {
                    'batch_size': 128,
                    'shuffle': True,
                    'num_workers': 1
                }
    }

In [7]:
# initialize the model architecture and set it to train mode
model = GRU(params['n_inputs'], params['n_hidden'], params['n_outputs'])
model = model.train()

In [8]:
# Now let's train the model. 
# Pass visualize_train=False to suppress any display
model = train_model(model, fhDataset, params)

<IPython.core.display.Javascript object>

Epoch: 0 | Training Loss: 2.259450674057007
Epoch: 1 | Training Loss: 2.1657955646514893
Epoch: 2 | Training Loss: 2.067577600479126
Epoch: 3 | Training Loss: 1.934626817703247
Epoch: 4 | Training Loss: 1.7558752298355103
Epoch: 5 | Training Loss: 1.5159595012664795
Epoch: 6 | Training Loss: 1.1972414255142212
Epoch: 7 | Training Loss: 0.8041644096374512
Epoch: 8 | Training Loss: 0.43760445713996887
Epoch: 9 | Training Loss: 0.3495078682899475
Epoch: 10 | Training Loss: 0.46550247073173523
Epoch: 11 | Training Loss: 0.5076450109481812
Epoch: 12 | Training Loss: 0.455902099609375
Epoch: 13 | Training Loss: 0.3567546308040619
Epoch: 14 | Training Loss: 0.2519708573818207
Epoch: 15 | Training Loss: 0.1725485771894455
Epoch: 16 | Training Loss: 0.13224007189273834
Epoch: 17 | Training Loss: 0.12328262627124786
Epoch: 18 | Training Loss: 0.12786731123924255
Epoch: 19 | Training Loss: 0.13182291388511658
Epoch: 20 | Training Loss: 0.12870605289936066
Epoch: 21 | Training Loss: 0.117987245321

In [18]:
def evaluate_model(model):
    # First off, let's create a new dataset. Since the initializations are random, we can 
    # consider this a proper test! To make life harder for the model, lets change up T too
    N, T = 1024, 2000
    test_dataset = FitzhughNagumo(N=1024, T=2000)

    # We still need an evaluation criterion
    criterion = torch.nn.MSELoss()
 
    # Create the data tensors
    x = torch.Tensor(test_dataset.data_x.reshape(N, T, 1)).permute(1,0,-1)
    y = torch.Tensor(test_dataset.data_y.reshape(N, T, 1))

    # Compute the feedforward pass. 
    # But since we aren't training, we can do without the gradients
    with torch.no_grad():
        _, pred = model(x)
        pred = torch.stack(pred).permute(1, 0, -1)
        test_error = criterion(pred, y).item()

    # Q: How does this compare with the training loss?
    # What can you say about this?
    print('Test error: {}'.format(test_error))


In [17]:
# Let's set the model to eval mode, and see its performance on a new random set
model = model.eval()
evaluate_model(model)

Test error: 0.010034583508968353
