<a href="https://colab.research.google.com/github/SonnetSaif/vision-transformer-from-scratch_PyTorch/blob/main/vision_transformer_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
import torch
import torch.nn as nn
from torchvision.datasets import OxfordIIITPet
from torchvision.transforms import Resize, ToTensor

In [29]:
to_tensor = [Resize((144, 144)), ToTensor()]

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image = t(image)
        return image, target

In [11]:
dataset = OxfordIIITPet(root=".", download=True, transforms=Compose(to_tensor))

# Input Embedding

In [30]:
class PatchEmbedding(nn.Module):
  def __init__(self, in_channels, patch_size, embed_dim):
    super().__init__()
    self.patch = nn.Conv2d(
        in_channels=in_channels,
        out_channels=embed_dim,
        kernel_size=patch_size,
        stride=patch_size
    )

  def forward(self, x):
    # Flatten along dim = 2 to maintain channel dimension.
    patches = self.patch(x)
    print("after patch", patches.shape)
    patches = patches.flatten(2)
    print("after flatten", patches.shape)
    patches = patches.transpose(1, 2)
    print("after transpose", patches.shape)
    return patches

In [31]:
first = dataset[0][0]
print("Initial shape:", first.shape)
sample_datapoint = torch.unsqueeze(dataset[0][0], 0)
print("after unsqueeze: ", sample_datapoint.shape)
embedding = PatchEmbedding(3, 8, 128)(sample_datapoint)
print("final Patches shape: ", embedding.shape) #

Initial shape: torch.Size([3, 144, 144])
after unsqueeze:  torch.Size([1, 3, 144, 144])
after patch torch.Size([1, 128, 18, 18])
after flatten torch.Size([1, 128, 324])
after transpose torch.Size([1, 324, 128])
final Patches shape:  torch.Size([1, 324, 128])


# Multi-Head Attention Block

In [32]:
class AttentionBlock(nn.Module):
  def __init__(self, embed_dim, n_heads, dropout):
    super().__init__()
    self.n_heads = n_heads
    self.attention = torch.nn.MultiheadAttention(embed_dim = embed_dim,
                                                 num_heads = n_heads,
                                                 dropout = dropout)

    self.q = nn.Linear(embed_dim, embed_dim)
    self.k = nn.Linear(embed_dim, embed_dim)
    self.v = nn.Linear(embed_dim, embed_dim)

  def forward(self, x):
    q = self.q(x)
    k = self.q(x)
    v = self.q(x)
    attention_output, attention_output_weights = self.attention(q, k, v)
    return attention_output

In [40]:
AttentionBlock(embed_dim=128, n_heads=4, dropout=0.)
# AttentionBlock(128, 4, 0.)

AttentionBlock(
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (q): Linear(in_features=128, out_features=128, bias=True)
  (k): Linear(in_features=128, out_features=128, bias=True)
  (v): Linear(in_features=128, out_features=128, bias=True)
)

Layer Normalization

In [34]:
class LayerNorm(nn.Module):
  def __init__(self, embed_dim, fn):
    super().__init__()
    self.norm = nn.LayerNorm(embed_dim)
    self.fn = fn

  def forward(self, x):
    return self.fn(self.norm(x))

In [39]:
norm = LayerNorm(128, AttentionBlock(embed_dim=128, n_heads=4, dropout=0.))