# Lab 11: Transformer

Reference: 
- https://medium.com/the-dl/transformers-from-scratch-in-pytorch-8777e346ca51
- https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec
- https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632
- https://github.com/lucidrains/vit-pytorch#vision-transformer---pytorch
- https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c

Today's lab will describe how to create transformer from scratch and how to implement it. :-)

Moreover, we will continue to ViT (Vision transformer).

<img src="img/optimus_prime.jpg" title="Transformer" style="width: 600px;" />

Note that this photo does not have relevance in the topic. :-)

## Transformer trend

In recent years, the deep learning network structures are in the form of CNNs or RNNs. CNNs are mostly used in visions, and RNNs are mostly used in NLP (Natural language processing). Their models can be worked together. CNNs can be used in language and RNNs also can be used in vision. However, both models have some disadvantage such as coreference resolution, and huge or complex dataset in languages.

In 2017, The google research lab has been introduced Transformer in the paper name: "Attention is all you need". This paper is very famous and has been cited more than 8 thousands times.

The main concept of Transformer is self-Attention process. The procees is not only replace the RNNs and CNNs, it also shows the relation between text and data.

## Transformer architecture

We use the code from [github](https://github.com/fkodom/transformer-from-scratch/tree/main/src)

The transformer diagram is show as below:

<img src="img/Transformer.png" title="Transformer" style="width: 600px;" />

The summary details and mathematics of the architecture is:

<img src="img/SummaryTransformer.PNG" title="Transformer Details" style="width: 1000px;" />

There are processes that we need to implement in the model:

## Attention!

In present, the standard model which is used in *sequence-to-sequence learning* is **Sequence-to-Sequence Model** (seq2seq) or **RNN Encoder-Decoder**. The model combines with 2 pieces called encoder and decoder.
The encoder receives the input and keeps some important information, and the decoder releases the output from the related information.

However, the seq2seq has a problem that some information may be lost from the long sequence running. Thus, the concept of attention is **to focus at the specific input directly**.

In attention, when we want to get the output at a target position, the decoder vector at the position willcaculate between the attention score and encoder vector in every position. The high score in some encoder position can tell that it is important than the other position. To get the probability weight for each encoder vector can be explained in softmax as:

$$r = \sum_i \frac{e^{p_i\cdot q}}{\sum_j e^{p_j\cdot q}}p_i$$

### Multi-head attention

Transformers use a specific type of attention mechanism, referred to as multi-head attention. This is the most important part of the model. An illustration from the paper is shown below.

<img src="img/MultiHeadAttention.png" title="Transformer" style="width: 600px;" />

The multi-head attention layer is described as:

$$\text{Attention}(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V$$

Q, K, and V are batches of matrices, each with shape <code>(batch_size, seq_length, num_features)</code>. Multiplying the query $Q$ and key $K$ arrays results in a <code>(batch_size, seq_length, seq_length)</code> array, which tells us roughly how important each element in the sequence is. This is the attention of this layer — it determines which elements we “pay attention” to. The attention array is normalized using softmax, so that all of the weights sum to one. Finally, the attention is applied to the value (V) array using matrix multiplication.

The scaled dot-product attention code is below: (multi_head_attention.py)

In [1]:
import torch
import torch.nn.functional as f
from torch import Tensor, nn


def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor) -> Tensor:
    # MatMul operations are translated to torch.bmm in PyTorch
    temp = query.bmm(key.transpose(1, 2))
    scale = query.size(-1) ** 0.5
    softmax = f.softmax(temp / scale, dim=-1)
    return softmax.bmm(value)

The multi-head is composed of several identical *attention heads*. Each attention head contains 3 lineary layers and combine them using scaled dot-product attention.

In [2]:
class AttentionHead(nn.Module):
    def __init__(self, dim_in: int, dim_q: int, dim_k: int):
        super().__init__()
        self.q = nn.Linear(dim_in, dim_q)
        self.k = nn.Linear(dim_in, dim_k)
        self.v = nn.Linear(dim_in, dim_k)

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        return scaled_dot_product_attention(self.q(query), self.k(key), self.v(value))

Then we can combine them to be multi-head attention layer.

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, dim_in: int, dim_q: int, dim_k: int):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(dim_in, dim_q, dim_k) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_k, dim_in)

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        return self.linear(
            torch.cat([h(query, key, value) for h in self.heads], dim=-1)
        )

Each attention head computes its own query, key, and value arrays, and then applies scaled dot-product attention. Conceptually, this means each head can attend to a different part of the input sequence, independent of the others. Increasing the number of attention heads allows us to “pay attention” to more parts of the sequence at once, which makes the model more powerful.

