In [None]:
# we need updated APIs, so we need updates after torch 1.12+ and torchvision 0.13+
# since the torch version is  already 2.2.2+ and the torchvision version is 0.17.2+ -- we need to assert with right version and this can help install the newer updates 

try:
    import torch
    import torchvision
    assert int(torch.__version__.split(".")[0]) >=1 and int(torch.__version__.split(".")[1]) >=12, "torch version should be more than 1.12+"                                                           
    assert int(torch.__version__.split(".")[1]) >=13, "torchvision version should be more than 1.12+"
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")
except:
    print(f"[INFO] torch.torchvision versions not as required, installing updated versions")
    !pip install -U torch torchvision torchaudio --extra-index-url https://dowload.pytorch.org/whl/cu113
    import torch
    import torchvision
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")
    

In [None]:
from Scripts import data_setup, engine, download_data, set_seeds, plot_loss_curves

In [None]:
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms 

try:
    from torchinfo import summary
except:
    print(f"[INFO] Couldn't find torchinfo...installing it.")
    !pip3 install -q torchinfo
    from torchinfo import summary

In [None]:
# Setup device agnostic code 
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
# Get the data
# Download pizza, steak, sushi images from github
image_path = download_data.download_data(source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",
                                         destination="pizza_steak_sushi")
image_path

In [None]:
# Setup directory paths to train and test images 
train_directory = image_path / "train"
test_directory = image_path / "test"

In [None]:
# Creating Datasets and DataLoaders:
""" 
  First we need to need to create transforms to prepare for our images.
  This is where our reference to the paper will come in.
  In Table 3, the training resolution is mentioned as being 224(height=224, width=224)

 # The Vit paper also states the use of batch of 4096 ,
   which is quite big for our data(downlaoded) and because hardware may not be able to handle this batch size,
   we are going to use batch size of 32
"""

In [None]:
# Prepare transforms for images :
# Create image szie (form ViT paper)
IMAGE_SIZE = 224

# Create transform pipeline manually
manual_transforms = transforms.Compose([
                                        transforms.Resize(size=(IMAGE_SIZE,IMAGE_SIZE)),
                                        transforms.ToTensor()
                                    ])

In [None]:
# Turn images into DataLoaders
import os
# Set the batch size(convinient)
BATCH_SIZE = 32 

# Create dataloaders
train_dataloaders, test_dataloaders , class_names = data_setup.create_dataloaders(train_dir=train_directory,
                                                                                  test_dir=test_directory,
                                                                                  train_transform=manual_transforms,
                                                                                  test_transform=manual_transforms,
                                                                                  batch_size=BATCH_SIZE,
                                                                                  num_workers=os.cpu_count()
                                                                                 )

                                                                                  

In [None]:
# Visualize a single image :
# Get the batch of images 
image_batch, label_batch = next(iter(train_dataloaders))
# Get a single image
image, label = image_batch[0], label_batch[0]
# View the sample image 
image.shape, label

# plot the image 
plt.imshow(image.permute(1,2,0))
plt.title(class_names[label])
plt.axis(False)

In [None]:
# SPLIT DATA INTO PATCHES AND CREATING THE CLASS,POSITION AND PATCH EMBEDDING
"""
We can represent the data in a good, learnable way(as embeddings are learnable representations), chances are, a learning algorithm will be able to perform welll on them.

Starting with patch embedding --This means that we will turn input images in a sequence of patches and then embedd those patches.
Embedding is a learnable representation and often is a vector.

From the ViT paper :
  The standard transformer receives as input a 1D sequence of token embeddings. To handle 2D images, we reshape the images x with dimension HxDxC into a sequence of flattened 2D patches xp where (H,W)is the resolution
  of the original image, C is the number of channels, (P,P) is the resolution of each image patch, and N = HW/P*P is the resulting number of patches, which also serves as the effective input sequence length for the
  transformer.
  The transformer uses constant latent vector size D through all its layers, so we can flatten the patches and map to D dimensions with a trainable linear projection.The output of this projection is patch embeddings.

So,
  D is the size of the patch embeddings, different values for D for various sized ViT models can be found in paper Table 1.
  The image starts as 2D with size H x W x C.
      * (H, W) is the resolution of the original image.
      * C is the number of color channels.
  The image gets converted to a sequence of flattened 2D patches with size N x (P*P . C)
      * (P,P) is the resolution of each image patch(patch size)
      * N = HW/P*P is the resulting number of patches, which also serve as the input sequence length for the transformer.
"""

In [None]:
# Calculating patch embedding input and output shapes by hand 
"Let's create variables to mimic each of the terms"
" We'll use a patch size(P) of 16 since it's the best performing version of ViT-Base uses"


# Create example values 
height = 224      # H
width = 224       # W
color_channels =3 # C
patch_size = 16   # P

