In [158]:
import torch
import torch.nn as nn
from torchinfo import summary
import numpy as np
from torchvision.models.vision_transformer import ViT_B_16_Weights
from torchvision.models.vision_transformer import vit_b_16

<img src="./ViT.png"></img>

| Model | Layers | Hidden size $D$ | MLP size | Heads | Params |
| :--- | :---: | :---: | :---: | :---: | :---: |
| ViT-Base | 12 | 768 | 3072 | 12 | $86 \mathrm{M}$ |

In [None]:
L = 12
D = 768
HEADS = 12

PATCH = 16
IMAGE_W = 224
N = (IMAGE_W/PATCH)**2
DH = D/HEADS # To keep num of params constant we set DH = D/HEADS
DMLP = 3072 # 4 times the D

To handle 2D images, we reshape the image $\mathbf{x} \in \mathbb{R}^{H \times W \times C}$ into a sequence of flattened 2D patches $\mathbf{x}_p \in \mathbb{R}^{N \times\left(P^2 \cdot C\right)}$,<br> where $(H, W)$ is the resolution of the original image, $C$ is the number of channels, $(P, P)$ is the resolution of each image patch, and $N=H W / P^2$

$$
\mathbf{x}_p \in \mathbb{R}^{N \times\left(P^2 \cdot C\right)}
$$

Similar to BERT's [class] token, we prepend a learnable embedding to the sequence of embedded patches $\left(\mathbf{z}_0^0=\mathbf{x}_{\text {class }}\right)$,<br>
whose state at the output of the Transformer encoder $\left(\mathbf{z}_L^0\right)$ serves as the image representation y.<br>
Both during pre-training and fine-tuning, a classification head is attached to $\mathbf{z}_L^0$.<br>
The classification head is implemented by a MLP with one hidden layer at pre-training time and by a single linear layer at fine-tuning time.

Layernorm (LN) is applied before every block, and residual connections after every block<br>
The MLP contains two layers with a GELU non-linearity.

\begin{aligned}
\mathbf{z}_0 & =\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_p^1 \mathbf{E} ; \mathbf{x}_p^2 \mathbf{E} ; \cdots ; \mathbf{x}_p^N \mathbf{E}\right]+\mathbf{E}_{\text {pos }}, & & \mathbf{E} \in \mathbb{R}^{\left(P^2 \cdot C\right) \times D}, \mathbf{E}_{\text {pos }} \in \mathbb{R}^{(N+1) \times D} \\
\mathbf{z}_{\ell}^{\prime} & =\operatorname{MSA}\left(\operatorname{LN}\left(\mathbf{z}_{\ell-1}\right)\right)+\mathbf{z}_{\ell-1}, & & \ell=1 \ldots L \\
\mathbf{z}_{\ell} & =\operatorname{MLP}\left(\operatorname{LN}\left(\mathbf{z}_{\ell}^{\prime}\right)\right)+\mathbf{z}_{\ell}^{\prime}, & & \ell=1 \ldots L \\
\mathbf{y} & =\operatorname{LN}\left(\mathbf{z}_L^0\right) & &
\end{aligned}

Standard qkv self-attention (SA, Vaswani et al. (2017))<br>
For each element in an input sequence $\mathbf{z} \in \mathbb{R}^{N \times D}$, we compute a weighted sum over all values $\mathbf{v}$ in the sequence.<br> 
The attention weights $A_{i j}$ are based on the pairwise similarity between two elements of the sequence and their respective query $\mathbf{q}^i$ and key $\mathbf{k}^j$ representations.
$$
\begin{array}{rlrl}
{[\mathbf{q}, \mathbf{k}, \mathbf{v}]} & =\mathbf{z} \mathbf{U}_{q k v} & \mathbf{U}_{q k v} & \in \mathbb{R}^{D \times 3 D_h}, \\
A & =\operatorname{softmax}\left(\mathbf{q} \mathbf{k}^{\top} / \sqrt{D_h}\right) & A \in \mathbb{R}^{N \times N} \\
\mathrm{SA}(\mathbf{z}) & =A \mathbf{v} &
\end{array}
$$
Multihead self-attention (MSA) is an extension of SA in which we run $k$ self-attention operations, called "heads", in parallel, and project their concatenated outputs.<br>
To keep compute and number of parameters constant when changing $k, D_h$ (Eq. 5 ) is typically set to $D / k$.
$$
\operatorname{MSA}(\mathbf{z})=\left[\mathrm{SA}_1(z) ; \mathrm{SA}_2(z) ; \cdots ; \mathrm{SA}_k(z)\right] \mathbf{U}_{m s a} \quad \mathbf{U}_{m s a} \in \mathbb{R}^{k \cdot D_h \times D}
$$

Dropout, when used, is applied after every dense layer except for the the qkv-projections and directly after adding positional- to patch embeddings.<br>
Finally, all training is done on resolution 224.

In order to stay as close as possible to the original Transformer model, we made use of an additional [class] token, which is taken as image representation.<br>
The output of this token is then transformed into a class prediction via a small multi-layer perceptron (MLP) with tanh as non-linearity in the single hidden layer.

