In [196]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [191]:
class PatchEmbedding(nn.Module):

    def __init__(self, img_size, patch_size, in_channels, embd_dim):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels = in_channels,
                              out_channels = embd_dim,
                              kernel_size = patch_size,
                              stride = patch_size,)
        self.cls_token = nn.Parameter(torch.randn(1,1,embd_dim))
        self.pos_embd = nn.Parameter(torch.randn(1,1+num_patches, embd_dim))

    def forward(self, x):
        self.x = x 
        B = x.size(0)
        self.x1 = self.proj(x)
        self.x2 = self.x1.flatten(2)
        self.x3 = self.x2.transpose(1,2)
        self.clss_token = self.cls_token.expand(B,-1,-1)
        self.x4 = torch.cat((self.clss_token,self.x3), dim=1)
        self.x5 = self.x4 + self.pos_embd
        
        return self.x5

    


In [197]:
class MLP(nn.Module):

    def __init__(self, in_feature, hidden_features, drop_rate):
        super().__init__()
        self.fc1 = nn.Linear(in_feature, hidden_features)
        self.fc2 = nn.Linear(hidden_features, in_feature)
        self.dropout = nn.Dropout(drop_rate)

    def forward(self,x):
        self.x = x
        self.x1 = self.dropout(F.gelu(self.fc1(self.x)))
        self.x2 = self.dropout(self.fc2(self.x1))
        
        return self.x2

In [240]:
class TransformerEncoderLayer(nn.Module):

    def __init__(self, embd_dim, num_heads, mlp_dim, drop_rate):
        super().__init__()
        self.ln1 = nn.LayerNorm(embd_dim)
        self.attn =  nn.MultiheadAttention(embed_dim=embd_dim,
                                           num_heads=num_heads,
                                           dropout=drop_rate,
                                           batch_first=True)
        self.ln2 = nn.LayerNorm(embd_dim)
        self.mlp = MLP(in_feature=embd_dim,
                       hidden_features=mlp_dim,
                       drop_rate=drop_rate)
        
    def forward(self,x):
        x = self.ln1(x)
        x = x + self.attn(x, x, x)[0]
        x = x + self.mlp(self.ln2(x))

        
        return self.x3

In [None]:
class ViT(nn.Module):

    def __init__(self, img_size, patch_size, in_channels, num_classes,
                 emb_dim, mlp_dim, drop_rate, num_heads, depth):
        super().__init__()
        self.patch_emb = PatchEmbedding(img_size=img_size,
                                        patch_size=patch_size,
                                        in_channels=in_channels,
                                        embd_dim=emb_dim)
        self.enc = nn.Sequential(*[TransformerEncoderLayer(emb_dim, num_heads,mlp_dim,drop_rate) for _ in range(depth)])
        self.norm = nn.LayerNorm(emb_dim)
        self.head = nn.Linear(emb_dim, num_classes)

    def forward(self,x):
        x = self.patch_emb(x)
        x = self.enc(x)
        x = self.norm(x)
        cls_token = x[:,0]
        return self.head(cls_token)
        

In [232]:
x = torch.randn(5,3,32,32)
model = PatchEmbedding(
    img_size=32,
    patch_size=4,
    in_channels=3,
    embd_dim=64
)

t = TransformerEncoderLayer(embd_dim=64,
                            mlp_dim=128,
                            num_heads=4,
                            drop_rate=0.1)