# Calculate N (number of patches)
number_of_patches = int((height*width)/patch_size**2)
print(f"Number of patches(N) with image height (H={height}), width (W={width}) and patch size (P={patch_size})")



In [None]:
# Let's replicate the input and output shapes of the patch embedding layer"
"""
Input : the image starts as 2D with size H x W x C.
Output : The image gets converted to a sequence of flattened 2D patches with size N x (P**2 . C)
"""

# Input shape(size of a single image)
embedding_layer_input_shape = (height, width, color_channels)

# Output shape 
embedding_layer_output_shape = (number_of_patches, patch_size**2 * color_channels)

print(f"Input shape (single 2D image): {embedding_layer_input_shape}")
print(f"Output shape (single 2D image flattened into patches): {embedding_layer_output_shape}")    

"Ideal Input and Output shape for patch embedding layer"

In [None]:
# How do we create the patch embedding layer?
"Let's turn a single image into patches"

# View a single image 
plt.imshow(image.permute(1,2,0))
plt.title(class_names[label])
plt.axis(False)

In [None]:
# Visualize the top row of the image 
image_permuted = image.permute(1,2,0)  # H*W*C    , compatible with matplotlib

# Index the plot to top row of patched pixels
patch_size = 16 # for the totality of cell code 
plt.figure(figsize=(patch_size,patch_size))
plt.imshow(image_permuted[:patch_size, :, :]);     # we want the row upto 16 pixels , all the column pixels and all the color channels

In [None]:
# Turn the top row into patches :

# Setup hyperparameters and make sure img_size and patch_size are compatible
img_size = 224
patch_size = 16
num_patches = img_size/patch_size
assert img_size % patch_size == 0, "Image must be divisible by patch size"
print(f"Number of patches per row: {num_patches}\nPatch Size: {patch_size} pixels x {patch_size} pixels")

# Create a series of subplots 
fig, axs = plt.subplots(nrows=1,
                       ncols=img_size//patch_size,
                       figsize=(num_patches, num_patches),
                       sharex=True,
                       sharey=True)

# Iterate through number of patches in the top row
for i, patch in enumerate(range(0, img_size, patch_size)):
    axs[i].imshow(image_permuted[:patch_size, patch:patch+patch_size, :]);
    axs[i].set_xlabel(i+1)  # Set the label
    axs[i].set_xticks([])
    axs[i].set_yticks([])

In [None]:
# Now we do these patches for the whole image 

# Setup hyperparameters and make sure img_size and patch_size are compatible
img_size = 224
patch_size = 16
num_patches = img_size // patch_size
assert img_size % patch_size == 0, "Image must be divisible by patch size"
print(f"Number of patches per row: {num_patches}\
        \nNumber of patches per column: {num_patches}\
        \nTotal patches: {num_patches*num_patches}\
        \nPatch size: {patch_size} pixels X {patch_size} pixels")

# Create a series of subplots 
fig, axs = plt.subplots(nrows=img_size // patch_size,
                        ncols=img_size // patch_size,
                        figsize=(num_patches,num_patches),
                        sharex=True,
                        sharey=True)

# Loop through height and width of image 
for i, patch_height in enumerate(range(0, img_size, patch_size)):    # iterate through height
    for j, patch_width in enumerate(range(0, img_size, patch_size)):  # iterate through width

        # plot the permuted image patch
        axs[i,j].imshow(image_permuted[patch_height:patch_height+patch_size,   # iterate through height
                                       patch_width:patch_width+patch_size,     # iterate through width 
                                       :])

        # Setup the label information, remove the ticks for clarity and set labels to outside 
        axs[i,j].set_ylabel(i+1,
                            rotation="horizontal",
                            horizontalalignment="right",
                            verticalalignment="center")
        axs[i,j].set_xlabel(j+1)
        axs[i,j].set_xticks([])
        axs[i,j].set_yticks([])
        axs[i,j].label_outer()

# Set a super title 
fig.suptitle(f"{class_names[label]} -> Patchified, fontsize=16")
plt.show()
        

