### Multi-Head Attention

In [1]:
import math
import torch
import torch.nn as nn

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, head_num, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num

        # shape: (hidden_dim, head_num * head_dim)
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        self.attn_dropout = nn.Dropout(dropout)

    def forward(self, X, attention_mask = None):
        # X shape: (batch_size, seq_len, hidden_dim)
        
        batch_size, seq_len, _ = X.size()

        Q = self.q_proj(X)  # (batch_size, seq_len, hidden_dim)
        K = self.k_proj(X)  # (batch_size, seq_len, hidden_dim)
        V = self.v_proj(X)  # (batch_size, seq_len, hidden_dim)
        
        # 我们希望得到 shape (batch_size, head_num, seq_len, head_dim)
        q_state = Q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        k_state = K.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        v_state = V.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)

        attention_weight = torch.matmul(
            q_state, k_state.transpose(-2, -1)  # (batch_size, head_num, head_dim, seq_len)
        ) / math.sqrt(self.head_dim) # (batch_size, head_num, seq_len, seq_len)

        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float('-inf')
            )

        attention_weight = torch.softmax(attention_weight, dim=-1)
        attention_weight = self.attn_dropout(attention_weight)
        print(attention_weight)

        output_mid = torch.matmul(
            attention_weight, v_state  # (batch_size, head_num, seq_len, head_dim)
        )

        # 我们要concat成 shape: (batch_size, seq_len, head_num * head_dim)
        
        output_mid = output_mid.transpose(1, 2).contiguous()
        output_mid = output_mid.view(batch_size, seq_len, -1)

        output = self.out_proj(output_mid)
        return output

x = torch.randn(3, 2, 128)

attention_mask = (
    torch.tensor(
        [
            [0, 1],
            [0, 0],
            [1, 0],
        ]
    )   # shape: (3, 2)
    .unsqueeze(1)   # shape: (3, 1, 2)
    .unsqueeze(2)   # shape: (3, 1, 1, 2)
    .expand(3, 8, 2, 2) # shape: (3, 8, 2, 2) 这里是广播机制, (batch_size, head_num, seq_len, seq_len)
)

net = MultiHeadAttention(128, 8) # head_dim = 16
net(x, attention_mask)


tensor([[[[0.0000, 1.1111],
          [0.0000, 1.1111]],

         [[0.0000, 1.1111],
          [0.0000, 1.1111]],

         [[0.0000, 1.1111],
          [0.0000, 1.1111]],

         [[0.0000, 1.1111],
          [0.0000, 1.1111]],

         [[0.0000, 1.1111],
          [0.0000, 1.1111]],

         [[0.0000, 0.0000],
          [0.0000, 1.1111]],

         [[0.0000, 1.1111],
          [0.0000, 1.1111]],

         [[0.0000, 1.1111],
          [0.0000, 1.1111]]],


        [[[   nan,    nan],
          [   nan,    nan]],

         [[   nan,    nan],
          [   nan,    nan]],

         [[   nan,    nan],
          [   nan,    nan]],

         [[   nan,    nan],
          [   nan,    nan]],

         [[   nan,    nan],
          [   nan,    nan]],

         [[   nan,    nan],
          [   nan,    nan]],

         [[   nan,    nan],
          [   nan,    nan]],

         [[   nan,    nan],
          [   nan,    nan]]],


        [[[1.1111, 0.0000],
          [0.0000, 0.0000]],

         [

tensor([[[-0.4018, -0.2847, -0.2598, -0.5817,  0.1536,  0.0125,  0.1009,
           0.6344, -0.2779, -0.2702, -0.2061, -0.0138, -0.0454, -0.2967,
          -0.1495, -0.2756,  0.1637,  0.2307, -0.0674, -0.3502,  0.4150,
          -0.0220,  0.1567,  0.4199,  0.0176,  0.1254,  0.0143, -0.4063,
          -0.1950, -0.2313,  0.2778, -0.1599,  0.1987, -0.3715,  0.0199,
           0.1280, -0.1462,  0.1232, -0.6659,  0.1592, -0.0554, -0.1185,
          -0.5166, -0.0453, -0.0612,  0.2194,  0.6215, -0.3896, -0.0832,
          -0.1930, -0.0470, -0.0054, -0.0187,  0.2994,  0.1914, -0.4119,
          -0.1261, -0.0128,  0.0223, -0.1768,  0.1012, -0.0023, -0.1958,
          -0.6680,  0.1518, -0.1023,  0.1019, -0.2101,  0.2763, -0.2076,
          -0.2705, -0.0869,  0.2849,  0.0713,  0.1786, -0.0509, -0.1446,
          -0.1729, -0.0984, -0.0397, -0.3444,  0.0494,  0.0979,  0.0952,
           0.0536,  0.1949, -0.1778, -0.1036,  0.0476, -0.0790, -0.0107,
           0.4252,  0.0278,  0.1853,  0.0620,  0.20