<a href="https://colab.research.google.com/github/SonnetSaif/vision-transformer-from-scratch_PyTorch/blob/main/vision_transformer_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
!pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [2]:
import torch
import torch.nn as nn
from torchvision.datasets import OxfordIIITPet
from torchvision.transforms import Resize, ToTensor

In [3]:
to_tensor = [Resize((144, 144)), ToTensor()]

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image = t(image)
        return image, target

In [4]:
dataset = OxfordIIITPet(root=".", download=True, transforms=Compose(to_tensor))

Downloading https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz to oxford-iiit-pet/images.tar.gz


100%|██████████| 791918971/791918971 [00:20<00:00, 37869133.34it/s]


Extracting oxford-iiit-pet/images.tar.gz to oxford-iiit-pet
Downloading https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz to oxford-iiit-pet/annotations.tar.gz


100%|██████████| 19173078/19173078 [00:01<00:00, 16389472.73it/s]


Extracting oxford-iiit-pet/annotations.tar.gz to oxford-iiit-pet


# Input Embedding

In [5]:
class PatchEmbedding(nn.Module):
  def __init__(self, in_channels, patch_size, embed_dim):
    super().__init__()
    self.patch = nn.Conv2d(
        in_channels=in_channels,
        out_channels=embed_dim,
        kernel_size=patch_size,
        stride=patch_size
    )

  def forward(self, x):
    # Flatten along dim = 2 to maintain channel dimension.
    patches = self.patch(x)
    print("after patch", patches.shape)
    patches = patches.flatten(2)
    print("after flatten", patches.shape)
    patches = patches.transpose(1, 2)
    print("after transpose", patches.shape)
    return patches

In [6]:
first = dataset[0][0]
print("Initial shape:", first.shape)
sample_datapoint = torch.unsqueeze(dataset[0][0], 0)
print("after unsqueeze: ", sample_datapoint.shape)
embedding = PatchEmbedding(3, 8, 128)(sample_datapoint)
print("final Patches shape: ", embedding.shape)

Initial shape: torch.Size([3, 144, 144])
after unsqueeze:  torch.Size([1, 3, 144, 144])
after patch torch.Size([1, 128, 18, 18])
after flatten torch.Size([1, 128, 324])
after transpose torch.Size([1, 324, 128])
final Patches shape:  torch.Size([1, 324, 128])


# Multi-Head Attention

Attention

In [7]:
class AttentionBlock(nn.Module):
  def __init__(self, embed_dim, n_heads, dropout):
    super().__init__()
    self.n_heads = n_heads
    self.attention = torch.nn.MultiheadAttention(embed_dim = embed_dim,
                                                 num_heads = n_heads,
                                                 dropout = dropout)

    self.q = nn.Linear(embed_dim, embed_dim)
    self.k = nn.Linear(embed_dim, embed_dim)
    self.v = nn.Linear(embed_dim, embed_dim)

  def forward(self, x):
    q = self.q(x)
    k = self.q(x)
    v = self.q(x)
    attention_output, attention_output_weights = self.attention(q, k, v)
    return attention_output

In [8]:
AttentionBlock(embed_dim=128, n_heads=4, dropout=0.)
# AttentionBlock(128, 4, 0.)

AttentionBlock(
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (q): Linear(in_features=128, out_features=128, bias=True)
  (k): Linear(in_features=128, out_features=128, bias=True)
  (v): Linear(in_features=128, out_features=128, bias=True)
)

Layer Normalization

In [9]:
class LayerNorm(nn.Module):
  def __init__(self, embed_dim, fn):
    super().__init__()
    self.norm = nn.LayerNorm(embed_dim)
    self.fn = fn

  def forward(self, x):
    return self.fn(self.norm(x))

In [10]:
norm = LayerNorm(128, AttentionBlock(embed_dim=128, n_heads=4, dropout=0.))

Feed Forward