In [None]:
# Now how do we turn each of these patches into embedding and convert them into a sequence?
"""
We have seen what an image looks like when it gets turned into patches, now let's start moving towards replicating the patch embedding layers with PyTorch.
The above operation of creating patches of an image is very similar to the convolutional operation in a CNN.

Referencing the ViT paper, in section 3.1 it is mentioned that the patch embedding is achievable with a convolutional neural netwrok(CNN).

"Hybrid architrecture. As an alternate to raw image patches, the input sequence can be formed from feature maps of an CNN. In this hybrid model, the patch embeddding projection E(Eq 1)
is applied to patches extracted from a CNN feature map.As a special case, the patches can have spatial size 1x1,which means that the input sequenceis obtained by simply flattening 
the spatial dimensions of the feature map and projecting to the transformer dimension.The classification input embedding and position emnedding are added as described above."

The "feature map" referred to are the weights/activations produced by the convolutional layer passing over an given image.
By setting the "kernel_size" and "stride" parameters of a torch.nn.Conv2d() layer equal to "patch_size", we can effectively get a layer
that splits our image into patches and creates a learnable embedding of each patch.

# For our image size of 224 and patch szie 16:
* Input : (2D image) 224,224,3 -> (height, width, color_channels)
* output: (flattened 2D patches) (196,768) -> (number of patches, embedding dimension)

# We can recreate these with :
* torch.nn.Conv2d(): for turning image into patches of CNN feature maps.
* torch.nn.Flatten(): for flattening the spatial dimension of the feature map.

# kernel_size = patch_size -> each convolutional kernel will be of size (patch_size, patch_size) 
# stride = patch_size -> each step of the convolutional kernel will be patch_size
# Set in_channels=3 for 3 color_channels and out_channels=768(the same as D value in Table 1 of ViT-Base(this is the embedding dimension, learnable vector of sixe 768))
"""

In [None]:
from torch import nn

# Set the patch size 
patch_size = 16

# Create the Conv2D layer with hyperparameters from the ViT paper
conv2d = nn.Conv2d(in_channels=3,
                   out_channels=768,
                   kernel_size=patch_size,
                   stride=patch_size,
                   padding=0)

In [None]:
# Pass the image through the convolutional layer 
image_conv = conv2d(image.unsqueeze(0))   # Add a single dimension(HxWxC)-> (BxHxWxC)
print(image_conv.shape)

In [None]:
# Plot five random convoltional feature maps :

import random
random_idx = random.sample(range(0,768), k=5)  # pick 5 numbers between 0 and 768
print(f"Showing random convolutional feature maps from indexes: {random_idx}")

# Create Plot:
fig ,axs = plt.subplots(nrows=1, ncols=5, figsize=(12,12))

