In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
    
    Args:
        dim: Dimension of the embedding.
        end: Maximum sequence length.
        theta: Base value for the frequency calculation.
    
    Returns:
        Complex tensor with shape [end, dim // 2] for efficient computation.
    """
    # Ensure dim is even
    if dim % 2 != 0:
        raise ValueError(f"Dimension {dim} must be even")
    
    # Create frequencies for each dimension
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    
    # Create position indices
    t = torch.arange(end, device=freqs.device)
    
    # Outer product of position indices and frequencies
    freqs = torch.outer(t, freqs)
    
    # Compute complex exponentials: cos(x) + i*sin(x)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    
    return freqs_cis


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.
    
    Args:
        xq: Query states tensor of shape [batch_size, seq_len, n_heads, head_dim]
        xk: Key states tensor of shape [batch_size, seq_len, n_heads, head_dim]
        freqs_cis: Complex tensor of shape [seq_len, head_dim/2]
        
    Returns:
        Tuple of (xq_out, xk_out) with the same shape as the input tensors.
    """
    # Extract shapes
    batch, seq_len, n_heads, head_dim = xq.shape
    
    # Ensure head_dim is even
    if head_dim % 2 != 0:
        raise ValueError(f"Head dimension {head_dim} must be even")
    
    # Reshape inputs to complex-valued tensors
    xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    # Extend frequency tensor to match the batch and heads dimensions
    freqs_cis = freqs_cis[:seq_len]
    
    # Apply rotation using complex multiplication
    xq_out = torch.view_as_real(xq_complex * freqs_cis.unsqueeze(0).unsqueeze(2)).flatten(-2)
    xk_out = torch.view_as_real(xk_complex * freqs_cis.unsqueeze(0).unsqueeze(2)).flatten(-2)
    
    # Return the rotated tensors with original dtype
    return xq_out.type_as(xq), xk_out.type_as(xk)


class RotaryEmbedding(nn.Module):
    """
    Rotary positional embedding implementation as a PyTorch module.
    """
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base
        self.register_buffer(
            "freqs_cis", precompute_freqs_cis(self.dim, self.max_seq_len, self.base)
        )
        
    def forward(self, q, k):
        """
        Apply rotary embeddings to query and key tensors.
        
        Args:
            q: Query tensor
            k: Key tensor
            
        Returns:
            Tuple of (q, k) with rotary embeddings applied
        """
        return apply_rotary_emb(q, k, self.freqs_cis)


# Example usage in a self-attention layer
class SelfAttentionWithRoPE(nn.Module):
    def __init__(self, hidden_size, num_heads, max_seq_len=2048):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        # Query, Key, Value projections
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)
        
        # Rotary embeddings
        self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len)
        
    def forward(self, hidden_states: torch.Tensor, attention_mask=None) -> torch.Tensor:
        batch_size, seq_length = hidden_states.shape[:2]
        
        # Project to query, key, value
        q:torch.Tensor = self.q_proj(hidden_states)
        k:torch.Tensor = self.k_proj(hidden_states)
        v:torch.Tensor = self.v_proj(hidden_states)
        
        # Reshape for multi-head attention
        q = q.view(batch_size, seq_length, self.num_heads, self.head_dim)
        k = k.view(batch_size, seq_length, self.num_heads, self.head_dim)
        v = v.view(batch_size, seq_length, self.num_heads, self.head_dim)
        
        # Apply rotary embeddings
        q, k = self.rotary_emb(q, k)
        
        # Transpose for attention calculation
        q = q.transpose(1, 2)  # (batch, num_heads, seq_len, head_dim)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Calculate attention scores
        scale = 1.0 / math.sqrt(self.head_dim)
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
        
        # Apply attention mask if provided
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask
        
        # Softmax and dropout
        attn_weights = F.softmax(attn_weights, dim=-1)
        
        # Get weighted sum
        attn_output = torch.matmul(attn_weights, v)
        
        # Transpose and reshape
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_length, self.hidden_size
        )
        
        # Output projection
        attn_output = self.out_proj(attn_output)
        
        return attn_output


# Example usage
if __name__ == "__main__":
    # Model parameters
    batch_size = 2
    seq_length = 10
    hidden_size = 512
    num_heads = 8
    
    # Create random input
    hidden_states = torch.rand(batch_size, seq_length, hidden_size)
    
    # Initialize the self-attention layer with RoPE
    self_attn = SelfAttentionWithRoPE(hidden_size, num_heads)
    
    # Forward pass
    output = self_attn(hidden_states)
    print(f"Input shape: {hidden_states.shape}")
    print(f"Output shape: {output.shape}")
    
    # We can verify the output shape matches the input shape
    assert output.shape == hidden_states.shape

Input shape: torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])


