# Multi-Head Self-Attention
![self_multi_attention](./../image/self_multi_atten.png)

上面的self attention其实就是描述的
$$ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $$
这个公式。

MultiHead attention可以参考：
$$ MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h)W^O $$
其中
$$ head_i = Attention(QW_i^Q, KW_i^K, VW_i^V) $$

In [1]:
import math
import torch
import torch.nn as nn

class MultiHeadSelfAtten(nn.Module):
    def __init__(self, hidden_dim, head_num, attention_dropout = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // head_num
        self.head_num = head_num
        
        self.q = nn.Linear(hidden_dim, hidden_dim) # (hidden_dim, head_dim * head_num)
        self.k = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        
        self.attention_dropout = nn.Dropout(attention_dropout)
        
    def forward(self, x, attention_mask = None):
        batch, seq_len, _ = x.size()
        
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x) # (b, s, h)
        
        # (b, s, h) -> (b, head_num, s, head_dim)
        q_state = Q.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        k_state = K.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        v_state = V.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        
        # (b, head_num, s, s)
        attention_weight = q_state @ k_state.transpose(-1, -2) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float('-inf')
            )
        # print(attention_weight)
        
        attention_weight = torch.softmax(attention_weight, dim = -1)
        attention_weight = self.attention_dropout(attention_weight)
        # (b, head_num, s, s) @ (b, head_num, s, head_dim) -> (b, head_num, s, head_dim)
        output_mid = attention_weight @ v_state 
        
        # (b, head_num, s, head_dim) -transpose-> (b, s, head_num, head_dim)
        output_mid = output_mid.transpose(1, 2).contiguous() # 内存连续化
        # (b, s, head_num, head_dim) -view-> (b, s, hidden_dim)
        output_mid = output_mid.view(batch, seq_len, -1)
        
        output = self.out_proj(output_mid)
        return output
    
# test
X = torch.rand(3, 2, 128)
# mast (3, 2) -unsqueeze-> (3, 1, 2) -unsqueeze-> (3, 1, 1, 2) -expand-> (3, 8, 2, 2)
mask = torch.tensor(
    [
        [0, 1],
        [0, 0],
        [1, 0]
    ]
).unsqueeze(1).unsqueeze(2).expand(3, 8, 2, 2)

# head number: 8
att_net = MultiHeadSelfAtten(128, 8)
out = att_net(X, mask)
# out