In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

In [None]:
@dataclass
class DeepseekArgs:
    max_batch_size: int = 8  # 最大 batch size
    max_seq_len: int = 4096 * 4  # 最大序列长度
    dim: int = 2048  # 模型维度
    n_heads: int = 16  # Attention 头数

    q_lora_rank: int = 0  # Query 的 LoRA rank
    kv_lora_rank: int = 512  # Key, Value 的 LoRA rank
    qk_nope_head_dim: int = 128  # Query, Key 的 NoPE 维度
    qk_rope_head_dim: int = 64  # Query, Key 的 RoPE 维度
    v_head_dim: int = 128  # Value 的维度

    original_seq_len: int = 4096  # 原始序列长度
    rope_theta: float = 10000.0  # 频率参数

In [None]:
def precompute_freqs_cis(args: DeepseekArgs) -> torch.Tensor:
    dim = args.qk_rope_head_dim
    seqlen = args.max_seq_len
    base = args.rope_theta

    freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
    freqs = torch.outer(torch.arange(seqlen), freqs)

    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def norm(self, x):
        x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return x

    def forward(self, x):
        x = self.norm(x.float()).type_as(x)
        return self.weight * x