In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

@dataclass
class ModelConfig:
    d_model: int = 1024
    n_heads: int = 16
    d_ff: int = 2816
    vocab_size: int = 32000
    num_encoder_layers: int = 6
    num_decoder_layers: int = 3
    rope_theta: float = 10000.0
    dropout: float = 0.0



In [5]:
class RoPE:
    def __init__(self, dim: int, base: float = 10000):
        half_dim = dim // 2
        freq = torch.exp(-torch.arange(half_dim, dtype=torch.float32) * (torch.log(torch.tensor(base)) / half_dim))
        self.freq = freq  # (dim/2)
    
    def get_rotary(self, seq_len: int, device: torch.device):
        # position = [0, 1, 2, ..., seq_len-1]
        positions = torch.arange(seq_len, dtype=torch.float32, device=device)
        freqs = torch.einsum("i,j->ij", positions, self.freq.to(device))
        return torch.cos(freqs), torch.sin(freqs)

    def apply_rotary(self, x, cos, sin):
        # x shape: (batch, seq_len, heads, head_dim)
        b, s, h, d = x.shape
        half = d // 2

        # reshape cos, sin to match x: (1, seq_len, 1, head_dim/2)
        cos = cos.unsqueeze(0).unsqueeze(2)  # (1, seq_len, 1, half_dim)
        sin = sin.unsqueeze(0).unsqueeze(2)

        x1 = x[..., :half]
        x2 = x[..., half:]

        x_rotated = torch.cat(
            [
                x1 * cos - x2 * sin,
                x2 * cos + x1 * sin
            ],
            dim=-1
        )
        return x_rotated



In [6]:
cfg = ModelConfig()
rope = RoPE(dim=64)  

x = torch.randn(2, 5, 4, 64)   # (batch, seq_len, heads, head_dim)
cos, sin = rope.get_rotary(seq_len=5, device=x.device)

y = rope.apply_rotary(x, cos, sin)

print("Input shape :", x.shape)
print("Output shape:", y.shape)


Input shape : torch.Size([2, 5, 4, 64])
Output shape: torch.Size([2, 5, 4, 64])
