# Relative position embeddings according to T5 paper


The part of the paper that deals with position embeddings:

> Since self-attention is order-independent (i.e. it is an operation on sets), it is common to provide an explicit position signal to the Transformer. While the original Transformer used a sinusoidal position signal or learned position embeddings, it has recently become more common to use relative position embeddings (Shaw et al., 2018; Huang et al., 2018a). Instead of using a fixed embedding for each position, relative position embeddings produce a different learned embedding according to the offset between the “key” and “query” being compared in the self-attention mechanism. We use a simplified form of position embeddings where each “embedding” is simply a scalar that is added to the corresponding logit used for computing the attention weights. For efficiency, we also share the position embedding parameters across all layers in our model, though within a given layer each attention head uses a different learned position embedding. Typically, a fixed number of embeddings are learned, each corresponding to a range of possible key-query offsets. In this work, we use 32 embeddings for all of our models with ranges that increase in size logarithmically up to an offset of 128 beyond which we assign all relative positions to the same embedding. Note that a given layer is insensitive to relative position beyond 128 tokens, but subsequent layers can build a sensitivity to larger offsets by combining local information from previous layers.

Well, their approach can seem a little confusing due to the use of algorithms, which is why you won't find many online hobbyists like me implementing this approach. Personally, I searched and only found the original implementation in TensorFlow.

Let's make it as simple as possible:
* They adopt the approach of (Shaw et al., 2018; Huang et al., 2018a). That is, they adopt the idea of relative positional representations by modifying the attention matrix. But what's new?
* Instead of using high-dimensional embedding vectors to encode relative distances, they use scalar values. That is, instead of using a vector from d dimension to represent the relative distance -2, they use a scalar value. Therefore the relative distance -2 is given a scalar value to represent it. So the embeddings here are scalar values. Nice! this reduces spatial complexity.
* They merely modify the attention matrix, and do not modify the value matrix as their predecessors did. This is sufficient, as the experiments showed, modifying the value matrix did not improve performance. And this another reduction for spatial complexity.
* Instead of using 2k+1 embedding vectors, they only use 32 ones.
* Their predecessors shared the embedding parameters across different attention heads and across all layers. Whereas here they share it across layers but use different parameters across heads. That is, embedding the first head in the first layer is the same as embedding the first head in the second layer. But the first head is different from the second. This allows each head to capture certain positional information that the rest may not.

**Now how do things work?**
1. First we have what is called `num_buckets` which indicates the number of embeddings (in our case it is 32), and we have another variable `max_distance`, in addition to a complex logarithmic function.
2. Two separate equal sets of embeddings are allocated, one for close distances and one for farther distances. For example, in our case we have 32 embeddings, so 16 of them are reserved for encoding close distances and the other 16 are reserved for longer distances. For close distances, they give each relative distance its own unique embedding. Well this is similar to what their ancestors were doing (this is necessary for precise). As for long distances, a complex logarithmic function is used to sort the relative distances into groups (binning). For example, in our case, relative distances from 8 onwards (since it is bidirectional, i.e. there are relative distances from right and left, the number of embeddings from each direction will be 8) are sorted into groups, each group having a shared embedding. For example, relative distances from 8 to 12 have a shared one, relative distances from 12 to 16 have another one, and so on until we reach the maximum relative distance `max_distance`. Starting from `max_distance` all relative distances are the same as the last set. Of course, the number of elements in each group increases as we move away from the center. For example, the number of relative distances in the first group may be 4, while the third group may have 7, and so on. This approach may be better than simply associating all relative distances beyond a certain threshold with the same embedding.
Unusually, I won't explain the mathematics behind the logarithmic function used, I think that's boring, and I don't see the point in doing so. Nor did the authors of the paper do so. So leave it alone, but understand its role.

**Analysis**
* This approach clearly focuses on reducing the number of parameters while maintaining good performance compared to their predecessors, and this was logical, given the results of experiments conducted on their predecessors' method, for example modifying the matrix of values did not give an improvement.
* This approach uses embeddings that can be learned during training. This allows the position information to adapt to the task at hand, but it also makes the model vulnerable to overfitting.
* Is this approach bounded or unbounded (i.e. can it generalize to arbitrary lengths)? Actually no, it is true that this approach can work with arbitrary lengths except that it clips after a certain distance (after a certain distance it becomes distance-aware), giving all distances the same embedding after a certain distance.


**Ref**
* [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683).