# Plot random image feature maps
for i, index in enumerate(random_idx):
    image_map = image_conv[:, index, :, :]   # index on the output tensor of the convolutional layer 
    axs[i].imshow(image_map.squeeze().detach().numpy())
    axs[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    
# Get a single feature map in tensor form 
single_feature_map = image_conv[:, 0, :, :]
single_feature_map, single_feature_map.requires_grad
#"""The "grad_fn" output of the single_feature_map and the "requires_grad=True" attribute means PyTorch is tracking the gradients of this feature map 
# and it will be updated by gradient descent during training"""


In [None]:
# Flattening the patch embedding with torch.nn.Flatten()

# Desired output(1D sequence of flattened 2D patches): (196,768)-> (number_of_patches, embedding dimensions)  -> Nx(P**2*C)
# We don't want to flatten the whole tensor but only want to flatten the "spatial dimensions of the feature map".(in this case are the feature_map_height and feature_map_width dimeansions)

# Create flatten layer
flatten = torch.nn.Flatten(start_dim=2,   # flatten feature_map_height (dimension=2)
                           end_dim=3)     # flatten feature_map_width (dimension=3)

In [None]:
# Put it all together :
#1. View a single image 
plt.imshow(image.permute(1,2,0))
plt.title(class_names[label])
plt.axis(False)
print(f"Original image shape: {image.shape}")

# 2. Turn image into feature map
image_conv = conv2d(image.unsqueeze(0))
print(f"Image feature map shape: {image_conv.shape}")

#3. Flatten the feature map
image_flattened = flatten(image_conv)
print(f"Flattened image feature map shape: {image_flattened.shape}")

# the desired output is (196, 768) and the current output is (1,768,196) so we need to reshape our output 

image_flattened_reshaped = image_flattened.permute(0,2,1)   # (batch size, number of patches, embedding dimensions)
print(f"Patch embedding sequence shape: {image_flattened_reshaped.shape}")

In [None]:
# Creatig the patch embedding layer into a PyTorch module
" Subclassing nn.Module and using above information"

#1. Create a class 
class patchembedding(nn.Module):
    """Turns a 2D image into a 1D sequence of learnable embedding vector.

    Args:
    in_channels(int) : Number of color channels for the input images. Default to 3.
    patch_size(int) : Size of patches to convert input image into. Defaults to 16.
    embedding_dim(int) : Size of embedding to turn image into. Defaults to 768.
"""
    # 2. Initialize the class with appropriate variables 
    def __init__(self, 
                 in_channels:int=3,
                 patch_size:int=16,
                 embedding_dim:int=768):
        super().__init__()

        # 3. Create a layer to turn an image into patches(feature map)
        self.patcher = nn.Conv2d(in_channels=in_channels,
                                 out_channels=embedding_dim,
                                 kernel_size=patch_size,
                                 stride=patch_size,
                                 padding=0)

        #4. Create a layer to flatten teh patch feature map into a single dimension 
        self.flatten = nn.Flatten(start_dim=2,
                                  end_dim=3)


    # Define the forward method
    def forward(self, x):
        image_resolution = x.shape[-1]
        # Create assertion to check that inputs are the correct shape
        assert image_resolution % patch_size == 0, f"Input image size must be divisible by patch size, image shape: {image_resolution}, patch_size: {patch_size}"

        # perform the forward pass
        x_patched = self.patcher(x)
        x_flattened = self.flatten(x_patched)

        return x_flattened.permute(0,2,1)  # Adjust so that the embedding is on the final dimension

In [None]:
set_seeds.set_seeds()

patchify = patchembedding(in_channels=3,
                          patch_size=16,
                          embedding_dim=768)

print(f"Input image shape: {image.unsqueeze(0).shape}")
patch_embedded_image = patchify(image.unsqueeze(0))
print(f"Output image shape: {patch_embedded_image.shape}")

# Get summary 
random_input = (1,3,224,224)
summary(patchembedding(),
        input_size=random_input,
        col_names =["input_size","output_size","num_params","trainable"],
        col_width=20,
        row_settings=["var_names"]
       )


In [None]:
# CREATING THE CLASS TOKEN EMBEDDING 
"""
X(CLASS) FOR Eq 1 FORM THE ViT PAPER.
FROM THE SECOND PARAGRAPH OF SECTION 3.1 FROM THE ViT PAPER, WE SEE:
   " Similar to BERT's [class] token , we prepend a learnable embedding to the sequence of embedded patches (z=x(class)), whose state at the output of the transformer encoder serves as the image representation y(Eq 4)."

So we need to "prepend a learnable embedding to the sequence of embedded patches"
"""

In [None]:
# We need to create a learnable embedding in the shape of the embedding_dimensions(D) and then add it to the number_of_patches dimension.

"""
pseudocode:
 patch_embedding = [image_patch_1, image_patch_2, ...]
 class_token = learnable_embedding
 patch_embedding_with_class_token = torch.cat((class_token, patch_embedding)),dim=1)

The torch.cat() happens on dim=1 (the number_of_patches dimension)
To create a learnable embedding :
 we'll get the batch szie and embedding dimension shape and then we'll create a torch.ones() tensor in the shape [batch_size, 1, embedding_dimension]

We'll make teh tensor learnable by passing it to nn.Parameter() with requires_grad=true.
"""

In [None]:
# Get the batch size and embedding dimension
batch_size = patch_embedded_image.shape[0]
embedding_dimension = patch_embedded_image.shape[2] 
# Create the class token embedding as a learnable parameter that shares the same size as the embedding dimension
class_token = nn.Parameter(torch.randn(batch_size, 1, embedding_dimension),     # [ batch_size, number_of_token, embedding_dimension]
                           requires_grad=True)      # to make sure the embedding is learnable

# show the first 10 examples of the class_token 
print(class_token[:, :, :10])

# print the class token shape
print(f"Class token shape: {class_token.shape} -> [batch_size, number_of_tokens, embeddding dimension]")


In [None]:
# Now prepend the class_token to sequence of image patches 

# Add the class token embedding to the front of the patch embedding 
patch_embedded_image_with_class_embedding = torch.cat((class_token, patch_embedded_image),dim=1)  # Concat on first dimension

# print it out 
print(patch_embedded_image_with_class_embedding)
print(f"Sequence of patch embedding with class token prepended shape: {patch_embedded_image_with_class_embedding.shape} -> [ batch_size, number_of_patches, embedding_dimension]")

In [None]:
# CREATING THE POSITION EMBEDDING

"""
From the Vit paper section 3.1:
    " Position embeddings are added to the patch embeddings to retain positional information. We use standard leanable 1D position embedding, since we have not observed significant performanece gains using more advanced 2D
     aware positional embeddings. The resulting sequence of embedding vectors serves as input to the encoder"

By 'retain positional information', the authors mean they want teh architecture to know what "order" the patches come in.
This positional information is important whwn considering what's in an image(without positional information and a flattened sequence couls be seen as having no order and thus no patch relates to any other patch).
"""

# From Eq 1 we have :
"""
E(pos) should have shape (N+1)*D
where,
* N= HW/P**2  is the resulting number of patches, which also serves as effective input sequence lenght for the transformer.
* D is the size of the patch embeddings, differrent value of D can be found in Table 1(embedding dimension)
"""

In [None]:
# Calculate the number of patches :
num_of_patches = int((height*width) / patch_size**2)

# Embedding Dimension:
embedding_dimension = patch_embedded_image_with_class_embedding.shape[2]

# Create a leanable 1D positional embedding
positional_embedding = nn.Parameter(torch.randn(1,
                                               num_of_patches+1,
                                               embedding_dimension),
                                    requires_grad = True)
print(positional_embedding[:, :10, :10])
print(f"Positional embedding shape: {positional_embedding.shape} -> [batch_size, number of patches, embedding dimension]")

# Add the positional embedding to the patch and class embeddings 

patch_and_position_embedding = patch_embedded_image_with_class_embedding + positional_embedding 
print(patch_and_position_embedding)
print(f" Patch embedding, class token prepend and positional embeddings added shape: {patch_and_position_embedding.shape}")

In [None]:
# Putting it all together :

set_seeds.set_seeds()
# 1. Set patch size
patch_size = 16

# 2. print the shape of original image tensor and get the image dimensions:
print(f"Image tensor shape: {image.shape}")
height, width = image.shape[1], image.shape[2]

# 3. Get image tensor and add batch dimension
x = image.unsqueeze(0)
print(f"Image tensor with batch dimension shape: {x.shape}")

# 4. Create patch embedding layer:
patch_embedded_layer = patchembedding(in_channels=3,
                                      patch_size=patch_size,
                                      embedding_dim = 768)
# 5. PAss image through patch embedding layer 
patch_embedding = patch_embedded_layer(x)
print(f"Patch embedding shape: {patch_embedding.shape}")

# 6. Create class token embedding:
batch_size = patch_embedding.shape[0]
embedding_dimension = patch_embedding.shape[-1]
class_token = nn.Parameter(torch.ones(batch_size, 1, embedding_dimension), requires_grad=True)
print(f"Class token embedding shape: {class_token.shape}")

# 7. Prepend class token embedding to patch embedding
patch_embedding_class_token = torch.cat((class_token, patch_embedding), dim=1)
print(f"Patch embedding with class token shape: {patch_embedding_class_token.shape}") 

# 8. Create positional embeddings :
number_of_patches = int((height*width) / patch_size**2)
positional_embedding = nn.Parameter(torch.ones(1, number_of_patches+1, embedding_dimension), requires_grad=True)
print(f"Position embedding shape: {positional_embedding.shape}")

# 9. Add position embedding to patch embedding with class token 
patch_and_position_embedding = patch_embedding_class_token + positional_embedding
print(f"Patch and position embedding shape: {patch_and_position_embedding.shape}")

In [None]:
# EQUATION 2 : MULTI-HEAD ATTENTION(MSA)
"""
The transformer encoder section is divided into two parts: the first being equation 2 and the second being equation 3

From equation 2 states:
 "A multi-head attention(MSA) layer wrapped in a LayerNorm(LN) layer with a residual connection(the input to the layers gets added to the output of the layer)"
Equation 2 refers to MSA block.

We can replicate these layers by using :
* Multi-Head Self Attention(MSA) - torch.nn.MultiheadAttention()
* Norm (LN or LayerNorm) - torch.nn.LayerNorm()
* Residual connection - add the input to output
"""

# Replicating Equation 2 with Pytorch layers :
"""
1. Create a  class called MultiheatSelfAttentionBlock that inherits from nn.Module.
2. Initialize the class with hyperparameters from Table 1 of the ViT paper from the ViT-Base model.
3. Create a layer normalization(LN) layer with torch.nn.LayerNorm() with the normalized_shape parameter the same as our embedding dimension(D form Table 1)
4. Create a multi-head attention(MSA) layer with the appropraite embed_dim, num_heads, dropout, and batch_first parameter.
5. Create a forward() method for our class passing the inputs throught the LN layer and the MSA layer.
"""

In [None]:
#1. Create a class
class MultiheadSelfAttentionBlock(nn.Module):
    #2. Initialize the class with hyperparameters:
    def __init__(self, 
                 embedding_dim: int=768,
                 num_heads: int=12,    # Heads from Table 1 of ViT paper ViT-Base
                 attn_dropout: float=0): # No dropout used in paper for MSA block
        super().__init__()

        #3. Create a Norm Layer(LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        #4. Create a multi head self attention (MSA) layer 
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                    num_heads=num_heads,
                                                    dropout=attn_dropout,
                                                    batch_first=True)    # Does our batch dimension come first?

    # 5. Create a forward function
    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, _ = self.multihead_attn(query=x,  # query embeddings
                                          key=x,    # key embeddings 
                                          value=x,  # value embeddings 
                                          need_weights=False)   # do we need weights or just the layer output?
        return attn_output

