# Assignment 0 Part 2 - Vision Transformers

Vision Transformers (ViTs) represent a significant shift in the field of computer vision, building on the transformative success of Transformers in natural language processing (NLP). The concept of Transformers was first introduced in the seminal 2017 paper, "Attention is All You Need," which revolutionized machine translation and other NLP tasks by employing self-attention mechanisms to process sequences of text. Inspired by this success, researchers have recently extended the Transformer architecture to the domain of computer vision, giving rise to Vision Transformers.

Unlike traditional convolutional neural networks (CNNs), which rely heavily on convolutions and leverage spatial invariance to process images, Vision Transformers take a different approach. Instead of treating an image as a grid of pixels, ViTs divide the image into a sequence of smaller patches, each of which is treated as a token—analogous to words in a sentence. These patches are then processed by the Transformer architecture, enabling the model to capture relationships between different parts of the image without relying on convolutions.

At the core of this architecture is the self-attention mechanism, which allows the model to weigh the importance of different patches relative to each other, effectively modeling the dependencies between various regions of the image. This mechanism is remarkably similar to the self-attention used in language models (LLMs) and other Transformer-based architectures designed for text, demonstrating the versatility and power of the Transformer framework.

One of the key advantages of Vision Transformers is their compatibility with transfer learning. Transfer learning allows you to leverage pretrained weights from a model trained on a large dataset, such as ImageNet, and adapt them to your specific task. This can significantly reduce the time and computational resources required, as you don't need to train a model from scratch. In this assignment, you'll explore how to load a pretrained Vision Transformer model and copy over the weights, enabling you to benefit from the knowledge encoded in the original model. This exercise will give you hands-on experience with one of the most exciting innovations in computer vision today, highlighting both the power of Transformers and the utility of transfer learning.

