# Vision Transformer Implementation

In [5]:
# imports
import torch
from torch import nn

from tqdm.auto import tqdm 

from einops import rearrange
from einops.layers.torch import Rearrange

In [7]:
class MLP(nn.Module):
    def __init___(self, input_dim, hidden_dim, output_dim, device):
        super().__init__()

        self.layer = nn.Sequential(
            nn.LayerNorm(input_dim, device = device),
            nn.Linear(input_dim, hidden_dim, device = device),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim, device= device),
        )
    
    def forward(self, x):
        x = self.layer(x)
        return x

In [None]:
# We will be doing this for square images and the patches will be square as well
def reshape_for_vit(self, sample_to_reshape, patch_size):
    b, c, h, w = sample_to_reshape.shape
    assert h % patch_size == 0 and w % patch_size == 0, "Height and Width must be divisible by patch size"

    # reshape c,h,w into c num_patches, patch_size*patch_size *c
    num_patches = (h // patch_size) * (w // patch_size)
    return sample_to_reshape.reshape(b, num_patches, patch_size*patch_size*c)

In [None]:
class MSA(nn.Module):
    def __init__(self, dim, num_heads, device):
        super().__init__()
        self.device = device
        self.mha = nn.MultiheadAttention(dim, num_heads, device = device)
    def forward(self, x):
        x = x.to(self.device)
        x, _ = self.mha(x, x, x)
        return x

In [None]:
class Transformer(nn.Module):
    def __init__(self, dim, hidden_dim, num_heads, layers, device):
        super().__init__()

        # need to figure out how to compute D and then get the z array from x


        self.layer = nn.ModuleList([
            MSA(dim, num_heads, device),
            MLP(dim, hidden_dim, dim, device)
        ])

        self.layers = nn.ModuleList([
            self.layer
            for _ in range(layers)
        ])

        self.layer_norm = nn.LayerNorm(dim)
        
    def forward(self, x):
        for layer in self.layers:
            residual = x
            x = self.layer_norm(x)
            x = layer(x)[0] #MSA
            x = residual + x
            
            residual = x
            x = self.layer_norm(x)
            x = layer(x)[1] #MLP
            x = residual + x
        return x
    

In [None]:
class ViT(nn.Module):
    def __init__(self, seq_len, patch_size, dim, hidden_dim, num_heads, num_layers, device):
        super().__init__()

        self.device = device
        # our seq_len is patch_size * patch_size * channels
        channels = seq_len // (patch_size * patch_size)

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c h w -> b (h*w/p) (p c)', p = patch_size), # same as reshape_for_vit
            nn.LayerNorm(patch_size * patch_size * channels),
            nn.Linear(patch_size * patch_size * channels, dim),
            nn.LayerNorm(dim)
        ) # this is E from the paper as this is the patch embedding projection

        self.transformer = Transformer(dim, hidden_dim, num_heads, num_layers, device)

       

    

torch.Size([1, 196, 768])
