In [17]:
import torch
import torch.nn as nn
import math
from typing import Optional, List

In [18]:
class TokenEmbedding(nn.Module):
    
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        
    def forward(self, x: torch.Tensor):
        return self.emb(x)

In [19]:
def get_sinusoidal_positional_encoding(d_model: int, max_len: int = 4096):
    # Empty encodings vectors
    encodings = torch.zeros(max_len, d_model)
    # Position indexes
    position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
    # $2 * i$
    two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
    # $10000^{\frac{2i}{d_{model}}}$
    div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
    # $PE_{p,2i} = sin\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$
    encodings[:, 0::2] = torch.sin(position * div_term)
    # $PE_{p,2i + 1} = cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$
    encodings[:, 1::2] = torch.cos(position * div_term)
    # Add batch dimension
    encodings = encodings.unsqueeze(1).requires_grad_(False)
    return encodings


class SinusoidalPositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout_prob: float, max_len: int = 4096):
        super().__init__()
        self.register_buffer('positional_encoding', get_sinusoidal_positional_encoding(d_model, max_len), False)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x: torch.Tensor):
        pe = self.positional_encoding[:x.shape[0]].detach().requires_grad_(False)
        return self.dropout(x + pe)


class LearnedPositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout_prob: float, max_len: int = 4096):
        super().__init__()
        self.positional_encoding = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x: torch.Tensor):
        pe = self.positional_encoding[:x.shape[0]]
        return self.dropout(x + pe)

In [20]:
class PrepareForMultiHeadAttention(nn.Module):
    """
    approx a linear transformation
    """

    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
        super().__init__()
        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
        self.heads = heads
        self.d_k = d_k

    def forward(self, x: torch.Tensor):
        head_shape = x.shape[:-1]
        x = self.linear(x)
        x = x.view(*head_shape, self.heads, self.d_k) # batch, length, heads, d_k
        return x
    
class MultiHeadAttention(nn.Module):

    def __init__(self, d_model: int, heads: int, dropout_prob: float = 0.1, bias: bool = True,
                 use_drop_key: bool = False, mask_ratio: float = 0.3):
        super().__init__()
        self.d_k = d_model // heads
        self.heads = heads
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
        self.scale = 1 / math.sqrt(self.d_k)
        self.softmax = nn.Softmax(dim=-1)
        self.output = nn.Linear(d_model, d_model)
        self.attn = None
        self.use_drop_key = use_drop_key
        self.mask_ratio = mask_ratio

    def get_score(self, query: torch.Tensor, key: torch.Tensor):
        score = torch.matmul(query, key.transpose(-2, -1))
        return score

    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
        assert len(mask.shape) == 2 or len(mask.shape) == 3
        if mask.shape == 2:
            assert mask.shape[0] == query_shape[1]
            assert mask.shape[1] == key_shape[1]
            mask = mask.unsqueeze(0)
        else:
            assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
            assert mask.shape[1] == query_shape[1]
            assert mask.shape[2] == key_shape[1]
        return mask

    def forward(self, *,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                mask: Optional[torch.Tensor] = None):
        """
        Args:
            query: shape (batch_size, seq_len, d_model)
            key: shape (batch_size, seq_len, d_model)
            value: shape (batch_size, seq_len, d_model)
            mask: shape (batch_size, seq_len, seq_len). Since we assume all data use a same mask, so
                  here the shape also equals to (1, seq_len, seq_len)

        Return:
            out: shape (batch_size, seq_len, d_model). The output of a multihead attention layer
        """
        seq_len, batch_size, _ = query.shape
        if mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)

        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        scores = self.get_score(query, key)
        scores *= self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = self.softmax(scores)
        x = torch.matmul(attn, value)
        self.attn = attn.detach()
        x = x.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.output(x)

In [21]:
inputs = torch.tensor([
    [1, 2, 3, 4],
    [5, 6, 7, 8]
])

In [22]:
token_embed = TokenEmbedding(d_model=512, vocab_size=32000)
x = token_embed(inputs)

In [23]:
pe = SinusoidalPositionalEncoding(d_model=512, dropout_prob=0.1, max_len=4096)
x = pe(x)

torch.Size([2, 4, 512])

In [26]:
mha = MultiHeadAttention(d_model=512, heads=8)
mha(query=x, key=x, value=x)

torch.Size([4, 2, 512])