## MLA


In [9]:
import torch 
from dataclasses import dataclass
import torch.nn as nn
import torch.nn.functional as F

In [10]:
@dataclass
class MLAconfig():
    # MHA
    hidden_dim: int
    num_heads: int
    max_position_embeddings: int
    rope_theta: float

    attention_dropout: float
    

    # MLA
    q_lora_dim: int 
    kv_lora_dim: int
    v_head_dim: int
    qk_nope_head_dim: int
    qk_rope_head_dim: int
    attention_bias: bool
    
   


    

In [11]:
class DeepseekV2RMSNorm(nn.Module):
    def __init__ (self, hidden_size: int, eps=1e-6):
        super().__init__()
        self.eps = eps # 避免出现除0情况
        self.weight = nn.Parameter(torch.ones(int(hidden_size)))
    
    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = torch.mean(hidden_states**2, dim=-1, keepdim=True)
        hidden_states = hidden_states / torch.sqrt(variance + self.eps)
        hidden_states = hidden_states * self.weight
        
        return hidden_states.to(input_dtype) 

    


class DeepseekV2RotaryEmbedding(nn.Module):
    def  __init__(self, config: MLAconfig, max_position_embeddings=2048, base=10000, device=None):
        super(DeepseekV2RotaryEmbedding, self).__init__()
        self.head_dim = config.hidden_dim // config.num_heads
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim)
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # 较小索引位置对应较低频率
        # 较大的索引位置有较高的频率

        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, 
            device=self.inv_freq.device, 
            dtype=torch.get_default_dtype()
        )
        self.max_seq_len_cache = None

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )
        freq = torch.outer(t, self.inv_freq.to(t.dtype))
        emb = torch.cat((freq, freq), dim=-1)   
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [batch_size, seq_len, hidden_dim]
        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype)
        )
    
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)
    
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    # q, k: [batch_size, num_heads, seq_len, head_dim]
    # cos, sin: [seq_len, hidden_dim]
    # position_ids: [seq_len]
    # unsqueeze_dim: 1
    
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)

    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    b, h, s, d = k.shape
    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
            


In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config:MLAconfig):
        super(MultiHeadAttention, self).__init__()
        self.hidden_dim = config.hidden_dim
        self.num_heads = config.num_heads
        self.v_hidden_dim = config.v_head_dim
        self.out_proj = nn.Linear(
            self.num_heads * self.v_head_dim,
            self.hidden_dim,
            bias=None
        )
    

        



