## Vision Transformer (ViT)

In this assignment we're going to work with Vision Transformer. We will start to build our own vit model and train it on an image classification task.
The purpose of this homework is for you to get familar with ViT and get prepared for the final project.

In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# VIT Implementation

The vision transformer can be seperated into three parts, we will implement each part and combine them in the end.

For the implementation, feel free to experiment different kinds of setup, as long as you use attention as the main computation unit and the ViT can be train to perform the image classification task present later.
You can read about the ViT implement from other libary: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py and https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

## PatchEmbedding
PatchEmbedding is responsible for dividing the input image into non-overlapping patches and projecting them into a specified embedding dimension. It uses a 2D convolution layer with a kernel size and stride equal to the patch size. The output is a sequence of linear embeddings for each patch.

In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
      super(PatchEmbedding, self).__init__()
      self.image_size = image_size
      self.patch_size = patch_size
      self.num_patches = (image_size // patch_size) ** 2

      # kernel_size = stride = patch_size ensures non-overlapping patches
      self.proj = nn.Conv2d(
          in_channels,
          embed_dim,
          kernel_size=patch_size,
          stride=patch_size
      )

    def forward(self, x):
      x = self.proj(x)
      x = x.flatten(2)
      x = x.transpose(1, 2)
      return x

## MultiHeadSelfAttention

This class implements the multi-head self-attention mechanism, which is a key component of the transformer architecture. It consists of multiple attention heads that independently compute scaled dot-product attention on the input embeddings. This allows the model to capture different aspects of the input at different positions. The attention outputs are concatenated and linearly transformed back to the original embedding size.

In [4]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
      super(MultiHeadSelfAttention, self).__init__()
      assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

      self.embed_dim = embed_dim
      self.num_heads = num_heads
      self.head_dim = embed_dim // num_heads

      # linear projections for Q, K, V
      self.qkv = nn.Linear(embed_dim, embed_dim * 3)
      # output projection
      self.proj = nn.Linear(embed_dim, embed_dim)
      # scaling factor for attention scores
      self.scale = self.head_dim ** -0.5


    def forward(self, x):
      # x shape: (B, N, D) where B=batch, N=num_patches+1, D=embed_dim
      B, N, D = x.shape

      # get Q, K, V in one go
      qkv = self.qkv(x)  # (B, N, 3*D)

      # split into Q, K, V
      q, k, v = qkv.chunk(3, dim=-1)  # each is (B, N, D)

      # reshape for multi-head attention: (B, N, D) -> (B, num_heads, N, head_dim)
      q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
      k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
      v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)

      # scaled dot-product attention
      attn_scores = (q @ k.transpose(-2, -1)) * self.scale  # (B, num_heads, N, N)
      attn_weights = F.softmax(attn_scores, dim=-1)         # (B, num_heads, N, N)
      attn_output = attn_weights @ v                         # (B, num_heads, N, head_dim)

      # concatenate heads: (B, num_heads, N, head_dim) -> (B, N, D)
      attn_output = attn_output.transpose(1, 2).reshape(B, N, D)

      # final projection
      output = self.proj(attn_output)
      return output

## TransformerBlock
This class represents a single transformer layer. It includes a multi-head self-attention sublayer followed by a position-wise feed-forward network (MLP). Each sublayer is surrounded by residual connections.
You may also want to use layer normalization or other type of normalization.

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
        super(TransformerBlock, self).__init__()

        # Multi-head self-attention
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)

        # MLP (Feed-forward network)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )

        # layer normalizations
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # dropout for attention output
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # first sub-layer: Multi-head attention with residual connection
        # pre-norm: normalize before attention
        attn_output = self.attn(self.norm1(x))
        x = x + self.dropout(attn_output)
        # second sub-layer: MLP with residual connection
        # pre-norm: normalize before MLP
        mlp_output = self.mlp(self.norm2(x))
        x = x + mlp_output

        return x

## VisionTransformer:
This is the main class that assembles the entire Vision Transformer architecture. It starts with the PatchEmbedding layer to create patch embeddings from the input image. A special class token is added to the sequence, and positional embeddings are added to both the patch and class tokens. The sequence of patch embeddings is then passed through multiple TransformerBlock layers. The final output is the logits for all classes

