# Vision Transformer (ViT) Architecture Demonstration

This Jupyter Notebook demonstrates the architectural correctness of the Vision Transformer (ViT) implementation. We will:

1.  **Define the ViT model components** (Patch Embedding, Positional Encoding, Encoder Block, Classification Head, and the full ViT model).
2.  **Instantiate the ViT model** with example parameters.
3.  **Create a dummy input image tensor**.
4.  **Perform a forward pass** through the model, step-by-step, and print the **output shape at each stage** to verify that it aligns with the expected ViT architecture.

This notebook focuses on **architectural verification** and does not include training or performance evaluation.

In [1]:
import torch
from torch import nn
from model import *

In [2]:
class Patch_Embedding_Projection(nn.Module):
    def __init__(self, patch_size, in_channels, embedding_dim):
        super().__init__()
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embedding_dim = embedding_dim
        self.patch_former = nn.Unfold(kernel_size=patch_size, stride=patch_size, padding=0)
        self.lin_proj = nn.Linear(in_features=patch_size ** 2 * self.in_channels, out_features=embedding_dim)

    def forward(self, x):
        patches = self.patch_former(x)
        x = patches.transpose(1, 2)
        return self.lin_proj(x)

In [3]:
class POS_ENC(nn.Module):
    def __init__(self, seq_len, enc_dim):
        super(POS_ENC, self).__init__()
        self.enc_dim = enc_dim
        self.pos_embeddings = nn.Parameter(torch.randn(1, seq_len, enc_dim))

    def forward(self, x):
        return x + self.pos_embeddings

In [4]:
class Encoder(nn.Module):
    def __init__(self, enc_dim, num_heads):
        super(Encoder, self).__init__()
        self.enc_dim = enc_dim
        self.num_heads = num_heads
        self.head_dim = enc_dim // num_heads

        assert enc_dim % num_heads == 0, "enc_dim must be divisible by num_heads"

        self.norm1 = nn.LayerNorm(enc_dim)

        self.wq = nn.Linear(in_features=enc_dim, out_features=enc_dim)
        self.wk = nn.Linear(in_features=enc_dim, out_features=enc_dim)
        self.wv = nn.Linear(in_features=enc_dim, out_features=enc_dim)

        self.softmax = nn.Softmax(dim=-1)
        self.norm2 = nn.LayerNorm(enc_dim)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=enc_dim, out_features=enc_dim*4),
            nn.GELU(),
            nn.Linear(in_features=enc_dim*4, out_features=enc_dim))
        self.proj_o = nn.Linear(in_features=enc_dim, out_features=enc_dim)

    def attention(self, x):
        batch_size, seq_len, enc_dim = x.shape

        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        k_q = torch.matmul(q, k.transpose(-2, -1))
        s_kq = self.softmax(k_q / (self.head_dim**0.5))
        attn_output_heads = torch.matmul(s_kq, v)

        attn_output_concat = attn_output_heads.transpose(1, 2).contiguous().view(batch_size, seq_len, enc_dim)
        attn_output_proj = self.proj_o(attn_output_concat)

        return attn_output_proj

    def forward(self, x):
        x_residual = x
        x =  self.norm1(x)
        attn_output = self.attention(x)
        x = attn_output + x_residual

        x_residual_mlp = x
        x = self.norm2(x)
        mlp_output = self.mlp(x)
        x = mlp_output + x_residual_mlp
        return x

