#### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
sys.path.append("..")

In [2]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [3]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [4]:
from tqdm import tqdm

In [5]:
from bptt_tgeb_mnist_architecture import *

#### Test for CUDA

In [6]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('No GPU, training on CPU')
    dev = torch.device('cpu')
else:
    print('GPU found, training on GPU')
    dev = torch.device('cuda')

No GPU, training on CPU


#### Load MNIST

In [7]:
## Make sure batch_size = 1 for now!!

def load_mnist(batch_size=1, shuffle_train=True):
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST("../data", train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST("../data", train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle_train)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [8]:
mnist_train_loader, mnist_test_loader = load_mnist()

#### Architectural Initialisations

In [9]:
n_classes = 10

In [10]:
input_dim = 784
hidden_dim = 100
output_dim = 1

In [11]:
## Gating vector
tvec_hh = torch.zeros(n_classes,hidden_dim)
for ii in range(n_classes):
    t_half = torch.randint(0, 2, (1, hidden_dim//2)).float()*2 - 1
    tvec_hh[ii,::2] = t_half
    tvec_hh[ii,1::2] = -t_half

In [12]:
# ## Gating vector
# tvec_ih = torch.zeros(n_classes,hidden_dim)
# for ii in range(n_classes):
#     t_half = torch.randint(0, 2, (1, hidden_dim//2)).float()*2 - 1
#     tvec_ih[ii,::2] = t_half
#     tvec_ih[ii,1::2] = -t_half

In [13]:
tvec_ih = tvec_hh

#### Architecture

In [14]:
# cell = RNNModule(input_dim, hidden_dim, output_dim, tvec_ih, tvec_hh)
# rnn = RNN(cell)

#### Loss

In [15]:
criterion_ce = nn.CrossEntropyLoss()

#### Training loop

In [16]:
# params = [rnn.cell.Wih]+[rnn.cell.Whh]+[rnn.cell.Woh]

In [17]:
# optimizer = optim.SGD(params, lr=5e-3)

In [18]:
epochs = 10

In [19]:
nRuns = 3

In [20]:
train_losses = np.zeros((nRuns,epochs))
train_acc = np.zeros((nRuns,epochs))

acc_classes = np.zeros(n_classes)

In [21]:
for run in range(nRuns):
    
    cell = RNNModule(input_dim, hidden_dim, output_dim, tvec_ih, tvec_hh)
    rnn = RNN(cell)
    
    params = [rnn.cell.Wih]+[rnn.cell.Whh]+[rnn.cell.Woh]
    
    optimizer = optim.SGD(params, lr=5e-3)
    
    for e in range(epochs):
        
        running_loss = 0
        running_acc = 0

        for image, label in tqdm(mnist_test_loader):

            ## Clear older gradients
            optimizer.zero_grad()

            ## Change to appropriate shapes!!
            image = torch.squeeze(image).view(1,-1)
            image = image.repeat(n_classes,1)

            xs = image
            hp = torch.zeros(cell.hid_dim) ## very first hidden state is the zero vector
            ts = torch.LongTensor(label)

            if train_on_gpu:
                xs, hp, ts = xs.cuda(), hp.cuda(), ts.cuda()

            ## Forward pass
            ys, hs = rnn.forward(xs, hp)
            loss = criterion_ce(ys.float().view(1, -1),ts)

            ## Compute gradients w/ Backprop (autograd)
            loss.backward()

            ## update weights
            optimizer.step()

            ## update loss
            running_loss += loss.item()

            ##check if sample is correctly classified
            pred_class = torch.argmax(ys)
            true_class = ts
            if (pred_class-true_class) == 0:
                running_acc +=1
            acc_classes[int(true_class[0])] += 1

        
        train_losses[run,e] = running_loss/len(mnist_test_loader)
        train_acc[run,e] = running_acc/len(mnist_test_loader)
        print(f"Training loss: {running_loss/len(mnist_test_loader)}")
        print(f"Training acc: {running_acc/len(mnist_test_loader)}")

100%|██████████| 10000/10000 [02:22<00:00, 69.97it/s]


Training loss: 1.8388246549844742
Training acc: 0.5478


100%|██████████| 10000/10000 [01:17<00:00, 128.35it/s]


Training loss: 1.6012689563751221
Training acc: 0.8154


100%|██████████| 10000/10000 [01:22<00:00, 121.01it/s]


Training loss: 1.569059454035759
Training acc: 0.8524


100%|██████████| 10000/10000 [01:25<00:00, 116.64it/s]


Training loss: 1.5493908738851547
Training acc: 0.8912


100%|██████████| 10000/10000 [01:42<00:00, 97.32it/s]


Training loss: 1.5408206284999848
Training acc: 0.9001


100%|██████████| 10000/10000 [01:45<00:00, 95.21it/s]


Training loss: 1.5305798378825188
Training acc: 0.9132


100%|██████████| 10000/10000 [01:52<00:00, 88.53it/s]


Training loss: 1.524454635810852
Training acc: 0.9221


100%|██████████| 10000/10000 [02:03<00:00, 81.08it/s]


Training loss: 1.5170689664959907
Training acc: 0.9349


100%|██████████| 10000/10000 [01:44<00:00, 95.63it/s]


Training loss: 1.516738697564602
Training acc: 0.9296


100%|██████████| 10000/10000 [01:53<00:00, 87.91it/s]


Training loss: 1.5102963964223861
Training acc: 0.9401


100%|██████████| 10000/10000 [01:48<00:00, 92.23it/s]


Training loss: 1.9041845077514647
Training acc: 0.4551


100%|██████████| 10000/10000 [01:47<00:00, 93.38it/s]


Training loss: 1.6880655025243758
Training acc: 0.6965


100%|██████████| 10000/10000 [01:49<00:00, 91.01it/s]


Training loss: 1.6434219279289246
Training acc: 0.7957


100%|██████████| 10000/10000 [01:59<00:00, 83.97it/s]


Training loss: 1.6230251613497735
Training acc: 0.8214


100%|██████████| 10000/10000 [02:05<00:00, 79.52it/s]


Training loss: 1.6103875289916991
Training acc: 0.8311


100%|██████████| 10000/10000 [01:43<00:00, 96.71it/s]


Training loss: 1.6026254326939582
Training acc: 0.8375


100%|██████████| 10000/10000 [01:50<00:00, 90.91it/s]


Training loss: 1.5956327906727792
Training acc: 0.8347


100%|██████████| 10000/10000 [01:50<00:00, 90.18it/s]


Training loss: 1.5304655809283256
Training acc: 0.9172


100%|██████████| 10000/10000 [01:48<00:00, 91.95it/s]


Training loss: 1.5192876732826233
Training acc: 0.9282


100%|██████████| 10000/10000 [01:42<00:00, 97.40it/s]


Training loss: 1.5144367001652717
Training acc: 0.934


100%|██████████| 10000/10000 [01:12<00:00, 138.29it/s]

Training loss: 1.5093630755066871
Training acc: 0.9404





In [23]:
np.save('train-losses-all-weights-same-hh-ih-bptt',train_losses)
np.save('train-accs-all-weights-same-hh-ih-bptt',train_acc)

In [24]:
train_acc

array([[0.5478, 0.8154, 0.8524, 0.8912, 0.9001, 0.9132, 0.9221, 0.9349,
        0.9296, 0.9401],
       [0.4551, 0.6965, 0.7957, 0.8214, 0.8311, 0.8375, 0.841 , 0.8452,
        0.8516, 0.8507],
       [0.5682, 0.8347, 0.8767, 0.909 , 0.9126, 0.9172, 0.9275, 0.9282,
        0.934 , 0.9404]])