In [1]:
!pip install labml-nn --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m435.0/435.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m131.0/131.0 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.3/266.3 kB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m43.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import math
from typing import Optional, List
import torch
from torch import nn
from labml import tracker

In [3]:
class PrepareForMultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
        super().__init__()
        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
        self.heads = heads
        self.d_k = d_k

    def forward(self, x: torch.Tensor):
#Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model] . We apply the linear transformation to the last dimension and split that into the heads.
        head_shape = x.shape[:-1]
#Linear transform
        x = self.linear(x)

#Split last dimension into heads
        x = x.view(*head_shape, self.heads, self.d_k)
#Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, heads, d_model]

        return x


In [12]:
class MultiHeadAttention(nn.Module):
  def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
            super().__init__()
            self.d_k = d_model // heads
#Number of heads
            self.heads = heads
#These transform the query , key and value vectors for multi-headed attention.
            self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
            self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
            self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
#Softmax for attention along the time dimension of key
            self.softmax = nn.Softmax(dim=1)
#Output layer

            self.output = nn.Linear(d_model, d_model)
#Dropout
            self.dropout = nn.Dropout(dropout_prob)
#Scaling factor before the softmax
            self.scale = 1 / math.sqrt(self.d_k)
#We store attentions so that it can be used for logging, or other computations if needed
            self.attn = None
#Calculate scores between queries and keys

  def get_scores(self, query: torch.Tensor, key: torch.Tensor):
    #performing batched matrix multiplication and contraction,
         return torch.einsum('ibhd,jbhd->ijbh', query, key)


In [14]:
def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
        assert mask.shape[1] == key_shape[0]
        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
#Same mask applied to all heads.

        mask = mask.unsqueeze(-1)
#resulting mask has shape [seq_len_q, seq_len_k, batch_size, heads]
        return mask

def forward(self, *,query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):
        seq_len, batch_size, _ = query.shape
        if mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)

        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        scores = self.get_scores(query, key)
        scores *= self.scale

        if mask is not None:
           scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = self.softmax(scores)
        tracker.debug('attn', attn)
        attn = self.dropout(attn)
        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
        self.attn = attn.detach()
        x = x.reshape(seq_len, batch_size, -1)
        return self.output(x)