##### Implementing Multi-head latent Attention 

Breakthrough into Multi-head latent attention, it was the game changer because it improves the computioanl cost of training.
- The usual multi head attention has drawback as it require a quite large amount of memory to store keys(k) and values(v) during inference for every token it has seen so far
 this is called Key-Value (KV) cache. This size of KV cache grows linearly with the sequence length.
- And here it comes Multi-Head latent attention, it's a new attention mechanism design to solve the memory KV cache issue problem. It acieves this by compressing the key and values into smaller, shared representation called a latent vector. This reduces the size.

MLA introduces two innovations :

1. **Low-Rank Key-Value Compression**: so the idea behind is that so compress the key and values into a smaller representation. Instead of storing full keys and values, MLA compress them into a latent vector. This latent vector is much smaller than the orignal keys and values, significantly reducing memory usage.
$$\text{Let } K \in \mathbb{R}^{n \times d} \text{ and } V \in \mathbb{R}^{n \times d} \text{ be the original key and value matrices.}\\
\text{We decompose them into low-rank factors:}$$

$$K \approx K_A K_B = \begin{bmatrix} k_1^T \\ k_2^T \\ \vdots \\ k_n^T \end{bmatrix} \text{ where } K_A \in \mathbb{R}^{n \times r}, K_B \in \mathbb{R}^{r \times d}$$

$$V \approx V_A V_B = \begin{bmatrix} v_1^T \\ v_2^T \\ \vdots \\ v_n^T \end{bmatrix} \text{ where } V_A \in \mathbb{R}^{n \times r}, V_B \in \mathbb{R}^{r \times d}$$

$$\text{where } r \ll d \text{ is the rank, achieving compression ratio of } \frac{2rd}{d(n+d)} \text{ for large } n$$

$$\text{Standard attention: } \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V$$

$$\text{Low-rank attention: } \text{Attention}(Q, K_A, K_B, V_A, V_B) = \text{softmax}\left(\frac{Q(K_A K_B)^T}{\sqrt{d}}\right)(V_A V_B)$$

$$= \text{softmax}\left(\frac{QK_B^T K_A^T}{\sqrt{d}}\right)V_A V_B$$

$$\text{Original memory: } \mathcal{O}(nd) \text{ for both K and V}$$

$$\text{Low-rank memory: } \mathcal{O}(nr + rd) = \mathcal{O}(r(n + d))$$


$$\text{At inference step } t, \text{ given new query } q_t \in \mathbb{R}^d \text{ and cached low-rank KV:}$$

$$K_{\text{cache}} = K_A K_B \text{ where } K_A \in \mathbb{R}^{(t-1) \times r}, K_B \in \mathbb{R}^{r \times d}$$

$$V_{\text{cache}} = V_A V_B \text{ where } V_A \in \mathbb{R}^{(t-1) \times r}, V_B \in \mathbb{R}^{r \times d}$$

$$\text{Compute new key-value pair for position } t:$$

$$k_t = \text{Linear}_K(x_t), \quad v_t = \text{Linear}_V(x_t)$$

$$\text{Decompose into low-rank factors:}$$

$$k_t = k_{t,A} k_{t,B} \text{ where } k_{t,A} \in \mathbb{R}^{1 \times r}, k_{t,B} \in \mathbb{R}^{r \times d}$$

$$v_t = v_{t,A} v_{t,B} \text{ where } v_{t,A} \in \mathbb{R}^{1 \times r}, v_{t,B} \in \mathbb{R}^{r \times d}$$

$$\text{Update cached factors by concatenation:}$$

$$K_A^{(t)} = \begin{bmatrix} K_A^{(t-1)} \\ k_{t,A} \end{bmatrix} \in \mathbb{R}^{t \times r}$$

$$K_B^{(t)} = K_B^{(t-1)} = k_{t,B} \in \mathbb{R}^{r \times d} \quad \text{(shared across positions)}$$

$$V_A^{(t)} = \begin{bmatrix} V_A^{(t-1)} \\ v_{t,A} \end{bmatrix} \in \mathbb{R}^{t \times r}$$

