<a href="https://colab.research.google.com/github/agrawalabr/deeplearning/blob/main/Vision%20Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformers(VIT) | Transformers in Computer Vision

Transformer architectures, originally designed for Natural Language Processing (NLP), have revolutionized the field of deep learning, forming the backbone of state-of-the-art models across various NLP applications. More recently, these architectures have been successfully adapted to computer vision tasks, leading to a new paradigm in image processing.

The [Vision Transformer (ViT)](https://arxiv.org/pdf/2010.11929), gained large attention in 2021, demonstrates how standard transformer architectures can achieve competitive performance on image classification tasks. The core idea is to split an image into smaller patches, treat each patch as a token, and process them using a sequence of self-attention-based transformer blocks. This approach eliminates the need for traditional convolutional neural networks (CNNs), offering a more flexible and scalable framework for vision tasks.

However, working with ViTs presents a few challenges:
- Computational Complexity: ViTs require significantly more computational resources compared to CNNs due to their high number of parameters. Training from scratch demands large-scale datasets and extensive GPU/TPU resources.
- Interpretability: Unlike CNNs, which leverage spatial hierarchies, transformers rely on global attention mechanisms, making them even more difficult to interpret.
-	Pretraining Dependency: In most practical scenarios, ViTs are pre-trained on massive datasets (e.g., ImageNet-21k, JFT-300M) before being fine-tuned on specific downstream tasks. Training from scratch is feasible but often impractical for smaller datasets.

Despite these challenges, ViTs have shown impressive results in image classification, object detection, and segmentation, signaling a shift in deep learning methodologies for vision tasks. In this notebook, we will explore how to build a Vision Transformer from scratch, gaining insights into its structure, training requirements, and performance characteristics.

- Name: Abhishek Agrawal
- NetId: aa9360

In [None]:
!pip install einops
!pip install torchinfo

In [None]:
import torch
from torch import nn
from torch import nn, einsum
import torch.nn.functional as F
from torch import optim
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import numpy as np
import torchvision
import time
from torchinfo import summary
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

[[0.64520577 0.01220981 0.41715171]
 [0.85467564 0.13222973 0.02282506]]
tensor([[0.6452, 0.0122, 0.4172],
        [0.8547, 0.1322, 0.0228]], dtype=torch.float64)


In [None]:
torch.manual_seed(42)
DOWNLOAD_PATH = './data'
BATCH_SIZE_TRAIN = 100
BATCH_SIZE_TEST = 1000
MEAN = 0.2859
STD = 0.3530

tensor([[10.6452, 10.0122, 10.4172],
        [10.8547, 10.1322, 10.0228]], dtype=torch.float64)

tensor([[0.6014, 0.0122, 0.4052],
        [0.7544, 0.1318, 0.0228]], dtype=torch.float64)

tensor(2.0843, dtype=torch.float64)

tensor(0.3474, dtype=torch.float64)

torch.Size([2, 3])


In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainingdata = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
testdata = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

trainDataLoader = torch.utils.data.DataLoader(trainingdata,batch_size=64,shuffle=True)
testDataLoader = torch.utils.data.DataLoader(testdata,batch_size=64,shuffle=False)

In [None]:
def pair(t):
return t if isinstance(t, tuple) else (t, t)

In [None]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),  # Alternative: nn.GELU()
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads

        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]) for _ in range(depth)
        ])

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim,
                 pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()

        assert image_size % patch_size == 0, "Image dimensions must be divisible by patch size."

        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size * patch_size

        assert pool in {'cls', 'mean'}, "pool type must be either 'cls' (CLS token) or 'mean' (mean pooling)."

        self.to_patch_embedding = nn.Sequential(
            rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_dim, dim)
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat([cls_tokens, x], dim=1)

        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x[:, 0] if self.pool == 'cls' else x.mean(dim=1)

        return self.mlp_head(x)

In [None]:
model = ViT(image_size=28, patch_size=4, num_classes=10, channels=1,
dim=64, depth=6, heads=4, mlp_dim=128)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
train_loss_history = []
test_loss_history = []
train_accuracy_history = []
test_accuracy_history = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 30

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0

    loop = tqdm(trainDataLoader, desc=f"Epoch {epoch+1}/{num_epochs} - Training", leave=False)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        predicted_output = model(images)
        loss = loss_fn(predicted_output, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        _, predicted = torch.max(predicted_output, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    train_accuracy = correct_train / total_train
    train_loss = train_loss / len(trainDataLoader)

    train_loss_history.append(train_loss)
    train_accuracy_history.append(train_accuracy)

    model.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0

    with torch.no_grad():
        loop = tqdm(testDataLoader, desc="Evaluating", leave=False)
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)

            predicted_output = model(images)
            loss = loss_fn(predicted_output, labels)
            test_loss += loss.item()

            _, predicted = torch.max(predicted_output, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

    test_accuracy = correct_test / total_test
    test_loss = test_loss / len(testDataLoader)

    test_loss_history.append(test_loss)
    test_accuracy_history.append(test_accuracy)

    print(f"Epoch {epoch+1}/{num_epochs} | "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f} | "
          f"Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.4f}")