In [None]:
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import functional as F

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(dim))  # 可学习参数γ

    def _norm(self, x: torch.Tensor):
        # 计算平方均值的根 (RMS)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        return self.gamma * self._norm(x.float()).type_as(x)


@dataclass
class Config:
    hidden_size: int
    num_heads: int
    max_seq_len: int
    rope_theta: float

    attention_dropout: float

    q_lora_rank: int
    qk_rope_head_dim: int
    kv_lora_rank: int

    v_head_dim: int
    qk_nope_head_dim: int
    attention_bias: bool


class MLA(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_heads
        self.max_seq_len = config.max_seq_len
        self.rope_theta = config.rope_theta

        self.attention_dropout = config.attention_dropout

        self.q_lora_rank = config.q_lora_rank
        self.qk_rope_head_dim = config.qk_rope_head_dim
        self.kv_lora_rank = config.kv_lora_rank

        self.v_head_dim = config.v_head_dim
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.attention_bias = config.attention_bias

        self.out_proj = nn.Linear(
            self.num_heads * self.v_head_dim, self.hidden_size, bias=False
        )
        # MLA压缩部分，一般是从7168 -> 1536，压缩比是 4.67
        self.q_down_proj = nn.Linear(
            self.hidden_size, self.q_lora_rank, bias=config.attention_bias
        )
        self.q_down_norm = RMSNorm(self.q_lora_rank)

        self.kv_down_proj = nn.Linear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=config.attention_bias,
        )
        self.kv_down_norm = RMSNorm(self.kv_lora_rank)


        # 升维部分
        self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
        self.q_up_proj = nn.Linear(
            self.q_lora_rank,
            self.num_heads * self.q_head_dim,
            bias=config.attention_bias,
        )