#### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
sys.path.append("..")

In [2]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

#### Test for CUDA

In [3]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('No GPU, training on CPU')
    dev = torch.device('cpu')
else:
    print('GPU found, training on GPU')
    dev = torch.device('cuda')

No GPU, training on CPU


#### Load MNIST

In [4]:
## Make sure batch_size = 1 for now!!

def load_mnist(batch_size=1, shuffle_train=True):
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST("../data", train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST("../data", train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle_train)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [5]:
mnist_train_loader, mnist_test_loader = load_mnist()

In [6]:
image_batch, label_batch = next(iter(mnist_train_loader))

In [7]:
n_classes = 10

In [8]:
## change to appropriate shapes!!
image_batch = torch.squeeze(image_batch).reshape(1,-1)
image_batch = image_batch.repeat(n_classes,1)
label_batch = F.one_hot(label_batch,n_classes).reshape(-1,1)

#### Architecture

In [9]:
class RNNModule:
    """An RNN cell responsible for a single timestep.

    Args:
        inp_dim (int): Input size.
        hid_dim (int): Hidden size.
        out_dim (int): Output size.
    """
    def __init__(self, inp_dim, hid_dim, out_dim):
        self.inp_dim = inp_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        n_classes = 10 ## NOTE: Edit as needed!!

        ## Wih, Whh, Woh are the parameters, so we set requires_grad=True
        self.Wih = torch.empty(hid_dim, inp_dim, requires_grad=True)
        self.Whh = torch.empty(hid_dim, hid_dim, requires_grad=True)
        self.Woh = torch.empty(out_dim, hid_dim, requires_grad=True)

        ## These are the gradients on Wih, Whh, and Woh computed manually
        ## Will be compared to the gradients computed by PyTorch
        self.Wih_grad = torch.zeros_like(self.Wih)
        self.Whh_grad = torch.zeros_like(self.Whh)
        self.Woh_grad = torch.zeros_like(self.Woh)
        
        ## Gating vector
        self.tvec = torch.zeros(n_classes,hid_dim)
        for ii in range(n_classes):
            t_half = torch.randint(0, 2, (1, hid_dim//2)).float()*2 - 1
            self.tvec[ii,::2] = t_half
            self.tvec[ii,1::2] = -t_half
        
        self.reset_parameters()
    
    def reset_parameters(self):
        """Initialize parameters.

        The parameters will be initialized from the uniform
        distribution U(-0.1, 0.1).
        """
        s = 0.1  # larger value may make the gradients explode
        torch.nn.init.uniform_(self.Wih, -s, s)
        torch.nn.init.uniform_(self.Whh, 0, s)
        torch.nn.init.uniform_(self.Woh, 0, s)
        
    def zero_grad(self):
        """Set the gradients to zero."""
        self.Wih_grad.zero_()
        self.Whh_grad.zero_()
        self.Woh_grad.zero_()

    def forward(self, x, hp, kk):
        """Perform the forward computation.
        
        Args:
            x (Tensor): Input at the current timestep.
            hp (Tensor): Hidden state at the previous timestep.
            
        Returns:
            Tensor: Output at the current timestep.
            Tensor: Hidden state at the current timestep.
        """
        _, h, _, y = self._get_internals(x, hp, kk)
        return y, h

    def backward(self, y_grad, rn_grad, x, hp, kk):
        """Perform the backward computation.
        
        Args:
            y_grad (Tensor): Gradient on output at the current timestep.
            rn_grad (Tensor): Gradient on vector r at the next timestep.
            x (Tensor): Input at the current timestep that was passed to `forward`.
            hp (Tensor): Hidden state at the previous timestep that was passed to `forward`.
            
        Returns:
            Tensor: Gradient on vector r at the current timestep.
        """
        # Get internal vectors r, h, and s from forward computation
        r, h, s, _ = self._get_internals(x, hp, kk)

        s_grad = y_grad * torch.sigmoid(s) * (1-torch.sigmoid(s)) ## note manual differentiation!!
        h_grad = self.Woh.t().matmul(s_grad) + self.Whh.t().matmul(rn_grad)
        r_grad = h_grad * ((self.tvec[kk]*r)>0)*1 ## note manual differentiation!!

        # Parameter gradients are accumulated
        self.Wih_grad += r_grad.view(-1, 1).matmul(x.view(1, -1))
        self.Whh_grad += r_grad.view(-1, 1).matmul(hp.view(1, -1)) 
        self.Woh_grad += s_grad.view(-1, 1).matmul(h.view(1, -1)) 

        return r_grad
    
    def _get_internals(self, x, hp, kk):
        # Actual forward computations
        r = self.Wih.matmul(x) + self.Whh.matmul(hp)
        h = ((self.tvec[kk]*r)>0)*r
        s = self.Woh.matmul(h)
        y = torch.sigmoid(s)
        
        return r, h, s, y

In [10]:
class RNN:
    def __init__(self, cell):
        self.cell = cell
    
    def forward(self, xs, h0):
        """Perform the forward computation for all timesteps.
        
        Args:
            xs (Tensor): 2-D tensor of inputs for each timestep. The
                first dimension corresponds to the number of timesteps.
            h0 (Tensor): Initial hidden state.
            
        Returns:
            Tensor: 2-D tensor of outputs for each timestep. The first
                dimension corresponds to the number of timesteps.
            Tensor: 2-D tensor of hidden states for each timestep plus
                `h0`. The first dimension corresponds to the number of
                timesteps.
        """
        ys, hs = [], [h0]
        for ii, x in enumerate(xs):
            y, h = self.cell.forward(x, hs[-1],ii)
            ys.append(y)
            hs.append(h)
        return torch.stack(ys), torch.stack(hs)
    
    def backward(self, ys_grad, xs, hs):
        """Perform the backward computation for all timesteps.
        
        Args:
            ys_grad (Tensor): 2-D tensor of the gradients on outputs
                for each timestep. The first dimension corresponds to
                the number of timesteps.
            xs (Tensor): 2-D tensor of inputs for each timestep that
                was passed to `forward`.
            hs (Tensor): 2-D tensor of hidden states that is returned
                by `forward`.
        """
        # For the last timestep, the gradient on r is zero
        rn_grad = torch.zeros(self.cell.hid_dim)

        for ii, (y_grad, x, hp) in enumerate(reversed(list(zip(ys_grad, xs, hs)))):
            rn_grad = cell.backward(y_grad, rn_grad, x, hp, n_classes-ii-1)

In [11]:
input_dim = 784
hidden_dim = 100
output_dim = 1

In [12]:
cell = RNNModule(input_dim, hidden_dim, output_dim)
rnn = RNN(cell)

#### Loss

In [13]:
# def compute_cross_entropy(predictions, targets, epsilon=1e-15):
#     """
#     Computes cross entropy between targets (encoded as one-hot vectors)
#     and predictions. 
#     Input: predictions (N, k) ndarray
#            targets (N, k) ndarray        
#     Returns: scalar
#     """
#     predictions = torch.clip(predictions, epsilon, 1. - epsilon)
#     N = predictions.shape[0]
#     ce = -torch.sum(targets*torch.log(predictions+1e-12))/N
    
#     return ce

In [14]:
def compute_loss(ys, ts):
    return 0.5 * torch.sum((ys - ts)**2)

In [15]:
## Note -- changed!!
xs = image_batch
hp = torch.zeros(cell.hid_dim)
ts = label_batch

In [16]:
ys, hs = rnn.forward(xs, hp)

In [17]:
ys, hs = rnn.forward(xs, hp)
loss = compute_loss(ys, ts)

In [18]:
# ys

#### PyTorch gradients

In [19]:
## gradient of cross entropy w.r.t. yhat_k = (1/yhat_k)*y_k
## gradient of mse w.r.t. y_hat =   y_hat - y

In [20]:
loss.backward()

In [21]:
# rnn.cell.Wih.grad

In [22]:
# rnn.cell.Whh.grad

In [23]:
# rnn.cell.Woh.grad

#### Manual gradients

In [24]:
# This is obtained from our loss function
# ys_grad = torch.matmul(label_batch,ys)
ys_grad = ys - ts

with torch.no_grad():  # required so PyTorch won't raise error
    rnn.cell.zero_grad()
    rnn.backward(ys_grad, xs, hs)

IndexError: index 11 is out of bounds for dimension 0 with size 10

In [None]:
# rnn.cell.Wih_grad

In [None]:
# rnn.cell.Whh_grad

In [None]:
# rnn.cell.Wih_grad

In [None]:
rnn.cell.Woh_grad - rnn.cell.Woh.grad

In [None]:
rnn.cell.Whh_grad - rnn.cell.Whh.grad

In [None]:
rnn.cell.Wih_grad - rnn.cell.Wih.grad