#### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
sys.path.append("..")

In [2]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [3]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [4]:
from tqdm import tqdm

In [5]:
from bptt_tgeb_mnist_architecture import *

#### Test for CUDA

In [6]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('No GPU, training on CPU')
    dev = torch.device('cpu')
else:
    print('GPU found, training on GPU')
    dev = torch.device('cuda')

No GPU, training on CPU


#### Load MNIST

In [7]:
## Make sure batch_size = 1 for now!!

def load_mnist(batch_size=1, shuffle_train=True):
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST("../data", train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST("../data", train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle_train)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [8]:
mnist_train_loader, mnist_test_loader = load_mnist()

#### Architectural initialisations

In [9]:
n_classes = 10

In [10]:
input_dim = 784
hidden_dim = 100
output_dim = 1

In [11]:
## Gating vector
tvec_hh = torch.zeros(n_classes,hidden_dim)
for ii in range(1):
    t_half = torch.randint(0, 2, (1, hidden_dim//2)).float()*2 - 1
    tvec_hh[ii,::2] = t_half
    tvec_hh[ii,1::2] = -t_half

In [12]:
tvec_ih = tvec_hh

#### Architecture

In [13]:
cell = RNNModule(input_dim, hidden_dim, output_dim, tvec_ih, tvec_hh)
rnn = RNN(cell)

In [14]:
# cell.to(dev)
# rnn.to(dev)

#### Loss

In [15]:
def compute_loss(ys, ts):
    return 0.5 * torch.sum((ys - ts)**2)

#### Clip micro gradients

In [16]:
def clip_micro_grads(grad_tensor, minVal=-1e-7, maxVal=1e-7):
    grad_tensor[grad_tensor==torch.clamp(grad_tensor, minVal, maxVal)] = 0
    return grad_tensor

#### Training loop

In [17]:
params = [rnn.cell.Whh]+[rnn.cell.Woh]

In [18]:
optimizer = optim.SGD(params, lr=0.01)

In [19]:
epochs = 5

In [20]:
for e in range(epochs):
    
    running_loss = 0
    
    for image, label in tqdm(mnist_test_loader):
        
        ## Clear older gradients
        optimizer.zero_grad()
        
        ## Change to appropriate shapes!!
        image = torch.squeeze(image).view(1,-1)
        image = image.repeat(n_classes,1)
        label = F.one_hot(label,n_classes).view(-1,1)

        xs = image
        hp = torch.zeros(cell.hid_dim) ## very first hidden state is the zero vector
        ts = label
        
        if train_on_gpu:
            xs, hp, ts = xs.cuda(), hp.cuda(), ts.cuda()

        ## Forward pass
        ys, hs = rnn.forward(xs, hp)
        loss = compute_loss(ys, ts)

        ## Compute gradients w/ Backprop (autograd)
        loss.backward()
        
        ## update weights
        optimizer.step()
        
        ## update loss
        running_loss += loss.item()
    
    print(f"Training loss: {running_loss/len(mnist_train_loader)}")

100%|██████████| 10000/10000 [01:21<00:00, 122.58it/s]


Training loss: 0.19286559781432153


100%|██████████| 10000/10000 [01:21<00:00, 123.04it/s]


Training loss: 0.19074403175314267


100%|██████████| 10000/10000 [01:22<00:00, 121.38it/s]


Training loss: 0.1901945289850235


100%|██████████| 10000/10000 [01:22<00:00, 120.63it/s]


Training loss: 0.18993746676246326


100%|██████████| 10000/10000 [01:24<00:00, 118.10it/s]

Training loss: 0.18977983029882114



