# Vision Transformer
- [Paper Link](https://arxiv.org/pdf/2010.11929v2.pdf)



## 1. Model from scratch

In [None]:
import torch
from torch import nn
from torch import Tensor
from einops.layers.torch import Rearrange
import numpy as np
from torch.nn import functional as F
import einops
from torch.utils.data.dataloader import DataLoader
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10
from torch.utils.data import random_split
from torchvision import transforms
from torch.utils.data import Subset
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import requests
import random
import os
from datetime import datetime

In [None]:
# Constants
WEIGHT_FOLDER_PATH = "./checkpoints"
BATCH_SIZE = 256
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5
EPOCHS = 300
MODEL_SAVE_NAME = f"{WEIGHT_FOLDER_PATH}/CURR-BEST-ViT-MODEL-CIFAR-10-BROAD-FINE.pt"


def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

In [None]:
cifar_10_labels = {0: {}, 1: {}}

In [None]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [None]:
class ImageClassificationBase(nn.Module):
    """
    Base class containing utilities
    """

    def __init__(self):
        super().__init__()
        self.acc = -1

    def training_step(self, batch):
        images, labels_fine, labels_broad = batch
        broad_out, fine_out = self(images)  # Generate predictions
        loss1 = F.cross_entropy(fine_out, labels_fine)  # Calculate loss
        loss2 = F.cross_entropy(broad_out, labels_broad)  # Calculate loss
        return loss1, loss2

    def validation_step(self, batch):
        images, labels_fine, labels_broad = batch
        broad_out, fine_out = self(images)  # Generate predictions

        loss1 = F.cross_entropy(fine_out, labels_fine)  # Calculate loss
        loss2 = F.cross_entropy(broad_out, labels_broad)  # Calculate loss

        acc_fine = accuracy(fine_out, labels_fine)
        acc_broad = accuracy(broad_out, labels_broad)

        return {
            "val_loss_fine": loss1.detach(),
            "val_loss_broad": loss2.detach(),
            "val_acc_fine": acc_fine,
            "val_acc_broad": acc_broad,
        }

    def validation_epoch_end(self, outputs):
        # Fine Label Validation
        batch_losses = [x["val_loss_fine"] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()  # Combine losses

        batch_accs = [x["val_acc_fine"] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()  # Combine accuracies

        # Broad Label Validation
        batch_losses_broad = [x["val_loss_broad"] for x in outputs]
        epoch_loss_broad = torch.stack(batch_losses_broad).mean()  # Combine losses

        batch_accs_broad = [x["val_acc_broad"] for x in outputs]
        epoch_acc_broad = torch.stack(batch_accs_broad).mean()  # Combine losses

        return {
            "val_loss_fine": epoch_loss.item(),
            "val_acc_fine": epoch_acc.item(),
            "val_loss_broad": epoch_loss_broad.item(),
            "val_acc_broad": epoch_acc_broad.item(),
        }

    def reset_stored_accuracy(self):
        self.acc = -1

    def epoch_end(self, epoch, result, mode):
        print(f"Epoch: [{epoch:<5}]")
        print(f"Train Loss Fine: {result['train_loss_fine']:.5f}")
        print(f"Val Loss Fine: {result['val_loss_fine']:.5f}")
        print(f"Val Acc Fine: {result['val_acc_fine']:.5f}")

        print(f"Train Loss Broad: {result['train_loss_broad']:.5f}")
        print(f"Val Loss Broad: {result['val_loss_broad']:.5f}")
        print(f"Val Acc Broad: {result['val_acc_broad']:.5f}")

        if epoch % 25 == 0:
            torch.save(self, f"./checkpoints/Vit-{datetime.now()}.pt")
            print(f"\nCheckpoint saved at ./checkpoints/Vit-{datetime.now()}.pt")

        mode_label = None

        if mode == "BROAD_ONLY":
            mode_label = "val_acc_broad"
        elif mode_label == "FINE_ONLY":
            label = "val_acc_fine"
        elif mode == "BROAD_AND_FINE":
            pass

        # TODO: This stays as is. Change to a separate function once needed
        # Keep this and for now

        if result[mode_label] > self.acc:
            print(
                f"{mode_label} Validation Accuracy Increased from {self.acc} to {result[mode_label]}"
            )
            self.acc = result[mode_label]
            save_path = MODEL_SAVE_NAME
            torch.save(self, save_path)
            print(f"Model Saved @ {save_path}")

In [None]:
class PositionalEmbedding1D(nn.Module):
    """
    Adds (optionally learned) positional embeddings to the inputs
    When using additional classification token, seq_len will be sequence length + 1
    """

    def __init__(self, seq_len, d_model):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, d_model))

    def forward(self, x):
        return x + self.pos_embedding

In [None]:
class MultiHeadSelfAttention(nn.Module):
    """
    Multi Head Attention Block.
    Multi Head, so splits and rejoins
    Takes in a tensor of shape (seq_len, emb_dim) and computes query, key, values
    """

    def __init__(self, emb_dim, num_heads, dropout_proba):
        super().__init__()
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.query_layer = nn.Linear(emb_dim, emb_dim)
        self.key_layer = nn.Linear(emb_dim, emb_dim)
        self.value_layer = nn.Linear(emb_dim, emb_dim)
        self.dropout = nn.Dropout(dropout_proba)
        self.scores = None

    def forward(self, x, mask=None):
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        """
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, numH, W_head) -trans-> (B, numH, S, W_head)
        q, k, v = self.query_layer(x), self.key_layer(x), self.value_layer(x)

        q, k, v = (
            einops.rearrange(i, "b s (nh wh) -> b nh s wh", nh=self.num_heads)
            for i in [q, k, v]
        )
        # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)
        scores = self.dropout(F.softmax(scores, dim=-1))
        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
        h = einops.rearrange(scores @ v, "b h s w -> b s h w")
        # -merge-> (B, S, D)
        h = einops.rearrange(h, "b s h w -> b s (h w)")
        self.scores = scores
        return h

In [None]:
class PositionWiseFeedForward(nn.Module):
    """
    FeedForward Neural Networks for each
    element of the seuqence
    """

    def __init__(self, emb_dim, feed_fwd_dim):
        super().__init__()
        self.fc1 = nn.Linear(emb_dim, feed_fwd_dim)
        self.fc2 = nn.Linear(feed_fwd_dim, emb_dim)

    def forward(self, x):
        # (B, S, D) -> (B, S, D_ff) -> (B, S, D)
        return self.fc2(F.gelu(self.fc1(x)))

In [None]:
class TransformerBlock(nn.Module):
    """
    Single Block of Transformer with Residual Connection

    """

    def __init__(self, dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.attn = MultiHeadSelfAttention(dim, num_heads, dropout)
        self.proj = nn.Linear(dim, dim)
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.pwff = PositionWiseFeedForward(dim, ff_dim)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask):
        h = self.drop(self.proj(self.attn(self.norm1(x), mask)))
        x = x + h
        h = self.drop(self.pwff(self.norm2(x)))
        x = x + h
        return x

In [None]:
class TransformerEncoder(nn.Module):
    """
    Transformer with Self-Attentive Blocks
    """

    def __init__(self, num_layers, emb_dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.blocks = nn.ModuleList(
            [
                TransformerBlock(emb_dim, num_heads, ff_dim, dropout)
                for _ in range(num_layers)
            ]
        )

    def forward(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return x

In [None]:
class VisionTransformer(ImageClassificationBase):
    """
    Main module for Vision Transformer for 2D colour images
    We will use the Conv2D hybrid architecture
    """

    def __init__(
        self,
        img_height: int,
        img_width: int,
        patch_dim: int,
        emb_dim: int,
        num_classes: int,
        in_channels: int = 3,
        num_heads: int = 12,
        pwff_dim: int = 3072,
        num_layers: int = 12,
        dropout: float = 0.6,
        num_broad_classes: int = 2,  # default for CIFAR 10
    ):
        super().__init__()
        self.img_height = img_height
        self.img_width = img_width
        self.patch_dim = patch_dim
        self.emb_dim = emb_dim
        self.num_patch_width = img_width // patch_dim
        self.num_patch_height = img_height // patch_dim
        # +2 for 2 additional tokens, one for broad class one for fine class
        self.seq_len = (self.num_patch_width * self.num_patch_height) + 2

        # (b, c, h, w) -> (b, nh*nw, d)
        self.embedding_layer = nn.Sequential(
            nn.Conv2d(
                in_channels,
                emb_dim,
                kernel_size=(patch_dim, patch_dim),
                stride=(patch_dim, patch_dim),
            ),
            Rearrange("b d x y -> b (x y) d"),
        )

        # class tokens
        self.fine_class_token = nn.Parameter(torch.zeros(1, 1, self.emb_dim))
        self.broad_class_token = nn.Parameter(torch.zeros(1, 1, self.emb_dim))

        # 1D positional embedding
        self.positional_embedding = PositionalEmbedding1D(self.seq_len, self.emb_dim)
        self.transformer_encoder = TransformerEncoder(
            num_layers, emb_dim, num_heads, pwff_dim, dropout
        )
        # Implement classification
        self.norm = nn.LayerNorm(emb_dim)
        self.mlp_fine = nn.Linear(emb_dim, num_classes)
        self.mlp_broad = nn.Linear(emb_dim, num_broad_classes)

    def add_class_tokens_to_input(self, x):
        """
        Adds [class] token to the input x
        """
        bs = x.shape[0]

        fine_class_token_expanded = einops.repeat(
            self.fine_class_token, "1 s d -> bs s d", bs=bs
        )
        broad_class_token_expanded = einops.repeat(
            self.broad_class_token, "1 s d -> bs s d", bs=bs
        )

        return torch.cat(
            [x, fine_class_token_expanded, broad_class_token_expanded], dim=1
        )

    def forward(self, x):
        y = self.embedding_layer(x)
        y = self.add_class_tokens_to_input(y)
        y = self.positional_embedding(y)
        y = self.transformer_encoder(y)

        # Get [class] tokens
        # last token is fine class token
        fine_class_token = self.norm(y)[:, -1]
        broad_class_token = self.norm(y)[:, -2]

        # Return fine and broad tokens from MLP head
        return (
            self.mlp_broad(broad_class_token),
            self.mlp_fine(fine_class_token),
        )

In [None]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

In [None]:
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

In [None]:
class DeviceDataLoader:
    """Wrap a dataloader to move data to a device"""

    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl:
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [None]:
device = get_default_device()
device

In [None]:
@torch.no_grad()
def evaluate(model, val_loader, mode):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

In [None]:
def fit(epochs, model, train_loader, val_loader, optimizer, mode):
    """
    epochs - number of epochs
    model - the model in required device
    train_loader - train dataloader
    val_loader - val dataloader
    optimizer - optimizer, with parameters and parameters set
    """
    history = []
    for epoch in range(epochs):
        # Training Phase
        model.train()
        train_losses_broad = []
        train_losses_fine = []

        for batch in train_loader:
            loss_fine, loss_broad = model.training_step(batch)
            train_losses_broad.append(loss_broad)
            train_losses_fine.append(loss_fine)

            # loss_fine.backward(retain_graph=True)
            loss_broad.backward()
            optimizer.step()
            optimizer.zero_grad()

        # Validation phase
        result = evaluate(model, val_loader, mode)

        result["train_loss_fine"] = torch.stack(train_losses_fine).mean().item()
        result["train_loss_broad"] = torch.stack(train_losses_broad).mean().item()

        model.epoch_end(epoch, result, mode)
        history.append(result)
    return history

In [None]:
from torch.utils.data import random_split

In [None]:
from numpy.random import choice

class_label = [i for i in range(10)]

device = get_default_device()
transform = transforms.Compose(
    [
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float),
    ]
)

In [None]:
from typing import Any, Tuple
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR10


def get_broad_label(idx):
    fine_to_broad = [0, 0, 1, 1, 1, 1, 1, 1, 0, 0]
    return fine_to_broad[idx]


class CIFAR10MultiLabelDataset(CIFAR10):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __len__(self):
        return super().__len__()

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img_tensor, fine_label = super().__getitem__(index)
        return img_tensor, fine_label, get_broad_label(fine_label)

In [None]:
dataset = CIFAR10MultiLabelDataset(".", download=True, transform=transform)

TRAIN_DATA_SIZE = 25000
VAL_DATA_SIZE = 10000
REMAINING = len(dataset) - TRAIN_DATA_SIZE - VAL_DATA_SIZE

train_ds, val_ds, _ = random_split(dataset, [TRAIN_DATA_SIZE, VAL_DATA_SIZE, REMAINING])

len(train_ds), len(val_ds)

In [None]:
#  DataLoaders
train_dl = DataLoader(
    train_ds, BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
)
val_dl = DataLoader(val_ds, BATCH_SIZE * 2, num_workers=2, pin_memory=True)

In [None]:
from dataclasses import dataclass


@dataclass
class Mode:
    FINE_ONLY: str = "FINE_ONLY"
    BROAD_ONLY: str = "BROAD_ONLY"
    FINE_AND_BROAD: str = "FINE_AND_BROAD"


all_modes = Mode()

In [None]:
# Experiment 2
model = VisionTransformer(
    img_height=32,
    img_width=32,
    patch_dim=8,
    emb_dim=768,
    num_classes=10,
    in_channels=3,
    num_heads=8,
    num_layers=8,
    pwff_dim=3072,
    dropout=0.5,
    num_broad_classes=2,  # default for CIFAR 10
)

model

In [None]:
LOAD_MODEL = False

if LOAD_MODEL:
    print("loading saved")
    model = torch.load(MODEL_SAVE_NAME)

model = to_device(model, device)
print(f"Last Saved Model Validation Accuracy: {model.acc}")

train_dl, val_dl = DeviceDataLoader(train_dl, device), DeviceDataLoader(val_dl, device)

# Sanity Test
evaluate(model, val_dl, all_modes.BROAD_ONLY)

In [None]:
optimizer = torch.optim.SGD(
    model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
# Parameters
EPOCHS = 300

history = fit(EPOCHS, model, train_dl, val_dl, optimizer, all_modes.BROAD_ONLY)

In [None]:
evaluate(model, val_dl, all_modes.BROAD_ONLY)