In [12]:
import torch.nn as nn
from torch.nn import Softmax, GELU
from dataclasses import dataclass
from einops import rearrange, repeat
import torch

# Model

In [None]:
@dataclass
class ImageParams:
    width: int
    height: int
    in_channel: int
@dataclass
class ModelParameters:
    patch_size: int
    inner_dim: int
    transformer_layers: int
    num_head: int
    embed_dropout: float
    attn_dropout: float
    mlp_dropout: float
@dataclass
class Hyperparameters:
    batch_size: int
    out_classes: int
    epochs: int
    learning_rate: float
    weight_decay: float


In [None]:
img_info = ImageParams(width=32, height=32, in_channel=3)
mparams = ModelParameters(patch_size=4, inner_dim=256, transformer_layers=6, num_head=4, embed_dropout=0.1, attn_dropout=0, mlp_dropout=0.1)
hparams = Hyperparameters(batch_size=1024, out_classes=10)

In [None]:
img_info = ImageParams(width=32, height=32, in_channel=3)
mparams = ModelParameters(patch_size=4, inner_dim=192, transformer_layers=12, num_head=3, embed_dropout=0.1, attn_dropout=0, mlp_dropout=0.1)
hparams = Hyperparameters(batch_size=1024, out_classes=10, epochs=2, learning_rate=5e-4*(1024/512), weight_decay=0.05)
class PatchEmbedding(nn.Module):
    def __init__(self, mparams, hparams, img_info):
        super(PatchEmbedding, self).__init__()
        self.patch_size = mparams.patch_size
        self.img_size = img_info.width
        self.num_patches = (self.img_size//self.patch_size) * (self.img_size//self.patch_size)
        self.D = mparams.inner_dim
        self.patch_embed = nn.Conv2d(
            in_channels=img_info.in_channel,
            out_channels=self.D,
            kernel_size=self.patch_size,
            stride=self.patch_size
        )
        self.cls_token = nn.Parameter(torch.rand(1,1,self.D))

    def forward(self, x):
        # Input: [B, C, H, W]
        # Output: [B, N, D] here N is selected num_patches(from image) + 1 (cls token)
        b = x.shape[0]
        cls_token = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = self.patch_embed(x)
        x = rearrange(x, 'b d h w -> b (h w) d')
        x = torch.cat((cls_token, x), dim=1)
        return x
class MHA(nn.Module):
    def __init__(self, mparams, hparams):
        super(MHA, self).__init__()
        self.D = mparams.inner_dim
        self.num_head = mparams.num_head
        assert self.D % self.num_head == 0 , 'Inner dimensions and number of attention head need to be perfectly divisible'
        self.head_size = self.D // self.num_head
        self.all_head_size = self.head_size * self.num_head
        # Set up QKV
        self.query = nn.Linear(in_features=self.D, out_features=self.all_head_size)
        self.key = nn.Linear(in_features=self.D, out_features=self.all_head_size)
        self.value = nn.Linear(in_features=self.D, out_features=self.all_head_size)
        self.output = nn.Linear(in_features=self.D, out_features=self.D)
        self.attn_dropout = nn.Dropout(mparams.attn_dropout)
        self.proj_dropout = nn.Dropout(mparams.attn_dropout)
        self.softmax = Softmax(dim=-1)
    def forward(self, x, mask= None):
        # Input: [B, N, D]
        # For atten: [B, num_head, num_patches, head_size]
        # Output: [B, N, D]
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        # For atten: [B, num_head, num_patches, head_size]
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_head)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_head)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_head)
        attn_score = torch.matmul(q, k.transpose(-1, -2))/ self.head_size**0.5
        if mask is not None:
            attn_score = attn_score.masked_fill(mask == 0, -1e9)
        attn_probs = self.softmax(attn_score)
        attn_probs = self.attn_dropout(attn_probs)
         # sum with V
        context = torch.matmul(attn_probs,v) #[B,h,n,d]
        # combine all heads
        context = rearrange(context, 'b h n d -> b n (h d) ')
        output = self.output(context)
        output = self.proj_dropout(output)
        return output
class MLP(nn.Module):
    def __init__(self, mparams, hparams):
        super().__init__()
        self.D = mparams.inner_dim
        self.hidden_dim = 4* self.D
        self.net = nn.Sequential(
            nn.Linear(self.D, self.hidden_dim),
            nn.GELU(),
            nn.Dropout(mparams.mlp_dropout),
            nn.Linear(self.hidden_dim, self.D),
            nn.Dropout(mparams.mlp_dropout)
        )
    def forward(self, x):
        return self.net(x)
