Importing packages needed.

In [1]:
import einops 
from tqdm.notebook import tqdm

from torchsummary import summary

import torch
from torch import nn
import torchvision
import torch.optim as optim
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomCrop

In [2]:
!jupyter nbextension enable --py widgetsnbextension

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


Set hyperparameters of the network and specify device.

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

patch_size = 16
latent_size = 768
n_channels = 3
num_heads = 12
num_encoders = 12
dropout = 0.1
num_classes = 10
size = 224

epochs = 10
base_lr = 10e-3 
weight_decay = 0.03
batch_size = 8

cuda:0


Implementation of input linear projection.

In [4]:
class InputEmbedding(nn.Module):
    def __init__(self, patch_size=patch_size, n_channels=n_channels, device=device, latent_size=latent_size, batch_size=batch_size):
        super(InputEmbedding, self).__init__()
        self.latent_size = latent_size
        self.patch_size = patch_size
        self.n_channels = n_channels
        self.device = device
        self.batch_size = batch_size
        self.input_size = self.patch_size*self.patch_size*self.n_channels

        # Linear projection
        self.linearProjection = nn.Linear(self.input_size, self.latent_size)

        # Class token
        self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)

        # Positional embedding
        self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)

    def forward(self, input_data):
        input_data = input_data.to(self.device)

        # Patchify input image
        patches = einops.rearrange(
            input_data, 'b c (h h1) (w w1) -> b (h w) (h1 w1 c)', h1=self.patch_size, w1=self.patch_size)

        #print(input_data.size())
        #print(patches.size())
        
        linear_projection = self.linearProjection(patches).to(self.device)
        b, n, _ = linear_projection.shape

        linear_projection = torch.cat((self.class_token, linear_projection), dim=1)
        pos_embed = einops.repeat(self.pos_embedding, 'b 1 d -> b m d', m=n+1)
    
        linear_projection += pos_embed

        return linear_projection


In [5]:
test_input = torch.randn((8, 3, 224, 224))
test_class = InputEmbedding().to(device)
embed_test = test_class(test_input)

Implement the Encoder block.

In [6]:
class EncoderBlock(nn.Module):
    def __init__(self, latent_size=latent_size, num_heads=num_heads, device=device, dropout=dropout):
        super(EncoderBlock, self).__init__()

        self.latent_size = latent_size
        self.num_heads = num_heads
        self.device = device
        self.dropout = dropout

        # Normalization layer
        self.norm = nn.LayerNorm(self.latent_size)

        self.multihead = nn.MultiheadAttention(
            self.latent_size, self.num_heads, dropout=self.dropout
        )

        self.enc_MLP = nn.Sequential(
            nn.Linear(self.latent_size, self.latent_size*4),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.latent_size*4, self.latent_size),
            nn.Dropout(self.dropout)
        )

    def forward(self, embedded_patches):
        firstnorm_out = self.norm(embedded_patches)
        attention_out = self.multihead(firstnorm_out, firstnorm_out, firstnorm_out)[0]

        # first residual connection
        first_added = attention_out + embedded_patches

        secondnorm_out = self.norm(first_added)
        ff_out = self.enc_MLP(secondnorm_out)

        return ff_out + first_added


In [7]:
test_encoder = EncoderBlock().to(device)
test_encoder(embed_test)

tensor([[[ 6.2355e-01, -1.8078e+00,  1.1759e+00,  ..., -3.4359e-01,
          -7.2185e-02,  1.3529e+00],
         [ 4.6247e-02, -9.9380e-01, -4.5251e-01,  ..., -1.1911e+00,
           1.4419e+00,  1.3693e+00],
         [ 9.7330e-01, -1.8508e+00,  3.5775e-01,  ..., -1.4842e+00,
           1.3984e+00,  1.1508e+00],
         ...,
         [ 4.4001e-01, -2.2045e+00,  3.7167e-01,  ..., -1.4944e+00,
           6.3034e-01,  7.8634e-01],
         [ 8.0298e-01, -1.7168e+00,  9.0895e-01,  ..., -2.4058e+00,
           1.4339e+00,  3.6456e-01],
         [ 1.2904e+00, -1.6537e+00,  5.2323e-01,  ..., -2.5036e+00,
           5.7349e-02,  1.3578e+00]],

        [[ 1.4517e+00, -6.3832e-01, -1.6566e+00,  ...,  2.2055e+00,
          -2.2559e+00,  5.0645e-01],
         [-1.2711e-02, -2.3971e+00, -1.5669e+00,  ...,  1.5312e+00,
          -8.7514e-01, -1.2052e-01],
         [-4.3736e-01, -2.1193e+00, -1.3339e-01,  ...,  1.6818e+00,
          -1.0663e+00,  6.4820e-01],
         ...,
         [ 1.3788e-02, -1

Put everything together.

In [8]:
class Vit(nn.Module):
    def __init__(self, num_encoders=num_encoders, latent_size=latent_size, device=device, num_classes=num_classes, dropout=dropout):
        super(Vit, self).__init__()
        self.num_encoder = num_encoders
        self.latent_size = latent_size
        self.device = device
        self.num_classes = num_classes
        self.dropout = dropout

        self.embedding = InputEmbedding()

        # Create the stack of encoders
        self.encStack = nn.ModuleList([EncoderBlock() for i in range(self.num_encoder)])

        self.MLP_head = nn.Sequential(
            nn.LayerNorm(self.latent_size),
            nn.Linear(self.latent_size, self.latent_size),
            nn.Linear(self.latent_size, self.num_classes)
        )

    def forward(self, test_input):
        enc_output = self.embedding(test_input)

        for enc_layer in self.encStack:
            enc_output = enc_layer(enc_output)

        cls_token_embed = enc_output[:, 0]

        return self.MLP_head(cls_token_embed)

Test Vit

In [9]:
model = Vit().to(device)
vit_output = model(test_input)
print(vit_output)
print(vit_output.size())

tensor([[-0.3911,  0.3709, -0.0763,  0.3983,  0.0033,  0.1565, -0.2159,  0.3850,
         -0.0957, -0.2416],
        [-0.9099, -0.6156, -0.4327,  0.2852, -0.3389,  0.4391,  0.0589,  0.1776,
          0.2711,  0.2731],
        [-0.4315, -0.3108, -0.3195, -0.1811,  0.1486,  0.3366,  0.2288, -0.6239,
          0.2557,  0.0992],
        [-0.0902,  0.2124, -0.3800,  0.0741,  0.0971, -0.0359, -0.1254,  0.1438,
          0.1679, -0.1064],
        [ 0.3950,  0.1897, -0.0706,  0.3142, -0.2319, -0.5109, -0.6775, -0.2189,
         -0.5458,  0.4139],
        [-0.6946,  0.0742, -0.0423,  0.4796, -0.0975,  0.0267, -0.2221, -0.2862,
          0.0536, -0.4352],
        [-0.4652,  0.3042, -0.1481,  0.3951, -0.0936,  0.3667, -0.0196, -0.0049,
         -0.1430,  0.0348],
        [-0.2570, -0.1593, -0.2347,  0.2582,  0.0360,  0.1477, -0.3275, -0.1828,
          0.4856, -0.1503]], device='cuda:0', grad_fn=<AddmmBackward0>)
torch.Size([8, 10])
