In [1]:
import torch

import torch.nn as nn

class Config:

    def __init__(self):

        self.img_dim = 224

        self.patch_dim = 16

        self.channels = 3

        #This would be the length of the 1D vector after flattening each patch's embedding. 16,16,3 would be flattened to 16*16*3 i.e 768

        self.embed_dim = self.patch_dim * self.patch_dim * self.channels

        self.attention_heads = 12

        self.attention_head_size = self.embed_dim // self.attention_heads

        self.transformer_layers = 3

        self.classes = 10



class Patcher(nn.Module):

    def __init__(self):

        #Whenever a class involve's usage of nn Module's submodules like Conv2d, Parameter, the class needs to inherit from nn.Module's superclass hence this super() call

        super().__init__()

        config = Config()

        #For an image of dimension (224,224) this would be equal to 224

        self.img_dim = config.img_dim

        #This is the size of the patch. For a patch of size 16 this would be equal to 16

        self.patch_dim = config.patch_dim

        #For a 3 channeled image this would be 3/ For one it would be equal to 1

        self.channels = config.channels

        #Number of patches for an image. If the image dimension is 224,224 and the patch dimensions are 16,16, The number of patches would be (224*224)/(16*16) == 196

        self.num_patches = (config.img_dim * config.img_dim) // (config.patch_dim * config.patch_dim)

        #This results in a 2D Convolution of the image where the kernel size would be the patch size and the size of the output would be (Batch_Size, embed_dim, patch_dim, patch_dim)

        #A note here, stride should be set to patch_dim, since all the patches need to traversed by the transformer and the default value of 1 would be mean traversing a single pixel at a time which is not the goal

        self.out = nn.Conv2d(self.channels, config.embed_dim, self.patch_dim,stride = self.patch_dim)



    def forward(self, x):

        self.projection = self.out(x)

        # print(self.projection.shape)

        return self.projection


In [2]:
'''

1. [CLS] TOKEN: Its a classification token which is a global representation of an image

'''



