#### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
sys.path.append("..")

In [2]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [3]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [4]:
from tqdm import tqdm

In [5]:
from bptt_tgeb_mnist_architecture import *

#### Test for CUDA

In [6]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('No GPU, training on CPU')
    dev = torch.device('cpu')
else:
    print('GPU found, training on GPU')
    dev = torch.device('cuda')

No GPU, training on CPU


#### Load MNIST

In [7]:
## Make sure batch_size = 1 for now!!

def load_mnist(batch_size=1, shuffle_train=True):
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST("../data", train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST("../data", train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle_train)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [8]:
mnist_train_loader, mnist_test_loader = load_mnist()

#### Architectural initialisations

In [9]:
n_classes = 10

In [10]:
input_dim = 784
hidden_dim = 100
output_dim = 1

In [11]:
## Gating vector
tvec_hh = torch.zeros(n_classes,hidden_dim)
for ii in range(n_classes):
    t_half = torch.randint(0, 2, (1, hidden_dim//2)).float()*2 - 1
    tvec_hh[ii,::2] = t_half
    tvec_hh[ii,1::2] = -t_half

In [12]:
# ## Gating vector
# tvec_ih = torch.zeros(n_classes,hidden_dim)
# for ii in range(n_classes):
#     t_half = torch.randint(0, 2, (1, hidden_dim//2)).float()*2 - 1
#     tvec_ih[ii,::2] = t_half
#     tvec_ih[ii,1::2] = -t_half

In [13]:
tvec_ih = tvec_hh

#### Architecture

In [14]:
cell = RNNModule(input_dim, hidden_dim, output_dim, tvec_ih, tvec_hh)
rnn = RNN(cell)

In [15]:
# cell.to(dev)
# rnn.to(dev)

#### Loss

In [16]:
def compute_loss(ys, ts):
    return 0.5 * torch.sum((ys - ts)**2)

In [17]:
criterion_ce = nn.CrossEntropyLoss()

#### Clip micro gradients

In [18]:
def clip_micro_grads(grad_tensor, minVal=-1e-7, maxVal=1e-7):
    grad_tensor[grad_tensor==torch.clamp(grad_tensor, minVal, maxVal)] = 0
    return grad_tensor

#### Training loop

In [19]:
params = [rnn.cell.Wih]+[rnn.cell.Whh]+[rnn.cell.Woh]

In [20]:
optimizer = optim.SGD(params, lr=5e-3)

In [21]:
epochs = 10

In [22]:
train_losses = np.zeros(epochs)
train_acc = np.zeros(epochs)

acc_classes = np.zeros(n_classes)

In [23]:
# a = ys.float().view(1, -1)
# b = torch.LongTensor([6])

# criterion = nn.CrossEntropyLoss()
# ll = criterion(a, b)
# print(ll)

In [24]:
for e in range(epochs):
    
    running_loss = 0
    running_acc = 0
    
    for image, label in tqdm(mnist_test_loader):
        
        ## Clear older gradients
        optimizer.zero_grad()
        
        ## Change to appropriate shapes!!
        image = torch.squeeze(image).view(1,-1)
        image = image.repeat(n_classes,1)

        xs = image
        hp = torch.zeros(cell.hid_dim) ## very first hidden state is the zero vector
        ts = torch.LongTensor(label)
        
        if train_on_gpu:
            xs, hp, ts = xs.cuda(), hp.cuda(), ts.cuda()

        ## Forward pass
        ys, hs = rnn.forward(xs, hp)
#         loss = compute_loss(ys, ts)
        loss = criterion_ce(ys.float().view(1, -1),ts)

        ## Compute gradients w/ Backprop (autograd)
        loss.backward()
        
        ## update weights
        optimizer.step()
        
        ## update loss
        running_loss += loss.item()
        
        ##check if sample is correctly classified
        pred_class = torch.argmax(ys)
        true_class = ts
        if (pred_class-true_class) == 0:
            running_acc +=1
        acc_classes[int(true_class[0])] += 1
    
    train_loss = running_loss
    train_acc = running_acc
    print(f"Training loss: {running_loss/len(mnist_test_loader)}")
    print(f"Training acc: {running_acc/len(mnist_test_loader)}")

100%|██████████| 10000/10000 [00:34<00:00, 293.43it/s]


Training loss: 1.8151166736125945
Training acc: 0.5808


100%|██████████| 10000/10000 [00:35<00:00, 280.72it/s]


Training loss: 1.5971530517935753
Training acc: 0.7994


100%|██████████| 10000/10000 [00:36<00:00, 277.31it/s]


Training loss: 1.5605059388518334
Training acc: 0.8733


100%|██████████| 10000/10000 [00:35<00:00, 280.87it/s]


Training loss: 1.5401383207321167
Training acc: 0.9124


100%|██████████| 10000/10000 [00:36<00:00, 271.40it/s]


Training loss: 1.5309971501350403
Training acc: 0.9196


100%|██████████| 10000/10000 [00:38<00:00, 260.72it/s]


Training loss: 1.5226776613235473
Training acc: 0.9298


100%|██████████| 10000/10000 [01:21<00:00, 122.61it/s]


Training loss: 1.515297588610649
Training acc: 0.9353


100%|██████████| 10000/10000 [00:38<00:00, 261.84it/s]


Training loss: 1.513420859515667
Training acc: 0.9364


100%|██████████| 10000/10000 [00:39<00:00, 252.50it/s]


Training loss: 1.5090141604542733
Training acc: 0.9413


100%|██████████| 10000/10000 [00:45<00:00, 220.99it/s]

Training loss: 1.5068807791233063
Training acc: 0.9464



