# Vision Transformer Implementation

In [1]:
# imports
import torch
from torch import nn

from tqdm.auto import tqdm 
from typing import Union, List, Tuple, Optional, Dict

from einops import rearrange
from einops.layers.torch import Rearrange

In [2]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device):
        super().__init__()

        self.layer = nn.Sequential(
            nn.LayerNorm(input_dim, device = device),
            nn.Linear(input_dim, hidden_dim, device = device),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim, device= device),
        )
    
    def forward(self, x):
        x = self.layer(x)
        return x

In [7]:
def test_MLP():
    mlp = MLP(128, 256, 128, 'cpu')
    x = torch.randn(1, 128)
    out = mlp(x)
    assert out.shape == (1, 128), f"Expected output shape (1, 128), but got {out.shape}"  
    print(out.shape)  
test_MLP()

torch.Size([1, 128])


In [8]:
# We will be doing this for square images and the patches will be square as well
def reshape_for_vit(self, sample_to_reshape, patch_size):
    b, c, h, w = sample_to_reshape.shape
    assert h % patch_size == 0 and w % patch_size == 0, "Height and Width must be divisible by patch size"

    # reshape c,h,w into c num_patches, patch_size*patch_size *c
    num_patches = (h // patch_size) * (w // patch_size)
    return sample_to_reshape.reshape(b, num_patches, patch_size*patch_size*c)

In [9]:
class MSA(nn.Module):
    def __init__(self, dim, num_heads, device):
        super().__init__()
        self.device = device
        self.mha = nn.MultiheadAttention(dim, num_heads, device = device)
    def forward(self, x):
        x = x.to(self.device)
        x, _ = self.mha(x, x, x)
        return x

In [10]:
def test_MSA():
    msa = MSA(128, 8, 'cpu')
    x = torch.randn(1, 16, 128)
    out = msa(x)
    assert out.shape == (1, 16, 128), f"Expected output shape (1, 16, 128), but got {out.shape}"  
    print(out.shape)
test_MSA()

torch.Size([1, 16, 128])


In [None]:
class Transformer(nn.Module):
    def __init__(self, dim, hidden_dim, num_heads, layers, device):
        super().__init__()

        # need to figure out how to compute D and then get the z array from x
        self.layers = nn.ModuleList([
            nn.ModuleList([
                MSA(dim, num_heads, device),
                MLP(dim, hidden_dim, dim, device)
            ])
            for _ in range(layers)
        ])

        self.layer_norm = nn.LayerNorm(dim)
        
    def forward(self, x):
        for attn, ffn in self.layers:
            x = attn(x) + x
            x = ffn(x) + x
        return self.layer_norm(x)
    

In [31]:
def test_transformer():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    x = torch.randn(1, 64, 512).to(device) # as we will be passing a 3 size image
    transformer = Transformer(512, 128, 8, 6, device)
    out = transformer(x)
    print(out.shape)
test_transformer()
    

torch.Size([1, 64, 512])


In [None]:
class ViT(nn.Module):
    def __init__(self, image_size: Tuple, patch_size, dim, hid_dim, num_classes, num_heads, num_layers, channels =3, device= "cuda"):
        super().__init__()
        self.device = device
        assert len(image_size) == 2, "Image size must be a tuple of 2 elements"
        # image size can be be H * W
        H, W = image_size
        assert H % patch_size == 0 and W % patch_size == 0, "Height and Width must be divisible by patch size"
        # number of patches
        num_patches = (H // patch_size) * (W // patch_size)
        # patch embedding
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_size * patch_size * channels),
            nn.Linear(patch_size * patch_size * channels, dim),
            nn.LayerNorm(dim)
        )

        self.pos_embedding = nn.Parameter(torch.randn(num_patches + 1, dim))
        self.transformer = Transformer(dim, hid_dim, num_heads, num_layers, device)
        self.mlp_head = MLP(dim, dim, num_classes, device)
    
    def forward(self, x):
        z = self.to_patch_embedding(x) + self.pos_embedding # PROBLEM HERE
        z = self.transformer(z)
        z = self.mlp_head(z)
        return z     

In [34]:
def test_ViT_Shapes():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    x = torch.randn(1, 3, 224, 224).to(device)
    vit = ViT((224, 224), 16, 512, 128, 1000, 8, 6, 3, device)
    out = vit(x)
    print(out.shape)
test_ViT_Shapes()

RuntimeError: The size of tensor a (196) must match the size of tensor b (197) at non-singleton dimension 1