In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import math

In [2]:
# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.query = nn.Linear(dim, dim)
        self.key   = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.out   = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, D = x.shape
        Q = self.query(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        context = attn @ V

        context = context.transpose(1, 2).reshape(B, N, D)
        return self.out(context)

In [3]:
# Transformer Encoder Layer
class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(dim, num_heads)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),  # Bonus: GELU
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [4]:
# Transformer Stack
class TransformerEncoder(nn.Module):
    def __init__(self, dim, depth, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(dim, num_heads, mlp_dim, dropout)
            for _ in range(depth)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [5]:
# ViT Model
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2

        self.patch_embed = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, dim))

        self.dropout = nn.Dropout(dropout)

        self.transformer = TransformerEncoder(dim, depth, heads, mlp_dim, dropout)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        B, C, H, W = img.shape
        patches = img.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(B, C, -1, self.patch_size, self.patch_size)
        patches = patches.permute(0, 2, 1, 3, 4).reshape(B, -1, 3 * self.patch_size ** 2)

        x = self.patch_embed(patches)
        x = torch.cat([self.cls_token.expand(B, -1, -1), x], dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)  # Bonus: Dropout after pos embed

        x = self.transformer(x)
        return self.mlp_head(x[:, 0])

In [8]:
# load dataset 

transform = transforms.Compose([
    #transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
test_loader  = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
# initialize model, loss nad optimizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=10,
    dim=128,
    depth=6,
    heads=8,
    mlp_dim=256,
    dropout=0.1
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)


In [12]:
# training loop

for epoch in range(100):
    model.train()
    total_loss = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        preds = model(images)
        loss = criterion(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} - Training Loss: {avg_loss:.4f}")


Epoch 1 - Training Loss: 0.7468
Epoch 2 - Training Loss: 0.7031
Epoch 3 - Training Loss: 0.6623
Epoch 4 - Training Loss: 0.6161
Epoch 5 - Training Loss: 0.5804
Epoch 6 - Training Loss: 0.5434
Epoch 7 - Training Loss: 0.5061
Epoch 8 - Training Loss: 0.4706
Epoch 9 - Training Loss: 0.4450
Epoch 10 - Training Loss: 0.4137
Epoch 11 - Training Loss: 0.3839
Epoch 12 - Training Loss: 0.3640
Epoch 13 - Training Loss: 0.3403
Epoch 14 - Training Loss: 0.3156
Epoch 15 - Training Loss: 0.2982
Epoch 16 - Training Loss: 0.2752
Epoch 17 - Training Loss: 0.2621
Epoch 18 - Training Loss: 0.2467
Epoch 19 - Training Loss: 0.2336
Epoch 20 - Training Loss: 0.2247
Epoch 21 - Training Loss: 0.2084
Epoch 22 - Training Loss: 0.1971
Epoch 23 - Training Loss: 0.1922
Epoch 24 - Training Loss: 0.1785
Epoch 25 - Training Loss: 0.1748
Epoch 26 - Training Loss: 0.1691
Epoch 27 - Training Loss: 0.1607
Epoch 28 - Training Loss: 0.1516
Epoch 29 - Training Loss: 0.1521
Epoch 30 - Training Loss: 0.1492
Epoch 31 - Training

In [13]:
# evaluation

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")


Test Accuracy: 66.35%
