<a href="https://colab.research.google.com/github/Papa-Panda/llm/blob/main/RoPE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# https://gemini.google.com/app/4e18d355af4f8a38

# 【DeepSeek-v2 MLA 原理讲解】 https://www.bilibili.com/video/BV1BYXRYWEMj/?share_source=copy_web&vd_source=985107e9bc8449878c67f709b64e7ad2

In [1]:
import torch
import torch.nn as nn
import numpy as np

class RotaryPositionEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048, theta=10000.0):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.theta = theta

        # Precompute the inverse frequencies
        # These are used to calculate the rotation angles for each dimension
        # and each position.
        inv_freq = 1.0 / (self.theta ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Precompute the positions (m)
        # We'll use these to generate the angles for each position.
        self.cached_angles = None
        self.cached_seq_len = None

    def _rotate_half(self, x):
        # Rotates the input tensor by half its dimension.
        # This is a core operation in RoPE, enabling the complex-number-like rotation.
        x1, x2 = x[..., : self.dim // 2], x[..., self.dim // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, x, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[1] # Assuming x has shape (batch_size, seq_len, dim)

        # If the sequence length changes, or angles haven't been computed yet, recompute them.
        if self.cached_seq_len is None or seq_len > self.cached_seq_len:
            self.cached_seq_len = seq_len
            # Generate the positions (m) for the current sequence length.
            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            # Calculate the angles for each position and each frequency.
            # Shape: (seq_len, dim // 2)
            angles = torch.einsum("i,j->ij", seq_idx, self.inv_freq)

            # Expand angles to cover the full dimension, duplicating for sin and cos parts.
            # Shape: (seq_len, dim)
            self.cached_angles = torch.cat((angles, angles), dim=-1)

        # Get the angles for the current sequence length
        angles = self.cached_angles[:seq_len, :].unsqueeze(0) # Add batch dimension

        # Apply the rotation
        # cos(angles) * x + sin(angles) * _rotate_half(x)
        # This is equivalent to multiplying by e^(i * angles) in the complex plane.
        x_rotated = x * angles.cos() + self._rotate_half(x) * angles.sin()
        return x_rotated

# Example Usage:
if __name__ == "__main__":
    # Define parameters
    batch_size = 2
    seq_len = 10
    model_dim = 64  # Must be an even number for this implementation of RoPE

    # Create a dummy input tensor (e.g., embeddings)
    # Shape: (batch_size, seq_len, model_dim)
    dummy_input = torch.randn(batch_size, seq_len, model_dim)

    print(f"Original input shape: {dummy_input.shape}")
    print(f"Original input (first element):\n{dummy_input[0, 0, :4]}")

    # Initialize RoPE
    rope = RotaryPositionEmbedding(dim=model_dim)

    # Apply RoPE to the input
    output_with_rope = rope(dummy_input)

    print(f"\nOutput shape with RoPE: {output_with_rope.shape}")
    print(f"Output with RoPE (first element):\n{output_with_rope[0, 0, :4]}")

    # Demonstrate relative positioning
    # The dot product between two tokens at a certain relative distance
    # should be related to the angle between their rotated embeddings.

    # Let's consider two tokens from the same sequence
    token_0 = dummy_input[0, 0]
    token_5 = dummy_input[0, 5]

    token_0_rope = output_with_rope[0, 0]
    token_5_rope = output_with_rope[0, 5]

    print("\nDemonstrating relative positioning effects:")
    print(f"Dot product of original token 0 and token 5: {torch.dot(token_0, token_5):.4f}")
    print(f"Dot product of RoPE-applied token 0 and token 5: {torch.dot(token_0_rope, token_5_rope):.4f}")

    # The exact relationship is that the dot product of two RoPE-applied vectors
    # depends on the original dot product and the cosine of the angle corresponding
    # to their relative position.
    # Mathematically, for q_m and k_n (query at position m, key at position n):
    # RoPE(q_m) . RoPE(k_n) = q_m . k_n * cos(theta_m - theta_n) + ... (more complex terms)
    # The key idea is that the relative position (m-n) is encoded into the rotation.
    # This leads to the desirable property that attention scores are functions of relative distance.

Original input shape: torch.Size([2, 10, 64])
Original input (first element):
tensor([ 1.8229, -1.1492,  0.4515,  1.8060])

Output shape with RoPE: torch.Size([2, 10, 64])
Output with RoPE (first element):
tensor([ 1.8229, -1.1492,  0.4515,  1.8060])

Demonstrating relative positioning effects:
Dot product of original token 0 and token 5: -11.9919
Dot product of RoPE-applied token 0 and token 5: -8.6927
