# 从零构建一个LLM

## 构建模型

In [1]:
import math
import json
import os
import time
import concurrent.futures
import copy
from typing import Optional, Tuple, Dict, List
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch.utils.checkpoint import checkpoint

In [2]:
# 参数配置文件
@dataclass
class DeepSeekV2Config:
    # 基础参数
    vocab_size: int = 151936
    hidden_size: int = 4096
    num_hidden_layers: int = 32
    num_attention_heads: int = 32
    max_position_embeddings: int = 4096
    initializer_range: float = 0.02   ###############
    max_epochs: int = 100
    
    # MLA 参数
    q_lora_rank: int = 1536
    qk_rope_head_dim: int = 64
    kv_lora_rank: int = 512
    v_head_dim: int = 128
    qk_nope_head_dim: int = 128
    rope_theta: float = 10000.0
    attention_bias: bool = False
    
    # MoE 参数
    expert_number: int = 8
    top_k: int = 2
    shared_expert_number: int = 2
    moe_load_balance_alpha: float = 0.01
    expert_dropout: float = 0.1
    
    # 训练参数
    batch_size: int = 4
    seq_len: int = 2048
    lr: float = 5e-5
    weight_decay: float = 0.1   
    warmup_steps: int = 1000
    total_steps: int = 100000
    valid_steps: int = 100
    grad_accum_steps: int = 1   
    save_every: int = 1000
    validation_batch: int = 50      ###
    async_validation: bool = True   ###
    
    # 其他参数
    attention_dropout: float = 0.1
    hidden_dropout: float = 0.1
    gradient_checkpointing: bool = False
    tie_word_embeddings: bool = True
    output_hidden_states: bool = False
    output_attentions: bool = False
    output_router_logits: bool = False

    # 日志和检查点
    log_dir: str = "model/logs"
    checkpoint_dir: str = "model/checkpoints"
    experiment_name: str = "llm_experiment"
    
    def save(self, path: str):
        """保存配置到JSON文件"""
        with open(path, 'w') as f:
            json.dump(asdict(self), f, indent=4)
    
    @classmethod
    def load(cls, path: str):
        """从JSON文件加载配置"""
        with open(path, 'r') as f:
            data = json.load(f)
        return cls(**data)

In [3]:
# utils
class DeepseekV2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
    
class DeepseekV2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
        )
        self.max_seq_len_cached = max_position_embeddings

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )

        freqs = torch.outer(t, self.inv_freq.to(t.device))
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len is not None and seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

# 旋转位置编码MALA
def apply_rotary_pos_emb_v2(q: torch.Tensor, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)

    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    return q_embed

# 学习率调度器(init_lr=1e-6)
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_steps, total_steps, init_lr=1e-6):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.current_step = 0
        self.init_lr = init_lr
        self.base_lr = optimizer.param_groups[0]['lr'] 
        
        # Set initial learning rates
        for group in self.optimizer.param_groups:
            group.setdefault('initial_lr', group['lr'])
        
    def step(self):
        self.current_step += 1
        lr = self._get_lr()
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
            
    def _get_lr(self):
        if self.current_step < self.warmup_steps:
            # Linear warmup
            return self.init_lr + (self.base_lr - self.init_lr) * self.current_step / self.warmup_steps
        else:
            # Cosine decay
            progress = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            return 0.5 * (1 + math.cos(math.pi * progress)) * self.optimizer.param_groups[0]['initial_lr']
    
    def state_dict(self):
        return {
            'current_step': self.current_step,
            'warmup_steps': self.warmup_steps,
            'total_steps': self.total_steps,
        }
    
    def load_state_dict(self, state_dict):
        self.current_step = state_dict['current_step']
        self.warmup_steps = state_dict['warmup_steps']
        self.total_steps = state_dict['total_steps']

In [4]:
#MoE: expert_hidden_dim = hidden_size * 2~3 (提升空间)
class FFNExpert(nn.Module):
    def __init__(self, hidden_dim, expert_dropout):
        super().__init__()
        mid_dim = hidden_dim * 8 // 3

        self.up = nn.Linear(hidden_dim, mid_dim, bias=False)
        self.down = nn.Linear(mid_dim, hidden_dim, bias=False)
        self.gate = nn.Linear(hidden_dim, mid_dim, bias=False)
        self.dropout = nn.Dropout(expert_dropout)

    def forward(self, x):
        output = self.dropout(
            self.down(
                F.silu(self.gate(x)) * self.up(x)
            )
        )
        return output

class MOERouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate = nn.Linear(config.hidden_size, config.expert_number)
        self.expert_number = config.expert_number
        self.top_k = config.top_k

    def forward(self, x):
        router_logits = self.gate(x)
        router_probs = F.softmax(router_logits, dim=1, dtype=torch.float)
        
        router_weights, selected_expert_indices = torch.topk(
            router_probs,
            self.top_k,
            dim=-1,
        )
        
        router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)
        router_weights = router_weights.to(x.dtype)
        
        expert_mask = F.one_hot(
            selected_expert_indices, 
            num_classes=self.expert_number,
        )
        
        expert_mask = expert_mask.permute(2, 1, 0)
        
        return router_logits, router_weights, selected_expert_indices, expert_mask, router_probs

class SparseMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.top_k = config.top_k
        self.hidden_dim = config.hidden_size
        self.expert_number = config.expert_number

        self.experts = nn.ModuleList(
            [FFNExpert(config.hidden_size, config.expert_dropout) for _ in range(config.expert_number)]
        )
        self.router = MOERouter(config)

    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.size()
        hidden_states = x.view(-1, hidden_dim)
        
        router_logits, router_weights, _, expert_masks, router_probs= self.router(hidden_states)
        
        final_hidden_states = torch.zeros(
            (batch_size * seq_len, hidden_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device
        )

        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]
            current_expert_mask = expert_masks[expert_idx]
            
            router_weight_idx, top_x = torch.where(current_expert_mask)
            
            current_states = hidden_states[top_x, :]
            current_states = expert_layer(current_states)
            
            current_token_router_weight = router_weights[top_x, router_weight_idx].unsqueeze(-1)
            current_hidden_states = current_states * current_token_router_weight
            
            final_hidden_states.index_add_(0, top_x, current_hidden_states)
        
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
        router_logits = router_logits.view(batch_size, seq_len, -1)

        return final_hidden_states, router_logits, expert_masks, router_probs  


class ShareExpertMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.moe_model = SparseMOE(config)
        self.shared_experts = nn.ModuleList([
            FFNExpert(config.hidden_size, config.expert_dropout) 
            for _ in range(config.shared_expert_number)
        ])

    def forward(self, x):
        sparse_moe_out, router_logits, expert_masks, router_probs= self.moe_model(x)
        
        shared_experts_out = torch.stack(
            [expert(x) for expert in self.shared_experts], dim=0
        ).sum(dim=0)
        
        moe_loss = self._calculate_moe_loss(expert_masks, router_probs)
        
        return sparse_moe_out + shared_experts_out, router_logits, moe_loss
    

    def _calculate_moe_loss(self, expert_masks, router_probs):
        """
        expert_masks: [num_experts, top_k, batch*seq_len]
        router_probs: [batch*seq_len, num_experts]
        """
        # importance / router_fraction
        router_fraction = router_probs.mean(dim=0)  # [num_experts]

        # load / expert_fraction
        load = expert_masks.float().mean(dim=[1,2])  # [num_experts]

        # load balancing loss
        moe_loss = self.config.moe_load_balance_alpha * torch.sum(load * router_fraction)
        
        return moe_loss


