# Training


Importing libaries and setting up setups

In [None]:
# Importing and setups
!pip install ptflops # need to install everytime either cpu or gpu
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import pandas as pd
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, classification_report, roc_curve, auc
import csv
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from ptflops import get_model_complexity_info

# Set seeds for reproducibility
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)  # also need to set cuda seed
np.random.seed(seed)
torch.backends.cudnn.deterministic = True  # reproducible

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cpu':
    print("WARNING: Training will be very slow without GPU!")

Collecting ptflops
  Downloading ptflops-0.7.4-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0->ptflops)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0->ptflops)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0->ptflops)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0->ptflops)
  Downloading nvidia_

Preparing dataset

In [None]:
# Data Preparation with Augmentation
class CIFAR10DataModule:
    def __init__(self, batch_size=128, num_workers=4):
        self.batch_size = batch_size
        self.num_workers = num_workers

        # CIFAR10 normalization values - DON'T CHANGE
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2470, 0.2435, 0.2616)

        # Define transformations
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),  # standard augmentation
            transforms.RandAugment(num_ops=2, magnitude=9),  # tried 3 ops but too aggressive
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])

        # No augmentation for test set
        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])

    def setup(self):
        # Download datasets
        print("Setting up datasets...")
        self.train_dataset = datasets.CIFAR10(
            root='./data',
            train=True,
            download=True,
            transform=self.train_transform
        )

        self.val_dataset = datasets.CIFAR10(
            root='./data',
            train=False,
            download=True,
            transform=self.test_transform
        )
        print(f"Loaded {len(self.train_dataset)} training and {len(self.val_dataset)} validation samples")

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,  # important for training!
            num_workers=self.num_workers,
            pin_memory=True  # helps if using GPU
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,  # no need to shuffle for validation
            num_workers=self.num_workers,
            pin_memory=True
        )

patch embedding

In [None]:
# Patch Embedding Layer
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        # Originally used a linear layer here, but conv is more efficient and does the same thing
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        # x shape: [B, C, H, W]
        B, C, H, W = x.shape
        assert H == self.img_size and W == self.img_size, \
            f"Input image size ({H}*{W}) doesn't match expected size ({self.img_size}*{self.img_size})"

        # [B, C, H, W] -> [B, E, H/P, W/P] -> [B, E, (H/P)*(W/P)] -> [B, (H/P)*(W/P), E]
        x = self.proj(x)  # [B, E, H/P, W/P]
        x = x.flatten(2)  # [B, E, (H/P)*(W/P)]
        x = x.transpose(1, 2)  # [B, (H/P)*(W/P), E]

        return x

Multi Head Attention (MHA)

In [None]:
# Multi-Head Self-Attention
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=192, num_heads=8, dropout=0.1): # 192/8 = 24 per head
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Double-check dimensions
        assert self.head_dim * num_heads == embed_dim, \
            f"embed_dim {embed_dim} must be divisible by num_heads {num_heads}"

        # Combined QKV projections
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_dropout = nn.Dropout(dropout)
        self.proj_dropout = nn.Dropout(dropout)



    def forward(self, x):
        # x shape: [B, N, E] - B=batch, N=sequence_length, E=embedding_dim
        B, N, E = x.shape

        # Project to Q, K, V and reshape for multi-head attention
        # This is that fancy reshape for multi-head attention
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # [B, H, N, D] - H=heads, D=head_dim

        # Scaled dot-product attention
        # The scaling is super important - training dies without it
        attn = (q @ k.transpose(-2, -1)) * (1.0 / np.sqrt(self.head_dim))  # [B, H, N, N]
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)  # helps generalization

        # Apply attention to values
        x = (attn @ v).transpose(1, 2).reshape(B, N, E)  # [B, N, E]
        x = self.proj(x)  # final projection
        x = self.proj_dropout(x)

        return x

MLP

In [None]:
# MLP Block
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        # GELU Better than ReLU for transformers
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)  # second dropout seems to help
        return x

Transformer Encoder Block

In [None]:
# Transformer Encoder Block
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=192, num_heads=8, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(
            in_features=embed_dim,
            hidden_features=int(embed_dim * mlp_ratio),  # the ratio matters!
            out_features=embed_dim,
            dropout=dropout
        )
        # NOTE: we're using pre-norm formulation

    def forward(self, x):
        # Pre-norm formulation - more stable, can train deeper networks
        # x + sublayer(norm(x)) instead of norm(x + sublayer(x))
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

Complete Vision Transformer Model

In [None]:
# Complete Vision Transformer Model
class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=32,
        patch_size=4,  # 4x4 patches for CIFAR ie(32^2//4^2 == 64 tokens)
        in_channels=3, # RGB channel
        num_classes=10,# number of expected outputs
        embed_dim=192,  # tried 384 but too many params for CIFAR tend to overfit
        depth=9,  # paper uses 12, but 9 is enough for CIFAR and 12 tend to overfit
        num_heads=8,  # must divide embed_dim evenly 192/8 = 24
        mlp_ratio=4.0,
        dropout=0.1, # probablity of skiping connection ie 10 percent
        embed_dropout=0.1  # separate dropout rate for embeddings
    ):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.num_tokens = (img_size // patch_size) ** 2

        # Patch embedding
        self.patch_embed = PatchEmbedding(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim
        )

        # Class token and position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # Position embeddings - could use sinusoidal but learned works fine
        # postional embeddings are used because we have 8 multi head attention we need assign position for each vector
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens + 1, embed_dim))

        # Initialize weights for faster convergence
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

        self.dropout = nn.Dropout(embed_dropout)

        # Transformer blocks - this is the main part of the model
        self.blocks = nn.ModuleList([
            TransformerBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout
            )
            for _ in range(depth) # we just use for loop instead rewriting tranformer 8 times
        ])

        # Final normalization layer
        self.norm = nn.LayerNorm(embed_dim)

        # Classification head - just a linear layer
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        self.apply(self._init_weights)

        # How many params?
        #print(f"ViT params: {sum(p.numel() for p in self.parameters())}")

    def _init_weights(self, m):
        # Weight initialization matters for transformers!
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # x shape: [B, C, H, W]
        B = x.shape[0]

        # Create patch embeddings
        x = self.patch_embed(x)  # [B, N, E]

        # Add class token - used for final classification
        cls_token = self.cls_token.expand(B, -1, -1)  # [B, 1, E]
        x = torch.cat((cls_token, x), dim=1)  # [B, N+1, E]

        # Add position embeddings and apply dropout
        x = x + self.pos_embed  # broadcasting takes care of batch dim
        x = self.dropout(x)

        # Pass through transformer blocks
        for i, block in enumerate(self.blocks):
            # Could add intermediate supervision here?
            # Tried it, didn't help much, so removed it
            x = block(x)

        # Apply final normalization
        x = self.norm(x)

        # Take class token for classification
        # Could use pooling over all tokens but this works better
        x = x[:, 0]  # just get CLS token

        # Classification head
        x = self.head(x)
        # Could add an extra non-linearity here but linear seems fine

        return x