# MLA代码实现

## ROPE+RMSNorm代码初始化

In [20]:
# 改编自：https://github.com/flashinfer-ai/flashinfer/blob/738460ff82e2230ebcc8dff50e49e1d6278e011a/tests/test_mla_decode_kernel.py
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    print('freqs_cis.shape', freqs_cis.shape)
    print('x.shape', x.shape)
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

class DeepseekV2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        DeepseekV2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return (self.weight * hidden_states).to(input_dtype)

## MLA 朴素版

In [109]:
class DeepseekV2AttentionVanilla(nn.Module):
    def __init__(self):
        super().__init__()
        # 以 deepseekv2 参数为准
        self.hidden_size = 7168
        self.num_heads = 128

        self.q_lora_rank = 1536
        self.qk_rope_head_dim = 64
        self.kv_lora_rank = 512
        self.v_head_dim = 128
        self.qk_nope_head_dim = 128
        self.q_head_dim = 192  # 192 = 128 + 64 = config.qk_nope_head_dim + config.qk_rope_head_dim

        self.rope_theta = 10000
        self.q_a_layernorm = DeepseekV2RMSNorm(self.q_lora_rank)
        self.softmax_scale = self.q_head_dim ** (-0.5)

        # W^DQ ~ [7168, 1536]
        self.W_DQ = nn.Linear(self.hidden_size, self.q_lora_rank, bias=False)
        # W^UQ ~ [1536, 128*128]
        self.W_UQ = nn.Linear(self.q_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False)
        # W^QR ~ [1536, 128*64]
        self.W_QR = nn.Linear(self.q_lora_rank, self.num_heads * self.qk_rope_head_dim, bias=False)
        # W^KR ~ [1536, 64]
        self.W_KR = nn.Linear(self.q_lora_rank, self.qk_rope_head_dim, bias=False)
        # W^DKV ~ [7168, 512]
        self.W_DKV = nn.Linear(self.hidden_size, self.kv_lora_rank, bias=False)
        # W^UK ~ [512, 128*128]
        self.W_UK = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False)
        # W^UV ~ [512, 128*128]
        self.W_UV = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, bias=False)
        # W^O ~ [128*128, 7168]
        self.W_O = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False)


    def run_decode(
        self,
        hidden_states: torch.Tensor,
        compressed_kv_normed_cache: torch.Tensor,
        k_pe_cache: torch.Tensor,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # 获取维度：[batch_size, query_length=1, hidden_size]
        bsz, q_len, _ = hidden_states.size()
        
        # 当前step输入的x，经过W_DQ，得到 [batch_size, 1, q_lora_rank]
        c_t_Q = self.q_a_layernorm(self.W_DQ(hidden_states))
        # 再经过W_UQ，得到 [batch_size, 1, num_heads=128 * qk_nope_head_dim=128]
        q_t_C = self.W_UQ(c_t_Q)
        # 再经过W_QR，得到 [batch_size, 1, num_heads=128 * qk_rope_head_dim=64]
        q_t_R = self.W_QR(c_t_Q).view(bsz, -1, self.num_heads, self.qk_rope_head_dim)
        # 再经过W_KR，得到 [batch_size, 1, qk_rope_head_dim=64]
        # 将当前step的k_t_R添加到k_pe_cache的最后一个位置，得到新的k_pe_cache
        k_t_R = self.W_KR(c_t_Q)
        k_pe_cache = torch.cat([k_pe_cache, k_t_R], dim=1)
        # 将最后一个维度拆开，方便注意力计算
        q_t_C = q_t_C.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)

        c_t_KV = self.W_DKV(hidden_states)
        compressed_kv_normed_cache = torch.cat([compressed_kv_normed_cache, c_t_KV], dim=1)
        k_C = self.W_UK(compressed_kv_normed_cache).view(bsz, -1, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
        v_C = self.W_UV(compressed_kv_normed_cache).view(bsz, -1, self.num_heads, self.v_head_dim).transpose(1, 2)
        
        # 计算位置编码，暂时不用管，与其他的RoPE计算方式类似，最终得到旋转之后的 q_pe, k_pe
        freqs_cis = precompute_freqs_cis(self.qk_rope_head_dim, compressed_kv_normed_cache.shape[1], self.rope_theta, use_scaled=False).to(q_t_R.device)
        q_t_R, k_R = apply_rotary_emb(
            q_t_R.repeat(1, compressed_kv_normed_cache.shape[1], 1, 1),
            k_pe_cache.unsqueeze(2),
            freqs_cis,
        )
        q_t_R = q_t_R[:, -1:, :, :].transpose(1, 2)
        k_R = k_R.transpose(1, 2).repeat(1, self.num_heads, 1, 1)

        attn_R = torch.matmul(q_t_R, k_R.transpose(2, 3))
        attn_C = torch.matmul(q_t_C, k_C.transpose(2, 3))

        attn_weights = (attn_R + attn_C) * self.softmax_scale
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_t_C.dtype)

        # 将注意力权重和v相乘，得到注意力输出，维度为[batch_size, num_heads, q_len, v_head_dim=128]
        attn_output = torch.matmul(attn_weights, v_C)

        # 将最后一个维度展开，得到[batch_size, num_heads, q_len, v_head_dim=128]
        attn_output = attn_output.transpose(1, 2).reshape(
            bsz, q_len, self.num_heads * self.v_head_dim
        )

        # 将注意力输出和W^O相乘，得到最终的输出，维度为[batch_size, q_len, hidden_size=7168]
        output = self.W_O(attn_output)

        return output, attn_weights, compressed_kv_normed_cache, k_pe_cache

