<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/holster/VIT_quickdraw.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Helpful links

# https://github.com/benisalla/Tiny-ViT-Transformer-from-scratch?tab=readme-ov-file#training
# https://github.com/benisalla/Tiny-ViT-Transformer-from-scratch?tab=readme-ov-file#training
# https://medium.com/@tyler_yu/vit-from-scratch-61debb718e99

In [None]:
import torch
import torch.nn as nn

In [None]:
num_channels: int = 3
batch_size:int = 16
image_size: int = 224
patch_size: int = 16
embd_dim: int = (patch_size ** 2) * num_channels           # 768
num_patches: int = (image_size // patch_size) ** 2         # 196
dropout: float = 0.0
device: str = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class VisionEmbedding(nn.Module):
    def __init__(self):
        super().__init__()

        self.patch_embedding = nn.Sequential(
            nn.Conv2d(
                in_channels=num_channels,
                out_channels=embd_dim,
                kernel_size=patch_size,
                stride=patch_size,
                padding="valid"
            ),
            nn.Flatten(start_dim=2)
        )

        self.cls_token = nn.Parameter(torch.randn(size=(1, 1, embd_dim)), requires_grad=True)
        self.pos_embeddings = nn.Parameter(torch.randn(size=(1, num_patches + 1, embd_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        patch_embd = self.patch_embedding(x).transpose(2,1)
        patch_embd = torch.cat([cls_token, patch_embd], dim=1)
        embd = self.pos_embeddings + patch_embd
        embd = self.dropout(embd)
        return embd

In [None]:
class AttentionHead(nn.Module):

    def __init__(self, embed_dim, head_size, dropout=0.0):
        super().__init__()
        self.head_size = head_size  # d_k
        # Linear projections for Q, K, V
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.key   = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape  # batch, tokens, embedding

        # Project inputs to Q, K, V
        Q = self.query(x)  # (B, T, head_size)
        K = self.key(x)    # (B, T, head_size)
        V = self.value(x)  # (B, T, head_size)

        # Compute attention scores
        # Q @ K^T per batch
        # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
        scores = Q @ K.transpose(-2, -1)
        scores = scores / math.sqrt(self.head_size)  # scale by head_size

        # Softmax to get attention probabilities
        attn_probs = F.softmax(scores, dim=-1)
        attn_probs = self.dropout(attn_probs)

        # Weighted sum of values
        out = attn_probs @ V  # (B, T, head_size)

        return out

In [None]:
class FeedForward(nn.Module):

  def __init__(self):
    super().__init__()
    ...

  def forward(self, x):
    ...

In [None]:
class MultiHeadAttention(nn.Module):

  def __init__(self):
    super().__init__()
    ...

  def forward(self, x):
    ...

In [None]:
class AttentionBlock(nn.Module):

  def __init__(self):
    super().__init__()
    ...

  def forward(self, x):
    ...

In [None]:
class VisionTransformer(nn.Module):

  def __init__(self):
    ...

  def forward(self):
    ...