In [1]:
# Filename: custom_vit_tiny_imagenet.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import os

# Set random seed for reproducibility
torch.manual_seed(42)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data augmentation and normalization for training
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

# Only normalization for validation and test
val_transforms = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
])

# Data loaders
batch_size = 64

# Update the paths according to your directory structure
train_dataset = datasets.ImageFolder(
    root='tiny-imagenet-200/train',
    transform=train_transforms
)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4
)

val_dataset = datasets.ImageFolder(
    root='tiny-imagenet-200/val',
    transform=val_transforms
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

# Number of classes
num_classes = 200

# Helper function to extract patches
def extract_patches(x, patch_size):
    """
    Extract patches from images.
    x: (batch_size, channels, height, width)
    Returns: (batch_size, num_patches, patch_size*patch_size*channels)
    """
    batch_size, channels, height, width = x.shape
    num_patches = (height // patch_size) * (width // patch_size)
    x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
    x = x.view(batch_size, num_patches, -1)
    return x

# Custom cross-patch attention module
class CrossPatchAttention(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super(CrossPatchAttention, self).__init__()
        self.num_heads = num_heads
        self.scale = (emb_dim // num_heads) ** -0.5

        self.to_q = nn.Linear(emb_dim, emb_dim)
        self.to_k = nn.Linear(emb_dim, emb_dim)
        self.to_v = nn.Linear(emb_dim, emb_dim)
        self.to_out = nn.Linear(emb_dim, emb_dim)

    def forward(self, q_input, kv_input):
        B, N_q, D = q_input.shape
        _, N_kv, _ = kv_input.shape

        q = self.to_q(q_input).view(B, N_q, self.num_heads,
                                    D // self.num_heads).transpose(1, 2)
        k = self.to_k(kv_input).view(B, N_kv, self.num_heads,
                                     D // self.num_heads).transpose(1, 2)
        v = self.to_v(kv_input).view(B, N_kv, self.num_heads,
                                     D // self.num_heads).transpose(1, 2)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_weights = torch.softmax(attn_scores, dim=-1)

        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).reshape(B, N_q, D)
        out = self.to_out(attn_output)
        return out

# Vision Transformer Model
class VisionTransformer(nn.Module):
    def __init__(self, img_size=64, patch_sizes=[8, 16],
                 emb_dim=256, num_heads=8, num_classes=200, depth=6):
        super(VisionTransformer, self).__init__()
        self.emb_dim = emb_dim

        # Positional Encodings
        num_patches_8 = (img_size // 8) ** 2
        num_patches_16 = (img_size // 16) ** 2
        self.pos_embed_8 = nn.Parameter(
            torch.randn(1, num_patches_8, emb_dim)
        )
        self.pos_embed_16 = nn.Parameter(
            torch.randn(1, num_patches_16, emb_dim)
        )

        # Patch Embedding Layers
        self.patch_embed_8 = nn.Linear(8 * 8 * 3, emb_dim)
        self.patch_embed_16 = nn.Linear(16 * 16 * 3, emb_dim)

        # Transformer Encoder Layers
        self.transformer_blocks = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(emb_dim),
                CrossPatchAttention(emb_dim, num_heads),
                nn.LayerNorm(emb_dim),
                nn.Sequential(
                    nn.Linear(emb_dim, emb_dim * 4),
                    nn.GELU(),
                    nn.Linear(emb_dim * 4, emb_dim)
                )
            ]) for _ in range(depth)
        ])

        # Classification Head
        self.classifier = nn.Sequential(
            nn.LayerNorm(emb_dim),
            nn.Linear(emb_dim, num_classes)
        )

    def forward(self, x):
        B = x.size(0)

        # Extract patches and embeddings for patch size 8 (Q)
        patches_8 = extract_patches(x, patch_size=8)  # Shape: (B, 64, 192)
        embeddings_8 = self.patch_embed_8(patches_8)  # Shape: (B, 64, D)
        embeddings_8 += self.pos_embed_8.to(x.device)

        # Extract patches and embeddings for patch size 16 (K, V)
        patches_16 = extract_patches(x, patch_size=16)  # Shape: (B, 16, 768)
        embeddings_16 = self.patch_embed_16(patches_16)  # Shape: (B, 16, D)
        embeddings_16 += self.pos_embed_16.to(x.device)

        q = embeddings_8
        kv = embeddings_16

        # Pass through transformer blocks
        for ln1, attn, ln2, mlp in self.transformer_blocks:
            # Custom attention with Q from finer patches and K,V from coarser patches
            q = q + attn(ln1(q), ln1(kv))
            # MLP
            q = q + mlp(ln2(q))

        # Pooling (take mean over sequence length)
        out = q.mean(dim=1)  # Shape: (B, D)

        # Classification head
        logits = self.classifier(out)  # Shape: (B, num_classes)
        return logits

# Accuracy metrics
def accuracy(output, target, topk=(1, 3, 5)):
    """Computes the accuracy over the k top predictions."""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
    pred = pred.t()  # Shape: (maxk, batch_size)
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append((correct_k / batch_size).item())
    return res  # Returns a list of accuracies for topk

# Training function
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    top1_acc = 0.0
    top3_acc = 0.0
    top5_acc = 0.0

    for batch_idx, (inputs, targets) in enumerate(
        tqdm(train_loader, desc=f"Epoch {epoch} [Train]")
    ):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # Metrics
        acc1, acc3, acc5 = accuracy(outputs, targets)
        running_loss += loss.item()
        top1_acc += acc1
        top3_acc += acc3
        top5_acc += acc5

    num_batches = len(train_loader)
    avg_loss = running_loss / num_batches
    avg_top1_acc = top1_acc / num_batches
    avg_top3_acc = top3_acc / num_batches
    avg_top5_acc = top5_acc / num_batches

    return avg_loss, avg_top1_acc, avg_top3_acc, avg_top5_acc

# Validation function
def validate(model, device, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    top1_acc = 0.0
    top3_acc = 0.0
    top5_acc = 0.0

    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc="Validation"):
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Metrics
            acc1, acc3, acc5 = accuracy(outputs, targets)
            running_loss += loss.item()
            top1_acc += acc1
            top3_acc += acc3
            top5_acc += acc5

    num_batches = len(val_loader)
    avg_loss = running_loss / num_batches
    avg_top1_acc = top1_acc / num_batches
    avg_top3_acc = top3_acc / num_batches
    avg_top5_acc = top5_acc / num_batches

    return avg_loss, avg_top1_acc, avg_top3_acc, avg_top5_acc

# Instantiate model, criterion, optimizer
model = VisionTransformer(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)

# Learning rate scheduler (optional)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Training loop
num_epochs = 10
train_losses, val_losses = [], []
train_acc1_list, val_acc1_list = [], []
train_acc3_list, val_acc3_list = [], []
train_acc5_list, val_acc5_list = [], []

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc1, train_acc3, train_acc5 = train(
        model, device, train_loader, optimizer, criterion, epoch
    )
    val_loss, val_acc1, val_acc3, val_acc5 = validate(
        model, device, val_loader, criterion
    )

    scheduler.step()

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_acc1_list.append(train_acc1)
    val_acc1_list.append(val_acc1)
    train_acc3_list.append(train_acc3)
    val_acc3_list.append(val_acc3)
    train_acc5_list.append(train_acc5)
    val_acc5_list.append(val_acc5)

    print(f"\nEpoch {epoch} Summary:")
    print(f"Train Loss: {train_loss:.4f}, "
          f"Top-1 Acc: {train_acc1*100:.2f}%, "
          f"Top-3 Acc: {train_acc3*100:.2f}%, "
          f"Top-5 Acc: {train_acc5*100:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, "
          f"Top-1 Acc: {val_acc1*100:.2f}%, "
          f"Top-3 Acc: {val_acc3*100:.2f}%, "
          f"Top-5 Acc: {val_acc5*100:.2f}%")

# Plotting training and validation accuracy
epochs = range(1, num_epochs + 1)

plt.figure(figsize=(12, 6))
plt.plot(epochs, train_acc1_list, 'b-', label='Train Top-1 Acc')
plt.plot(epochs, val_acc1_list, 'r-', label='Val Top-1 Acc')
plt.plot(epochs, train_acc3_list, 'b--', label='Train Top-3 Acc')
plt.plot(epochs, val_acc3_list, 'r--', label='Val Top-3 Acc')
plt.plot(epochs, train_acc5_list, 'b-.', label='Train Top-5 Acc')
plt.plot(epochs, val_acc5_list, 'r-.', label='Val Top-5 Acc')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()


Using device: cuda


Epoch 1 [Train]:   0%|          | 0/1563 [00:00<?, ?it/s]

: 