### Positional Encoding

To completed the encoding zone of transformer (input zone), we need to build another component: **position encoder**.
The <code>MultiHeadAttention</code> has no trainable components that operate over the *sequence dimension* (axis 1). Everything operates over the *feature dimension* (axis 2), and so it is independent of sequence length. We have to provide positional information to the model, so that it knows about the relative position of data points in the input sequences.

The positional information is encoded using **trigonometric functions** as:

$$PE_{(pos,2i)}=\sin (\frac{pos}{10000^{2i/d_{model}}})$$
$$PE_{(pos,2i+1)}=\cos (\frac{pos}{10000^{2i/d_{model}}})$$

This constant is a 2d matrix. Pos refers to the order in the sentence, and i refers to the position along the embedding vector dimension. Each value in the pos/i matrix is then worked out using the equations above.

<img src="img/positionalencoder.png" title="Positional Encoder" style="width: 400px;" />

The positional encoding is implemented in code as: (utils.py)

In [4]:
def position_encoding(seq_len: int, dim_model: int, device: torch.device = torch.device("cpu"),) -> Tensor:
    pos = torch.arange(seq_len, dtype=torch.float, device=device).reshape(1, -1, 1)
    dim = torch.arange(dim_model, dtype=torch.float, device=device).reshape(1, 1, -1)
    phase = pos / (1e4 ** (dim // dim_model))

    return torch.where(dim.long() % 2 == 0, torch.sin(phase), torch.cos(phase))

Why should sinusoidal encodings extrapolate to longer sequence lengths? Because sine/cosine functions are periodic, and they cover a range of $[0, 1]$. Most other choices of encoding would not be periodic or restricted to the range $[0, 1]$. Suppose that, during inference, you provide an input sequence longer than any used during training. Positional encoding for the last elements in the sequence could be different than anything the model has seen before. For those reasons, and despite the fact that learned embeddings appeared to perform equally as well, the authors still chose to use sinusoidal encoding.

### The completed transformer

Transformer uses an encoder-decoder architecture. The encoder processes the input sequence and returns a **feature vector (or memory vector)**. The decoder processes the **target sequence, and incorporates information from the encoder memory**. The output from the decoder is the model prediction.

We need to build another piece for the transformer. That is the feed forward network. (utils.py)

In [5]:
def feed_forward(dim_input: int = 512, dim_feedforward: int = 2048) -> nn.Module:
    return nn.Sequential(
        nn.Linear(dim_input, dim_feedforward),
        nn.ReLU(),
        nn.Linear(dim_feedforward, dim_input),
    )

To combine dropout and normalization, we can build residual module (utils.py)

In [6]:
class Residual(nn.Module):
    def __init__(self, sublayer: nn.Module, dimension: int, dropout: float = 0.1):
        super().__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(dimension)
        self.dropout = nn.Dropout(dropout)

    def forward(self, *tensors: Tensor) -> Tensor:
        # Assume that the "query" tensor is given first, so we can compute the
        # residual.  This matches the signature of 'MultiHeadAttention'.
        return self.norm(tensors[0] + self.dropout(self.sublayer(*tensors)))

Woooo, we can create encoder now! (encoder.py)

First, let's create Transformer Encoder layer.

In [7]:
class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        dim_model: int = 512,
        num_heads: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        super().__init__()
        dim_q = dim_k = max(dim_model // num_heads, 1)
        self.attention = Residual(
            MultiHeadAttention(num_heads, dim_model, dim_q, dim_k),
            dimension=dim_model,
            dropout=dropout,
        )
        self.feed_forward = Residual(
            feed_forward(dim_model, dim_feedforward),
            dimension=dim_model,
            dropout=dropout,
        )

    def forward(self, src: Tensor) -> Tensor:
        src = self.attention(src, src, src)
        return self.feed_forward(src)

Then the Transformer encoder model

In [8]:
class TransformerEncoder(nn.Module):
    def __init__(
        self,
        num_layers: int = 6,
        dim_model: int = 512,
        num_heads: int = 8,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerEncoderLayer(dim_model, num_heads, dim_feedforward, dropout)
                for _ in range(num_layers)
            ]
        )

    def forward(self, src: Tensor) -> Tensor:
        seq_len, dimension = src.size(1), src.size(2)
        src += position_encoding(seq_len, dimension)
        for layer in self.layers:
            src = layer(src)

        return src

### The decoder

The decoder module is extremely similar. Just a few small differences:
- The decoder accepts two arguments (target and memory), rather than one.
- There are two multi-head attention modules per layer, instead of one.
- The second multi-head attention accepts memory for two of its inputs.

In [9]:
class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        dim_model: int = 512,
        num_heads: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        super().__init__()
        dim_q = dim_k = max(dim_model // num_heads, 1)
        self.attention_1 = Residual(
            MultiHeadAttention(num_heads, dim_model, dim_q, dim_k),
            dimension=dim_model,
            dropout=dropout,
        )
        self.attention_2 = Residual(
            MultiHeadAttention(num_heads, dim_model, dim_q, dim_k),
            dimension=dim_model,
            dropout=dropout,
        )
        self.feed_forward = Residual(
            feed_forward(dim_model, dim_feedforward),
            dimension=dim_model,
            dropout=dropout,
        )

    def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:
        tgt = self.attention_1(tgt, tgt, tgt)
        tgt = self.attention_2(tgt, memory, memory)
        return self.feed_forward(tgt)

In [10]:
class TransformerDecoder(nn.Module):
    def __init__(
        self,
        num_layers: int = 6,
        dim_model: int = 512,
        num_heads: int = 8,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerDecoderLayer(dim_model, num_heads, dim_feedforward, dropout)
                for _ in range(num_layers)
            ]
        )
        self.linear = nn.Linear(dim_model, dim_model)

    def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:
        seq_len, dimension = tgt.size(1), tgt.size(2)
        tgt += position_encoding(seq_len, dimension)
        for layer in self.layers:
            tgt = layer(tgt, memory)

        return torch.softmax(self.linear(tgt), dim=-1)

### Combine everthing, Finish!!!

In [11]:
class Transformer(nn.Module):
    def __init__(
        self, 
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_model: int = 512, 
        num_heads: int = 6, 
        dim_feedforward: int = 2048, 
        dropout: float = 0.1, 
        activation: nn.Module = nn.ReLU(),
    ):
        super().__init__()
        self.encoder = TransformerEncoder(
            num_layers=num_encoder_layers,
            dim_model=dim_model,
            num_heads=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )
        self.decoder = TransformerDecoder(
            num_layers=num_decoder_layers,
            dim_model=dim_model,
            num_heads=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )

    def forward(self, src: Tensor, tgt: Tensor) -> Tensor:
        return self.decoder(tgt, self.encoder(src))

Let’s create a simple test, as a sanity check for our implementation. We can construct random tensors for src and tgt, check that our model executes without errors, and confirm that the output tensor has the correct shape.

In [12]:
# input (64-batch_size, 32-words, 512-embedding)
src = torch.rand(64, 32, 512)
tgt = torch.rand(64, 16, 512)
out = Transformer()(src, tgt)
print(out.shape)
# torch.Size([64, 16, 512])

torch.Size([64, 16, 512])


You can try to train transformer using the PyTorch model from [here](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)

## Vision Transformer (ViT)

A Vision Transformer (ViT) is a transformer that is targeted at vision processing tasks. It have dominated the field of Computer Vision, obtaining state-of-the-art performance in image classification and others.

The ViT concept for image classification is shown at below:

<img src="img/vit.gif" title="ViT" />

### How the ViT works?

There are steps of ViT as:

1. Split an image into patches
2. Flatten the patches
3. Produce lower-dimensional linear embeddings from the flattened patches
4. Add positional embeddings
5. Feed the sequence as an input to a standard transformer encoder
6. Pretrain the model with image labels (fully supervised on a huge dataset)
7. Finetune on the downstream dataset for image classification

### ViT architecture

As you know that the ViT has implemented from Transforemer, so the architecture will be the same as transformer. However, using ViT for image classification is used only the Transformer encoder. Thus, the decoder has been removed from the architecture. The ViT architecture is shown at below:

<img src="img/ViTArchitecture.png" title="ViT architecture" />

From the figure, there are 4 parts:
<ol style="list-style-type:lower-alpha">
    <li> the main architecture of the model </li>
    <li> the Transformer module </li>
    <li> the Multiscale self-attention (MSA) head </li>
    <li> The self-attention (SA) head </li>
</ol>

### Let's start

Since the goal is how to create ViT model, so we use the MNIST dataset for train and test it!

We get the code example from: https://github.com/BrianPulfer/PapersReimplementations

In [1]:
import numpy as np

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor

Import the MNIST dataset.

In [2]:
# Loading data
transform = ToTensor()

train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, shuffle=True, batch_size=16)
test_loader = DataLoader(test_set, shuffle=False, batch_size=16)

### Train/Test function

Create the train and test function.

In [3]:
def train_ViT_classify(model, optimizer, N_EPOCHS, train_loader, device="cpu"):
    criterion = CrossEntropyLoss()
    for epoch in range(N_EPOCHS):
        train_loss = 0.0
        for batch in train_loader:
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y) / len(x)

            train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")
        
