<a href="https://colab.research.google.com/github/atiyehghm/DeepLearningSelfStudy/blob/main/Vision_transformer_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Vision Transformer

This notebook focuses on implementing VIT for Mnist dataset.

[The main article on medium](https://medium.com/@brianpulfer/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c)

In [1]:
! pip install tqdm



In [2]:
import numpy as np
from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7815d0196ff0>

![picture](https://miro.medium.com/v2/resize:fit:1400/format:webp/1*tA7xE2dQA_dfzA0Bub5TVw.png)

In [6]:
## Step1: patchify the image : (N, C, H, W) ---> (N, #of patches, patch dimensionality)
## in our example which is a 28x28 image we convert it to 49 4x4 patches
## Todo: Look for more efficient ways to do this

def patchify(images, n_patches):
  n,c, h, w = images.shape

  assert h==w, "Patchify only works for square images"
  patch_size = h//n_patches
  patches = np.zeros((n, n_patches**2, (c*h*w)//(n_patches**2)))

  for idx, image in enumerate(images):
    for i in range(n_patches):
      for j in range(n_patches):
        patch = image[:, i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size]
        patches[idx, i*n_patches + j] = patch.flatten()

  return patches

In [11]:
## Testing the patchify
x = torch.randn(7, 1, 28, 28) # Dummy images
print(patchify(x, 7).shape) # torch.Size([7, 49, 16])

(7, 49, 16)


In [9]:
'''
hidden_d: dimension of the output of linear mapping
'''

class VIT(nn.Module):
  def __init__(self, chw=(1, 28, 28), n_patches=7, hidden_d=8):
    super(VIT, self).__init__()

    self.chw = chw
    self.n_patches = n_patches
    self.patch_size = (chw[1] // n_patches, chw[2] // n_patches)
    self.hidden_d

    assert chw[1]%n_patches, "Input is not divisible by the number of patches"
    assert chw[2]%n_patches, "Input is not divisible by the number of patches"

    ## Linear mapping
    self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
    self.linear1 = nn.Linear(self.input_d, self.hidden_d)


  def forward(self, images):
    patches = patchify(images, self.n_patches)
    tokens = self.linear1(patches)
    return patches