## Multi-Head Latent Attention with Rotary Positional Embeddings

### How does MLA with RoPE work?

Let's break down the process step by step:

- **Input:**  
  X (shape: [2, 768])

- **Query projection:**  
  Q = X * Wq (shape: [2, 768])

- **Latent key-value space:**  
  Ckv = X * Wdkv (shape: [2, 512])

- **Key and Value projections:**  
  K = Ckv * Wuk (shape: [2, 768])  
  V = Ckv * Wuv (shape: [2, 768])

- **Attention Scores:**  
  Attention_Scores = Q * K_transpose (shape: [2, 2])

#### Expanded view of attention computation:

- Q * K_transpose = X * Wq * (Ckv * Wuk)_transpose
- This can be rewritten as: X * Wq * Wuk_transpose * Ckv_transpose
- Note: Wdkv * Wuk_transpose is a fixed matrix, so the expression becomes:  
  X * (Wq * Wuk_transpose) * (X * Wdkv)_transpose

- The term X * (Wq * Wuk_transpose) is sometimes called the "absorbed query".
- The term (X * Wdkv)_transpose (which is Ckv_transpose) can be cached for efficiency.

---

### Absorbed Query and Latent KV Caching

- The "absorbed query" (X * (Wq * Wuk_transpose)) can be precomputed and stored.
- For each new input x (single or multiple tokens), you only need to bring in the absorbed query and the cached Ckv (both are precalculated).
- This allows you to calculate attention scores efficiently, saving memory on the key-value (KV) cache, since you are storing the latent key-value space instead.
- As a result, both memory usage and computation are improved.

---

### What changes when introducing RoPE?

- When you introduce RoPE (Rotary Positional Embeddings), you need to rotate the query first and then rotate the key before calculating the attention scores.
- The computation becomes:

  ```
  Attention_Scores = Q * K_transpose = X * Wq * (Ckv * Wuk)_transpose
  Attention_Scores = Rope(X * Wq) * Rope(Ckv * Wuk)_transpose
  ```

- By doing this, you lose the absorbed query optimization, because the rotation must be applied after the projections, and you also have to calculate the key and then apply the rotation.
- This means you can't precompute and cache the absorbed query as before.

---

### DeepSeek's Approach

- To address this limitation, DeepSeek splits the query and key into two parts:
  - One part where RoPE is applied.
  - One part where RoPE is not applied (and uses the same latent key-value cache as before).
- For the part where RoPE is applied:
  - Multi-Query Attention (MQA) is used, meaning the head weights are shared across all heads for the RoPE part.
  - The keys for this part are directly cached for a single head, and since the weights are shared, this is sufficient for all heads.
- Both of the parts will give their own attention scores and we will add them together to get the final attention scores.
- And once we have the added attention scores we can get the value matrix which remains the same as before; it is still an up-projection from the latent key-value space and hence obtain the context vector.
- For the query, both down and up projections are performed. (NOT SURE WHAT THIS MEANS BUT it was mentioned about reducing activation memory during training.)

---

In [33]:
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F

