# Using a Custom Elman RNN for MNIST Classification (row and sequential)

In this notebook, we implement a **custom Elman RNN** to classify the **MNIST dataset** in both **row-wise** and **sequential** formats.

### Overview of the Implementation:
We define two key components:
1. ````CustomRNNLayer```` – A single-layer recurrent neural network (Elman RNN).
2. ````RNNBackbone```` – A full RNN-based model that stacks one or more `CustomRNNLayer` instances and adds a ```final linear layer``` to classify MNIST digits.

---

## Optimizing Computation for Efficiency

One common bottleneck in RNN implementations is the repeated matrix multiplication inside the loop. To optimize this, we leverage the following observation:

### Key Idea:
When multiplying a matrix ```W``` of shape ```(input_size, hidden_size)``` with an input tensor ```x```, each row of ```x``` undergoes the same transformation. Instead of computing ```x_t @ W``` separately for each time step ```t```, we can **precompute the entire transformation** outside the loop.

For an input ```x``` of shape ```(batch_size, seq_len, input_size)```, conceptually, we can think of it as a collection of individual ```(input_size,)``` vectors, each of which gets multiplied by ```W```. Instead of performing these multiplications step by step inside the loop, we **batch the operation** in one go:

$$
\~x = x @ W
$$

This provides the transformed values for all time steps at once, which we can then slice efficiently within the loop.

### What This Means in Practice:
Instead of performing ```x_t @ W``` at every time step ```t```, we:
1. ```Precompute``` the full transformation ```x_tilde = x @ W``` before entering the loop.
2. ```Slice``` the corresponding row at each time step ```t```.

This avoids redundant computation and significantly speeds up the RNN execution.

