In [1]:
import os
import shutil
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.auto import tqdm

In [2]:
#  device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Hyperparameters ---
#  ViT on Tiny ImageNet
BATCH_SIZE = 64
EPOCHS = 30
LEARNING_RATE = 3e-4
NUM_CLASSES = 200
IMAGE_SIZE = 64
PATCH_SIZE = 16
EMBED_DIM = 256
NUM_HEADS = 8
DEPTH = 6
MLP_DIM = 512
DROPOUT = 0.1
CHANNELS = 3

Using device: cuda


In [3]:
#  Dataset and DataLoader
print("\n--- Downloading and Preparing Tiny ImageNet ---")
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip -q tiny-imagenet-200.zip
data_dir = './tiny-imagenet-200'

# Organize the validation data
val_dir = os.path.join(data_dir, 'val')
val_annotations_file = os.path.join(val_dir, 'val_annotations.txt')
val_img_dict = {}
with open(val_annotations_file, 'r') as f:
    for line in f.readlines():
        words = line.strip().split('\t')
        val_img_dict[words[0]] = words[1]
for img_filename, label in val_img_dict.items():
    src = os.path.join(val_dir, 'images', img_filename)
    dst = os.path.join(val_dir, label, img_filename)
    if not os.path.exists(os.path.join(val_dir, label)):
        os.makedirs(os.path.join(val_dir, label))
    if os.path.exists(src):
        shutil.move(src, dst)
shutil.rmtree(os.path.join(val_dir, 'images'))
os.remove(val_annotations_file)
print("Validation set successfully organized.")

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(IMAGE_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
    'val': transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
}
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train'])
test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), data_transforms['val'])
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")


--- Downloading and Preparing Tiny ImageNet ---
--2025-09-21 19:20:23--  http://cs231n.stanford.edu/tiny-imagenet-200.zip
Resolving cs231n.stanford.edu (cs231n.stanford.edu)... 171.64.64.64
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.64.64|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://cs231n.stanford.edu/tiny-imagenet-200.zip [following]
--2025-09-21 19:20:23--  https://cs231n.stanford.edu/tiny-imagenet-200.zip
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.64.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 248100043 (237M) [application/zip]
Saving to: ‘tiny-imagenet-200.zip’


2025-09-21 19:20:27 (63.4 MB/s) - ‘tiny-imagenet-200.zip’ saved [248100043/248100043]

Validation set successfully organized.
Number of training samples: 100000
Number of test samples: 10000


In [4]:
# ViT Model Arch
class PatchEmbeddings(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim):
        super().__init__()
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        return self.projection(x).flatten(2).transpose(1, 2)
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
        super().__init__()
        self.num_heads = num_heads
        head_dim = embed_dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return self.dropout(x)
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, dropout):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return self.dropout(x)
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_dim, embed_dim, dropout)
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x
class VisionTransformer(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, num_classes, embed_dim, depth, num_heads, mlp_dim, dropout):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = PatchEmbeddings(in_channels, patch_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)
        for block in self.transformer_blocks:
            x = block(x)
        x = self.norm(x)
        cls_token_output = x[:, 0]
        logits = self.head(cls_token_output)
        return logits

In [5]:
# Training and Evaluation Functions
def train_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss, total_correct = 0, 0
    with tqdm(dataloader, unit="batch") as tepoch:
        for images, labels in tepoch:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_correct += (predicted == labels).sum().item()
            tepoch.set_postfix(loss=total_loss / (tepoch.n+1), accuracy=100. * total_correct / len(dataloader.dataset))
    return total_loss / len(dataloader), 100. * total_correct / len(dataloader.dataset)
def evaluate(model, dataloader, device):
    model.eval()
    total_correct = 0
    with torch.no_grad(), tqdm(dataloader, unit="batch") as tepoch:
        for images, labels in tepoch:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_correct += (predicted == labels).sum().item()
            tepoch.set_postfix(accuracy=100. * total_correct / len(dataloader.dataset))
    return 100. * total_correct / len(dataloader.dataset)

