# Using a Custom Elman RNN for MNIST Classification (row and sequential)

In this notebook, we implement a **custom Elman RNN** to classify the **MNIST dataset** in both **row-wise** and **sequential** formats.

### Overview of the Implementation:
We define two key components:
1. ````CustomRNNLayer```` – A single-layer recurrent neural network (Elman RNN).
2. ````RNNBackbone```` – A full RNN-based model that stacks one or more `CustomRNNLayer` instances and adds a ```final linear layer``` to classify MNIST digits.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
from tqdm import tqdm
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import sys 
sys.path.append('../Flax')


In [2]:
if torch.cuda.is_available():
    print("CUDA is available")
    print(torch.cuda.get_device_name())
    device = torch.device("cuda")
    print(device)
    print(torch.cuda.current_device())
    print(torch.cuda.device_count())

else:
    print("CUDA is not available")
    device = torch.device("cpu")


CUDA is available
NVIDIA GeForce RTX 4090
cuda
0
1


In [3]:
from utils import create_mnist_classification_dataset

In [4]:
# Hyperparameters
BATCH_SIZE = 128
HIDDEN_SIZE = 64
LEARNING_RATE = 0.001
EPOCHS = 10
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATASET_VERSION = "row" # "sequential" or "row"

# Create dataset
train_loader, val_loader, test_loader, n_classes, seq_length, in_dim = create_mnist_classification_dataset(
    bsz=BATCH_SIZE, version=DATASET_VERSION
)


[*] Generating MNIST Classification Dataset...


In [5]:
class CustomRNNLayer(nn.Module):
    '''
    Custom RNN module with tanh activation.
    This layer is run for each time step in the sequence.
    It handles batched input the same way as nn.Linear does:
    - x: [batch_size, seq_len, input_size]
    - we access x_t for each time step t: x_t = x[:, t, :]
    - It computes x_new = x_t @ W.T + b (batch_size, hidden_size) for each time step.
    - W: [input_size, hidden_size]
    '''
    def __init__(self, input_size, hidden_size, output_size=10):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.output_size = output_size
        
        # Single weight matrices for the entire RNN
        self.W_ih = nn.Linear(input_size, hidden_size, bias=True)
        self.W_hh = nn.Linear(hidden_size, hidden_size, bias=False)
        # self.W_out = nn.Linear(hidden_size, output_size, bias=True)
    
    def forward(self, x):
        """
        Processes a batch of sequences through the RNN.
        
        Args:
            x: Input tensor of shape [batch_size, seq_len, input_size]
        
        Returns:
            tuple: (state_history, output_history)
                - state_history: Tensor of shape [batch_size, seq_len, hidden_size]
                - output_history: Tensor of shape [batch_size, seq_len, output_size]
        """
        batch_size, seq_len, _ = x.shape
        device = x.device
        
        # Initialize hidden state
        h = torch.zeros(batch_size, self.hidden_size, device=device)
        
        # Prepare tensors to store history
        state_history = torch.zeros(batch_size, seq_len, self.hidden_size, device=device)
        # output_history = torch.zeros(batch_size, seq_len, self.output_size, device=device)
        
        # Efficiently process the sequence using vectorized operations
        for t in range(seq_len):
            # Extract current input across all batches
            x_t = x[:, t, :]
            
            # Combined operation: h = tanh(W_ih @ x_t + W_hh @ h)
            h = torch.tanh(self.W_ih(x_t) + self.W_hh(h))
            
            # Calculate output
            # y = torch.tanh(self.W_out(h))
            
            # Store states and outputs
            state_history[:, t, :] = h
            # output_history[:, t, :] = y
        
        return state_history #, output_history
    
# Create model
class RNNBackbone(nn.Module):
    '''
    RNN backbone using 1 recurrent layer and 1 readout layer
    '''
    def __init__(self, input_size, hidden_size, output_size=10):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.input_size = input_size
        self.rnn_layer = CustomRNNLayer(input_size, hidden_size, output_size)
        self.W_out = nn.Linear(hidden_size, output_size, bias=True)

    
    def forward(self, x):
        # x shape: [seq_len, input_size] or [batch_size, seq_len, input_size]
        
        # state_hist, out_hist = self.rnn_layer(x)
        state_hist = self.rnn_layer(x)
        out_hist = self.W_out(state_hist)
        return state_hist, out_hist

In [6]:
def train(model, train_loader, optimizer, criterion, device='cuda'):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        # Convert numpy arrays to PyTorch tensors
        inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
        targets = torch.tensor(targets, dtype=torch.long).to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        _, outputs = model(inputs)
        # Use the last output for classification
        final_outputs = outputs[:, -1, :]
        
        loss = criterion(final_outputs, targets)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Calculate accuracy
        _, predicted = final_outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    accuracy = 100.0 * correct / total
    average_loss = total_loss / len(train_loader)
    
    return average_loss, accuracy

def validate(model, val_loader, criterion, device='cuda'):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            # Convert numpy arrays to PyTorch tensors
            inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
            targets = torch.tensor(targets, dtype=torch.long).to(device)
            
            # Forward pass
            _, outputs = model(inputs)
            # Use the last output for classification
            final_outputs = outputs[:, -1, :]
            
            loss = criterion(final_outputs, targets)
            total_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = final_outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    accuracy = 100.0 * correct / total
    average_loss = total_loss / len(val_loader)
    
    return average_loss, accuracy

def test(model, test_loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            # Convert numpy arrays to PyTorch tensors
            inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
            targets = torch.tensor(targets, dtype=torch.long).to(device)
            
            # Forward pass
            _, outputs = model(inputs)
            # Use the last output for classification
            final_outputs = outputs[:, -1, :]
            
            # Calculate accuracy
            _, predicted = final_outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    accuracy = 100.0 * correct / total
    
    return accuracy


In [7]:
# Create model
model = RNNBackbone(in_dim, HIDDEN_SIZE, n_classes).to(DEVICE)
# model = CustomRNNLayer(in_dim, HIDDEN_SIZE, n_classes).to(DEVICE)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Training loop
for epoch in range(1, EPOCHS + 1):
    start_time = time.time()
    
    # Training
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, DEVICE)
    
    # Validation (using test set as validation in this case)
    val_loss, val_acc = validate(model, test_loader, criterion, DEVICE)
    
    # Calculate epoch time
    epoch_time = time.time() - start_time
    
    # Print statistics
    print(f'Epoch {epoch}/{EPOCHS}:')
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
    print(f'Time: {epoch_time:.2f}s')
    print('-' * 60)

# Final test
test_acc = test(model, test_loader, DEVICE)
print(f'Test Accuracy: {test_acc:.2f}%')


100%|██████████| 468/468 [00:12<00:00, 37.18it/s]


Epoch 1/10:
Train Loss: 1.1373 | Train Acc: 61.61%
Val Loss: 0.6638 | Val Acc: 77.86%
Time: 13.80s
------------------------------------------------------------


100%|██████████| 468/468 [00:11<00:00, 41.88it/s]


Epoch 2/10:
Train Loss: 0.5846 | Train Acc: 81.23%
Val Loss: 0.4678 | Val Acc: 86.10%
Time: 12.38s
------------------------------------------------------------


 29%|██▊       | 134/468 [00:02<00:07, 46.59it/s]


KeyboardInterrupt: 

: 