# Architecture

In [2]:
import torch
import torch.nn as nn

### RMSNorm
- Standard LayerNorm does two things - centers the data then scales by the standard deviation.  
- RMSNorm simplifies this by skipping the mean centering and just normalizes by the root mean square of the values.
- It is cheaper in temrs of compute and just as effective emperically.  
$$\text{RMS}(x) = \sqrt{\frac{1}{n}\sum_{i=1}^{n} x_i^2}$$

In [3]:
class RMSNorm(nn.Module):
    def __init(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):
        super().__init__()
        self.eps = eps
        self.qwen3_compatible = qwen3_compatible
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeroes(emb_dim)) if bias else None

    def forward(self, x):
        input_dtype = x.dtype

        if self.qwen3_compatible:
            x = x.to(torch.float32)

        variance = x.pow(2).mean(dim=-1, keepdim=True)
        norm_x = x * torch.rsqrt(variance + self.eps)
        norm_x = norm_x * self.scale

        if self.shift is not None:
            norm_x = norm_x + self.shift
        
        return norm_x.to(input_dtype)


### ROPE
- Transformers process tokens in parallel, so they have no inherent information about token positions.
- Old GPT does this via learned position params. But this causes position info to degrade through the layers.
- RoPE encodes absolute position information of a token as rotation applied to its query and key vectors. When the dot product between Q and K is computed, the result naturally depends on the relative distance between the two tokens.
<div align="center">
   <img src="./assets/rope.png" alt="RoPE Illustration" width="350px"/><br>
   <sub><b>Figure:</b> A gentle introduction to Rotary Position Embedding (<a href="https://krasserm.github.io/2022/12/13/rotary-position-embedding/">image source</a>)</sub>
</div>

In [4]:
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
    assert head_dim % 2 == 0, "Embedding dim must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))

    # Generate position indices
    positions = torch.arange(context_length, dtype=dtype)

    # Compute the angles
    angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)

    # Precompute sin and cos
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin


def apply_rope(x, cos, sin):
    # x : (batch-size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dim must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2] # First half
    x2 = x[..., head_dim // 2 :] # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotations
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    # It's ok to use lower-precision after applying cos and sin rotation
    return x_rotated.to(dtype=x.dtype)
