# Vision Transformer (ViT) from Scratch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/vision_transformer.ipynb)

ViT applies the pure Transformer architecture directly to sequences of image patches.

Key Steps:
1. **Patchify:** Split image into fixed-size patches.
2. **Linear Projection:** Flatten patches and map to `d_model`.
3. **Position Embeddings:** Add learnable position vectors.
4. **Transformer Encoder:** Standard BERT-like encoder.
5. **Classification Head:** MLP on `[CLS]` token.

In [None]:
!pip install torch torchvision matplotlib

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Patch Embeddings

Convert image (C, H, W) -> Sequence of (N, d_model).
Can be implemented using a Conv2d layer with kernel_size = stride = patch_size.

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, d_model=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Use Conv2d to implement patch projection
        self.proj = nn.Conv2d(in_chans, d_model, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: [B, C, H, W]
        x = self.proj(x)  # [B, d_model, H/P, W/P]
        x = x.flatten(2)  # [B, d_model, N_patches]
        x = x.transpose(1, 2)  # [B, N_patches, d_model]
        return x

## 2. Transformer Encoder (Standard)

Same as BERT encoder.

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.scale = (dim // n_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio=4.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, n_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

## 3. Full ViT Model

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, 
                 d_model=768, depth=12, n_heads=12):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, d_model)
        
        # CLS token and Positional Embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, d_model))
        
        self.blocks = nn.ModuleList([
            Block(d_model, n_heads) for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        
        # Append CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Add pos embedding
        x = x + self.pos_embed
        
        for blk in self.blocks:
            x = blk(x)
            
        x = self.norm(x)
        cls_out = x[:, 0]  # Take CLS token only
        return self.head(cls_out)

# Init ViT-Base
model = VisionTransformer(img_size=224, patch_size=16, d_model=768, depth=12, n_heads=12).to(device)
print(f"ViT-Base Initialized: {sum(p.numel() for p in model.parameters())/1e6:.1f}M params")

## 4. Visualize Patch Embeddings

Visualizing how an image is split.

In [None]:
# Fake image batch
img = torch.randn(1, 3, 224, 224, device=device)
output = model(img)
print(f"Output shape: {output.shape} (Batch, Classes)")

# Visualize filters of first layer
filters = model.patch_embed.proj.weight.data.cpu()
print(f"Patch Filters: {filters.shape}")

fig, axes = plt.subplots(4, 8, figsize=(10, 5))
for i, ax in enumerate(axes.flatten()):
    # Normalize filter for visualization
    f = filters[i].permute(1, 2, 0)
    f = (f - f.min()) / (f.max() - f.min())
    ax.imshow(f)
    ax.axis('off')
plt.suptitle('First 32 Patch Projection Filters (Random Init)')
plt.show()