In [1]:
import torch
import torch.nn as nn
from einops import rearrange, repeat

### Patch Embedding
It has three components:
- Convert the image into sequence of patches
- Add CLS token to sequence of patches
- Add positional encoding to all the patches. 

In [39]:
class PatchEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        image_height = config["image_height"]
        image_width = config["image_width"]
        im_channels = config["im_channels"]
        emb_dim = config["emb_dim"] # Transformer dimentions(D)
        patch_embd_dropout = config["patch_emb_dropout"]

        self.patch_height = config["patch_height"]
        self.patch_width = config["patch_width"]

        num_patches = (image_height // self.patch_height) * (image_width // self.patch_width)

        patch_dim = im_channels * self.patch_height * self.patch_width    
        
        # W belongs to R^(patch_dim x emb_dim)
        self.patch_emb = nn.Sequential(
            nn.LayerNorm(patch_dim),           
            nn.Linear(patch_dim, emb_dim),
            nn.LayerNorm(emb_dim),
        )
        
        # Positional information needs to be added to cls as well so 1+num_patches
        self.pos_emb = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))
        self.cls_token = nn.Parameter(torch.randn(emb_dim))    # CLS token belongs to R^emb_dim
        self.patch_emb_dropout = nn.Dropout(patch_embd_dropout)

    def forward(self, x):
        batch_size = x.shape[0]

        out  = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_height, p2=self.patch_width)  # split image into patches

        out = self.patch_emb(out)
        cls_token = repeat(self.cls_token, 'd -> b n d', b=batch_size, n=1)
        out = torch.cat([cls_token, out], dim=1)
        out += self.pos_emb
        out = self.patch_emb_dropout(out)

        return out


In [51]:
#example run
image = torch.randn(1, 3, 224, 224)
config = {
    "image_height": 224,
    "image_width": 224,
    "im_channels": 3,
    "emb_dim": 512,
    "patch_height": 16,
    "patch_width": 16,
    "patch_emb_dropout": 0.1
}
patch_emb = PatchEmbedding(config)
out = patch_emb(image)
print(out.shape)  

torch.Size([1, 197, 512])


### Attention Module

In [47]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config["n_heads"]
        self.head_dim = config["head_dim"]   # d_h
        self.emb_dim = config["emb_dim"]
        self.drop_prob = config["dropout"] if "dropout" in config else 0.0
        self.att_dim  = self.n_heads * self.head_dim

        self.qkv_proj = nn.Linear(self.emb_dim, self.att_dim * 3, bias=False)
        self.att_drop = nn.Dropout(self.drop_prob)

        self.out_proj = nn.Sequential(
            nn.Linear(self.att_dim, self.emb_dim),
            nn.Dropout(self.drop_prob)
        )

    def forward(self, x):
        B, N = x.shape[:2]  # B: batch size, N: number of tokens

        q, k,v = self.qkv_proj(x).split(self.att_dim, dim=-1)
        #split into heads
        q = rearrange(q, 'b n (h d_h) -> b h n d_h', h=self.n_heads, d_h=self.head_dim) 
        k = rearrange(k, 'b n (h d_h) -> b h n d_h', h=self.n_heads, d_h=self.head_dim)
        v = rearrange(v, 'b n (h d_h) -> b h n d_h', h=self.n_heads, d_h=self.head_dim)

        #Scaled dot product attention

        att = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        att = nn.functional.softmax(att, dim=-1)
        att = self.att_drop(att)


        #Weighted Value Computation
        out = torch.matmul(att, v)

        #Rearrange heads
        out = rearrange(out, 'b h n d_h -> b n (h d_h)', h=self.n_heads, d_h=self.head_dim)
        out = self.out_proj(out)

        return out
        

In [48]:
#example run
config = {
    "n_heads": 8,
    "head_dim": 64,
    "emb_dim": 512,
    "dropout": 0.1
}

att = Attention(config)
out = att(out)
print(out.shape)

torch.Size([1, 197, 512])


### Transformer

In [52]:
class TransformerLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        emb_dim = config["emb_dim"]
        ff_hidden_dim = config["ff_dim"] if "ff_dim" in config else 4 * emb_dim
        ff_dropout = config["ff_drop"] if "ff_drop" in config else 0.0
        self.att_norm = nn.LayerNorm(emb_dim)
        self.ff_norm = nn.LayerNorm(emb_dim)
        self.attention_block = Attention(config)
        self.ff_block = nn.Sequential(
            nn.Linear(emb_dim, ff_hidden_dim),
            nn.GELU(),
            nn.Dropout(ff_dropout),
            nn.Linear(ff_hidden_dim, emb_dim),
            nn.Dropout(ff_dropout)
        )

    def forward(self, x):
        out = x
        out = out + self.attention_block(self.att_norm(out))
        out = out + self.ff_block(self.ff_norm(out))
        return out
        

In [54]:
#expample run
config = {
    "emb_dim": 512,
    "ff_dim": 2048,
    "ff_drop": 0.1,
    "n_heads": 8,
    "head_dim": 64,
    "dropout": 0.1
}

trasnformer_block = TransformerLayer(config)
out = trasnformer_block(out)
print(out.shape)

torch.Size([1, 197, 512])


### ViT

In [55]:
class ViT(nn.Module):
    def __init__(self, config):
        super().__init__()
        n_layers = config["n_layers"]
        emb_dim = config["emb_dim"]
        num_dim = config["num_dim"] # number of classes
        self.patch_embedding = PatchEmbedding(config)
        self.transformer = nn.ModuleList([TransformerLayer(config) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(emb_dim)
        self.fc_layer = nn.Linear(emb_dim, num_dim)

    def forward(self, x):
        out = self.patch_embedding(x)
        for layer in self.transformer:
            out = layer(out)
        out = self.norm(out)
        
        out = self.fc_layer(out[:, 0])

        #Logits, No softmax
        return out


In [56]:
#example run
image = torch.randn(1, 3, 224, 224)
config = {
    "image_height": 224,
    "image_width": 224,
    "im_channels": 3,
    "emb_dim": 512,
    "patch_height": 16,
    "patch_width": 16,
    "patch_emb_dropout": 0.1,
    "n_layers": 12,
    "num_dim": 1000,
    "n_heads": 8,
    "head_dim": 64,
    "dropout": 0.1
}

model = ViT(config)
out = model(image)
print(out.shape)

torch.Size([1, 1000])
