In [12]:
import torch.nn as nn
from torch.nn import Softmax, GELU
from dataclasses import dataclass
from einops import rearrange, repeat
import torch

In [9]:
@dataclass
class ImageParams:
    width: int
    height: int
    in_channel: int
@dataclass
class ModelParameters:
    patch_size: int
    inner_dim: int
    transformer_layers: int
    num_head: int
    embed_dropout: float
    attn_dropout: float
    mlp_dropout: float
@dataclass
class Hyperparameters:
    batch_size: int
    out_classes: int


In [None]:
img_info = ImageParams(width=32, height=32, in_channel=3)
mparams = ModelParameters(patch_size=4, inner_dim=256, transformer_layers=6, num_head=4, embed_dropout=0.1, attn_dropout=0, mlp_dropout=0.1)
hparams = Hyperparameters(batch_size=1024, out_classes=10)

In [23]:
img_info = ImageParams(width=32, height=32, in_channel=3)
mparams = ModelParameters(patch_size=4, inner_dim=256, transformer_layers=6, num_head=4, embed_dropout=0.1, attn_dropout=0, mlp_dropout=0.1)
hparams = Hyperparameters(batch_size=1024, out_classes=10)
class PatchEmbedding(nn.Module):
    def __init__(self, mparams, hparams, img_info):
        super(PatchEmbedding, self).__init__()
        self.patch_size = mparams.patch_size
        self.img_size = img_info.width
        self.num_patches = (self.img_size//self.patch_size) * (self.img_size//self.patch_size)
        self.D = mparams.inner_dim
        self.patch_embed = nn.Conv2d(
            in_channels=img_info.in_channel,
            out_channels=self.D,
            kernel_size=self.patch_size,
            stride=self.patch_size
        )
        self.cls_token = nn.Parameter(torch.rand(1,1,self.D))

    def forward(self, x):
        # Input: [B, C, H, W]
        # Output: [B, N, D] here N is selected num_patches(from image) + 1 (cls token)
        b = x.shape[0]
        cls_token = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = self.patch_embed(x)
        x = rearrange(x, 'b d h w -> b (h w) d')
        x = torch.cat((cls_token, x), dim=1)
        return x
class MHA(nn.Module):
    def __init__(self, mparams, hparams):
        super(MHA, self).__init__()
        self.D = mparams.inner_dim
        self.num_head = mparams.num_head
        assert self.D % self.num_head == 0 , 'Inner dimensions and number of attention head need to be perfectly divisible'
        self.head_size = self.D // self.num_head
        self.all_head_size = self.head_size * self.num_head
        # Set up QKV
        self.query = nn.Linear(in_features=self.D, out_features=self.all_head_size)
        self.key = nn.Linear(in_features=self.D, out_features=self.all_head_size)
        self.value = nn.Linear(in_features=self.D, out_features=self.all_head_size)
        self.output = nn.Linear(in_features=self.D, out_features=self.D)
        self.attn_dropout = nn.Dropout(mparams.attn_dropout)
        self.proj_dropout = nn.Dropout(mparams.attn_dropout)
        self.softmax = Softmax(dim=-1)
    def forward(self, x, mask= None):
        # Input: [B, N, D]
        # For atten: [B, num_head, num_patches, head_size]
        # Output: [B, N, D]
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        # For atten: [B, num_head, num_patches, head_size]
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_head)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_head)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_head)
        attn_score = torch.matmul(q, k.transpose(-1, -2))/ self.head_size**0.5
        if mask is not None:
            attn_score = attn_score.masked_fill(mask == 0, -1e9)
        attn_probs = self.softmax(attn_score)
        attn_probs = self.attn_dropout(attn_probs)
         # sum with V
        context = torch.matmul(attn_probs,v) #[B,h,n,d]
        # combine all heads
        context = rearrange(context, 'b h n d -> b n (h d) ')
        output = self.output(context)
        output = self.proj_dropout(output)
        return output
class MLP(nn.Module):
    def __init__(self, mparams, hparams):
        super().__init__()
        self.D = mparams.inner_dim
        self.hidden_dim = 4* self.D
        self.net = nn.Sequential(
            nn.Linear(self.D, self.hidden_dim),
            nn.GELU(),
            nn.Dropout(mparams.mlp_dropout),
            nn.Linear(self.hidden_dim, self.D),
            nn.Dropout(mparams.mlp_dropout)
        )
    def forward(self, x):
        return self.net(x)
class EncoderBlock(nn.Module):
    def __init__(self, mparams, hparams):
        super().__init__()
        self.norm1 = nn.LayerNorm(mparams.inner_dim)
        self.attn = MHA(mparams=mparams, hparams=hparams)
        self.norm2 = nn.LayerNorm(mparams.inner_dim)
        self.ffn = MLP(mparams=mparams, hparams=hparams)
    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x = self.attn(x) + residual

        residual = x
        x = self.norm2(x)
        x = self.ffn(x) + residual
        return x
class Transformer(nn.Module):
    def __init__(self, mparams, hparams):
        super().__init__()
        self.depth = mparams.transformer_layers
        self.layers = nn.ModuleList([
            EncoderBlock(mparams=mparams, hparams=hparams) for _ in range(self.depth)
        ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
class ViT(nn.Module):
    def __init__(self, mparams, hparams, img_info):
        super().__init__()
        image_width = img_info.width
        patch_size = mparams.patch_size
        num_patches = (image_width//patch_size)**2
        self.pos_embed = nn.Parameter(torch.rand(1, num_patches+1, mparams.inner_dim))
        self.patch_embed = PatchEmbedding(mparams=mparams, hparams=hparams, img_info=img_info)
        self.transformer = Transformer(mparams=mparams, hparams=hparams)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(mparams.inner_dim),
            nn.Linear(mparams.inner_dim, hparams.out_classes)
        )
        self.embed_dropout = nn.Dropout(mparams.embed_dropout)
    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.embed_dropout(x)
        x = self.transformer(x)
        cls_token_ouput = x[:,0] # or u can do x.mean(dim=1) if we do a mean pooling for the final cls token
        return self.mlp_head(cls_token_ouput)


In [24]:
test_tensor = torch.rand(2,3,32,32)
print(f'test tensor shape: {test_tensor.shape}')
vit = ViT(mparams=mparams, hparams=hparams, img_info=img_info)
output = vit.forward(test_tensor)
print(f'Output Shape: {output.shape}')

test tensor shape: torch.Size([2, 3, 32, 32])
Output Shape: torch.Size([2, 10])