In [11]:
class FeedForward(nn.Module):
  def __init__(self, embed_dim, hidden_dim, dropout):
    super().__init__()
    self.feedForward = nn.Sequential(
        nn.Linear(embed_dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, embed_dim),
        nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.feedForward(x)

In [12]:
feedForward = FeedForward(embed_dim=128, hidden_dim=256, dropout=0.)

Residuals

In [13]:
class Residuals(nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn

  def forward(self, x):
    res = x
    x = self.fn(x)
    x += res
    return x

In [14]:
residual = Residuals(AttentionBlock(embed_dim=128, n_heads=4, dropout=0.))

In [20]:
embed_dim = 128
n_head = 4
dropout = 0.1
patch_size = 4
in_channels = 3
img_size = 144
n_layers = 6
out_dim = 37

In [23]:
from einops import repeat

class ViT(nn.Module):
    def __init__(self, ch=in_channels, img_size=img_size, patch_size=patch_size, emb_dim=embed_dim,
                n_layers=n_layers, out_dim=out_dim, dropout=dropout, heads=n_head):
        super(ViT, self).__init__()

        # Attributes
        self.channels = ch
        self.height = img_size
        self.width = img_size
        self.patch_size = patch_size
        self.n_layers = n_layers

        # Patching
        self.patch_embedding = PatchEmbedding(in_channels=ch,
                                              patch_size=patch_size,
                                              embed_dim=emb_dim)
        # Learnable params
        num_patches = (img_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(
            torch.randn(1, num_patches + 1, emb_dim))
        self.cls_token = nn.Parameter(torch.rand(1, 1, emb_dim))

        # Transformer Encoder
        self.layers = nn.ModuleList([])
        for _ in range(n_layers):
            transformer_block = nn.Sequential(
                Residuals(LayerNorm(emb_dim, AttentionBlock(emb_dim, n_heads = heads, dropout = dropout))),
                Residuals(LayerNorm(emb_dim, FeedForward(emb_dim, emb_dim, dropout = dropout))))
            self.layers.append(transformer_block)

        # Classification head
        self.head = nn.Sequential(nn.LayerNorm(emb_dim), nn.Linear(emb_dim, out_dim))


    def forward(self, img):
        # Get patch embedding vectors
        x = self.patch_embedding(img)
        b, n, _ = x.shape

        # Add cls token to inputs
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat([cls_tokens, x], dim=1)
        x += self.pos_embedding[:, :(n + 1)]

        # Transformer layers
        for i in range(self.n_layers):
            x = self.layers[i](x)

        # Output based on classification token
        return self.head(x[:, 0, :])

In [24]:
model = ViT()
print(model)
model(torch.ones((1, 3, 144, 144)))

ViT(
  (patch_embedding): PatchEmbedding(
    (patch): Conv2d(3, 32, kernel_size=(4, 4), stride=(4, 4))
  )
  (layers): ModuleList(
    (0-5): 6 x Sequential(
      (0): Residuals(
        (fn): LayerNorm(
          (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (fn): AttentionBlock(
            (attention): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
            )
            (q): Linear(in_features=32, out_features=32, bias=True)
            (k): Linear(in_features=32, out_features=32, bias=True)
            (v): Linear(in_features=32, out_features=32, bias=True)
          )
        )
      )
      (1): Residuals(
        (fn): LayerNorm(
          (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (feedForward): Sequential(
              (0): Linear(in_features=32, out_features=32, bias=True)
              (1): GELU(approximate='non

tensor([[ 0.2854, -0.1852,  0.2888,  0.1892, -0.4187,  0.1629, -0.6740, -0.4691,
          0.3448, -0.1879, -0.4464,  0.3888,  0.8779, -0.3254,  0.7555, -0.5407,
         -0.0821, -0.3850, -0.0379, -0.7557,  0.5837, -0.0823,  0.4275, -0.8089,
         -0.3892, -0.3435,  0.2637, -0.3229, -0.2170,  0.5829, -0.9651, -0.5030,
         -0.0632, -0.9391,  0.4440, -0.1836, -0.4092]],
       grad_fn=<AddmmBackward0>)