In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

# MNIST as 28x28 → sequence length = 28, input_size = 28
transform = transforms.Compose([transforms.ToTensor()])

train_set = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='data', train=False, download=True, transform=transform)

# Convert to sequences
def prepare_sequences(dataset):
    X, y = [], []
    for img, label in dataset:
        # img: 1x28x28 -> 28x28 sequence
        X.append(img.squeeze(0).T)  # transpose to (seq_len, input_size)
        y.append(label)
    return torch.stack(X), torch.tensor(y)

X_train, y_train = prepare_sequences(train_set)
X_test, y_test = prepare_sequences(test_set)

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=64, shuffle=False)

print("LSTM-ready data:", X_train.shape, y_train.shape)


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 485kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.51MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.2MB/s]


LSTM-ready data: torch.Size([60000, 28, 28]) torch.Size([60000])


In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [None]:
import torch.nn as nn
import torch.nn.functional as F

class MNIST_LSTM(nn.Module):
    def __init__(self, input_size=28, hidden_size=128, num_layers=2, num_classes=10):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # x: batch, seq_len, input_size
        out, _ = self.lstm(x)
        out = out[:, -1, :]  # take last timestep
        out = self.fc(out)
        return out

model = MNIST_LSTM().to(device)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [None]:
def train_lstm(model, epochs=3):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}")

train_lstm(model)


Epoch 1/3, Loss: 0.4879
Epoch 2/3, Loss: 0.1293
Epoch 3/3, Loss: 0.0897


In [None]:
def evaluate_lstm(model):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            preds = outputs.argmax(dim=1)
            correct += (preds == y_batch).sum().item()
            total += y_batch.size(0)
    print(f"LSTM Test Accuracy: {100*correct/total:.2f}%")

evaluate_lstm(model)


LSTM Test Accuracy: 97.65%
