## Vision Transformer apply for OCR task

+ [ViT tutorial](https://www.datacamp.com/tutorial/vision-transformers)
+ [multi-head attenion](https://www.datacamp.com/tutorial/multi-head-attention-transformers)
+ [Transformer block in Pytorch](https://docs.pytorch.org/docs/2.9/generated/torch.nn.Transformer.html)
+ [Building a Vision Transformer Model From Scratch](https://medium.com/correll-lab/building-a-vision-transformer-model-from-scratch-a3054f707cc6)
+ [Understanding Multi-Head Attention in Transformers](https://www.datacamp.com/tutorial/multi-head-attention-transformers)
+ [Transformer Course from Standford](https://www.youtube.com/watch?v=P127jhj-8-Y&list=PLoROMvodv4rNiJRchCzutFw5ItR_Z27CM)
+ [Paper : Attension Is All You Need](https://arxiv.org/abs/1706.03762)
+ [Paper : Vision Transformer](https://arxiv.org/abs/2010.11929)

## Load libs

In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.optim import Adam
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader
import numpy as np




## Dataset

## Model

In [3]:
# Patch EMbeddings

class PatchEmbedding(nn.Module):
    def __init__(self, d_model, img_size, patch_size, n_channels):
        super().__init__()
        
        self.d_model = d_model # Dimension of the model
        self.img_size = img_size # Size (H, W)
        self.patch_size = patch_size # (Ph, Pw)
        self.n_channels = n_channels #  3 for RGB, 1 for Grayscale
        
        self.linear_project = nn.Conv2d(self.n_channels, self.d_model, kernel_size=self.patch_size, stride=self.patch_size)
    # B: Batch Size
    # C: Image Channels
    # H: Image Height
    # W: Image Width
    # P_col: Patch Column
    # P_row: Patch Row
    def forward(self, x):
        x = self.linear_project(x) # (B, C, H, W) -> (B, d_model, P_col, P_row)
        x = x.flatten(2) # (B, d_model, P_col, P_row) -> (B, d_model, P)
        x = x.transpose(1, 2) # (B, d_model, P) -> (B, P, d_model)
        return x
        
        
# example usage:
img = torch.randn(8, 3, 128, 128) # (B, C, H, W)
patch_embed = PatchEmbedding(d_model=256, img_size=(128, 128), patch_size=(16, 16), n_channels=3)
patches = patch_embed(img) # (B, P, d)
print(patches.shape) # should be (8, 64, 256) since there
        

torch.Size([8, 64, 256])


In [4]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_height, img_width, patch=16, d_model=256):
        super().__init__()
        assert img_height % patch == 0 and img_width % patch == 0, \
            f"Image H/W must be divisible by patch size {patch}"
        self.patch = patch
        self.num_patches = (img_height // patch) * (img_width // patch)
        self.d_model = d_model
        self.proj = nn.Linear(3 * patch * patch, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, d_model))

    def forward(self, x):
        # x: (B,3,H,W)
        B, C, H, W = x.shape
        p = self.patch
        
        # Cut image into patches
        patches = x.unfold(2,p,p).unfold(3,p,p)  # B,C,Hp,Wp,p,p
        patches = patches.permute(0,2,3,1,4,5).contiguous()  # B,Hp,Wp,C,p,p
        patches = patches.view(B, -1, C*p*p)  # B, Np, C*p*p
        patches = self.proj(patches)  # B, Np, d_model
        patches = patches + self.pos_embed  # add positional encoding

        return patches  # (B, Np, d)
    
    
# example usage:
img = torch.randn(8, 3, 128, 128) # (B, C, H, W)
patch_embed = PatchEmbedding(img_height=128, img_width=128, patch=16, d_model=256)
patches = patch_embed(img) # (B, P, d)
print(patches.shape) # should be (8, 64, 256) since there

torch.Size([8, 64, 256])


## Training

## Inference