# Vision Transformer (ViT) Detailed Explanation

This document explains how the `VisionTransformer` class works in PyTorch, focusing on the `__init__` and `forward` functions, with concrete examples.

In [3]:
import torch
import torch.nn as nn

class VisionTransformer(nn.Module):
    
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,
    in_channels: int = 3):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        # Convolution to convert image patches into embeddings
        self.conv1 = nn.Conv2d(
        in_channels=in_channels,
        out_channels=width,
        kernel_size=patch_size,
        stride=patch_size,
        bias=False
        )

        # Scaling factor for initialization
        scale = width ** -0.5

        # Learnable class embedding (same for all images initially)
        self.class_embedding = nn.Parameter(scale * torch.randn(width))

        # Learnable positional embedding (sequence length = num_patches + 1)
        self.positional_embedding = nn.Parameter(
        scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
        )

        # LayerNorm before Transformer
        self.ln_pre = LayerNorm(width)

        # Transformer blocks
        self.transformer = Transformer(width, layers, heads)

        # LayerNorm after Transformer
        self.ln_post = LayerNorm(width)

        # Projection matrix to final output dimension
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
    
    
    def forward(self, x: torch.Tensor, output_all_features: bool = False, output_attention_map: bool = False):
        # Step 1: Convert image patches into embeddings using conv
        x = self.conv1(x) # shape = [batch, width, grid, grid]
        grid = x.size(2)


        # Step 2: Flatten patches
        x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [batch, width, grid**2]
        x = x.permute(0, 2, 1) # shape = [batch, num_patches, width]

        # Step 3: Prepare class token
        batch_class_token = self.class_embedding + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
        x = torch.cat([batch_class_token, x], dim=1) # shape = [batch, num_patches+1, width]

        # Step 4: Add positional embeddings
        x = x + self.positional_embedding.to(x.dtype)

        # Step 5: LayerNorm before Transformer
        x = self.ln_pre(x)

        # Step 6: Permute for Transformer input
        x = x.permute(1, 0, 2) # shape: [seq_len, batch, width]

        # Step 7: Pass through Transformer
        x, attn = self.transformer(x)
        x = x.permute(1, 0, 2) # shape: [batch, seq_len, width]

        # Step 8: Take class token embedding for output
        cls_feature = self.ln_post(x[:, 0, :]) @ self.proj

        # Step 9: Prepare outputs
        outputs = (cls_feature,)
        if output_all_features:
            outputs += (x[:, 1:, :],) # patch embeddings
        if output_attention_map:
            outputs += (einops.rearrange(attn[:, :, :, 0, 1:], 'n_layers b n_heads (h w) -> n_layers b n_heads h w', h=grid, w=grid),)

        return outputs

### Explanation with Example

* **Input image**: 2 images of size 6x6, 1 channel each
* **Patch size**: 3x3 → each image has `(6/3)^2 = 4` patches
* **Width (embedding dimension)**: 3
* **Class embedding**: shape `(width,) = (3,)`, same for all images initially
* **Positional embedding**: shape `(num_patches + 1, width) = (5, 3)`
* **Convolution (`conv1`)**: converts each 3x3 patch to a vector of size `width` (3)
* **Transformer**: processes sequence of patch embeddings plus class token
* **Projection**: final output dimension of the class token after Transformer

---


### Step-by-step Explanation with Example

#### **Input patches example**

```python
# Batch of 2 images, 4 patches per image, embedding dim = 3
x = torch.tensor([
    [[0.1,0.2,0.3], [0.4,0.5,0.6], [0.7,0.8,0.9], [1.0,1.1,1.2]],  # image 1
    [[1.1,1.2,1.3], [1.4,1.5,1.6], [1.7,1.8,1.9], [2.0,2.1,2.2]]   # image 2
])  # shape = (2, 4, 3)
```

### **Adding Class Token to Batch (Explanation Only)**

#### **Input example**

```python
# Batch of 2 images, 4 patches per image, embedding dim = 3
x = torch.tensor([
    [[0.1,0.2,0.3], [0.4,0.5,0.6], [0.7,0.8,0.9], [1.0,1.1,1.2]],  # image 1
    [[1.1,1.2,1.3], [1.4,1.5,1.6], [1.7,1.8,1.9], [2.0,2.1,2.2]]   # image 2
])  # shape = (2, 4, 3)
```

#### **Class embedding**

```python
class_embedding = torch.tensor([0.5, 0.6, 0.7])  # shape = (width,) = (3,)
```

* Learnable vector, same for all images initially.

#### **Expanding class token for batch**

```python
batch_class_token = class_embedding + torch.zeros(x.shape[0], 1, x.shape[-1])
# shape = (2, 1, 3)
```

* `torch.zeros(batch_size, 1, width)` acts as a placeholder.
* Adding `class_embedding` broadcasts it to each image in the batch.

```
batch_class_token = [
 [[0.5, 0.6, 0.7]],  # image 1
 [[0.5, 0.6, 0.7]]   # image 2
]
```

#### **Concatenating class token with patches**

```python
x = torch.cat([batch_class_token, x], dim=1)  # shape = (2, 5, 3)
```

* Sequence per image: `[CLS, P1, P2, P3, P4]`
* Now each image in the batch has its own copy of the class token prepended.

✅ **Key points:**

* Class embedding is **shared initially across all images**.
* Each image gets a **separate copy** in the batch via broadcasting.
* After Transformer, the class token becomes **image-specific** through attention to patches.


* Sequence per image: `[CLS, P1, P2, P3, P4]`
---

#### **Add positional embeddings**

* Positional embedding shape: `(5, 3)` → added to each batch

#### **LayerNorm & Transformer**

* Normalizes embeddings along width dimension
* Transformer attends patches + class token

#### **Class token output**

* `cls_feature = ln_post(x[:, 0, :]) @ proj` → image-specific representation
* Even though **initial class token was same for all images**, attention gives it unique info for each image

---

### ✅ Key Points

1. **Class embedding**: same vector for all images, learnable, prepended as first token
2. **Patch embeddings**: flattened and embedded via convolution
3. **Sequence length**: `num_patches + 1` (including class token)
4. **Positional embeddings**: added to sequence, learnable, shape `(num_patches+1, width)`
5. **Transformer**: outputs updated embeddings for each token; class token summarizes image content
6. **Final output**: class token after Transformer → projected to desired output dimension