## Vision Transformer (ViT)

In this assignment we're going to work with Vision Transformer. We will start to build our own vit model and train it on an image classification task.
The purpose of this homework is for you to get familar with ViT and get prepared for the final project. 

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.optim as optim

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# VIT Implementation

The vision transformer can be seperated into three parts, we will implement each part and combine them in the end.

For the implementation, feel free to experiment different kinds of setup, as long as you use attention as the main computation unit and the ViT can be train to perform the image classification task present later.
You can read about the ViT implement from other libary: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py and https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

## PatchEmbedding
PatchEmbedding is responsible for dividing the input image into non-overlapping patches and projecting them into a specified embedding dimension. It uses a 2D convolution layer with a kernel size and stride equal to the patch size. The output is a sequence of linear embeddings for each patch.

- Use a CNN layer with kernel size = patch_size and stride = patch_size
- Output is a sequence of linear embeddings for each patch - reshape

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
      super().__init__()
      self.embed_dim = embed_dim
      self.image_size = image_size 
      self.patch_size = patch_size
      self.in_channels = in_channels

      self.cnn = nn.Sequential(nn.Conv2d(in_channels, embed_dim, patch_size, patch_size)) # Takes each patch and converts it into embedding dimension


    def forward(self, x):
      self.batch_size = x.shape[0]
      x = self.cnn(x) # Shape is batch_size * embed_dim * num_patches_w * num_patches_h
      x.permute(0, 2, 3, 1) # batch_size * __ * __ * embed_dim -> Moved embedding dimension towards the end
      return x.view(self.batch_size, -1, self.embed_dim) # batch_size * num_patches * embed_dim [equivalent of batch_size * sequence_length * embed_dim]

## MultiHeadSelfAttention

This class implements the multi-head self-attention mechanism, which is a key component of the transformer architecture. It consists of multiple attention heads that independently compute scaled dot-product attention on the input embeddings. This allows the model to capture different aspects of the input at different positions. The attention outputs are concatenated and linearly transformed back to the original embedding size.

- Multiple attention heads independently compute scaled dot-product attention on input embeddings 
- Individual heads are concatenated and linearly transformed to the original embedding dimension
- Calculate attention also here
- matmul(Q, K)
- softmax(in the embedding dimension)
- matmul V with softmax

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
      super().__init__()
      self.embed_dim = embed_dim
      self.num_heads = num_heads
      # weight matrices
      self.wQ = nn.Linear(embed_dim, embed_dim) #Typically kept equal to the embedding dimension
      self.wK = nn.Linear(embed_dim, embed_dim) #Typically kept equal to the embedding dimension
      self.wV = nn.Linear(embed_dim, embed_dim) #Typically kept equal to the embedding dimension
      # Remeber that the weights of all the attention heads are initialized to the same value
      ## Hence, the linear transformation on the input should occur first and then the splitting of the attention heads - Correct
      ### Why is W_q embed_dim * embed_dim when the input is batch_size * num_patches * embed_dim
      ### This means that each patch is also supplied the same weight? - Yes - Improves generalization and the dimension is embedding * embedding to keep dimensionality consistetly equal to batch_size * num_patches * embed_dim 

      self.wH = nn.Linear(embed_dim, embed_dim) 


    def split_heads(self, x, num_heads):
      self.num_heads = num_heads
      self.rf_size = int(x.shape[2]/num_heads) # each attention head works on 512/8(for eight heads) = 64 of the embedding dimension
     # rf_size -> RF stands for receptive field
      return x.view(x.shape[0], -1, num_heads, self.rf_size).transpose(1, 2) # (batch_size, num_heads, embedding_dimension, RF_size)

    def group_heads(self, x):
      # Takes (batch_size, num_heads, RF_size, embedding_dimension) and concatenates it to batch_size
      # (batch_size, num_heads, embedding_dimension, RF_size) -> # (batch_size, embedding_dimension, num_heads, RF_size)
      return x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.num_heads * self.rf_size) # Contiguous is done to ensure that memory takes up contiguous positions
      


    def forward(self, x):
      ## Your image has been broken down into patches 
      ## Break down into multiple attention heads = self.num_heads

      Q = self.split_heads(self.wQ(x), self.num_heads) # batch_size * num_patches * embed_dim -> Shared weights across all tokens 
      Q = Q/np.sqrt(self.embed_dim) # Same scaling factor as Language transformer 
      K = self.split_heads(self.wK(x), self.num_heads) 
      V = self.split_heads(self.wV(x), self.num_heads)
      A = torch.matmul(Q, K.transpose(-2, -1)) # If there are more than 2 dimensions, PyTorch identifies the first dimsnion as the batch size and handles it 
      # So A is batch_size * num_patches * num_patches
      # Now for each of the patches, we will take a softmax
      A = torch.softmax(A, dim = -1) #Impute normalizing constant -> Done in Q
      # Now we have to group the heads together
      H = torch.matmul(A, V) # batch_size * num_patches * num_patches and # batch_size * num_patches * embed_dim -> batch_size * num_patches * embed_dim (embed_dim/num_heads)
      H_cat = self.group_heads(H) # batch_size * num_patches * embed_dim
      H_cat = self.wH(H_cat) # Ensure dimensional consistency -> Again only scales the last dimension
      return H_cat, A

