## **Self-Attention with Relative Position Representations**

**Note:** I advise you to watch [this video](https://www.youtube.com/watch?v=DwaBQbqh5aE) if you do not have prior knowledge of this type of representation.

The self-attention mechanism in original Transformer is extended to efficiently consider representations of the relative positions, or distances between sequence elements. Abstract of the original paper:

> Relying entirely on an attention mechanism, the Transformer introduced by Vaswani et al. (2017) achieves state-of-the-art results for machine translation. In contrast to recurrent and convolutional neural networks, it does not explicitly model relative or absolute position information in its structure. Instead, it requires adding representations of absolute positions to its inputs. In this work we present an alternative approach, extending the self-attention mechanism to efficiently consider representations of the relative positions, or distances between sequence elements. On the WMT 2014 English-to-German and English-to-French translation tasks, this approach yields improvements of 1.3 BLEU and 0.3 BLEU over absolute position representations, respectively. Notably, we observe that combining relative and absolute position representations yields no further improvement in translation quality. We describe an efficient implementation of our method and cast it as an instance of relation-aware self-attention mechanisms that can generalize to arbitrary graph-labeled inputs.

### **Self-Attention**

1. **Input and Output Dimensions:**
   - Input sequence $x=\left(x_{1}, \ldots, x_{n}\right)$ of $n$ elements with dimension $d_a$.
   - Output sequence $z=\left(z_{1}, \ldots, z_{n}\right)$ with dimension $d_z$.

2. **Computation of $z_i$:**
   $$
z_{i}=\sum_{j=1}^{n} \alpha_{i j}\left(x_{j} W^{V}\right)
$$

3. **Weight Coefficients $\alpha_{ij}\$:**
   $$
\alpha_{i j}=\frac{\exp e_{i j}}{\sum_{k=1}^{n} \exp e_{i k}}
$$

4. **Compatibility Function $e_{ij}$:**
$$
e_{i j}=\frac{\left(x_{i} W^{Q}\right)\left(x_{j} W^{K}\right)^{T}}{\sqrt{d_{z}}}
$$

### **Relation-aware Self-Attention**

1. **Extension to Self-Attention:**
   - An extension to self-attention is proposed to consider the pairwise relationships between input elements in the sense that the input is modeled as a labeled, directed, fully-connected graph. The edge between input elements $x_i$ and $x_j$ is represented by vectors: $a_{i j}^{V}, a_{i j}^{K} \in \mathbb{R}^{d_{a}}$. So, $a_{i j}^{V}, a_{i j}^{K}$ model the interaction between positions $i$ and $j$. These representations can be shared across attention heads. We use $d_a = d_z$ .
   - These representations can be shared across attention heads.
   - Edges can capture information about the relative position differences between input elements.

2. **Modification of $z_i$ with Edge Information $a_{ij}^V$:**
   $$
z_{i}=\sum_{j=1}^{n} \alpha_{i j}\left(x_{j} W^{V}+a_{i j}^{V}\right)
$$
3. **Modification of Compatibility Function with Edge Information $a_{ij}^K$:**
$$
e_{i j}=\frac{x_{i} W^{Q}\left(x_{j} W^{K}+a_{i j}^{K}\right)^{T}}{\sqrt{d_{z}}}
$$

### **Relative Position Representations**

1. **Edge Labels for Relative Positions:**
   - Edge representations $a_{ij}^K, a_{ij}^V$ capture relative position differences.
   - The maximum relative position is clipped to a maximum absolute value of $k$. It is hypothesized that precise relative position information is not useful beyond a certain distance. Clipping the maximum distance also enables the model to generalize to sequence lengths not seen during training.
   - So, $2k+1$ unique edge labels are considered.

2. **Learnable Relative Position Representations:**
   - $a_{ij}^K, a_{ij}^V$ are determined using learnable relative position representations $w^K, w^V$.
   - Clipping function:
$$
\begin{aligned}
a_{i j}^{K} & =w_{\operatorname{clip}(j-i, k)}^{K} \\
a_{i j}^{V} & =w_{\operatorname{clip}(j-i, k)}^{V} \\
\operatorname{clip}(x, k) & =\max (-k, \min (k, x))
\end{aligned}
$$

3. **Learnable Vectors \( w^K, w^V \):**
   - $w^K = (w_{-k}^K, \ldots, w_k^K)$
   - $w^V = (w_{-k}^V, \ldots, w_k^V)$

### **Analysis**

* Among the first, Shaw, Uszkoreit, and Vaswani (2018) introduced an alternative method for incorporating both absolute and relative position encodings.
* The use of relative position representations allows the model to consider pairwise relationships and capture information about the relative position differences between input elements, enhancing its ability to understand the sequence structure.
* Although it cannot directly be compared with the effect of simple addition of position embeddings, they roughly omit the position–position interaction and have only one unit–position term. In addition, they do not share the projection matrices but directly model the pairwise position interaction with the vectors $a$. In an ablation analysis they found that solely adding $a_{i j}^{K}$ might be sufficient.
*  To reduce space complexity, they share the parameters across attention heads. While it is not explicitly mentioned in their paper we understand that they add the position information in each layer but do not share the parameters. The authors find that relative position embeddings perform better in machine translation and the combination of absolute and relative embeddings does not improve the performance.

[Ref.1](https://direct.mit.edu/coli/article/48/3/733/111478/Position-Information-in-Transformers-An-Overview), [Ref.2](https://doi.org/10.18653/v1/N18-2074), [Ref.3](https://sh-tsang.medium.com/review-self-attention-with-relative-position-representations-266ab2f78dd7).

In [27]:
import torch
import torch.nn as nn
from typing import List
import torch.nn.functional as F

class RelativePosition(nn.Module):
    """
    Relative Position Embeddings Module

    This module generates learnable relative position embeddings to enrich
    the self-attention mechanism with information about the relative distances
    between elements in input sequences.

    Args:
        d_a (int): Number of dimensions in the relative position embeddings.
        k (int): Clipping distance.

    Attributes:
        position_embeddings (nn.Parameter): Learnable parameter for relative position embeddings.

    Example:
        >>> # Create a RelativePosition instance with 16 dimensions and clipping distance of 10
        >>> relative_position = RelativePosition(d_a=16, k=10)
        >>> # Generate relative position embeddings for sequences of lengths 5 and 7
        >>> embeddings = relative_position(length_query=5, length_key=7)
    """

    def __init__(self, d_a: int, k: int):
        """
        Initialize the RelativePosition module.

        Args:
        - d_a (int): Number of dimensions in the relative position embeddings.
        - k (int): Clipping distance.
        """
        super().__init__()
        self.d_a = d_a
        self.k = k
        self.position_embeddings = nn.Parameter(torch.empty((2 * k + 1, d_a)))
        nn.init.xavier_uniform_(self.position_embeddings)

    def forward(self, length_query: int, length_key: int) -> torch.Tensor:
        """
        Compute relative position embeddings.

        Args:
        - length_query (int): Length of the query sequence.
        - length_key (int): Length of the key sequence.

        Returns:
        - embeddings (torch.Tensor): Relative position embeddings (length_query, length_key, embedding_dim).
        """
        # Generate relative position embeddings
        indices_query = torch.arange(length_query, device=self.position_embeddings.device)
        indices_key = torch.arange(length_key, device=self.position_embeddings.device)
        distance_matrix = indices_key.unsqueeze(0) - indices_query.unsqueeze(1)
        distance_matrix_clipped = torch.clamp(distance_matrix, -self.k, self.k)
        final_matrix = distance_matrix_clipped + self.k
        embeddings = self.position_embeddings[final_matrix.to(torch.long)]

        return embeddings


class RelationAwareAttentionHead(nn.Module):
    """
    Relation-aware attention head implementation.

    Args:
        hidden_size (int): Hidden size for the model (embedding dimension).
        head_dim (int): Dimensionality of the attention head.
        k_bias_matrix (torch.Tensor): Matrix for relative position attention in query-key interaction.
        v_bias_matrix (torch.Tensor): Matrix for relative position attention in query-value interaction.

    Attributes:
        query_weights (nn.Linear): Linear layer for query projection.
        key_weights (nn.Linear): Linear layer for key projection.
        value_weights (nn.Linear): Linear layer for value projection.
    """

    def __init__(self, hidden_size, head_dim, k_bias_matrix, v_bias_matrix):
        """
        Initializes the RelationAwareAttentionHead.

        Args:
            hidden_size (int): Hidden size for the model (embedding dimension).
            head_dim (int): Dimensionality of the attention head.
            k_bias_matrix (torch.Tensor): Matrix for relative position attention in query-key interaction.
            v_bias_matrix (torch.Tensor): Matrix for relative position attention in query-value interaction.
        """
        super().__init__()
        self.head_dim = head_dim
        self.query_weights: nn.Linear = nn.Linear(hidden_size, head_dim)
        self.key_weights: nn.Linear = nn.Linear(hidden_size, head_dim)
        self.value_weights: nn.Linear = nn.Linear(hidden_size, head_dim)
        self.k_bias_matrix = k_bias_matrix
        self.v_bias_matrix = v_bias_matrix

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Applies attention mechanism to the input query, key, and value tensors.

        Args:
            query (torch.Tensor): Query tensor.
            key (torch.Tensor): Key tensor.
            value (torch.Tensor): Value tensor.
            mask (torch.Tensor): Optional mask tensor.

        Returns:
            torch.Tensor: Updated value embeddings after applying attention mechanism.
        """
        query: torch.Tensor = self.query_weights(query) # (b_s, n_t, head_dim)
        key: torch.Tensor = self.key_weights(key) # (b_s, n_t, head_dim)
        value: torch.Tensor = self.value_weights(value) # (b_s, n_t, head_dim)

        # Self-Attention scores
        attn_1: torch.Tensor = torch.matmul(query, key.transpose(1, 2)) # Q*K^T:(b_s, n_t, n_t)

        # Relative Position Attention scores
        attn_2: torch.Tensor = torch.matmul(query.permute(1, 0, 2), self.k_bias_matrix.transpose(1, 2)).transpose(0, 1) # Q*K_shifting^T:(b_s, n_t, n_t)

        # Relation-aware Self-Attention scores
        att_scores: torch.Tensor = (attn_1 + attn_2)/self.head_dim ** 0.5

        if mask is not None:
            mask = mask.to(torch.int)
            att_scores: torch.Tensor = att_scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)

        att_weights: torch.Tensor = F.softmax(att_scores, dim=-1)

        # Weighted sum of values
        values_1: torch.Tensor = torch.matmul(att_weights, value) # (b_s, n_t, head_dim)

        # Relative Position Representation for values
        values_2: torch.Tensor = torch.matmul(att_weights.permute(1, 0, 2), self.v_bias_matrix).transpose(0, 1) # (b_s, n_t, head_dim)

        # Relation-aware values
        n_value  = values_1 + values_2
        return n_value


class RelationAwareMultiHeadAttention(nn.Module):
    """
    Multi-head attention layer implementation.

    Args:
        hidden_size (int): Hidden size for the model (embedding dimension).
        num_heads (int): Number of attention heads.
        k (int): Clipping distance for relative position embeddings.
        seq_len (int): Length of the input sequences.

    Attributes:
        hidden_size (int): Hidden size for the model (embedding dimension).
        num_heads (int): Number of attention heads.
        head_dim (int): Dimensionality of each attention head.
        relative_position_k (RelativePosition): Instance of RelativePosition for query-key relative positions.
        relative_position_v (RelativePosition): Instance of RelativePosition for query-value relative positions.
        k_bias_matrix (torch.Tensor): Matrix for relative position attention in query-key interaction.
        v_bias_matrix (torch.Tensor): Matrix for relative position attention in query-value interaction.
        attention_heads (nn.ModuleList): List of RelationAwareAttentionHead layers.
        fc (nn.Linear): Fully connected layer for final projection.
    """

    def __init__(self, hidden_size, num_heads, k, seq_len):
        """
        Initializes the RelationAwareMultiHeadAttention layer.

        Args:
            hidden_size (int): Hidden size for the model (embedding dimension).
            num_heads (int): Number of attention heads.
            k (int): Clipping distance for relative position embeddings.
            seq_len (int): Length of the input sequences.
        """
        super().__init__()
        self.hidden_size: int = hidden_size
        self.num_heads: int = num_heads
        self.head_dim: int = hidden_size // num_heads
        self.relative_position_k: torch.Tensor = RelativePosition(self.head_dim, k)
        self.relative_position_v: torch.Tensor = RelativePosition(self.head_dim, k)
        self.k_bias_matrix: torch.Tensor = self.relative_position_k(seq_len, seq_len)
        self.v_bias_matrix: torch.Tensor = self.relative_position_v(seq_len, seq_len)
        self.attention_heads: nn.ModuleList = nn.ModuleList([RelationAwareAttentionHead(self.hidden_size, self.head_dim,
                                                                           self.k_bias_matrix, self.v_bias_matrix) for _ in range(self.num_heads)])
        self.fc: nn.Linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Applies multi-head attention mechanism to the input query, key, and value tensors.

        Args:
            query (torch.Tensor): Query tensor.
            key (torch.Tensor): Key tensor.
            value (torch.Tensor): Value tensor.
            mask (torch.Tensor): Optional mask tensor.

        Returns:
            torch.Tensor: Updated hidden state after applying multi-head attention mechanism.
        """
        attention_outputs: List[torch.Tensor] = [attention_head(query, key, value, mask=mask) for attention_head in self.attention_heads]
        hidden_state: torch.Tensor = torch.cat(attention_outputs, dim=-1)
        hidden_state: torch.Tensor = self.fc(hidden_state)
        return hidden_state

In [29]:
hidden_size = 768
num_heads = 8
head_dim = 64
k = 4
seq_len=20
batch_size = 16

# Set a random seed for reproducibility
torch.manual_seed(42)

# Create instances of RelativePosition, RelationAwareAttentionHead, and RelationAwareMultiHeadAttention
relative_position_k = RelativePosition(d_a=head_dim, k=k)
relative_position_v = RelativePosition(d_a=head_dim, k=k)

attention_head = RelationAwareAttentionHead(hidden_size=hidden_size,
                                           head_dim=head_dim,
                                           k_bias_matrix=relative_position_k(seq_len, seq_len),
                                           v_bias_matrix=relative_position_v(seq_len, seq_len))

multihead_attention = RelationAwareMultiHeadAttention(hidden_size, num_heads, k, seq_len)

# Generate dummy input tensors
x_input = torch.rand((batch_size, seq_len, hidden_size))

# Test RelativePosition
relative_position_embeddings = relative_position_k(seq_len, seq_len)
print("Relative Position Embeddings Shape:", relative_position_embeddings.shape)

# Test RelationAwareAttentionHead
output_attention_head = attention_head(x_input, x_input, x_input)
print("RelationAwareAttentionHead Output Shape:", output_attention_head.shape)

# Test RelationAwareMultiHeadAttention
output_multihead_attention = multihead_attention(x_input, x_input, x_input)
print("RelationAwareMultiHeadAttention Output Shape:", output_multihead_attention.shape)


Relative Position Embeddings Shape: torch.Size([20, 20, 64])
RelationAwareAttentionHead Output Shape: torch.Size([16, 20, 64])
RelationAwareMultiHeadAttention Output Shape: torch.Size([16, 20, 768])
