# 10-Reason

参数太小的模型直接通过冷启动SFT+GRPO几乎不可能获得任何推理效果，因此，使用冷启动 SFT + GRPO 训练方法对小模型推理能力的作用有限.因此，MiniMind 项目作者使用推理数据集对 MiniMind 系列模型进行黑盒蒸馏来训练推理模型.

使用的推理数据格式:

```
{
    "conversations": [
        {"role": "user", "content": "Q1?"},
        {"role": "assistant", "content": "<think>T1</think>\n<answer>A1</answer>"},
        {"role": "user", "content": "Q2?"},
        {"role": "assistant", "content": "<think>T2</think>\n<answer>A2</answer>"}
    ]
}
```

此笔记本的完整实现见主仓库 `/minimind/train_reason.py`

In [1]:
import os
import platform
import argparse
import time
import math
import warnings

import pandas as pd
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext

from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import SFTDataset

warnings.filterwarnings('ignore')

## 可选参数设置

首先，查看训练的可选参数，这些参数在实际使用时通过解析命令行进行导入，我们用 class 进行包装.

In [2]:
class args:
    epochs: int = 5 # 训练轮数
    batch_size: int = 2 # pretrain 数据集仅两个样本，设置 batch 为 2
    learning_rate: float = 5e-4 # 学习率
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    dtype: str = 'bfloat16' # 16 bit 浮点数：8 bit 指数 + 7 bit 尾数
    # use_wandb: bool = False # 是否使用 wandb 我们不使用
    wandb_project: str = 'MiniMind-Notebook'
    num_workers: int = 1 # 工作进程数
    # ddp：bool = False # 单机多卡
    accumulation_steps: int = 1 # 梯度累积步数
    grad_clip: float = 1.0 # 梯度剪裁
    warmup_iters: int = 0 # 学习率热启动
    log_interval: int = 1 # 每一步打印日志 仅用于观察
    local_rank: int = 1 # device 设备号
    dim: int = 512 # 词嵌入维度 模型超参数
    n_layers: int = 2 # MiniMind Block 数量 模型超参数
    max_seq_len: int = 512 # 序列长度阈值
    use_moe: bool = False # 是否启用混合专家
    data_path: str = './toydata/r1_data.jsonl' # 数据集路径
    save_dir: str = "./output"  # 模型保存目录
    save_weight: str = "minimind_reasoning"  # checkpoint 文件前缀
    save_interval: int = 1  # 每多少步保存一次模型，0表示不保存 我们这里只展示训练过程（可选择的保存模型，建议先保存）

In [3]:
print(f'查看工作设备 {args.device}')

查看工作设备 cuda


接下来，我们对分词器、MiniMind 学生模型以及数据迭代器执行初始化.

In [None]:
def init_model(lm_config):
    tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
    model = MiniMindLM(lm_config)
    moe_path = '_moe' if lm_config.use_moe else ''
    # 热启动
    ckp = f'./output/minimind_dpo_{lm_config.dim}{moe_path}.pth'  
    state_dict = torch.load(ckp, map_location=args.device)
    model.load_state_dict(state_dict, strict=False)
    print(f'LLM总参数量：{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
    model = model.to(args.device)
    return model, tokenizer

In [5]:
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
model, tokenizer = init_model(lm_config)

train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)

train_loader = DataLoader(
    train_ds,
    batch_size=args.batch_size,
    pin_memory=True,
    drop_last=False,
    shuffle=False,
)

print(f'模型位于设备：{model.device}, 词表长度：{tokenizer.vocab_size}, DataLoader：{train_loader}')

LLM总参数量：8.915 百万
模型位于设备：cuda:0, 词表长度：6400, DataLoader：<torch.utils.data.dataloader.DataLoader object at 0x000002480320A290>


## 启动训练

接下来，我们定义 MiniMind LoRA 微调所使用的优化器，损失函数和学习率调度，并进行一轮简单的训练.

In [6]:
# 学习率调度方面 采用余弦退火学习率
def get_lr(current_step, total_steps, lr):
    return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))

# 优化器方面 选择 AdamW 优化器 并在混精度场景下创建 scaler 进行梯度缩放避免数值下溢
scaler = torch.amp.GradScaler('cuda', enabled=(args.dtype in ['float16', 'bfloat16']))  # 专门解决混合精度训练中的数值下溢问题
# 优化学生模型参数
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

device_type = "cuda" if "cuda" in args.device else "cpu"
print(f'设备类型：{device_type}')
# 根据指定的数据类型设置混精度训练的 dtype，以下步骤为不可缺少的混精度训练准备工作
if args.dtype == 'bfloat16':
    amp_dtype = torch.bfloat16
elif args.dtype == 'float16':
    amp_dtype = torch.float16
else:
    amp_dtype = torch.float32  # 默认为 FP32
print(f'使用混精度训练，数据类型：{amp_dtype}')
# 在 cuda 上启动混精度训练，否则空白上下文
autocast_ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type='cuda', dtype=amp_dtype) 

设备类型：cuda
使用混精度训练，数据类型：torch.bfloat16


