In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

In [2]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Parameters
BATCH_SIZE = 128

SEQUENCE_LENGTH = 28
INPUT_SIZE = 28
HIDDEN_SIZE = 128
NUM_LAYERS = 2
NUM_CLASSES = 10

LEARNING_RATE = 0.003
EPOCHS = 2

In [3]:
# MNIST dataset
train_dataset = datasets.MNIST(
    root="../../data",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
test_dataset = datasets.MNIST(
    root="../../data",
    train=False,
    transform=transforms.ToTensor()
)

# Data loader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [4]:
a, b = train_dataset[0]
print(a.shape)
print(type(b))

torch.Size([1, 28, 28])
<class 'int'>


In [5]:
class BiRNN(nn.Module):

    def __init__(self, input_size: int, hidden_size: int, num_layers: int, num_classes: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        print(x.shape)

        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)

        # Forward to LSTM
        out, _ = self.lstm(x, (h0, c0))  # output format: (batch_size, seq_length, hidden_size * 2)

        out = self.fc(out[:, -1, :])
        return out

In [6]:
model = BiRNN(INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS, NUM_CLASSES)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Train the model
for epoch in range(EPOCHS):
    for i, (images, labels) in enumerate(train_loader):
        print(images.shape)
        
        images = images.reshape(-1, SEQUENCE_LENGTH, INPUT_SIZE).to(device)
        labels = labels.to(device)

        print(f"I: {images.size()}")
        print(f"L: {labels.size()}")

        break
        
        # forward pass
        outputs = model(images)

        print(f"O: {outputs.size()}")

        loss = criterion(outputs, labels)

        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 99 == 0:
            print(f"Epoch [{epoch + 1}/{EPOCHS}], Loss: {round(loss.item(), 4)}")
        break
    break

print("DONE")

torch.Size([128, 1, 28, 28])
I: torch.Size([128, 28, 28])
L: torch.Size([128])
DONE


In [10]:
# Test model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        print(images.shape)
        
        images = images.reshape(-1, SEQUENCE_LENGTH, INPUT_SIZE).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Test Accuracy: {100 * correct / total}")

# Save model
torch.save(model.state_dict(), 'model.pth')

Test Accuracy: 97.59