## TransformerBlock
This class represents a single transformer layer. It includes a multi-head self-attention sublayer followed by a position-wise feed-forward network (MLP). Each sublayer is surrounded by residual connections.
You may also want to use layer normalization or other type of normalization.

In [None]:
# Implement GeLU
class GeLU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2 / torch.pi)) * (x + 0.044715 * torch.pow(x, 3))))

In [None]:
# pip install --upgrade torch torchvision

In [None]:
class CNN(nn.Module):
    def __init__(self, d_model, hidden_dim):
        super().__init__()
        self.k1convL1 = nn.Linear(d_model,    hidden_dim)
        self.k1convL2 = nn.Linear(hidden_dim, d_model)
        self.activation = GeLU()

    def forward(self, x):
        x = self.k1convL1(x)
        x = self.activation(x)
        x = self.k1convL2(x)
        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
      super().__init__()
      self.embed_dim = embed_dim
      self.mlp_dim = mlp_dim
      self.num_heads = num_heads
      self.dropout = dropout
      self.mha = MultiHeadSelfAttention(embed_dim, num_heads) # Passing x to forward 
      self.cnn = CNN(embed_dim, mlp_dim)
      self.dropout1 = nn.Dropout(dropout)
      self.dropout2 = nn.Dropout(dropout)


      ## Batch norm layer
      self.ln1 = nn.LayerNorm(embed_dim, eps=1e-6) # Small value added to the denominator to avoid division by 0
      ## MLP layer

      ## Batch norm layer
      self.ln2 = nn.LayerNorm(embed_dim, eps=1e-6)
      ## where to dropout

    def forward(self, x):
      ## Residual connections
        h, _ = self.mha(x)
        h = self.dropout1(h) # Dropout after MHA
        x = self.ln1(x + h) # Residual + MHA passed to LayerNorm
        fcn_x = self.cnn(x)
        fcn_x = self.dropout2(fcn_x)
        x = self.ln2(x + fcn_x)
        return x
         

## VisionTransformer:
This is the main class that assembles the entire Vision Transformer architecture. It starts with the PatchEmbedding layer to create patch embeddings from the input image. A special class token is added to the sequence, and positional embeddings are added to both the patch and class tokens. The sequence of patch embeddings is then passed through multiple TransformerBlock layers. The final output is the logits for all classes

