## VISION TRANSFORMERS

In [121]:
# Import libraries
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os, sys
import einops
import math
from einops import rearrange, reduce, repeat

# Check for GPU
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print(f"Using PyTorch version {torch.__version__} with GPU...")
else:
    device = torch.device("cpu")
    print("Using CPU")

Using PyTorch version 2.3.0 with GPU...


In [134]:
channels = 3
patch_size = 2
embed_dim = 512

class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, channels, embed_dim):
        """ Initialize Patch Embedding class """
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.channels = channels
        self.embed_dim = embed_dim
        self.linear = nn.Linear(self.patch_size**2 * self.channels, self.embed_dim)
    
    def forward(self, data):
        # data should be channel first format
        patches = rearrange(data, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                            p1=self.patch_size, p2=self.patch_size)
        
        out = self.linear(patches)
        return out



class PositionalEmbedding(nn.Module):
    def __init__(self, N, embed_dim):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.randn(1, N+1, embed_dim))
    def forward(self, x):
        return x + self.pos_embed

In [137]:
X = np.arange(1, 601).reshape(2, 3, 10, 10)
X = torch.Tensor(X)
N = X.shape[2]*X.shape[3]//(patch_size**2)
X.shape

torch.Size([2, 3, 10, 10])

In [138]:
patch_embed = PatchEmbedding(patch_size=patch_size, embed_dim=512, channels=3)
pe = PositionalEmbedding(N=N, embed_dim=512)
out = patch_embed(X)
out.shape

torch.Size([2, 25, 512])

In [None]:
class Attention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)

        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        attn_out, attn_weights = self.attn(q, k, v)
        return attn_out, attn_weights

class EncoderBlock(nn.Module):
    """
    Defines an Encoder block as: Multihead + (Add & Norm) + FFN + (Add & Norm)
    """
    def __init__(self, d_model, expansion_factor, num_heads):
        super(EncoderBlock, self).__init__()

        self.attention = Attention(num_heads=num_heads, embed_dim=d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, expansion_factor*d_model),
            nn.GELU(),
            nn.Linear(expansion_factor*d_model, d_model)
        )
        self.dropout1 = nn.Dropout(0.2)
        self.dropout2 = nn.Dropout(0.2)

    def forward(self, X):
        # output from attention
        attn_out, _ = self.attention(X)

        # do add and normalize
        assert attn_out.shape == X.shape, "Tensor size do not match!"
        res_connection = attn_out + X
        norm1_out = self.dropout1(self.norm1(res_connection))

        # do Feedforward
        feedfwd_out = self.feed_forward(norm1_out)

        # do add and normalize
        feed_fwd_residual_out = feedfwd_out + norm1_out
        norm2_out = self.dropout2(self.norm2(feed_fwd_residual_out))

        return norm2_out

## Transformer Encoder
class TransformerEncoder(nn.Module):
    def __init__(self, d_model, num_layers=6, expansion_factor=4, num_heads=8):
        super(TransformerEncoder, self).__init__()

        self.embedding_layer = PatchEmbedding(patch_size=patch_size, channels=3, embed_dim=d_model)
        self.positional_encoder = PositionalEmbedding(embed_dim=512, N=N)
        self.layers = nn.ModuleList(
            EncoderBlock(d_model=d_model, expansion_factor=expansion_factor, num_heads=num_heads) for i in range(num_layers)
            )

    def forward(self, X):
        E = self.embedding_layer(X)
        out = self.positional_encoder(E)
        for layer in self.layers:
            out = layer(out)
        return out


In [115]:
attn = Attention(embed_dim=512, n_heads=8, dropout=0.1)
attn_out, attn_weights = attn(out)