# Create an instance of MSA block

multihead_self_attentionblock = MultiheadSelfAttentionBlock(embedding_dim=768,  # from Table 1
                                                           num_heads=12)       # from Table 1

# Pass patch and position embedding through MSA block
patched_image_msa_block = multihead_self_attentionblock(patch_and_position_embedding)
print(f"Input shape of MSA block: {patch_and_position_embedding.shape}")
print(f"Output shape of MSA block: {patched_image_msa_block.shape}")

In [None]:
# EQUATION 3 : MULTILAYER PERCEPTRON (MLP)
""" It is referred to as the MLP block of the transformer encoder:
The term MLP is quite broad as it can refer to almost any combination of multiple layers:
Generally it follows the pattern:
linear layer -> Non-linear layer -> linear layer -> Non-linear layer
IN case of ViT paper the MLP structure is defined in section 3.1 as :
" The MLP contains 2 linear layers with a GELU non-linearity."
GELU non linearity: refers to Gussian Error Linear Units non activation function.
Also in Appendix B.1 it says :
"Table 3 summarizes out training setups for our different models...Droupout, when used, is applied after every dense layer except for the qkv-projrctioins and directly after adding positional-to patch embeddings"
This means that every linear layer(or dense layer) in MLP block has a dropout layer, value of which can be found in Table 3 of ViT paper(ViT-Base, dropout=0.1)


Knowing this the structure of our MLP is :
layer norm -> linear layer -> non-linear layer -> dropout -. linear layer -> dropout

With the hyper parameters values for the linear layer from Table 1(MLP size is the number of hidden units between the linear layers and the hidden size D is the output of the MLP block).
"""