In [2]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
    
    Args:
        dim: Dimension of the embedding.
        end: Maximum sequence length.
        theta: Base value for the frequency calculation.
    
    Returns:
        Complex tensor with shape [end, dim // 2] for efficient computation.
    """
    # Ensure dim is even
    if dim % 2 != 0:
        raise ValueError(f"Dimension {dim} must be even")
    
    # Create frequencies for each dimension
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    
    # Create position indices
    t = torch.arange(end, device=freqs.device)
    
    # Outer product of position indices and frequencies
    freqs = torch.outer(t, freqs)
    
    # Compute complex exponentials: cos(x) + i*sin(x)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    
    return freqs_cis

In [3]:
batch_size = 2
seq_length = 10
hidden_size = 512
num_heads = 8

# Create random input
hidden_states = torch.rand(batch_size, seq_length, hidden_size)

In [4]:
import torch

magnitude = torch.tensor([1.0])  # abs = 1
angle = torch.tensor([torch.pi / 4])  # angle = π/4

complex_number = torch.polar(magnitude, angle)
print(complex_number)  # Output: tensor([0.7071+0.7071j])


tensor([0.7071+0.7071j])


In [5]:
import torch

# Define frequency matrix
freqs = torch.tensor([[0.0, torch.pi / 2], [torch.pi, 3*torch.pi / 2]])  # Angles
magnitudes = torch.ones_like(freqs)  # Unit circle (magnitude = 1)

# Compute complex exponentials
freqs_cis = torch.polar(magnitudes, freqs)

print(freqs_cis)


tensor([[ 1.0000e+00+0.0000e+00j, -4.3711e-08+1.0000e+00j],
        [-1.0000e+00-8.7423e-08j,  1.1925e-08-1.0000e+00j]])


In [8]:
def custom_polar(magnitude, angle):
    real = magnitude * torch.cos(angle)
    imag = magnitude * torch.sin(angle)
    
    return torch.stack([real, imag], dim=-1)

magnitude = torch.tensor([1.0])  # abs = 1
angle = torch.tensor([torch.pi / 4])  # angle = π/4

complex_number = custom_polar(magnitude, angle)


print(complex_number)  # Output: tensor([0.7071, 0.7071])

print(torch.view_as_complex(complex_number))

tensor([[0.7071, 0.7071]])
tensor([0.7071+0.7071j])


In [10]:
dim = 512
theta = 10000.0

if dim % 2 != 0:
    raise ValueError(f"Dimension {dim} must be even")

# Create frequencies for each dimension
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

freqs.shape

torch.Size([256])

In [19]:
1.0 / (10000 ** (torch.arange(0, dim, 2)[: (dim // 2)].float()/ dim))

tensor([1.0000e+00, 9.6466e-01, 9.3057e-01, 8.9769e-01, 8.6596e-01, 8.3536e-01,
        8.0584e-01, 7.7737e-01, 7.4989e-01, 7.2339e-01, 6.9783e-01, 6.7317e-01,
        6.4938e-01, 6.2643e-01, 6.0430e-01, 5.8294e-01, 5.6234e-01, 5.4247e-01,
        5.2330e-01, 5.0481e-01, 4.8697e-01, 4.6976e-01, 4.5316e-01, 4.3714e-01,
        4.2170e-01, 4.0679e-01, 3.9242e-01, 3.7855e-01, 3.6517e-01, 3.5227e-01,
        3.3982e-01, 3.2781e-01, 3.1623e-01, 3.0505e-01, 2.9427e-01, 2.8387e-01,
        2.7384e-01, 2.6416e-01, 2.5483e-01, 2.4582e-01, 2.3714e-01, 2.2876e-01,
        2.2067e-01, 2.1288e-01, 2.0535e-01, 1.9810e-01, 1.9110e-01, 1.8434e-01,
        1.7783e-01, 1.7154e-01, 1.6548e-01, 1.5963e-01, 1.5399e-01, 1.4855e-01,
        1.4330e-01, 1.3824e-01, 1.3335e-01, 1.2864e-01, 1.2409e-01, 1.1971e-01,
        1.1548e-01, 1.1140e-01, 1.0746e-01, 1.0366e-01, 1.0000e-01, 9.6466e-02,
        9.3057e-02, 8.9769e-02, 8.6596e-02, 8.3536e-02, 8.0584e-02, 7.7737e-02,
        7.4989e-02, 7.2339e-02, 6.9783e-

In [17]:
freqs

tensor([1.0000e+00, 9.6466e-01, 9.3057e-01, 8.9769e-01, 8.6596e-01, 8.3536e-01,
        8.0584e-01, 7.7737e-01, 7.4989e-01, 7.2339e-01, 6.9783e-01, 6.7317e-01,
        6.4938e-01, 6.2643e-01, 6.0430e-01, 5.8294e-01, 5.6234e-01, 5.4247e-01,
        5.2330e-01, 5.0481e-01, 4.8697e-01, 4.6976e-01, 4.5316e-01, 4.3714e-01,
        4.2170e-01, 4.0679e-01, 3.9242e-01, 3.7855e-01, 3.6517e-01, 3.5227e-01,
        3.3982e-01, 3.2781e-01, 3.1623e-01, 3.0505e-01, 2.9427e-01, 2.8387e-01,
        2.7384e-01, 2.6416e-01, 2.5483e-01, 2.4582e-01, 2.3714e-01, 2.2876e-01,
        2.2067e-01, 2.1288e-01, 2.0535e-01, 1.9810e-01, 1.9110e-01, 1.8434e-01,
        1.7783e-01, 1.7154e-01, 1.6548e-01, 1.5963e-01, 1.5399e-01, 1.4855e-01,
        1.4330e-01, 1.3824e-01, 1.3335e-01, 1.2864e-01, 1.2409e-01, 1.1971e-01,
        1.1548e-01, 1.1140e-01, 1.0746e-01, 1.0366e-01, 1.0000e-01, 9.6466e-02,
        9.3057e-02, 8.9769e-02, 8.6596e-02, 8.3536e-02, 8.0584e-02, 7.7737e-02,
        7.4989e-02, 7.2339e-02, 6.9783e-

In [20]:
seq_length = 10

t = torch.arange(seq_length)

t

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [22]:
torch.outer(t, freqs).shape

torch.Size([10, 256])

In [23]:
v1 = torch.arange(1., 5.)
v2 = torch.arange(1., 4.)
torch.outer(v1, v2)

tensor([[ 1.,  2.,  3.],
        [ 2.,  4.,  6.],
        [ 3.,  6.,  9.],
        [ 4.,  8., 12.]])

In [37]:
torch.outer(v1, v2).T

tensor([[ 1.,  2.,  3.,  4.],
        [ 2.,  4.,  6.,  8.],
        [ 3.,  6.,  9., 12.]])

In [26]:
v1, v2

(tensor([1., 2., 3., 4.]), tensor([1., 2., 3.]))

In [38]:
torch.matmul(v1.unsqueeze(-1), v2.unsqueeze(0))

tensor([[ 1.,  2.,  3.],
        [ 2.,  4.,  6.],
        [ 3.,  6.,  9.],
        [ 4.,  8., 12.]])

In [36]:
v1 , v1.T

(tensor([1., 2., 3., 4.]), tensor([1., 2., 3., 4.]))

In [29]:
v1.unsqueeze(-1), v2.unsqueeze(0)

(tensor([[1.],
         [2.],
         [3.],
         [4.]]),
 tensor([[1., 2., 3.]]))

In [30]:
# torch.outer is works only on 1D tensors

In [50]:
torch.ones_like(torch.outer(t, freqs))

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])

In [46]:
t.unsqueeze(1).shape, freqs.unsqueeze(0).shape

(torch.Size([10, 1]), torch.Size([1, 256]))

In [47]:
freqs.unsqueeze(0)

tensor([[1.0000e+00, 9.6466e-01, 9.3057e-01, 8.9769e-01, 8.6596e-01, 8.3536e-01,
         8.0584e-01, 7.7737e-01, 7.4989e-01, 7.2339e-01, 6.9783e-01, 6.7317e-01,
         6.4938e-01, 6.2643e-01, 6.0430e-01, 5.8294e-01, 5.6234e-01, 5.4247e-01,
         5.2330e-01, 5.0481e-01, 4.8697e-01, 4.6976e-01, 4.5316e-01, 4.3714e-01,
         4.2170e-01, 4.0679e-01, 3.9242e-01, 3.7855e-01, 3.6517e-01, 3.5227e-01,
         3.3982e-01, 3.2781e-01, 3.1623e-01, 3.0505e-01, 2.9427e-01, 2.8387e-01,
         2.7384e-01, 2.6416e-01, 2.5483e-01, 2.4582e-01, 2.3714e-01, 2.2876e-01,
         2.2067e-01, 2.1288e-01, 2.0535e-01, 1.9810e-01, 1.9110e-01, 1.8434e-01,
         1.7783e-01, 1.7154e-01, 1.6548e-01, 1.5963e-01, 1.5399e-01, 1.4855e-01,
         1.4330e-01, 1.3824e-01, 1.3335e-01, 1.2864e-01, 1.2409e-01, 1.1971e-01,
         1.1548e-01, 1.1140e-01, 1.0746e-01, 1.0366e-01, 1.0000e-01, 9.6466e-02,
         9.3057e-02, 8.9769e-02, 8.6596e-02, 8.3536e-02, 8.0584e-02, 7.7737e-02,
         7.4989e-02, 7.2339e

In [49]:
t.unsqueeze(1)

tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8],
        [9]])

In [51]:
torch.outer(t, freqs)

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 9.6466e-01, 9.3057e-01,  ..., 1.1140e-04, 1.0746e-04,
         1.0366e-04],
        [2.0000e+00, 1.9293e+00, 1.8611e+00,  ..., 2.2279e-04, 2.1492e-04,
         2.0733e-04],
        ...,
        [7.0000e+00, 6.7526e+00, 6.5140e+00,  ..., 7.7978e-04, 7.5223e-04,
         7.2564e-04],
        [8.0000e+00, 7.7173e+00, 7.4446e+00,  ..., 8.9118e-04, 8.5969e-04,
         8.2931e-04],
        [9.0000e+00, 8.6820e+00, 8.3751e+00,  ..., 1.0026e-03, 9.6715e-04,
         9.3297e-04]])