接下来，我们来看看训练函数.

蒸馏思考数据集的训练过程与 SFT 类似，区别在于模型生成序列中，思考标签位置的预测错误惩罚被放大.

In [7]:
def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=None):
    start_of_think_ids = tokenizer('<think>').input_ids
    end_of_think_ids = tokenizer('</think>').input_ids
    start_of_answer_ids = tokenizer('<answer>').input_ids
    end_of_answer_ids = tokenizer('</answer>').input_ids
    loss_fct = nn.CrossEntropyLoss(reduction='none')
    start_time = time.time()
    
    for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
        # 将输入数据移动到指定设备
        X = X.to(args.device)
        Y = Y.to(args.device)
        loss_mask = loss_mask.to(args.device)
        lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        with autocast_ctx:
            res = model(X)
            shift_logits = res.logits[..., :-1, :].contiguous()
            shift_labels = Y[..., 1:].contiguous()
            shift_loss_mask = loss_mask[..., 1:].contiguous()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())

            # 对特殊标记位置的 loss 进行放大
            sp_ids = torch.isin(shift_labels.view(-1),
                                torch.tensor(start_of_think_ids + end_of_think_ids
                                             + start_of_answer_ids + end_of_answer_ids
                                             ).to(args.device))  # [batch_size*seq_len-1]
            loss_mask_flat = shift_loss_mask.view(-1) # [batch_size*seq_len-1]
            loss_mask_sum = loss_mask_flat.sum()  # 计算有效位置的数量
            loss_mask_flat[sp_ids] = 10   # 将特殊标记位置的 loss 放大 10 倍
            shift_loss_mask = loss_mask_flat.view(shift_labels.size())  # 恢复 loss_mask 的原始形状
            logits_loss = (loss * shift_loss_mask).sum() / loss_mask_sum
            loss = logits_loss + res.aux_loss
            loss = loss / args.accumulation_steps

        scaler.scale(loss).backward()

        if (step + 1) % args.accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        if step % args.log_interval == 0 or step == iters - 1:
            spend_time = time.time() - start_time
            current_loss = loss.item() * args.accumulation_steps
            current_aux_loss = res.aux_loss if res.aux_loss is not None else 0.0
            current_logits_loss = logits_loss.item()
            current_lr = optimizer.param_groups[-1]['lr']
            eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
            print(
                f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
                f'loss: {current_loss:.4f}, '
                f'logits_loss: {current_logits_loss:.4f}, '
                f'aux_loss: {current_aux_loss:.4f}, '
                f'lr: {current_lr:.8f}, '
                f'epoch_time: {eta_min:.1f}min'
            )
            if wandb: 
                wandb.log({
                    "loss": current_loss, 
                    "logits_loss": current_logits_loss, 
                    "aux_loss": current_aux_loss, 
                    "learning_rate": current_lr, 
                    "epoch_time": eta_min
                })

        # 到达指定保存步数时，保存模型（仅主进程）
        if args.save_interval > 0 and (step % args.save_interval == 0 or step == iters - 1):
            if not dist.is_initialized() or dist.get_rank() == 0:
                os.makedirs(args.save_dir, exist_ok=True)  # 确保保存目录存在
                model.eval()
                moe_suffix = '_moe' if lm_config.use_moe else ''
                ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.dim}{moe_suffix}.pth'
                raw_model = model.module if isinstance(model, DistributedDataParallel) else model
                raw_model = getattr(raw_model, '_orig_mod', raw_model)
                state_dict = raw_model.state_dict()
                torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
                print(f'模型已保存至：{ckp}')
                model.train()
                del state_dict

        del X, Y, loss_mask, loss, shift_logits, shift_labels, shift_loss_mask

接下来，我们启动一个 Epoch 的训练进行观察.

In [8]:
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
    train_epoch(epoch, train_loader, iter_per_epoch, tokenizer, lm_config)
print('reasoning 训练完成！')

Epoch:[1/5](1/1), loss: 11.7936, logits_loss: 11.7936, aux_loss: 0.0000, lr: 0.00050225, epoch_time: 0.0min
模型已保存至：./output/minimind_reasoning_512.pth
Epoch:[2/5](1/1), loss: 10.2723, logits_loss: 10.2723, aux_loss: 0.0000, lr: 0.00037725, epoch_time: 0.0min
模型已保存至：./output/minimind_reasoning_512.pth
Epoch:[3/5](1/1), loss: 9.3473, logits_loss: 9.3473, aux_loss: 0.0000, lr: 0.00022275, epoch_time: 0.0min
模型已保存至：./output/minimind_reasoning_512.pth
Epoch:[4/5](1/1), loss: 8.8924, logits_loss: 8.8924, aux_loss: 0.0000, lr: 0.00009775, epoch_time: 0.0min
模型已保存至：./output/minimind_reasoning_512.pth
Epoch:[5/5](1/1), loss: 8.6720, logits_loss: 8.6720, aux_loss: 0.0000, lr: 0.00005000, epoch_time: 0.0min
模型已保存至：./output/minimind_reasoning_512.pth
reasoning 训练完成！


In [9]:
del model