In [None]:
import math
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import matplotlib.pyplot as plt
%matplotlib inline
from IPython import display
from time import sleep

In [None]:
# MNIST dataset 
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

In [None]:
# Model Hyperparams 
input_size = 28 
sequence_length = 28
hidden_size = 128
num_classes = 10

# Training Hyperparams
num_epochs = 3
batch_size = 100
learning_rate = 0.001

# Create Data Loader 
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

Custom LSTM-module

In [None]:
class emiLSTM_ver1(nn.Module):  # TODO: multiple LSTM:s on top of each other (num_layers)?, directions?
  """ Own implementation of LSTM.
  A single layer LSTM progressing in one single direction. The implemented equations for 
  updating cell and hidden state can be found at https://colah.github.io/posts/2015-08-Understanding-LSTMs/. """

  def __init__(self, input_size, hidden_size):
    """ Sets the size of the input at each timestep (input_size) and the size 
    of the hidden vector (hidden_size) and initializes weights. Note that the size of the cell state
    is the same hidden_size per definition. """

    super().__init__()
    self.input_size = input_size  # Size of the input-values in sequence
    self.hidden_size = hidden_size  # Note that the size of the cellstate is equal to the hidden_size in the standard architecture

    # Forget-gate layer parameters
    self.Wf = nn.Parameter(torch.zeros(hidden_size, hidden_size + input_size)) # dims på denna
    self.bf = nn.Parameter(torch.zeros(hidden_size, 1))

    # Input-gate layer parameters
    self.Wi = nn.Parameter(torch.zeros(hidden_size, hidden_size + input_size))
    self.bi = nn.Parameter(torch.zeros(hidden_size, 1))

    # Candidate parameters
    self.Wc = nn.Parameter(torch.zeros(hidden_size, hidden_size + input_size))
    self.bc = nn.Parameter(torch.zeros(hidden_size, 1))

    # Output-gate layer parameters
    self.Wo = nn.Parameter(torch.zeros(hidden_size, hidden_size + input_size))
    self.bo = nn.Parameter(torch.zeros(hidden_size, 1))

    self.init_weights()

  def init_weights(self):   # TODO: CHANGE THIS?
    """ Sets the weights in a standard way. """

    stdv = 1.0 / math.sqrt(self.hidden_size)
    for weight in self.parameters():
        weight.data.uniform_(-stdv, stdv)


  def forward(self, X, H_t=None, S_t=None): 
    """ Makes the forwardpass over each sequence in a batch training data (X) simultenously. 
    X is a batch of training data and needs to be on the form: [seq_length (image row direction), input_size (image column direction), batch_size] """

    batch_size, _, input_size = X.shape
    X = torch.transpose(X, 0, 2)
    X = torch.transpose(X, 0, 1)

    if H_t is None:
      H_t = torch.zeros(self.hidden_size, batch_size)
    if S_t is None:
      S_t = torch.zeros(self.hidden_size, batch_size)


    # The forward pass for each sequence in batch
    hidden_sequence = []
    for t in range(sequence_length):
      X_t = X[t, :, :]  # extracts the input vector at timestep t for each sequence in the batch, dim: [input_size, batch_size]. For an image the t:th row is the input vector.
      X_and_H = torch.cat((X_t, H_t), 0)

      # Update cell state (S_t)
      F_t = torch.sigmoid(self.Wf @ X_and_H + self.bf)
      I_t = torch.sigmoid(self.Wi @ X_and_H + self.bi)
      C_t = torch.tanh(self.Wc @ X_and_H + self.bc)  
      S_t = F_t * S_t + I_t * C_t  # Hadamard product

      # Update hidden state (H_t)
      O_t = torch.sigmoid(self.Wo @ X_and_H + self.bo)
      H_t = O_t * torch.tanh(S_t)

      hidden_sequence.append(H_t.unsqueeze(0)) # unsqueeze necessary for concatenation at the end

    # Concatenate and reshape
    hidden_sequences = torch.cat(hidden_sequence)
    hidden_sequences = torch.transpose(hidden_sequences, 1, 2)
    hidden_sequences = torch.transpose(hidden_sequences, 0, 1)

    # A regular LSTM would have the return statement below...    
    return hidden_sequences, (H_t, S_t) # hidden_sequences format: [seq_length, batch_size, hidden_size]