# Replicating Equation 3 with Pytorch layers :
"""
1. Create a  class called MultilayerPerceptronBlock that inherits from nn.Module.
2. Initialize the class with hyperparameters from Table 1 and Table 3 of the ViT paper from the ViT-Base model.
3. Create a layer normalization(LN) layer with torch.nn.LayerNorm() with the normalized_shape parameter the same as our embedding dimension(D form Table 1)
4. Create a sequential series of MLP layer(s) using torch.nn.Linear(),torch.nn.Dropout() and torch.nn.GELU() with appropriate hyperparameter values from Table 1 and Table 3.
5. Create a forward() method for our class passing the inputs throught the LN layer and the MLP layer(s).
"""

In [None]:
#1. Create a class
class MultilayerPerceptronBlock(nn.Module):
    #2. Initialize the hyperparameters 
    def __init__(self, 
                 embedding_dim:int=768,   # Hidden size D from Table 1 from ViT-Base
                 mlp_size:int=3072,       # MLP size form Table 1 
                 dropout: float=0.1):     # dropout from Table 3
        super().__init__()
        
        #3. Create the Norm Layer
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        #4. Create MLP layer(s)
        self.mlp = nn.Sequential(
                                nn.Linear(in_features=embedding_dim,
                                          out_features=mlp_size),
                                nn.GELU(),
                                nn.Dropout(p=dropout),
                                nn.Linear(in_features=mlp_size,
                                          out_features=embedding_dim),
                                nn.Dropout(p=dropout)
        )

    #5. Create a forward method 
    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x


# Create instance of MLP block
mlp_block = MultilayerPerceptronBlock(embedding_dim=768,
                                      mlp_size=3072,
                                      dropout=0.1)

patched_image_mlp_block = mlp_block(patched_image_msa_block)
print(f"Input shape : {patched_image_msa_block.shape}")
print(f"Output shape : {patched_image_mlp_block.shape}")


In [None]:
# CREATING THE TRANSFORMER ENCODER OF ViT-Base BY COMBINING EQUATION 2 AND 3 AND THE CODE BLOCKS FROM ABOVE 
"""
In deep learning, 'Encoder' or 'Auto encoder' generally refers to a stack of layers that encodes an input(turns it some form of numerical representation)
Transformer Encoder will encode our patched image embedding into a learned representation using a series of alternating layers of MSA blocks and MLP blocks , as per section 3.1 of ViT paper:
  ' The transformer encoder consists of alternating layers of multihead selfattention(MSA) and MLP blocks. LayerNorm(LN) is applied before every block,a nd residual connections after every block'

Residual Connections : also called skip connections, are achieved by adding a layer(s) input to its subsequent output.
In case of ViT architecture, residual connections means the input of the MSA block is added beck to the output of the MSA block before if passes to the MLP block.

Pseudocode :
 x_input -> MSA block -> [MSA block_output + x_input] -> MLP block -> [MLP block_output + MSA block_output + X_input] -> ...

MAIN IDEA : of residual connections is that they prevent weight values and gradient updates from getting too small and thus allow deeper networks and in turn allow deeper representations to be learned.

"""

## Creating a transformer encoder by combining our custom made layers :
"""
1. Create a  class called TransformerEncoderBlock that inherits from nn.Module.
2. Initialize the class with hyperparameters from Table 1 and Table 3 of the ViT paper from the ViT-Base model.
3. Instantiate MSA block for equation 2 using MultiheadSelfAttentionBlock from above with appropriate parameters.
4. Instantiate MLP block for equation 3 using our MultilayerPerceptronBlock from above with appropriate parameter.
5. Create a forward() method for our TranformerEncoderBlock class.
6. Create a residual connection for MSA block.
7. Create a residual connection for MLP block.
"""