In [None]:
w = ViT_B_16_Weights.DEFAULT
w.get_state_dict(progress=True)


In [200]:
state = w.get_state_dict(progress=True)

In [201]:
ref_m = vit_b_16()
ref_m.load_state_dict(state)

<All keys matched successfully>

<pre>
class_token: [1, 1, D] = [1, 1, 768] second dimention is the num of embeddings, so it's always 1
conv_proj: Conv2d(3, D, kernel_size=(16, 16), stride=(16, 16))
conv_proj.weight: [B, 3, 224, 224] -> [OUT_CHAN, IN_CHAN, KERNEL, KERNEL] = [D, 3, P, P] = [768, 3, 16, 16] -> [B, 768, 224/16, 224/16] = [B, 768, 14, 14]
conv_proj.bias: [768]
flatten: [B, 768, 14, 14] -> [B, 768, 14 * 14] = [B, 768, 196]
reshape [B, 768, 196] -> [B, 196, 768]
encoder.pos_embedding: [1, N+1, D] = [1, 197, 768]
dropout
    encoder.layers.encoder_layer_0.ln_1: LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    encoder.layers.encoder_layer_0.ln_1.weight: [D] [768] has learnable weights for each component of the vector
    encoder_layer_0.ln_1.bias: [D] [768] 
    encoder.layers.encoder_layer_0.self_attention: 
        encoder.layers.encoder_layer_0.self_attention.in_proj_weight: [HEADS * DH * 3, D] = [2304, 768]
        encoder.layers.encoder_layer_0.self_attention.in_proj_bias : [2304]
        attention dropout is optoinal in pytorch's implementation, off by default. The paper doesn't have it
        encoder.layers.encoder_layer_0.self_attention.out_proj.weight: [HEADS * DH, D] = [768, 768]
        encoder.layers.encoder_layer_0.self_attention.out_proj.bias: [768]
    dropout
    residual
    encoder.layers.encoder_layer_0.ln_2: LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    encoder.layers.encoder_layer_0.mlp.linear_1: Linear(in_features=D, out_features=DMLP, bias=True)
    dropout
    encoder.layers.encoder_layer_0.mlp.linear_2: Linear(in_features=DMLP, out_features=D, bias=True)
    dropout
encoder.ln: LayerNorm((768,), eps=1e-06, elementwise_affine=True)
heads.head: Linear(in_features=768, out_features=CLASSES, bias=True)
<pre>

In [202]:
summary(ref_m, depth=4, input_size=(1, 3, IMAGE_W, IMAGE_W),col_names=["kernel_size", "input_size", "output_size", "num_params"], row_settings=["var_names"],)

Layer (type (var_name))                                      Kernel Shape              Input Shape               Output Shape              Param #
VisionTransformer (VisionTransformer)                        --                        [1, 3, 224, 224]          [1, 1000]                 768
├─Conv2d (conv_proj)                                         [16, 16]                  [1, 3, 224, 224]          [1, 768, 14, 14]          590,592
├─Encoder (encoder)                                          --                        [1, 197, 768]             [1, 197, 768]             151,296
│    └─Dropout (dropout)                                     --                        [1, 197, 768]             [1, 197, 768]             --
│    └─Sequential (layers)                                   --                        [1, 197, 768]             [1, 197, 768]             --
│    │    └─EncoderBlock (encoder_layer_0)                   --                        [1, 197, 768]             [1, 197, 768]      

In [206]:
import math
keys = np.array(list(state.keys()))

max_len = -1
max_i = -1
for i, k in enumerate(keys):
    if(len(k) >= max_len):
        max_len = len(k)
        max_i = i

rows = 10
pad_end = math.ceil(len(keys)/rows)*rows - len(keys)
keys = np.pad(keys, (0,pad_end), constant_values='')

margin  = 4
lines = np.stack(np.array_split(keys, int(len(keys)/rows))).T
for l in lines:
    print((' ' * margin).join([str(key).ljust(max_len) for key in l]))

class_token                                                       encoder.layers.encoder_layer_0.ln_2.weight                        encoder.layers.encoder_layer_1.self_attention.out_proj.weight     encoder.layers.encoder_layer_2.self_attention.in_proj_weight      encoder.layers.encoder_layer_3.ln_1.weight                        encoder.layers.encoder_layer_3.mlp.linear_2.weight                encoder.layers.encoder_layer_4.mlp.linear_1.weight                encoder.layers.encoder_layer_5.ln_2.weight                        encoder.layers.encoder_layer_6.self_attention.out_proj.weight     encoder.layers.encoder_layer_7.self_attention.in_proj_weight      encoder.layers.encoder_layer_8.ln_1.weight                        encoder.layers.encoder_layer_8.mlp.linear_2.weight                encoder.layers.encoder_layer_9.mlp.linear_1.weight                encoder.layers.encoder_layer_10.ln_2.weight                       encoder.layers.encoder_layer_11.self_attention.out_proj.weight    heads.head