# VIT Implementation

(B, N, 3, h, d)  ==  (2, 8, 3, 4, 16)
This means:

2 batches

8 tokens

3 values (q, k, v)

4 attention heads

16-dimensional space per head


In [1]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module): # Turning image into tokens

    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=64):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size

        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)  # Splits image into non-overlapping patches of size patch_size × patch_size
                                                                                                  # It mimics splitting image into tokens, like words
                                                                                                  # Each image patch is converted into a vector of length embed_dim.
                                                                                                  # for example : The kernel shape is (64, 3, 4, 4),It slides a 4×4 patch over the image
                                                                                                  # and for each patch, it applies all 64 filters.Each filter reduces that 3×4×4 patch → a single number ,64 values per patch
    def forward(self, x):
        # x: (B, 3, 32, 32)
        x = self.proj(x)  # (B, embed_dim, H/patch, W/patch)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)                           # (B,64,8,8) -> (B,64,8*8) -> transpose(1,2) -> This swaps dimensions dim1 and dim2 of a tensor.
                                                                                                  # we went from filters,image per filter to pixel,representation from all filters in a vector (Simplified idea)
        return x

class MultiHeadSelfAttention(nn.Module):

    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        assert dim % num_heads == 0

        self.qkv = nn.Linear(dim, dim * 3) # Single layer computes Q,K,V
        self.out = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, D = x.shape
        qkv = self.qkv(x)  # (B, N, 3D)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)          # (B,N,3,H,D) -> (3,B,H,N,D).permute reorders the dimensions of a tensor
                                                                                                  #You're not changing the data, just relabeling how it’s interpreted — like rotating or transposing dimensions

        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = (q @ k.transpose(-2, -1)) / self.head_dim**0.5  # (B, heads, N, N)
        attn = scores.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, D)
        return self.out(out)

class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, num_classes=10,
                 embed_dim=64, depth=6, num_heads=4):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) #(B, N, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # (batch,num of tokens,dim_same_as_patch) prepending makes (B, N+1, embed_dim)

        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, embed_dim))
        self.blocks = nn.Sequential(*[
            TransformerEncoderBlock(embed_dim, num_heads) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)  # (B, N, D)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls, x), dim=1)
        x = x + self.pos_embed[:, :x.size(1), :]

        x = self.blocks(x)
        x = self.norm(x)
        return self.head(x[:, 0])  # Class token