In [5]:
class ViTClassifierHead(nn.Module):
    def __init__(self, enc_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(in_features=enc_dim, out_features=num_classes)

    def forward(self, x):
        class_token = x[:, 0, :]
        logits = self.fc(class_token)
        return logits

In [6]:
class ViTForImageClassification(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, num_classes, embedding_dim, depth, num_heads):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self.depth = depth
        self.num_heads = num_heads

        assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size."
        self.num_patches = (image_size // patch_size) ** 2
        self.seq_len = self.num_patches + 1

        self.patch_embed_proj = Patch_Embedding_Projection(patch_size=patch_size, in_channels=in_channels, embedding_dim=embedding_dim)
        self.pos_embedding = POS_ENC(seq_len=self.seq_len,enc_dim=self.embedding_dim)
        self.class_token = nn.Parameter(torch.randn(1, 1, embedding_dim))

        self.encoder_layers = nn.Sequential(*[Encoder(enc_dim=embedding_dim, num_heads=num_heads) for _ in range(depth)])
        self.norm_head = nn.LayerNorm(embedding_dim)
        self.classifier_head = ViTClassifierHead(enc_dim=embedding_dim, num_classes=num_classes)

    def forward(self, x):
        batch_size = x.shape[0]
        patch_embeddings = self.patch_embed_proj(x)

        class_token = self.class_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((class_token, patch_embeddings), dim=1)
        embeddings_with_pos = self.pos_embedding(embeddings)

        encoded_sequence = self.encoder_layers(embeddings_with_pos)
        pooled = self.norm_head(encoded_sequence)
        logits = self.classifier_head(pooled)
        return logits

## Instantiate the ViT Model

Here, we instantiate the `ViTForImageClassification` model with some example parameters. These parameters are based on the ViT-Base configuration but can be adjusted.

**Parameters:**
*   `image_size`: Input image size (e.g., 224x224).
*   `patch_size`: Size of image patches (e.g., 16x16).
*   `in_channels`: Number of input image channels (e.g., 3 for RGB).
*   `num_classes`: Number of classes for classification (e.g., 1000 for ImageNet).
*   `embedding_dim`: Embedding dimension for patches and tokens (e.g., 768).
*   `depth`: Number of Transformer Encoder layers (e.g., 12).
*   `num_heads`: Number of attention heads in Multi-Head Self-Attention (e.g., 12).

In [7]:
# Example parameters (ViT-Base like)
image_size = 224
patch_size = 16
in_channels = 3
num_classes = 1000
embedding_dim = 768
depth = 12
num_heads = 12

# Instantiate the ViT model
model = ViTForImageClassification(
    image_size=image_size,
    patch_size=patch_size,
    in_channels=in_channels,
    num_classes=num_classes,
    embedding_dim=embedding_dim,
    depth=depth,
    num_heads=num_heads
)

print("ViT Model Instantiated!")

ViT Model Instantiated!


## Create a Dummy Input Image

We create a dummy input image tensor to simulate a batch of images passing through the network. The shape should be `(batch_size, in_channels, image_size, image_size)`.

In [8]:
batch_size = 2 # Example batch size
dummy_input = torch.randn(batch_size, in_channels, image_size, image_size)

print("Dummy Input Image Shape:", dummy_input.shape)

Dummy Input Image Shape: torch.Size([2, 3, 224, 224])


## Patch Embedding and Projection

We pass the dummy input image through the `Patch_Embedding_Projection` layer. This layer should:
*   Divide the image into patches of size `patch_size` x `patch_size`.
*   Flatten each patch.
*   Project the flattened patches to `embedding_dim`.

The expected output shape is `(batch_size, num_patches, embedding_dim)`, where `num_patches = (image_size / patch_size) ** 2`.

In [9]:
patch_embeddings = model.patch_embed_proj(dummy_input)

print("Patch Embeddings Shape:", patch_embeddings.shape)
expected_num_patches = (image_size // patch_size) ** 2
assert patch_embeddings.shape == (batch_size, expected_num_patches, embedding_dim), \
       f"Patch Embeddings shape is incorrect. Expected {(batch_size, expected_num_patches, embedding_dim)}, but got {patch_embeddings.shape}"

Patch Embeddings Shape: torch.Size([2, 196, 768])


## Positional Encoding

We add positional embeddings to the patch embeddings using the `POS_ENC` layer. This layer adds learnable positional embeddings to each position in the sequence.

The shape should remain the same: `(batch_size, num_patches + 1, embedding_dim)` because we prepend the class token later, but for now we check the shape after positional encoding is applied to patch embeddings *before* the class token is prepended in the full model's forward pass.

In [10]:
# Create class token and prepend, then apply positional encoding (simulating inside ViTForImageClassification forward)
class_token = model.class_token.expand(batch_size, -1, -1)
embeddings_with_class_token = torch.cat((class_token, patch_embeddings), dim=1)
embeddings_with_pos = model.pos_embedding(embeddings_with_class_token)

print("Embeddings with Positional Encoding Shape:", embeddings_with_pos.shape)
expected_seq_len = expected_num_patches + 1 # +1 for class token
assert embeddings_with_pos.shape == (batch_size, expected_seq_len, embedding_dim), \
       f"Embeddings with Positional Encoding shape is incorrect. Expected {(batch_size, expected_seq_len, embedding_dim)}, but got {embeddings_with_pos.shape}"

Embeddings with Positional Encoding Shape: torch.Size([2, 197, 768])


## Transformer Encoder Layers

We pass the embeddings with positional encoding through the stacked Transformer Encoder layers (`model.encoder_layers`). Each Encoder layer consists of Multi-Head Self-Attention and MLP blocks.

The shape should remain the same throughout the Encoder layers: `(batch_size, num_patches + 1, embedding_dim)`.

In [11]:
encoded_sequence = model.encoder_layers(embeddings_with_pos)

print("Output Shape after Transformer Encoder Layers:", encoded_sequence.shape)
assert encoded_sequence.shape == (batch_size, expected_seq_len, embedding_dim), \
       f"Encoder Output shape is incorrect. Expected {(batch_size, expected_seq_len, embedding_dim)}, but got {encoded_sequence.shape}"

Output Shape after Transformer Encoder Layers: torch.Size([2, 197, 768])


## Classification Head

Finally, we pass the output from the Transformer Encoder through the `ViTClassifierHead`. This head extracts the `[class]` token representation and projects it to class logits.

The expected output shape is `(batch_size, num_classes)`, representing the logits for each class.

In [12]:
output_logits = model.classifier_head(encoded_sequence)

print("Output Logits Shape from Classification Head:", output_logits.shape)
assert output_logits.shape == (batch_size, num_classes), \
       f"Classification Head output shape is incorrect. Expected {(batch_size, num_classes)}, but got {output_logits.shape}"

Output Logits Shape from Classification Head: torch.Size([2, 1000])


## Full Forward Pass Verification

To further confirm, let's perform a full forward pass through the entire `ViTForImageClassification` model in one step and check the final output logits shape.

In [13]:
full_output_logits = model(dummy_input)

print("Output Logits Shape from Full Forward Pass:", full_output_logits.shape)
assert full_output_logits.shape == (batch_size, num_classes), \
       f"Full Forward Pass output shape is incorrect. Expected {(batch_size, num_classes)}, but got {full_output_logits.shape}"

Output Logits Shape from Full Forward Pass: torch.Size([2, 1000])


## Conclusion

The notebook successfully demonstrated that the implemented Vision Transformer (ViT) code works according to the expected architecture.

**Verification Points:**
*   **Patch Embedding:** The `Patch_Embedding_Projection` layer correctly divides the image into patches and projects them to the embedding dimension.
*   **Positional Encoding:** Positional embeddings are correctly added to the patch embeddings.
*   **Transformer Encoder:** The Transformer Encoder layers maintain the sequence length and embedding dimension.
*   **Classification Head:** The `ViTClassifierHead` correctly extracts the `[class]` token representation and projects it to the final logits of shape `(batch_size, num_classes)`.

**Next Steps:**
*   Train the ViT model on an image classification dataset (e.g., CIFAR-10, ImageNet).
*   Evaluate the model's performance on validation and test sets.
*   Experiment with different ViT configurations (depth, embedding dimension, number of heads, patch size).

This notebook provides a solid foundation for further experimentation and development with Vision Transformers.