class EncoderBlock(nn.Module):
    def __init__(self, mparams, hparams):
        super().__init__()
        self.norm1 = nn.LayerNorm(mparams.inner_dim)
        self.attn = MHA(mparams=mparams, hparams=hparams)
        self.norm2 = nn.LayerNorm(mparams.inner_dim)
        self.ffn = MLP(mparams=mparams, hparams=hparams)
    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x = self.attn(x) + residual

        residual = x
        x = self.norm2(x)
        x = self.ffn(x) + residual
        return x
class Transformer(nn.Module):
    def __init__(self, mparams, hparams):
        super().__init__()
        self.depth = mparams.transformer_layers
        self.layers = nn.ModuleList([
            EncoderBlock(mparams=mparams, hparams=hparams) for _ in range(self.depth)
        ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
class ViT(nn.Module):
    def __init__(self, mparams, hparams, img_info):
        super().__init__()
        image_width = img_info.width
        patch_size = mparams.patch_size
        num_patches = (image_width//patch_size)**2
        self.pos_embed = nn.Parameter(torch.rand(1, num_patches+1, mparams.inner_dim))
        self.patch_embed = PatchEmbedding(mparams=mparams, hparams=hparams, img_info=img_info)
        self.transformer = Transformer(mparams=mparams, hparams=hparams)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(mparams.inner_dim),
            nn.Linear(mparams.inner_dim, hparams.out_classes)
        )
        self.embed_dropout = nn.Dropout(mparams.embed_dropout)
    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.embed_dropout(x)
        x = self.transformer(x)
        cls_token_ouput = x[:,0] # or u can do x.mean(dim=1) if we do a mean pooling for the final cls token
        return self.mlp_head(cls_token_ouput)


In [24]:
test_tensor = torch.rand(2,3,32,32)
print(f'test tensor shape: {test_tensor.shape}')
vit = ViT(mparams=mparams, hparams=hparams, img_info=img_info)
output = vit.forward(test_tensor)
print(f'Output Shape: {output.shape}')

test tensor shape: torch.Size([2, 3, 32, 32])
Output Shape: torch.Size([2, 10])


# Data Loader

In [None]:
from torch.utils.data import DataLoader


# Trainer

In [None]:
from tqdm import tqdm
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
import random
import numpy as np
class Trainer:
    def __init__(self, model, train_loader, val_loader, optimizer, scaler, lr_scheduler, epochs):
        self.device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.mps.is_available() else 'cpu')
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optim = optimizer
        self.scaler = scaler
        self.lr_sch = lr_scheduler
        self.epochs = epochs
    def _train_one_epoch(self):
        model = self.model
        model.train()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0
        for image, label in tqdm(self.train_loader, desc='Training'):
            img, label = image.to(self.device), label.to(self.device)
            # Trying automatic mixed precision (AMP)
            if self.device =='cuda':
                with torch.amp.autocast():
                    logits = model(img)
                    loss = self.criterion(logits, label)
            # Backpropagation
            self.optim.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optim)
            self.scaler.update()
            # Update lr
            self.lr_sch.step()
            # Metrics
            total_loss += loss.item()
            _, predicted = torch.max(logits.data, 1)
            total_samples += label.size(0)
            total_correct += (predicted == label).sum().item()

        avg_loss = total_loss / len(self.train_loader)
        accuracy = 100 * total_correct / total_samples
        return avg_loss, accuracy
    def _validate(self):
        model = self.model
        model.eval()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            for images, labels in tqdm(self.val_loader, desc="Validating"):
                images, labels = images.to(self.device), labels.to(self.device)

                # Note: No need for AMP here, but it's fine if used
                outputs = model(images)
                loss = self.criterion(outputs, labels)

                # Metrics
                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_samples += labels.size(0)
                total_correct += (predicted == labels).sum().item()

        avg_loss = total_loss / len(self.val_loader)
        accuracy = 100 * total_correct / total_samples
        return avg_loss, accuracy
    def _run_trainer(self, model_path):
        model = self.model
        model.to(self.device)
        best_val_accuracy = 0.0
        print("Starting Training...")
        for epoch in range(self.epochs):
            print(f"--- Epoch {epoch+1}/{self.epochs} ---")

            train_loss, train_acc = self._train_one_epoch()
            val_loss, val_acc = self._validate()

            print(f"Epoch {epoch+1} Summary:")
            print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
            print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
            print(f"  Current Learning Rate: {self.optim.param_groups[0]['lr']:.6f}")

            # Save the model if it has the best validation accuracy so far
            if val_acc > best_val_accuracy:
                best_val_accuracy = val_acc
                torch.save(model.state_dict(), model_path)
                print(f"✅ New best model saved with accuracy: {best_val_accuracy:.2f}%")

        print("--- Training Finished ---")
        print(f"Best Validation Accuracy: {best_val_accuracy:.2f}%")
    @staticmethod
    def set_seed(seed: int, strict: bool = True):
        """
        Sets the seed for all relevant libraries for reproducibility.
        Args:
            seed (int): The seed value.
            strict (bool): If True, enforces full determinism which may cause errors on MPS/CUDA if an operation is not supported.
        """
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        # Specific settings for CUDA
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
            if strict:
                torch.backends.cudnn.benchmark = False
                torch.backends.cudnn.deterministic = True
        # Specific settings for MPS
        if torch.backends.mps.is_available():
            torch.mps.manual_seed(seed)
        # Enforce deterministic algorithms globally
        if strict:
            torch.use_deterministic_algorithms(True)