In [5]:
# MLA注意力机制
class MLAV2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta

        self.q_lora_rank = config.q_lora_rank
        self.qk_rope_head_dim = config.qk_rope_head_dim
        self.kv_lora_rank = config.kv_lora_rank
        self.v_head_dim = config.v_head_dim
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim

        self.q_down_proj = nn.Linear(
            self.hidden_size,
            self.q_lora_rank,
            bias=config.attention_bias,
        )
        self.q_down_layernorm = DeepseekV2RMSNorm(self.q_lora_rank)

        self.q_up_proj = nn.Linear(
            self.q_lora_rank,
            self.num_heads * self.q_head_dim,
            bias=False,
        )

        self.kv_down_proj = nn.Linear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=config.attention_bias,
        )
        self.kv_down_layernorm = DeepseekV2RMSNorm(self.kv_lora_rank)
        self.kv_up_proj = nn.Linear(
            self.kv_lora_rank,
            self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
            bias=False,
        )

        self.o_proj = nn.Linear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=config.attention_bias,
        )

        self.rotary_emb = DeepseekV2RotaryEmbedding(
            self.qk_rope_head_dim,
            self.max_position_embeddings,
            self.rope_theta,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        compressed_kv: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        
        bsz, q_len, _ = hidden_states.size()

        # Query projection and split
        q = self.q_up_proj(self.q_down_layernorm(self.q_down_proj(hidden_states)))
        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        # Key/Value projection and split
        if compressed_kv is None:
            compressed_kv = self.kv_down_proj(hidden_states)  # [B, L, kv_lora_rank + qk_rope_head_dim]
            raw_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
            # 对低秩部分做 layernorm（和你在 init 中定义的一致）
            compressed_kv = self.kv_down_layernorm(raw_kv)  # [B, seq, kv_lora_rank]
        else:
            # 兼容上层传入的合并张量（假设传入的就是 kv_lora_rank + qk_rope_head_dim 的合并形式）
            raw_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
            compressed_kv = raw_kv  # 不在此处额外做 layernorm（假定上层缓存已处理）

        kv_seq_len = compressed_kv.size(1)
        k_pe = k_pe.view(bsz, kv_seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)

        # Split kv_up_proj into heads
        kv_up_proj = self.kv_up_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
        q_absorb = kv_up_proj[:, :self.qk_nope_head_dim, :]
        out_absorb = kv_up_proj[:, self.qk_nope_head_dim:, :]

        # Apply RoPE
        cos, sin = self.rotary_emb(q_pe, seq_len=q_len)
        q_pe = apply_rotary_pos_emb_v2(q_pe, cos, sin, position_ids)

        # Project q_nope
        q_nope = torch.matmul(q_nope, q_absorb)

        # Attention score calculation
        attn_weights = (
            torch.matmul(q_pe, k_pe.mT)
            + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)
        ) / math.sqrt(self.q_head_dim)

        # Apply causal mask（合并causal mask和padding mask） 
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # [bsz, 1, 1, seq_len]
            attention_mask = attention_mask.expand(-1, self.num_heads, q_len, -1)
            # 创建causal mask
            causal_mask = torch.tril(
                torch.ones((q_len, kv_seq_len), device=attn_weights.device, dtype=torch.bool)
            )
            causal_mask = causal_mask.view(1, 1, q_len, kv_seq_len)
            attn_weights = attn_weights.masked_fill(~causal_mask, float("-inf"))

        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(q_nope.dtype)

        # Compute attention output
        attn_output = torch.einsum("bhql,blc->bhqc", attn_weights, compressed_kv)
        attn_output = torch.matmul(attn_output, out_absorb.mT)

        # Merge heads and project
        attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)

        return attn_output, attn_weights

In [6]:
# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Pre-LN architecture
        self.ln1 = DeepseekV2RMSNorm(config.hidden_size)
        self.attn = MLAV2(config)
        self.ln2 = DeepseekV2RMSNorm(config.hidden_size)
        self.moe = ShareExpertMOE(config)
        
        self.dropout = nn.Dropout(config.hidden_dropout)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        compressed_kv: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_router_logits: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.ln1(hidden_states)
        attn_output, attn_weights = self.attn(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            # compressed_kv=compressed_kv,
        )
        hidden_states = residual + self.dropout(attn_output)

        # MoE FFN
        residual = hidden_states
        hidden_states = self.ln2(hidden_states)
        ffn_output, router_logits, moe_loss = self.moe(hidden_states)  
        hidden_states = residual + self.dropout(ffn_output)

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (attn_weights,)
        if output_router_logits:
            outputs += (router_logits,)
            
        return outputs, moe_loss, router_logits

