# Recurrent Neural Network Example

Build a recurrent neural network (LSTM) with PyTorch.

- Author: Ritchie Ng (some changes to the original though)
- Project: https://github.com/ritchieng/deep-learning-wizard

## RNN Overview

<img src="http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-unrolled.png" alt="nn" style="width: 600px;"/>

References:
- [Long Short Term Memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf), Sepp Hochreiter & Jurgen Schmidhuber, Neural Computation 9(8): 1735-1780, 1997.

## MNIST Dataset Overview

This example is using MNIST handwritten digits. The dataset contains 60,000 examples for training and 10,000 examples for testing. The digits have been size-normalized and centered in a fixed-size image (28x28 pixels) with values from 0 to 1. For simplicity, each image has been flattened and converted to a 1-D numpy array of 784 features (28*28).

<img src="http://neuralnetworksanddeeplearning.com/images/mnist_100_digits.png" alt="MNIST Dataset" style="width: 300px;"/>

To classify images using a recurrent neural network, we consider every image row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 timesteps for every sample.

More info: http://yann.lecun.com/exdb/mnist/

### 0. Import dependencies
At the beginning of every project, we first import all the required dependencies to process the data, define the model, train the model and visualize the results.

In [1]:
import torch
from torch import nn
from torchvision import transforms
from torchvision import datasets

import matplotlib.pyplot as plt

### 1. Data Loading
To train our model to recognize (classify) handwritten digits, we download the famous MNIST dataset from a PyTorch collection of example datasets. The dataset consists of a training dataset (train split) and a independent test dataset (test split). Each split is a collection of tuples, whereas every tuple consists of an input image with 28x28 pixels and a target value (label) representing the value of the handwritten digit.

In [None]:
train_dataset = datasets.MNIST(root='./data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='./data', 
                              train=False, 
                              transform=transforms.ToTensor())

In [None]:
# Visualize the first input image of the dataset
plt.imshow(train_dataset[0][0].squeeze())
plt.show()

# Print the associated ground truth value (label)
print(f"The ground truth label for the given image is: {train_dataset[0][1]}")

### 2. Model Definition
We define a neural network consisting of multiple consecutive LSTM cells and a final fully connected layer to predict a digit given an image of a handwritten digit. The initial hidden state and cell state are initialized with zeros.

In [4]:
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModel, self).__init__()
        # Hidden dimensions
        self.hidden_dim = hidden_dim
        
        # Number of hidden layers
        self.layer_dim = layer_dim
        
        # Building your LSTM
        # batch_first=True causes input/output tensors to be of shape
        # (batch_dim, seq_dim, feature_dim)
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        
        # Readout layer
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x: torch.Tensor):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
        
        # Initialize cell state
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
        
        # 28 time steps
        # We need to detach as we are doing truncated backpropagation through time (BPTT)
        # If we don't, we'll backprop all the way to the start even after going through another batch
        out, _ = self.lstm(x, (h0.detach(), c0.detach()))
        
        # Index hidden state of last step
        # out.size() --> 100, 28, 100
        # out[:, -1, :] --> 100, 100 --> just want last step!
        out = self.fc(out[:, -1, :])
        # out.size() --> 100, 10
        return out

### 3. Model Initialization
The model is initialized to expect an input sequence of length 28 and predict a probability value for the digits 0-9. The model consist of a single LSTM layer with a feature dimension of 64 for the hidden state.

In [5]:
# Dimension of the input data (sequence of 28 pixels)
input_dim = 28

# Dimemsion of the hidden state
hidden_dim = 64

# Number of LSTM layers
layer_dim = 1

# Dimension of the output data (digits 0-9)
output_dim = 10

# Model initialization
model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)

### 4. Training Setup
A cross-entropy loss function is used for the model training and the model weights are optimized using a stochastic gradient descent method with a learning rate (step size) of 0.1.

In [6]:
# Definition of the loss function (cross-entropy for regression problem)
criterion = nn.CrossEntropyLoss()

# Learning rate for the model training (optimization)
learning_rate = 0.1

# Definition of the optimizer (stochastic gradient descent)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  

In [7]:
# Bacht size for model training (size of the mini batches)
batch_size = 100

# Make the datasets iterable
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

### 5. Model Training

In [None]:
# Number of steps to unroll (28 times 28 pixels)
seq_dim = 28  

# Number of training epochs
num_epochs = 2

iteration = 0
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Load images as a torch tensor with gradient accumulation abilities
        images = images.view(-1, seq_dim, input_dim).requires_grad_()
        
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()
        
        # Forward pass to get output/logits
        # outputs.size() --> 100, 10
        outputs = model(images)
        
        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)
        
        # Getting gradients w.r.t. parameters
        loss.backward()
        
        # Updating parameters
        optimizer.step()
        
        iteration += 1
        
        # Evaluate model performance
        if iteration % 100 == 0:
            # Calculate Accuracy
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
                # Resize images
                images = images.view(-1, seq_dim, input_dim)
                
                # Forward pass only to get logits/output
                outputs = model(images)
                
                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)
                
                # Total number of labels
                total += labels.size(0)
                
                # Total correct predictions
                correct += (predicted == labels).sum()
            
            accuracy = 100 * correct / total
            
            # Print Loss
            print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iteration, loss.item(), accuracy))

### 6. Model Evaluation

In [None]:
# Select arbitrary test image
image = test_dataset[0][0]

# Resize image
image = image.view(-1, seq_dim, input_dim)

# Forward pass only to get logits/output
outputs = model(image)

# Get predictions from the maximum value
_, predicted = torch.max(outputs.data, 1)

# Visualize the test image (model input)
plt.imshow(image.squeeze())
plt.show()

# Print the predicted value
print(f"The predicted value for the given image is: {predicted}")