In [None]:
# 1. Create a class
class TransformerEncoderBlock(nn.Module):
    #2. Initialize the class with hyperparameters
    def __init__(self, 
                 embedding_dim:int=768,
                 num_heads:int=12,
                 mlp_size:int=3072,
                 mlp_dropout:float=0.1,
                 attn_dropout:float=0):
        super().__init__()

        # 3. Create a MSA block:
        self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
                                                     num_heads=num_heads,
                                                     attn_dropout=attn_dropout)
        #4. Create a MLP block:
        self.mlp_block = MultilayerPerceptronBlock(embedding_dim=embedding_dim,
                                              mlp_size=mlp_size,
                                              dropout=mlp_dropout)
    #5. Create forward method:
    def forward(self, x):
        #6 Create residual connections for MSA block
        x = self.msa_block(x) + x
        #7. Create residual connections for mlp block
        x = self.mlp_block(x) + x 

        return x 

# Create an instance of TransformerEncoderBlock
transformer_encoder = TransformerEncoderBlock()

summary(model=transformer_encoder,
        input_size=(1,197,768),
        col_names = ["input_size","output_size","num_params","trainable"],
        col_width=20,
        row_settings=["var_names"])


In [None]:
# Creating a  Transformer Encoder with PyTorch's Transformer Layers
" we can recreate the TranformerencoderBlock using torch.nn.TransformerEncoderLayer() and setting some hyperparameter"

torch_transformer_encoder = torch.nn.TransformerEncoderLayer(d_model=768,   # Hidden size D from Table 1 For ViT-Base
                                                             nhead =12,     # heads from Table 1
                                                             dim_feedforward=3072,   # MLP size from Table 1
                                                             dropout=0.1,   # Amount of dropout for dense layers from Table 3 for ViT-Base
                                                             activation="gelu",    # GELU non linear activation
                                                             batch_first=True,
                                                             norm_first=True)    # Normalize first or after MSA/MLP layers?

summary(model=torch_transformer_encoder,
        input_size=(1,197,768),
        col_names = ["input_size","output_size","num_params","trainable"],
        col_width=20,
        row_settings=["var_names"])


In [None]:
# PUTTING IT ALL TOGETHER TO CREATE ViT ARCHITECTURE:
""" 
We're going to combine all of the blocks that we have created to replicate the full ViT architecture.
From patch and postional embedding to the Transformer Encoder(s) to the MLP head.

We'll add Equation 4 into our overall ViT architecture class.
All we need is a torch.nn.LayerNorm() and a torch.nn.Linear() layer to convert the 0th index of the Transformer Encoder logit outputs to the target number of classes we have.

To create the full architecture we'll also need to stack number of TransformerEncoderBlock s on top of each other, this can be done by
 passing a list of then to torch.nn.Sequential().

Focus on ViT-Base hyperparameters from Table 1 but make the code adaptable to other ViT variants.
"""

# Creating the Vit architecture:
"""
1. Create a class called ViT that inherits from torch.nn.Module.
2. Initialize the class with hyperparameters from Table 1 and Table 3 of the ViT paper for the ViT-Base.
3. Make sure the image size is divisible by the patch size (the image should be split into even patches)
4. Calculate the number of patches using the formula N = HW/P**2, where H is height, W is width and P is is the patch size.
5. Create a learnable class embedding token(equation 1) as done above.
6. Create a learnable position embedding vector(equation 1) as done above.
7. Setup the embedding dropout layer 
8. Create the patch embedding layer using 'patchembedding' class above.
9. Create a series of Transformer Encoder blocks by passing list of TransformerEncoderBlocks created above to torch.nn.Sequential.
10. Create teh MLP head(also called teh classifier had of Equation 4) by passing torch.nn.LayerNorm() (LN) layer and a torch.nn.Linear(out_features=num_classes) layer 
(where the num_classes is the target number of classes) linear layer to torch.nn.Sequential.
11. Create the forward method that accepts the input.
12. Get the batch size of the input(the first dimension of the shape).
13. Create the patch embedding  using layers created in step 8.
14. Create the class tokenn embedding using the layer created in step 5 expand it across the number of batches found in step 12 using torch.tensor.expand().
15. Concatenate the class token embedding created in step 14 to the first dimension of the patch embedding created in step 13 using torch.cat().
16. Add the position embedding created in step 6 to the patch and class token embedding created in step 15.
17. Pass the patch and position embedding through the dropout layer created in step 7.
18. Pass the patch and position embedding from step 16 through transformer encoder layers created in step 9(Equations 2 & 3).
19. Pass index 0 of the output of the stack of transformer Encoder layers from step 18 through the classifier head created in step 10(Equation 4).
20. Done.
"""