In [7]:
class DeepSeekV2Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # self.gradient_checkpointing = config.gradient_checkpointing ## 梯度检查点
        
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.embed_positions = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_hidden_layers)
        ])
        
        self.norm = DeepseekV2RMSNorm(config.hidden_size)
        
        if config.tie_word_embeddings:
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
            self.lm_head.weight = self.embed_tokens.weight
        else:
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
            
        self.gradient_checkpointing = False
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 1.0 / math.sqrt(module.weight.size(1))  # hidden_size 自适应
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):
            std = 1.0 / math.sqrt(module.weight.size(1))
            nn.init.normal_(module.weight, mean=0.0, std=std)

        elif isinstance(module, DeepseekV2RMSNorm):
            nn.init.ones_(module.weight) 

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
    ) -> Dict[str, torch.Tensor]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        output_router_logits = output_router_logits if output_router_logits is not None else self.config.output_router_logits

        # Prepare inputs
        batch_size, seq_length = input_ids.shape
        device = input_ids.device
        
        if position_ids is None:
            position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
        
        if attention_mask is None:
            attention_mask = torch.ones((batch_size, seq_length), device=device)
        
        # Embed positions and tokens
        inputs_embeds = self.embed_tokens(input_ids)
        position_embeds = self.embed_positions(position_ids)
        hidden_states = inputs_embeds + position_embeds
        
        # Initialize output containers
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_router_logits = () if output_router_logits else None

        # Forward through layers
        for layer in self.layers:
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs, moe_loss, router_logits = layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                output_attentions=output_attentions,
                output_router_logits=output_router_logits,
            )
            
            hidden_states = layer_outputs[0]
            
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                
            if output_router_logits:
                all_router_logits = all_router_logits + (router_logits,)
        
        hidden_states = self.norm(hidden_states)
        
        # Add last hidden state
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)
            
        # Compute logits
        logits = self.lm_head(hidden_states)
        
        return {
            "logits": logits,
            "moe_loss": moe_loss,
            "hidden_states": all_hidden_states,
            "attentions": all_self_attentions,
            "router_logits": all_router_logits,
        }
    
    @torch.no_grad()
    def generate(
        self,
        input_ids: torch.LongTensor,
        max_new_tokens: int = 50,
        temperature: float = 1.0,
        top_k: int = None,
        top_p: float = None,
        eos_token_id: int = None,
    ) -> torch.LongTensor:
        """
        Decoder-only 自回归生成
        """
        self.eval()
        device = input_ids.device
        output_ids = input_ids.clone()

        for _ in range(max_new_tokens):
            # forward 获取 logits
            logits = self.forward(output_ids)["logits"]  # [B, seq_len, vocab]
            next_token_logits = logits[:, -1, :] / temperature

            # top-k 过滤
            if top_k is not None:
                topk_vals, topk_indices = torch.topk(next_token_logits, top_k, dim=-1)
                mask = next_token_logits < topk_vals[:, -1][:, None]
                next_token_logits = next_token_logits.masked_fill(mask, -float('Inf'))

            # top-p 过滤
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_logits[sorted_indices_to_remove] = -float('Inf')
                next_token_logits.scatter_(dim=-1, index=sorted_indices, src=sorted_logits)

            # 采样下一个 token
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)  # [B, 1]
            output_ids = torch.cat([output_ids, next_token], dim=-1)

            # 如果 EOS 出现则停止
            if eos_token_id is not None and (next_token == eos_token_id).all():
                break

        return output_ids


