![mla](./img/MLA-.png)

# MLA (带低秩压缩)

## 公式：
$$
Q = HW_Q^{down}W_Q^{up}, \quad K, V = HW_{KV}^{down}W_{KV}^{up}
$$

- Query 压缩到 $r_q$，再升到 $(d_{q,nope} + d_{q,rope})$
- Key 压缩到 $r_k$ ，再升到 $(d_{q,nope} + d_v)$
- Rope 部分 $d_{rope}$ 单独保留

```
输入 H: [B, L, d]
   │
   ├─ W_Q^down: [d, r_q]
   │     → [B, L, r_q]
   ├─ W_Q^up: [r_q, n_h·(d_q_nope + d_q_rope)]
   │     → Q: [B, n_h, L, d_q_nope + d_q_rope]
   │             split → Q_nope[d_q_nope], Q_pe[d_q_rope]
   │
   ├─ W_KV^down: [d, r_k + d_rope]
   │     → [B, L, r_k + d_rope]
   │             split → K_comp[r_k], K_pe[d_rope]
   ├─ W_KV^up: [r_k, n_h·(d_q_nope + d_v)]
   │     → K_proj, V_proj
   │
   └─ 注意力: 
         Q_pe ⨉ K_pe^T   (positional)
       + Q_nope ⨉ K_comp^T (content)

```

In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
from typing import Optional, Tuple
import math
from dataclasses import dataclass

In [2]:
# RMS标准化和RoPE旋转位置编码
class DeepseekV2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
    
class DeepseekV2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # 较小索引位置对应较低频率
        # 较大的索引位置有较高的频率
        
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
        )
        self.max_seq_len_cached = None

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )

        freqs = torch.outer(t, self.inv_freq.to(t.device))
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)

    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    b, h, s, d = k.shape
    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [3]:
# 无矩阵吸收版本MLA
@dataclass
class DeepseekConfig:
    hidden_size: int
    num_heads: int
    max_position_embeddings: int
    rope_theta: float
    attention_dropout: float
    
    q_lora_rank: int
    qk_rope_head_dim: int
    kv_lora_rank: int
    v_head_dim: int 
    qk_nope_head_dim: int
    attention_bias: bool



