In [1]:
import torch
import torch.nn as nn
import numpy as np
import math

In [2]:
a = torch.Tensor([[[2, 3, 4, 3, 1, 6], [1, 1, 1, 1, 1, 1], [0, -4, 18, 0, 0, 0], [5, 6, 7, 1, 4, 8]],
                  [[1, 2, 55, 11, 34, 88], [5, 34, 13, 22, 12, 2], [0, 0, 0, 9, 98, 22], [-10, -6, 7, 23, 2, 1]]])
a.shape # .shape -> [batch_size, seq_len, emb_size]

torch.Size([2, 4, 6])

## Sinusoidal PE

In [4]:
# Предложен в attention is all you need (2017)
# Применяется до self_attention после входа: x = x_input + sinusoidal_pe

In [28]:
class SinusoidalPositionalEncoding:
  def __init__(self, seq_len, emb_size):
    self.seq_len = seq_len
    self.emb_size = emb_size

  def __call__(self):
    pos_emb = torch.zeros(self.seq_len, self.emb_size)
    pos = torch.arange(0, self.seq_len, dtype=torch.float32).unsqueeze(-1) # [seq_len, 1]
    denom = torch.exp(torch.arange(0, self.emb_size, 2).float() * (-math.log(10_000.0) / self.emb_size))
    pos_emb[:, 0::2] = torch.sin(pos * denom)
    pos_emb[:, 1::2] = torch.cos(pos * denom)
    return pos_emb.unsqueeze(0) # [1, seq_len, emb_size]

In [29]:
sin_pe = SinusoidalPositionalEncoding(seq_len=a.size(1),
                                        emb_size=a.size(2))
out = sin_pe()
out, out.shape, a + out, (a + out).shape