mla_vanilla = DeepseekV2AttentionVanilla()

batch_size = 6
kv_len = 640

hidden_states = torch.randn([batch_size, 1, mla_vanilla.hidden_size])
compressed_kv_normed_cache = torch.randn([batch_size, kv_len, mla_vanilla.kv_lora_rank])
k_pe_cache = torch.randn([batch_size, kv_len, mla_vanilla.qk_rope_head_dim])

import time
start_time = time.time()
for i in range(100):
    output_vanilla, attn_weights, compressed_kv_normed_cache, k_pe_cache = mla_vanilla.run_decode(
        hidden_states, compressed_kv_normed_cache, k_pe_cache
    )
    # print('output_vanilla.shape', output_vanilla.shape)
    # print('attn_weights.shape', attn_weights.shape)
    # print('compressed_kv_normed_cache.shape', compressed_kv_normed_cache.shape)
    # print('k_pe_cache.shape', k_pe_cache.shape)
    # print('-'*70)
end_time = time.time()
print('time', end_time - start_time)

time 55.68700408935547


In [2]:
from transformers import Qwen2ForCausalLM




## MLA 吸收矩阵版

In [108]:
from torch import nn
class DeepseekV2AttentionMatAbsorbDecode(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_size = 7168
        self.num_heads = 128

        self.q_lora_rank = 1536
        self.qk_rope_head_dim = 64
        self.kv_lora_rank = 512
        self.v_head_dim = 128
        self.qk_nope_head_dim = 128
        self.q_head_dim = 192  # 192 = 128 + 64 = config.qk_nope_head_dim + config.qk_rope_head_dim

        self.rope_theta = 10000
        self.q_a_layernorm = DeepseekV2RMSNorm(self.q_lora_rank)
        self.softmax_scale = self.q_head_dim ** (-0.5)

        # W^DQ ~ [7168, 1536]
        self.W_DQ = nn.Linear(self.hidden_size, self.q_lora_rank, bias=False)
        # W^UQ ~ [1536, 128*128]
        self.W_UQ = nn.Linear(self.q_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False)
        # W^QR ~ [1536, 128*64]
        self.W_QR = nn.Linear(self.q_lora_rank, self.num_heads * self.qk_rope_head_dim, bias=False)
        # W^KR ~ [1536, 64]
        self.W_KR = nn.Linear(self.q_lora_rank, self.qk_rope_head_dim, bias=False)
        # W^DKV ~ [7168, 512]
        self.W_DKV = nn.Linear(self.hidden_size, self.kv_lora_rank, bias=False)
        # W^UK ~ [512, 128*128]
        self.W_UK = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False)
        # W^UV ~ [512, 128*128]
        self.W_UV = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, bias=False)
        # W^O ~ [128*128, 7168]
        self.W_O = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False)

        # 由于nn.Linear初始化的时一个对象，权重矩阵只是对象中的一个类，没法直接两个矩阵相乘
        # 所以需要用.weight来取出来，而且因为.weight的维度与初始化是反的，所以需要用t()来转置
        # W_UQ_absorb ~ [1536, 128, 128]
        W_UQ_absorb = self.W_UQ.weight.t().view(self.q_lora_rank, self.num_heads, self.qk_nope_head_dim)
        # W_UK_absorb ~ [512, 128, 128]
        W_UK_absorb = self.W_UK.weight.t().view(self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim)
        # W_UV_absorb ~ [512, 128, 128]
        W_UV_absorb = self.W_UV.weight.t().view(self.kv_lora_rank, self.num_heads, self.v_head_dim)
        # W_O_absorb ~ [7168, 128, 128]
        W_O_absorb = self.W_O.weight.view(self.hidden_size, self.num_heads, self.v_head_dim)

        # 吸收矩阵：将W_UQ和W_UK合并，得到新的W_UQK，维度为[1536, 128, 128]
        # q~q_lora_rank  n~num_heads  d~qk_nope_head_dim  l~kv_lora_rank
        # 这里把n当做batch_size，也就是矩阵相乘不会影响的那个维度，矩阵qd与dl相乘，得到ql，加上刚才的n，所以得到qnl
        # 再将其flatten展平，得到[1536, 65536]
        self.W_UQK = torch.einsum("q n d, l n d -> q n l", W_UQ_absorb, W_UK_absorb).flatten(start_dim=1)
        # 吸收矩阵，将W_UV和W_O合并，得到新的W_UV_O，维度为[128, 512, 7168]
        # l~kv_lora_rank  n~num_heads  d~v_head_dim  h~hidden_size
        # 这里把n当做batch_size，也就是矩阵相乘不会影响的那个维度，矩阵ld与dh相乘，得到lh，加上刚才的n，并把n放到最前面，所以得到nlh
        # 再将其flatten展平，得到[65536, 7168]
        self.W_UV_O = torch.einsum("l n d, h n d -> n l h", W_UV_absorb, W_O_absorb).flatten(start_dim=0, end_dim=1)

    def run_decode(
        self,
        hidden_states: torch.Tensor,
        compressed_kv_normed_cache: torch.Tensor,
        k_pe_cache: torch.Tensor,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # 获取维度：[batch_size, query_length=1, hidden_size]
        bsz, q_len, _ = hidden_states.size()

        # 当前step输入的x，经过W_DQ，得到 [batch_size, 1, q_lora_rank]
        c_t_Q = self.q_a_layernorm(self.W_DQ(hidden_states))
        # 再经过W_UQ，得到 [batch_size, 1, num_heads=128 * qk_nope_head_dim=128]
        q_t_C = torch.matmul(c_t_Q, self.W_UQK)
        # 再经过W_QR，得到 [batch_size, 1, num_heads=128 * qk_rope_head_dim=64]
        q_t_R = self.W_QR(c_t_Q).view(bsz, -1, self.num_heads, self.qk_rope_head_dim)
        # 再经过W_KR，得到 [batch_size, 1, qk_rope_head_dim=64]
        # 将当前step的k_t_R添加到k_pe_cache的最后一个位置，得到新的k_pe_cache
        k_t_R = self.W_KR(c_t_Q)
        k_pe_cache = torch.cat([k_pe_cache, k_t_R], dim=1)
        # 将最后一个维度拆开，方便注意力计算
        q_t_C = q_t_C.view(bsz, q_len, self.num_heads, self.kv_lora_rank).transpose(1, 2)

        c_t_KV = self.W_DKV(hidden_states)
        compressed_kv_normed_cache = torch.cat([compressed_kv_normed_cache, c_t_KV], dim=1)
        
        # 计算位置编码，暂时不用管，与其他的RoPE计算方式类似，最终得到旋转之后的 q_pe, k_pe
        freqs_cis = precompute_freqs_cis(self.qk_rope_head_dim, compressed_kv_normed_cache.shape[1], self.rope_theta, use_scaled=False).to(q_t_R.device)
        q_t_R, k_R = apply_rotary_emb(
            q_t_R.repeat(1, compressed_kv_normed_cache.shape[1], 1, 1),
            k_pe_cache.unsqueeze(2),
            freqs_cis,
        )
        q_t_R = q_t_R[:, -1:, :, :].transpose(1, 2)
        k_R = k_R.transpose(1, 2).repeat(1, self.num_heads, 1, 1)

        attn_R = torch.matmul(q_t_R, k_R.transpose(2, 3))
        attn_C = torch.matmul(q_t_C, compressed_kv_normed_cache.unsqueeze(1).transpose(2, 3))

        attn_weights = (attn_R + attn_C) * self.softmax_scale
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_t_C.dtype)

        # attn_weights * c^KV * W^UVO
        attn_output = torch.matmul(
            attn_weights.squeeze(2),  # [bsz, 128, kv_len]
            compressed_kv_normed_cache,  # [bsz, kv_len, 512]
        ).reshape(bsz, self.num_heads * self.kv_lora_rank)
        output = torch.matmul(attn_output, self.W_UV_O,)  # W_UV_O ~ [65536, 7168]

        return output, attn_weights, compressed_kv_normed_cache, k_pe_cache
        

