<a href="https://colab.research.google.com/github/Nahom32/ViT/blob/main/notebooks/ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms

In [2]:
import math

transform = transforms.Compose([transforms.ToTensor()])

train_data = datasets.CIFAR10(root='./data/cifar-10', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data/cifar-10', train=False, download=True, transform=transform)

100%|██████████| 170M/170M [00:03<00:00, 47.4MB/s]


In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size = 32, patch_size = 4, in_chans = 3, embed_dim = 768):
        super().__init__()
        self.img_size   = img_size
        self.patch_size = patch_size    # P
        self.in_chans   = in_chans      # C
        self.embed_dim  = embed_dim     # D

        self.num_patches = (img_size // patch_size) ** 2        # N = H*W/P^2
        self.flatten_dim = patch_size * patch_size * in_chans   # P^2*C

        self.proj = nn.Linear(self.flatten_dim, embed_dim) # (P^2*C,D)

        self.position_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))
        self.class_embed    = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        B, C, H, W = x.shape

        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.reshape(1, -1, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 1, 3).reshape(B, self.num_patches, -1)

        x = self.proj(x)

        cls_emb = self.class_embed.expand(B, -1, -1)
        x = torch.cat((cls_emb, x), dim = 1)

        x = x + self.position_embed
        return x


In [4]:
patch_embed = PatchEmbedding()

embeddings = patch_embed(torch.stack([train_data[i][0] for i in range(10)]))
print(embeddings.shape)
print(embeddings)


