In [13]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import tqdm
from torch.utils.data import Dataset
from copy import deepcopy
from scipy import integrate
import os

In [10]:
class GRU(torch.nn.Module):
    def __init__(self, num_inputs, num_hiddens, num_outputs, per_timestep_readout=True, sigma=0.01):
        super().__init__()

        self.per_timestep_readout = per_timestep_readout
        # Gaussian random init with standard deviation *sigma*
        init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)

        # It is easier to initialize it this way since we always need to worry about
        # (1) projections from the inputs, (2) projections from the latent state, and (3) the bias
        # Note that unlike biological RNNs, we **do not** introduce stochasticity in the activities
 
        triple = lambda: (init_weight(num_inputs, num_hiddens),
                          init_weight(num_hiddens, num_hiddens),
                          nn.Parameter(torch.zeros(num_hiddens)))

        # create the parameters for the update gate
        self.W_xz, self.W_hz, self.b_z = triple()

        # create the parameters for the reset gate
        self.W_xr, self.W_hr, self.b_r = triple()

        # hidden state parameters
        self.W_xh, self.W_hh, self.b_h = triple()

        # readout layer parameters
        self.fc = nn.Linear(num_hiddens, num_outputs)
        self.relu = nn.ReLU()

    ''' Given that our parent class is nn.Module, what we are doing here is essentially *overloading*
    This is the function that will be called when we pass a batch of inputs to the GRU
    '''
    def forward(self, inputs, H=None):
        matmul_H = lambda A, B: torch.matmul(A, B) if H is not None else 0
        outputs = []
        readouts = []

        for X in inputs:
            Z = torch.sigmoid(torch.matmul(X, self.W_xz) + (
                torch.matmul(H, self.W_hz) if H is not None else 0) + self.b_z)
            if H is None: H = torch.zeros_like(Z)
            R = torch.sigmoid(torch.matmul(X, self.W_xr) +
                            torch.matmul(H, self.W_hr) + self.b_r)
            H_tilda = torch.tanh(torch.matmul(X, self.W_xh) +
                               torch.matmul(R * H, self.W_hh) + self.b_h)
            H = Z * H + (1 - Z) * H_tilda
            outputs.append(H)

            if self.per_timestep_readout:
                readouts.append(self.fc(self.relu(H)))

        if not self.per_timestep_readout:
            # final timestep readout layer
            readouts.append(self.fc(self.relu(H)))

        return outputs, readouts

    def single_step(self, X, H):
        matmul_H = lambda A, B: torch.matmul(A, B)

        Z = torch.sigmoid(torch.matmul(X, self.W_xz) + (
            torch.matmul(H, self.W_hz) if H is not None else 0) + self.b_z)
        
        R = torch.sigmoid(torch.matmul(X, self.W_xr) +
                        torch.matmul(H, self.W_hr) + self.b_r)
        
        H_tilda = torch.tanh(torch.matmul(X, self.W_xh) +
                           torch.matmul(R * H, self.W_hh) + self.b_h)
        
        H = Z * H + (1 - Z) * H_tilda

        return H, self.fc(self.relu(H))

In [5]:
class FitzhughNagumo(Dataset):
    def __init__(self, N, T, I=0.5, a=0.7, b=0.8):
        self.I = I
        self.a = a
        self.b = b
        self.N = N
        self.T = T

        data_x = []
        data_y = []
        for i in range(N):
            t = np.linspace(0,400,T+1)
            x0 = np.array([float(np.random.rand(1))*2.-1.,0.])
            sol = integrate.solve_ivp(self.FHN_rhs, [0,400], x0, t_eval=t)
            data_x.append(sol.y[0,:-1])
            data_y.append(sol.y[0,1:])

        self.data_x = np.array(data_x).reshape(N,T,1)
        self.data_y = np.array(data_y).reshape(N,T,1)

    def __len__(self):
        return self.data_x.shape[0]

    def __getitem__(self, idx):
        return torch.Tensor(self.data_x[idx]), torch.Tensor(self.data_y[idx])

    def FHN_rhs(self, t,x):
        I, a, b = self.I, self.a, self.b
        eps = 1./50.
        dim1 = x[0] - (x[0]**3)/3. - x[1] + I
        dim2 = eps*(x[0] + a - b*x[1])
        out = np.stack((dim1,dim2)).T

        return out

    def get_init(self):
        t = np.linspace(0,400,self.T+1)
        x0 = np.array([float(np.random.rand(1))*2.-1.,0.])
        sol = integrate.solve_ivp(self.FHN_rhs, [0,400], x0, t_eval=t)
        init_x = sol.y[0, :50]
        return init_x

In [6]:
def train_model(model, dataset, params, warm_up=50):

    # create the data generator to iterate over mini batches
    trainDataGenerator = torch.utils.data.DataLoader(dataset, **params['train_params'])

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=params['init_lr'])

    for epoch in range(params['num_epochs']):

        for data, label in trainDataGenerator:
            # The inputs need to be of the form T x B x N_in
            # where T is the total "time" duration of the signal, B is the batch size
            # and N_in is the feature dimensionality of an observation
            data = data.transpose(0, 1)

            # forward pass to warm-up
            latent_activities, readout = model(data[:warm_up])

            # now the autoregression begins
            autoreg_outputs = []
            latent = latent_activities[-1]
            X = readout[-1]

            for t in range(warm_up, data.shape[0]):
                latent, X = model.single_step(X, latent)
                autoreg_outputs.append(X)

            autoreg_outputs = torch.stack(autoreg_outputs)

            # compute the loss
            loss = criterion(autoreg_outputs, data[warm_up:]) #.to('cuda:0'))

            # backpropagate through time!
            loss.backward()

            # update model parameters
            optimizer.step()
            optimizer.zero_grad()

        print('Epoch: {} | Training Loss: {}'.format(epoch, loss.item()))

    return model

In [7]:
fhDataset = FitzhughNagumo(N=256, T=512)


In [8]:
params = {
        'n_inputs': 1,
        'n_hidden': 128,
        'num_epochs': 1000,
        'init_lr': 1e-3,
        'n_outputs': 1,

        'train_params': {
                    'batch_size': 128,
                    'shuffle': True,
                    'num_workers': 1
                }
    }

In [11]:
# initialize the model architecture and set it to train mode
model = GRU(params['n_inputs'], params['n_hidden'], params['n_outputs'])

In [14]:
if not os.path.exists(os.path.join('../ckpts', 'autoregressiveGRU.pth')):
    model = model.train()

    # Now let's train the model. 
    # Pass visualize_train=False to suppress any display
    model = train_model(model, fhDataset, params, visualize_train=False)
    torch.save(model.state_dict(), '../ckpts/autoregressiveGRU.pth')

else:
    model.load_state_dict(torch.load(os.path.join('../ckpts', 'autoregressiveGRU.pth')))

In [16]:
def generate(model, init_x, future_T=1000):
    model = model.eval()
    gen_seq = init_x.clone()

    for t in tqdm.tqdm(range(future_T)):
        with torch.no_grad():
            _, output = model(gen_seq)
            gen_seq = torch.cat([gen_seq, output[-1].unsqueeze(0)], dim=0) 

    return gen_seq

In [17]:
model = model.eval()

# This is going to be cool. We can treat RNNs as "generative" models too :)
# Let's "seed" the model with an initial sequence
init_x = fhDataset.get_init()
init_x = torch.Tensor(init_x[:, np.newaxis, np.newaxis])

gen_seq = generate(model, init_x, future_T=250)

100%|██████████| 250/250 [00:05<00:00, 42.81it/s]