bsz = 6
kv_len = 640
page_size = 16

hidden_states = torch.randn([bsz, 1, 7168])
compressed_kv_normed_cache = torch.randn([bsz, kv_len, 512])
k_pe_cache = torch.randn([bsz, kv_len, 64])

mla_mat_absorb = DeepseekV2AttentionMatAbsorbDecode()
import time
start_time = time.time()
for i in range(100):
    output_vanilla, attn_weights, compressed_kv_normed_cache, k_pe_cache = mla_mat_absorb.run_decode(
        hidden_states, compressed_kv_normed_cache, k_pe_cache
    )
    # print('output_vanilla.shape', output_vanilla.shape)
    # print('attn_weights.shape', attn_weights.shape)
    # print('compressed_kv_normed_cache.shape', compressed_kv_normed_cache.shape)
    # print('k_pe_cache.shape', k_pe_cache.shape)
    # print('-'*70)
end_time = time.time()
print('time', end_time - start_time)

time 27.271135807037354


## 计算量对比

In [140]:
# 实际吸收后的矩阵更大了，计算量更多了，但由于n的存在，序列越长，总体计算量越小
n=20000
W_UQK = 1536*128*128 + 512*128*128*n + 128*128*n
W_UQK_absorbed = 1536*128*512 + 128*512*n
W_UV_O = 512*128*128*n + 128*128*n + 128*128*7168
W_UV_O_absorbed = 128*512*7168 + 128*512*n
print('W_UQK吸收前：', W_UQK)
print('W_UQK吸收后：', W_UQK_absorbed)
print('W_UV_O吸收前：', W_UV_O)
print('W_UV_O吸收后：', W_UV_O_absorbed)
print('全部吸收前：', W_UQK + W_UV_O)
print('全部吸收后：', W_UQK_absorbed + W_UV_O_absorbed)


W_UQK吸收前： 168125005824
W_UQK吸收后： 1411383296
W_UV_O吸收前： 168217280512
W_UV_O吸收后： 1780482048
全部吸收前： 336342286336
全部吸收后： 3191865344
