In [1]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import math

transform = transforms.Compose([transforms.ToTensor()])

train_data = datasets.CIFAR10(root='./data/cifar-10', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data/cifar-10', train=False, download=True, transform=transform)

train_data, test_data

Files already downloaded and verified
Files already downloaded and verified


(Dataset CIFAR10
     Number of datapoints: 50000
     Root location: ./data/cifar-10
     Split: Train
     StandardTransform
 Transform: Compose(
                ToTensor()
            ),
 Dataset CIFAR10
     Number of datapoints: 10000
     Root location: ./data/cifar-10
     Split: Test
     StandardTransform
 Transform: Compose(
                ToTensor()
            ))

In [2]:
train_data.data.shape, len(train_data.targets)

((50000, 32, 32, 3), 50000)

In [3]:
train_data[0][0].numpy().shape, train_data[0][1]

((3, 32, 32), 6)

In [4]:
train_data[0][0].numpy().shape, train_data[0][1]

((3, 32, 32), 6)

In [5]:
patch_size = 4
for i in range(0, 32, patch_size):
    for j in range(0, 32, patch_size):
        patch = train_data[0][0][:, i:i+patch_size, j:j+patch_size]

        # do something with patch...

print(patch.shape)

torch.Size([3, 4, 4])


In [6]:
image = torch.arange(0.,48).reshape(1, 3, 4, 4) 
image, image.unfold(2, 2, 2).unfold(3, 2, 2)

(tensor([[[[ 0.,  1.,  2.,  3.],
           [ 4.,  5.,  6.,  7.],
           [ 8.,  9., 10., 11.],
           [12., 13., 14., 15.]],
 
          [[16., 17., 18., 19.],
           [20., 21., 22., 23.],
           [24., 25., 26., 27.],
           [28., 29., 30., 31.]],
 
          [[32., 33., 34., 35.],
           [36., 37., 38., 39.],
           [40., 41., 42., 43.],
           [44., 45., 46., 47.]]]]),
 tensor([[[[[[ 0.,  1.],
             [ 4.,  5.]],
 
            [[ 2.,  3.],
             [ 6.,  7.]]],
 
 
           [[[ 8.,  9.],
             [12., 13.]],
 
            [[10., 11.],
             [14., 15.]]]],
 
 
 
          [[[[16., 17.],
             [20., 21.]],
 
            [[18., 19.],
             [22., 23.]]],
 
 
           [[[24., 25.],
             [28., 29.]],
 
            [[26., 27.],
             [30., 31.]]]],
 
 
 
          [[[[32., 33.],
             [36., 37.]],
 
            [[34., 35.],
             [38., 39.]]],
 
 
           [[[40., 41.],
             [44.,

In [7]:
image.unfold(2,2,2).unfold(3,2,2).reshape(1, -1, 4, 4)

tensor([[[[ 0.,  1.,  4.,  5.],
          [ 2.,  3.,  6.,  7.],
          [ 8.,  9., 12., 13.],
          [10., 11., 14., 15.]],

         [[16., 17., 20., 21.],
          [18., 19., 22., 23.],
          [24., 25., 28., 29.],
          [26., 27., 30., 31.]],

         [[32., 33., 36., 37.],
          [34., 35., 38., 39.],
          [40., 41., 44., 45.],
          [42., 43., 46., 47.]]]])

In [8]:
image.unfold(2,2,2).unfold(3,2,2).reshape(1, -1, 4, 4).permute(0, 2, 1, 3) # B x N x C x P
image.unfold(2,2,2).unfold(3,2,2).reshape(1, -1, 4, 4).permute(0, 2, 1, 3).reshape(1, 4, -1)

tensor([[[ 0.,  1.,  4.,  5., 16., 17., 20., 21., 32., 33., 36., 37.],
         [ 2.,  3.,  6.,  7., 18., 19., 22., 23., 34., 35., 38., 39.],
         [ 8.,  9., 12., 13., 24., 25., 28., 29., 40., 41., 44., 45.],
         [10., 11., 14., 15., 26., 27., 30., 31., 42., 43., 46., 47.]]])

In [9]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size = 32, patch_size = 4, in_chans = 3, embed_dim = 768):
        super().__init__()
        self.img_size   = img_size
        self.patch_size = patch_size    # P
        self.in_chans   = in_chans      # C
        self.embed_dim  = embed_dim     # D

        self.num_patches = (img_size // patch_size) ** 2        # N = H*W/P^2
        self.flatten_dim = patch_size * patch_size * in_chans   # P^2*C
        
        self.proj = nn.Linear(self.flatten_dim, embed_dim) # (P^2*C,D)

        self.position_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))
        self.class_embed    = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        B, C, H, W = x.shape

        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.reshape(1, -1, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 1, 3).reshape(B, self.num_patches, -1)

        x = self.proj(x)

        cls_emb = self.class_embed.expand(B, -1, -1)
        x = torch.cat((cls_emb, x), dim = 1)

        x = x + self.position_embed
        return x

In [10]:
patch_embed = PatchEmbedding()

embeddings = patch_embed(torch.stack([train_data[i][0] for i in range(10)]))

embeddings, embeddings.shape

(tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.2749, -0.2905, -0.5764,  ...,  0.1573,  0.1988,  0.0659],
          [ 0.2592, -0.1101, -0.4806,  ...,  0.1212,  0.2012, -0.0917],
          ...,
          [ 0.2375, -0.1225, -0.2667,  ...,  0.0106,  0.0968, -0.0470],
          [ 0.1667, -0.2141, -0.4615,  ...,  0.0946,  0.1268, -0.0650],
          [ 0.1417, -0.0721, -0.2739,  ...,  0.0598,  0.1200, -0.0240]],
 
         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.2669, -0.4007, -0.8260,  ...,  0.3725,  0.2589, -0.0055],
          [ 0.1941, -0.2255, -0.5947,  ...,  0.2060,  0.1223, -0.0429],
          ...,
          [ 0.4147, -0.3022, -0.4285,  ...,  0.2373,  0.2337,  0.0757],
          [ 0.1847, -0.1846, -0.6169,  ...,  0.2372,  0.3218,  0.0298],
          [ 0.2384, -0.3314, -0.4986,  ...,  0.0724,  0.1768,  0.0325]],
 
         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.2773, -0.3076,

In [11]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim = 768, num_heads = 4, bias = False, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim   = embed_dim
        self.num_heads   = num_heads
        self.head_dim    = embed_dim // num_heads

        self.query   = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.key     = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.value   = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.out     = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, _ = x.size()

        q = self.query(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.key(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.value(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # do NOT use causal attention as we are not dealing with sequential data (image patches are unordered)
        attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
        attn = attn.softmax(dim=-1)

        out = (attn @ v).permute(0, 2, 1, 3).reshape(B, N, self.embed_dim)

        out = self.out(out)

        return out

In [12]:
MSA = SelfAttention()
LN = nn.LayerNorm(embeddings.shape, bias=False)

MSA(LN(embeddings))

tensor([[[ 0.4521,  0.3022, -0.4003,  ...,  0.1410, -0.1591,  0.1097],
         [ 0.4379,  0.2964, -0.3942,  ...,  0.1399, -0.1582,  0.1012],
         [ 0.4376,  0.2990, -0.3890,  ...,  0.1354, -0.1567,  0.0946],
         ...,
         [ 0.4447,  0.2999, -0.3981,  ...,  0.1387, -0.1594,  0.1015],
         [ 0.4448,  0.3031, -0.3942,  ...,  0.1397, -0.1557,  0.1006],
         [ 0.4457,  0.2967, -0.4000,  ...,  0.1437, -0.1589,  0.1051]],

        [[ 0.4397,  0.3206, -0.3868,  ...,  0.1627, -0.1734,  0.0518],
         [ 0.4380,  0.3194, -0.3835,  ...,  0.1600, -0.1794,  0.0408],
         [ 0.4349,  0.3249, -0.3820,  ...,  0.1565, -0.1751,  0.0447],
         ...,
         [ 0.4388,  0.3122, -0.3854,  ...,  0.1745, -0.1755,  0.0433],
         [ 0.4333,  0.3087, -0.3845,  ...,  0.1708, -0.1781,  0.0381],
         [ 0.4326,  0.3140, -0.3814,  ...,  0.1620, -0.1773,  0.0423]],

        [[ 0.4510,  0.3417, -0.4013,  ...,  0.1494, -0.1787,  0.1014],
         [ 0.4403,  0.3454, -0.3978,  ...,  0

In [13]:
class MLP(nn.Module):
    def __init__(self, embed_dim = 768, bias = False, dropout = 0.1):
        super().__init__()
        self.c_fc = nn.Linear(embed_dim, embed_dim * 4, bias=bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(embed_dim * 4, embed_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)

        return x
    
class Block(nn.Module):

    def __init__(self, embed_dim = 768, bias = False):
        super().__init__()
        self.ln_1 = nn.LayerNorm(embed_dim, bias=bias)
        self.attn = SelfAttention(embed_dim, bias=bias)
        self.ln_2 = nn.LayerNorm(embed_dim, bias=bias)
        self.mlp = MLP(embed_dim, bias=bias)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [14]:
class ViT(nn.Module):

    def __init__(self, embed_dim = 768, num_layers = 4, out_dim = 10, bias = False, dropout = 0.1):
        super().__init__()

        self.transformer = nn.ModuleDict(dict(
            pe = PatchEmbedding(),
            drop = nn.Dropout(dropout),
            h = nn.ModuleList([Block() for _ in range(num_layers)]),
            ln_f = nn.LayerNorm(embed_dim)
        ))
        self.head = nn.Linear(embed_dim, out_dim, bias=False)


        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def forward(self, x):
        emb = self.transformer.pe(x)
        x = self.transformer.drop(emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        class_token = x[:, 0]
        logits = self.head(class_token)
        return logits

In [15]:
vit = ViT()
vit(torch.stack([train_data[i][0] for i in range(10)]))

number of parameters: 28.42M


tensor([[-0.1456, -0.2529, -0.3532, -1.1192,  1.4267,  1.2951, -0.9670,  1.7261,
          0.2198,  0.4939],
        [-0.2264, -0.2649, -0.6843, -1.2652,  1.1855,  1.7658, -1.0856,  2.0121,
          0.0788,  0.5883],
        [-0.2941, -0.2332, -0.5577, -1.1553,  1.3977,  1.4724, -1.0573,  1.9429,
         -0.0756,  0.6941],
        [-0.2021, -0.2922, -0.6660, -1.2247,  1.4233,  1.6697, -1.2059,  1.9832,
          0.0576,  0.7225],
        [-0.4389, -0.3675, -0.6694, -1.2252,  1.2807,  1.2867, -1.2640,  1.8317,
          0.0752,  0.5037],
        [-0.2618, -0.4284, -0.4262, -1.3106,  1.2286,  1.4338, -1.0285,  1.9554,
          0.2316,  0.5074],
        [-0.1456, -0.4101, -0.7117, -1.2143,  1.3792,  1.4249, -1.0340,  1.6290,
          0.3441,  0.5831],
        [-0.4597, -0.2289, -0.4271, -1.0971,  1.1724,  1.6888, -1.1043,  1.7292,
          0.0644,  0.4074],
        [-0.2236, -0.3282, -0.6501, -1.2261,  1.4172,  1.5638, -1.2377,  1.8906,
          0.2984,  0.5549],
        [-0.3231, -