(tensor([[[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
          [ 0.8415,  0.5403,  0.0464,  0.9989,  0.0022,  1.0000],
          [ 0.9093, -0.4161,  0.0927,  0.9957,  0.0043,  1.0000],
          [ 0.1411, -0.9900,  0.1388,  0.9903,  0.0065,  1.0000]]]),
 torch.Size([1, 4, 6]),
 tensor([[[ 2.0000e+00,  4.0000e+00,  4.0000e+00,  4.0000e+00,  1.0000e+00,
            7.0000e+00],
          [ 1.8415e+00,  1.5403e+00,  1.0464e+00,  1.9989e+00,  1.0022e+00,
            2.0000e+00],
          [ 9.0930e-01, -4.4161e+00,  1.8093e+01,  9.9569e-01,  4.3089e-03,
            9.9999e-01],
          [ 5.1411e+00,  5.0100e+00,  7.1388e+00,  1.9903e+00,  4.0065e+00,
            9.0000e+00]],
 
         [[ 1.0000e+00,  3.0000e+00,  5.5000e+01,  1.2000e+01,  3.4000e+01,
            8.9000e+01],
          [ 5.8415e+00,  3.4540e+01,  1.3046e+01,  2.2999e+01,  1.2002e+01,
            3.0000e+00],
          [ 9.0930e-01, -4.1615e-01,  9.2699e-02,  9.9957e+00,  9.8004e+01,
            2.3000e+01],

## Relative PE

In [None]:
# Применяется внутри self_attention при вычислении аттеншн скоров: attn_scores[i][j] += f(i - j)

In [32]:
class RelativePositionalEncoding:
  def __init__(self, seq_len, num_heads=8):
    self.seq_len = seq_len
    self.num_heads = num_heads

  def __call__(self):
    max_relative_distance = 2 * self.seq_len - 1 # для seq_len=4 => [-3, -2, -1, 0, 1, 2, 3] -> 7
    relative_positions = torch.arange(-self.seq_len + 1, self.seq_len, dtype=torch.float32).unsqueeze(-1) # для seq_len=4 => [-3, -2, -1, 0, 1, 2, 3]
    denom = torch.exp(torch.arange(0, self.num_heads, 2).float() * (-math.log(10_000.0) / self.num_heads))

    pos_emb = torch.zeros(max_relative_distance, self.num_heads)
    pos_emb[:, 0::2] = torch.sin(relative_positions * denom)
    pos_emb[:, 1::2] = torch.cos(relative_positions * denom)

    bias = torch.zeros(self.num_heads, self.seq_len, self.seq_len) # q * k.T + relative_pe => shape
    for i in range(self.seq_len):
      for j in range(self.seq_len):
        relative_pos = i - j + self.seq_len - 1
        bias[:, i, j] = pos_emb[relative_pos]

    return bias # [num_heads, seq_len, seq_len]

In [48]:
relative_pe = RelativePositionalEncoding(seq_len=a.size(1))
relative_bias = relative_pe()
out = (a @ a.transpose(1, 2)).unsqueeze(1).repeat(1, 8, 1, 1) # [batch_size, 8, seq_len, seq_len] -> имитация q @ k.T
out.shape, (out + relative_bias).shape

(torch.Size([2, 8, 4, 4]), torch.Size([2, 8, 4, 4]))

## Rotary PE

In [None]:
# Применяется внутри self_attention до dot_product q * k:
# Q_roped = apply_rope(Q)
# K_roped = apply_rope(K)
# attn_scores = Q_roped @ K_roped.T

In [3]:
class RotaryPositionalEmbeddings(nn.Module):
  def __init__(self, head_dim, rope_base=10_000, max_seq_len=4096):
    super().__init__()
    self.head_dim = head_dim
    self.rope_base = rope_base
    self.max_seq_len = max_seq_len
    self._rope_init()

  def reset_parameters(self):
    self._rope_init()

  def _rope_init(self):
    theta = 1.0 / (self.rope_base ** torch.arange(0, self.head_dim, 2)[:(self.head_dim//2)].float() / self.head_dim)
    self.register_buffer('theta', theta, persistent=False)
    self.build_rope_cache(self.max_seq_len)

  def build_rope_cache(self, max_seq_len=4096):
    seq_idx = torch.arange(max_seq_len, dtype=self.rope_base.dtype, device=self.rope_base.device)
    # outer product seq_idx и rope_theta
    idx_theta = torch.einsum('i, j -> ij', seq_idx, self.rope_theta).float() # [max_seq_len, head_dim//2]
    cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) # [max_seq_len, head_dim//2, 2]
    self.register_buffer('cache', cache, persistent=False)

  def forward(self, x, input_pos=None):
    # x.shape -> [batch_size, seq_len, num_heads, head_dim]
    seq_len = x.size(1)
    rope_cache = (self.cache[:seq_len] if input_pos in None else self.cache[input_pos])
    x_shaped = x.float().reshape(*x.shape[:-1], -1, 2) # [batch_size, seq_len, num_heads, head_dim//2]
    rope_cache = rope_cache.view(-1, x_shaped.size(1), 1, x_shaped.size(3), 2)

    x_out = torch.stack(
        [
        x_shaped[..., 0] * rope_cache[..., 0] - x_shaped[..., 1] * rope_cache[..., 1],
        x_shaped[..., 1] * rope_cache[..., 0] + x_shaped[..., 0] * rope_cache[..., 1]
        ],
        dim=-1
    ) # [batch_size, seq_len, num_heads, head_dim//2, 2]
    x_out = x_out.flatten(3) # [batch_size, seq_len, num_heads, head_dim]
    return x_out

## Alibi (attention with linear biases)

In [None]:
# Применяется внутри self_attention при вычислении аттеншн скоров:
# attn_scores[i][j] += -slope * |i - j|

In [5]:
class AlibiPositionalBias:
  def __init__(self, num_heads, seq_len, learnable_slopes=False):
    self.num_heads = num_heads
    self.seq_len = seq_len

    self.slopes = self._get_slopes(num_heads)
    if learnable_slopes:
      self.slopes = nn.Parameter(self.slopes, requires_grad=True)

    pos = torch.arange(seq_len)
    self.rel_dist = (pos[None, :] - pos[:, None]).abs() # [seq_len, seq_len]

  def _get_slopes(self, n):
    def get_pow2_slopes(n):
      start = 2.0 ** (-8.0 / n)
      return torch.Tensor([start ** i for i in range(n)])

    if math.log2(n).is_integer():
      return get_pow2_slopes(n)
    else:
      closest_power_of2 = 2 ** math.floor(math.log2(n))
      base_slopes = get_pow2_slopes(closest_power_of2)
      extra = self._get_slopes(2 * closest_power_of2)[0::2][:n-closest_power_of2]
      return torch.cat([base_slopes, extra], dim=0)

  def __call__(self):
    # [num_heads, 1, 1] * [1, seq_len, seq_len] => [num_heads, seq_len, seq_len]
    bias = -self.slopes[:, None, None] * self.rel_dist[None, :, :]
    return bias # [num_heads, seq_len, seq_len]

In [6]:
batch_size = 2
num_heads = 4
seq_len = 6

attn_scores = torch.randn(batch_size, num_heads, seq_len, seq_len)
alibi = AlibiPositionalBias(num_heads, seq_len)
bias = alibi()  # [num_heads, seq_len, seq_len]

attn_scores_with_alibi = attn_scores + bias.unsqueeze(0)  # [batch_size, num_heads, seq_len, seq_len]