torch.Size([10, 65, 768])
tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-5.2078e-01,  1.0504e-02,  5.4104e-01,  ...,  1.8488e-01,
          -3.0784e-01,  1.0644e-01],
         [-4.0490e-01,  4.1700e-03,  4.2748e-01,  ...,  1.1319e-01,
          -2.0800e-01, -2.2263e-02],
         ...,
         [-3.0173e-01,  7.5302e-02,  3.5745e-01,  ...,  1.8374e-01,
          -1.4220e-01,  1.4374e-01],
         [-3.3998e-01, -3.5417e-02,  3.2596e-01,  ...,  1.0464e-01,
          -6.7990e-02, -7.4361e-02],
         [-1.7090e-01, -1.3885e-02,  2.8166e-01,  ...,  1.0261e-01,
          -6.6580e-02, -3.1313e-02]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-6.6811e-01, -1.0550e-01,  5.6385e-01,  ..., -4.7664e-03,
          -2.9877e-01, -7.1115e-02],
         [-3.9925e-01, -3.2990e-03,  3.9107e-01,  ...,  1.3110e-01,
          -9.6710e-02, -8.2381e-03],
         ...,

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim = 768, num_heads = 4, bias = False, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim   = embed_dim
        self.num_heads   = num_heads
        self.head_dim    = embed_dim // num_heads

        self.query   = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.key     = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.value   = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.out     = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, _ = x.size()

        q = self.query(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.key(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.value(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # do NOT use causal attention as we are not dealing with sequential data (image patches are unordered)
        attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
        attn = attn.softmax(dim=-1)

        out = (attn @ v).permute(0, 2, 1, 3).reshape(B, N, self.embed_dim)

        out = self.out(out)

        return out


In [6]:
MSA = SelfAttention()
LN = nn.LayerNorm(embeddings.shape, bias=False)

MSA(LN(embeddings))


tensor([[[ 0.2489,  0.1068,  0.2229,  ..., -0.2221, -0.5877,  0.2212],
         [ 0.2648,  0.0926,  0.1939,  ..., -0.2076, -0.6067,  0.2941],
         [ 0.2649,  0.0987,  0.2033,  ..., -0.2089, -0.6147,  0.2913],
         ...,
         [ 0.2570,  0.0990,  0.2022,  ..., -0.2134, -0.5974,  0.2629],
         [ 0.2574,  0.0963,  0.2004,  ..., -0.2156, -0.5942,  0.2632],
         [ 0.2552,  0.1006,  0.2089,  ..., -0.2170, -0.5923,  0.2500]],

        [[ 0.2255,  0.1066,  0.2066,  ..., -0.2092, -0.5811,  0.2345],
         [ 0.2371,  0.0945,  0.1928,  ..., -0.1960, -0.5817,  0.2740],
         [ 0.2345,  0.0960,  0.1916,  ..., -0.2061, -0.5797,  0.2649],
         ...,
         [ 0.2299,  0.1054,  0.2050,  ..., -0.1972, -0.5866,  0.2527],
         [ 0.2289,  0.1020,  0.1959,  ..., -0.1942, -0.5860,  0.2612],
         [ 0.2379,  0.0993,  0.1998,  ..., -0.1937, -0.5869,  0.2707]],

        [[ 0.2527,  0.1149,  0.2157,  ..., -0.2263, -0.5973,  0.2352],
         [ 0.2642,  0.1064,  0.1812,  ..., -0

In [7]:
class MLP(nn.Module):
    def __init__(self, embed_dim = 768, bias = False, dropout = 0.1):
        super().__init__()
        self.c_fc = nn.Linear(embed_dim, embed_dim * 4, bias=bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(embed_dim * 4, embed_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)

        return x

class Block(nn.Module):

    def __init__(self, embed_dim = 768, bias = False):
        super().__init__()
        self.ln_1 = nn.LayerNorm(embed_dim, bias=bias)
        self.attn = SelfAttention(embed_dim, bias=bias)
        self.ln_2 = nn.LayerNorm(embed_dim, bias=bias)
        self.mlp = MLP(embed_dim, bias=bias)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x




In [8]:
class ViT(nn.Module):

    def __init__(self, embed_dim = 768, num_layers = 4, out_dim = 10, bias = False, dropout = 0.1):
        super().__init__()

        self.transformer = nn.ModuleDict(dict(
            pe = PatchEmbedding(),
            drop = nn.Dropout(dropout),
            h = nn.ModuleList([Block() for _ in range(num_layers)]),
            ln_f = nn.LayerNorm(embed_dim)
        ))
        self.head = nn.Linear(embed_dim, out_dim, bias=False)

vit = ViT()
vit(torch.stack([train_data[i][0] for i in range(10)]))

        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def forward(self, x):
        emb = self.transformer.pe(x)
        x = self.transformer.drop(emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        class_token = x[:, 0]
        logits = self.head(class_token)
        return logits



In [9]:
vit = ViT()
vit(torch.stack([train_data[i][0] for i in range(10)]))


number of parameters: 28.42M


tensor([[-5.8224e-01, -3.7082e-01,  5.7170e-01, -1.7401e-01,  1.3356e-01,
         -2.1121e-01, -1.3950e-01,  8.3914e-01, -3.8696e-01, -2.5666e-01],
        [-3.3041e-01, -5.4182e-01,  5.2174e-01, -4.4653e-01,  1.6548e-01,
         -1.6080e-02, -2.4076e-01,  9.8670e-01, -4.2810e-01, -5.0980e-01],
        [-5.1813e-01, -3.1301e-01,  4.0145e-01, -2.6649e-02,  2.5983e-01,
         -1.8473e-01, -3.2406e-01,  8.4892e-01, -4.7218e-01, -4.1467e-01],
        [-4.1519e-01, -5.1449e-01,  6.7793e-01, -1.9332e-01,  1.5040e-01,
         -2.7068e-04, -1.4009e-01,  8.9058e-01, -6.0834e-01, -4.3051e-01],
        [-3.3792e-01, -6.3134e-01,  4.4530e-01, -3.9470e-02,  1.9257e-01,
         -4.1653e-02, -2.3850e-01,  8.4040e-01, -5.9365e-01, -3.3698e-01],
        [-6.3377e-01, -5.7146e-01,  5.5477e-01,  6.4145e-02,  7.1113e-02,
          1.0503e-01, -1.1914e-01,  8.2762e-01, -4.7477e-01, -4.8431e-01],
        [-5.1507e-01, -4.0855e-01,  4.6911e-01, -1.5588e-01,  2.6813e-01,
         -2.8914e-02, -2.9787e-0