<a href="https://colab.research.google.com/github/SonnetSaif/vision-transformer-from-scratch_PyTorch/blob/main/vision_transformer_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
from torchvision.datasets import OxfordIIITPet
from torchvision.transforms import Resize, ToTensor

In [3]:
to_tensor = [Resize((144, 144)), ToTensor()]

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image = t(image)
        return image, target

In [4]:
dataset = OxfordIIITPet(root=".", download=True, transforms=Compose(to_tensor))

Downloading https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz to oxford-iiit-pet/images.tar.gz


100%|██████████| 791918971/791918971 [00:18<00:00, 43224388.97it/s]


Extracting oxford-iiit-pet/images.tar.gz to oxford-iiit-pet
Downloading https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz to oxford-iiit-pet/annotations.tar.gz


100%|██████████| 19173078/19173078 [00:01<00:00, 18459871.77it/s]


Extracting oxford-iiit-pet/annotations.tar.gz to oxford-iiit-pet


# Input Embedding

In [5]:
class PatchEmbedding(nn.Module):
  def __init__(self, in_channels, patch_size, embed_dim):
    super().__init__()
    self.patch = nn.Conv2d(
        in_channels=in_channels,
        out_channels=embed_dim,
        kernel_size=patch_size,
        stride=patch_size
    )

  def forward(self, x):
    # Flatten along dim = 2 to maintain channel dimension.
    patches = self.patch(x)
    print("after patch", patches.shape)
    patches = patches.flatten(2)
    print("after flatten", patches.shape)
    patches = patches.transpose(1, 2)
    print("after transpose", patches.shape)
    return patches

In [6]:
first = dataset[0][0]
print("Initial shape:", first.shape)
sample_datapoint = torch.unsqueeze(dataset[0][0], 0)
print("after unsqueeze: ", sample_datapoint.shape)
embedding = PatchEmbedding(3, 8, 128)(sample_datapoint)
print("final Patches shape: ", embedding.shape) #

Initial shape: torch.Size([3, 144, 144])
after unsqueeze:  torch.Size([1, 3, 144, 144])
after patch torch.Size([1, 128, 18, 18])
after flatten torch.Size([1, 128, 324])
after transpose torch.Size([1, 324, 128])
final Patches shape:  torch.Size([1, 324, 128])


# Multi-Head Attention

Attention

In [7]:
class AttentionBlock(nn.Module):
  def __init__(self, embed_dim, n_heads, dropout):
    super().__init__()
    self.n_heads = n_heads
    self.attention = torch.nn.MultiheadAttention(embed_dim = embed_dim,
                                                 num_heads = n_heads,
                                                 dropout = dropout)

    self.q = nn.Linear(embed_dim, embed_dim)
    self.k = nn.Linear(embed_dim, embed_dim)
    self.v = nn.Linear(embed_dim, embed_dim)

  def forward(self, x):
    q = self.q(x)
    k = self.q(x)
    v = self.q(x)
    attention_output, attention_output_weights = self.attention(q, k, v)
    return attention_output

In [8]:
AttentionBlock(embed_dim=128, n_heads=4, dropout=0.)
# AttentionBlock(128, 4, 0.)

AttentionBlock(
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (q): Linear(in_features=128, out_features=128, bias=True)
  (k): Linear(in_features=128, out_features=128, bias=True)
  (v): Linear(in_features=128, out_features=128, bias=True)
)

Layer Normalization

In [9]:
class LayerNorm(nn.Module):
  def __init__(self, embed_dim, fn):
    super().__init__()
    self.norm = nn.LayerNorm(embed_dim)
    self.fn = fn

  def forward(self, x):
    return self.fn(self.norm(x))

In [10]:
norm = LayerNorm(128, AttentionBlock(embed_dim=128, n_heads=4, dropout=0.))

Feed Forward

In [12]:
class FeedForward(nn.Module):
  def __init__(self, embed_dim, hidden_dim, dropout):
    super().__init__()
    self.feedForward = nn.Sequential(
        nn.Linear(embed_dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, embed_dim),
        nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.feedForward(x)

In [14]:
feedForward = FeedForward(embed_dim=128, hidden_dim=256, dropout=0.)

Residuals

In [15]:
class Residuals(nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn

  def forward(self, x):
    res = x
    x = self.fn(x)
    x += res
    return x

In [16]:
residual = Residuals(AttentionBlock(embed_dim=128, n_heads=4, dropout=0.))

In [17]:
embed_dim = 128
n_head = 4
dropout = 0.1
patch_size = 4
in_channels = 3
img_size = 144
n_layers = 6
out_dim = 37