In [None]:
# This is not the right way to implement positional embeddings; ViT uses learnable positional embeddings 
# ten = np.array([np.zeros(512) for i in range(196)])
# nums = np.array([i for i in range(512)])
# for i in range(196):
#   ten[i, nums[i]] = 1

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.dropout = dropout
        self.max_embedding = image_size/patch_size
        self.embeddings = PatchEmbedding(image_size, patch_size, in_channels, embed_dim) # Creates embeddings
        ## Add positional embeddings 
        self.pos_embeddings = nn.Embedding(self.max_embedding + 1, embed_dim) # +1 for the CLS token

        ## Some number of transformer layers
        self.enc_layers = nn.ModuleList()
        for _ in range(num_layers):
          self.enc_layers.append(TransformerBlock(embed_dim, num_heads, mlp_dim, dropout))


    def forward(self, x):
        embeddings = self.embeddings(x)
        # Add the CLS token 

        # Initialize embeddings randomly 
        pos_embeddings = self.pos_embeddings()
        x = embeddings + pos_embeddings

        # Run through the ViT blocks
        for i in range(self.num_layers):
          x = self.enc_layers[i](x)

        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.dropout = dropout

        self.embeddings = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)  # Creates embeddings

        # +1 for the CLS token
        self.pos_embeddings = nn.Parameter(torch.zeros(1, (image_size // patch_size)**2 + 1, embed_dim))

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        self.enc_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.enc_layers.append(TransformerBlock(embed_dim, num_heads, mlp_dim, dropout))

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, x, stochastic_depth_rate = 0.1):
      # Implemented stochastic depth to randomly dropout layers
        batch_size = x.shape[0]
        embeddings = self.embeddings(x)

        # Add the CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, embeddings), dim=1)

        # Add positional embeddings
        x = x + self.pos_embeddings

        # Run through the ViT blocks
        for i in range(self.num_layers):
            if self.training and torch.rand(1).item()<stochastic_depth_rate:
              continue
            else:
              x = self.enc_layers[i](x)

        # Classifier head
        x = x[:, 0]
        x = self.mlp_head(x)

        return x

## Let's train the ViT!

We will train the vit to do the image classification with cifar100. Free free to change the optimizer and or add other tricks to improve the training

# Example usage:
image_size = 224
patch_size = 16
in_channels = 3
embed_dim = 384
num_heads = 2
mlp_dim = 3072
num_layers = 1
num_classes = 100
dropout = 0.1
batch_size = 64
- 19.61%

image_size = 224
patch_size = 16
in_channels = 3
embed_dim = 384
num_heads = 2
mlp_dim = 3072
num_layers = 1
num_classes = 100
dropout = 0.1
batch_size = 64
- 17.92%

image_size = 224
in_channels = 3
num_classes = 100
dropout = 0.1
batch_size = 64

patch_size = 16
embed_dim = 384
num_heads = 4
mlp_dim = 1536
num_layers = 1
- BAD

image_size = 32
patch_size = 8
in_channels = 3
embed_dim = 256
num_heads = 4
mlp_dim = 2048
num_layers = 12
num_classes = 100
dropout = 0.1
batch_size = 128
learning_rate = 0.0001
weight_decay = 0.01
num_epochs = 50
- 39%
- num_steps_per_epoch = len(trainloader)
T_0 = num_steps_per_epoch * 10  # Restart every 10 epochs
T_mult = 1
lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, eta_min=1e-6)

image_size = 96
patch_size = 8
in_channels = 3
embed_dim = 384
num_heads = 6
mlp_dim = 1536
num_layers = 1
num_classes = 100
dropout = 0.1
batch_size = 64
learning_rate = 0.0001
weight_decay = 0.0001
num_epochs = 70
- Overfit

In [None]:
image_size = 32
patch_size = 8 # 4 or 8
in_channels = 3
embed_dim = 300 # Try 512
num_heads = 12
mlp_dim = embed_dim*2 
num_layers = 9 # Increase num_layers
num_classes = 100
dropout = 0.25
batch_size = 256 # Lower batch size
num_epochs = 200

In [None]:
model = VisionTransformer(image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout).to(device)
input_tensor = torch.randn(1, in_channels, image_size, image_size).to(device)
output = model(input_tensor)
print(output.shape)

In [None]:
# Keep around 2M parameters - experimenting with approx 4M parameters
print (sum(p.numel() for p in model.parameters()))

In [None]:
# Load the CIFAR-100 dataset
from PIL import ImageOps
from torchvision.transforms import Lambda

class AutoAugment:
    def __call__(self, img):
        return ImageOps.autocontrast(img)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=15),
    transforms.RandomGrayscale(p=0.1),
    Lambda(lambda img: AutoAugment()(img)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
])

transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
criterion = nn.CrossEntropyLoss(label_smoothing = 0.1)
learning_rate = 0.0002
weight_decay = 0.00005
optimizer = optim.AdamW(model.parameters(), lr = learning_rate, weight_decay = weight_decay)
num_steps_per_epoch = len(trainloader)
T_0 = num_steps_per_epoch * 10  # Restart every 10 epochs
T_mult = 1
lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, eta_min=1e-6)

In [None]:
# from torch.optim.lr_scheduler import LambdaLR
# warmup_epochs = 10
# base_lr = 0.0005
# def lr_schedule(epoch):
#     if epoch < warmup_epochs:
#         return (batch_size / 256) * (epoch + 1) / warmup_epochs
#     else:
#         t = (epoch - warmup_epochs) / (num_epochs - warmup_epochs)
#         return (batch_size / 256) * 0.5 * (1 + math.cos(math.pi * t))

# lr_scheduler = LambdaLR(optimizer, lr_schedule)

In [None]:
checkpoint_path = "/content/drive/MyDrive/DL-HW4/best_model5.pth"
# checkpoint = torch.load(checkpoint_path)
# # # Restore the model and optimizer states
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

- best_model1: 40% - Model starts to overfit after 20th epoch

In [None]:
# Train the model
num_epochs = num_epochs
best_val_acc = 0
train_losses = []
val_losses = []
val_accuracies = []
for epoch in range(num_epochs):
    model.train()
    running_train_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        running_train_loss = running_train_loss + loss.item()
        correct += (predicted==labels).sum().item()
        total += labels.size(0)
    train_acc = 100*correct/total
    print(f"Epoch: {epoch + 1}, Training Accuracy: {train_acc:.2f}%")
    avg_train_loss = running_train_loss / len(trainloader)
    train_losses.append(avg_train_loss)
    # Validate the model
    model.eval()
    correct = 0
    total = 0
    running_val_loss = 0.0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_losses.append(running_val_loss/len(testloader))
    val_acc = 100 * correct / total
    val_accuracies.append(val_acc)
    print(f"Epoch: {epoch + 1}, Validation Accuracy: {val_acc:.2f}%")

    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, checkpoint_path)

In [None]:
import matplotlib.pyplot as plt

# Plot training and validation loss
plt.figure(figsize=(12, 6))
plt.plot(train_losses, label="Training Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.show()

# Plot validation accuracy
plt.figure(figsize=(12, 6))
plt.plot(val_accuracies)
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Validation Accuracy")
plt.show()

Please submit your best_model.pth with this notebook. And report the best test results you get.