#### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
sys.path.append("..")

In [2]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [3]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [4]:
from tqdm import tqdm

In [5]:
from bptt_tgeb_mnist_architecture import *

#### Test for CUDA

In [6]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('No GPU, training on CPU')
    dev = torch.device('cpu')
else:
    print('GPU found, training on GPU')
    dev = torch.device('cuda')

No GPU, training on CPU


#### Load MNIST

In [7]:
## Make sure batch_size = 1 for now!!

def load_mnist(batch_size=1, shuffle_train=True):
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST("../data", train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST("../data", train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle_train)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [8]:
mnist_train_loader, mnist_test_loader = load_mnist()

#### Architectural Initialisations

In [9]:
n_classes = 10

In [10]:
input_dim = 784
hidden_dim = 100
output_dim = 1

In [11]:
## Gating vector
tvec_hh = torch.zeros(n_classes,hidden_dim)
for ii in range(n_classes):
    t_half = torch.randint(0, 2, (1, hidden_dim//2)).float()*2 - 1
    tvec_hh[ii,::2] = t_half
    tvec_hh[ii,1::2] = -t_half

In [12]:
## Gating vector
tvec_ih = torch.zeros(n_classes,hidden_dim)
for ii in range(n_classes):
    t_half = torch.randint(0, 2, (1, hidden_dim//2)).float()*2 - 1
    tvec_ih[ii,::2] = t_half
    tvec_ih[ii,1::2] = -t_half

In [13]:
# tvec_ih = tvec_hh

#### Architecture

In [14]:
# cell = RNNModule(input_dim, hidden_dim, output_dim, tvec_ih, tvec_hh)
# rnn = RNN(cell)

#### Loss

In [15]:
criterion_ce = nn.CrossEntropyLoss()

#### Training loop

In [16]:
# params = [rnn.cell.Wih]+[rnn.cell.Whh]+[rnn.cell.Woh]

In [17]:
# optimizer = optim.SGD(params, lr=5e-3)

In [18]:
epochs = 10

In [19]:
nRuns = 3

In [20]:
train_losses = np.zeros((nRuns,epochs))
train_acc = np.zeros((nRuns,epochs))

acc_classes = np.zeros(n_classes)

In [21]:
for run in range(nRuns):
    
    cell = RNNModule(input_dim, hidden_dim, output_dim, tvec_ih, tvec_hh)
    rnn = RNN(cell)
    
    params = [rnn.cell.Wih]+[rnn.cell.Whh]+[rnn.cell.Woh]
    
    optimizer = optim.SGD(params, lr=5e-3)
    
    for e in range(epochs):
        
        running_loss = 0
        running_acc = 0

        for image, label in tqdm(mnist_test_loader):

            ## Clear older gradients
            optimizer.zero_grad()

            ## Change to appropriate shapes!!
            image = torch.squeeze(image).view(1,-1)
            image = image.repeat(n_classes,1)

            xs = image
            hp = torch.zeros(cell.hid_dim) ## very first hidden state is the zero vector
            ts = torch.LongTensor(label)

            if train_on_gpu:
                xs, hp, ts = xs.cuda(), hp.cuda(), ts.cuda()

            ## Forward pass
            ys, hs = rnn.forward(xs, hp)
            loss = criterion_ce(ys.float().view(1, -1),ts)

            ## Compute gradients w/ Backprop (autograd)
            loss.backward()

            ## update weights
            optimizer.step()

            ## update loss
            running_loss += loss.item()

            ##check if sample is correctly classified
            pred_class = torch.argmax(ys)
            true_class = ts
            if (pred_class-true_class) == 0:
                running_acc +=1
            acc_classes[int(true_class[0])] += 1

        
        train_losses[run,e] = running_loss/len(mnist_test_loader)
        train_acc[run,e] = running_acc/len(mnist_test_loader)
        print(f"Training loss: {running_loss/len(mnist_test_loader)}")
        print(f"Training acc: {running_acc/len(mnist_test_loader)}")

100%|██████████| 10000/10000 [01:20<00:00, 123.70it/s]


Training loss: 1.842613744199276
Training acc: 0.5458


100%|██████████| 10000/10000 [01:34<00:00, 106.08it/s]


Training loss: 1.5964221465468407
Training acc: 0.8012


100%|██████████| 10000/10000 [02:16<00:00, 73.06it/s]


Training loss: 1.563451170384884
Training acc: 0.8375


100%|██████████| 10000/10000 [01:18<00:00, 128.13it/s]


Training loss: 1.547852477478981
Training acc: 0.8797


100%|██████████| 10000/10000 [01:22<00:00, 120.62it/s]


Training loss: 1.5361614840269089
Training acc: 0.9064


100%|██████████| 10000/10000 [01:31<00:00, 108.86it/s]


Training loss: 1.526982453751564
Training acc: 0.9212


100%|██████████| 10000/10000 [01:33<00:00, 107.52it/s]


Training loss: 1.5212139238238334
Training acc: 0.9232


100%|██████████| 10000/10000 [01:43<00:00, 96.99it/s]


Training loss: 1.5170430536031723
Training acc: 0.9339


100%|██████████| 10000/10000 [01:59<00:00, 83.96it/s]


Training loss: 1.514110361123085
Training acc: 0.9356


100%|██████████| 10000/10000 [02:01<00:00, 82.06it/s]


Training loss: 1.5091771039128303
Training acc: 0.9406


100%|██████████| 10000/10000 [01:45<00:00, 94.50it/s]


Training loss: 1.8159673189163208
Training acc: 0.6431


100%|██████████| 10000/10000 [01:55<00:00, 86.91it/s]


Training loss: 1.5880081775784491
Training acc: 0.854


100%|██████████| 10000/10000 [01:49<00:00, 91.18it/s]


Training loss: 1.539299065876007
Training acc: 0.9121


100%|██████████| 10000/10000 [02:04<00:00, 80.63it/s]


Training loss: 1.524467913365364
Training acc: 0.9273


100%|██████████| 10000/10000 [01:49<00:00, 91.73it/s]


Training loss: 1.51003835170269
Training acc: 0.9398


100%|██████████| 10000/10000 [01:46<00:00, 93.96it/s]


Training loss: 1.593526305794716
Training acc: 0.8291


100%|██████████| 10000/10000 [02:15<00:00, 73.82it/s]


Training loss: 1.5348826131105422
Training acc: 0.9164


100%|██████████| 10000/10000 [01:49<00:00, 91.55it/s]


Training loss: 1.5150229097723962
Training acc: 0.9362


100%|██████████| 10000/10000 [02:10<00:00, 76.73it/s]

Training loss: 1.5098208153128625
Training acc: 0.9394





In [23]:
np.save('train-losses-all-weights-hh-ih-bptt',train_losses)
np.save('train-accs-all-weights-hh-ih-bptt',train_acc)

In [24]:
train_acc

array([[0.5458, 0.8012, 0.8375, 0.8797, 0.9064, 0.9212, 0.9232, 0.9339,
        0.9356, 0.9406],
       [0.6431, 0.854 , 0.9   , 0.9121, 0.9208, 0.9273, 0.9299, 0.9381,
        0.9398, 0.9422],
       [0.5936, 0.8291, 0.8717, 0.9059, 0.9164, 0.9262, 0.9267, 0.9362,
        0.9383, 0.9394]])