$$V_B^{(t)} = V_B^{(t-1)} = v_{t,B} \in \mathbb{R}^{r \times d} \quad \text{(shared across positions)}$$

$$\text{Compute attention scores efficiently:}$$

$$\text{scores}_t = \frac{q_t (K_A^{(t)} K_B^{(t)})^T}{\sqrt{d}} = \frac{q_t (K_B^{(t)})^T (K_A^{(t)})^T}{\sqrt{d}}$$

$$= \frac{(q_t (K_B^{(t)})^T) (K_A^{(t)})^T}{\sqrt{d}} \in \mathbb{R}^{1 \times t}$$

$$\text{where } q_t (K_B^{(t)})^T \in \mathbb{R}^{1 \times r} \text{ is computed once}$$

$$\alpha_t = \text{softmax}(\text{scores}_t) \in \mathbb{R}^{1 \times t}$$

$$\text{output}_t = \alpha_t (V_A^{(t)} V_B^{(t)}) = (\alpha_t V_A^{(t)}) V_B^{(t)}$$

$$= \sum_{i=1}^{t} \alpha_{t,i} \cdot (v_{i,A} v_{i,B}) = \left(\sum_{i=1}^{t} \alpha_{t,i} v_{i,A}\right) V_B^{(t)}$$


2. **Decoupled Rotary position Embedding**: technique to encode the position into tokens of sequence. However it normal RoPE face challenge because when using low-rank compression, the position information get mixed into the compressed key and values, making it hard to reuse them efficiently during inference. 
- To efficiently use while maintaining memory efficiency, MLA uses a decoupled RoPE strategy. This introduces additional multi-head queries and a shared key to encoder RoPE.
$$\text{Standard RoPE applies rotation to both queries and keys:}$$

$$q_m = \text{RoPE}(q_m, m), \quad k_m = \text{RoPE}(k_m, m)$$

$$\text{In MLA, decouple position encoding from compressed representations:}$$

$$\text{Shared key: } k_m^{\text{shared}} = \text{RoPE}(W_k^{\text{shared}} x_m, m) \in \mathbb{R}^{d_k}$$

$$\text{Multi-head queries: } q_{m,h} = \text{RoPE}(W_{q,h} x_m, m) \in \mathbb{R}^{d_k} \text{ for head } h$$

$$\text{Low-rank KV (position-free): } k_m^{\text{LR}} = W_k^{\text{LR}} x_m, \quad v_m^{\text{LR}} = W_v^{\text{LR}} x_m$$

$$\text{Standard RoPE applies rotation to both queries and keys:}$$

$$q_m = \text{RoPE}(q_m, m), \quad k_m = \text{RoPE}(k_m, m)$$

$$\text{Attention combines position-aware and compressed components:}$$

$$\text{Attention}_h(m) = \alpha_{m,h}^{\text{pos}} \cdot k_m^{\text{shared}} + \alpha_{m,h}^{\text{content}} \cdot (k_m^{\text{LR}} v_m^{\text{LR}})$$

$$\text{where: } \alpha_{m,h}^{\text{pos}} = \text{softmax}\left(\frac{q_{m,h}^T k_{\cdot}^{\text{shared}}}{\sqrt{d_k}}\right)$$

$$\alpha_{m,h}^{\text{content}} = \text{softmax}\left(\frac{q_{m,h}^T k_{\cdot}^{\text{LR}}}{\sqrt{d_k}}\right)$$

- So one of the key optimization in MLA is the absorption of weight matrices. This allow the model to allow to avoid explicitly reconstructing the keys and value during inference, saving both computation time and memory.


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import math 

device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')