In [8]:
# 训练器
class LLMTrainer:
    def __init__(self, model, config, train_dataset, valid_dataset=None):
        self.model = model
        self.config = config
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        # 数据加载器
        self.train_loader = DataLoader(
            self.train_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=True, 
            collate_fn=self.collate_fn,
            pin_memory=True   # CPU->GPU 传输更快
        )

        if self.valid_dataset is not None:
            self.valid_loader = DataLoader(
                self.valid_dataset,
                batch_size=self.config.batch_size,
                shuffle=False,   # 验证集不需要 shuffle
                collate_fn=self.collate_fn,
                pin_memory=True
            )
        else:
            self.valid_loader = None

        self.validation_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        self.validation_future = None
        self.last_validation_result = None

        self.val_loss_ema = None  
        self.ema_decay = 0.9      # 衰减系数，可以调
    

        # 设备设置
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        
        # 优化器和学习率调度器
        self.optimizer = torch.optim.AdamW(
            model.parameters(), 
            lr=config.lr, 
            weight_decay=config.weight_decay
        )
        
        self.scheduler = WarmupCosineScheduler(
            self.optimizer, 
            config.warmup_steps, 
            config.total_steps
        )
        
        # 混合精度训练
        self.scaler = torch.cuda.amp.GradScaler()
        
        # TensorBoard 记录器
        log_dir = os.path.join(config.log_dir, config.experiment_name)
        self.writer = SummaryWriter(log_dir=log_dir)
        
        # 训练状态
        self.global_step = 0
        self.best_val_loss = float('inf')

    def train_step(self, batch):
        self.model.train()
        self.optimizer.zero_grad()
        
        input_ids, attention_mask = batch  # 解包batch
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        
        with torch.cuda.amp.autocast():
            outputs = self.model(input_ids, 
                                 attention_mask, 
                                 output_router_logits=True)
            logits = outputs["logits"]
            
            # 计算语言建模损失
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            lm_loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)), 
                shift_labels.view(-1)
            )
            
            # 添加MoE负载均衡损失
            moe_loss = outputs.get("moe_loss", 0.0)
            if isinstance(moe_loss, float):
                moe_loss = torch.tensor(moe_loss, device=self.device)

            total_loss = lm_loss + moe_loss

        # 专家使用率
        with torch.no_grad():
            router_logits = outputs.get("router_logits", None)

            if router_logits is not None and len(router_logits) > 0:
                # 取最后一层的 router logits
                last_router_logits = router_logits[-1]  # shape: [batch, seq, num_experts]
                router_probs = torch.softmax(last_router_logits, dim=-1)

                # 专家使用率
                expert_usage = router_probs.mean(dim=[0, 1])  # [num_experts]
                for i, usage in enumerate(expert_usage):
                    self.writer.add_scalar(f'router/expert_{i}_usage', usage.item(), self.global_step)

                # 熵
                entropy = -torch.sum(router_probs * torch.log(router_probs + 1e-8), dim=-1).mean()
                self.writer.add_scalar('router/entropy', entropy.item(), self.global_step)
            else:
                # 没有 router logits 的情况
                self.writer.add_scalar('router/entropy', 0.0, self.global_step)

        
        #反向传播
        self.scaler.scale(total_loss).backward()
        
        # 记录损失到 TensorBoard
        self.writer.add_scalar('train/lm_loss', lm_loss.item(), self.global_step)
        self.writer.add_scalar('train/moe_loss', moe_loss.item(), self.global_step)
        self.writer.add_scalar('train/total_loss', total_loss.item() * self.config.grad_accum_steps, self.global_step)
        self.writer.add_scalar('train/learning_rate', self.optimizer.param_groups[0]['lr'], self.global_step)
        
        # 梯度累积步骤
        if (self.global_step + 1) % self.config.grad_accum_steps == 0:
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),  max_norm=1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.scheduler.step()
            self.optimizer.zero_grad()
        
        # 监控数值稳定性
        if self.global_step % 100 == 0:
            # 检查激活值范围
            for name, param in self.model.named_parameters():
                if param.requires_grad and param.grad is not None:
                    grad_mean = param.grad.abs().mean().item()
                    if grad_mean > 1e3:  # 梯度爆炸
                        print(f"Warning: Large gradients in {name}: {grad_mean:.6f}")
            
            # 检查损失数值
            if total_loss.item() > 100:  # 损失异常高
                print(f"Warning: High loss {total_loss.item():.6f}")
        
        return {
            "lm_loss": lm_loss.item(),
            "total_loss": total_loss.item(),
            "moe_loss": moe_loss.item()
        }

    def collate_fn(self, batch):
            # train_dataset is dict，cloumns 'input_ids' and 'attention_mask'
            input_ids_list = [item['input_ids'] for item in batch]
            attention_mask_list = [item['attention_mask'] for item in batch]
            
            # 计算最大长度，不超过config.seq_len
            max_len = min(max([len(ids) for ids in input_ids_list]), self.config.seq_len)
            
            # 初始化填充后的张量
            padded_input_ids = torch.zeros((len(batch), max_len), dtype=torch.long)
            padded_attention_mask = torch.zeros((len(batch), max_len), dtype=torch.long)
            
            for i, (ids, mask) in enumerate(zip(input_ids_list, attention_mask_list)):
                l = min(len(ids), max_len)
                padded_input_ids[i, :l] = torch.tensor(ids[:l], dtype=torch.long)
                padded_attention_mask[i, :l] = torch.tensor(mask[:l], dtype=torch.long)
            
            return padded_input_ids, padded_attention_mask
    
    def get_validation_batch(self):
        """根据训练进度动态控制验证 batch 数"""
        progress = self.global_step / self.config.total_steps
        validation_batch = self.config.validation_batch
        
        if progress < 0.3:
            return max(1, int(validation_batch * 0.15))   # 前期少验证
        elif progress < 0.7:
            return max(1, int(validation_batch * 0.4))  # 中期适中
        else:
            return validation_batch  # 后期更稳定
        

    def update_val_loss_ema(self, new_val_loss):
        """更新指数滑动平均"""
        if self.val_loss_ema is None:
            self.val_loss_ema = new_val_loss
        else:
            self.val_loss_ema = (
                self.ema_decay * self.val_loss_ema
                + (1 - self.ema_decay) * new_val_loss
            )
        return self.val_loss_ema
    
    def async_validate(self):
        """异步执行检测"""
        if self.valid_dataset is None:
            return None

        # 深度拷贝
        model_copy = copy.deepcopy(self.model)
        model_copy.eval()

        total_loss = 0
        total_moe_loss = 0
        total_samples = 0
        

        max_batch = self.get_validation_batch()

        with torch.no_grad():
             for i, (input_ids, attention_mask) in enumerate(self.valid_loader):
                if i >= max_batch:  # 限制验证批次
                    break

                input_ids = input_ids.to(self.device)
                attention_mask = attention_mask.to(self.device)
                
                with torch.cuda.amp.autocast():
                    outputs = self.model(input_ids, 
                                         attention_mask, 
                                         output_router_logits=True)
                    logits = outputs["logits"]
                    
                    # 计算LM损失
                    shift_logits = logits[:, :-1, :].contiguous()
                    shift_labels = input_ids[:, 1:].contiguous()
                    loss = F.cross_entropy(
                        shift_logits.view(-1, shift_logits.size(-1)), 
                        shift_labels.view(-1)
                    )

                    # 获取MoE损失
                    moe_loss = outputs.get("moe_loss", 0.0)
                    if isinstance(moe_loss, float):
                        moe_loss = torch.tensor(moe_loss, device=self.device)
                
                total_loss += loss.item() * input_ids.size(0)
                total_moe_loss += moe_loss.item() * input_ids.size(0)
                total_samples += input_ids.size(0)
        
        avg_lm_loss = total_loss / total_samples
        avg_moe_loss = total_moe_loss / total_samples
        avg_total_loss = avg_lm_loss + avg_moe_loss
        ema_loss = self.update_val_loss_ema(avg_total_loss)

        self.writer.add_scalar('val/lm_loss', avg_lm_loss, self.global_step)
        self.writer.add_scalar('val/moe_loss', avg_moe_loss, self.global_step)
        self.writer.add_scalar('val/total_loss', avg_total_loss, self.global_step)
        self.writer.add_scalar('val/ema_loss', ema_loss, self.global_step)

        # 保存最佳模型
        if ema_loss < self.best_val_loss:
            self.best_val_loss = ema_loss
            self.save_checkpoint(f"best_model.pt")
        
        return avg_total_loss
    
    def check_validation_result(self):
        """检查异步验证结果"""
        if self.validation_future and self.validation_future.done():
            try:
                val_loss = self.validation_future.result()
                self.last_validation_result = val_loss
                self.writer.add_scalar('val/loss', val_loss, self.global_step)
                
                # 保存最佳模型
                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    self.save_checkpoint(f"best_model.pt")
                    
                # tqdm.write(f"Step {self.global_step}: val_loss = {val_loss:.4f}")
            except Exception as e:
                print(f"Validation error: {e}")
            finally:
                self.validation_future = None

    def train(self):
        # 训练循环
        progress_bar = tqdm(total=self.config.total_steps, desc="Training")
        
        for epoch in range(self.config.max_epochs):  # 足够大的epoch数，通过total_steps控制
            for batch in self.train_loader:
                if self.global_step >= self.config.total_steps:
                    break
                
                # 检查是否有验证结果
                self.check_validation_result()

                # 训练步骤
                metrics = self.train_step(batch)
                
                # 更新进度条
                progress_bar.set_postfix({
                    "loss": f"{metrics['total_loss']:.4f}",
                    "lr": f"{self.optimizer.param_groups[0]['lr']:.2e}",
                    "val_loss": f"{self.last_validation_result:.4f}" if self.last_validation_result else "N/A"
                })
                progress_bar.update(1)
                
                # 验证和保存检查点
                if (self.global_step % self.config.valid_steps == 0 
                    and self.validation_future is None):
                    self.validation_future = self.validation_executor.submit(self.async_validate)

                    # if val_loss is not None:
                    #     tqdm.write(f"Step {self.global_step}: val_loss = {val_loss:.4f}")
                
                if self.global_step % self.config.save_every == 0:
                    self.save_checkpoint(f"checkpoint_step_{self.global_step}.pt")
                
                self.global_step += 1
            
            if self.global_step >= self.config.total_steps:
                break
        # 等待最后的验证完成
        if self.validation_future:
            self.validation_future.result()
        
        progress_bar.close()
        self.writer.close()
        
        # 保存最终模型
        self.save_checkpoint("final_model.pt")
    
    def save_checkpoint(self, filename):
        checkpoint_path = os.path.join(self.config.checkpoint_dir, self.config.experiment_name, filename)
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
        
        checkpoint = {
            'global_step': self.global_step,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'scaler_state_dict': self.scaler.state_dict(),
            'best_val_loss': self.best_val_loss,
            'config': asdict(self.config)
        }
        
        torch.save(checkpoint, checkpoint_path)
    
    def load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
        self.global_step = checkpoint['global_step']
        self.best_val_loss = checkpoint['best_val_loss']