# Dataloader

In [None]:
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import collections
from torchvision.transforms.v2 import MixUp, CutMix

if not hasattr(collections.abc, 'Sequence'):
    collections.abc.Sequence = collections.Sequence

class DataHandler:
    def __init__(self, image_information, batch_size):
        self.img_info = image_information
        self.batch_size = batch_size

    def _train_transform(self):
        return transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomResizedCrop((self.img_info.width, self.img_info.height), scale=(0.8, 1.0)),
            # num_ops = number of augmentations to apply
            # magnitude = strength of augmentations
            transforms.RandAugment(num_ops=2, magnitude=9), # can be tweaked after checking how similar the images are
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

    def _val_test_transform(self):
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

    def get_dataloaders(self):
        train_dataset = datasets.CIFAR10(root='./data', train=True,
                                         download=True, transform=self._train_transform())

        val_dataset = datasets.CIFAR10(root='./data', train=True,
                                       download=True, transform=self._val_test_transform())

        # crave validation from train
        num_train = len(train_dataset)
        indices = list(range(num_train))
        #torch.manual_seed(42) # write a better code so that across multiple runs we keep a seed
        torch.utils.data.sampler.SubsetRandomSampler(indices)
        split = int(0.9 * num_train)
        train_indices, val_indices = indices[:split], indices[split:]

        # Create subsets based on the indices
        train_subset = Subset(train_dataset, train_indices)
        val_subset = Subset(val_dataset, val_indices)

        # Now, create the test dataset
        test_dataset = datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=self._val_test_transform())

        # Special Collate function for MixUp and CutMix
        # These operations need to be applied to a whole batch at once.
        mixup_cutmix = [
            MixUp(alpha=1.0, num_classes=10),
            CutMix(alpha=1.0, num_classes=10)
        ]
        combine_fn = lambda batch: mixup_cutmix[torch.randint(0,2,(1,)).item()](*torch.utils.data.default_collate(batch))


        # Create the DataLoaders
        train_loader = DataLoader(train_subset, batch_size=self.batch_size,
                                  shuffle=True, num_workers=2, collate_fn=combine_fn)

        val_loader = DataLoader(val_subset, batch_size=self.batch_size,
                                shuffle=False, num_workers=2)

        test_loader = DataLoader(test_dataset, batch_size=self.batch_size,
                                 shuffle=False, num_workers=2)

        print("DataLoaders created successfully.")
        print(f"Training samples: {len(train_subset)}")
        print(f"Validation samples: {len(val_subset)}")
        print(f"Test samples: {len(test_dataset)}")

        return train_loader, val_loader, test_loader

# --- How to use it ---
# data_handler = DataHandler(image_information=img_info, batch_size=1024)
# train_loader, val_loader, test_loader = data_handler.get_dataloaders()

In [None]:
import torch
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