In [None]:
#1. create the class
class ViT(nn.Module):
    """ Creates a Vision Transformer architecture with ViT-Base hyperparameter by default."""
    #2. Initialize the class with hyperparameters
    def __init__(self, 
                 img_size:int=224,
                 in_channels:int=3,
                 patch_size:int=16,
                 num_transformer_layers:int=12,
                 embedding_dim:int=768,
                 mlp_size:int=3072,
                 num_heads:int=12,
                 attn_dropout:float=0,
                 mlp_dropout:float=0.1,
                 embedding_dropout:float=0.1,    # dropout for patch and position embedding 
                 num_classes:int=1000):          # Default for ImageNet but can customize this 
        super().__init__()

        # 3.Make the image size divisible by the patch size 
        assert img_size % patch_size == 0, f"Image size should be divisible by patch size, image size: {img_size}, patch size: {patch_size}."

        #4. Calculate the number of patches 
        self.num_patches = (img_size*img_size) // patch_size**2

        #5. Create a learnable class embedding token (needs to go at front of sequence of patch embeddings)
        self.class_embedding = nn.Parameter(data=torch.randn(1,1,embedding_dim), requires_grad=True)

        #6. Create a learnable position embedding 
        self.position_embedding = nn.Parameter(data=torch.randn(1,self.num_patches+1, embedding_dim), requires_grad=True)

        #7. Create embedding dropout value 
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)

        #8. Create patch embedding layer 
        self.patch_embedding = patchembedding(in_channels=in_channels,
                                               patch_size=patch_size,
                                               embedding_dim=embedding_dim)

        #9. create Transformer Encoder block(stacking blocks using nn.Sequential()
        # NOTE: '*' means "ALL"
        self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                                           num_heads=num_heads,
                                                                           mlp_size=mlp_size,
                                                                           mlp_dropout=mlp_dropout,
                                                                           attn_dropout=attn_dropout) for _ in range(num_transformer_layers)])
        #10. Create classifier head 
        self.classifier = nn.Sequential(
                                        nn.LayerNorm(normalized_shape=embedding_dim),
                                        nn.Linear(in_features=embedding_dim,
                                                  out_features=num_classes)
        )

    #11. Create forward method
    def forward(self, x):
        #12. Get batch size 
        batch_size = x.shape[0]

        #13. Create class token embedding and expand it to match the batch size 
        class_token = self.class_embedding.expand(batch_size, -1, -1)     # -1 means to infer the dimension

        #14. Create patch embedding 
        x = self.patch_embedding(x)

        #15. Concatenate class token and patch embedding 
        x = torch.cat((class_token, x), dim=1)

        #16. Add position embedding to patch embedding 
        x = self.position_embedding + x

        #17. Run through embedding dropout 
        x = self.embedding_dropout(x)

        #18. Pass patch. position and class embedding through transformer encoder layers 
        x = self.transformer_encoder(x)

        #19. Put 0 th index logit through classifier
        x = self.classifier(x[:, 0])     # run on each sample in a batch at 0 index 

        return x 
                                                                           
                 

In [None]:
# example 
batch_size = 32
class_token_single = nn.Parameter(data=torch.randn(1,1,768))
class_token_expanded = class_token_single.expand(batch_size, -1, -1)

print(class_token_expanded.shape)

set_seeds.set_seeds()

random_image_tensor = torch.randn(1,3,224,224)

# Create an Instance of ViT class
vit = ViT(num_classes=len(class_names))

vit(random_image_tensor)

In [None]:
# Getting Visual summary

summary(model=vit,
        input_size=(32,3,224,224),
        col_names = ["input_size","output_size","num_params","trainable"],
        col_width=20,
        row_settings=["var_names"])

In [None]:
# Training ViT model

# Import engine module form Scripts 
from Scripts import engine

# Setup the optimizer as per ViT model parameters using hyperparamters from the ViT paper 
optimizer = torch.optim.Adam(params=vit.parameters(),
                             lr=3e-3,                     # Base Learning Rate from Table 3 for ViT-* ImageNet-1k
                             betas=(0.9,0.999),           # Default values but also mentioned in ViT paper section 4.1 (Training & Finetuning)
                             weight_decay=0.3)            # From ViT paper section 4.1 (Training and Fine tuning)

# Setup loss function for multiclass classification
loss_fn = torch.nn.CrossEntropyLoss()

# Set seeds 
set_seeds.set_seeds()

# Train the model using engine module  and save the training results to a dictionary
results = engine.train(model=vit,
                       train_dataloader=train_dataloaders,
                       test_dataloader=test_dataloaders,
                       loss_func=loss_fn,
                       optimizer=optimizer,
                       epochs=10,
                       device=device)
                       


In [None]:
from Scripts import plot_loss_curves 

# PLot our ViT model's loss curves 
plot_loss_curves(results)