<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/holster/VIT_quickdraw.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Helpful links

# https://github.com/benisalla/Tiny-ViT-Transformer-from-scratch?tab=readme-ov-file#training
# https://github.com/benisalla/Tiny-ViT-Transformer-from-scratch?tab=readme-ov-file#training
# https://medium.com/@tyler_yu/vit-from-scratch-61debb718e99

In [2]:
!pip install torchinfo kaggle

Collecting torchinfo
  Using cached torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [4]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"chemauer","key":"029f6a06205bcb5ec29e5a41783a6486"}'}

In [5]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [6]:
!kaggle datasets download -d balabaskar/tom-and-jerry-image-classification

Dataset URL: https://www.kaggle.com/datasets/balabaskar/tom-and-jerry-image-classification
License(s): CC0-1.0
Downloading tom-and-jerry-image-classification.zip to /content
 99% 430M/435M [00:02<00:00, 98.6MB/s]
100% 435M/435M [00:02<00:00, 160MB/s] 


In [14]:
!mkdir -p ./data/tom_jerry
!unzip tom-and-jerry-image-classification.zip -d ./data/tom_jerry

[1;30;43mDie letzten 5000Â Zeilen der Streamingausgabe wurden abgeschnitten.[0m
  inflating: ./data/tom_jerry/tom_and_jerry/tom_and_jerry/jerry/frame2793.jpg  
  inflating: ./data/tom_jerry/tom_and_jerry/tom_and_jerry/jerry/frame2794.jpg  
  inflating: ./data/tom_jerry/tom_and_jerry/tom_and_jerry/jerry/frame2795.jpg  
  inflating: ./data/tom_jerry/tom_and_jerry/tom_and_jerry/jerry/frame2796.jpg  
  inflating: ./data/tom_jerry/tom_and_jerry/tom_and_jerry/jerry/frame2797.jpg  
  inflating: ./data/tom_jerry/tom_and_jerry/tom_and_jerry/jerry/frame2798.jpg  
  inflating: ./data/tom_jerry/tom_and_jerry/tom_and_jerry/jerry/frame2799.jpg  
  inflating: ./data/tom_jerry/tom_and_jerry/tom_and_jerry/jerry/frame2800.jpg  
  inflating: ./data/tom_jerry/tom_and_jerry/tom_and_jerry/jerry/frame2801.jpg  
  inflating: ./data/tom_jerry/tom_and_jerry/tom_and_jerry/jerry/frame2802.jpg  
  inflating: ./data/tom_jerry/tom_and_jerry/tom_and_jerry/jerry/frame2803.jpg  
  inflating: ./data/tom_jerry/tom_and_

In [20]:
import torch
from torchinfo import summary
from torch.nn import functional as F
import torch.nn as nn
import math

In [21]:
num_channels: int = 3
batch_size:int = 16
image_size: int = 224
patch_size: int = 16
embd_dim: int = (patch_size ** 2) * num_channels           # 768
num_heads: int = 8
num_classes: int = 10
num_layers: int = 12
num_patches: int = (image_size // patch_size) ** 2         # 196
dropout: float = 0.0
layer_norm_eps: float = 1e-6
device: str = "cuda" if torch.cuda.is_available() else "cpu"

In [22]:
class VisionEmbedding(nn.Module):
    def __init__(self):
        super().__init__()

        self.patch_embedding = nn.Sequential(
            nn.Conv2d(
                in_channels=num_channels,
                out_channels=embd_dim,
                kernel_size=patch_size,
                stride=patch_size,
                padding="valid"
            ),
            nn.Flatten(start_dim=2)
        )

        self.cls_token = nn.Parameter(torch.randn(size=(1, 1, embd_dim)), requires_grad=True)
        self.pos_embeddings = nn.Parameter(torch.randn(size=(1, num_patches + 1, embd_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        patch_embd = self.patch_embedding(x).transpose(2,1)
        patch_embd = torch.cat([cls_token, patch_embd], dim=1)
        embd = self.pos_embeddings + patch_embd
        embd = self.dropout(embd)
        return embd

In [23]:
class AttentionHead(nn.Module):

    def __init__(self, embed_dim, head_size, dropout=0.0):
        super().__init__()
        self.head_size = head_size  # d_k
        # Linear projections for Q, K, V
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.key   = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape  # batch, tokens, embedding

        # Project inputs to Q, K, V
        Q = self.query(x)  # (B, T, head_size)
        K = self.key(x)    # (B, T, head_size)
        V = self.value(x)  # (B, T, head_size)

        # Compute attention scores
        # Q @ K^T per batch
        # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
        scores = Q @ K.transpose(-2, -1)
        scores = scores / math.sqrt(self.head_size)  # scale by head_size



        # Softmax to get attention probabilities
        attn_probs = F.softmax(scores, dim=-1)
        attn_probs = self.dropout(attn_probs)

        # Weighted sum of values
        out = attn_probs @ V  # (B, T, head_size)

        return out

In [24]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head self-attention for ViT.

    Input:
      x: (B, T, embed_dim)
    Output:
      out: (B, T, embed_dim)
    """

    def __init__(self, embed_dim, num_heads, dropout=0.0):
      super().__init__()
      assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

      self.embed_dim = embed_dim
      self.num_heads = num_heads
      self.head_size = embed_dim // num_heads

      self.heads = nn.ModuleList([
        AttentionHead(embed_dim, self.head_size, dropout)
        for _ in range(num_heads)
      ])

      # Output projection
      self.proj = nn.Linear(embed_dim, embed_dim)
      self.dropout = nn.Dropout(dropout)

    def forward(self, x):
      # Run all heads in parallel
      out = torch.cat([head(x) for head in self.heads], dim=-1)
      # Project back to embed_dim
      out = self.proj(out)
      out = self.dropout(out)
      return out


In [25]:
class FeedForward(nn.Module):
    """
    Feed-forward (MLP) block used in ViT.

    Input:
        x: (B, T, embed_dim)
    Output:
        out: (B, T, embed_dim)
    """

    def __init__(self, embed_dim, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


In [26]:
class AttentionBlock(nn.Module):
    """
    Transformer encoder block for ViT.

    Structure:
        x = x + MHA(LN(x))
        x = x + MLP(LN(x))
    """

    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

        self.attn = MultiHeadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout
        )

        self.ffwd = FeedForward(
            embed_dim=embed_dim,
            dropout=dropout
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


In [27]:
class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT) for image classification.
    """

    def __init__(self):
        super().__init__()

        self.embedding = VisionEmbedding()

        self.blocks = nn.ModuleList([
            AttentionBlock(
                embed_dim=embd_dim,
                num_heads=num_heads,
                dropout=dropout
            )
            for _ in range(num_layers)
        ])

        self.ln_f = nn.LayerNorm(
            embd_dim,
            eps=layer_norm_eps
        )

        self.head = nn.Linear(
            embd_dim,
            num_classes
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Patch + position embedding
        x = self.embedding(x)  # (B, T, embd_dim)

        # Transformer encoder
        for block in self.blocks:
            x = block(x)

        # Final normalization
        x = self.ln_f(x)

        # Classification via CLS token
        cls_token = x[:, 0]
        logits = self.head(cls_token)

        return logits


In [28]:
vit = VisionTransformer()
summary(model=vit,
        input_size=(16, 3, 224, 224),
        col_names= ["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings= ["var_names"]
    )

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)              [16, 3, 224, 224]    [16, 10]             --                   True
ââVisionEmbedding (embedding)                      [16, 3, 224, 224]    [16, 197, 768]       152,064              True
â    ââSequential (patch_embedding)                [16, 3, 224, 224]    [16, 768, 196]       --                   True
â    â    ââConv2d (0)                             [16, 3, 224, 224]    [16, 768, 14, 14]    590,592              True
â    â    ââFlatten (1)                            [16, 768, 14, 14]    [16, 768, 196]       --                   --
â    ââDropout (dropout)                           [16, 197, 768]       [16, 197, 768]       --                   --
ââModuleList (blocks)                              --                   --                   --                   True
â    ââ