# Multi-Head Self-Attention
![self_multi_attention](./../image/self_multi_atten.png)

上面的self attention其实就是描述的
$$ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $$
这个公式。

MultiHead attention可以参考：
$$ MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h)W^O $$
其中
$$ head_i = Attention(QW_i^Q, KW_i^K, VW_i^V) $$

In [11]:
import math
import torch
import torch.nn as nn

class MultiHeadSelfAtten(nn.Module):
    def __init__(self, hidden_dim, head_num, attention_dropout = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // head_num
        self.head_num = head_num
        
        self.q = nn.Linear(hidden_dim, hidden_dim) # (hidden_dim, head_dim * head_num)
        self.k = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        
        self.attention_dropout = nn.Dropout(attention_dropout)
        
    def forward(self, x, attention_mask = None):
        batch, seq_len, _ = x.size()
        
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x) # (b, s, h)
        
        # (b, s, h) -> (b, head_num, s, head_dim)
        q_state = Q.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        k_state = K.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        v_state = V.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        
        # (b, head_num, s, s)
        attention_weight = q_state @ k_state.transpose(-1, -2) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float('-inf')
            )
        # print(attention_weight)
        
        attention_weight = torch.softmax(attention_weight, dim = -1)
        attention_weight = self.attention_dropout(attention_weight)
        # (b, head_num, s, s) @ (b, head_num, s, head_dim) -> (b, head_num, s, head_dim)
        output_mid = attention_weight @ v_state 
        
        # (b, head_num, s, head_dim) -transpose-> (b, s, head_num, head_dim)
        output_mid = output_mid.transpose(1, 2).contiguous() # 内存连续化
        # (b, s, head_num, head_dim) -view-> (b, s, hidden_dim)
        output_mid = output_mid.view(batch, seq_len, -1)
        
        output = self.out_proj(output_mid)
        return output
    
# test
X = torch.rand(3, 2, 128)
# mast (3, 2) -unsqueeze-> (3, 1, 2) -unsqueeze-> (3, 1, 1, 2) -expand-> (3, 8, 2, 2)
mask = torch.tensor(
    [
        [0, 1],
        [0, 0],
        [1, 0]
    ]
).unsqueeze(1).unsqueeze(2).expand(3, 8, 2, 2)

# head number: 8
att_net = MultiHeadSelfAtten(128, 8)
out = att_net(X, mask)
out

tensor([[[-0.0471,  0.1656,  0.2075,  0.0242, -0.0959, -0.3411,  0.1504,
          -0.4096,  0.2382,  0.2489,  0.2209,  0.1893, -0.0222,  0.1877,
          -0.2426, -0.2405,  0.2585, -0.3117, -0.1795,  0.0900,  0.5275,
           0.3705, -0.1212, -0.2062,  0.4271, -0.2557,  0.1747, -0.1633,
          -0.1529, -0.0054, -0.1377,  0.2733, -0.4722,  0.1235, -0.1139,
          -0.2334,  0.2540, -0.3412,  0.1857, -0.0594,  0.1311,  0.0223,
           0.0512,  0.0360, -0.0992,  0.2679,  0.1296,  0.3204, -0.4174,
           0.2088,  0.1322, -0.0610,  0.1558, -0.1672,  0.1751, -0.3704,
           0.0196, -0.0290, -0.2155, -0.1648,  0.0823,  0.1269,  0.3863,
          -0.3423, -0.1285,  0.0373, -0.1913, -0.1605, -0.4552,  0.2384,
           0.0624,  0.0482,  0.0336, -0.0083, -0.3587, -0.1341,  0.1184,
           0.7127,  0.4748,  0.1044,  0.1637, -0.1446, -0.0902,  0.0372,
           0.1216, -0.3071,  0.2499,  0.1505,  0.0847, -0.2407,  0.3632,
           0.1685, -0.2675, -0.2420, -0.0836, -0.06