# Loop
print("\n--- Initializing Model and Training ---")
model = VisionTransformer(
    img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, in_channels=CHANNELS, num_classes=NUM_CLASSES,
    embed_dim=EMBED_DIM, depth=DEPTH, num_heads=NUM_HEADS, mlp_dim=MLP_DIM, dropout=DROPOUT
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, loss_fn, device)
    test_acc = evaluate(model, test_loader, device)
    print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%")

print("\nTraining complete!")


--- Initializing Model and Training ---

Epoch 1/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 1 | Train Loss: 4.8354 | Train Acc: 4.37% | Test Acc: 7.24%

Epoch 2/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 2 | Train Loss: 4.5087 | Train Acc: 8.08% | Test Acc: 9.10%

Epoch 3/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 3 | Train Loss: 4.3532 | Train Acc: 10.13% | Test Acc: 11.08%

Epoch 4/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 4 | Train Loss: 4.2518 | Train Acc: 11.45% | Test Acc: 12.26%

Epoch 5/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 5 | Train Loss: 4.1702 | Train Acc: 12.66% | Test Acc: 13.63%

Epoch 6/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 6 | Train Loss: 4.0943 | Train Acc: 13.57% | Test Acc: 14.92%

Epoch 7/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 7 | Train Loss: 4.0202 | Train Acc: 14.79% | Test Acc: 15.75%

Epoch 8/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 8 | Train Loss: 3.9638 | Train Acc: 15.34% | Test Acc: 16.97%

Epoch 9/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 9 | Train Loss: 3.9062 | Train Acc: 16.34% | Test Acc: 17.06%

Epoch 10/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 10 | Train Loss: 3.8568 | Train Acc: 17.00% | Test Acc: 18.27%

Epoch 11/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 11 | Train Loss: 3.8082 | Train Acc: 17.79% | Test Acc: 18.89%

Epoch 12/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 12 | Train Loss: 3.7432 | Train Acc: 18.79% | Test Acc: 20.03%

Epoch 13/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 13 | Train Loss: 3.6985 | Train Acc: 19.46% | Test Acc: 20.52%

Epoch 14/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 14 | Train Loss: 3.6498 | Train Acc: 20.26% | Test Acc: 21.27%

Epoch 15/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 15 | Train Loss: 3.6005 | Train Acc: 21.01% | Test Acc: 21.59%

Epoch 16/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 16 | Train Loss: 3.5573 | Train Acc: 21.55% | Test Acc: 22.86%

Epoch 17/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 17 | Train Loss: 3.5165 | Train Acc: 22.23% | Test Acc: 23.28%

Epoch 18/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 18 | Train Loss: 3.4682 | Train Acc: 23.10% | Test Acc: 23.62%

Epoch 19/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 19 | Train Loss: 3.4319 | Train Acc: 23.63% | Test Acc: 24.42%

Epoch 20/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 20 | Train Loss: 3.3925 | Train Acc: 24.27% | Test Acc: 24.91%

Epoch 21/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 21 | Train Loss: 3.3542 | Train Acc: 25.19% | Test Acc: 25.27%

Epoch 22/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 22 | Train Loss: 3.3223 | Train Acc: 25.63% | Test Acc: 26.24%

Epoch 23/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 23 | Train Loss: 3.2879 | Train Acc: 26.13% | Test Acc: 26.14%

Epoch 24/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 24 | Train Loss: 3.2546 | Train Acc: 26.70% | Test Acc: 27.02%

Epoch 25/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 25 | Train Loss: 3.2216 | Train Acc: 27.10% | Test Acc: 26.84%

Epoch 26/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 26 | Train Loss: 3.1917 | Train Acc: 27.70% | Test Acc: 27.48%

Epoch 27/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 27 | Train Loss: 3.1614 | Train Acc: 28.25% | Test Acc: 27.81%

Epoch 28/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 28 | Train Loss: 3.1306 | Train Acc: 28.66% | Test Acc: 27.46%

Epoch 29/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 29 | Train Loss: 3.1066 | Train Acc: 29.28% | Test Acc: 28.38%

Epoch 30/30


  0%|          | 0/1563 [00:00<?, ?batch/s]

  0%|          | 0/157 [00:00<?, ?batch/s]

Epoch 30 | Train Loss: 3.0831 | Train Acc: 29.61% | Test Acc: 28.44%

Training complete!
