# Multi-Head Attention in Pytorch
[link](https://ut.philkr.net/deeplearning/transformers/multihead_attention_in_pytorch/)

In [9]:
import torch
from mha import MultiHeadAttention

## Multi-Head Attention (in pytorch)

Torch implementation corresponds to the particular (usual) case
 * $d_h:=d_k=d_v$ (head dimensions)
 * $d := c_q = d_h\cdot h = d_o$ (model dimension)  

Denote $d_h = d|h$ ($h$ must divide $d$). Then we have the simplified defintion

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times d}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}$

Weights
* $W_{q,i} \in \mathbb{R}^{d \times d_h}$ and $b_{q,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c_k \times d_h}$ and $b_{k,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c_v \times d_h}$ and $b_{v,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d \times d }$ and $b_{o}\in \mathbb{R}^{d}$.


Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d}$
\begin{align*}
O  &= \text{MultiHeadAttention}_{\{\mathcal{W}_i\}_{i=1}^h\cup \{W_o\}}(Q,K,V),\\
&=\begin{pmatrix}
\text{Attention}_{\mathcal{W}_1}(Q,K,V)|
\text{Attention}_{\mathcal{W}_2}(Q,K,V)|
\dots|
\text{Attention}_{\mathcal{W}_h}(Q,K,V)
\end{pmatrix}W_o + B_o
\end{align*}
where:
    * $W_q=[W_{q,1}|W_{q,2}|\dots|W_{q,h}]\in\mathbb{R}^{d\times d}$  
    * $W_k=[W_{k,1}|W_{k,2}|\dots|W_{k,h}]\in\mathbb{R}^{d\times d}$  
    * $W_v=[W_{v,1}|W_{v,2}|\dots|W_{v,h}]\in\mathbb{R}^{d\times d}$  
    * $b_q=[b^T_{q,1}|b^T_{q,2}|\dots|b^T_{q,h}]^T\in\mathbb{R}^{d}$  
    * $b_k=[b^T_{k,1}|b^T_{k,2}|\dots|b^T_{k,h}]^T\in\mathbb{R}^{d}$ 
    * $b_v=[b^T_{v,1}|b^T_{v,2}|\dots|b^T_{v,h}]^T\in\mathbb{R}^{d}$  
    * $B_o= [b_o|b_o|\dots|b_o]^T \in \mathbb{R}^{M \times d}$
    * $\mathcal{W}_i = \{W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i}\}, \quad i = 1,2,\dots,h$.

### Hyper-Parameters

In [10]:
batch_dim = 10
M = 5  # sequence length of q
N = 3  # sequence length of k,v
d = 16 # embedding/model dimension
ck = 32  # key dimension 
cv = 64  # value dimension
h = 2  # number of heads

bias = False
add_bias_kv = True

In [11]:
torch_attn = torch.nn.MultiheadAttention(embed_dim=d,kdim=ck,vdim=cv,
                                        num_heads=h,batch_first=True,
                                        bias=bias,add_bias_kv=add_bias_kv)

#### Input

In [12]:
q = torch.rand(batch_dim,M,d)
k = torch.rand(batch_dim,N,ck) 
v = torch.rand(batch_dim,N,cv)

#### Output

In [13]:
o_torch,_ = torch_attn(q,k,v)
print(f"dim({batch_dim},{M},{d}) = {tuple(o_torch.shape)}")

dim(10,5,16) = (10, 5, 16)


## Custom Multi-head attention

In [14]:
assert d % h == 0, "d must be divisible by h"
dh = d // h
our_attn = MultiHeadAttention(cq = d, ck = ck, cv=cv, dk = dh ,dv=dh,do=d,h=h,bias=bias,add_bias_kv=add_bias_kv)

### Weights

In [15]:
for name,w in our_attn.named_parameters():
    print(f"{name} shape: {tuple(w.shape)}")

bias_k shape: (1, 1, 16)
bias_v shape: (1, 1, 16)
q_proj.weight shape: (16, 16)
k_proj.weight shape: (16, 32)
v_proj.weight shape: (16, 64)
out_proj.weight shape: (16, 16)


In [16]:
for name,w in torch_attn.named_parameters():
    print(f"{name} shape: {tuple(w.shape)}")

q_proj_weight shape: (16, 16)
k_proj_weight shape: (16, 32)
v_proj_weight shape: (16, 64)
bias_k shape: (1, 1, 16)
bias_v shape: (1, 1, 16)
out_proj.weight shape: (16, 16)