In [9]:
config = DeepSeekV2Config(
    # 基础参数（极小化）
    vocab_size = 151643,            # Qwen的词表
    hidden_size = 512,              # embedding 维度设置 MLA里
    num_hidden_layers = 4,          # 仅 4 层
    num_attention_heads = 4,        # 4 个注意力头
    max_position_embeddings = 512,
    initializer_range = 0.02,
    # max_epochs = 100, deafult

    # MLA 参数（极小化）
    q_lora_rank = 128,      # 原64 → 128
    kv_lora_rank = 96,      # 原32 → 96  
    qk_rope_head_dim = 48,  # 原32 → 48
    qk_nope_head_dim = 80,  # 原64 → 80
    v_head_dim = 96,        # 原64 → 96
    rope_theta = 10000.0,
    attention_bias = False,

    # MoE 参数（小规模测试）
    expert_number = 8,          # 仅 2 个专家
    top_k = 2,                  # 每次只激活 1 个
    shared_expert_number = 2,
    moe_load_balance_alpha= 0.01,
    expert_dropout = 0.1,

    # 训练参数（快速实验）
    batch_size = 4,
    seq_len = 512,              
    lr = 1e-5,
    weight_decay = 0.1,         
    warmup_steps = 150,
    total_steps = 10000,  
    save_every = 10000,
    grad_accum_steps=1,         # 梯度累计关闭，有可能对梯度爆炸
    valid_steps = 150,      
    validation_batch=50,
    async_validation=True,    

    # 其他参数
    attention_dropout = 0.05,
    hidden_dropout = 0.1,
    tie_word_embeddings = True,
    output_hidden_states = False,
    output_attentions = False,
    output_router_logits = True,

    # 日志和检查点
    log_dir = ".model/logs",
    checkpoint_dir= ".model/checkpoints",
    experiment_name = "llm_experiment_8",
)
model = DeepSeekV2Model(config)
model.apply(model._init_weights)

