In [118]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import math

In [119]:
class PatchEmbed(nn.Module):
    def __init__(self,img_size = 224, patch_size = 16, in_channels = 3, embed_dims = 768):
        super().__init__()
        self.img_size= img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dims = 768
        self.num_patches = (img_size//patch_size)**2

        self.proj = nn.Conv2d(in_channels =  in_channels,
                          out_channels = embed_dims,
                          kernel_size = patch_size, 
                          stride = patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1,2)

        return x

In [120]:
class clsToken(nn.Module):
    def __init__(self, embed_dims = 768, cls = 1):
        super().__init__()
        self.embed_dims = embed_dims
        self.cls_token = nn.Parameter(torch.zeros(1,1, embed_dims))

    def forward(self, x):
        batch_size = x.shape[0]
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_token,x), dim = 1)
        return x

In [121]:
class PositionalEmbedding(nn.Module):
    def __init__(self, num_patches, embed_dims):
        super().__init__()
        self.num_patches = num_patches +1 
        self.embed_dims = embed_dims

        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.embed_dims)) 
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward (self, x):
        x = x+self.pos_embed
        return x

In [122]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dims, num_heads):
        super().__init__()
        self.embed_dims = embed_dims
        self.num_heads = num_heads

        assert embed_dims % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.head_dim = embed_dims // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias = False)
        self.proj = nn.Linear(embed_dims, embed_dims)
        self.attn_dropout = nn.Dropout(0.0)
        self.proj_dropout = nn.Dropout(0.0)

    def forward(self, x):
        batch_size, num_tokens, embed_dims = x.shape
        qkv= self.qkv(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
        q, k ,v = qkv.permute(2, 0, 3, 1, 4) 

        attn = (q @ k.transpose(-2, -1)) * self.scale

        #softmax to get attn prob
        attn = attn.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(batch_size, num_tokens, embed_dims)

        x = self.proj(x)
        x = self.proj_dropout(x)

        return x

In [123]:
class MLP(nn.Module):
    def __init__(self, embed_dims, mlp_ratio=4.0, dropout_rate=0.0):
        super().__init__()
        self.embed_dims = embed_dims
        self.hidden_dim = int(embed_dims * mlp_ratio)
        self.fc1 = nn.Linear(embed_dims, self.hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(self.hidden_dim, embed_dims)
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


In [124]:
class transformerEncoder(nn.Module):
    def __init__(self, embed_dims,num_heads,  mlp_ratio = 4.0, dropout_rate = 0.0):
        super().__init__()
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.dropout_rate = dropout_rate

        self.norm1 = nn.LayerNorm(embed_dims)
        self.attn = MultiHeadAttention(
            embed_dims=embed_dims, 
            num_heads=num_heads
        ) 

        self.norm2 = nn.LayerNorm(embed_dims)
        self.mlp = MLP(
            embed_dims=embed_dims, 
            mlp_ratio=mlp_ratio, 
            dropout_rate=dropout_rate
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [125]:
class VisionTransformer(nn.Module):
    def __init__(self, 
                 img_size=224, 
                 patch_size=16, 
                 in_channels=3, 
                 num_classes=1000, 
                 embed_dims=768, 
                 num_layers=12,
                 num_heads=12,   
                 mlp_ratio=4.0, 
                 dropout_rate=0.1): 
        super().__init__()

        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.embed_dims = embed_dims
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.dropout_rate = dropout_rate

        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dims)
        self.cls_token_module = clsToken(embed_dims)
        num_patches = self.patch_embed.num_patches
        self.pos_embed_module = PositionalEmbedding(num_patches, embed_dims)
        self.pos_dropout = nn.Dropout(dropout_rate)
        self.transformer_blocks = nn.ModuleList([
            transformerEncoder(
                embed_dims=embed_dims,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout_rate=dropout_rate
            )
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dims)
        self.head = nn.Linear(embed_dims, num_classes)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.zeros_(m.bias)
                nn.init.ones_(m.weight)

    def forward(self, x):

        x = self.patch_embed(x) 
        x = self.cls_token_module(x)
        x = self.pos_embed_module(x)
        x = self.pos_dropout(x)
        for block in self.transformer_blocks:
            x = block(x)
        x = self.norm(x)
        cls_token_output = x[:, 0] 
        logits = self.head(cls_token_output)

        return logits
        

In [126]:
IMG_SIZE = 224
PATCH_SIZE = 16
EMBED_DIMS= 768
NUM_HEADS = 12
NUM_LAYERS = 12 
NUM_CLASSES = 37 
BATCH_SIZE = 4

dummy_images = torch.randn(BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE) 


vit_model = VisionTransformer(
    img_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=3,
    num_classes=NUM_CLASSES,
    embed_dims=EMBED_DIMS,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    mlp_ratio=4.0,
    dropout_rate=0.1
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit_model.to(device)
dummy_images = dummy_images.to(device)
output_logits = vit_model(dummy_images)

print(f"Input image shape: {dummy_images.shape}")
print(f"Final output logits shape: {output_logits.shape}")
total_params = sum(p.numel() for p in vit_model.parameters() if p.requires_grad)
print(f"Total learnable parameters: {total_params:,}")

Input image shape: torch.Size([4, 3, 224, 224])
Final output logits shape: torch.Size([4, 37])
Total learnable parameters: 85,799,461