In [4]:
class MLA(nn.Module):
    def __init__(self, config,):
        super().__init__()

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rpoe_theta = config.rope_theta

        # q的压缩向量(compressed latent vector)
        self.q_lora_rank = config.q_lora_rank
        # qk 进行rope 的维度
        self.qk_rope_head_dim = config.qk_rope_head_dim
        
        # kv的压缩向量
        self.kv_lora_rank = config.kv_lora_rank
        self.v_head_dim = config.v_head_dim
        
        # 对于旋转位置编码的拆分， 因为 k^R_t and q^R_t,i is shared shape (d^R_h)
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim

        # 投影矩阵
        self.q_down_proj = nn.Linear(self.hidden_size, 
                                     self.q_lora_rank, 
                                     bias=config.attention_bias)
        self.q_down_layernorm = DeepseekV2RMSNorm(self.q_lora_rank)
        self.q_up_proj = nn.Linear(self.q_lora_rank, 
                                   self.num_heads * self.q_head_dim, 
                                   bias=False)
        
        self.kv_down_proj = nn.Linear(self.hidden_size, 
                                     self.kv_lora_rank + self.qk_rope_head_dim,  # RoPE(k)是和 RoPE(q) shared shape
                                     bias=config.attention_bias)
        self.kv_down_layernorm = DeepseekV2RMSNorm(self.kv_lora_rank)
        self.kv_up_proj = nn.Linear(self.kv_lora_rank, 
                                   self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), 
                                   bias=False) # self.q_head_dim - self.qk_rope_head_dim 是nope部分
        
        self.o_proj = nn.Linear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=config.attention_bias,
        )

        self.rotary_emb = DeepseekV2RotaryEmbedding(
            self.qk_rope_head_dim,
            self.max_position_embeddings,
            self.rpoe_theta,
        )

    def forward(
            self,
            hidden_state: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        MLA (Multi-head Linearized Attention) forward pass

        hidden state shape [B, T, d]
        attetion head-- H
        query head-- d_q = d_q^{nope} + d_q^{rope} 有位置编码和没有的
        value head-- d_v
        Q lora rank-- r_q
        KV lora rank-- r_{kv}   
        """
        batch_size, seq_len, hidden_size = hidden_state.size()  # [B, T, d]

        # 1.q 投影和拆分
        q = self.q_down_proj(hidden_state)    # 压缩 [B, T, d] --> [B, T, r_q]
        print("q_down_proj:", q.shape)

        q = self.q_down_layernorm(q)          # 标准化不变 [B, T, r_q]
        q = self.q_up_proj(q)                 # [B, T, r_q] --> [B, T, H *d_q]
        print("q_up_proj:", q.shape)

        # reshape + transpose                 # [B, T, H * d_q] --> [B, H, T, d_q]
        q = q.view(batch_size, seq_len, self.num_heads, self.q_head_dim).transpose(1, 2)
        print("q_reshape:", q.shape)  

        q_nope, q_pe = torch.split(
            q,
            [self.qk_nope_head_dim, self.qk_rope_head_dim],
            dim=-1, 
        )
        print("q_nope:", q_nope.shape, "q_pe:", q_pe.shape)  
        # 拆分为： q_nope: (B, H, T, d_q^{nope}) and  q_pe: (B, H, T, d_q^{rope})
        


        # 2.k/v 压缩和投影
        # (B, T, d) → (B, T, r_{kv} + d_q^{rope})
        compressed_kv = self.kv_down_proj(hidden_state)  
        print("kv_compressed with rope:", compressed_kv.shape)

        # compressed_kv: (B, T, r_{kv})
        # k_pe:  (B, T, d_q^{rope})
        compressed_kv, k_pe = torch.split(compressed_kv, 
                                          [self.kv_lora_rank, self.qk_rope_head_dim],
                                          dim=-1,)
        print("compressed_kv:", compressed_kv.shape, "k_pe:", k_pe.shape)

        # reshape + transpose (B, 1, T, d_q^{rope})
        k_pe = k_pe.view(batch_size, seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)
        print("reshape k_pe:", k_pe.shape)

        # (B, T, r_{kv}) → (B, T, H * (d_q^{nope} + d_v))
        kv = self.kv_down_layernorm(compressed_kv)
        print("kv_down_layernorm:", kv.shape)
        kv = self.kv_up_proj(kv)
        print("kv_up:", kv.shape)

        # (B, T, H * (d_q^{nope} + d_v)) → (B, H, T, d_q^{nope} + d_v)
        kv = kv.view(batch_size, seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2)
        print("kv_shape:", kv.shape)

        # k_nope: (B, H, T, d_q^{nope})
        # v:      (B, H, T, d_v)
        k_nope, value_state = torch.split(
            kv,
            [self.qk_nope_head_dim, self.v_head_dim],
            dim=-1,
        )
        print("k_nope:", k_nope.shape, "value_states:", value_state.shape)

        #3. 应用旋转位置编码
        kv_seq_len = value_state.shape[-2]  # for T--seq_len
        cos, sin = self.rotary_emb(value_state, seq_len=kv_seq_len)

        # q_pe, k_pe: (B, H, T, d_q^{rope})
        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

        # 4. Combine position-dependent and independent parts
        # q_head_dim = qk_nope_head_dim + qk_rope_head_dim, 以qk_nope_head_dim为分界点前后插入q_nope,q_pe
        query_states = torch.empty(
            batch_size, self.num_heads, seq_len, self.q_head_dim, 
            device=k_pe.device
        )
        query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
        query_states[:, :, :, self.qk_nope_head_dim:] = q_pe

        # 对key同样的
        key_states = torch.empty(
            batch_size, self.num_heads, seq_len, self.q_head_dim, 
            device=k_pe.device
        )
        key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
        key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
        # query_states = [B, H, T, d_q^{nope} + d_q^{pe}] 
        # and key_states = query_states = [B, H, T, d_k^{nope} + d_k^{pe}]
        print("final query_states:", query_states.shape)
        print("final key_states:", key_states.shape)

        # 5. Compute attention scores
        # [B, H, T, d_q] times [[B, H, d_q, T] --> [B, H, T, T]
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
        attn_weights = attn_weights / math.sqrt(self.q_head_dim)

        if attention_mask is not None:
            attn_weights = torch.masked_fill(
                attn_weights,
                attention_mask == 0,
                float("-inf"),
            )
        print("attn_weights:", attn_weights.shape)

        # 6. Softmax and dropout
        attn_weights = F.softmax(
            attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = F.dropout(
            attn_weights, p=self.attention_dropout, training=self.training)

        # 7. Compute attention output
        attn_output = torch.matmul(attn_weights, value_state)
        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1)
        attn_output = self.o_proj(attn_output)

        return attn_output, attn_weights


In [5]:
def test_mla():
    config = DeepseekConfig(
        hidden_size=7168,
        num_heads=16,
        max_position_embeddings=1024,  # 测试上下文长度（deepseek支持8192 / 32768...）
        rope_theta=128000,
        attention_dropout=0.1,
        q_lora_rank=1536,
        qk_rope_head_dim=64,
        kv_lora_rank=512,
        
        v_head_dim=128,
        qk_nope_head_dim=128,
        attention_bias=False,
    )

    mla = MLA(config)
    x = torch.randn(2, 1024, 7168)
    position_ids = torch.arange(config.max_position_embeddings,).unsqueeze(0).expand(x.size(0), -1)
    attn_output, attn_weights = mla(x, position_ids=position_ids)
    print("attention outputs: ",attn_output.shape)


test_mla()

q_down_proj: torch.Size([2, 1024, 1536])
q_up_proj: torch.Size([2, 1024, 3072])
q_reshape: torch.Size([2, 16, 1024, 192])
q_nope: torch.Size([2, 16, 1024, 128]) q_pe: torch.Size([2, 16, 1024, 64])
kv_compressed with rope: torch.Size([2, 1024, 576])
compressed_kv: torch.Size([2, 1024, 512]) k_pe: torch.Size([2, 1024, 64])
reshape k_pe: torch.Size([2, 1, 1024, 64])
kv_down_layernorm: torch.Size([2, 1024, 512])
kv_up: torch.Size([2, 1024, 4096])
kv_shape: torch.Size([2, 16, 1024, 256])
k_nope: torch.Size([2, 16, 1024, 128]) value_states: torch.Size([2, 16, 1024, 128])
final query_states: torch.Size([2, 16, 1024, 192])
final key_states: torch.Size([2, 16, 1024, 192])
attn_weights: torch.Size([2, 16, 1024, 1024])
attention outputs:  torch.Size([2, 1024, 7168])