def test_ViT_classify(model, optimizer, test_loader):
    criterion = CrossEntropyLoss()
    correct, total = 0, 0
    test_loss = 0.0
    for batch in test_loader:
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        
        y_hat = model(x)
        loss = criterion(y_hat, y) / len(x)
        test_loss += loss

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")

### Multi-head Self Attention (MSA) Model

Before create ViT model, as same as normal transformer, we need to create MSA model to complete all model.

For a single image, that each patch gets updated based on some similarity measure with the other patches.
Thus, do linearly mapping each patch (that is now an 8-dimensional vector in our example) to 3 distinct vectors: $q$, $k$, and $v$ (query, key, value).

For each single patch, compute the dot product between its $q$ vector with all of the $k$ vectors, and divide by the square root of the dimension of these vectors.
At the end of the patch, do softmax the patch ouput. This is called attention cues, and multiply each attention cue with the $v$ vectors associated with the different $k$ vectors and sum all up.

Each patch is assumed to be a new value that is based on its similarity (after the linear mapping to $q$, $k$, and $v$) with other patches.

However, the whole procedure is carried out $H$ times on $H$ sub-vectors of our current 8-dimensional patches, where $H$ is the number of Heads.

Once all results are obtained, they are concatenated together. Finally, the result is passed through a linear layer.

