<a href="https://colab.research.google.com/github/Hadrien-Cornier/cool-nn-stuff/blob/main/rotary_positional_embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [31]:
def build_rotation_matrix(dim, θ):
    rotation_matrix = torch.zeros((dim, dim))
    for i in range(dim // 2):
        rotation_matrix[2 * i, 2 * i] = torch.cos(θ)
        rotation_matrix[2 * i, 2 * i + 1] = -torch.sin(θ)
        rotation_matrix[2 * i + 1, 2 * i] = torch.sin(θ)
        rotation_matrix[2 * i + 1, 2 * i + 1] = torch.cos(θ)
    return rotation_matrix

build_rotation_matrix(4, torch.tensor(math.pi/2))

tensor([[-4.3711e-08, -1.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.0000e+00, -4.3711e-08,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00, -4.3711e-08, -1.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  1.0000e+00, -4.3711e-08]])

In [32]:
(build_rotation_matrix(4, torch.tensor(math.pi/4))) @ torch.tensor([[1.0, 0.0, 0.0, 0.0]]).transpose(0, 1)

tensor([[0.7071],
        [0.7071],
        [0.0000],
        [0.0000]])

In [33]:
(build_rotation_matrix(4, torch.tensor(math.pi/4)))**2 @ torch.tensor([[1.0, 0.0, 0.0, 0.0]]).transpose(0, 1)

tensor([[0.5000],
        [0.5000],
        [0.0000],
        [0.0000]])

In [25]:
class RotaryPositionalEmbeddings(nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base

    def _build_rotation_matrix(self, θ):
        rotation_matrix = torch.zeros((self.dim, self.dim))
        for i in range(self.dim // 2):
            rotation_matrix[2 * i, 2 * i] = torch.cos(θ)
            rotation_matrix[2 * i, 2 * i + 1] = -torch.sin(θ)
            rotation_matrix[2 * i + 1, 2 * i] = torch.sin(θ)
            rotation_matrix[2 * i + 1, 2 * i + 1] = torch.cos(θ)
        return rotation_matrix

    def forward(self, x):
        seq_len = x.shape[1]
        theta = 1. / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
        rotation_matrix = self._build_rotation_matrix(theta)
        x_rotated = torch.einsum('ijk,kl->ijl', x, rotation_matrix)
        return x_rotated

In [26]:
x = torch.randn(1, 10, 512)
rope = RotaryPositionalEmbeddings(512)
output = rope(x)
print(output.shape)

RuntimeError: expand(torch.FloatTensor{[256]}, size=[]): the number of sizes provided (0) must be greater or equal to the number of dimensions in the tensor (1)

In [None]:
class RotaryPEMultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, rope_percentage=0.5, dropout_prob=0.0):
        super().__init__()
        self.heads = heads
        self.d_model = d_model
        self.rope_percentage = rope_percentage
        self.dropout_prob = dropout_prob
        self.query_rotary_pe = RotaryPositionalEmbeddings(int(d_model * rope_percentage))
        self.key_rotary_pe = RotaryPositionalEmbeddings(int(d_model * rope_percentage))

    def get_scores(self, query, key):
        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))

In [None]:
multi_head_attention = RotaryPEMultiHeadAttention(8, 512)
query = torch.randn(1, 10, 512)
key = torch.randn(1, 10, 512)
scores = multi_head_attention.get_scores(query, key)
print(scores.shape)