In [3]:

from typing import List, Optional

import torch
import torch.nn.functional as F
from torch import nn


In [4]:


class PrepareForMHA(nn.Module):

    """
    This module does a linear transformation and splits the vector into given number of heads for multi-head attention.
    This is used to transform key, query, and value vectors.
    """

    def __init__(self,
                 num_heads: int,
                 embed_dim: int,
                 key_dim: int,
                 bias: bool
                 ) -> None:

        """
        Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model].
        We apply the linear transformation to the last dimension and split that into the heads.
        """

        super().__init__()

        self.fc = nn.Linear(embed_dim, num_heads * key_dim, bias = bias)
        self.num_heads = num_heads
        self.key_dim = key_dim

    def forward(self, x: torch.Tensor):

        head_shape = x.shape[:-1]

        """
        split last dimension into heads
        output has shape [seq_len, batch_size, heads, d_k] or [batch_size, d_model]
        """

        return self.fc(x).reshape(*head_shape, self.num_heads, self.key_dim)

class MHA(nn.Module):

    def __init__(self,
                 num_heads: int,
                 embed_dim: int,
                 dropout_p: float,
                 bias: bool = True
                 ) -> None:
        super().__init__()

        self.num_heads = num_heads
        self.key_dim = embed_dim // num_heads

        self.W_Q = PrepareForMHA(num_heads, embed_dim, self.key_dim, bias = bias)
        self.W_K = PrepareForMHA(num_heads, embed_dim, self.key_dim, bias = bias)
        self.W_V = PrepareForMHA(num_heads, embed_dim, self.key_dim, bias = True)

        self.fc = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.scale = 1 / torch.sqrt(self.key_dim)

        self.attn = None

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                attn_mask: Optional[torch.Tensor] = None,
                need_weights: bool = True):

        seq_len, batch_size, _ = query.shape

        # Prepare query, key and value for attention computation.
        # These will then have shape [seq_len, batch_size, heads, key_dim].
        Q = self.W_Q(query)
        K = self.W_K(key)
        V = self.W_V(value)

        scores = self.scale * torch.einsum('ibhd,jbhd->ijbh', query, key) # or  Q * K^⊤

        if attn_mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        scores = self.dropout(F.softmax(scores))

        output = torch.einsum("ijbh,jbhd -> ibhd", scores, V).reshape(seq_len, batch_size, -1)
        output = self.fc(output)

        return (output, scores) if need_weights else output

    
    def prepare_mask(self,
                     mask: torch.Tensor,
                     query_shape: List[int],
                     key_shape: List[int]):

        """
        mask has shape [seq_len_q, seq_len_k, batch_size],
        where first dimension is the query dimension.
        If the query dimension is equal to 1 it will be broadcasted.
        """

        assert \
            mask.shape[0] == 1 or mask.shape[0] == query_shape[0] \
            or mask.shape[1] == key_shape[0] \
            or mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

        return mask.unsqueeze(-1)

