-  **Ref:** [DifferentialAttentionHead](https://medium.com/@AykutCayir34/lets-implement-differential-transformer-paper-0e4499659604)

In [9]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import numpy as np

import wandb
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import itertools
import os
import yaml


In [10]:
with open("vit_config.yaml", "r") as file:
    config = yaml.safe_load(file)

# Assign the config values to the corresponding variables
d_model   = config["d_model"]
n_classes = config["n_classes"]
img_size  = config["img_size"]
patch_size = config["patch_size"]
n_channels = config["n_channels"]
n_heads   = config["n_heads"]
n_layers  = config["n_layers"]
batch_size = config["batch_size"]
epochs    = config["epochs"]
alpha     = config["alpha"]

exp_name = f"vit-patchsize-{patch_size[0]}-attention_head-{n_heads}-layer-{n_layers}"

wandb.init(project = "vit-image-classification", name = exp_name)

config = {
    "d_model": d_model,
    "n_classes": n_classes,
    "img_size": img_size,
    "patch_size": patch_size,
    "n_channels": n_channels,
    "n_heads": n_heads,
    "n_layers": n_layers,
    "batch_size": batch_size,
    "epochs": epochs,
    "alpha": alpha
}

wandb.config.update(config)


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁
train_loss,█▁
validation_loss,█▁

0,1
epoch,1.0
train_loss,2.20715
validation_loss,2.14251


In [11]:
transform = T.Compose([
  # T.Resize(img_size),
  T.ToTensor()
])

train_set = CIFAR10(
  root="/home/akash/ws/cv_assignment/assignment-5-MlLearnerAkash/Q1/dataset", train=True, download=True, transform=transform
)
test_set = CIFAR10(
  root="/home/akash/ws/cv_assignment/assignment-5-MlLearnerAkash/Q1/dataset", train=False, download=True, transform=transform
)


train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=False, batch_size=batch_size)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
from diff_vit import VisionTransformer


def train_transformer(transformer,save_path, criterion, epochs, optimizer):
   
    # Setup
    init_val_loss = np.inf
    os.makedirs(save_path, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device, 
          f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
    
    # Ensure the model is on the proper device.
    transformer.to(device)
    
    # Training & Validation loop
    for epoch in range(epochs):
        transformer.train()
        training_loss = 0.0
        
        # Training loop
        for i, (inputs, labels) in enumerate(train_loader, 0):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = transformer(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            training_loss += loss.item()
        
        avg_loss = training_loss / len(train_loader)
        print(f'Epoch {epoch + 1}/{epochs} - Train loss: {avg_loss:.3f}')
        wandb.log({"epoch": epoch + 1, "train_loss": avg_loss})
        
        # Validation loop
        transformer.eval()
        validation_loss = 0.0
        with torch.no_grad():
            for val_inputs, val_labels in test_loader:
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                val_outputs = transformer(val_inputs)
                val_loss = criterion(val_outputs, val_labels)
                validation_loss += val_loss.item()
        
        avg_val_loss = validation_loss / len(test_loader)
        print(f'Epoch {epoch + 1}/{epochs} - Validation loss: {avg_val_loss:.3f}')
        wandb.log({"validation_loss": avg_val_loss})
        
        # Save best model based on validation loss
        if avg_val_loss < init_val_loss:
            init_val_loss = avg_val_loss
            torch.save(transformer.state_dict(), os.path.join(save_path, "best.pt"))
        
        # Log a few sample predictions from the last validation batch.
        sample_inputs = val_inputs[:4].detach().cpu()
        sample_labels = val_labels[:4].detach().cpu()
        sample_outputs = val_outputs[:4].detach().cpu()
        _, sample_preds = torch.max(sample_outputs, 1)
        
        samples = []
        for idx in range(len(sample_inputs)):
            # Convert image from (C, H, W) to (H, W, C) for plotting.
            image_np = sample_inputs[idx].permute(1, 2, 0).numpy()
            plt.figure(figsize=(2,2))
            plt.imshow(image_np)
            plt.title(f"GT: {sample_labels[idx].item()} | Pred: {sample_preds[idx].item()}")
            plt.axis("off")
            fig = plt.gcf()
            samples.append(wandb.Image(fig))
            plt.close(fig)
        
        wandb.log({"sample_predictions": samples, "epoch": epoch + 1})

    


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
transformer = VisionTransformer(d_model, n_classes, img_size, patch_size, n_channels, n_heads, n_layers,10).to(device)

save_path = exp_name
criterion = nn.CrossEntropyLoss()
epochs = epochs
optimizer = Adam(transformer.parameters(), lr=alpha)
train_transformer(transformer = transformer,
                    save_path=save_path, 
                    criterion=criterion, 
                    epochs=epochs, 
                    optimizer=optimizer)

wandb.finish()

Using device:  cuda (NVIDIA GeForce RTX 4060 Ti)
Epoch 1/10 - Train loss: 2.214
Epoch 1/10 - Validation loss: 2.170
Epoch 2/10 - Train loss: 2.149
Epoch 2/10 - Validation loss: 2.105
Epoch 3/10 - Train loss: 2.101
Epoch 3/10 - Validation loss: 2.076
Epoch 4/10 - Train loss: 2.082
Epoch 4/10 - Validation loss: 2.059
Epoch 5/10 - Train loss: 2.073
Epoch 5/10 - Validation loss: 2.044
Epoch 6/10 - Train loss: 2.060
Epoch 6/10 - Validation loss: 2.038
Epoch 7/10 - Train loss: 2.051
Epoch 7/10 - Validation loss: 2.051
Epoch 8/10 - Train loss: 2.055
Epoch 8/10 - Validation loss: 2.043
Epoch 9/10 - Train loss: 2.051
Epoch 9/10 - Validation loss: 2.029
Epoch 10/10 - Train loss: 2.048
Epoch 10/10 - Validation loss: 2.041
