#### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
sys.path.append("..")

In [2]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

In [3]:
from tqdm import tqdm

In [4]:
from bptt_tgeb_mnist_architecture import *

#### Test for CUDA

In [5]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('No GPU, training on CPU')
    dev = torch.device('cpu')
else:
    print('GPU found, training on GPU')
    dev = torch.device('cuda')

No GPU, training on CPU


#### Load MNIST

In [6]:
## Make sure batch_size = 1 for now!!

def load_mnist(batch_size=1, shuffle_train=True):
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST("../data", train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST("../data", train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle_train)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [7]:
mnist_train_loader, mnist_test_loader = load_mnist()

#### Architectural initialisations

In [8]:
n_classes = 10

In [9]:
input_dim = 784
hidden_dim = 100
output_dim = 1

In [10]:
## Gating vector
tvec_hh = torch.zeros(n_classes,hidden_dim)
for ii in range(n_classes):
    t_half = torch.randint(0, 2, (1, hidden_dim//2)).float()*2 - 1
    tvec_hh[ii,::2] = t_half
    tvec_hh[ii,1::2] = -t_half

In [11]:
tvec_ih = tvec_hh

#### Architecture

In [12]:
cell = RNNModule(input_dim, hidden_dim, output_dim, tvec_ih, tvec_hh)
rnn = RNN(cell)

#### Loss

In [13]:
def compute_loss(ys, ts):
    return 0.5 * torch.sum((ys - ts)**2)

#### Training-esque loop

In [14]:
loader_choice = 'test'

In [15]:
if loader_choice == 'test':
    loader = mnist_test_loader
    len_dataset = 10000
elif loader_choice == 'train':
    loader = mnist_train_loader
    len_dataset = 60000

In [16]:
Wih_bp = np.zeros((len_dataset,n_classes,hidden_dim,input_dim),dtype='float32')
Whh_bp = np.zeros((len_dataset,n_classes,hidden_dim,hidden_dim),dtype='float32')
Woh_bp = np.zeros((len_dataset,n_classes,output_dim,hidden_dim),dtype='float32')

Wih_geb = np.zeros((len_dataset,n_classes,hidden_dim,input_dim),dtype='float32')
Whh_geb = np.zeros((len_dataset,n_classes,hidden_dim,hidden_dim),dtype='float32')
Woh_geb = np.zeros((len_dataset,n_classes,output_dim,hidden_dim),dtype='float32')

MemoryError: Unable to allocate 58.4 GiB for an array with shape (10000, 10, 100, 784) and data type float64

In [17]:
for ii, (image,label) in tqdm(enumerate(loader)):
    
    ## Change to appropriate shapes!!
    image = torch.squeeze(image).view(1,-1)
    image = image.repeat(n_classes,1)
    label = F.one_hot(label,n_classes).view(-1,1)

    xs = image
    hp = torch.zeros(cell.hid_dim)
    ts = label
    
    ## Forward pass
    ys, hs = rnn.forward(xs, hp)
    loss = compute_loss(ys, ts)
    
    ## Compute gradients w/ Backprop (autograd)
    loss.backward()
    
    ## Manual gradients
    ## Valid only for MSE!!
    ys_grad = ys - ts

    with torch.no_grad():  # required so PyTorch won't raise error
        rnn.cell.zero_grad()
        rnn.backward(ys_grad, xs, hs)
        
    for kk in range(n_classes):
        Wih_bp[ii,kk] = rnn.cell.Wih_grad_all[kk]
        Whh_bp[ii,kk] = rnn.cell.Whh_grad_all[kk]
        Woh_bp[ii,kk] = rnn.cell.Woh_grad_all[kk]
        
        Wih_geb[ii,kk] = rnn.cell.Wih_grad_geb_all[kk]
        Whh_geb[ii,kk] = rnn.cell.Whh_grad_geb_all[kk]
        Woh_geb[ii,kk] = rnn.cell.Woh_grad_geb_all[kk]

0it [00:00, ?it/s]


ValueError: only one element tensors can be converted to Python scalars

In [16]:
# def clip_mini_grads(grad_dict, nSteps=10):
    

#### PyTorch gradients

In [19]:
## gradient of cross entropy w.r.t. yhat_k = (1/yhat_k)*y_k
## gradient of mse w.r.t. y_hat = y_hat - y

In [21]:
# rnn.cell.Wih.grad

In [22]:
# rnn.cell.Whh.grad

In [23]:
# rnn.cell.Woh.grad

#### Manual gradients

In [24]:
# This is obtained from our loss function
# ys_grad = torch.matmul(label_batch,ys)
ys_grad = ys - ts

with torch.no_grad():  # required so PyTorch won't raise error
    rnn.cell.zero_grad()
    rnn.backward(ys_grad, xs, hs)

In [36]:
print('Summed total differences for Whh: ',torch.abs(rnn.cell.Whh_grad - rnn.cell.Whh_grad_geb))

Summed total differences for Whh:  tensor([[2.9098, 6.6073, 3.2535,  ..., 7.0618, 6.2041, 4.5104],
        [2.8203, 4.2550, 0.2491,  ..., 0.2389, 0.2248, 7.5119],
        [1.0735, 6.3153, 1.1986,  ..., 6.6821, 6.5828, 1.6456],
        ...,
        [0.6296, 4.9521, 0.2488,  ..., 0.8699, 0.4176, 4.5638],
        [0.1245, 9.8743, 0.2491,  ..., 5.4157, 5.5624, 4.5645],
        [6.6030, 0.6977, 3.0078,  ..., 1.7924, 0.7971, 6.9081]])


In [37]:
print('Summed total differences for Wih: ',torch.abs(rnn.cell.Wih_grad - rnn.cell.Wih_grad_geb))

Summed total differences for Wih:  tensor([[3.8722, 3.8722, 3.8722,  ..., 3.8722, 3.8722, 3.8722],
        [3.2990, 3.2990, 3.2990,  ..., 3.2990, 3.2990, 3.2990],
        [3.7147, 3.7147, 3.7147,  ..., 3.7147, 3.7147, 3.7147],
        ...,
        [4.3902, 4.3902, 4.3902,  ..., 4.3902, 4.3902, 4.3902],
        [4.3657, 4.3657, 4.3657,  ..., 4.3657, 4.3657, 4.3657],
        [4.3580, 4.3580, 4.3580,  ..., 4.3580, 4.3580, 4.3580]])


In [38]:
# for tt in range(n_classes):
#     print(tt)
    
#     rnn.cell.Wih_grad_all[tt][rnn.cell.Wih_grad_all[tt]==torch.clamp(rnn.cell.Wih_grad_all[tt], -1e-6 , 1e-6)]=0
#     rnn.cell.Wih_grad_all[tt][rnn.cell.Wih_grad_geb_all[tt]==torch.clamp(rnn.cell.Wih_grad_geb_all[tt], -1e-6 , 1e-6)]=0
    
#     rnn.cell.Whh_grad_all[tt][rnn.cell.Whh_grad_all[tt]==torch.clamp(rnn.cell.Whh_grad_all[tt], -1e-6 , 1e-6)]=0
#     rnn.cell.Whh_grad_all[tt][rnn.cell.Whh_grad_geb_all[tt]==torch.clamp(rnn.cell.Whh_grad_geb_all[tt], -1e-6 , 1e-6)]=0
    
#     rnn.cell.Woh_grad_all[tt][rnn.cell.Woh_grad_all[tt]==torch.clamp(rnn.cell.Woh_grad_all[tt], -1e-6 , 1e-6)]=0
#     rnn.cell.Woh_grad_all[tt][rnn.cell.Woh_grad_geb_all[tt]==torch.clamp(rnn.cell.Woh_grad_geb_all[tt], -1e-6 , 1e-6)]=0
    
#     print('Timestep signed differences for Wih:',torch.unique(torch.sign(rnn.cell.Wih_grad_all[tt]) - torch.sign(rnn.cell.Wih_grad_geb_all[tt])))
#     print('Timestep signed differences for Whh:',torch.unique(torch.sign(rnn.cell.Whh_grad_all[tt]) - torch.sign(rnn.cell.Whh_grad_geb_all[tt])))
#     print('Timestep signed differences for Woh:',torch.unique(torch.sign(rnn.cell.Woh_grad_all[tt]) - torch.sign(rnn.cell.Woh_grad_geb_all[tt])))
#     print('-----------------------')

In [39]:
for tt in range(n_classes):
    
    print(tt)
    
    totWih = np.prod(rnn.cell.Wih_grad_all[tt].shape)
    totWhh = np.prod(rnn.cell.Whh_grad_all[tt].shape)
    totWoh = np.prod(rnn.cell.Woh_grad_all[tt].shape)
    
#     rnn.cell.Wih_grad_all[tt][rnn.cell.Wih_grad_all[tt]==torch.clamp(rnn.cell.Wih_grad_all[tt], -1e-6 , 1e-6)]=0
#     rnn.cell.Wih_grad_all[tt][rnn.cell.Wih_grad_geb_all[tt]==torch.clamp(rnn.cell.Wih_grad_geb_all[tt], -1e-6 , 1e-6)]=0
    
#     rnn.cell.Whh_grad_all[tt][rnn.cell.Whh_grad_all[tt]==torch.clamp(rnn.cell.Whh_grad_all[tt], -1e-6 , 1e-6)]=0
#     rnn.cell.Whh_grad_all[tt][rnn.cell.Whh_grad_geb_all[tt]==torch.clamp(rnn.cell.Whh_grad_geb_all[tt], -1e-6 , 1e-6)]=0
    
#     rnn.cell.Woh_grad_all[tt][rnn.cell.Woh_grad_all[tt]==torch.clamp(rnn.cell.Woh_grad_all[tt], -1e-6 , 1e-6)]=0
#     rnn.cell.Woh_grad_all[tt][rnn.cell.Woh_grad_geb_all[tt]==torch.clamp(rnn.cell.Woh_grad_geb_all[tt], -1e-6 , 1e-6)]=0
    
    num0_Wih = len(np.where((torch.sign(rnn.cell.Wih_grad_all[tt]) - torch.sign(rnn.cell.Wih_grad_geb_all[tt]))==0)[0])
    num1_Wih = len(np.where((torch.sign(rnn.cell.Wih_grad_all[tt]) - torch.sign(rnn.cell.Wih_grad_geb_all[tt]))==1)[0])
    num2_Wih = len(np.where((torch.sign(rnn.cell.Wih_grad_all[tt]) - torch.sign(rnn.cell.Wih_grad_geb_all[tt]))==2)[0])
    numm1_Wih = len(np.where((torch.sign(rnn.cell.Wih_grad_all[tt]) - torch.sign(rnn.cell.Wih_grad_geb_all[tt]))==-1)[0])
    numm2_Wih = len(np.where((torch.sign(rnn.cell.Wih_grad_all[tt]) - torch.sign(rnn.cell.Wih_grad_geb_all[tt]))==-2)[0])
    
    num0_Whh = len(np.where((torch.sign(rnn.cell.Whh_grad_all[tt]) - torch.sign(rnn.cell.Whh_grad_geb_all[tt]))==0)[0])
    num1_Whh = len(np.where((torch.sign(rnn.cell.Whh_grad_all[tt]) - torch.sign(rnn.cell.Whh_grad_geb_all[tt]))==1)[0])
    num2_Whh = len(np.where((torch.sign(rnn.cell.Whh_grad_all[tt]) - torch.sign(rnn.cell.Whh_grad_geb_all[tt]))==2)[0])
    numm1_Whh = len(np.where((torch.sign(rnn.cell.Whh_grad_all[tt]) - torch.sign(rnn.cell.Whh_grad_geb_all[tt]))==-1)[0])
    numm2_Whh = len(np.where((torch.sign(rnn.cell.Whh_grad_all[tt]) - torch.sign(rnn.cell.Whh_grad_geb_all[tt]))==-2)[0])
    
    num0_Woh = len(np.where((torch.sign(rnn.cell.Woh_grad_all[tt]) - torch.sign(rnn.cell.Woh_grad_geb_all[tt]))==0)[0])
    num1_Woh = len(np.where((torch.sign(rnn.cell.Woh_grad_all[tt]) - torch.sign(rnn.cell.Woh_grad_geb_all[tt]))==1)[0])
    num2_Woh = len(np.where((torch.sign(rnn.cell.Woh_grad_all[tt]) - torch.sign(rnn.cell.Woh_grad_geb_all[tt]))==2)[0])
    numm1_Woh = len(np.where((torch.sign(rnn.cell.Woh_grad_all[tt]) - torch.sign(rnn.cell.Woh_grad_geb_all[tt]))==-1)[0])
    numm2_Woh = len(np.where((torch.sign(rnn.cell.Woh_grad_all[tt]) - torch.sign(rnn.cell.Woh_grad_geb_all[tt]))==-2)[0])
    
    print('Frac. 0 signed differences for Wih:',num0_Wih/totWih)
#     print('Frac. 1 signed differences for Wih:',num1_Wih/totWih)
#     print('Frac. 2 signed differences for Wih:',num2_Wih/totWih)
#     print('Frac. -1 signed differences for Wih:',numm1_Wih/totWih)
#     print('Frac. -2 signed differences for Wih:',numm2_Wih/totWih)
    print('-----------------------')
    print('Frac. 0 signed differences for Whh:',num0_Whh/totWhh)
#     print('Frac. 1 signed differences for Whh:',num1_Whh/totWhh)
#     print('Frac. 2 signed differences for Whh:',num2_Whh/totWhh)
#     print('Frac. -1 signed differences for Whh:',numm1_Whh/totWhh)
#     print('Frac. -2 signed differences for Whh:',numm2_Whh/totWhh)
    print('-----------------------')
    print('Frac. 0 signed differences for Woh:',num0_Woh/totWoh)
#     print('Frac. 1 signed differences for Woh:',num1_Woh/totWoh)
#     print('Frac. 2 signed differences for Woh:',num2_Woh/totWoh)
#     print('Frac. -1 signed differences for Woh:',numm1_Woh/totWoh)
#     print('Frac. -2 signed differences for Woh:',numm2_Woh/totWoh)
    print('=======================')

0
Frac. 0 signed differences for Wih: 1.0
-----------------------
Frac. 0 signed differences for Whh: 1.0
-----------------------
Frac. 0 signed differences for Woh: 1.0
1
Frac. 0 signed differences for Wih: 0.32
-----------------------
Frac. 0 signed differences for Whh: 0.7316
-----------------------
Frac. 0 signed differences for Woh: 1.0
2
Frac. 0 signed differences for Wih: 0.66
-----------------------
Frac. 0 signed differences for Whh: 1.0
-----------------------
Frac. 0 signed differences for Woh: 1.0
3
Frac. 0 signed differences for Wih: 0.62
-----------------------
Frac. 0 signed differences for Whh: 1.0
-----------------------
Frac. 0 signed differences for Woh: 1.0
4
Frac. 0 signed differences for Wih: 0.56
-----------------------
Frac. 0 signed differences for Whh: 1.0
-----------------------
Frac. 0 signed differences for Woh: 1.0
5
Frac. 0 signed differences for Wih: 0.52
-----------------------
Frac. 0 signed differences for Whh: 1.0
-----------------------
Frac. 0 sign