In [2]:
import torch

##### Step 1		
 Create an embedding table for relative positions
##### Step 2	
 Compute relative distances between query and key positions
##### Step 3	
Clip distances to be within [-max_relative_position, max_relative_position]
##### Step 4	
Convert to indices for lookup in embedding table
##### Step5	
Retrieve embeddings for each relative position

In [3]:
import torch.nn as nn

- `num_units:` Embedding size (e.g., `num_units=8` means each relative position is encoded as an 8-dimensional vector).
- `max_relative_position:` Maximum relative distance (e.g., `max_relative_position=4 `means relative positions range from -4 to +4)

In [38]:
num_units = 8
max_relative_position = 2
length_q = 3
length_k = 3

- `embeddings_table:` A trainable matrix of shape `(2 * max_relative_position + 1, num_units)`.

- Each row represents a `relative position from -max_relative_position` to +max_relative_position.
- Example: If` max_relative_position=4` and `num_units=8`, the shape of `embeddings_table` will be (9, 8) because positions range from` -4 to +4`.


In [39]:
embeddings_table=nn.Parameter(torch.Tensor(max_relative_position*2+1,num_units))
embeddings_table

Parameter containing:
tensor([[4.2039e-45, 7.0065e-45, 4.9739e+22, 8.4218e-43, 5.9271e+22, 8.4218e-43,
         8.5149e+22, 8.4218e-43],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.4013e-45, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [2.1019e-44, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00]], requires_grad=True)

- `nn.init.xavier_uniform_():` This initializes the embedding table with values that help with stable training.

In [40]:
nn.init.xavier_normal_(embeddings_table)

Parameter containing:
tensor([[ 0.1740, -0.1894,  0.1764,  0.4095, -0.6246,  0.4744, -0.2800,  0.2685],
        [-0.1422, -0.5910,  0.1839, -0.2461,  0.1772, -0.2254, -0.6317, -0.3293],
        [ 0.2137,  0.5104,  0.2616, -0.5953,  0.2767, -0.1704,  0.1310, -0.2576],
        [-0.7897, -1.3042, -0.2933,  0.1419,  0.5016, -0.4019,  0.9015,  0.0623],
        [-0.3086, -0.3942,  0.3056,  0.0655, -0.4996, -0.2155, -0.0900,  0.7124]],
       requires_grad=True)

- range_vec_q creates a tensor [0, 1, 2, ..., length_q-1].
- range_vec_k creates a tensor [0, 1, 2, ..., length_k-1].

In [41]:
range_vec_q=torch.arange(length_q)
range_vec_k=torch.arange(length_k)
print(range_vec_k)

tensor([0, 1, 2])


- Computes a **distance matrix** where each element is the difference between query and key positions.
- Example for `length_q=3`, `length_k=3`:

The distance between position `i` (query) and `j (key)` is `j - i.`

In [42]:
i=range_vec_k[None,:]
i

tensor([[0, 1, 2]])

In [43]:
j=range_vec_q[:,None]
j

tensor([[0],
        [1],
        [2]])

In [44]:
distance_mat=j-i
distance_mat

tensor([[ 0, -1, -2],
        [ 1,  0, -1],
        [ 2,  1,  0]])

- `Clips` the distance values so that they stay within the range `[-max_relative_position, max_relative_position].`
- Example with `max_relative_position=2`:

- If a distance exceeds `max_relative_position`, it is **clamped** to `±max_relative_position.`

In [45]:
distance_mat_clipped=torch.clamp(distance_mat,-max_relative_position,max_relative_position)
distance_mat_clipped

tensor([[ 0, -1, -2],
        [ 1,  0, -1],
        [ 2,  1,  0]])

- **Shifts all values to be positive** so they can be used as indices in the embedding table.
- Example `(with max_relative_position=2)`:

- This ensures the smallest relative position `(-2)` maps to index `0`, and the largest `(+2)` maps to index `4`


In [48]:
final_mat=distance_mat_clipped+max_relative_position
final_mat

tensor([[2, 1, 0],
        [3, 2, 1],
        [4, 3, 2]])

- Uses `final_mat` as indices to select the corresponding embeddings from `embeddings_table.`
- The shape of `embeddings` will be (length_q, length_k, num_units).

In [54]:
embeddings_table[final_mat]

tensor([[[ 0.2137,  0.5104,  0.2616, -0.5953,  0.2767, -0.1704,  0.1310,
          -0.2576],
         [-0.1422, -0.5910,  0.1839, -0.2461,  0.1772, -0.2254, -0.6317,
          -0.3293],
         [ 0.1740, -0.1894,  0.1764,  0.4095, -0.6246,  0.4744, -0.2800,
           0.2685]],

        [[-0.7897, -1.3042, -0.2933,  0.1419,  0.5016, -0.4019,  0.9015,
           0.0623],
         [ 0.2137,  0.5104,  0.2616, -0.5953,  0.2767, -0.1704,  0.1310,
          -0.2576],
         [-0.1422, -0.5910,  0.1839, -0.2461,  0.1772, -0.2254, -0.6317,
          -0.3293]],

        [[-0.3086, -0.3942,  0.3056,  0.0655, -0.4996, -0.2155, -0.0900,
           0.7124],
         [-0.7897, -1.3042, -0.2933,  0.1419,  0.5016, -0.4019,  0.9015,
           0.0623],
         [ 0.2137,  0.5104,  0.2616, -0.5953,  0.2767, -0.1704,  0.1310,
          -0.2576]]], grad_fn=<IndexBackward0>)