# RNN on MNIST
<img src="./LSTM.png">

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import numpy as np


In [2]:
# 1. Hyper Parameters
input_size = 28
sequence_size = 28
hidden_size = 128
num_layers = 1
num_classes = 10

learning_rate = 0.01
batch_size = 1
ephoc_size = 2


In [3]:
# 2. Preparing datasets
    # MNIST Dataset (Images and Labels)
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

    # Dataset Loader (Input Pipline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [8]:
# 3. Build a model
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fnn = nn.Linear(hidden_size, num_classes)

    def forward(self, x):

        out, _ = self.rnn(x)
        #print(out, out.size())
        out_tmp = out[:, -1, :] 
        #print(out_tmp)   # 1x128 takes the last hidden layer  
        output = self.fnn(out_tmp) #output = 1x10
        #print(output)
        
        return output

In [9]:
# 4. Generate the model
model = RNNModel(input_size, hidden_size, num_layers, num_classes)

In [10]:
# 5. Set Loss and Optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# 6. Train
for ephoc in range(ephoc_size):
    for idx, (images, labels) in enumerate(train_loader):
        # convert dataset as for Pytorch type
        images = Variable(images.view(-1, sequence_size, input_size))
        labels = Variable(labels)
        
        # Forward, Backward and Gradient decent
        optimizer.zero_grad()
        output = model(images)
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()
        if idx%100 == 0:
            print("loss:", loss.item())

loss: 2.2353217601776123
loss: 2.86865234375
loss: 2.1444647312164307
loss: 2.4055161476135254
loss: 1.3455277681350708
loss: 2.1672654151916504
loss: 2.2023348808288574
loss: 2.7592661380767822
loss: 4.385477542877197
loss: 1.659388780593872
loss: 0.8252060413360596
loss: 2.148242950439453
loss: 2.499058246612549
loss: 0.9146583080291748
loss: 5.119024276733398
loss: 1.1530007123947144
loss: 1.1509158611297607
loss: 1.8377019166946411
loss: 1.5790812969207764
loss: 0.7153019905090332
loss: 0.7283716201782227
loss: 0.81229567527771
loss: 1.2011340856552124
loss: 1.200746774673462
loss: 0.37594088912010193
loss: 0.03935695439577103
loss: 4.458440780639648
loss: 0.6619580388069153
loss: 0.0959213450551033
loss: 1.519495964050293
loss: 0.26937732100486755
loss: 0.020531507208943367
loss: 0.5322520732879639
loss: 0.8357500433921814
loss: 0.3747805953025818
loss: 2.791842460632324
loss: 0.023426424711942673
loss: 0.6903879046440125
loss: 0.023575585335493088
loss: 0.36850690841674805
loss: 

In [None]:
# 7. Test
correct = 0
total = 0
for images, labels in test_loader:
    images = Variable(images.view(-1, sequence_size, input_size))
    
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += len(predicted)
    correct += (predicted == labels).sum()
    
print("accuracy:", correct.item()/total)