In [2]:
import torch
import torch.nn as nn
import math

In [31]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, head_num, att_dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num

        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)

        self.att_dropout = nn.Dropout(att_dropout)

        self.proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X, att_mask=None):
        batch, seq_len, _ = X.size()

        Q = self.query(X)
        K = self.key(X)
        V = self.value(X)

        q_state = Q.view(batch, seq_len, self.head_num, self.head_dim).transpose(1,2)
        k_state = K.view(batch, seq_len, self.head_num, self.head_dim).transpose(1,2)
        v_state = V.view(batch, seq_len, self.head_num, self.head_dim).transpose(1,2)

        att_weight = q_state @ k_state.transpose(-1,-2) / math.sqrt(self.head_dim)

        if att_mask is not None:
            att_weight = att_weight.masked_fill(
                att_mask == 0,
                float('-inf')
            )
        # shape: [batch, head_num, seq_len, seq_len]
        
        att_weight = torch.softmax(att_weight, dim=-1)
        att_weight = self.att_dropout(att_weight)

        output_mid = att_weight @ v_state # batch, head_num, seq_len, head_dim #
        output_mid = output_mid.transpose(1,2).contiguous()

        output = output_mid.view(batch, seq_len, -1)

        return output


attention_mask = (
    torch.tensor(
        [
            [0, 1],
            [0, 0],
            [1, 0],
        ]
    )
    .unsqueeze(1)
    .unsqueeze(2)
    .expand(3, 8, 2, 2)
)

x = torch.rand(3, 2, 128)
net = MultiHeadAttention(128, 8)
net(x, attention_mask).shape


torch.Size([3, 2, 128])