In [12]:
import torch
import torch.nn as nn
from typing import List
from torch.nn import functional as F

torch.manual_seed(42)

<torch._C.Generator at 0x7f69e8151450>

In [13]:
class RelativePositionBias(nn.Module):
    """
    Translate relative position to a bucket number for relative attention.

    The relative position is defined as memory_position - query_position, i.e.
    the distance in tokens from the attending position to the attended-to
    position. If bidirectional=False, then positive relative positions are
    invalid.

    We use smaller buckets for small absolute relative_position and larger buckets
    for larger absolute relative_positions. All relative positions >=max_distance
    map to the same bucket. All relative positions <=-max_distance map to the
    same bucket. This should allow for more graceful generalization to longer
    sequences than the model has been trained on.

    Args:
        bidirectional (bool): Whether the attention is bidirectional.
        num_buckets (int): Number of buckets.
        max_distance (int): Maximum distance for relative positions.
        num_heads (int): Number of attention heads.

    # REFRANCE: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
    """
    def __init__(self, bidirectional=True, num_buckets=32, max_distance=128, num_heads=8):
        super(RelativePositionBias, self).__init__()
        self.bidirectional = bidirectional
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.num_heads = num_heads
        self.relative_attention_bias = nn.Embedding(self.num_buckets, self.num_heads)

    @staticmethod
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        """
        Translate relative position to a bucket number.

        Args:
            relative_position (torch.Tensor): Relative position tensor.
            bidirectional (bool): Whether the attention is bidirectional.
            num_buckets (int): Number of buckets.
            max_distance (int): Maximum distance for relative positions.

        Returns:
            torch.Tensor: Bucket number tensor.
        """
        ret = torch.ones_like(relative_position, dtype=torch.long)

        # Handle bidirectional
        if bidirectional:
            num_buckets //= 2
            ret += (relative_position < 0).to(torch.long) * num_buckets
            relative_position = relative_position.abs()

        max_exact = num_buckets // 2
        is_small = (relative_position < max_exact)

        # Compute val_if_large
        val_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact) /
            torch.log(torch.tensor(max_distance / max_exact).float()) * (num_buckets - max_exact)
        ).to(torch.long)

        # Clamp val_if_large
        val_if_large = torch.clamp(val_if_large, max=num_buckets - 1)

        # Update ret based on is_small
        ret += torch.where(is_small, relative_position, val_if_large)

        return ret

    def compute_bias(self, qlen, klen):
        """
        Compute binned relative position bias.

        Args:
            qlen (int): Length of the query sequence.
            klen (int): Length of the key sequence.

        Returns:
            torch.Tensor: Relative position bias tensor.
        """
        # Create context and memory positions
        context_position = torch.arange(0, qlen, dtype=torch.long, device=self.relative_attention_bias.weight.device)[:, None]
        memory_position = torch.arange(0, klen, dtype=torch.long, device=self.relative_attention_bias.weight.device)[None, :]

        # Compute relative position
        relative_position = memory_position - context_position

        # Compute relative position bucket
        rp_bucket = self._relative_position_bucket(
            relative_position,
            bidirectional=self.bidirectional,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance
        )
        rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)

        # Get values from the embedding
        values = self.relative_attention_bias(rp_bucket)
        values = values.permute([2, 0, 1]).unsqueeze(0)

        return values

    def forward(self, qlen, klen):
        """
        Forward pass.

        Args:
            qlen (int): Length of the query sequence.
            klen (int): Length of the key sequence.

        Returns:
            torch.Tensor: Relative position bias tensor.
        """
        return self.compute_bias(qlen, klen)


In [14]:
# Example usage
def test_relative_position_bias():
    # Instantiate the RelativePositionBias module
    num_buckets = 32
    max_distance = 128
    num_heads = 2
    relative_position_bias = RelativePositionBias(num_buckets=num_buckets, max_distance=max_distance, num_heads=num_heads)

    # Example input sequence lengths (can be adjusted based on your data)
    qlen = 5
    klen = 5

    # Compute relative position biases
    biases = relative_position_bias(qlen, klen)

    # Print the computed biases
    print("Computed biases shape:", biases.shape)
    print(biases[:,0].shape)
    #print("Computed biases:\n", biases)

if __name__ == "__main__":
    test_relative_position_bias()

Computed biases shape: torch.Size([1, 2, 5, 5])
torch.Size([1, 5, 5])


