trainer_distillation.ipynb
---------------------------------
MicroCortex 语言模型模型蒸馏脚本（带详细中文注释）。
本脚本演示了如何使用 PyTorch + Transformers 来蒸馏一个自定义的
MicroCortexCausalLanguageModel，并支持分布式数据并行（DDP）训练、
梯度累积、自动混合精度 (AMP)、学习率余弦退火以及按间隔保存检查点。
"""
模型加载：
    学生模型512*8权重full_sft_512.pth
    教师模型768*16权重full_sft_768.pth
    学生参与优化，教师仅前向传递参数与蒸馏
损失：
    训练损失 = α × CE(学生 vs GT) + (1–α) × KL(学生 || 教师) @ temperature，蒸馏部分仅在 mask=1 的 token 上计算。
"""

一、导入相关包，和train_pretrain一样

In [None]:
import os
import sys

__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import math
import warnings

import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import SFTDataset

warnings.filterwarnings('ignore')

二、相关的工具函数，和train_pretrain一样

In [None]:
def Logger(content):
    if not ddp or dist.get_rank() == 0:
        print(content)

def get_lr(current_step, total_steps, lr):
    return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))

kl散度损失

In [None]:
# kl散度损失
def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
    '''
    计算蒸馏损失（KL 散度）
    Args:
        student_logits: [N, V] 学生网络最后一层未归一化的logits
        teacher_logits: [N, V] 教师网络最后一层为归一化的logits
        temperature: float 温度系数，温度越低分布越平缓，温度越高分布越陡峭
        reduction: str kl散度的聚合方式
    Raise:
        loss:
    '''
    with torch.no_grad():
        # 输出教师模型的概率分布，这里用到了温度
        teacher_probs = F.softmax(teacher_logits / temperature, hidden_size=-1).detach()

    student_log_probs = F.log_softmax(student_logits / temperature, hidden_size=-1)

    kl = F.kl_div(
        student_log_probs,
        teacher_probs,
        reduction=reduction
    )
    # 因为teacher_probs和student_log_probs都除以了T，所以要乘以T平方，使梯度大小与温度设置无关
    # 当T为1的时候就相当于完全没有T
    return (temperature ** 2) * kl

三、单个epoch的训练逻辑

模型与tokenizer的初始化，这里我们要初始化两个模型，老师模型和学生模型，这里我们是加载之前SFT模型的参数

In [None]:
#初始化学生和教师模型
def init_student_model(lm_config):
    tokenizer = AutoTokenizer.from_pretrained('../model/')
    model = MiniMindForCausalLM(lm_config)
    moe_path = '_moe' if lm_config.use_moe else ''
    ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
    state_dict = torch.load(ckp, map_location=args.device)
    model.load_state_dict(state_dict, strict=False)
    Logger(f'学生模型(LLM)总参数量：{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
    model = model.to(args.device)

    return model, tokenizer


def init_teacher_model(lm_config):
    model = MiniMindForCausalLM(lm_config)
    moe_path = '_moe' if lm_config.use_moe else ''
    ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}{moe_path}.pth'
    state_dict = torch.load(ckp, map_location=args.device)
    model.load_state_dict(state_dict, strict=False)
    Logger(f'教师模型(LLM)总参数量：{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
    model = model.to(args.device)
    return model

分布式初始化，torchrun环境变量自动注入，和train_pretrain一样

In [None]:
#和train_pretrain一样
def init_distributed_mode():
    if not ddp: return
    global ddp_local_rank, DEVICE

    dist.init_process_group(backend="nccl")
    ddp_rank = int(os.environ["RANK"])
    ddp_local_rank = int(os.environ["LOCAL_RANK"])
    ddp_world_size = int(os.environ["WORLD_SIZE"])
    DEVICE = f"cuda:{ddp_local_rank}"
    torch.cuda.set_device(DEVICE)

单个epoch的训练，

In [None]:
def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
    start_time = time.time()

    #教师模型固定参数，即用eval模式
    if teacher_model is not None:
        teacher_model.eval()
        teacher_model.requires_grad_(False)#教师模型也不计算梯度

    for step, (X, Y, loss_mask) in enumerate(train_loader):
        # 将 batch 数据搬到指定设备
        X = X.to(args.device)
        Y = Y.to(args.device)
        loss_mask = loss_mask.to(args.device)
        # 计算并设置学习率（step 级别）
        lr = get_lr(epoch * iter_per_epoch + step,
                    args.epochs * iter_per_epoch,
                    args.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # 前向传播（学生模型），自动混合精度上下文（在 CPU 时为 nullcontext 空操作）
        with ctx:
            res = model(X)
            student_logits = res.logits

        # 教师模型前向传播（只在eval & no_grad）
        if teacher_model is not None:
            with torch.no_grad():
                teacher_logits = teacher_model(X).logits
                vocab_size_student = student_logits.size(-1)  # N
                teacher_logits = teacher_logits[..., :vocab_size_student]

        # ========== 计算损失 ==========
        #### 1) Ground-Truth CE Loss（可选）
        loss_mask_flat = loss_mask.view(-1)
        #计算学生模型输出和Y的交叉熵损失，token level级别的，这和train_pretrain是一样的
        ce_loss = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            Y.view(-1),
            ignore_index=0,
            reduction='none'
        )
        # mask + 求平均
        ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
        # 加上模型可能返回的额外正则项（如 MoE loss）
        if lm_config_student.use_moe:
            ce_loss += res.aux_loss

        #### 2) Distillation Loss（可选）
        if teacher_model is not None:
            # 只在有效token位置做蒸馏
            distill_loss = distillation_loss_fn(
                student_logits.view(-1, student_logits.size(-1))[loss_mask_flat == 1],
                teacher_logits.view(-1, teacher_logits.size(-1))[loss_mask_flat == 1],
                temperature=temperature
            )
        else:
            distill_loss = torch.tensor(0.0, device=args.device)

        #### 3) 总损失 = alpha * CE + (1-alpha) * Distill
        # 梯度累积：先除以累计步数
        loss = (alpha * ce_loss + (1 - alpha) * distill_loss) / args.accumulation_steps

        #以下开始和train_pretrain就一样了
        scaler.scale(loss).backward()

        if (step + 1) % args.accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        if step % args.log_interval == 0:
            spend_time = time.time() - start_time
            Logger(
                'Epoch:[{}/{}]({}/{}) loss:{:.4f} lr:{:.12f} epoch_Time:{}min:'.format(
                    epoch,
                    args.epochs - 1,
                    step,
                    iter_per_epoch,
                    loss.item(),
                    optimizer.param_groups[-1]['lr'],
                    spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
                )
            )

            if (wandb is not None) and (not ddp or dist.get_rank() == 0):
                wandb.log({
                    "loss": loss.item(),
                    "ce_loss": ce_loss.item(),
                    "distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
                    "lr": optimizer.param_groups[-1]['lr'],
                    "last-time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
                })

        if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
            model.eval()
            moe_path = '_moe' if lm_config_student.use_moe else ''
            ckp = f'{args.save_dir}/full_dist_{lm_config_student.hidden_size}{moe_path}.pth'
            if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            state_dict = {k: v.half() for k, v in state_dict.items()}  # 半精度保存
            torch.save(state_dict, ckp)
            model.train()