#### Load weights

In [17]:
with torch.no_grad():
    # 1) copy weights (shapes already match)
    our_attn.q_proj.weight.copy_(torch_attn.q_proj_weight)   # (d, d)
    our_attn.k_proj.weight.copy_(torch_attn.k_proj_weight)   # (d, ck)
    our_attn.v_proj.weight.copy_(torch_attn.v_proj_weight)   # (d, cv)

    # 2) split the packed bias: (3d,) -> (d,) + (d,) + (d,)
    b = torch_attn.in_proj_bias      # shape (48,)
    if bias:
        our_attn.q_proj.bias.copy_(b[0:d])        # 0:d
        our_attn.k_proj.bias.copy_(b[d:2*d])      # d:2d
        our_attn.v_proj.bias.copy_(b[2*d:3*d])    # 2d:3d
    
    if add_bias_kv:
        our_attn.bias_k.copy_(torch_attn.bias_k.squeeze(0))
        our_attn.bias_v.copy_(torch_attn.bias_v.squeeze(0))

    # 3) output projection
    our_attn.out_proj.weight.copy_(torch_attn.out_proj.weight)
    if bias:
        our_attn.out_proj.bias.copy_(torch_attn.out_proj.bias)

In [18]:
o_our = our_attn(q, k, v)
o_torch,_ = torch_attn(q,k,v)

In [19]:
o_torch

tensor([[[ 1.5013e-01, -1.2651e-01, -6.1665e-02,  1.0711e-02, -2.2102e-01,
           3.5204e-01, -3.0692e-01,  5.1115e-02, -3.1570e-02, -2.2177e-01,
          -3.5550e-01,  6.6265e-01, -7.2389e-02,  4.4651e-01,  3.1056e-01,
           5.4125e-01],
         [ 1.6387e-01, -1.4438e-01, -6.9482e-02, -3.1283e-02, -2.1814e-01,
           2.8269e-01, -2.9591e-01,  9.4113e-02,  2.1844e-03, -2.8042e-01,
          -3.8915e-01,  5.9910e-01, -1.8294e-02,  3.9788e-01,  2.3717e-01,
           4.9621e-01],
         [ 1.6837e-01, -1.4277e-01, -5.4942e-02, -6.4830e-03, -2.2956e-01,
           3.4378e-01, -2.9047e-01,  5.7376e-02, -2.2718e-02, -2.4855e-01,
          -3.6860e-01,  6.7153e-01, -7.0738e-02,  4.4581e-01,  3.2029e-01,
           5.4473e-01],
         [ 1.7244e-01, -1.5290e-01, -6.3431e-02, -3.3558e-02, -2.1663e-01,
           3.0943e-01, -2.8010e-01,  8.8373e-02, -7.3300e-03, -2.8857e-01,
          -3.9613e-01,  6.3131e-01, -4.5604e-02,  4.0592e-01,  2.6818e-01,
           5.1075e-01],
    

In [20]:
o_our

tensor([[[ 1.5013e-01, -1.2651e-01, -6.1665e-02,  1.0711e-02, -2.2102e-01,
           3.5204e-01, -3.0692e-01,  5.1115e-02, -3.1570e-02, -2.2177e-01,
          -3.5550e-01,  6.6265e-01, -7.2389e-02,  4.4651e-01,  3.1056e-01,
           5.4125e-01],
         [ 1.6387e-01, -1.4438e-01, -6.9482e-02, -3.1283e-02, -2.1814e-01,
           2.8269e-01, -2.9591e-01,  9.4113e-02,  2.1844e-03, -2.8042e-01,
          -3.8915e-01,  5.9910e-01, -1.8294e-02,  3.9788e-01,  2.3717e-01,
           4.9621e-01],
         [ 1.6837e-01, -1.4277e-01, -5.4942e-02, -6.4830e-03, -2.2956e-01,
           3.4378e-01, -2.9047e-01,  5.7376e-02, -2.2718e-02, -2.4855e-01,
          -3.6860e-01,  6.7153e-01, -7.0738e-02,  4.4581e-01,  3.2029e-01,
           5.4473e-01],
         [ 1.7244e-01, -1.5290e-01, -6.3431e-02, -3.3558e-02, -2.1663e-01,
           3.0943e-01, -2.8010e-01,  8.8373e-02, -7.3300e-03, -2.8857e-01,
          -3.9613e-01,  6.3131e-01, -4.5604e-02,  4.0592e-01,  2.6818e-01,
           5.1075e-01],
    