In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

In [None]:
class DeepseekV2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
    
class DeepseekV2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # 较小索引位置对应较低频率
        # 较大的索引位置有较高的频率
        
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
        )
        self.max_seq_len_cached = None

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )

        freqs = torch.outer(t, self.inv_freq.to(t.device))
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)

    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    b, h, s, d = k.shape
    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed




In [None]:
@dataclass
class DeepseeekConfig:
    hidden_dim:int
    num_heads:int
    max_posistion_embeddings:int
    rope_theta:float
    attetion_dropout:float

    q_lora_rank:int
    kv_lora_rank:int
    qk_rope_rank:int
    qk_rope_head_dim: int
    v_head_dim:int
    qk_nope_head_dim:int
    attetion_bias:bool

class MLA(nn.Module):
    def __init__(self,config:DeepseeekConfig):
        super().__init__()

        self.hidden_dim = config.hidden_dim
        self.num_heads = config.num_heads
        self.q_lora_rank = config.q_lora_rank
        self.kv_lora_rank = config.kv_lora_rank
        self.v_head_dim = config.v_head_dim
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.qk_rope_head_dim = config.qk_rope_head_dim
        self.num_heads = config.num_heads
        self.attetions_bias = config.attetion_bias
        self.Deepseekrms = DeepseekV2RMSNorm

        self.config = config
        
        self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
        # q的压缩
        self.q_down_proj = nn.Linear(self.hidden_dim,self.q_lora_rank,bias=self.attetions_bias)
        self.q_down_norm = DeepseekV2RMSNorm(self.q_lora_rank)

        # q的上升 这里需要做分解，一部分是待压缩的数据，一部分是不需要压缩的数据
        self.q_up_proj = nn.Linear(self.q_lora_rank,self.q_head_dim * self.num_heads,bias=False)
        
        # kv 的压缩, 一个是 需要过旋转编码的压缩矩阵 一个是 需要不需要过旋转编码的KV矩阵
        self.kv_down_proj = nn.Linear(self.hidden_dim, self.kv_lora_rank + self.qk_rope_head_dim,bias=self.attetions_bias)
        self.kv_down_norm = DeepseekV2RMSNorm(self.kv_lora_rank)

        # kv 的升维 这个数组后续需要拆分
        self.kv_up_proj = nn.Linear(self.qk_nope_head_dim,self.num_heads * (self.v_head_dim + self.qk_nope_head_dim),bias=False)
        
        # output
        self.output_proj = nn.Linear(self.v_head_dim * self.num_heads,self.hidden_dim,bias=self.attetions_bias)

         # 初始化 rope 的参数
        self.rotary_emb = DeepseekV2RotaryEmbedding(
            self.qk_rope_head_dim,
            self.max_postion_embeddings,
            self.rope_theta,
        )

    def forward(self,x,mask = None):
        bsz , seq , _ = x.size()

        # 首次压缩
        q_donw = self.q_down_proj(x)

        q_donw = self.Deepseekrms(q_donw)
        # 升维度
        