class RoPE(nn.Module):
    '''implementing decoupled RoPE'''
    def __init__(self, d_model, max_seq_len=2048, base=10000):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.base = base 

        # inverse frequency and register as a buffer to check and save how fast each dimensions rotate 
        inverse_freq = 1.0 / (self.base **(torch.arange(0, self.d_model, 2).float().to(device)/self.d_model))
        self.register_buffer("inverse_freq", inverse_freq, persistent = False)

        self.cache(seq_len = max_seq_len, device=device, dtype =torch.float32)
    
    def cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inverse_freq.dtype)

        frequencies = torch.einsum("i,j->ij", t, self.inverse_freq)
        embedding = torch.cat((frequencies, frequencies), dim=-1)
        self.register_buffer("cosine_cached", embedding.cos().to(dtype), persistent=False)
        self.register_buffer("sine_cached", embedding.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        if seq_len > self.max_seq_len_cached:
            self.cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
        return (
            self.cosine_cached[:seq_len].to(dtype=x.dtype),
            self.sine_cached[:seq_len].to(dtype=x.dtype),

        )

def rotate_half(x):
    '''rotate half of the hidde dim from the input'''
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary(x, cos, sin, position_ids):
    # cos and sin have shape [seq_len, dim]
    # position_ids have shape [batch, seq_len]
    cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
    sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
    # x has shape [bs, seq_len, n_heads, dim]
    x_embed = (x * cos) + (rotate_half(x) * sin)
    return x_embed

In [16]:
class MultiHeadLatent(nn.Module):
    def __init__(self, d_model, n_heads, kv_latent_dim):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.kv_latent_dim = kv_latent_dim
        self.dh = d_model//n_heads
        self.rotary_emb = RoPE(self.kv_latent_dim)

        # projection layer
        self.w_q = nn.Linear(d_model, d_model, bias=False) # q matrix
        self.w_dkv = nn.Linear(d_model, kv_latent_dim, bias=False) # compressed kv 
        self.w_uk = nn.Linear(kv_latent_dim, d_model, bias=False) # takes compress intermediate key representation vector as input
        self.w_uv = nn.Linear(kv_latent_dim, d_model, bias=False) # takes compress intermediate value representation vector as input
        self.w_o = nn.Linear(d_model, d_model, bias=False) # output layer
        self.ln = nn.LayerNorm(kv_latent_dim)

        self.register_buffer('absorbed_k', None) # have information about w_q @ w_uk

    def forward(self, x, kv_cache=None, past_length=0):
        B,S,D = x.size() # batch, sequence_lenth, dimmension

        # calculate absorbed_k w_q @ w_uk
        if self.absorbed_k is None:
            absorbed = torch.matmul(self.w_q.weight, self.w_uk.weight)
            self.absorbed_k = absorbed.view(self.n_heads, self.dh, self.kv_latent_dim) # reshaping

        # compress input into latent kv 
        new_c_kv = self.ln(self.w_dkv(x))

        if kv_cache is None:
            c_kv = new_c_kv
        else:
            c_kv = torch.cat([kv_cache, new_c_kv], dim=1)

        s_full = c_kv.size(1)

        # Decompress V to full d_model and split into heads
        v_full = self.w_uv(c_kv)
        v = v_full.view(B,s_full,self.n_heads, self.dh).transpose(1,2)

        # Use input x directly (since W_q is abosorbed)
        q = x.view(B,S, self.n_heads, self.dh)

        # project q into latent space
        q_latent = torch.einsum('bshd, hdl->bshl', q, self.absorbed_k)

        # get cos, sin
        cos, sin = self.rotary_emb(x=q_latent, seq_len=s_full)

        # apply rotary positional 
        query_pos = torch.arange(past_length, past_length+S, device=x.device).view(1, S)
        q_rotate = apply_rotary(q_latent, cos, sin, query_pos)  
        q_rotate = q_rotate.transpose(1, 2) # shape (B, n_heads, S, kv_latent_dim)

        # Apply RoPE to keys (from c_kv) based on their positions
        key_pos = torch.arange(0, s_full, device=x.device).view(1, s_full)
        k_rotate = apply_rotary(c_kv.unsqueeze(2), cos, sin, key_pos) # shape (B, S_full, 1, kv_latent_dim)
        k_rotate = k_rotate.transpose(1,2) # shape (B, 1, S_full, kv_latent_dim)

        # compute attention score 
        # Compute scaled dot-product attention Q @ K.T.
        # Shapes: (B, n_heads, S, D) @ (B, 1, D, S_full) -> (B, n_heads, S, S_full)
        # The '1' head in K is broadcast to match the 'n_heads' in Q (Multi-Query Attention).
        # The transpose(2,3) on K aligns the dimensions for the dot product.
        attention_score = torch.matmul(q_rotate, k_rotate.transpose(2,3)) / math.sqrt(self.kv_latent_dim)
        mask = torch.tril(torch.ones((S, s_full), device = x.device), diagonal=past_length) # casual masking
        attention_score = attention_score.masked_fill(mask.view(1,1,S,s_full)==0, float('-inf'))
        attention_weights = F.softmax(attention_score, dim=1)

        output = torch.matmul(attention_weights, v)
        output = output.transpose(1,2).contiguous().view(B,S,D)

        return self.w_o(output), c_kv # final output + updated latent cache


In [17]:
# testing the memory 
def demo():
    
    model = MultiHeadLatent(d_model=512,n_heads=8,kv_latent_dim=256).to(device)
    
    x = torch.randn(1,5,512).to(device)

    out, cache = model(x)
    print(f"Output shape: {out.shape}, Cache shape:{cache.shape}")

    B, S, D = x.shape
    latent_dim = cache.shape[-1]
    std_size = B * 2 * S * D * 4 / 1024
    latent_size = B * S * latent_dim * 4 / 1024
    print(f"Memory: Standard KV Cache = {std_size:.1f} KB, Latent Cache = {latent_size:.1f} KB, Reduction = {std_size/latent_size:.1f}x")

if __name__ == "__main__":
    demo()

Output shape: torch.Size([1, 5, 512]), Cache shape:torch.Size([1, 5, 256])
Memory: Standard KV Cache = 20.0 KB, Latent Cache = 5.0 KB, Reduction = 4.0x


In [18]:
def demo_kv_cache_growth(num_initial_tokens=5, num_new_tokens=3):
  """Demonstrates the growth of the KV cache during autoregressive decoding."""
  print(f"--- Starting KV Cache Growth Demo ---")
  torch.manual_seed(0)

  model = MultiHeadLatent(d_model=8, n_heads=2, kv_latent_dim=4).to(device)

  # Move the input tensor to the device
  x = torch.randn(1, num_initial_tokens, 8).to(device)
  
  # The first pass has no past, so kv_cache is None and past_length is 0
  out, cache = model(x)
  print(f"Step 0: Initial prompt processed.")
  print(f"        Input shape: ({num_initial_tokens} tokens)")
  print(f"        Cache shape: {cache.shape} -> (Batch, Sequence, LatentDim)\n")

  # Step 2: Incrementally append new tokens, one at a time
  for step in range(1, num_new_tokens + 1):
    # The past context is now the cache from the previous step
    past_context = cache
    
    #  Move the new token tensor to the device
    new_token = torch.randn(1, 1, 8).to(device)
    
    # Pass the new token and the existing cache to the model
    out, cache = model(new_token, kv_cache=past_context, past_length=past_context.shape[1])
    
    print(f"Step {step}: Generating one new token...")
    print(f"        Input shape: (1 token)")
    print(f"        Cache shape: {cache.shape} -> Sequence length grew by 1!\n")

if __name__ == "__main__":
    demo_kv_cache_growth(num_initial_tokens=50, num_new_tokens=4)

--- Starting KV Cache Growth Demo ---
Step 0: Initial prompt processed.
        Input shape: (50 tokens)
        Cache shape: torch.Size([1, 50, 4]) -> (Batch, Sequence, LatentDim)

Step 1: Generating one new token...
        Input shape: (1 token)
        Cache shape: torch.Size([1, 51, 4]) -> Sequence length grew by 1!

Step 2: Generating one new token...
        Input shape: (1 token)
        Cache shape: torch.Size([1, 52, 4]) -> Sequence length grew by 1!

Step 3: Generating one new token...
        Input shape: (1 token)
        Cache shape: torch.Size([1, 53, 4]) -> Sequence length grew by 1!

Step 4: Generating one new token...
        Input shape: (1 token)
        Cache shape: torch.Size([1, 54, 4]) -> Sequence length grew by 1!