The MSA model is shown at below:

In [4]:
class MSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        self.k_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        self.v_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

**Note**: for each head, we create distinct Q, K, and V mapping functions (square matrices of size 4x4 in our example).

Since our inputs will be sequences of size (N, 50, 8), and we only use 2 heads, we will at some point have an (N, 50, 2, 4) tensor, use a nn.Linear(4, 4) module on it, and then come back, after concatenation, to an (N, 50, 8) tensor.

### Positional encoding

As anticipated, positional encoding allows the model to understand where each patch would be placed in the original image. While it is theoretically possible to learn such positional embeddings, previous work by Vaswani et. al.

In particular, positional encoding adds low-frequency values to the first dimensions and higher-frequency values to the latter dimensions.

In each sequence, for token i we add to its j-th coordinate the following value:

$$ p_{i,j} =
\left\{\begin{matrix}
\sin (\frac{i}{10000^{j/d_{embdim}}})\\ 
\cos (\frac{i}{10000^{j/d_{embdim}}})
\end{matrix}\right.
$$

<img src="img/peimages.png" title="" style="width: 800px;" />

In [5]:
def get_positional_embeddings(sequence_length, d, device="cpu"):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result.to(device)

### ViT Model

Create the ViT model as below. The explaination is later.

In [6]:
class ViT(nn.Module):
    def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10):
        # Super constructor
        super(ViT, self).__init__()

        # Input and patches sizes
        self.input_shape = input_shape
        self.n_patches = n_patches
        self.n_heads = n_heads
        assert input_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert input_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
        self.hidden_d = hidden_d

        # 1) Linear mapper
        self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        # (In forward method)

        # 4a) Layer normalization 1
        self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 4b) Multi-head Self Attention (MSA) and classification token
        self.msa = MSA(self.hidden_d, n_heads)

        # 5a) Layer normalization 2
        self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 5b) Encoder MLP
        self.enc_mlp = nn.Sequential(
            nn.Linear(self.hidden_d, self.hidden_d),
            nn.ReLU()
        )

        # 6) Classification MLP
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
        # Dividing images into patches
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)

        # Running linear layer for tokenization
        tokens = self.linear_mapper(patches)

        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

        # Adding positional embedding
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d, device).repeat(n, 1, 1)

        # TRANSFORMER ENCODER BEGINS ###################################
        # NOTICE: MULTIPLE ENCODER BLOCKS CAN BE STACKED TOGETHER ######
        # Running Layer Normalization, MSA and residual connection
        self.msa(self.ln1(tokens.to("cpu")).to(device))
        out = tokens + self.msa(self.ln1(tokens))

        # Running Layer Normalization, MLP and residual connection
        out = out + self.enc_mlp(self.ln2(out))
        # TRANSFORMER ENCODER ENDS   ###################################

        # Getting the classification token only
        out = out[:, 0]

        return self.mlp(out)

#### Step 1: Patchifying and the linear mapping

