In [5]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

In [7]:
# Transformations applied on each image
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to tensor format
    transforms.Normalize((0.5,), (0.5,))  # Normalize images
])

# Loading the dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Data loaders for handling the dataset batching
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [8]:
class TransformerMNIST(nn.Module):
    def __init__(self, input_dim, num_heads, num_classes, dim_feedforward=512, num_layers=3):
        super().__init__()
        self.embedding = nn.Linear(input_dim, dim_feedforward)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim_feedforward, 
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(dim_feedforward, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)  # Pool over the sequence
        x = self.output_layer(x)
        return x

# Model initialization
model = TransformerMNIST(input_dim=28, num_heads=4, num_classes=10)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training loop with status updates
def train_model(model, train_loader):
    model.train()
    total_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        data = data.view(data.size(0), 28, 28)  # Reshape data

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Status update
        if (batch_idx + 1) % 100 == 0:  # Update every 100 batches
            print(f'Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item()}')

    print(f'Training Loss: {total_loss / len(train_loader)}')

# Test the model with status updates
def test_model(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data = data.view(data.size(0), 28, 28)  # Reshape data
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    print(f'Test Accuracy: {100. * correct / len(test_loader.dataset)}%')

# Run the training and testing
train_model(model, train_loader)
test_model(model, test_loader)


Batch 100/1875, Loss: 1.9712975025177002
Batch 200/1875, Loss: 1.871441125869751
Batch 300/1875, Loss: 1.9426764249801636
Batch 400/1875, Loss: 1.8364906311035156
Batch 500/1875, Loss: 2.2779624462127686
Batch 600/1875, Loss: 2.2522079944610596
Batch 700/1875, Loss: 2.3697690963745117
Batch 800/1875, Loss: 2.2546520233154297
Batch 900/1875, Loss: 2.3219988346099854
Batch 1000/1875, Loss: 2.342280626296997
Batch 1100/1875, Loss: 2.394291639328003
Batch 1200/1875, Loss: 2.3023767471313477
Batch 1300/1875, Loss: 2.1443378925323486
Batch 1400/1875, Loss: 2.2141313552856445
Batch 1500/1875, Loss: 2.195491313934326
Batch 1600/1875, Loss: 2.2670705318450928
Batch 1700/1875, Loss: 2.0390360355377197
Batch 1800/1875, Loss: 1.9013569355010986
Training Loss: 2.1675297175725303
Test Accuracy: 20.04%
