In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import math

# Hyperparameters
BATCH_SIZE = 64
EPOCHS = 5
LR = 1e-3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
transform = transforms.Compose([
    transforms.ToTensor()  # Converts to [1, 28, 28]
])

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=BATCH_SIZE)


In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)


In [7]:
class MNISTTransformer(nn.Module):
    def __init__(self, input_dim=28, model_dim=128, num_heads=4, num_layers=2, num_classes=10):
        super().__init__()
        self.embedding = nn.Linear(input_dim, model_dim)
        self.pos_enc = PositionalEncoding(model_dim)

        encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dim_feedforward=256)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.classifier = nn.Linear(model_dim, num_classes)

    def forward(self, x):
        # x: (batch, 1, 28, 28)
        x = x.squeeze(1)            # (batch, 28, 28)
        x = self.embedding(x)       # (batch, 28, model_dim)
        x = self.pos_enc(x)         # (batch, 28, model_dim)
        x = x.permute(1, 0, 2)      # (28, batch, model_dim)
        x = self.transformer(x)     # (28, batch, model_dim)
        x = x.mean(dim=0)           # (batch, model_dim)
        out = self.classifier(x)    # (batch, num_classes)
        return out


In [8]:
model = MNISTTransformer().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()




In [9]:
def train():
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        output = model(images)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()


In [10]:
def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            output = model(images)
            preds = output.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = correct / total
    print(f"Test Accuracy: {acc * 100:.2f}%")


In [11]:
for epoch in range(EPOCHS):
    train()
    print(f"Epoch {epoch+1} complete.")
    test()


Epoch 1 complete.
Test Accuracy: 96.18%
Epoch 2 complete.
Test Accuracy: 97.07%
Epoch 3 complete.
Test Accuracy: 97.34%
Epoch 4 complete.
Test Accuracy: 97.04%
Epoch 5 complete.
Test Accuracy: 97.97%
