#### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
sys.path.append("..")

In [2]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [3]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [4]:
from tqdm import tqdm

In [5]:
from bptt_tgeb_mnist_architecture import *

#### Test for CUDA

In [6]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('No GPU, training on CPU')
    dev = torch.device('cpu')
else:
    print('GPU found, training on GPU')
    dev = torch.device('cuda')

No GPU, training on CPU


#### Load MNIST

In [7]:
## Make sure batch_size = 1 for now!!

def load_mnist(batch_size=1, shuffle_train=True):
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST("../data", train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST("../data", train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle_train)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [8]:
mnist_train_loader, mnist_test_loader = load_mnist()

#### Architectural Initialisations

In [9]:
n_classes = 10

In [10]:
input_dim = 784
hidden_dim = 100
output_dim = 1

In [11]:
## Gating vector
tvec_hh = torch.zeros(n_classes,hidden_dim)
for ii in range(n_classes):
    t_half = torch.randint(0, 2, (1, hidden_dim//2)).float()*2 - 1
    tvec_hh[ii,::2] = t_half
    tvec_hh[ii,1::2] = -t_half

In [12]:
## Gating vector
tvec_ih = torch.zeros(n_classes,hidden_dim)
for ii in range(n_classes):
    t_half = torch.randint(0, 2, (1, hidden_dim//2)).float()*2 - 1
    tvec_ih[ii,::2] = t_half
    tvec_ih[ii,1::2] = -t_half

In [13]:
# tvec_ih = tvec_hh

#### Architecture

In [14]:
# cell = RNNModule(input_dim, hidden_dim, output_dim, tvec_ih, tvec_hh)
# rnn = RNN(cell)

#### Loss

In [15]:
criterion_ce = nn.CrossEntropyLoss()

#### Training loop

In [16]:
# params = [rnn.cell.Wih]+[rnn.cell.Whh]+[rnn.cell.Woh]

In [17]:
# optimizer = optim.SGD(params, lr=5e-3)

In [18]:
epochs = 10

In [19]:
nRuns = 3

In [20]:
train_losses = np.zeros((nRuns,epochs))
train_acc = np.zeros((nRuns,epochs))

acc_classes = np.zeros(n_classes)

In [21]:
for run in range(nRuns):
    
    cell = RNNModule(input_dim, hidden_dim, output_dim, tvec_ih, tvec_hh)
    rnn = RNN(cell)
    
    params = [rnn.cell.Whh]+[rnn.cell.Woh]
    
    optimizer = optim.SGD(params, lr=5e-3)
    
    for e in range(epochs):
        
        running_loss = 0
        running_acc = 0

        for image, label in tqdm(mnist_test_loader):

            ## Clear older gradients
            optimizer.zero_grad()

            ## Change to appropriate shapes!!
            image = torch.squeeze(image).view(1,-1)
            image = image.repeat(n_classes,1)

            xs = image
            hp = torch.zeros(cell.hid_dim) ## very first hidden state is the zero vector
            ts = torch.LongTensor(label)

            if train_on_gpu:
                xs, hp, ts = xs.cuda(), hp.cuda(), ts.cuda()

            ## Forward pass
            ys, hs = rnn.forward(xs, hp)
            loss = criterion_ce(ys.float().view(1, -1),ts)

            ## Compute gradients w/ Backprop (autograd)
            loss.backward()

            ## update weights
            optimizer.step()

            ## update loss
            running_loss += loss.item()

            ##check if sample is correctly classified
            pred_class = torch.argmax(ys)
            true_class = ts
            if (pred_class-true_class) == 0:
                running_acc +=1
            acc_classes[int(true_class[0])] += 1

        
        train_losses[run,e] = running_loss/len(mnist_test_loader)
        train_acc[run,e] = running_acc/len(mnist_test_loader)
        print(f"Training loss: {running_loss/len(mnist_test_loader)}")
        print(f"Training acc: {running_acc/len(mnist_test_loader)}")

100%|██████████| 10000/10000 [01:43<00:00, 96.48it/s]


Training loss: 2.1470159982562067
Training acc: 0.3198


100%|██████████| 10000/10000 [01:17<00:00, 128.22it/s]


Training loss: 1.8462618626713754
Training acc: 0.5333


100%|██████████| 10000/10000 [01:21<00:00, 123.10it/s]


Training loss: 1.7458268641352654
Training acc: 0.621


100%|██████████| 10000/10000 [01:40<00:00, 99.10it/s] 


Training loss: 1.6905132331252097
Training acc: 0.7363


100%|██████████| 10000/10000 [01:43<00:00, 96.57it/s]


Training loss: 1.6614133506774902
Training acc: 0.7887


100%|██████████| 10000/10000 [01:42<00:00, 97.17it/s] 


Training loss: 1.644657049381733
Training acc: 0.7993


100%|██████████| 10000/10000 [01:46<00:00, 94.25it/s]


Training loss: 1.6291710914373398
Training acc: 0.8215


100%|██████████| 10000/10000 [01:44<00:00, 95.24it/s]


Training loss: 1.6231127940893173
Training acc: 0.8181


100%|██████████| 10000/10000 [01:42<00:00, 97.11it/s]


Training loss: 1.6138138742923736
Training acc: 0.8296


100%|██████████| 10000/10000 [01:50<00:00, 90.34it/s]


Training loss: 1.6087224274396896
Training acc: 0.8213


100%|██████████| 10000/10000 [01:48<00:00, 92.06it/s]


Training loss: 2.15535080589056
Training acc: 0.3739


100%|██████████| 10000/10000 [01:46<00:00, 93.81it/s] 


Training loss: 1.8105514201521873
Training acc: 0.6067


100%|██████████| 10000/10000 [01:42<00:00, 97.88it/s]


Training loss: 1.6284765392780305
Training acc: 0.8121


100%|██████████| 10000/10000 [01:47<00:00, 92.80it/s]


Training loss: 1.6126771145224572
Training acc: 0.8325


100%|██████████| 10000/10000 [01:56<00:00, 85.66it/s]


Training loss: 1.5999619376778602
Training acc: 0.8368


100%|██████████| 10000/10000 [01:46<00:00, 93.96it/s] 


Training loss: 1.632600198531151
Training acc: 0.8179


100%|██████████| 10000/10000 [01:47<00:00, 93.40it/s]


Training loss: 1.6139769778847695
Training acc: 0.8312


100%|██████████| 10000/10000 [01:38<00:00, 101.98it/s]


Training loss: 1.6075553658127786
Training acc: 0.826


100%|██████████| 10000/10000 [01:27<00:00, 114.17it/s]

Training loss: 1.6024482870697976
Training acc: 0.8321





In [23]:
np.save('train-losses-sub-weights-hh-ih-bptt',train_losses)
np.save('train-accs-sub-weights-hh-ih-bptt',train_acc)