In [None]:
#| default_exp transformers.relative_mha

In [None]:
#| hide

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| hide
#| export

from fastcore.basics import patch
from nbdev.showdoc import *

# Relative Multi-Headed Attention
> Annotated [PyTorch](https://pytorch.org) implementation of Relative Multi-Headed Attention from the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860).

In [None]:
#| export
#| hide

import torch
from torch import nn

from fastcore.test import *

from labml.logger import inspect
from mlrpg.transformers.mha import MultiHeadAttention

In [None]:
#| export

def shift_right(x: torch.Tensor):
    "This method shifts the $i^{th}$ row of a matrix by $i$ columns."

    #*Ideally we should mask out the lower triangle but it's ok for our purpose*.

    # Concatenate a column of zeros
    zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
    x_padded = torch.cat([x, zero_pad], dim=1)

    # Reshape and remove excess elements from the end
    x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
    x = x_padded[:-1].view_as(x)

    return x

For example, if the input is `[[1, 2 ,3], [4, 5 ,6], [7, 8, 9]]`, the shifted result would be `[[1, 2 ,3], [0, 4, 5], [6, 0, 7]]`.

In [None]:

a = torch.tensor([[1, 2 ,3], [4, 5 ,6], [7, 8, 9]])
b = torch.tensor([[1, 2 ,3], [0, 4, 5], [6, 0, 7]])
test_eq(shift_right(a), b)

We can visualize tensors with `inspect()`:

In [None]:
c = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
inspect(c[:, :, 0, 0])

In [None]:
inspect(shift_right(c)[:, :, 0, 0])

In [None]:
#| exports

class RelativeMultiHeadAttention(MultiHeadAttention):
    "Relative Multi-Head Attention Module."

    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): # <1>
        
        super().__init__(heads, d_model, dropout_prob, bias=False) # <2>

        self.P = 2 ** 12 # <1>
        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True) # <3>
        self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True) # <4>
        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) # <5>

1. The linear transformations do not need a bias since we explicitly include it when calculating scores. However having a bias for `value` might make sense.
2. Number of relative positions
3. Relative positional embeddings for key relative to the query. We need $2P$ embeddings because the keys can be before or after the query
4. Relative positional embedding bias for key relative to the query.
5. Positional embeddings for the query is independent of the position of the query

In [None]:
#| exports

@patch
def get_scores(self: RelativeMultiHeadAttention, query:torch.Tensor, key:torch.Tensor): # <1>
    "Get relative attention scores."

    key_pos_emb = self.key_pos_embeddings[self.P - key.shape[0]:self.P + query.shape[0]] # <2>
    key_pos_bias = self.key_pos_bias[self.P - key.shape[0]:self.P + query.shape[0]] # <3>
    query_pos_bias = self.query_pos_bias[None, None, :, :] # <4>
    ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key) # <5>
    b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb) # <6>
    d = key_pos_bias[None, :, None, :] # <7>
    bd = shift_right(b + d) # <8>
    bd = bd[:, -key.shape[0]:] # <9>

    # Return the sum
    return ac + bd

1. We override the `MultiHeadAttention` module so we only need to write the `get_scores()` method.
2. $R_k$
3. $S_k$
4. ${v^\top}$
5. $\mathbf{A + C}_{i,j} = Q_i^\top K_j + v^\top K_j$
6. $\mathbf{B'}_{i,k} = Q_i^\top {R_k}$
7. $\mathbf{D'}_{i,k} = {S_k}$
8. Shift the rows of $\mathbf{(B' + D')}_{i,k}$ to get $\mathbf{(B + D)}_{i,j} = \mathbf{(B' + D')}_{i,i - j}$
9. Remove extra positions

With absolute attention  $A^{abs}_{j} = lin_q(X^q_i + P_i)^\top lin_k(X^k_j + P_j) = {Q_i^\top K_j} + {Q_i^\top U^K_j} + {U^Q_i}^\top K_j + {{U^Q_i}^\top U^K_j}$, where $Q_i, K_j$, are linear transformations of original embeddings $X^q_i, X^k_j$ and $U^Q_i, U^K_j$ are linear transformations of absolute positional encodings $P_i, P_j$. They reason out that the attention to a given key should be the same regardless of the position of query. Hence replace ${U^Q_i}^\top K_j$ with a constant ${v^\top} K_j$.

For the second and third terms relative positional encodings are introduced. So ${Q_i^\top U^K_j}$ is replaced with ${Q_i^\top {R_{i - j}}}$ and ${{U^Q_i}^\top U^K_j}$ with ${{S_{i-j}}}$, $A^{rel}_{i,j} = \mathbf{Q_i^\top K_j} + {Q_i^\top {R_{i - j}}} + {{v^\top} K_j} + {S_{i-j}}$

In [None]:
#| hide

import nbdev; nbdev.nbdev_export()