In [2]:
import torch

In [3]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

### Rotary Position Embedding

In [None]:
import torch
import torch.nn as nn

class RotaryPositionEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        pos = torch.arange(max_seq_len).float()
        freqs = torch.einsum("i,j->ij", pos, inv_freq)
        self.cos = torch.cos(freqs)
        self.sin = torch.sin(freqs)
        
    def apply_rotary(self, x, seq_len):
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        cos = self.cos[:seq_len].unsqueeze(0).to(x.device)
        sin = self.sin[:seq_len].unsqueeze(0).to(x.device)

        x_rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
        return x_rotated

    def forward(self, x):
        return self.apply_rotary(x, x.size(1))