In [1]:
import torch
import torch.nn as nn

In [None]:
class PatchEmbedding(nn.Module):
    """Patch the image (needs to be square) and performs a linear projection of the patchs see : """
    def __init__(self, img_size, patch_size, in_channels=3, embedding_dim=512):
        super().__init__()
        self.img_size = img_size
        self.n_patches = (self.img_size // patch_size) ** 2
        self.proj_layer = nn.Conv2d(in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        """x : [n_batches, in_channels, img_size, img_size]
            output : [n_batches, embedding_dim, n_batches]
        """

        x = self.proj_layer(x) #[n_bathces, embedding_dim, sqrt(n_patches), sqrt(n_pathces)]
        x = x.flatten(2) #[n_batches, embedding_dim, n_patches]

        return x

class EncoderBlock(nn.Module):

    def __init__(self, dim, n_heads, mlp_ratio=4, p_dropout=0.5):
        super(EncoderBlock, self).__init__()

        self.dim = dim
        self.n_heads = n_heads
        self.p_dropout = p_dropout
        self.mlp_ratio = mlp_ratio
        self.norm = nn.LayerNorm(self.latent_size)
        self.attention = nn.MultiheadAttention(self.dim, self.n_heads, dropout=self.p_dropout)
        self.MLP = nn.Sequential(
            nn.Linear(self.dim, self.dim * mlp_ratio),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.dim * mlp_ratio, self.dim),
            nn.Dropout(self.dropout)
        )

    def forward(self, x):
        """
        x : [n_samples, n_patches + 1, embedding_dim]
        output : [n_samples, n_patches + 1, embedding_dim]
        """
        first_norm = self.norm(x)
        attention_out = self.attention(first_norm, first_norm, first_norm)
        first_added = attention_out + x
        second_norm = self.norm(first_added)
        mlp_out = self.MLP(second_norm)
        output = mlp_out + first_added

        return output

class ViT(nn.Module):
    def __init__(self,img_size, patch_size=9, in_channels=3, embedding_dim=512, depth=6, n_heads=6, mlp_ratio=4, p_dropout=0.5):
        super().__init__()

        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embedding_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embedding_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embedding.n_patches, embedding_dim))

        
        self.encoder_blocks = nn.ModuleList([ EncoderBlock(embedding_dim, n_heads, mlp_ratio, p_dropout) for _ in range(depth)])

    
    def forward(self, x):
        """
        x : [n_samples, in_channels, img_size, img_size]
        output : [n_samples, 1, embedding_dim]
        """
        n_samples = x.shape[0]
        x = self.patch_embedding(x)

        cls_token = self.cls_token.expand(n_samples, -1, -1) #[n_samples, 1, embedding_dim]
        x = torch.cat((cls_token, x), dim=1) #[n_samples, 1 + n_pathces, embedding_dim]

        x = x + self.pos_embed #[n_samples, 1 + n_pathces, embedding_dim]

        for enc_block in self.encoder_blocks:
            x = enc_block(x)

        return x[:, 0] #[n_samples, 1, embedding_dim] Only extract the token embedding