In [1]:
import math
from typing import Optional, List

import torch
from torch import nn

from labml import tracker

# Prepare for multi-head attention
This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform **key**, **query**, and **value** vectors.

In [3]:
class PrepareForMultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
        super().__init__()
        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)  # Linear layer for linear transform
        self.heads = heads  # Number of heads
        self.d_k = d_k  # Number of dimensions in vectors in each head

    def forward(self, x: torch.Tensor):
        head_shape = x.shape[
                     :-1]  # Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model] . We apply the linear transformation to the last dimension and split that into the heads.
        x = self.linear(x)  # Linear ransform
        x = x.view(*head_shape, self.heads, self.d_k)  # Split last dimension into heads
        return x  # Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, heads, d_model]

# Multi-Head Attention Module
This computes scaled multi-headed attention for given query , key and value vectors.
$$Attention(Q,K,V)=softmax\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$$
In simple terms, it finds keys that matches the query, and gets the values of those keys.

It uses dot-product of query and key as the indicator of how matching they are. Before taking the$softmax$the dot-products are scaled by $\frac1{\sqrt{d_k}}.$
 . This is done to avoid large dot-product values causing softmax to give very small gradients when $d_k$ is large.

Softmax is calculated along the axis of of the sequence (or time).


#### Equation-1
$$\text{Calculate }QK^\top\mathrm{~or~}S_{ijbh}=\sum_dQ_{ibhd}K_{jbhd}$$

In [4]:
class MultiHeadAttention(nn.Module):  # Multi-Head Attention Module
    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1,
                 bias: bool = True):  # heads is the number of heads. d_model is the number of features in the query , key and value vectors.
        super().__init__()
        self.d_k = d_model // heads  # Number of features per head
        self.heads = heads  # Number of heads
        # These transform the query , key and value vectors for multi-headed attention.
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
        self.softmax = nn.Softmax(dim=1)  # Softmax for attention along the time dimension of key
        self.output = nn.Linear(d_model, d_model)  # Output layer
        self.dropout = nn.Dropout(dropout_prob)  # Dropout
        self.scale = 1 / math.sqrt(self.d_k)  # Scaling factor before the softmax
        self.attn = None  # We store attentions so that it can be used for logging, or other computations if needed

    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
        """
        :param query:
        :param key:
        :return:
        Calculate scores between queries and keys. This method can be overridden for other variations like relative attention.
        """
        return torch.einsum('ibhd,jbhd->ijbh', query, key)  # Calculate Equation-1

    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
        """
        :param mask:
        :param query_shape:
        :param key_shape:
        :return:
        mask has shape [seq_len_q, seq_len_k, batch_size] , where first dimension is the query dimension. If the query dimension is equal to 1 it will be broadcasted.
        """
        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
        assert mask.shape[1] == key_shape[0]
        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
        mask = mask.unsqueeze(-1)  # Same mask applied to all heads.
        return mask  # resulting mask has shape [seq_len_q, seq_len_k, batch_size, heads]

    def forward(self, *, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                mask: Optional[torch.Tensor] = None):
        """
        :param query:
        :param key:
        :param value:
        :param mask:
        :return:
        query , key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, batch_size, d_model] .

        mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether for batch b , query at position i has access to key-value at position j .
        """
        seq_len, batch_size, _ = query.shape  # query , key and value have shape [seq_len, batch_size, d_model]
        if mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)
        # Prepare query , key and value for attention computation. These will then have shape [seq_len, batch_size, heads, d_k] .
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        # Compute attention scores QK⊤. This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .
        scores = self.get_scores(query, key)
        # Scale scores
        scores *= self.scale
        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        # softmax attention along the key sequence dimension
        attn = self.softmax(scores)
        # Save attentions if debugging
        tracker.debug('attn', attn)
        # Multiply by values
        x = torch.einsum('ijbn,jbnd->ibnd', attn, value)
        # Save attentions for any other calculations
        self.attn = attn.detach()
        # Concatenate multiple heads
        x = x.reshape(seq_len, batch_size, -1)
        return self.output(x)