Here is a list of resources to help you with this part:
- [GPT from Scratch](https://youtu.be/kCc8FmEb1nY?si=h8WFNBNU6tFs7LkM) - This is all you need to start grasping Transformers. While this lecture is centered around GPT, very few modifications are needed to extend this to Vision Transformers.
- [learnpytorch.io](learnpytorch.io) - If you want a textual guide and something higher-level for starters, look into Chapter 08.
- [Pinecone's Intro to Vision Transformers](https://www.pinecone.io/learn/series/image-search/vision-transformers/) - A good conceptual writeup.
- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) - The official Vision Transformer paper.

In [None]:
from typing import Optional, Tuple, Callable, Optional, Type, Union
from functools import partial
from IPython.display import clear_output

from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import timm

import warnings
warnings.filterwarnings('ignore')

## Task 0. Getting our Data

We will be using CIFAR-100 for this part. This is a much harder dataset to perform better on, in comparison to CIFAR-10: beyond the simple fact there are ten times the number of classes, you still have the same number of images, which means there is much fewer opportunities for the model to learn the representations linked to a specific class.

You can load the dataset and convert them to instances of the `DataLoader` in the same fashion as before, be careful with the transformations that are applied (normalize them in the same way as the previous notebook, this will be critical later).

In [None]:
# Define some variables pertaining to our dataset
IMAGE_SIZE = 224
BATCH_SIZE = 16
PATCH_SIZE = 16

# Load in the CIFAR-100 dataset, create the DataLoaders, and visualize a grid of images (same as before)
raise NotImplementedError

train_ds = None
val_ds = None

train_dl = None
val_dl = None

In [None]:
# Get an instance from the dataset - returns a tuple of (input, ground truth)
x, y = train_ds[0]
print(x.shape)

## Task 1. Embedding an Input Image and adding Positional Information

(There is no input required from you for this section, but understand everything here properly)

To use a Transformer for images, the image must first be converted into a format that the model can process, similar to how text is tokenized in NLP tasks. This involves dividing the image into smaller, fixed-size patches and then embedding these patches into vectors.

For example, an image of size 224x224 pixels can be divided into 16x16 patches, resulting in a grid of 14x14 patches. Each patch is then embedded into a vector using a convolutional layer, which extracts features like edges, textures, and colors. This transformation effectively turns a 2D image into a 1D sequence of patches, where each patch contains a small portion of the image's data. These patch embeddings are treated as tokens in a sequence, and are fed into the Transformer model, which applies self-attention to learn relationships between different parts of the image.

So, to convert our image into a sequence of patch embeddings, we need to:

1. Generate fixed-size patches from the image.
2. Embed these patches into a vector space.

We can achieve both of these steps in one go using `nn.Conv2d`.

If the convolution operation consists of a kernel of size $(k, k)$ and a stride of $k$, it effectively breaks the input image into non-overlapping patches of size $k \times k$. The kernel, also known as the filter, slides across the image, covering different sections, or patches, of the input. At each position, the kernel performs a dot product between the filter weights and the corresponding input pixels, followed by summing these products and applying an activation function. The output from each position becomes an element in the resulting feature map.

In [None]:
# Let's explore how we end up with 14x14 patches with the hyperparameters defined so far
number_of_patches = int((IMAGE_SIZE**2) / PATCH_SIZE**2)
print(f"Using {PATCH_SIZE=}, we get {number_of_patches=} for each channel.")
print(f"This is what we expected as {14*14=}.")

# Now if we consider the output as a long sequence of patches, we can compute the expected output shape
print(f"Original input shape: {x.shape}")

patchified = x.contiguous().view(number_of_patches, -1)
print(f"Transformed input shape: {patchified.shape}")

In [None]:
# Create the patchifier
patchifier = nn.Conv2d(in_channels=3,
                       out_channels=768, # as we computed above
                       kernel_size=PATCH_SIZE,
                       stride=PATCH_SIZE,
                       padding=0                       
)

# Transform a batch of inputs to see how it works
x, _ = next(iter(train_dl))
out = patchifier(x)

print(f"Input shape: {x.shape}")
print(f"Patchified shape: {out.shape}")

Quick note about the shape:

$(16, 768, 14, 14) := \text{(batch size, embedding dim, number of patches horizontally, number of patches vertically)}$

Since we want to treat this as a sequence, i.e. losing the 2D structure, we can simply flatten this along the last two axes.

We will also transpose the tensor so that we have the number of channels/features at the end, this is just convention.

In [None]:
patch_emb = out.flatten(start_dim=2).transpose(1, 2) # NCHW -> NLC
print(f"Final shape: {patch_emb.shape}")

Before moving forward, another important point to note is how we will incorporate a "CLS token" or "classification token" later on for our task of image classification.

This is a technique borrowed from [BERT](https://arxiv.org/abs/1810.04805), where you have a learnable embedding meant to represent the entire input sequence. In the context of the Vision Transformer, the CLS token is a special token added to the sequence of patch embeddings. It serves as a summary or a representation of the entire image. The model learns this token's embedding along with the patch embeddings during training.

At the end of the transformer layers, the CLS token's final embedding is used as the input to a classification head, typically a fully connected layer, to predict the class label. This approach allows the model to consider the entire image's context when making a classification decision, leveraging the self-attention mechanism to gather information from all patches into this single, informative vector.

<div style="text-align: center;">
  <img src="./vit.png" alt="vit-layout" style="width:50%;">
</div>

In [None]:
# Quick demonstration of prepending a learnable embedding to this activation (along the "patch length" axis)
cls_token = nn.Parameter(
    torch.zeros(1, 1, 768) # channels-last
)

toks = torch.cat([
    cls_token.expand(BATCH_SIZE, -1, -1), # have to expand out the batch axis
    patch_emb,
], dim=1)

print(f"Final shape of embeddings with the CLS token: {toks.shape}")

Self-attention, while powerful, is inherently **permutation invariant**. This means that it does not take into account the order of the patches, treating them as a set rather than a sequence. However, in vision tasks, the spatial arrangement of patches is crucial for understanding the structure and relationships within an image.

To address this, we introduce positional encodings or embeddings. These are additional vectors added to the patch embeddings to inject information about the relative or absolute position of each patch in the image. In our implementation, we'll use a set of learnable weights for these positional encodings, allowing the model to learn the most effective way to incorporate positional information during training.

By combining these positional encodings with the patch embeddings, we create a richer input representation that not only captures the visual content of the patches but also their spatial arrangement. This enriched input is then fed into the next main component of the transformer, enabling the model to leverage both the content and the position of patches for more accurate understanding and processing of the image.

In the cell below, note how we use `nn.Parameter` to intialize the tensor and the shape following the patch embeddings. We rely on broadcasting to create copies over the batch axis since we do not want there to be different positional encodings for elements in different batches (that would make no sense).

In [None]:
# Initialize a randomly initialized set of positional encodings
pos_embed = nn.Parameter(torch.randn(1, toks.shape[1], toks.shape[2]) * 0.02) # this factor is from the timm implementation
x = toks + pos_embed

print(f"Final shape of input: {x.shape}")

## Task 2. Multi-Head Self Attention

Self-attention is a fundamental mechanism in the Vision Transformer that enables patches, or tokens, to communicate with one another across the entire image. This mechanism allows each patch to consider the information from all other patches, effectively sharing and enriching the context of each patch's representation. In the scene of computer vision, this means that the model can capture relationships and dependencies between different parts of the image, such as identifying that certain shapes or colors in one part of the image may relate to features in another part. This global interaction helps the model build a more comprehensive understanding of the image, crucial for tasks like image classification.

The Self-Attention mechanism can be expressed with the following expression from [the 2017 paper](https://arxiv.org/abs/1706.03762):
$$\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right)V$$

Here, the components are:

- Queries ($Q$): These represent the specific aspects or "questions" that each patch wants to know about other patches.
- Keys ($K$): These act like tags or "keywords" that help identify relevant information in the patches.
- Values ($V$): These are the actual data or "answers" that the patches contain.

In the Vision Transformer, these components are extracted from the patch embeddings using learned linear transformations. The three matrices are computed as follows:
$$Q = W^Q X$$
$$K = W^K X$$
$$V = W^V X$$

where $W^Q, W^K, W^V$ are learnable weight matrices, and $X$ is the tensor corresponding to the input embeddings.

In [None]:
# Define variables before moving further
embed_dim = 768
B, P, C = x.shape

# We can use nn.Linear layers (without the bias) to perform these computations
query = nn.Linear(embed_dim, embed_dim, bias=False)
key = nn.Linear(embed_dim, embed_dim, bias=False)
value = nn.Linear(embed_dim, embed_dim, bias=False)

# Get the projections, these are the Q, K, V matrices in the equation above
q = query(x)
k = key(x)
v = value(x)

# Get the shapes
q.shape, k.shape, v.shape

In the context of self-attention, a "head" refers to an individual attention mechanism within the multi-head attention framework. Each head operates independently, learning different aspects of the relationships between patches or tokens. By having multiple heads, the model can capture a diverse range of interactions and dependencies, enhancing its ability to understand complex patterns in the data.

Each head has its own set of learnable parameters for queries, keys, and values. The use of multiple heads allows the model to focus on different types of relationships or features in parallel. For instance, one head might learn to focus on spatial locality, while another might capture more global interactions.

Note how the `in_features` and `out_features` are the same for this setup. This actually makes it much easier for us to work with multiple heads in one go: we can partition the projections from these matrices by introducing a "head axis", this would then let us perform computations with `H` of these vectors of size `C//H` in parallel.

--- 

With the projections ready, we can finally implement the meat of the component: the equation.

$$\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right)V$$

Here's a few things to note:
1. $Q$ and $K$ are 4D tensors so the notion of a "tranpose" sounds rather strange. Note however, that the output we want from this operation is an **Attention Map** - a map that lays out the affinities between each and every pair of patches in the input tensor. This means what we really want is a `(B, H, P, P)` tensor - this means all we have to do is swap the last two axes of $K$ and follow the rules of Batched Matrix Multiplication.

2. The scale factor $\sqrt{d_k}$ is helpful for stable training. Without this, the activations would blow up exactly on the order of $d_k$, and this can lead to unstable gradients - so we scale everything down by this amount to end up with activations of unit-variance.

3. The Softmax here is applied on a row axis - i.e. the rows of the 2D slice must contain values that sum up to 1. This is important to note since we consider the dot product of the row of the query matrix with the columns of the transpose of the key (so if you ask a specific question following the analogy above, you'd want the answers to be weighed according to the question, not all questions weighed according to a single answer).

Just from looking at the equation, try to work out what the shape for the output of this big equation would be.

--- 

Assuming you got it right, the shapes can help give us a very interesting interpretation of what Self-Attention really does: if the input and output shapes are exactly the same (i.e. `[16, 197, 768]`) but with so much of processing with learnable parameters in the middle, then we can think of this module as simply **refining the representations**.

To expand on this, your input is provided as a sequence of embeddings (order matters here) - the Self-Attention operation allows the elements to communicate and share information with one another, enriching each other with context, so that the final consolidated output is simply an enriched version of the input. It also helps that having the same shape allows you to stack these components on top of one another very easily, as we will see later.

**Now for your first actual task:** Put all of these operations (with a few other bells and whistles) into a single `nn.Module` class.

**Note:** For your convenience in this task and the rest of this notebook, we have provided the definitions for the `__init__` methods - do not modify these. You need only define the `forward` methods for these modules.

Notice how we add two extra components here:
1. The `proj` component: this is a projection back into the same vector space, so it again acts like simply *refining* what you already have. Andrej Karpathy calls this as *thinking* on what you have just computed from the Self-Attention operation.

2. The `proj_drop` component: This is a form of Dropout which acts as a regularizing mechanism for the Vision Transformer. This becomes very important since they are prone to overfit on simplistic datasets, so this and `attn_drop` (functionally the same thing) can help mitigate this heinous activity.

In [None]:
class MHSA(nn.Module):

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

In [None]:
# Perform a forward pass on a dummy input
mhsa = MHSA(dim=768,
            num_heads=12,
            qkv_bias=True,
            )

# Create sample input tensor
B, P, C = 16, 196, 768
x = torch.randn(B, P, C)

# Pass the input through the MHSA layer to trigger the hook
output = mhsa(x)
print(f"{x.shape=} --> {output.shape=}")
del mhsa

## Task 3. The "Encoder Block"

Let's go back to the main illustration in the original paper:

<div style="text-align: center;">
  <img src="./vit.png" alt="vit-layout" style="width:50%;">
</div>

The grayed out block on the right side represents a single "Encoder Block".

The Encoder Block is a fundamental component of the Vision Transformer, inheriting its architecture from the original Transformer model used in natural language processing. Each Encoder Block is designed to process and refine the input sequence—here, the sequence of patch embeddings enriched with positional encodings.

The Encoder Block consists of two main sub-layers:

1. Multi-Head Self-Attention Mechanism: This layer allows the model to focus on different parts of the input sequence simultaneously, capturing various relationships and dependencies among the patches. As discussed earlier, multiple heads enable the model to learn different aspects of the data, providing a comprehensive representation.

2. **Feed-Forward Neural Network (FFN)**: Following the self-attention layer, a position-wise feed-forward neural network processes the output. This network typically consists of two linear transformations separated by a non-linear activation function, such as GeLU. It acts on each position independently and helps in further transforming and refining the representation learned from the self-attention layer.

To ensure the effective flow of gradients during training and to preserve the original input information, each sub-layer is equipped with **Skip Connections** (also known as Residual Connections). These connections add the input of each sub-layer to its output, forming a residual path that facilitates better gradient propagation and helps prevent the vanishing gradient problem. Mathematically, this can be expressed as:

$$\text{Output} = \text{Layernorm}(x + \text{SubLayer}(x))$$

In the equation above, $x$ represents the input to the sub-layer, and the sum $x + \text{SubLayer}(x)$ forms the residual connection. This is a very old idea in Deep Learning, going [back to 2015 with ResNets](https://arxiv.org/abs/1512.03385), and can perhaps be better understood with the following illustration:

<div style="text-align: center;">
  <img src="skip-connection.png" alt="skip-connection" style="width:30%;">
</div>

The output is then normalized using Layer Normalization (`LayerNorm`), which stabilizes the training by normalizing the summed outputs across each patch, ensuring that the model's activations are within a stable range. LayerNorm adjusts the output by scaling and shifting, allowing the model to maintain useful information while preventing excessive internal covariate shifts.

The Encoder Block's design, with its combination of self-attention, feed-forward neural networks, skip connections, and layer normalization, enables the Vision Transformer to learn rich, hierarchical representations of the input data. This structure is repeated multiple times in the model, each block building upon the representations learned by the previous one, gradually refining the understanding of the input image.

Go ahead and implement this directly.

In [None]:
class Mlp(nn.Module):
    def __init__(self, 
                 in_features: int,
                 hidden_features: int,
                 drop: float,
                 norm_layer: nn.Module = None) -> None:
        super().__init__()

        # There are two Linear layers in the conventional implementation
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, in_features)

        # Dropout is used twice, once after the GeLU and the next after the second Linear layer
        self.drop1 = nn.Dropout(drop)
        self.drop2 = nn.Dropout(drop)

        # The paper uses the GeLU activation function after the first Linear layer only
        self.act = nn.GELU()

        # Optional normalization layer to be used after the first Dropout
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

class Block(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int,
            mlp_ratio: float = 4.,
            qkv_bias: bool = False,
            proj_drop: float = 0.,
            attn_drop: float = 0.,
    ) -> None:
        super().__init__()
        
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.mlp = Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            drop=proj_drop,
        )

        self.attn = MHSA(
            dim=dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=proj_drop
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError
    
block = Block(dim=768,
              num_heads=12)
block_out = block(x)

print(f"{x.shape=} --> {block_out.shape=}")
del block

A wonderful observation here is again that the **input and output shapes are exactly the same!**

This means we can again fall back on the interpretation that each Block (on top of each MHSA module) is simply *refining* and embedding rich context into whatever input is fed into it. 

Plus, having the same shape as the input allows us to stack these encoders nicely on top of each other without much thought or care.

## Task 4. Putting it all together

Now with everything we've learned so far, let's make one final class that aggregates everything.

Recall what we had to do with

1. The Patch Embeddings to represent our image as a sequence and let Self-Attention link parts of it with one another,

2. The Positional Encodings/Embeddings to move past the permutation invariant nature of Self-Attention and embed information regarding the position of each patch into the mix,

3. The CLS Token to let the model have an overall representation of the entire image which would provide a means of performing classification,

4. The Multi-Head Self-Attention class (`MHSA`) to let the patches communicate and share information with one another in the hopes of enriching the representations,

5. The Block class (`Block`) to be able to string together the computations performed by the Self-Attention, Feedforward, and Layer Normalization modules.

Now we put it all together.

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self,
                 img_size: int,
                 patch_size: int,
                 in_chans: int,
                 embed_dim: int,
                 bias: bool = True,
                 norm_layer: Optional[Callable] = None) -> None:
        super().__init__()
        
        self.img_size = img_size
        self.patch_size = patch_size

        self.grid_size = (self.img_size // self.patch_size, ) * 2
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_chans, 
                              embed_dim, 
                              kernel_size=patch_size, 
                              stride=patch_size, 
                              bias=bias, 
                              padding=0)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Take care to flatten and transpose the projected layer (BCHW -> BLC)
        raise NotImplementedError

class VisionTransformer(nn.Module):

    def __init__(
            self,
            img_size: Union[int, Tuple[int, int]] = 224,
            patch_size: Union[int, Tuple[int, int]] = 16,
            in_chans: int = 3,
            num_classes: int = 1000,
            embed_dim: int = 768,
            depth: int = 12,
            num_heads: int = 12,
            mlp_ratio: float = 4.,
            qkv_bias: bool = True,
            drop_rate: float = 0.,
            pos_drop_rate: float = 0.,
            proj_drop_rate: float = 0.,
            attn_drop_rate: float = 0.,
    ) -> None:
        """
        Args:
            img_size: Input image size.
            patch_size: Patch size.
            in_chans: Number of image input channels.
            num_classes: Number of classes for classification heads
            embed_dim: Transformer embedding dimension.
            depth: Depth of transformer.
            num_heads: Number of attention heads.
            mlp_ratio: Ratio of mlp hidden dim to embedding dim.
            qkv_bias: Enable bias for qkv projections if True.
            drop_rate: Head dropout rate.
            pos_drop_rate: Position embedding dropout rate.
            proj_drop_rate: MHSA projection dropout rate.
            attn_drop_rate: Attention dropout rate.
        """
        super().__init__()

        self.num_classes = num_classes
        self.num_features = self.head_hidden_size = self.embed_dim = embed_dim  # for consistency with other models

        # Define the Patch Embedding module - note this does not include the CLS token yet
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            bias=True,
        )
        num_patches = self.patch_embed.num_patches
        embed_len = num_patches + 1 # don't forget we need to incorporate the CLS token

        # Define the CLS token, the Positional Encodings/Embeddings, and a Dropout for the Positional information
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
        self.pos_drop = nn.Dropout(p=pos_drop_rate)

        # Define LayerNorms for before and after the Encoder block processing
        norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm_pre = norm_layer(embed_dim)
        self.norm = norm_layer(embed_dim)

        # Initialize the blocks
        self.blocks = nn.Sequential(*[
            Block(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                proj_drop=proj_drop_rate,
                attn_drop=attn_drop_rate
            )
            for i in range(depth)])
        
        self.feature_info = [
            dict(module=f'blocks.{i}', num_chs=embed_dim) for i in range(depth)
        ]

        # Classifier Head
        self.head_drop = nn.Dropout(drop_rate)
        self.head = nn.Linear(self.embed_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError
    
model = VisionTransformer()

In [None]:
x, _ = next(iter(train_dl))
out = model(x)
print(out.shape)

## Task 5.  Loading in Pretrained Weights

Now that we've built our simplistic version of a Vision Transformer, let's load in the parameters from a pretrained checkpoint, specifically from the Small variant of the `timm` library.

The procedure is simple: load in the `state_dict` from the pretrained model, and match with the corresponding parameters in your implementation. 

The reason we provided the definitions for the `__init__` methods was to make your task easier with matching the tensors - it is your task to actually do the matching and copying.

You can look into [this notebook](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/01_main-chapter-code/ch05.ipynb) (from [Sebastian Raschka's amazing repository on "LLMs from Scratch"](https://github.com/rasbt/LLMs-from-scratch)), to understand how to go about loading in pretrained weights in a more controlled fashion. 

You can either go this route where you load in all the weights one by one manually, or you can go back and redefine the different components of the Vision Transformer (adding in methods specifically for loading in weights from a `state_dict`) to make things cleaner. It is entirely your choice.

In [None]:
def vit_small_patch16_224(
        num_classes: int = 10,
        pretrained: bool = False,
        in_chans: int = 3,
        drop_rate: float = 0,
        pos_drop_rate: float = 0,
        attn_drop_rate: float = 0,
        proj_drop_rate: float = 0.
        ):
    
    model = VisionTransformer(
        img_size=224,
        patch_size=16,
        in_chans=in_chans,
        num_classes=num_classes,
        embed_dim=384,
        num_heads=6,
        depth=12,
        drop_rate=drop_rate,
        pos_drop_rate=pos_drop_rate,
        attn_drop_rate=attn_drop_rate,
        proj_drop_rate=proj_drop_rate
    )

    if pretrained:
        raise NotImplementedError
    
    return model

In [None]:
# Load in our model
model = vit_small_patch16_224(num_classes=100,
                             pretrained=True)

## Task 6. Finetuning our Vision Transformer

Now with everything in place, we can finally move on to actually finetuning our model.

In the cells below, 

To spice things up, you can add in random augmentations during the training process.

Train for 3 epochs and note how the model (provided you have loaded in the weights properly) performs much better than anything you could reasonably train from scratch, given the same number of training steps. You should use the `AdamW` optimizer with a learning rate of `1e-4`. The goal here is to get a Validation Accuracy of over 80% within 3 epochs of finetuning - this is cake if (again) you have loaded in the weights properly.

In [None]:
raise NotImplementedError

In [None]:
# Plot the loss curves
raise NotImplementedError

In [None]:
# Plot the accuracy curves
raise NotImplementedError

## Fin.