# Implement ROPE from Scratch
# 🔁 Rotary Positional Embeddings in PyTorch

### 🧠 Problem Statement
Transformers need a sense of **order**, but vanilla attention mechanisms are position-agnostic. Positional encodings help inject this order-awareness into the model. 

Your mission is to implement **Rotary Positional Embeddings (RoPE)** from scratch — a newer and slicker technique that rotates the query and key vectors instead of simply adding sine-cosine vectors. This method preserves attention efficiency while enabling better generalization for long sequences.

---

### ✅ Requirements

1. **Implement the Rotary Module**
   - Construct a `Rotary` class to compute sinusoidal frequencies.
   - Precompute and cache `cos` and `sin` values per sequence length.
   - Register these as buffers to keep them on the correct device.

2. **Define Rotation Helpers**
   - `rotate_half(x)` splits and rotates half the dimensions of a tensor.
   - `apply_rotary_pos_emb(q, k, cos, sin)` applies these rotations to Q and K.

3. **Simulate Usage**
   - Create synthetic tensors for Q, K, V.
   - Generate rotary embeddings using the custom `Rotary` module.
   - Apply rotary embeddings to Q and K.

4. **Verify Dimensions**
   - Final shapes should align with expected shapes for attention modules.
   - Confirm RoPE is applied before dot-product attention would normally occur.

---

### 📏 Constraints

- ✅ Use only PyTorch — no Fairseq or HuggingFace positional modules.
- ✅ Must support dynamic sequence lengths and cache embeddings per sequence.
- ✅ Should handle odd/even dimensional splits correctly.
- ❌ Do **not** manually plug in Fairseq’s `SinusoidalPositionalEmbedding`.

---

<details>
  <summary>💡 Hint</summary>
  - Use `torch.einsum("i,j->ij", t, inv_freq)` to compute frequency pairs.
  - Cache the cosine and sine values in the `Rotary` class using `self.register_buffer()`.
  - The `rotate_half(x)` function should split `x` into two halves and rotate them: `[-x2, x1]`.
  - Apply the rotary transformation using:  
    `(q * cos) + (rotate_half(q) * sin)`  
    and similarly for `k`.
  - Remember to broadcast `cos` and `sin` to match the shape of `q` and `k`.
</details>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math

In [None]:
# Synthetic data
torch.manual_seed(42)
batch_size = 3
seq_len = 4
d_model = 8
num_heads = 2

q = torch.rand(batch_size, seq_len, d_model)
k = torch.rand(batch_size, seq_len, d_model)
v = torch.rand(batch_size, seq_len, d_model)
print(q.shape)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

In [None]:
class Rotary(torch.nn.Module):
    
def rotate_half(x):
    ...

@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
    ...

In [None]:
# from fairseq.modules.sinusoidal_positional_embedding import SinusoidalPositionalEmbedding

max_seq_len = 100
d_model = 64

# Fairseq's implementation requires the number of embeddings (seq length) and embedding dim
# pos_emb = SinusoidalPositionalEmbedding(d_model, max_seq_len, padding_idx=None)

# Generate embeddings for a sequence of length 50
seq_len = 50
positions = torch.arange(seq_len).unsqueeze(0)  # Shape: (1, seq_len)
# positional_encoding = pos_emb(positions)  # Shape: (1, seq_len, d_model)

custom_pos_emb = Rotary(d_model, max_seq_len)

positional_encoding_custom = apply_rotary_pos_emb(positions)

print(positional_encoding_custom.shape)  # (1, 50, 64)