In [15]:
class AttentionHead(nn.Module):
    """
    Relation-aware attention head implementation.

    Args:
        hidden_size (int): Hidden size for the model (embedding dimension).
        head_dim (int): Dimensionality of the attention head.

    Attributes:
        query_weights (nn.Linear): Linear layer for query projection.
        key_weights (nn.Linear): Linear layer for key projection.
        value_weights (nn.Linear): Linear layer for value projection.
    """

    def __init__(self, hidden_size, head_dim):
        """
        Initializes the AttentionHead.

        Args:
            hidden_size (int): Hidden size for the model (embedding dimension).
            head_dim (int): Dimensionality of the attention head.
        """
        super().__init__()
        self.head_dim = head_dim
        self.query_weights: nn.Linear = nn.Linear(hidden_size, head_dim)
        self.key_weights: nn.Linear = nn.Linear(hidden_size, head_dim)
        self.value_weights: nn.Linear = nn.Linear(hidden_size, head_dim)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                 relative_biases:torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Applies attention mechanism to the input query, key, and value tensors.

        Args:
            query (torch.Tensor): Query tensor.
            key (torch.Tensor): Key tensor.
            value (torch.Tensor): Value tensor.
            mask (torch.Tensor): Optional mask tensor.

        Returns:
            torch.Tensor: Updated value embeddings after applying attention mechanism.
        """
        query: torch.Tensor = self.query_weights(query)
        key: torch.Tensor = self.key_weights(key)
        value: torch.Tensor = self.value_weights(value)

        att_scores: torch.Tensor = (torch.matmul(query, key.transpose(1, 2)) + relative_biases) / self.head_dim ** 0.5

        if mask is not None:
            mask = mask.to(torch.int)
            att_scores: torch.Tensor = att_scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)

        att_weights: torch.Tensor = F.softmax(att_scores, dim=-1)
        n_value: torch.Tensor = torch.matmul(att_weights, value)

        return n_value


class MultiHeadAttention(nn.Module):
    """
    Multi-head attention layer implementation.

    Args:
        hidden_size (int): Hidden size for the model (embedding dimension).
        num_heads (int): Number of attention heads.

    Attributes:
        hidden_size (int): Hidden size for the model (embedding dimension).
        num_heads (int): Number of attention heads.
        head_dim (int): Dimensionality of each attention head.
        attention_heads (nn.ModuleList): List of AttentionHead layers.
        fc (nn.Linear): Fully connected layer for final projection.
    """

    def __init__(self, hidden_size, num_heads):
        """
        Initializes the MultiHeadAttention layer.
        """
        super().__init__()
        self.hidden_size: int = hidden_size
        self.num_heads: int = num_heads
        self.head_dim: int = hidden_size // num_heads
        self.attention_heads: nn.ModuleList = nn.ModuleList([AttentionHead(self.hidden_size, self.head_dim) for head_num in range(self.num_heads)])
        self.fc: nn.Linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, relative_position_bias: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Applies multi-head attention mechanism to the input query, key, and value tensors.

        Args:
            query (torch.Tensor): Query tensor.
            key (torch.Tensor): Key tensor.
            value (torch.Tensor): Value tensor.
            mask (torch.Tensor): Optional mask tensor.

        Returns:
            torch.Tensor: Updated hidden state after applying multi-head attention mechanism.
        """
        attention_outputs: List[torch.Tensor] = [attention_head(query, key, value, mask=mask, relative_biases=relative_position_bias[:,i]) for i, attention_head in enumerate(self.attention_heads)]
        hidden_state: torch.Tensor = torch.cat(attention_outputs, dim=-1)
        hidden_state: torch.Tensor = self.fc(hidden_state)
        return hidden_state

In [16]:
import torch

# Define input dimensions
batch_size = 2
seq_len = 5
hidden_size = 16
num_heads = 8
num_buckets = 32
max_distance = 128

# Compute relative position biases
relative_position_bias = RelativePositionBias(num_buckets=num_buckets, max_distance=max_distance, num_heads=num_heads)
biases = relative_position_bias(seq_len, seq_len)

# Create random input tensors
x = torch.randn(batch_size, seq_len, hidden_size)

# Instantiate the MultiHeadAttention module
multihead_attention = MultiHeadAttention(hidden_size=hidden_size, num_heads=num_heads)

# Forward pass
output = multihead_attention(x, x, x, biases)

# Print the output shape
print("Output Shape:", output.shape)


Output Shape: torch.Size([2, 5, 16])
