# Homework 7 - Experiments on RNN and LSTM

Please implement the following two functions:
- MnistRNN() - Design a RNN
- MnistLSTM() - Design a LSTM 

Please train two models on the Mnist dataset and print the training results for each epoch.

In [2]:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import os
import torch
import numpy as np

BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.enabled = False

# dataloader for the dataset
def get_dataloader(train,batch_size=BATCH_SIZE):
    transform_fn = Compose([
        ToTensor(),
        Normalize(mean = (0.1307,), std = (0.3081,))
        ]) 
    dataset = MNIST(root = './data',train = train,transform = transform_fn, download = True)
    data_loader = DataLoader(dataset,batch_size = batch_size,shuffle = True)
    return data_loader

In [3]:
# RNN
class MnistRNN(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=128, output_dim=10):
        super(MnistRNN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.stem = nn.Sequential(
            nn.Conv2d(
                input_dim, hidden_dim,
                kernel_size=4, stride=4),
            nn.ReLU())

        # Tip: define RNN
        pass

        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs):
        inputs = self.stem(inputs)
        inputs = inputs.view(inputs.size(0), self.hidden_dim, -1)  # B x 1 x (7x7)

        # Tip: forward RNN
        hidden = pass

        output = self.softmax(hidden)
        return output

In [4]:
# LSTM
class MnistLSTM(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=128, output_dim=10):
        super(MnistLSTM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.stem = nn.Sequential(
            nn.Conv2d(
                input_dim, hidden_dim,
                kernel_size=4, stride=4),
            nn.ReLU())

        # Tip: define LSTM (the official implementation in `nn`` can be used)
        pass

        self.fc = nn.Linear(hidden_dim, output_dim)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs):
        inputs = self.stem(inputs)
        inputs = inputs.view(inputs.size(0), -1, self.hidden_dim)  # B x (7x7) x D

        # Tip: forward LSTM
        output = pass

        output = self.fc(output.mean([-2]))
        output = self.softmax(output)
        return output

## Train the RNN model

In [5]:
model = MnistRNN().to(device)
optimizer = Adam(model.parameters(), lr=0.001)

In [6]:
def train(epoch, num_epochs):
    data_loader = get_dataloader(True)
    total_step = len(data_loader)
    for idx, (input, target) in enumerate(data_loader):
        optimizer.zero_grad()
        output = model(input.to(device))
        loss = F.nll_loss(output, target.to(device))
        loss.backward()
        optimizer.step()
        if (idx+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, idx+1, total_step, loss.item()))

In [7]:
def test():
    loss_list = []
    acc_list = []
    test_dataloader = get_dataloader(train = False,batch_size=TEST_BATCH_SIZE)
    for idx,(input,target) in enumerate(test_dataloader):
        with torch.no_grad():
            output = model(input.to(device))
            target = target.to(device)
            cur_loss = F.nll_loss(output, target)
            loss_list.append(cur_loss.cpu())
            pred = output.max(dim = -1)[-1]
            cur_acc = pred.eq(target).float().mean()
            acc_list.append(cur_acc.cpu())
    print("Mean accuracy: ", np.mean(acc_list), "Mean loss: ", np.mean(loss_list))

In [None]:
test()
num_epochs = 3
for i in range(num_epochs):
    train(i, num_epochs)
test()

## Train the LSTM model

In [10]:
model = MnistLSTM().to(device)
optimizer = Adam(model.parameters(), lr=0.001)

In [None]:
test()
num_epochs = 3
for i in range(num_epochs):
    train(i, num_epochs)
test()