In [12]:
import torch
from torch import nn
import torch.nn.functional as F

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, 
                 hidden_dim: torch.Tensor, 
                 num_heads: torch.Tensor, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self,
                X: torch.Tensor, 
                casual_mask: torch.Tensor = None, 
                pad_mask: torch.Tensor = None):
        bsz, seq, _ = X.shape
        # [b, s, h]
        q: torch.Tensor = self.q_proj(X)
        k: torch.Tensor = self.k_proj(X)
        v: torch.Tensor = self.v_proj(X)
        # [b, nh, s, h]
        q = q.view(bsz, seq, self.num_heads, -1).permute(0, 2, 1, 3)
        k = k.view(bsz, seq, self.num_heads, -1).permute(0, 2, 1, 3)
        v = v.view(bsz, seq, self.num_heads, -1).permute(0, 2, 1, 3)
        attn_score = q @ k.transpose(-2, -1) / self.head_dim ** 0.5
        if casual_mask is not None: attn_score += casual_mask * -1e9
        if pad_mask is not None:
            # [bsz, 1, 1, seq]
            pad_mask = pad_mask.unsqueeze(1).unsqueeze(1)
            attn_score += pad_mask * -1e9
        attn_probs = F.softmax(attn_score, dim=-1)
        attn_probs = self.dropout(attn_probs)
        output_mid: torch.Tensor = attn_probs @ v # [bsz, nh, seq, seq] @ [bsz, nh, seq, h] = [bsz, nh, seq, h]
        output = self.o_proj(output_mid.permute(0, 2, 1, 3).reshape(bsz, seq, -1))
        return output

In [16]:
def mha_test():
    bsz, seq_len, hidden_dim, num_heads = 128, 512, 1024, 8
    X = torch.randn(bsz, seq_len, hidden_dim)
    mha = MultiHeadAttention(hidden_dim, num_heads)
    casual_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    output = mha(X, casual_mask)
    return output

In [17]:
if __name__ == "__main__":
    mha_test()