# Using a Custom Elman RNN for MNIST Classification (row and sequential)

In this notebook, we implement a **custom Elman RNN** to classify the **MNIST dataset** in both **row-wise** and **sequential** formats.

### Overview of the Implementation:
We define two key components:
1. ````CustomRNNLayer```` – A single-layer recurrent neural network (Elman RNN).
2. ````RNNBackbone```` – A full RNN-based model that stacks one or more `CustomRNNLayer` instances and adds a ```final linear layer``` to classify MNIST digits.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
from tqdm import tqdm
import lightning as L

import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import sys 
sys.path.append('../Flax')


In [3]:
if torch.cuda.is_available():
    print("CUDA is available")
    print(torch.cuda.get_device_name())
    device = torch.device("cuda")
    print(device)
    print(torch.cuda.current_device())
    print(torch.cuda.device_count())

else:
    print("CUDA is not available")
    device = torch.device("cpu")


CUDA is available
NVIDIA GeForce RTX 4090
cuda
0
1


In [4]:
from utils import create_mnist_classification_dataset

In [5]:
# Hyperparameters
BATCH_SIZE = 128
HIDDEN_SIZE = 64
LEARNING_RATE = 0.001
EPOCHS = 10
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATASET_VERSION = "sequential" # "sequential" or "row"

# Create dataset
train_loader, val_loader, test_loader, n_classes, seq_length, in_dim = create_mnist_classification_dataset(
    bsz=BATCH_SIZE, version=DATASET_VERSION
)

# split the dataset into train and validation
train_size = int(0.8 * len(train_loader.dataset))
valid_size = len(train_loader.dataset) - train_size
train_dataset, valid_dataset = torch.utils.data.random_split(
    train_loader.dataset, [train_size, valid_size]
)

batch_images, batch_labels = next(iter(train_loader))
print(batch_images.shape)
print(batch_labels.shape)


[*] Generating MNIST Classification Dataset...
(128, 784, 1)
(128,)


In [24]:
# Create model

class LightningRNNBackbone(L.LightningModule):
    def __init__(self, input_size, hidden_size, output_size, criterion, batch_size):
        '''
        RNN backbone using 1 recurrent layer and 1 readout layer
        '''
        super().__init__()
        self.batch_size = batch_size
        self.criterion = criterion
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.input_size = input_size
        self.rnn_layer = nn.RNN(input_size, hidden_size, nonlinearity='tanh', bias=True, batch_first=True)
        self.W_out = nn.Linear(hidden_size, output_size, bias=True)
        self.logger_kwargs = {"batch_size": batch_size, "on_epoch": True, "on_step":True, "prog_bar": True}

    def forward(self, x):
        # x shape: [seq_len, input_size] or [batch_size, seq_len, input_size]
        
        # state_hist, out_hist = self.rnn_layer(x)
        state_hist, _ = self.rnn_layer(x)
        out_hist = self.W_out(state_hist)
        return state_hist, out_hist
    
    def run_batch(self, batch):
        # x shape: [seq_len, input_size] or [batch_size, seq_len, input_size]
        x, y = batch
        x = torch.tensor(x, dtype=torch.float32).to(self.device)
        y = torch.tensor(y, dtype=torch.long).to(self.device)
        # state_hist, out_hist = self.rnn_layer(x)
        state_hist, _ = self(x)
        out_hist = self.W_out(state_hist)
        final_outputs = out_hist[:, -1, :]
        return self.criterion(final_outputs, y)

    
    def training_step(self, batch, batch_idx):

        loss = self.run_batch(batch)
        # Logging to TensorBoard (if installed) by default
        # self.log("train_loss", loss, prog_bar=True, on_epoch=True, batch_size=128)
        self.log("train_loss", loss, **self.logger_kwargs) 

        return loss
    
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        loss = self.run_batch(batch)
        self.log("val_loss", loss, **self.logger_kwargs) 
        return loss
    
    def test_step(self, batch, batch_idx):
        # this is the test loop
        loss = self.run_batch(batch)
        self.log("test_loss", loss, **self.logger_kwargs) 
        return loss

    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


In [25]:
model = LightningRNNBackbone(in_dim, HIDDEN_SIZE, n_classes, nn.CrossEntropyLoss(), BATCH_SIZE)
model.to(device)
print(model.device)


cuda:0


In [26]:
# from lightning.pytorch.callbacks import RichProgressBar

# # train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
# callbacks = [RichProgressBar(leave=True)]
from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger("logs/", name="rnn_experiment")

from lightning.pytorch.callbacks import TQDMProgressBar
callbacks = [TQDMProgressBar(leave=True)]
trainer = L.Trainer(max_epochs=2, callbacks=callbacks,
                    log_every_n_steps=50,
                    limit_train_batches=200, 
                    enable_progress_bar=True, logger=logger)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | criterion | CrossEntropyLoss | 0      | train
1 | rnn_layer | RNN              | 4.3 K  | train
2 | W_out     | Linear           | 650    | train
-------------------------------------------------------
4.9 K     Trainable params
0         Non-trainable params
4.9 K     Total params
0.020     Total estimated model params size (MB)
3         Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 303.06it/s]

Training: |          | 0/? [00:00<?, ?it/s]                                 
Epoch 0: 100%|██████████| 200/200 [00:04<00:00, 41.29it/s, v_num=6, train_loss_step=2.310, val_loss_step=2.290, val_loss_epoch=2.300, train_loss_epoch=2.300]
Epoch 1: 100%|██████████| 200/200 [00:04<00:00, 40.94it/s, v_num=6, train_loss_step=2.310, val_loss_step=2.290, val_loss_epoch=2.300, train_loss_epoch=2.300]

`Trainer.fit` stopped: `max_epochs=2` reached.





In [21]:
# test the model
trainer.test(model, dataloaders=test_loader)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]
/home/tristan/miniconda3/envs/.jax_conda_env_LearningJAX/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 78/78 [00:01<00:00, 74.18it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_loss_epoch        2.2979655265808105
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss_epoch': 2.2979655265808105}]

In [9]:
!pip3 uninstall ipywidgets -y

  pid, fd = os.forkpty()


Found existing installation: ipywidgets 8.1.5
Uninstalling ipywidgets-8.1.5:
  Successfully uninstalled ipywidgets-8.1.5


In [1]:
!pip3 install tensorboard

Collecting tensorboard
  Downloading tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting grpcio>=1.48.2 (from tensorboard)
  Downloading grpcio-1.70.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)
Collecting markdown>=2.6.8 (from tensorboard)
  Using cached Markdown-3.7-py3-none-any.whl.metadata (7.0 kB)
Collecting protobuf!=4.24.0,>=3.19.6 (from tensorboard)
  Downloading protobuf-5.29.3-cp38-abi3-manylinux2014_x86_64.whl.metadata (592 bytes)
Collecting tensorboard-data-server<0.8.0,>=0.7.0 (from tensorboard)
  Using cached tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl.metadata (1.1 kB)
Collecting werkzeug>=1.0.1 (from tensorboard)
  Using cached werkzeug-3.1.3-py3-none-any.whl.metadata (3.7 kB)
Downloading tensorboard-2.19.0-py3-none-any.whl (5.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading grpcio-1.70.0-cp312-cp