In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import math

In [42]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [43]:
class PatchEmbedding(nn.Module):
  def __init__(self, patch_size, embed_dim, img_h, img_w, in_channels):
    super().__init__()
    self.patch_size = patch_size
    self.num_patches = (img_h // patch_size) * (img_w // patch_size)
    patch_dim = in_channels * (patch_size ** 2)
    # stride is patch size to avoid overlapping kernels
    self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
    self.proj = nn.Linear(in_features=patch_size**2 * in_channels, out_features=embed_dim) # project to input dim D of transformer
    self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
    self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

  def forward(self, x):
    #B, C, H, W is current shape of tensor
    B = x.shape[0]
    patches = self.unfold(x).transpose(1, 2) # B, C * P^2, N (number of patches) -> # B, N, C * P^2
    x = self.proj(patches) # project to transformer input dim
    cls_tokens = self.cls_token.expand(B, -1, -1) #cls token for all samples
    x = torch.cat([cls_tokens, x], dim=1)
    x = x + self.pos_embed
    return x

In [44]:
class LayerNorm(nn.Module):
  def __init__(self, d_model, eps=1e-6):
    super().__init__()
    self.gamma = nn.Parameter(torch.ones(d_model))
    self.beta = nn.Parameter(torch.zeros(d_model))
    self.eps = eps

  def forward(self, x):
    mean = x.mean(-1, keepdim=True) # keep dim to maintain tensor dimensionality
    std = x.std(-1, keepdim=True)
    return self.gamma * (x - mean) / (std + self.eps) + self.beta

In [45]:
class FeedForward(nn.Module):
  def __init__(self, d_model, d_ff, dropout_rate):
    super().__init__()
    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout_rate)

  def forward(self, x):
    x = F.gelu(self.linear1(x))
    x = self.dropout(x)
    return self.linear2(x)

In [46]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, n_heads, dropout_rate):
    super().__init__()
    self.n_heads = n_heads
    self.d_head = d_model // n_heads
    self.W_qkv = nn.Linear(d_model, 3*d_model) # One matrix for q, k, v for computation efficiency
    self.dropout = nn.Dropout(dropout_rate)
    self.W_o = nn.Linear(d_model, d_model)
    self.d_model = d_model
    self.dropout_rate = dropout_rate

  def ScaledDotProductAttention(self, q, k, v, mask=None):
    d_k = k.size(-1) # Q and K have shape B, N_H, T, D_H
    scores = q@k.transpose(-2, -1) / math.sqrt(d_k) # B, N_H, T, T
    if mask is not None:
      scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = self.dropout(F.softmax(scores, dim=-1)) # use last dim because we want to sum across keys to get probs for each query
    return attn@v # B, N_H, T, D_H

  def forward(self, x):
    B, T, D = x.shape

    qkv = self.W_qkv(x)

    q, k, v = torch.split(qkv, self.d_model, dim=-1) # Split into q, k, v

    q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2)
    k = k.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2)
    v = v.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2)

    attn = self.ScaledDotProductAttention(q, k, v) # no causal mask for ViT
    attn_concat = attn.transpose(1, 2).reshape(B, T, D)
    mha_output = self.W_o(attn_concat)

    return mha_output # shape (B, T, D)

In [47]:
class TransformerBlock(nn.Module):
  def __init__(self, d_model, n_heads, d_ff, dropout_rate):
    super().__init__()
    self.mha = MultiHeadAttention(d_model, n_heads, dropout_rate)
    self.layernorm1 = LayerNorm(d_model)
    self.layernorm2 = LayerNorm(d_model)
    self.mlp = FeedForward(d_model, d_ff, dropout_rate)

  def forward(self, x):
    x = x + self.mha(self.layernorm1(x))
    x = x + self.mlp(self.layernorm2(x))
    return x

In [48]:
class ViT(nn.Module):
  def __init__(self, d_model, n_layers, n_heads, d_ff, dropout_rate, patch_size, num_classes, img_h, img_w, in_channels): # Need classes for classification head
    super().__init__()

    self.patch_embedding = PatchEmbedding(patch_size, d_model, img_h, img_w, in_channels)

    self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout_rate)
            for _ in range(n_layers)
        ])
    self.final_layernorm = LayerNorm(d_model)
    self.mlp_head = nn.Linear(d_model, num_classes)
    self.dropout = nn.Dropout(dropout_rate)

  def forward(self, x):

    x = self.patch_embedding(x)

    for transformer_block in self.transformer_blocks:
            x = transformer_block(x)

    x = self.final_layernorm(x)

    cls_token = x[:, 0, :]

    # classification head
    logits = self.mlp_head(cls_token)

    return logits

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

def get_config(dataloader):
  image_batch, _ = next(iter(dataloader))
  B, C, H, W = image_batch.shape
  return {
      "channels": C,
      "height": H,
      "width": W,
      "num_classes": len(dataloader.dataset.classes)
  }

config = get_config(train_loader)

model = ViT(patch_size=4,
            d_model=192,
            n_layers=8,
            n_heads=12,
            d_ff=768,
            dropout_rate=0.1,
            num_classes=config["num_classes"],
            img_h=config["height"],
            img_w=config["width"],
            in_channels=config["channels"]).to(device)

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

epochs = 12
for epoch in range(epochs):
  # Training
  model.train()
  train_loss = 0
  train_correct = 0
  train_total = 0

  for i, (data, targets) in enumerate(train_loader):
    data, targets = data.to(device), targets.to(device)

    optimizer.zero_grad()
    outputs = model(data)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()

    # Calculate accuracy
    _, predicted = outputs.max(1)
    train_total += targets.size(0)
    train_correct += predicted.eq(targets).sum().item()

  # Validation
  model.eval()
  val_loss = 0
  val_correct = 0
  val_total = 0

  with torch.no_grad():
    for data, targets in test_loader:
      data, targets = data.to(device), targets.to(device)
      outputs = model(data)
      val_loss += criterion(outputs, targets).item()

      # Calculate validation accuracy
      _, predicted = outputs.max(1)
      val_total += targets.size(0)
      val_correct += predicted.eq(targets).sum().item()

  # Print epoch results
  train_acc = 100. * train_correct / train_total
  val_acc = 100. * val_correct / val_total
  avg_train_loss = train_loss / len(train_loader)
  avg_val_loss = val_loss / len(test_loader)

  print(f'Epoch {epoch+1}:')
  print(f'  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%')
  print(f'  Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%')
  print('=' * 50)