### Code Comparison:
- **Previous (Less Efficient) Approach**:  
  Each time step computes ```x_t @ W`` inside the loop.
- **Optimized Approach**:
  We precompute ```x_tilde = x @ W``` once and slice the result in the loop.

This simple change leads to a **more efficient** implementation without altering the core behavior of the RNN.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
from tqdm import tqdm
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys 
sys.path.append('../Flax')


In [2]:
if torch.cuda.is_available():
    print("CUDA is available")
    print(torch.cuda.get_device_name())
    device = torch.device("cuda")
    print(device)
    print(torch.cuda.current_device())
    print(torch.cuda.device_count())

else:
    print("CUDA is not available")
    device = torch.device("cpu")


CUDA is available
NVIDIA GeForce RTX 4090
cuda
0
1


In [3]:
from utils import create_mnist_classification_dataset

## Illustration of the matrix multiplication optimization:

In [4]:
a = torch.arange(0, 27).reshape(3,3,3)
# change the dtype of the tensor
a = a.type(torch.float32)
print(a)
print(a.shape)

tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.],
         [ 6.,  7.,  8.]],

        [[ 9., 10., 11.],
         [12., 13., 14.],
         [15., 16., 17.]],

        [[18., 19., 20.],
         [21., 22., 23.],
         [24., 25., 26.]]])
torch.Size([3, 3, 3])


In [None]:
b = torch.ones(3,3) * torch.tensor([1,2,3])
print(b)
print(b.shape)

tensor([[1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.]])
torch.Size([3, 3])


In [6]:
print(a @ b)

tensor([[[  3.,   6.,   9.],
         [ 12.,  24.,  36.],
         [ 21.,  42.,  63.]],

        [[ 30.,  60.,  90.],
         [ 39.,  78., 117.],
         [ 48.,  96., 144.]],

        [[ 57., 114., 171.],
         [ 66., 132., 198.],
         [ 75., 150., 225.]]])


## New Optimized Approach:

In [7]:
# Hyperparameters
BATCH_SIZE = 128
HIDDEN_SIZE = 64
LEARNING_RATE = 0.001
EPOCHS = 10
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATASET_VERSION = "row" # "sequential" or "row"

# Create dataset
train_loader, val_loader, test_loader, n_classes, seq_length, in_dim = create_mnist_classification_dataset(
    bsz=BATCH_SIZE, version=DATASET_VERSION
)


[*] Generating MNIST Classification Dataset...


In [8]:
class CustomRNN(nn.Module):
    '''
    Custom RNN module with tanh activation.
    This layer is run for each time step in the sequence.
    It handles batched input the same way as nn.Linear does:
    - x: [batch_size, seq_len, input_size]
    - we access x_t for each time step t: x_t = x[:, t, :]
    - It computes x_new = x_t @ W.T + b (batch_size, hidden_size) for each time step.
    - W: [input_size, hidden_size]
    '''
    def __init__(self, input_size, hidden_size, output_size=10):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.output_size = output_size
        
        # Single weight matrices for the entire RNN
        self.W_ih = nn.Linear(input_size, hidden_size, bias=True)
        self.W_hh = nn.Linear(hidden_size, hidden_size, bias=False)
        # self.W_out = nn.Linear(hidden_size, output_size, bias=True)
    
    def forward(self, x):
        """
        Processes a batch of sequences through the RNN.
        
        Args:
            x: Input tensor of shape [batch_size, seq_len, input_size]
        
        Returns:
            tuple: (state_history, output_history)
                - state_history: Tensor of shape [batch_size, seq_len, hidden_size]
                - output_history: Tensor of shape [batch_size, seq_len, output_size]
        """
        batch_size, seq_len, _ = x.shape
        device = x.device
        
        # Initialize hidden state
        h = torch.zeros(batch_size, self.hidden_size, device=device)
        
        # Prepare tensors to store history
        state_history = torch.zeros(batch_size, seq_len, self.hidden_size, device=device)
        # output_history = torch.zeros(batch_size, seq_len, self.output_size, device=device)
        
        x_tilde = self.W_ih(x)
        # Efficiently process the sequence using vectorized operations
        for t in range(seq_len):
            # Extract current input across all batches
            x_t = x_tilde[:, t, :]
            
            # Combined operation: h = tanh(W_ih @ x_t + W_hh @ h)
            h = torch.tanh(x_t + self.W_hh(h))
            
            # Calculate output
            # y = torch.tanh(self.W_out(h))
            
            # Store states and outputs
            state_history[:, t, :] = h
            # output_history[:, t, :] = y
        
        return state_history #, output_history
    
# Create model
class RNN(nn.Module):
    '''
    RNN backbone using 1 recurrent layer and 1 readout layer
    '''
    def __init__(self, input_size, hidden_size, output_size=10):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.input_size = input_size
        self.rnn_layer = CustomRNN(input_size, hidden_size, output_size)
        self.W_out = nn.Linear(hidden_size, output_size, bias=True)

    
    def forward(self, x):
        # x shape: [seq_len, input_size] or [batch_size, seq_len, input_size]
        
        # state_hist, out_hist = self.rnn_layer(x)
        state_hist = self.rnn_layer(x)
        out_hist = self.W_out(state_hist)
        return state_hist, out_hist

In [9]:
def train(model, train_loader, optimizer, criterion, device='cuda'):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        # Convert numpy arrays to PyTorch tensors
        inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
        targets = torch.tensor(targets, dtype=torch.long).to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        _, outputs = model(inputs)
        # Use the last output for classification
        final_outputs = outputs[:, -1, :]
        
        loss = criterion(final_outputs, targets)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Calculate accuracy
        _, predicted = final_outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    accuracy = 100.0 * correct / total
    average_loss = total_loss / len(train_loader)
    
    return average_loss, accuracy

def validate(model, val_loader, criterion, device='cuda'):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            # Convert numpy arrays to PyTorch tensors
            inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
            targets = torch.tensor(targets, dtype=torch.long).to(device)
            
            # Forward pass
            _, outputs = model(inputs)
            # Use the last output for classification
            final_outputs = outputs[:, -1, :]
            
            loss = criterion(final_outputs, targets)
            total_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = final_outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    accuracy = 100.0 * correct / total
    average_loss = total_loss / len(val_loader)
    
    return average_loss, accuracy

def test(model, test_loader, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            # Convert numpy arrays to PyTorch tensors
            inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
            targets = torch.tensor(targets, dtype=torch.long).to(device)
            
            # Forward pass
            _, outputs = model(inputs)
            # Use the last output for classification
            final_outputs = outputs[:, -1, :]
            
            # Calculate accuracy
            _, predicted = final_outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    accuracy = 100.0 * correct / total
    
    return accuracy


In [None]:
# Create model
model = RNN(in_dim, HIDDEN_SIZE, n_classes).to(DEVICE)
# model = CustomRNN(in_dim, HIDDEN_SIZE, n_classes).to(DEVICE)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Training loop
for epoch in range(1, EPOCHS + 1):
    start_time = time.time()
    
    # Training
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, DEVICE)
    
    # Validation (using test set as validation in this case)
    val_loss, val_acc = validate(model, test_loader, criterion, DEVICE)
    
    # Calculate epoch time
    epoch_time = time.time() - start_time
    
    # Print statistics
    print(f'Epoch {epoch}/{EPOCHS}:')
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
    print(f'Time: {epoch_time:.2f}s')
    print('-' * 60)
        
# Final test
test_acc = test(model, test_loader, DEVICE)
print(f'Test Accuracy: {test_acc:.2f}%')


100%|██████████| 468/468 [00:12<00:00, 37.27it/s]


Epoch 1/10:
Train Loss: 1.2272 | Train Acc: 57.12%
Val Loss: 0.7780 | Val Acc: 74.83%
Time: 13.78s
------------------------------------------------------------


100%|██████████| 468/468 [00:11<00:00, 41.72it/s]


Epoch 2/10:
Train Loss: 0.6493 | Train Acc: 79.77%
Val Loss: 0.5507 | Val Acc: 83.94%
Time: 12.45s
------------------------------------------------------------


 93%|█████████▎| 436/468 [00:09<00:00, 45.20it/s]


KeyboardInterrupt: 

: 