In [2]:
import lightning as L

import torch
from torch import nn

from torchvision.ops import MLP

# Model

## Path Embeddings

In [3]:
class PatchEmbeddings(nn.Module):
    def __init__(self, img_size, patch_size, dim):
        super().__init__()

        self.num_patches = (img_size // patch_size) ** 2
        self.embedding = nn.LazyConv2d(
            kernel_size=patch_size,
            stride=patch_size,
            out_channels=dim,
        )

    def forward(self, x):
        x = self.embedding(x)
        x = torch.flatten(x, start_dim=2)
        x = torch.transpose(x, 1, 2)
        return x

## Layer

In [4]:
class ViTLayer(nn.Module):
    def __init__(self, dim, num_heads, hidden_channels):
        super().__init__()
        self.norm = [nn.LayerNorm(dim) for _ in range(2)]
        self.attn = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=num_heads,
        )
        self.mlp = MLP(
            in_channels=dim,
            hidden_channels=[hidden_channels, dim],
            activation_layer=nn.GELU
        )

    def forward(self, x):
        x = x + self.attn(*([self.norm[0](x)] * 3), need_weights=False)[0]
        x = x + self.mlp(self.norm[1](x))
        return x

## Vision Transformer (ViT)

In [5]:
class VisionTransformer(nn.Module):
    def __init__(self, num_layers, img_size, num_classes, dim, patch_size, num_heads, hidden_channels):
        super().__init__()
        self.patch_embeddings = PatchEmbeddings(
            img_size=img_size,
            patch_size=patch_size,
            dim=dim
        )
        self.cls_token = nn.Parameter(
            torch.zeros(1, 1, dim)
        )
        self.pos_embeddings = nn.Parameter(
            torch.zeros(1, 1 + self.patch_embeddings.num_patches, dim),
        )
        self.layers = nn.ModuleList([
            ViTLayer(
                dim=dim,
                num_heads=num_heads,
                hidden_channels=hidden_channels
            ) for _ in range(num_layers)
        ])
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = self.patch_embeddings(x)

        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embeddings

        for layer in self.layers:
            x = layer(x)

        x = self.head(x[:, 0])
        return x