In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Patching(nn.Module):
    """
    Patching module for extracting patches from input images.
    """
    def __init__(self, in_channels=3, patch_size=4, embedding_dim=48):
        super().__init__()
        self.patch = nn.Sequential(
            nn.Conv2d(in_channels, embedding_dim,
                      kernel_size=(patch_size, patch_size),
                      stride=(patch_size, patch_size)),
            nn.Flatten(2, 3)
        )

    def forward(self, x):
        """
        Forward pass of the Patching module.
        """
        return self.patch(x).transpose(-2, -1)

class Head(nn.Module):
    """
    Head module for performing self-attention on input features.
    """
    def __init__(self, n_embed, head_size, dropout):
        super().__init__()
        self.n_embed = n_embed
        self.qkv = nn.Linear(n_embed, head_size * 3, bias=False)
        self.attention_dropout = nn.Dropout(dropout)

    def forward(self, x, attention_mask=None):
        """
        Forward pass of the Head module.
        """
        B, T, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=2)
        w = torch.bmm(k, q.transpose(-2, -1)) * (self.n_embed ** -0.5)
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(-1).float()
            w = w * attention_mask
        w = F.softmax(w, dim=-1)
        w = self.attention_dropout(w)
        out = torch.bmm(w, v)
        return out

class MultiHead(nn.Module):
    """
    MultiHead module for combining multiple attention heads.
    """
    def __init__(self, head_size, n_heads, n_embed,dropout):
        super().__init__()
        self.heads = nn.ModuleList([Head(n_embed, head_size, dropout) for _ in range(n_heads)])
        self.proj = nn.Linear(n_embed, n_embed)

    def forward(self, x, attention_mask):
        """
        Forward pass of the MultiHead module.
        """
        out = torch.cat([head(x, attention_mask) for head in self.heads], -1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    """
    FeedForward module for applying a feed-forward neural network to input features.
    """
    def __init__(self, n_embed, mlp_ratio, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, n_embed * mlp_ratio),
            nn.GELU(),
            nn.Linear(n_embed * mlp_ratio, n_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        """
        Forward pass of the FeedForward module.
        """
        return self.net(x)

class Block(nn.Module):
    """
    Block module for a single transformer block.
    """
    def __init__(self, n_embed, head_size, n_heads, dropout, mlp_ratio):
        super().__init__()
        self.multihead = MultiHead(head_size, n_heads, n_embed,dropout)
        self.ffwd = FeedForward(n_embed, mlp_ratio, dropout)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x, attention_mask):
        """
        Forward pass of the Block module.
        """
        x = self.ln1(x)
        x = x + self.multihead(x, attention_mask)
        x = self.ln2(x)
        x = x + self.ffwd(x)
        return x

class ViT(nn.Module):
    """
    Vision Transformer model.
    """
    def __init__(self, in_channels=3, patch_size=4, embedding_dim=48, head_size=12,
                 n_heads=4, n_layers=2, dropout=0.4, mlp_ratio=2, block_size=64, num_classes=10):
        super().__init__()
        self.patch_embedding = Patching(in_channels, patch_size, embedding_dim)
        self.positional_embedding = nn.Embedding(block_size, embedding_dim)
        self.blocks = nn.ModuleList([Block(embedding_dim, head_size, n_heads, dropout, mlp_ratio) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(embedding_dim)
        self.sequence_pooling = nn.Linear(embedding_dim,1)
        self.cl_head = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * mlp_ratio),
            nn.ReLU(),
            nn.Linear(embedding_dim * mlp_ratio, num_classes)
        )
        

    def forward(self, x, attention_mask=None):
        """
        Forward pass of the ViT model.
        """
        x = self.patch_embedding(x)
        x = x + self.positional_embedding(torch.arange(x.shape[1], device=x.device))
        for block in self.blocks:
            x = block(x, attention_mask)
        x = self.ln(x)
        seq_pool = self.sequence_pooling(x).transpose(-2,-1) # B,1,N
        seq_pool = torch.nn.functional.softmax(seq_pool,dim=2) # B,1,N
        x = torch.bmm(seq_pool,x).squeeze(1) # B,D
        x = self.cl_head(x)
        return x
    

In [20]:
device = 'cpu'
model = ViT()
out = model(torch.randn(32,3,32,32))

In [21]:
out.shape

torch.Size([32, 10])