In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, input_resolution:int, patch_size:int, width:int, layers:int, heads:int, output_dim:int):
        """class VisionTransformer(nn.Module):

        Vision Transformer (ViT) for image classification.

        Splits the input image into patches, projects each patch into a `width`-dim embedding, prepends a 
        learnable class token, adds learnable positional embeddings, and passes the sequence through a 
        Transformer encoder. The output of the class token is used as the global image representation.

        Args:
            input_resolution (int): Input image size (assumed square), e.g., 224 for 224x224 images.
            patch_size (int): Size of each square patch, e.g., 16.
            width (int): Embedding dimension for patches and class token (d_model), e.g., 768.
            layers (int): Number of Transformer blocks to stack.
            heads (int): Number of attention heads per block.
            output_dim (int): Dimension of final output (e.g., number of classes).
            in_channels (int, optional): Number of image channels (default 3 for RGB).

        Notes:
            - class_embedding is learnable and summarizes the image.It is same for all images initially as it is randomly initialized.After training it learns to be a good summary of the image.And is different for different images.
            Eg. Sequence for one image: [Class_embedding, patch1, patch2, patch3, patch4]
            - positional_embedding is learnable and encodes spatial info.
            - conv1 splits image into patches and projects them.
            - scale = width**-0.5 ensures stable initialization.
        """
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
        
        scale = width ** -0.5 #eg, for dim=768, scale = 1/sqrt(768)=0.036
        #initializing embeddings with huge random numbers (e.g. N(0,1)),we shrink them so their average magnitude ‚âà 0.036 ‚Üí stable dot products, stable softmax.
        #class embedding is a learnable parameter that is added to the sequence of patch embeddings to represent the entire image
        self.class_embedding = nn.Parameter(scale*torch.randn(width)) # a random vector of size width
        # learnable positional embeddings for each patch + 1 for class embedding
        self.positional_embedding = nn.Parameter(scale*torch.randn((input_resolution // patch_size) ** 2 + 1, width)) # as it is learnable, we write parameter when imnitializing, so that it can be updated during training.
        self.ln_pre = LayerNorm(width)
        
        self.transformer = Transformer(width, layers, heads)
        
        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale*torch.randn(width, output_dim))
        
    def forward(self, x: torch.Tensor, output_all_features: bool = False, output_attention_map: bool = False):
        """
        Forward pass of the Vision Transformer.

        Args:
            x (torch.Tensor): Input images of shape [batch_size, in_channels, height, width].
            output_all_features (bool, optional): If True, the function will also return all patch embeddings, not just the class token. Useful for things like visualizing patch features or doing segmentation tasks. Default: False.
            output_attention_map (bool, optional): If True, the function will return attention maps from the class token to all patches. Useful for visualizing where the model ‚Äúlooks‚Äù in the image.Default: False.

        Returns:
            tuple: Contains at least the class token features (cls_feature) of shape [batch_size, output_dim].
                Optionally:
                - Patch embeddings of shape [batch_size, num_patches, width] (if output_all_features=True)
                - Attention maps of shape [n_layers, batch_size, n_heads, grid, grid] (if output_attention_map=True)

        Notes:
            - Images are first converted to patch embeddings via a Conv2d layer.
            - A learnable class token is prepended to the sequence to aggregate image-level information.
            - Learnable positional embeddings are added to each token (including the class token).
            - The sequence is normalized (LayerNorm) and passed through the Transformer blocks.
            - The class token embedding is extracted, normalized, and projected to output_dim for downstream tasks.
            - Patch embeddings and attention maps are optional outputs useful for visualization or analysis.

            Input image: [B, 3, 224, 224]
                    ‚îÇ
            Conv2d ‚Üí Patch embeddings: [B, 768, 14, 14]
                    ‚îÇ
            Flatten ‚Üí [B, 196, 768]
                    ‚îÇ
            Add CLS token ‚Üí [B, 197, 768]
                    ‚îÇ
            Add positional embeddings ‚Üí [B, 197, 768]
                    ‚îÇ
            LayerNorm ‚Üí [B, 197, 768]
                    ‚îÇ
            Transformer ‚Üí [B, 197, 768], attn maps [layers, B, heads, 197, 197]
                    ‚îÇ
            Extract CLS token ‚Üí [B, output_dim]
                    ‚îÇ
            Optional outputs ‚Üí patch embeddings, attention maps

        """
        #split image into non-overlapping patches and project to `width` dimensions
        x = self.conv1(x) #shape=[*, width, grid, grid], eg. [*, 768, 14, 14] for 224x224 input and 16x16 patches
        grid = x.size(2)
        #flatten the 2D grid into a sequence of patches
        x = x.reshape(x.shape[0], x.shape[1],-1) #shape=[*, width, grid**2] eg. [*, 768, 196]
        x = x.permute(0,2,1) #shape=[*, grid**2, width]
        #add class token to the beginning of the sequence
        # self.class_embedding has shape (width,) ‚Üí 1D vector, this exmplanation is written in the vision_transformer_explanation.ipynb file with exmaple
        batch_class_token = self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
        x=torch.cat(
            [batch_class_token, x],
             dim=1) #shape=[*, grid**2+1, width]
        x=x + self.positional_embedding.to(x.dtype) #add positional embeddings
        
        #pre normalize all sequqnce elements including class token before feeding to the transfprmer though there is ln_1 (layernorm) in each block of transformer as the ln_1 normalizes per block, so both are needed
        x=self.ln_pre(x)
        # as transformer expects input of shape (seq_len, batch, width)
        x.permute(1,0,2) #NLD-> LND , shape=[grid**2+1, *, width] ,(* is batch size)
        x,attn = self.transformer(x) #shape=[grid**2+1, *, width], attn shape=[layers, *, heads, grid**2+1, grid**2+1]
        x = x.permute(1,0,2) #LND-> NLD
        
        
        
        # this is for class_feature extraction and is used for feature extraction or classification, ...It is not strictly needed. but in the cocap, it is included for feature extraction
        # ln_post normalizes the class token embedding
        # @ self.proj projects the embedding from width -> output_dim. as it is matrix multiplication
        # Shape after projection: [batch_size, output_dim]
        #x[:, 0, :] ‚Üí selects the class token embedding for all images in the batch.
        cls_feature = self.ln_post(x[:,0,:]) @ self.proj ## cls_feature.shape = [batch_size, output_dim]
        
        # 1Ô∏è‚É£1Ô∏è‚É£ Prepare outputs tuple
        # Start with just the class token feature as primary output
        outputs = (cls_feature,)
        
        # Optional: include patch embeddings
        if output_all_features:
            # x[:, 1:, :] excludes the class token and keeps only patch embeddings
            # Shape: [batch_size, num_patches, width]
            outputs += (x[:, 1:, :],)
            # Purpose:
            # - Useful for tasks where individual patch features are needed
            #   e.g., segmentation, attention visualization, or feature extraction

        # Optional: include attention maps
        if output_attention_map:
            # attn.shape = [n_layers, batch_size, n_heads, seq_len, seq_len]
            # attn[:, :, :, 0, 1:] selects attention from the class token to all patches
            # Shape: [n_layers, batch_size, n_heads, num_patches]
            # einops rearranges it to match the 2D grid layout of patches: [n_layers, batch_size, n_heads, h, w]
            outputs += (einops.rearrange(
                attn[:, :, :, 0, 1:],  # class token attends to patches
                "n_layers b n_heads (h w) -> n_layers b n_heads h w",
                h=grid, w=grid
            ),)
            # Purpose:
            # - Visualizes where the class token "looks" in the image
            # - Helpful for interpretability of attention

        # 1Ô∏è‚É£2Ô∏è‚É£ Return final outputs
        # Tuple contains:
        # 1. cls_feature: image-level representation for classification
        # 2. (optional) patch embeddings: individual patch features
        # 3. (optional) attention maps: visualization of attention from class token to patches
        return outputs

## EXPLAINING ABOVE CODE

Ahhh okay, now I **get exactly what‚Äôs confusing you** üòÖ ‚Äî you‚Äôre asking **not what the outputs are**, but **where do they come from inside the network**, i.e., the **flow of tensors through the blocks** that produces `cls_feature`, `patch_features`, and `attention_map`. Let‚Äôs go **step by step from input to output**, including which intermediate nodes produce them.

---

# Step 0Ô∏è‚É£ Input image

```python
x: [batch, in_channels, H, W] ‚Üí e.g., [8, 3, 224, 224]
```

* This is your raw image.
* Goes into `self.conv1` to turn into **patch embeddings**.

---

# Step 1Ô∏è‚É£ Patch embedding (`x` after conv1)

```python
x = self.conv1(x)  # shape: [B, width, grid, grid] ‚Üí e.g., [8, 768, 14, 14]
```

* Each patch becomes a `width`-dimensional embedding vector.
* Output: **patch-level features** (still 2D grid at this stage).
* Not yet a sequence for transformer.

---

# Step 2Ô∏è‚É£ Flatten grid ‚Üí sequence

```python
x = x.reshape(B, width, grid*grid)  # [8, 768, 196]
x = x.permute(0, 2, 1)             # [B, num_patches, width] ‚Üí [8, 196, 768]
```

* Each patch is now **one token in a sequence**.
* This is **the sequence input to the transformer**, before adding class token.

---

# Step 3Ô∏è‚É£ Add CLS token

```python
batch_class_token = self.class_embedding.to(x.dtype) + torch.zeros(B, 1, width)
x = torch.cat([batch_class_token, x], dim=1)
```

* `self.class_embedding` ‚Üí **learnable tensor** `[1, width]`
* Broadcasted to `[B, 1, width]`
* Concatenated **at the start of the sequence** ‚Üí `[B, num_patches+1, width]` ‚Üí `[8, 197, 768]`

‚úÖ Now the **CLS token is part of the input sequence**, not separate.

---

# Step 4Ô∏è‚É£ Add positional embeddings

```python
x = x + self.positional_embedding.to(x.dtype)
```

* Shape: `[B, 197, 768]`
* Each token (CLS + patches) gets a **position vector**.
* Still part of the same sequence.

---

# Step 5Ô∏è‚É£ Pre-normalize

```python
x = self.ln_pre(x)
```

* Normalizes each token vector
* Output shape: `[B, 197, 768]`
* Still **sequence of tokens**.

---

# Step 6Ô∏è‚É£ Transformer blocks

```python
x.permute(1,0,2)          # [seq_len, B, width] ‚Üí [197, 8, 768]
x, attn = self.transformer(x)  
x = x.permute(1,0,2)      # back to [B, 197, 768]
```

### What happens here:

* `self.transformer` is a stack of **multi-head attention + feed-forward blocks**
* Input: `[CLS + patch tokens]` sequence
* Output:

  * `x` ‚Üí **same shape `[B, 197, 768]`**, updated embeddings for each token

    * `x[:,0,:]` ‚Üí **CLS token embedding**
    * `x[:,1:,:]` ‚Üí **patch embeddings**
  * `attn` ‚Üí attention matrices `[n_layers, B, n_heads, seq_len, seq_len]`

‚úÖ **Key point:** CLS token is **not injected separately**, it travels **through the transformer as part of the sequence**, learning to ‚Äúaggregate‚Äù information from all patches.

---

# Step 7Ô∏è‚É£ Extract CLS token ‚Üí `cls_feature`

```python
cls_feature = self.ln_post(x[:,0,:]) @ self.proj
```

* `x[:,0,:]` ‚Üí **first token** (CLS) after transformer ‚Üí `[B, width]`
* LayerNorm ‚Üí normalize
* Linear projection ‚Üí `[B, output_dim]`

üí° This is why **`cls_feature` comes from the previous node**, specifically:

```
x (output of transformer) ‚Üí select [:,0,:] ‚Üí ln_post ‚Üí proj
```

* Not injected externally. CLS token **propagates with sequence** and transforms.


---

# Step 8Ô∏è‚É£ Optional: Patch embeddings ‚Üí `output_all_features`

```python
if output_all_features:
    outputs += (x[:, 1:, :],)
```

* **Source:** transformer output sequence `x`

  * `x[:,1:,:]` ‚Üí all tokens **except CLS token**
* **Shape:** `[B, num_patches, width]` ‚Üí e.g., `[8, 196, 768]`
* **Purpose / Why use:**

  * Gives **individual patch-level embeddings**
  * Useful for tasks that require **per-patch information**, e.g.:

    * Segmentation
    * Patch-level feature extraction
    * Attention visualization per patch
* **Why we don‚Äôt always need it:**

  * For **image-level classification**, the **CLS token already aggregates all patch information**
  * Returning all patch embeddings increases memory footprint unnecessarily

‚úÖ So this output is optional ‚Äî only included if `output_all_features=True`.

---

# Step 9Ô∏è‚É£ Optional: Attention maps ‚Üí `output_attention_map`

```python
if output_attention_map:
    outputs += (einops.rearrange(
        attn[:, :, :, 0, 1:],  # CLS token attends to all patches
        "n_layers b n_heads (h w) -> n_layers b n_heads h w",
        h=grid, w=grid
    ),)
```

* **Source:** `attn` tensor from transformer blocks

  * `attn` shape: `[n_layers, B, n_heads, seq_len, seq_len]` ‚Üí `[12, 8, 12, 197, 197]` for example
  * Slice `[:, :, :, 0, 1:]` ‚Üí **only CLS token attending to patches**
  * Shape after slice: `[n_layers, B, n_heads, num_patches]` ‚Üí `[12, 8, 12, 196]`
  * Reshaped to `[n_layers, B, n_heads, grid, grid]` ‚Üí `[12, 8, 12, 14, 14]` for visualization
* **Purpose / Why use:**

  * Visualizes **where the CLS token is ‚Äúlooking‚Äù** in the image
  * Helps **interpretability** ‚Äî e.g., which patches the model focuses on for classification
* **Why we don‚Äôt need all attention:**

  * Full attention tensor `[n_layers, B, n_heads, seq_len, seq_len]` contains **patch ‚Üí patch attention**
  * Patch ‚Üí patch attention is mostly **internal information**, rarely needed for analysis or downstream tasks
  * Returning only CLS ‚Üí patches gives **meaningful attention heatmaps** and is **much smaller**
  * `attn[:, :, :, 0, 1:]` ‚Üí takes **all layers, all batches, all heads**, but only the **CLS token (row 0) attending to all patches (columns 1 onward)**
  * This represents **how the CLS token ‚Äúlooks at‚Äù each patch** in the image, across the model hierarchy and attention heads
  * We don‚Äôt take full attention (`seq_len √ó seq_len`) because **patch ‚Üí patch attention** is mostly **internal** and not needed for visualization or classification
  * This slice is **small, meaningful, and directly interpretable** as a heatmap over the image patches


  

‚úÖ Optional output, included only if `output_attention_map=True`.

---

# Step 10Ô∏è‚É£ Return final outputs

```python
return outputs
```

* **Tuple contains:**

  1. `cls_feature` ‚Üí `[B, output_dim]` ‚Üí main image representation
  2. `(optional) patch_features` ‚Üí `[B, num_patches, width]` ‚Üí only if `output_all_features=True`
  3. `(optional) attention_map` ‚Üí `[n_layers, B, n_heads, grid, grid]` ‚Üí only if `output_attention_map=True`

* **Key insight:**

  * CLS token is **central** ‚Äî aggregates all information
  * Patch embeddings and attention maps are **slices of the same transformer output**
  * They are optional because **for most classification tasks, CLS token alone is sufficient**
  * Returning the others is useful for **visualization, research, or patch-level tasks**

---

If you want, I can now **draw a full diagram showing the CLS token + patches flowing through the transformer, and exactly where `cls_feature`, `patch_features`, and `attention_map` come from**, including their **shapes at each step**.

Do you want me to do that next?

---

# ‚úÖ Summary of ‚Äúwhere each output comes from‚Äù

| Output         | Source inside forward                               | Shape                                | Notes                                                         |
| -------------- | --------------------------------------------------- | ------------------------------------ | ------------------------------------------------------------- |
| cls_feature    | `x[:,0,:]` from transformer output ‚Üí ln_post ‚Üí proj | `[B, output_dim]`                    | Image-level representation, learned via CLS token aggregation |
| patch_features | `x[:,1:,:]` from transformer output                 | `[B, num_patches, width]`            | Each patch embedding after transformer                        |
| attention_map  | `attn[:, :, :, 0, 1:]` from transformer attention   | `[n_layers, B, n_heads, grid, grid]` | CLS token attention to patches, reshaped                      |

‚úÖ **Key insight:** CLS token is **part of the input sequence**, travels through transformer, learns to aggregate patches ‚Üí `cls_feature`. Patch embeddings and attention maps are **slices or transformations** of the **same transformer output**.

---

If you want, I can **draw a diagram showing CLS token + patch tokens flowing through transformer, and exactly where cls_feature, patch_features, and attention maps are extracted**, which makes this super obvious visually.

Do you want me to do that?