class Embeddings(nn.Module):

    def __init__(self):

        super().__init__() #nn.Module super() call

        config = Config()

        #The [CLS] token which is used for classification is prepended to the positional embeddings

        self.cls_token = nn.Parameter(torch.randn(1, 1, config.embed_dim))

        #This is the positional embeddings which is a trainable paramater. Its generated for each patch of the image. Its size here is: (1,197,768)

        self.position_embeddings = nn.Parameter(torch.randn(1, ((config.img_dim // config.patch_dim) ** 2)+1, config.embed_dim))

        self.dropout = nn.Dropout(0.1)

        self.patcher = Patcher()



    def forward(self, x):

        #These are the patch embeddings generated from the Conv2D projection of the '''Patcher'''

        x = self.patcher.forward(x)

        #It's size currently is (batch_size, 768, 14, 14) and it needs to be flattened to (batch_size, 196,768)

        x = x.flatten(2).transpose(1, 2)

        #Batch size

        batch_size = x.shape[0]

        #The [CLS] token in the __init__ is just for one image. But we need to generate tokens for all images in the batch. Hence, expand() called which simly repeats the parameter '''batch_size''' times keeping the size of the 2nd (1) and 3rd (768) dimension constant (-1)

        cls_token = self.cls_token.expand(batch_size, -1, -1)

        #The new positional embeddings will have the [CLS] token at the start

        x = torch.cat((cls_token, x), dim=1)

        #The new patch embeddings will be the sum of itself and the position embeddings

        x = x + self.position_embeddings

        x = self.dropout(x)

        return x

In [3]:
class Softmax(nn.Module):

    def __init__(self):

        super().__init__()



    def forward(self, x):

        max_ = torch.max(x, dim=-1, keepdim=True).values

        x = x - max_

        x = torch.exp(x)

        sum_ = torch.sum(x, dim = -1, keepdim=True)

        return x/ sum_

In [4]:
'''Self Attention: Because Vision Transformers are Encoder models, self attention is used wherein Q,K,V belong to the same input sequence'''

'''Attention Score is calculated using the formula : softmax((Q dot K.T)/sqrt(attention_head_size)) dot V'''

'''The outputs of all the heads are concatenated'''

'''All Q,K,V would be of the size : (batch_size, patch_area +1, attention_head_size) or (batch_size, 197, 64)'''



class SelfAttention(nn.Module):

    def __init__(self):

        super().__init__()

        config = Config()

        #This is the size of the output of each head. If the embedding size is 768 and there are 12 attention heads, the head size would be 768/12 = 64

        self.attention_head_size = config.attention_head_size

        #This is the Key Layer which generates the K vector for an input sequence

        self.key = nn.Linear(config.embed_dim, self.attention_head_size)

        #This is the Query Layer which generates the Q vector for an input sequence

        self.query = nn.Linear(config.embed_dim, self.attention_head_size)

        #This is the Value Layer which generates the V vector for an input sequence

        self.value =nn.Linear(config.embed_dim, self.attention_head_size)

        self.dropout = nn.Dropout(0.1)

        self.softmax_fn = Softmax()



    def forward(self, x):

        q = self.query(x)

        k = self.key(x)

        v = self.value(x)

        #This is the dot priduct of the Query Vector and the Key Vector. -2,-1 is done to reverse the 2nd dimension and the 1st dimension for a valid dot porduct

        q_k_dot = torch.matmul(q, k.transpose(-2,-1))

        #The caled score is dividing the attention weight by the square root of the head size (dimensionality)

        scaled_score = q_k_dot / (self.attention_head_size ** 0.5)

        #Softmax function

        attention_score = self.softmax_fn(scaled_score)

        #Dot product of the attention score with the Value vector

        final_attention = torch.matmul(attention_score, v)

        #The output of each head would have the size (_, 197, 64)

        return final_attention


In [5]:
class AttentionHeads(nn.Module):

    def __init__(self):

        super().__init__()

        config = Config()

        #Making a list of all the attention heads

        self.attention_heads = nn.ModuleList()

        for _ in range(config.attention_heads):

            head_ = SelfAttention()

            #Each head's outpupt is calculated

            self.attention_heads.append(head_)

        #This is a trainable layer

        self.output_layer = nn.Linear(config.embed_dim, config.embed_dim)

        self.dropout = nn.Dropout(0.1)



    def forward(self,x):

        head_out = [k(x) for k in self.attention_heads]

        #The output for each head are concatenated. So 12 outputs each of size (_,197,64) -> (_,197,768)

        combined_head_out = torch.cat(head_out, dim = -1)

        out = self.output_layer(combined_head_out)

        out = self.dropout(out)

        return out

        # print(out.shape)


In [6]:
class GELU(nn.Module):

    def __init__(self):

        super().__init__()



    def forward(self, x):

        return 0.5*x*(1 + torch.tanh(((2/3.14) ** 0.5)*(x + 0.044715 * (x**3))))

In [7]:
class Mutli_Layer_Perceptron(nn.Module):

    def __init__(self):

        super().__init__()

        config = Config()

        #First Connection Layer of the MLP

        self.layer1 = nn.Linear(config.embed_dim, config.embed_dim//2)

        #Second Connection Layer of the MLP

        self.layer2 = nn.Linear(config.embed_dim//2, config.embed_dim)

        #GELU Activation

        self.gelu_layer = GELU()

        self.dropout = nn.Dropout(0.1)



    def forward(self, x):

        layer1_output = self.layer1(x)

        gelu_output = self.gelu_layer(layer1_output)

        layer2_output = self.layer2(gelu_output)

        out = self.dropout(layer2_output)

        return out


In [8]:
class EncoderBlock(nn.Module):

    def __init__(self):

        super().__init__()

        config = Config()

        self.multi_attention = AttentionHeads()

        self.layer_norm = nn.LayerNorm(config.embed_dim)

        self.mlp = Mutli_Layer_Perceptron()



    def forward(self, x):

      #Layer Normalization

      out1 = self.layer_norm(x)

      attention_out = self.multi_attention(out1)

      #Skip Connection

      x = x + attention_out

      #Layer Normalization

      out2 = self.layer_norm(x)

      out3 = self.mlp(out2)

      #Skip Connection

      x = x + out3

      return x


In [9]:
class Transformer(nn.Module):

    def __init__(self):

        super().__init__()

        config = Config()

        #Transformer Blocks

        self.transformer_blocks = nn.ModuleList()

        for i in range(config.transformer_layers):

            block = EncoderBlock()

            self.transformer_blocks.append(block)



    def forward(self, x):

        for block in self.transformer_blocks:

            x = block(x)

        return x


In [10]:
'''

These are all the layers and their respective nn modules used uptil now:

1. Patcher: Conv2d

2. Embeddings: Parameter, Dropout

3. SelfAttention: Linear, Dropout

4. AttentionHeads: Linear, Dropout

5. Mutli_Layer_Perceptron: Linear, Dropout

6. EncoderBlock: LayerNorm

'''





'''

All these modules need initialization of their weights and biases

Conv2d: Normal distribution of weights and 0 bias

Linear: Normal distribution of weights and 0 bias

LayerNorm: Weights set to 1, Bias set to 0

Parameter: There's position_embeddings and cls_token. They are set to normal distribution which as been truncated (outliers removed)

'''

"\n\nAll these modules need initialization of their weights and biases\n\nConv2d: Normal distribution of weights and 0 bias\n\nLinear: Normal distribution of weights and 0 bias\n\nLayerNorm: Weights set to 1, Bias set to 0\n\nParameter: There's position_embeddings and cls_token. They are set to normal distribution which as been truncated (outliers removed)\n\n"

In [11]:
class VIT(nn.Module):

    def __init__(self):

        super().__init__()

        config = Config()

        #This is the embedding layer

        self.emb = Embeddings()

        #This is the Encoder.

        self.transformer = Transformer()

        #This is the ifnal layer which works on the CLS token

        self.final_layer = nn.Linear(config.embed_dim, config.classes)

        self.apply(self.initialize_weights_biases)



    def initialize_weights_biases(self, layer):



        if isinstance(layer, Embeddings):

            layer.position_embeddings.data = nn.Parameter(nn.init.trunc_normal_(

                  layer.position_embeddings.data.to(torch.float32),

                  mean=0.0,

                  std=0.01,

              ).to(layer.position_embeddings.dtype))



            layer.cls_token = nn.Parameter(nn.init.trunc_normal_(

                  layer.cls_token.data.to(torch.float32),

                  mean=0.0,

                  std=0.01,

              ).to(layer.cls_token.dtype))



        elif isinstance(layer, nn.Linear):

            nn.init.normal_(layer.weight, std = 0.02)

            nn.init.zeros_(layer.bias)



        elif isinstance(layer, nn.LayerNorm):

            layer.weight.data.fill_(1)

            nn.init.zeros_(layer.bias)



        elif isinstance(layer, nn.Conv2d):

          nn.init.normal_(layer.weight, std = 0.02)

          nn.init.zeros_(layer.bias)



    def forward(self, x):

        x = self.emb(x)

        x = self.transformer(x)

        cls_token = x[:, 0, :]

        out = self.final_layer(cls_token)



        return out


In [12]:
import torch

from torch import nn, optim

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

from tqdm import tqdm



model = VIT()



batch_size = 32

learning_rate = 0.001

epochs = 20



transform = transforms.Compose([

    transforms.Resize((224, 224)),

    transforms.ToTensor(),

    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

])



train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)



test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)



criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)



for epoch in range(epochs):

    model.train()

    train_loss = 0

    correct = 0

    total = 0



    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')



    for images, labels in progress_bar:

        images, labels = images.to(device), labels.to(device)



        optimizer.zero_grad()



        outputs = model(images)

        loss = criterion(outputs, labels)



        loss.backward()

        optimizer.step()



        train_loss += loss.item()

        _, predicted = outputs.max(1)

        total += labels.size(0)

        correct += predicted.eq(labels).sum().item()



        progress_bar.set_postfix({'loss': train_loss / (len(train_loader)), 'accuracy': 100. * correct / total})



    print(f'Epoch {epoch+1}, Loss: {train_loss / len(train_loader)}, Accuracy: {100. * correct / total}%')



model.eval()

correct = 0

total = 0

with torch.no_grad():

    for images, labels in test_loader:

        images, labels = images.to(device), labels.to(device)

        outputs = model(images)

        _, predicted = outputs.max(1)

        total += labels.size(0)

        correct += predicted.eq(labels).sum().item()



print(f'Test Accuracy: {100. * correct / total}%')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 78551170.16it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Epoch 1/20: 100%|██████████| 1563/1563 [04:47<00:00,  5.45it/s, loss=2.17, accuracy=22.9]


Epoch 1, Loss: 2.1745184484537945, Accuracy: 22.854%


Epoch 2/20: 100%|██████████| 1563/1563 [04:48<00:00,  5.42it/s, loss=2.14, accuracy=24.8]


Epoch 2, Loss: 2.1441758984720103, Accuracy: 24.76%


Epoch 3/20: 100%|██████████| 1563/1563 [04:49<00:00,  5.40it/s, loss=2.22, accuracy=24.4]


Epoch 3, Loss: 2.217001100114272, Accuracy: 24.376%


Epoch 4/20: 100%|██████████| 1563/1563 [04:49<00:00,  5.40it/s, loss=2.24, accuracy=24.5]


Epoch 4, Loss: 2.240109932704835, Accuracy: 24.508%


Epoch 5/20: 100%|██████████| 1563/1563 [04:48<00:00,  5.42it/s, loss=2.16, accuracy=25.4]


Epoch 5, Loss: 2.161747473138918, Accuracy: 25.394%


Epoch 6/20: 100%|██████████| 1563/1563 [04:48<00:00,  5.42it/s, loss=2.13, accuracy=25.9]


Epoch 6, Loss: 2.12752443174483, Accuracy: 25.868%


Epoch 7/20: 100%|██████████| 1563/1563 [04:48<00:00,  5.42it/s, loss=2.21, accuracy=26.5]


Epoch 7, Loss: 2.206456618742232, Accuracy: 26.502%


Epoch 8/20: 100%|██████████| 1563/1563 [04:50<00:00,  5.38it/s, loss=2.2, accuracy=27.1]


Epoch 8, Loss: 2.1967945001251943, Accuracy: 27.09%


Epoch 9/20: 100%|██████████| 1563/1563 [04:50<00:00,  5.38it/s, loss=2.07, accuracy=28.1]


Epoch 9, Loss: 2.0673231498141056, Accuracy: 28.066%


Epoch 10/20: 100%|██████████| 1563/1563 [04:51<00:00,  5.37it/s, loss=2, accuracy=29]


Epoch 10, Loss: 1.9982323814338396, Accuracy: 29.024%


Epoch 11/20: 100%|██████████| 1563/1563 [04:50<00:00,  5.38it/s, loss=2.03, accuracy=29.6]


Epoch 11, Loss: 2.0326622954104394, Accuracy: 29.552%


Epoch 12/20: 100%|██████████| 1563/1563 [04:49<00:00,  5.41it/s, loss=1.93, accuracy=30.9]


Epoch 12, Loss: 1.9312089159179024, Accuracy: 30.878%


Epoch 13/20: 100%|██████████| 1563/1563 [04:50<00:00,  5.39it/s, loss=1.93, accuracy=31.5]


Epoch 13, Loss: 1.9277552879550712, Accuracy: 31.528%


Epoch 14/20: 100%|██████████| 1563/1563 [04:47<00:00,  5.43it/s, loss=1.94, accuracy=31.8]


Epoch 14, Loss: 1.9391271962771717, Accuracy: 31.77%


Epoch 15/20: 100%|██████████| 1563/1563 [04:46<00:00,  5.46it/s, loss=1.87, accuracy=33.4]


Epoch 15, Loss: 1.8725772306694866, Accuracy: 33.418%


Epoch 16/20: 100%|██████████| 1563/1563 [04:46<00:00,  5.45it/s, loss=1.84, accuracy=34.8]


Epoch 16, Loss: 1.8420622075740452, Accuracy: 34.782%


Epoch 17/20: 100%|██████████| 1563/1563 [04:47<00:00,  5.44it/s, loss=1.81, accuracy=35.8]


Epoch 17, Loss: 1.813010285622175, Accuracy: 35.81%


Epoch 18/20: 100%|██████████| 1563/1563 [04:48<00:00,  5.42it/s, loss=1.77, accuracy=36.9]


Epoch 18, Loss: 1.7669029583628941, Accuracy: 36.936%


Epoch 19/20: 100%|██████████| 1563/1563 [04:47<00:00,  5.44it/s, loss=1.75, accuracy=37.7]


Epoch 19, Loss: 1.7513279746300277, Accuracy: 37.724%


Epoch 20/20: 100%|██████████| 1563/1563 [04:47<00:00,  5.43it/s, loss=1.68, accuracy=40.3]


Epoch 20, Loss: 1.6795655414605095, Accuracy: 40.282%
Test Accuracy: 45.43%