The transformer encoder was developed with sequence data in mind, such as English sentences. However, an image is not a sequence. Thus, we break it into multiple sub-images and map each sub-image to a vector.

We do so by simply reshaping our input, which has size $(N, C, H, W)$ (in our example $(N, 1, 28, 28)$), to size (N, #Patches, Patch dimensionality), where the dimensionality of a patch is adjusted accordingly.

In MNIST, we break each $(1, 28, 28)$ into 7x7 patches (hence, each of size 4x4). That is, we are going to obtain 7x7=49 sub-images out of a single image.

$$(N,1,28,28) \rightarrow (N,P\times P, H \times C/P  \times W \times C/P) \rightarrow (N, 7\times 7, 4\times 4) \rightarrow (N, 49, 16)$$

<img src="img/patch.png" title="an image is split into patches" />

#### Step 2: Adding the classification token

When information about all other tokens will be present here, we will be able to classify the image using only this special token. The initial value of the special token (the one fed to the transformer encoder) is a parameter of the model that needs to be learned.

We can now add a parameter to our model and convert our (N, 49, 8) tokens tensor to an (N, 50, 8) tensor (we add the special token to each sequence).

Passing from (N,49,8) → (N,50,8) is probably sub-optimal. Also, notice that the classification token is put as the first token of each sequence. This will be important to keep in mind when we will then retrieve the classification token to feed to the final MLP.

#### Step 3: Positional encoding

See above, as we mentioned.

#### Step 4: LN, MSA, and Residual Connection

The step is to apply layer normalization to the tokens, then apply MSA, and add a residual connection (add the input we had before applying LN).
- **Layer normalization** is a popular block that, given an input, subtracts its mean and divides by the standard deviation.
- **MSA**: same as the vanilla transformer.
- **A residual connection** consists in just adding the original input to the result of some computation. This, intuitively, allows a network to become more powerful while also preserving the set of possible functions that the model can approximate.

The residual connection is added at the original (N, 50, 8) tensor to the (N, 50, 8) obtained after LN and MSA.

#### Step 5: LN, MLP, and Residual Connection
All that is left to the transformer encoder is just a simple residual connection between what we already have and what we get after passing the current tensor through another LN and an MLP.

#### Step 6: Classification MLP
Finally, we can extract just the classification token (first token) out of our N sequences, and use each token to get N classifications.
Since we decided that each token is an 8-dimensional vector, and since we have 10 possible digits, we can implement the classification MLP as a simple 8x10 matrix, activated with the SoftMax function.

The output of our model is now an (N, 10) tensor.

In [7]:
# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# print('Using device', device)
# cuda is does not work now -> kernels always died
device = "cpu"

Using device cuda:1


In [8]:
model = ViT((1, 28, 28), n_patches=7, hidden_d=20, n_heads=2, out_d=10)
model = model.to(device)

N_EPOCHS = 5
LR = 0.01
optimizer = Adam(model.parameters(), lr=LR)

In [9]:
train_ViT_classify(model, optimizer, N_EPOCHS, train_loader, device)

Epoch 1/5 loss: 406.94
Epoch 2/5 loss: 373.10
Epoch 3/5 loss: 367.05
Epoch 4/5 loss: 365.66
Epoch 5/5 loss: 364.58


In [10]:
test_ViT_classify(model, optimizer, test_loader)

Test loss: 60.45
Test accuracy: 91.45%


The testing accuracy is around 90% and our implement is done.

### Pytorch ViT

[Here](https://github.com/lucidrains/vit-pytorch#vision-transformer---pytorch) is the link of the full version of ViT using pytorch.

In [11]:
!pip install vit-pytorch

Collecting vit-pytorch
  Downloading vit_pytorch-0.29.0-py3-none-any.whl (56 kB)
[K     |████████████████████████████████| 56 kB 1.6 MB/s eta 0:00:011
Collecting einops>=0.4.1
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops, vit-pytorch
Successfully installed einops-0.4.1 vit-pytorch-0.29.0


#### Using ViT Pytorch

In [12]:
import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

DistillableViT

In [13]:
import torch
from torchvision.models import resnet50

from vit_pytorch.distill import DistillableViT, DistillWrapper

teacher = resnet50(pretrained = True)

v = DistillableViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

distiller = DistillWrapper(
    student = v,
    teacher = teacher,
    temperature = 3,           # temperature of distillation
    alpha = 0.5,               # trade between main loss and distillation loss
    hard = False               # whether to use soft or hard distillation
)

img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))

loss = distiller(img, labels)
loss.backward()

# after lots of training above ...

pred = v(img) # (2, 1000)

and so on...

## Exercise