def apply_rope(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
    """
    RoPE-rotate q or k.

    Args
    ----
    x   : (B, H, S, dh)  – even dh required
    pos : (S,)  or  (B, S)  – absolute token indices

    Returns
    -------
    x_rot : same shape as x, with RoPE applied
    """

    dh = x.shape[-1]
    assert dh % 2 == 0, "dh must be even"

    half = dh // 2

    device, dtype = x.device, x.dtype

    idx = torch.arange(half, device=device, dtype=dtype)
    freqs = 1.0 / (10000 ** (idx / half))

    if pos.dim() == 1:                                   # (S,)  → (S,1)
        theta = pos.to(dtype).unsqueeze(-1) * freqs          # (S, half)
        # bring to  (1, 1, S, half)  so it lines up with (B,H,S,dh)
        theta = theta.unsqueeze(0).unsqueeze(0)
    else:
        theta = pos.to(dtype).unsqueeze(-1) * freqs # (B, S, half)
        theta = theta.unsqueeze(1) # (B, 1, S, half)

    sin, cos = theta.sin(), theta.cos()

    x_first_half = x[..., :half]
    x_second_half = x[..., half:]

    x_first_half_rot = x_first_half * cos - x_second_half * sin
    x_second_half_rot = x_first_half * sin + x_second_half * cos

    x_rot = torch.cat([x_first_half_rot, x_second_half_rot], dim=-1)

    return x_rot


class MLAWithRope(nn.Module):
    def __init__(self, d_model,n_heads,n_heads_rope,kv_latent_dim):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_heads_rope = n_heads_rope
        self.kv_latent_dim = kv_latent_dim
        self.dh = d_model // n_heads
        self.dh_rope = d_model // n_heads_rope

        # Weight for MLA without Rope
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_dkv = nn.Linear(d_model, kv_latent_dim, bias=False)
        self.W_uk = nn.Linear(kv_latent_dim, d_model, bias=False)
        self.W_uv = nn.Linear(kv_latent_dim, d_model, bias=False)

        # Weight for MLA with Rope
        self.W_q_rope = nn.Linear(d_model, self.dh_rope, bias=False)
        self.W_k_rope = nn.Linear(d_model,self.dh_rope, bias=False) # becuase all heads share the same weight matrix in case of Rope

        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.ln1 = nn.LayerNorm(kv_latent_dim)

        self.register_buffer('absorbed_k', None)

    def forward(self, x: torch.Tensor, latent_kv_cache: Optional[torch.Tensor] = None, rope_key_cache: Optional[torch.Tensor] = None, past_length: int = 0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        B, S, D = x.shape # batch size, sequence length, d_model

        assert D == self.n_heads * self.dh, "d_model must be divisible by n_heads"
        assert D == self.n_heads_rope * self.dh_rope, "d_model must be divisible by n_head_rope"

        if self.absorbed_k is None:
            absorbed = torch.matmul(self.W_q.weight,self.W_uk.weight) # the transpose is done automatically in matmul (d_model, kv_latent_dim)
            self.absorbed_k = absorbed.view(self.n_heads,self.dh, -1) # (n_heads, dh, kv_latent_dim)
        
        # ckv basically means the latent key value space
        new_c_kv = self.ln1(self.W_dkv(x)) # (B, S, kv_latent_dim)
        if latent_kv_cache is None:
            latent_kv_cache = new_c_kv
        else:
            latent_kv_cache = torch.cat([latent_kv_cache, new_c_kv], dim=1) # (B, S_full, kv_latent_dim)

        assert latent_kv_cache is not None, "latent_kv_cache should never be None here"

        s_full = latent_kv_cache.size(1)

        # MLA without Rope

        x_per_head = x.view(B, S, self.n_heads, self.dh) # (B, S, n_heads, dh)
        x_per_head = x_per_head.transpose(1, 2) # (B, n_heads, S, dh)

        # (B, n_heads, S, dh) * (n_heads, dh, kv_latent_dim) -> (B, n_heads, S, kv_latent_dim)
        q_lat = torch.matmul(x_per_head, self.absorbed_k) # (B, n_heads, S, d_model)

        # (B, n_heads, S, kv_latent_dim) * (B, n_heads, s_full, kv_latent_dim).transpose(1, 2) -> (B, n_heads, S, s_full)
        scores = torch.matmul(q_lat, latent_kv_cache.transpose(1, 2)) # (B, n_heads, S, s_full)

        # MLA with Rope
        # here all heads share the same weight matrix, 

        # (B, S, D) -> (B, S, dh_rope)
        q_rope_per_head: torch.Tensor= self.W_q_rope(x).unsqueeze(1) # (B, 1, S, dh_rope)
        k_rope_per_head: torch.Tensor = self.W_k_rope(x).unsqueeze(1) # (B, 1, S, dh_rope)

        pos_cur = torch.arange(past_length, past_length + S, device=x.device)
        q_rope_per_head = apply_rope(q_rope_per_head, pos_cur) # (B, 1, S, dh_rope)
        k_rope_per_head = apply_rope(k_rope_per_head, pos_cur) # (B, 1, S, dh_rope)


        if rope_key_cache is None:
            rope_key_cache = k_rope_per_head
        else:
            rope_key_cache = torch.cat([rope_key_cache, k_rope_per_head], dim=2) # (B, 1, S_full, dh_rope)

        q_rope = q_rope_per_head.expand(-1, self.n_heads_rope, -1, -1) # (B, n_heads_rope, S, dh_rope)
        k_rope = rope_key_cache.expand(-1, self.n_heads_rope, -1, -1) # (B, n_heads_rope, S_full, dh_rope))

        # (B, n_heads, S, dh_rope) * (B, n_heads, S_full, dh_rope).transpose(1, 2) -> (B, n_heads_rope, S, S_full)
        scores_rope = torch.matmul(q_rope, k_rope.transpose(2, 3)) # (B, n_heads_rope, S, S_full)

        print(f"final attention scores : {scores.shape} and {scores_rope.shape}")


        # final attention scores are addition of scores from MLA without Rope and MLA with Rope
        final_attention_scores = scores + scores_rope # (B, n_heads_rope, S, S_full)


        final_attention_scores: torch.Tensor = final_attention_scores / (self.dh ** 0.5)

        mask = torch.tril(torch.ones(S,s_full,device=x.device),diagonal=past_length) # diagonal -> shifting the diagonal by past_length

        final_attention_scores = final_attention_scores.masked_fill(mask.view(1,1,S,s_full) == 0, float('-inf')) # (B, n_heads_rope, S, S_full)

        attention_weights = F.softmax(final_attention_scores,dim=-1) # (B, n_heads_rope,S,S_full)


        # value matrix will be extracted from the latent key value space
        value_matrix: torch.Tensor = self.W_uv(latent_kv_cache) # (B, S_full, d_model)
        value_matrix = value_matrix.view(B,s_full,self.n_heads_rope,self.dh_rope).transpose(1,2) # (B, n_heads_rope, S_full, dh_rope)

        context_matrix = torch.matmul(attention_weights,value_matrix) # (B,n_heads_rope,S,dh_rope)
        
        context_matrix = context_matrix.transpose(1, 2).contiguous().view(B, S, self.d_model)

        return self.W_o(context_matrix), latent_kv_cache, rope_key_cache        

In [34]:
def demo_mla():

    torch.manual_seed(0)
    
    model = MLAWithRope(d_model=8,n_heads=2,n_heads_rope=2,kv_latent_dim=4)

    ## we have to keep the n_heads and n_heads_rope same becaue then only we will be able to concatenate the attention scores from both the parts
    
    x1 = torch.randn(1,5,8) # (batch_size, seq_len, d_model)
    out1, latent_kv1, k_rope = model(x1)
    
    print("Step 1:")
    print(f"Output shape: {out1.shape}")
    print(f"KV shape: {latent_kv1.shape}")
    print(f"Rope Key cache shape: {k_rope.shape}")
    
    x2 = torch.randn(1,1,8) # (batch_size, seq_len, d_model)
    out2, latent_kv2, k_rope2 = model(x2,latent_kv_cache=latent_kv1,rope_key_cache=k_rope,past_length=5)
    
    print("Step 2:")
    print(f"Output shape: {out2.shape}")
    print(f"KV shape: {latent_kv2.shape}")
    print(f"Rope Key cache shape: {k_rope2.shape}")
    
demo_mla()
    
    
    
    

final attention scores : torch.Size([1, 2, 5, 5]) and torch.Size([1, 2, 5, 5])
Step 1:
Output shape: torch.Size([1, 5, 8])
KV shape: torch.Size([1, 5, 4])
Rope Key cache shape: torch.Size([1, 1, 5, 4])
final attention scores : torch.Size([1, 2, 1, 6]) and torch.Size([1, 2, 1, 6])
Step 2:
Output shape: torch.Size([1, 1, 8])
KV shape: torch.Size([1, 6, 4])
Rope Key cache shape: torch.Size([1, 1, 6, 4])
