# Vision Transformer (ViT)

Vision Transformers (ViTs) apply the Transformer architecture—originally designed for NLP—directly to image recognition tasks.

## Key Concepts
- **Image patches**: An image is divided into fixed-size patches (e.g., 16×16 pixels).
- **Linear embedding**: Each patch is flattened and projected into a vector embedding.
- **Positional encoding**: Since Transformers lack inherent spatial information, positional encodings are added.
- **Transformer encoder**: Multi-head self-attention layers process the sequence of patch embeddings.
- **Classification token ([CLS])**: A learnable token prepended to the sequence; its final representation is used for classification.

ViTs have achieved state-of-the-art results in image classification and beyond.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.pos_embedding = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))

    def forward(self, x):
        B = x.shape[0]
        x = self.projection(x).flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embedding
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, emb_size=768, num_heads=8, hidden_dim=2048, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=emb_size, num_heads=num_heads, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(emb_size, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, emb_size)
        )
        self.ln1 = nn.LayerNorm(emb_size)
        self.ln2 = nn.LayerNorm(emb_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.ln1(x + self.dropout(attn_out))
        ff_out = self.ff(x)
        x = self.ln2(x + self.dropout(ff_out))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=10, emb_size=768, depth=6, num_heads=8):
        super().__init__()
        self.embedding = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
        self.encoders = nn.ModuleList([
            TransformerEncoder(emb_size, num_heads) for _ in range(depth)
        ])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, num_classes)
        )

    def forward(self, x):
        x = self.embedding(x)
        for encoder in self.encoders:
            x = encoder(x)
        cls_token = x[:, 0]
        return self.mlp_head(cls_token)

# Example usage
model = VisionTransformer(img_size=224, patch_size=16, num_classes=10)
dummy_input = torch.randn(4, 3, 224, 224)
output = model(dummy_input)
print(output.shape)  # torch.Size([4, 10])

## Summary
- Vision Transformers split images into patches instead of using convolutions.
- Self-attention enables learning long-range dependencies.
- Pretraining on large datasets (like ImageNet-21k or JFT-300M) is crucial.
- ViTs compete with CNNs in image classification and extend well to detection, segmentation, and vision-language tasks.