In [10]:
import torch
import math
import torch.nn as nn
from timm.layers import trunc_normal_
from typing import Type

### model config for vit_large_patch16_224 (search Google)
dtype = torch.float32
embed_dim = 1024
qkv_bias = True
num_heads=16
norm_layer: Type[nn.Module] = nn.LayerNorm
qk_norm = True
dropout_p = 0.0

# if use input minibatch shape (16, 3, 224, 224) for vit_large_patch16_224, then x tensor shape would be
x = torch.randn((16, 196, embed_dim), dtype=dtype)

# initialised vars from model config:
head_dim = embed_dim // num_heads
scale = head_dim ** -0.5
qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias)
q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()

# Steal from Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L858-L865
def init_weights_vit_timm(module: nn.Module, name: str = '') -> None:
    """ ViT weight initialization, original timm impl (for reproducibility) """
    if isinstance(module, nn.Linear):
        trunc_normal_(module.weight, std=.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        module.init_weights()
init_weights_vit_timm(qkv)

# Creating Q, K, V
batch_size, num_seq, channels = x.shape
qkv = qkv(x).reshape(batch_size, num_seq, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = q_norm(q), k_norm(k)

Minor test: Check similarity between @ and matmul

In [13]:
attn_logit_1 = torch.matmul(q, k.transpose(-2, -1)) * scale
attn_logit_2 = q @ k.transpose(-2, -1) * scale
print("L2-norm for 1st and 2nd Attention Logits", torch.norm(attn_logit_1.detach() - attn_logit_2.detach(), p=2))

L2-norm for 1st and 2nd Attention Logits tensor(0.)


### Timm Multi-head Self-Attention vs F.scaled_dot_product_attention

- Attention Logits is right before you apply softmax

In [14]:
### F.scaled_dot_product_attention
L, S = q.size(-2), k.size(-2)
attn_bias = torch.zeros(L, S, dtype=q.dtype)

attn_logit_F = q @ k.transpose(-2, -1) * scale
attn_logit_F += attn_bias
attn_score_F = torch.softmax(attn_logit_F, dim=-1)
attn_score_F = torch.dropout(attn_score_F, dropout_p, train=True)
output_F = attn_score_F @ v

### Timm Multi-Head Self-Attention
attn_drop = nn.Dropout(dropout_p)

q = q * scale
attn_logit_Timm = q @ k.transpose(-2, -1)
attn_score_Timm = attn_logit_Timm.softmax(dim=-1)
attn_score_Timm = attn_drop(attn_score_Timm)
output_Timm = attn_score_Timm @ v

# Linear algebra says this should be identical due to Associative rule but let's check
print("L2 for attn_logit_F and attn_logit_Timm", torch.norm(attn_logit_F.detach() - attn_logit_Timm.detach(), p=2))
print("L2 for attn_score_F and attn_score_Timm", torch.norm(attn_score_F.detach() - attn_score_Timm.detach(), p=2))
print("L2 for output_F and output_Timm", torch.norm(output_F.detach() - output_Timm.detach(), p=2))

L2 for attn_logit_F and attn_logit_Timm tensor(0.)
L2 for attn_score_F and attn_score_Timm tensor(0.)
L2 for output_F and output_Timm tensor(0.)
