In [18]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

#### Loading data

In [19]:
train_dataset = datasets.MNIST(root='dataset/', train=True,
                              transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='dataset/', train=False,
                             transform=transforms.ToTensor(), download=True)

### HyperParameters

In [20]:
input_size = 28
sequence_length = 28
num_layers = 2
hidden_size = 256
epochs = 2
batch_size = 64
num_classes = 10
learning_rate =0.001

### Creating DataLoaders

In [21]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

### Create the RNN Model

In [22]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size * sequence_length, num_classes)
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        #Forward prop
        out,_ = self.rnn(x, h0)
        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)
        return out

In [23]:
model = RNN(input_size, hidden_size, num_layers, num_classes)

### Loss and Optimizer

In [24]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

### Training

In [25]:
for epoch in range(epochs):
    for batch_idx, (input_data, labels) in enumerate(train_dataloader):
        input_data = input_data.squeeze(1)
        output = model(input_data)
        loss = criterion(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx%50==49:
            print(f'Epoch:{epoch}, Batch Number:{batch_idx}, Loss={loss.item()}')
print('Finished Training!')            

Epoch:0, Batch Number:49, Loss=0.2672605812549591
Epoch:0, Batch Number:99, Loss=0.31670960783958435
Epoch:0, Batch Number:149, Loss=0.3653744161128998
Epoch:0, Batch Number:199, Loss=0.4004823863506317
Epoch:0, Batch Number:249, Loss=0.1081690713763237
Epoch:0, Batch Number:299, Loss=0.1712498664855957
Epoch:0, Batch Number:349, Loss=0.5119423866271973
Epoch:0, Batch Number:399, Loss=0.26457279920578003
Epoch:0, Batch Number:449, Loss=0.2853630483150482
Epoch:0, Batch Number:499, Loss=0.31032976508140564
Epoch:0, Batch Number:549, Loss=0.12149421125650406
Epoch:0, Batch Number:599, Loss=0.16274918615818024
Epoch:0, Batch Number:649, Loss=0.2640329897403717
Epoch:0, Batch Number:699, Loss=0.04390516132116318
Epoch:0, Batch Number:749, Loss=0.20299088954925537
Epoch:0, Batch Number:799, Loss=0.18193213641643524
Epoch:0, Batch Number:849, Loss=0.15093453228473663
Epoch:0, Batch Number:899, Loss=0.027932429686188698
Epoch:1, Batch Number:49, Loss=0.14121529459953308
Epoch:1, Batch Number:

In [26]:
def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on training data')
    else:
        print('Checking accuracy on testing data')
    
    correct = 0
    total = 0
    model.eval()
    
    with torch.no_grad():
        for x, y in loader:
            x = x.squeeze(1)
            
            scores = model(x)
            
            _, predictions = scores.max(1)
            correct += (predictions==y).sum()
            total += predictions.size(0)
        print(f'Got {correct} / {total} with accuracy {float(correct)/float(total) * 100:.2f}') 
    
    model.train()    

In [27]:
check_accuracy(train_dataloader, model)
check_accuracy(test_dataloader, model)

Checking accuracy on training data
Got 58260 / 60000 with accuracy 97.10
Checking accuracy on testing data
Got 9675 / 10000 with accuracy 96.75