In [6]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
      super(VisionTransformer, self).__init__()

      # patch embedding layer
      self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
      num_patches = self.patch_embed.num_patches

      # class token
      self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

      # positional embeddings; add +1 for the class token
      self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

      # positional dropout
      self.dropout = nn.Dropout(dropout)

      # transformer encoder blocks
      self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_dim, dropout)
            for _ in range(num_layers)
        ])

      # layer normalization before classification head
      self.norm = nn.LayerNorm(embed_dim)

      # classification head
      self.head = nn.Linear(embed_dim, num_classes)

      # initialize weights
      self._init_weights()

    def _init_weights(self):
        # initialize positional embeddings with truncated normal
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

        # initialize the classification head
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.zeros_(self.head.bias)

    def forward(self, x):
        # x shape: (batch_size, in_channels, image_size, image_size)
        batch_size = x.shape[0]

        # patch embedding
        x = self.patch_embed(x) # (batch_size, num_patches, embed_dim)

        # prepend class token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1) # (1, 1, embed_dim) -> (batch_size, 1, embed_dim)
        # concatenate: (batch_size, num_patches + 1, embed_dim)
        x = torch.cat([cls_tokens, x], dim=1)

        # add positional embeddings
        x = x + self.pos_embed
        x = self.dropout(x)

        # pass through transformer blocks
        for block in self.blocks:
            x = block(x)

        # apply final layer normalization
        x = self.norm(x)

        # extract class token and pass through classification head
        cls_token_final = x[:, 0]
        logits = self.head(cls_token_final)

        return logits

## Let's train the ViT!

We will train the vit to do the image classification with cifar100. Free free to change the optimizer and or add other tricks to improve the training

In [7]:
image_size = 32
patch_size = 4
in_channels = 3
embed_dim = 256
num_heads = 8
mlp_dim = 512
num_layers = 6
num_classes = 100
dropout = 0.1

batch_size = 64
num_epochs = 200

In [8]:
model = VisionTransformer(image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout).to(device)
input_tensor = torch.randn(1, in_channels, image_size, image_size).to(device)
output = model(input_tensor)
print(output.shape)
print(f"Number of parameters = {round(sum(p.numel() for p in model.parameters()) / 10**6, 2)}M")

torch.Size([1, 100])
Number of parameters = 3.22M


In [9]:
# Load the CIFAR-100 dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    # transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

100%|██████████| 169M/169M [00:03<00:00, 49.4MB/s]


In [10]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=6e-4,
    weight_decay=0.2
)

warmup_ratio = 0.1
warmup_epochs = int(num_epochs * warmup_ratio)

warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.1,
    end_factor=1.0,
    total_iters=warmup_epochs
)

cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_epochs - warmup_epochs,
    eta_min=1e-6
)

scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[warmup_scheduler, cosine_scheduler],
    milestones=[warmup_epochs]
)

In [None]:
from tqdm.auto import tqdm

progress_bar = tqdm(total=len(trainloader) * num_epochs)

# Train the model
best_val_acc = 0
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Track training metrics
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

        # Update progress bar
        progress_bar.set_postfix({
            'epoch': f'{epoch+1}/{num_epochs}',
            'train_loss': f'{loss.item():.4f}',
            'train_acc': f'{100 * train_correct / train_total:.4f}%'
        })
        progress_bar.update(1)

    # Compute epoch metrics
    train_acc = 100 * train_correct / train_total
    avg_train_loss = train_loss / len(trainloader)

    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]

    # Validate the model
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_acc = 100 * correct / total
    progress_bar.write(f"Epoch [{epoch+1}/{num_epochs}] - train_loss: {avg_train_loss:.4f}, train_acc: {train_acc:.2f}%, val_acc: {val_acc:.2f}%, LR: {current_lr:.6f}")

    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        progress_bar.write(f"✓ Epoch {epoch+1}: New best model saved! Val Acc: {val_acc:.2f}%")

progress_bar.close()
print(f"Training completed! Best validation accuracy: {best_val_acc:.2f}%")

Link to `best_model.pth`: https://drive.google.com/file/d/1NBnCtDfnf-hHGZ3dhk_gBl5O78ZMgUlL/view?usp=drive_link

Link to other trained models: https://drive.google.com/drive/folders/1Ozm-WQh_RIqfuyVQq5bEIYuS7vRbw77z?usp=sharing

In [11]:
import glob

# Function to evaluate the model on the test set
def evaluate_model(model, testloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_acc = 100 * correct / total
    print(f"Val Accuracy: {val_acc:.2f}%")

# Loop over all .pth files
for model_path in glob.glob(f"*.pth"):
    print(f"Evaluating: {model_path}")

    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)

    # Evaluate
    evaluate_model(model, testloader, device)

Evaluating: best_model.pth
Val Accuracy: 67.32%


Please submit your best_model.pth with this notebook. And report the best test results you get.