In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.optim import Adam
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader
import numpy as np

### <center> Patch Embeddings </center>

In [2]:
class PatchEmbedding(nn.Module):
  """
  Split the image (for ImageNet dataset) into patches using a Conv2d layer
  """
  def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, bias=True):
    super().__init__()

    self.img_size = img_size
    self.patch_size = patch_size
    self.in_channels = in_channels
    self.embed_dim = embed_dim
    
    self.num_patches = (img_size // patch_size) ** 2

    # Using a convolution to get patches
    self.proj = nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)


  # B: batch size
  # C: num channels
  # P_col: patch col
  # P_row: patch row
  def forward(self, x):
    x = self.proj(x) # (B, C, H, W) -> (B, emed_dim, P_col, P_row)

    x = x.flatten(2) # (B, embed_dim, P_col, P_row) -> (B, embed_dim, num_patches)

    x = x.transpose(1, 2) # (B, embed_dim, num_patches) -> (B, num_patches, embed_dim)
    
    return x

### <center> Self Attention </center>

In [None]:
class SelfAttentionEncoder(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, attn_p=0, proj_p=0, flash_attn=True):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = int(embed_dim // num_heads)
        self.scale = self.head_dim ** -0.5 # Normalization
        self.flash_attn = flash_attn

        self.q = nn.Linear(embed_dim, embed_dim)
        self.k = nn.Linear(embed_dim, embed_dim)
        self.v = nn.Linear(embed_dim, embed_dim)
        self.attn_p = attn_p
        self.attn_drop = nn.Dropout(attn_p)

        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        q = self.q(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        if self.flash_attn:
            x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_p if self.training else 0)
        else:
            # (B, H, S, E) -> (B, H, E, S)
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)

        # Output projection
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

### <center>MLP</center>