In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import unicodedata
import string
import os
from tqdm.auto import tqdm
import io
import glob
import random

In [None]:
input_size = 28
sequence_length = 28
num_classes = 10
num_layers = 2
hidden_size = 256
learning_rate = 0.001
batch_size = 32
num_of_epochs = 5

* Input data shape = batch size x color channels x height x width

# Creating RNN Models

## RNN version 1

In [78]:
class RNNversion1(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNNversion1, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x: torch.Tensor):
       initial_hidden_state = torch.zeros(self.num_layers, x.size(0),self.hidden_size)
       x, _ = self.rnn(x, initial_hidden_state)
       x  = x[:, -1, :]
       x = self.fc(x)
       return x

## RNN Version 2

In [80]:
class RNNversion2(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNNversion2, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size*sequence_length, num_classes)

    def forward(self, x: torch.Tensor):
        ##initialize hidden state
        initial_hidden_state = torch.zeros(self.num_layers, x.size(0),self.hidden_size)
        x, _ = self.rnn(x, initial_hidden_state)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        return x

## LSTM

In [None]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Sequential(
            nn.Flatten(start_dim=1),
            nn.Linear(hidden_size*sequence_length, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        )
    def forward(self, x: torch.Tensor):
        ##initialize hidden state
        initial_hidden_state = torch.zeros(self.num_layers, x.size(0),self.hidden_size)
        initial_cell_state = torch.zeros(self.num_layers, x.size(0),self.hidden_size)
        x, _ = self.lstm(x, (initial_hidden_state, initial_cell_state))
        x = self.fc(x)
        return x

In [83]:
# Hyperparameters
input_size = 2
hidden_size = 100
num_layers = 20
num_classes = 50
sequence_length = 10
batch_size = 4

# Generate dummy data
# Input tensor of shape (batch_size, sequence_length, input_size)
dummy_input = torch.randn(batch_size, sequence_length, input_size)

# Instantiate the model
model1 = RNNversion1(input_size, hidden_size, num_layers, num_classes)
# model2 = RNNversion2(input_size, hidden_size, num_layers, num_classes)
model2 = LSTM(input_size, hidden_size, num_layers, num_classes)

# Forward pass with dummy data
output1 = model1(dummy_input)
output2 = model2(dummy_input)

# print("Input shape:", dummy_input.shape)
# print("Output shape:", output1.shape)
# print("Output:", output1)
print("Output shape:", output2.shape)
print("Output:", output2)

Output shape: torch.Size([4, 50])
Output: tensor([[ 0.0901,  0.1010,  0.0744,  0.0105,  0.1101,  0.1084, -0.0263, -0.0751,
         -0.0860,  0.0226,  0.0301, -0.0400, -0.0561, -0.0916, -0.0444, -0.0637,
          0.0251,  0.0408,  0.0745, -0.0650,  0.0215,  0.0650, -0.0265,  0.0188,
          0.0634,  0.0627,  0.0173,  0.0527,  0.0402, -0.0720,  0.0639, -0.0492,
          0.0640, -0.0485, -0.0953,  0.0530,  0.0525,  0.0348,  0.0398,  0.0923,
         -0.0609,  0.0202,  0.0518, -0.0108,  0.0164, -0.0496,  0.0081,  0.0815,
         -0.0600, -0.0083],
        [ 0.0901,  0.1010,  0.0744,  0.0105,  0.1101,  0.1084, -0.0263, -0.0751,
         -0.0860,  0.0226,  0.0301, -0.0400, -0.0561, -0.0916, -0.0444, -0.0637,
          0.0251,  0.0408,  0.0745, -0.0650,  0.0215,  0.0650, -0.0265,  0.0188,
          0.0634,  0.0627,  0.0173,  0.0527,  0.0402, -0.0720,  0.0639, -0.0492,
          0.0640, -0.0485, -0.0953,  0.0530,  0.0525,  0.0348,  0.0398,  0.0923,
         -0.0609,  0.0202,  0.0518, -0.