In [14]:
import torch, math
from torch import nn


class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim: int, heads: int, dropout: int = 0.1):
        super().__init__()
        assert embed_dim % heads == 0
        self.embed_dim = embed_dim
        self.heads = heads
        self.head_dim = embed_dim // heads
        self.q_proj_weight = nn.Linear(embed_dim, embed_dim)
        self.k_proj_weight = nn.Linear(embed_dim, embed_dim)
        self.v_proj_weight = nn.Linear(embed_dim, embed_dim)

        self.softmax = nn.Softmax(dim=-1)

        self.dropout = nn.Dropout(dropout)
        self.output = nn.Linear(embed_dim, embed_dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        b, l, d = x.shape
        q, k, v = self.q_proj_weight(x), self.k_proj_weight(x), self.v_proj_weight(x)

        q = q.reshape(b, l, self.heads, self.head_dim).transpose(1, 2)  # b l head head_dim
        k = k.reshape(b, l, self.heads, self.head_dim).transpose(1, 2)
        v = v.reshape(b, l, self.heads, self.head_dim).transpose(1, 2)

        attn_score = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)

        if mask is not None:
            mask = mask.unsqueeze(1).expand(-1, self.heads, -1, -1)
            attn_score = attn_score.masked_fill(mask == 0, float("-inf"))

        attn_score = self.softmax(attn_score)
        attn_score = self.dropout(attn_score)
        attn_output = torch.matmul(attn_score, v)
        attn_output = attn_output.transpose(1, 2).reshape(b, l, d)
        return self.output(attn_output)

In [15]:
mha = MultiheadAttention(embed_dim=8, heads=2)
x = torch.randn(2, 4, 8)
mha(x).shape

torch.Size([2, 4, 8])