In [None]:
class LSTM_based_RNN(nn.Module):
  """ RNN that uses an LSTM module and an extra linear output-layer """
  
  def __init__(self, input_size, hidden_size, num_classes):
    super().__init__()
    self.num_layers = 1  # this only works for pytorch's implementation
    self.hidden_size = hidden_size
          
    # CHANGE THIS TO COMPARE LSTM:s
    #self.lstm = nn.LSTM(input_size, hidden_size, num_layers = self.num_layers, batch_first=True)
    self.lstm = emiLSTM_ver1(input_size, hidden_size)

    self.end_layer = nn.Linear(hidden_size, num_classes)
      
  def forward(self, X):  # -> x needs to be: (batch_size, seq_length, input_size)      

    if type(self.lstm).__name__ == 'emiLSTM_ver1': 
      out, _ = self.lstm(X)
    else:
      # Pytorch LSTM module requires setting initial  hidden and cell states outside module
      h0 = torch.zeros(self.num_layers, X.size(0), self.hidden_size).to(device) 
      c0 = torch.zeros(self.num_layers, X.size(0), self.hidden_size).to(device) 
      out, _ = self.lstm(X, (h0,c0))  
  
    # out: [batch_size, seq_length, hidden_size] in both cases since we're using: batch_first=True

    out = out[:, -1, :]  # Hidden states for each sample at last time step. Dims: [batch_size, hidden_size]
    out = self.end_layer(out)
    return out  # dims: [N, 10]


Model Training / Comparison

In [None]:
model = LSTM_based_RNN(input_size, hidden_size, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

n_total_steps = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  

        # Prepare batch_data
        images = images.reshape(-1, sequence_length, input_size).to(device)  #  from [batch_size, channels, seq_length, input_size] to [batch_size, seq_length, input_size]

        # Forward-pass
        outputs = model(images) # images needs to be: [batch_size, seq_length (image rows), input_size (image columns)]
        loss = criterion(outputs, labels)

        # Backward-pass and grad descent
        optimizer.zero_grad()  # set gradients to zero, otherwise they accumulate
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print (f'Epoch: {epoch+1}/{num_epochs} | Batch nr. {i+1}/{n_total_steps} | Loss: {loss.item():.4f}')


Epoch: 1/3 | Batch nr. 100/600 | Loss: 0.9142
Epoch: 1/3 | Batch nr. 200/600 | Loss: 0.4555
Epoch: 1/3 | Batch nr. 300/600 | Loss: 0.3580
Epoch: 1/3 | Batch nr. 400/600 | Loss: 0.2649
Epoch: 1/3 | Batch nr. 500/600 | Loss: 0.2165
Epoch: 1/3 | Batch nr. 600/600 | Loss: 0.2388
Epoch: 2/3 | Batch nr. 100/600 | Loss: 0.2751
Epoch: 2/3 | Batch nr. 200/600 | Loss: 0.2318
Epoch: 2/3 | Batch nr. 300/600 | Loss: 0.1760
Epoch: 2/3 | Batch nr. 400/600 | Loss: 0.1790
Epoch: 2/3 | Batch nr. 500/600 | Loss: 0.0756
Epoch: 2/3 | Batch nr. 600/600 | Loss: 0.0982
Epoch: 3/3 | Batch nr. 100/600 | Loss: 0.1136
Epoch: 3/3 | Batch nr. 200/600 | Loss: 0.1073
Epoch: 3/3 | Batch nr. 300/600 | Loss: 0.1303
Epoch: 3/3 | Batch nr. 400/600 | Loss: 0.1154
Epoch: 3/3 | Batch nr. 500/600 | Loss: 0.0582
Epoch: 3/3 | Batch nr. 600/600 | Loss: 0.0401


Accuracy on test set

In [None]:
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    for images, labels in test_loader:

        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()

    acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of network on test set (10000 images): {acc} %')

Accuracy of network on test set (10000 images): 97.32 %