DeepSeekV2Model(
  (embed_tokens): Embedding(151643, 512)
  (embed_positions): Embedding(512, 512)
  (layers): ModuleList(
    (0-3): 4 x TransformerBlock(
      (ln1): DeepseekV2RMSNorm()
      (attn): MLAV2(
        (q_down_proj): Linear(in_features=512, out_features=128, bias=False)
        (q_down_layernorm): DeepseekV2RMSNorm()
        (q_up_proj): Linear(in_features=128, out_features=512, bias=False)
        (kv_down_proj): Linear(in_features=512, out_features=144, bias=False)
        (kv_down_layernorm): DeepseekV2RMSNorm()
        (kv_up_proj): Linear(in_features=96, out_features=704, bias=False)
        (o_proj): Linear(in_features=384, out_features=512, bias=False)
        (rotary_emb): DeepseekV2RotaryEmbedding()
      )
      (ln2): DeepseekV2RMSNorm()
      (moe): ShareExpertMOE(
        (moe_model): SparseMOE(
          (experts): ModuleList(
            (0-7): 8 x FFNExpert(
              (up): Linear(in_features=512, out_features=1365, bias=False)
              (down): 

In [10]:
print(model.embed_tokens.weight[:5, :5])  # 打印前5行5列
print(model.layers[0].moe.moe_model.experts[0].gate.weight) 

tensor([[ 0.0041, -0.0245,  0.0615,  0.0088, -0.0512],
        [ 0.0228,  0.1038,  0.0265, -0.0415,  0.0258],
        [-0.0488,  0.0112,  0.0542, -0.0627,  0.0016],
        [-0.0491, -0.0184,  0.0093, -0.0004,  0.0613],
        [-0.0326,  0.0582, -0.0642, -0.0075, -0.0313]],
       grad_fn=<SliceBackward0>)
Parameter containing:
tensor([[ 0.0266,  0.0228,  0.0310,  ...,  0.0371, -0.0277,  0.0614],
        [-0.0436,  0.0076,  0.0081,  ..., -0.0363, -0.0021,  0.0378],
        [ 0.0388,  0.0363,  0.0011,  ..., -0.0297, -0.0462,  0.0007],
        ...,
        [-0.0209,  0.0363,  0.0619,  ...,  0.0100,  0.0542,  0.0601],
        [-0.0785,  0.0399,  0.0161,  ...,  0.0688,  0.0553,  0.0214],
        [-0.0511, -0.0298, -0.0242,  ..., -0.1000, -0.0659,  0.0278]],
       requires_grad=True)


In [11]:
# 计算参数数量
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

total_params, trainable_params = count_parameters(model)

print(f"模型总参数数量: {total_params:,}")
print(f"可训练参数数量: {trainable_params:,}")
print(f"参数数量 (百万): {total_params / 1e6:.2f}M")

模型总参数数量: 163,666,848
可训练参数数量: 163,666,848
参数数量 (百万): 163.67M


In [12]:
from datasets import load_dataset
from transformers import AutoTokenizer

dataset = load_dataset("parquet",
                        data_files = "./dataset/wudao/clean_weight.parquet", 
                        split="train") #.select(range(50))
dataset



Dataset({
    features: ['text', 'labels'],
    num_rows: 234733
})

In [13]:
data = dataset.remove_columns(column_names="labels" )
data

Dataset({
    features: ['text'],
    num_rows: 234733
})

In [14]:
tokenizer = AutoTokenizer.from_pretrained("../model/Qwen2.5-0.5B-Instruct")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
config.vocab_size = tokenizer.vocab_size  

# 把文本转为 token ids
def tokenize_function(example):
    return tokenizer(
        example["text"], 
        truncation=True, 
        max_length=512, 
    )

tokenized_datasets = data.map(tokenize_function, batched=True, remove_columns=["text"])

print(tokenized_datasets)

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 234733
})


In [15]:
train_data, valld_data = tokenized_datasets.train_test_split(test_size=0.1, seed=42).values()
train_data, valld_data 

(Dataset({
     features: ['input_ids', 'attention_mask'],
     num_rows: 211259
 }),
 Dataset({
     features: ['input_ids', 'attention_mask'],
     num_rows: 23474
 }))

In [16]:
trainer = LLMTrainer(model, config, train_data, valld_data)

In [17]:
trainer.train()

Training: 100%|█████████████████████| 10000/10000 [3:28:55<00:00,  1.25s/it, loss=7.7800, lr=0.00e+00, val_loss=6.3903]


![LLMs_loss](./img/LLMs_loss.png)
![LLMs_entropy](./img/LLMs_entropy.png)

# 模型生成预测

In [18]:
def load_pretrained(model: nn.Module, checkpoint_path: str, device='cuda'):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    return model

In [38]:
inputs = torch.tensor(tokenized_datasets[58845]["input_ids"], dtype=torch.long)
print(inputs.dtype)
tokenizer.decode(inputs, skip_special_token = False)

torch.int64


'我为江湖\n《我为江湖》是毁灭ko魔神创作的网络小说，发表于起点网。作品简介 初入江湖 1.炼血宗 海浪猛烈的拍击着岸边,远处海鸟在海面上不停的盘旋着,在岸边一个青衣老者和三个长像怪异的人对峙着,那个老者”唉”了一声打破了平静,本来那老者在海边的岩洞里闭关修练,马上可以突破这一层,可是就在这时,跑出眼前的三个人,一句话也没说就开打,要是平常的时侯在来三个也是没问题的,可是在练攻最关键的时侯,为了接住他们的攻击只好强行从入定中醒来,受了很重的内伤. “你们想干什么”青衣老者出声问到. “嘿…嘿…”三个人阴阴笑着. “ [1]'

In [79]:
import torch

# 假设你有 tokenizer 和 model
device = "cuda" if torch.cuda.is_available() else "cpu"
config.vocab_size = tokenizer.vocab_size
model = DeepSeekV2Model(config)
model = load_pretrained(model, ".model/checkpoints/llm_experiment_5/final_model.pt", device)

# 初始 prompt
prompt_text = '今天的天气如何？'
input_ids = tokenizer.encode(prompt_text, return_tensors="pt").to(device)

# 自回归生成
output_ids = model.generate(
    input_ids=input_ids,
    max_new_tokens=40,
    temperature=0.9,
    # top_k=40,
    top_p=0.8,
    eos_token_id=tokenizer.eos_token_id
)

# 解码
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print("生成结果:", generated_text)


生成结果: 今天的天气如何？站血爱中了解韩
新十二阶段肌肤快速的 �青先生,他的一些的:速",时八,我想以上、甲或者因,我们体",他在
