#MNIST classification with RNN
Treat each row of MNIST image as an input vector at each time point. So, 32 rows will be fed sequentially. After that the output will be probablity vector for classification. This would be an example of a "sequence-to-vector" model of RNN 

In [8]:
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
%matplotlib inline

n_epochs = 10
batch_size = 200

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

from google.colab import drive
drive.mount('/content/drive')

# Checking GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
cuda:0


In [0]:
from torch.utils.data import random_split

MNIST_training = torchvision.datasets.MNIST('/content/drive/My Drive/HIP2019/MNIST_dataset/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((0.1307,), (0.3081,))]))

MNIST_test_set = torchvision.datasets.MNIST('/content/drive/My Drive/HIP2019/MNIST_dataset/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((0.1307,), (0.3081,))]))

# create a training and a validation set
MNIST_training_set, MNIST_validation_set = random_split(MNIST_training, [55000, 5000])

train_loader = torch.utils.data.DataLoader(MNIST_training_set,batch_size=batch_size, shuffle=True)

validation_loader = torch.utils.data.DataLoader(MNIST_validation_set,batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(MNIST_test_set,batch_size=batch_size, shuffle=True)

In [10]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_data.shape,example_targets.shape)

torch.Size([200, 1, 28, 28]) torch.Size([200])


In [0]:
# One problem 
class BasicRNN(nn.Module):
    def __init__(self, batch_size, n_inputs, n_neurons):
        super(BasicRNN, self).__init__()
        
        self.rnn = nn.LSTMCell(28, 64)                     # input dimension x number of neurons
        self.hx = torch.randn(batch_size, 64).to(device)   # initialize hidden state
        self.cx = torch.randn(batch_size, 64).to(device)   # initialize memory state
        self.out = nn.Linear(64, 10)

    def forward(self, X):
        # transforms X to dimensions: n_steps X batch_size X n_inputs
        X = X.permute(1, 0, 2)

        # for each time step - there are 28 rows in the MNIST image
        hx = self.hx
        cx = self.cx
        for i in range(28):
            hx, cx = self.rnn(X[i], (hx, cx))
        
        return self.out(hx)


In [25]:
rnn = BasicRNN(batch_size,28,64).to(device)
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.01)   # optimize all rnn parameters
loss_func = nn.CrossEntropyLoss()                       # the target label is not one-hotted

for epoch in range(10):
    for step, (x, y) in enumerate(train_loader):        # gives batch data
        b_x = x.view(-1, 28, 28).to(device)             # reshape x to (batch, time_step, input_size)
        b_y = y.to(device)                              # batch y

        output = rnn(b_x)                               # rnn output
        loss = loss_func(output, b_y)                   # cross entropy loss
        optimizer.zero_grad()                           # clear gradients for this training step
        loss.backward()                                 # backpropagation, compute gradients
        optimizer.step()                                # apply gradients

        if step % 50 == 0:
          with torch.no_grad():
            accuracy=0.0
            for validation_x, validation_y in validation_loader:
              validation_x = validation_x.view(-1, 28, 28).to(device)    # reshape x to (batch, time_step, input_size)
              validation_y = validation_y.to(device)                     # batch y
              validation_output = rnn(validation_x) 
              pred = validation_output.data.max(1, keepdim=True)[1]
              accuracy += 100.0*pred.eq(validation_y.data.view_as(pred)).sum() / float(len(validation_loader.dataset))
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.item(), '| validation accuracy: %.2f' % accuracy)

Epoch:  0 | train loss: 2.3047 | validation accuracy: 10.62
Epoch:  0 | train loss: 0.5686 | validation accuracy: 81.62
Epoch:  0 | train loss: 0.2626 | validation accuracy: 92.64
Epoch:  0 | train loss: 0.2473 | validation accuracy: 92.44
Epoch:  0 | train loss: 0.2741 | validation accuracy: 93.62
Epoch:  0 | train loss: 0.1467 | validation accuracy: 94.66
Epoch:  1 | train loss: 0.1038 | validation accuracy: 95.42
Epoch:  1 | train loss: 0.1261 | validation accuracy: 95.90
Epoch:  1 | train loss: 0.1259 | validation accuracy: 96.24
Epoch:  1 | train loss: 0.1520 | validation accuracy: 95.86
Epoch:  1 | train loss: 0.1286 | validation accuracy: 96.10
Epoch:  1 | train loss: 0.1419 | validation accuracy: 96.52
Epoch:  2 | train loss: 0.0777 | validation accuracy: 96.62
Epoch:  2 | train loss: 0.1228 | validation accuracy: 96.80
Epoch:  2 | train loss: 0.0834 | validation accuracy: 96.70
Epoch:  2 | train loss: 0.1319 | validation accuracy: 96.56
Epoch:  2 | train loss: 0.1397 | validat

In [26]:
with torch.no_grad():
  accuracy=0.0
  for test_x, test_y in test_loader:
    test_x = test_x.view(-1, 28, 28).to(device)    # reshape x to (batch, time_step, input_size)
    test_y = test_y.to(device)                     # batch y
    test_output = rnn(test_x) 
    pred = test_output.data.max(1, keepdim=True)[1]
    accuracy += 100.0*pred.eq(test_y.data.view_as(pred)).sum() / float(len(test_loader.dataset))
  print('test accuracy: %.2f' % accuracy)

test accuracy: 97.38