class Evaluator:
    def __init__(self, model, test_loader, device, output_dir):
        self.model = model
        self.test_loader = test_loader
        self.device = device
        self.output_dir = output_dir
        self.class_names = self.test_loader.dataset.classes

    def plot_confusion_matrix(self, cm, file_path):
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=self.class_names, yticklabels=self.class_names)
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.title('Confusion Matrix')
        plt.savefig(file_path)
        plt.show() # Display in Colab

    def evaluate(self, model_path):
        """Loads the best model and computes final metrics."""
        print(f"\n--- Starting Final Evaluation ---")
        # Load the best model state
        self.model.load_state_dict(torch.load(model_path))
        self.model.to(self.device)
        self.model.eval()

        all_preds = []
        all_labels = []

        with torch.no_grad():
            for images, labels in tqdm(self.test_loader, desc="Testing"):
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # --- Calculate Metrics ---
        # Overall Accuracy
        overall_accuracy = accuracy_score(all_labels, all_preds)
        print(f"Overall Test Accuracy: {overall_accuracy * 100:.2f}%")

        # Classification Report (Precision, Recall, F1-score)
        report = classification_report(all_labels, all_preds,
                                       target_names=self.class_names, output_dict=True)

        # Confusion Matrix
        cm = confusion_matrix(all_labels, all_preds)

        # --- Save Results ---
        # Save classification report to a JSON file
        report_path = os.path.join(self.output_dir, "classification_report.json")
        with open(report_path, 'w') as f:
            json.dump(report, f, indent=4)
        print(f"Classification report saved to {report_path}")

        # Plot and save confusion matrix
        cm_path = os.path.join(self.output_dir, "confusion_matrix.png")
        self.plot_confusion_matrix(cm, cm_path)
        print(f"Confusion matrix plot saved to {cm_path}")

        return {"overall_accuracy": overall_accuracy, "report": report}

In [None]:
import os
def main(num_runs, master_seed, test_name, image_information, model_parameters, hyperparameters):
    NUM_RUNS = num_runs
    """ Checking MASTER_SEED
    1. MASTER_SEED = 42  --> used for trial runs (T0)
    """
    MASTER_SEED = master_seed
    for i in range(NUM_RUNS):
        run_seed = i+MASTER_SEED
        Trainer.set_seed(seed=run_seed)
        print(f"\n--- Starting Run {i+1}/{NUM_RUNS} (Seed: {run_seed}) ---")

        run_name = test_name+ f'_{i+1}'
        run_output_dir = f"run_outputs/{run_name}"
        local_model_path = os.path.join(run_output_dir, "best_model.pth")
        eval_results_dir = os.path.join(run_output_dir, "evaluation_results")
        os.makedirs(run_output_dir, exist_ok=True)
        os.makedirs(eval_results_dir, exist_ok=True)

        img_info = image_information
        mparams = model_parameters
        hparams = hyperparameters

        train_loader, val_loader, test_loader = DataHandler(image_information=img_info, batch_size=hparams.batch_size)

        base_model = ViT(mparams=mparams, hparams=hparams, img_info=img_info)
        base_optimizer = optim.AdamW(params=base_model.parameters(), lr=hparams.learning_rate, weight_decay=hparams.weight_decay)
        base_scaler = torch.amp.GradScaler(device='cuda')
        base_scheduler = OneCycleLR(optimizer=base_optimizer, max_lr=hparams.learning_rate, steps_per_epoch=len(train_loader), epochs=hparams.epochs)

        trainer = Trainer(model=base_model, train_loader=train_loader, val_loader=val_loader, optimizer=base_optimizer, scaler=base_scaler, lr_scheduler=base_scheduler)
        trainer._run_trainer(model_path=local_model_path)
        print("\n--- Training complete. Starting final evaluation on test set. ---")
        evaluator = Evaluator(
        model=base_model,
        test_loader=test_loader,
        device=trainer.device, 
        output_dir=eval_results_dir)
        final_metrics = evaluator.evaluate(model_path=local_model_path)


main(
    num_runs=3,
    master_seed=42,
    test_name=f'test_run_check_T0',
    image_information=ImageParams(width=32, height=32, in_channel=3),
    model_parameters=ModelParameters(patch_size=4, inner_dim=192, transformer_layers=12, num_head=3, embed_dropout=0.1, attn_dropout=0.0, mlp_dropout=0.1),
    hyperparameters=Hyperparameters(batch_size=1024, out_classes=10, epochs=100, learning_rate=5e-4*(1024/512), weight_decay=0.05)
)