class MLA(nn.Module):
    def __init__(self, config:MLAconfig):
        super(MLA, self).__init__()
        self.hidden_dim = config.hidden_dim
        self.num_heads = config.num_heads
    
    # part1: MLA部分
        # q的压缩和分片
        self.q_lora_dim = config.q_lora_dim
        self.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
        self.q_down_proj = nn.Linear(self.hidden_dim, config.q_lora_dim, bias=False)
        self.q_layer_norm = DeepseekV2RMSNorm(hidden_size=self.q_lora_dim)
        self.q_up_proj = nn.Linear(config.q_lora_dim, 
                                   config.num_heads * self.qk_head_dim,
                                   bias=False
                                   )

        # k,v的压缩和分片
        
        self.kv_lora_dim = config.kv_lora_dim
        self.qk_rope_head_dim = config.qk_rope_head_dim
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.v_head_dim = config.v_head_dim
        self.kv_lora_dim = config.kv_lora_dim
        self.kv_down_proj = nn.Linear(
            self.hidden_dim, 
            config.kv_lora_dim + config.qk_rope_head_dim,
            bias=False
        )
        self.kv_layer_norm = DeepseekV2RMSNorm(self.kv_lora_dim)
        self.kv_up_proj = nn.Linear(
            config.kv_lora_dim,
            config.num_heads * (
                self.qk_head_dim - config.qk_rope_head_dim + config.v_head_dim        
            ),
            bias=False
        )

        self.v_head_dim = config.v_head_dim


    # part2: Rope部分
        self.rope_emb = DeepseekV2RotaryEmbedding(
            config,
            config.max_position_embeddings,
            config.rope_theta,
        )
    # part3 MHA部分
        self.attention_dropout = config.attention_dropout
        self.hidden_dim = config.hidden_dim
        self.num_heads = config.num_heads
        self.v_head_dim = config.v_head_dim

        self.out_proj = nn.Linear(
            self.num_heads * self.v_head_dim,
            self.hidden_dim,
            bias=False
        )
    # part4: KVcache
        

    def forward(self, hidden_dim, position_ids, attention_mask=None):
        batch_size, seq_len, _ = hidden_dim.size()
        # hidden_dim.shape = (batch_size, seq_len, hidden_dim)

        # q的压缩和分片 
        q = self.q_down_proj(hidden_dim)
        # q.shape = (batch_size, seq_len, q_lora_dim)
        q = self.q_layer_norm(q)
        # q.shape = (batch_size, seq_len, q_lora_dim)
        q = self.q_up_proj(q)
        # q.shape = ( batch_size, seq_len, num_heads * (qk_nope_head_dim + qk_rope_head_dim) )
        # num_heads * (qk_nope_head_dim + qk_rope_head_dim) = num_heads * qk_head_dim
        # q.shape = ( batch_size, seq_len, num_heads * qk_head_dim )
        q_view_tran = q.view(batch_size, seq_len, self.num_heads, self.qk_head_dim).transpose(1, 2)
        # q_view_tran.shape = (batch_size, num_heads, seq_len, qk_head_dim)
        # q_view_tran.shape = (batch_size, num_heads, seq_len, (qk_nope_head_dim + qk_rope_head_dim))
        q_nope, q_rope = torch.split(q_view_tran, 
                                     [self.qk_nope_head_dim, self.qk_rope_head_dim], 
                                     dim=-1
                                     )
        # q_nope.shape = (batch_size, num_heads, seq_len, qk_nope_head_dim)
        # q_rope.shape = (batch_size, num_heads, seq_len, qk_rope_head_dim)

        # k,v的压缩和分片
        kv = self.kv_down_proj(hidden_dim)
        # kv.shape = (batch_size, seq_len, kv_lora_dim + qk_rope_head_dim)
        kv_rope, kv_lora = torch.split(
            kv,
            [self.qk_rope_head_dim, self.kv_lora_dim],
            dim=-1
        )
        # kv_rope.shape = (batch_size, seq_len, qk_rope_head_dim)
        # kv_lora.shape = (batch_size, seq_len, kv_lora_dim)
        kv_rope = kv_rope.view(
            batch_size, seq_len, 1, self.qk_rope_head_dim
        ).transpose(1, 2)
        # kv_rope.shape = (batch_size, 1, seq_len, qk_rope_head_dim)
        kv_lora = self.kv_layer_norm(kv_lora)
        # kv_lora.shape = (batch_size, seq_len, kv_lora_dim)
        kv_lora = self.kv_up_proj(kv_lora)
        # kv_lora.shape = (batch_size, seq_len, qk_head_dim - qk_rope_head_dim + v_head_dim)
        # kv_lora.shape = (batch_size, seq_len, qk_nope_head_dim + v_head_dim)
        kv_view_tran = kv_lora.view(
            batch_size, seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
        ).transpose(1, 2)
        # kv_view_tran.shape = (batch_size, num_heads, seq_len, qk_nope_head_dim + v_head_dim)

        k_nope, v_dim = torch.split(
            kv_view_tran, 
            [self.qk_nope_head_dim, self.v_head_dim],
            dim=-1
        )
        # k_nope.shape = (batch_size, num_heads, seq_len, qk_nope_head_dim)
        # v_dim.shape = (batch_size, num_heads, seq_len, v_head_dim)

        # Rope部分
        
        kv_seq_len = v_dim.shape[-2]
        
        cos, sin = self.rope_emb(
            v_dim,
            seq_len=kv_seq_len
        )
        # kvrope.shape = (batch_size, 1, seq_len, qk_rope_head_dim)
        q_rope, kv_rope = apply_rotary_pos_emb(
            q_rope, kv_rope, cos, sin, position_ids
        )
        # q_rope.shape = (batch_size, num_heads, seq_len, qk_rope_head_dim)
        # k_rope.shape = (batch_size, 1, seq_len, qk_rope_head_dim)

        # MHA部分
        q_head_dim = torch.concat(
            [q_nope, q_rope], dim=-1
        )
        # q_head_dim.shape = (batch_size, num_heads, seq_len, qk_head_dim) 

        # k_rope.shape = (batch_size, 1, seq_len, qk_rope_head_dim) 
        # k_nope.shape = (batch_size, num_heads, seq_len, qk_nope_head_dim)
        # 要形状相同
        k_head_dim = torch.concat(
            [k_nope, kv_rope.expand(-1, self.num_heads, -1, -1)], 
            dim=-1
        )
        # k_head_dim.shape = (batch_size, num_heads, seq_len, qk_head_dim)
        # q_head_dim.shape = (batch_size, num_heads, seq_len, qk_head_dim)
        atten_weights = torch.matmul(q_head_dim, k_head_dim.transpose(-2, -1))
        # atten_weights.shape = (batch_size, num_heads, seq_len, seq_len)
        atten_weights = atten_weights / (self.qk_head_dim ** 0.5)
        # atten_weights.shape = (batch_size, num_heads, seq_len, seq_len)
        if attention_mask is not None:
            atten_weights = torch.masked_fill(atten_weights, attention_mask, float("-inf"))
        atten_weights = F.softmax(atten_weights, dim=-1).to(q_head_dim.dtype)

        atten_output = torch.matmul(atten_weights, v_dim)
        # atten_output.shape = (batch_size, num_heads, seq_len, v_head_dim)
        atten_output = atten_output.transpose(1, 2).reshape(
            batch_size, seq_len, self.num_heads * self.v_head_dim
        )
        # atten_output.shape = (batch_size, seq_len, num_heads * v_head_dim)
        # atten_output.shape = (batch_size, seq_len, hidden_dim)

        outputs = self.out_proj(atten_output)
    
        return outputs, atten_weights


## 测试代码

In [13]:
def test():
    config = MLAconfig(
        hidden_dim=512,
        num_heads=8,
        max_position_embeddings=2048,
        rope_theta=30,
        attention_dropout=0.1,
        q_lora_dim=128,
        kv_lora_dim=128,
        v_head_dim=64,
        qk_nope_head_dim=64,
        qk_rope_head_dim=64,
        attention_bias=False,
    )
    mla = MLA(config)
    x = torch.randn(2, 2048, 512)
    position_ids = torch.arange(
        config.max_position_embeddings,
    ).unsqueeze(0).expand(
        x.size(0), -1
    ) # [batch_size, seq_len]

    atten_output, atten_weights = mla(x, position_ids=position_ids)
    print('the atten_output is : \n' ,atten_output.shape)
    print( 'the atten_weights is: \n',atten_weights.shape)

test()

the atten_output is : 
 torch.Size([2, 2048, 512])
the atten_weights is: 
 torch.Size([2, 8, 2048, 2048])
