<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/holster/VIT_quickdraw.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Helpful links

# https://github.com/benisalla/Tiny-ViT-Transformer-from-scratch?tab=readme-ov-file#training
# https://github.com/benisalla/Tiny-ViT-Transformer-from-scratch?tab=readme-ov-file#training
# https://medium.com/@tyler_yu/vit-from-scratch-61debb718e99

In [14]:
import torch
import torch.functional as F
import torch.nn as nn
import math

In [11]:
num_channels: int = 3
batch_size:int = 16
image_size: int = 224
patch_size: int = 16
embd_dim: int = (patch_size ** 2) * num_channels           # 768
num_patches: int = (image_size // patch_size) ** 2         # 196
dropout: float = 0.0
device: str = "cuda" if torch.cuda.is_available() else "cpu"

In [12]:
class VisionEmbedding(nn.Module):
    def __init__(self):
        super().__init__()

        self.patch_embedding = nn.Sequential(
            nn.Conv2d(
                in_channels=num_channels,
                out_channels=embd_dim,
                kernel_size=patch_size,
                stride=patch_size,
                padding="valid"
            ),
            nn.Flatten(start_dim=2)
        )

        self.cls_token = nn.Parameter(torch.randn(size=(1, 1, embd_dim)), requires_grad=True)
        self.pos_embeddings = nn.Parameter(torch.randn(size=(1, num_patches + 1, embd_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        patch_embd = self.patch_embedding(x).transpose(2,1)
        patch_embd = torch.cat([cls_token, patch_embd], dim=1)
        embd = self.pos_embeddings + patch_embd
        embd = self.dropout(embd)
        return embd

In [13]:
class AttentionHead(nn.Module):

    def __init__(self, embed_dim, head_size, dropout=0.0):
        super().__init__()
        self.head_size = head_size  # d_k
        # Linear projections for Q, K, V
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.key   = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape  # batch, tokens, embedding

        # Project inputs to Q, K, V
        Q = self.query(x)  # (B, T, head_size)
        K = self.key(x)    # (B, T, head_size)
        V = self.value(x)  # (B, T, head_size)

        # Compute attention scores
        # Q @ K^T per batch
        # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
        scores = Q @ K.transpose(-2, -1)
        scores = scores / math.sqrt(self.head_size)  # scale by head_size



        # Softmax to get attention probabilities
        attn_probs = F.softmax(scores, dim=-1)
        attn_probs = self.dropout(attn_probs)

        # Weighted sum of values
        out = attn_probs @ V  # (B, T, head_size)

        return out

In [15]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head self-attention for ViT.

    Input:
      x: (B, T, embed_dim)
    Output:
      out: (B, T, embed_dim)
    """

    def __init__(self, embed_dim, num_heads, dropout=0.0):
      super().__init__()
      assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

      self.embed_dim = embed_dim
      self.num_heads = num_heads
      self.head_size = embed_dim // num_heads

      self.heads = nn.ModuleList([
        AttentionHead(embed_dim, self.head_size, dropout)
        for _ in range(num_heads)
      ])

      # Output projection
      self.proj = nn.Linear(embed_dim, embed_dim)
      self.dropout = nn.Dropout(dropout)

    def forward(self, x):
      # Run all heads in parallel
      out = torch.cat([head(x) for head in self.heads], dim=-1)
      # Project back to embed_dim
      out = self.proj(out)
      out = self.dropout(out)
      return out


In [16]:
class FeedForward(nn.Module):
    """
    Feed-forward (MLP) block used in ViT.

    Input:
        x: (B, T, embed_dim)
    Output:
        out: (B, T, embed_dim)
    """

    def __init__(self, embed_dim, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


In [17]:
class AttentionBlock(nn.Module):
    """
    Transformer encoder block for ViT.

    Structure:
        x = x + MHA(LN(x))
        x = x + MLP(LN(x))
    """

    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

        self.attn = MultiHeadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout
        )

        self.ffwd = FeedForward(
            embed_dim=embed_dim,
            dropout=dropout
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


In [18]:
class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT) for image classification.
    """

    def __init__(self, config):
        super().__init__()

        self.embedding = VisionEmbedding(config)

        self.blocks = nn.ModuleList([
            AttentionBlock(
                embed_dim=config.embd_dim,
                num_heads=config.num_heads,
                dropout=config.dropout
            )
            for _ in range(config.num_layers)
        ])

        self.ln_f = nn.LayerNorm(
            config.embd_dim,
            eps=config.layer_norm_eps
        )

        self.head = nn.Linear(
            config.embd_dim,
            config.num_classes
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Patch + position embedding
        x = self.embedding(x)  # (B, T, embd_dim)

        # Transformer encoder
        for block in self.blocks:
            x = block(x)

        # Final normalization
        x = self.ln_f(x)

        # Classification via CLS token
        cls_token = x[:, 0]
        logits = self.head(cls_token)

        return logits
