In [1]:
import torch
import torch.utils.data as dataloader

In [2]:
import torch.nn as nn

In [3]:
import torchvision
import torchvision.transforms as transforms

In [4]:
data_transformation = transforms.Compose([transforms.ToTensor()])

In [5]:
train_dataset = torchvision.datasets.MNIST(root="./data", 
                                           train=True,
                                           download=True,
                                           transform=data_transformation)
val_dataset = torchvision.datasets.MNIST(root="./data", 
                                         train=False,
                                         download=True,
                                         transform=data_transformation)

In [6]:
img_size = 28
num_channels = 1
patch_size = 7
num_patches = (img_size//patch_size)**2
token_dim = 32
num_heads = 4 
transformer_blocks = 4  # L
num_classes = 10
batch_size=64
mlp_hidden_dim = 64
learning_rate = 3e-4
epochs = 5

In [7]:
train_dataloader = dataloader.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = dataloader.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# PART 1 : Patch Embedding

In [8]:
class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = nn.Conv2d(num_channels, token_dim, patch_size, patch_size)
    
    def forward(self, x):  # x: (64, 1, 28, 28)
        x = self.patch_embed(x) # (batch, token_dim, patch_size, patch_size) = (64,32,4,4)
        x = x.flatten(2)  # (64, 32, 16)
        x = x.transpose(1, 2) # (64, 16, 32)
        return x

# PART 2 : Transformer Encoder

In [9]:
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layernorm1 = nn.LayerNorm(token_dim)
        self.layernorm2 = nn.LayerNorm(token_dim)
        self.multihead_attention = nn.MultiheadAttention(token_dim, num_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(token_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, token_dim),
        )

    def forward(self, x):
        residual1 = x

        x = self.layernorm1(x)
        x = self.multihead_attention(x, x, x)[0] # context vector
        x = x + residual1

        residual2 = x
        x = self.layernorm2(x)
        x = self.mlp(x)
        x = x + residual2

        return x

# PART 3 : MLP Classification head

In [10]:
class MLPHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.layernorm = nn.LayerNorm(token_dim)
        self.mlp = nn.Linear(token_dim, num_classes)

    def forward(self, x):
        x = self.layernorm(x)
        x = self.mlp(x)
        return x

# PART 1, 2 3 combined

In [None]:
class Visitiontransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embedding = PatchEmbedding()
        self.cls_token = nn.Parameter(torch.randn(1,1,token_dim))
        self.position_embedding = nn.Parameter(torch.rand(1, num_patches+1, token_dim))
        self.transformer_blocks = nn.Sequential(*[TransformerEncoder() for _ in range(transformer_blocks)])
        self.mlp_head = MLPHead()

    def forward(self, x):
        x = self.patch_embedding(x)
        B = x.shape[0]
        cls_token = self.cls_token.expand(B, -1, -1) # (B, 1, token_dim)
        x = torch.cat((cls_token, x), dim=1) # (B, num_patches+1, token_dim)
        x = x + self.position_embedding # (B, num_patches+1, token_dim)

        x = self.transformer_blocks(x) # (B, num_patches+1, token_dim)
        x = x[:, 0]
        x = self.mlp_head(x)

        return x


# device, model, optimizer, loss

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Visitiontransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [13]:
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in train_dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Validation Accuracy: {accuracy:.2f}%')

Epoch [1/5], Loss: 0.6641
Validation Accuracy: 92.26%
Epoch [2/5], Loss: 0.2082
Validation Accuracy: 95.03%
Epoch [3/5], Loss: 0.1540
Validation Accuracy: 95.71%
Epoch [4/5], Loss: 0.1204
Validation Accuracy: 95.95%
Epoch [5/5], Loss: 0.1060
Validation Accuracy: 96.31%


In [14]:
for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_epoch = 0
    print(f"\nEpoch {epoch+1}")

    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct = (preds == labels).sum().item()
        accuracy = 100.0 * correct / labels.size(0)

        correct_epoch += correct
        total_epoch += labels.size(0)

        if batch_idx % 100 == 0:
            print(f"  Batch {batch_idx+1:3d}: Loss = {loss.item():.4f}, Accuracy = {accuracy:.2f}%")

    epoch_acc = 100.0 * correct_epoch / total_epoch
    print(f"==> Epoch {epoch+1} Summary: Total Loss = {total_loss:.4f}, Accuracy = {epoch_acc:.2f}%")


Epoch 1


NameError: name 'train_loader' is not defined

In [None]:
# Vision Transformer Class# Switch to evaluation mode
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_acc = 100.0 * correct / total
print(f"\n==> Val Accuracy: {test_acc:.2f}%")

In [None]:
import matplotlib.pyplot as plt

# Show 10 predictions from the first test batch
model.eval()
images, labels = next(iter(val_loader))
images, labels = images.to(device), labels.to(device)
with torch.no_grad():
    outputs = model(images)
    preds = outputs.argmax(dim=1)

# Move to CPU for plotting
images = images.cpu()
preds = preds.cpu()
labels = labels.cpu()

# Plot first 10 images
plt.figure(figsize=(12, 4))
for i in range(10):
    plt.subplot(2, 5, i+1)
    plt.imshow(images[i].squeeze(), cmap='gray')
    plt.title(f"Pred: {preds[i].item()}\nTrue: {labels[i].item()}")
    plt.axis('off')
plt.tight_layout()