In [1]:
import sys
sys.path.append('..')

import os
import random
import argparse
import time
import warnings
import json
import numpy as np
import pandas as pd
import torch
from torch import optim, nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, Dataset, DistributedSampler

from types import SimpleNamespace
from contextlib import nullcontext

from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_lora import *
from model.model_minimind import MiniMindConfig
from model.model_lora import save_lora, apply_lora
from trainer.trainer_utils import setup_seed

from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler

warnings.filterwarnings('ignore')

In [2]:
DATA_LOC = '/root/xlcoder/MiniMind2-Small/dataset'
INSTR_LOC = '/root/xlcoder/MiniMind2-Small/hands-on'

In [3]:
# def init_model(args):
#     tokenizer = AutoTokenizer.from_pretrained(args.load_from)
#     if 'model' in args.load_from:
#         model = MiniMindForCausalLM(MiniMindConfig(
#             hidden_size=args.hidden_size,
#             num_hidden_layers=args.num_hidden_layers,
#             use_moe=bool(args.use_moe),
#             inference_rope_scaling=args.inference_rope_scaling
#         ))
#         moe_suffix = '_moe' if args.use_moe else ''
#         ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
#         model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
#         if args.lora_weight != 'None':
#             apply_lora(model)
#             load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
#     else:
#         model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
#     print(f'MiniMind模型参数: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
#     return model.eval().to(args.device), tokenizer

In [4]:
train_data = pd.read_csv(os.path.join(DATA_LOC, 'bbc_train_std.csv'))

In [5]:
class ClassificationSFTDataset(Dataset):
    """
    MiniMind-compatible SFT Dataset for text classification
    Returns: X, Y, loss_mask
    """

    def __init__(
        self,
        df: pd.DataFrame,
        tokenizer,
        max_length=1024,
        use_instruction=True,
        use_title=True,
        instruction_position="head",  # "head" | "middle"
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.samples = df

        self.instruction = 'Classify the following passage into one of the categories: <CLS_B>, <CLS_E>, <CLS_P>, <CLS_S>, <CLS_T>.'

        self.use_instruction = use_instruction
        self.use_title = use_title
        self.instruction_position = instruction_position

        self.pad_id = tokenizer.pad_token_id

    def __len__(self):
        return len(self.samples)
        
    def format_example(self, ex):
        parts = []

        instruction = self.instruction
        title = ex.title
        content = ex.content
        label = ex.category

        if self.use_instruction and self.instruction_position == "head":
            parts.append(instruction)
            parts.append("")

        if self.use_title and title:
            parts.append("Title:")
            parts.append(title)
            parts.append("")

        parts.append("Text:")
        parts.append(content)
        parts.append("")

        if self.use_instruction and self.instruction_position == "middle":
            parts.append(instruction)
            parts.append("")

        parts.append("Label:")

        prompt = "\n".join(parts)
        return prompt, label

    def __getitem__(self, idx):
        ex = self.samples.iloc[idx]
        prompt, label = self.format_example(ex)

        # 注意：label 前加空格，避免 tokenizer 粘连
        full_text = prompt + " " + label

        encoding = self.tokenizer(
            full_text,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt"
        )

        input_ids = encoding.input_ids.squeeze(0)
        attention_mask = encoding.attention_mask.squeeze(0)

        # -------- 构造 loss_mask（核心区别）--------
        # 只在 label 部分计算 loss
        prompt_ids = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.max_length,
            padding=False,
            return_tensors=None
        ).input_ids

        loss_mask = torch.zeros_like(input_ids)
        start = len(prompt_ids)

        # 对 label token 打 1
        loss_mask[start: start + (attention_mask[start:].sum())] = 1

        # -------- 构造 MiniMind 所需的 X, Y --------
        X = input_ids[:-1]
        Y = input_ids[1:]
        loss_mask = loss_mask[1:]

        return X, Y, loss_mask


In [6]:
def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
    loss_fct = nn.CrossEntropyLoss(reduction='none')
    start_time = time.time()
    for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
        X = X.to(args.device)
        Y = Y.to(args.device)
        loss_mask = loss_mask.to(args.device)
        lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        with autocast_ctx:
            res = model(X)
            loss = loss_fct(
                res.logits.view(-1, res.logits.size(-1)),
                Y.view(-1)
            ).view(Y.size())

            loss = (loss * loss_mask).sum() / loss_mask.sum()
            loss += res.aux_loss
            loss = loss / args.accumulation_steps

        scaler.scale(loss).backward()

        if (step + 1) % args.accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

            scaler.step(optimizer)
            scaler.update()

            optimizer.zero_grad(set_to_none=True)
            torch.cuda.empty_cache()

        if step % args.log_interval == 0 or step == iters - 1:
            spend_time = time.time() - start_time
            current_loss = loss.item() * args.accumulation_steps
            current_lr = optimizer.param_groups[-1]['lr']
            eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
            
            Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
            
            if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})

        if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
            model.eval()
            moe_suffix = '_moe' if lm_config.use_moe else ''
            ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
            if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            state_dict = {k: v.half().cpu() for k, v in state_dict.items()}
            torch.save(state_dict, ckp)
            lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, 
                         epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scaler=scaler)
            model.train()
            del state_dict

        del X, Y, loss_mask, res, loss

In [7]:
def train_epoch_lora(epoch, loader, iters, lora_params, start_step=0, wandb=None):
    loss_fct = nn.CrossEntropyLoss(reduction='none')
    start_time = time.time()
    for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
        X = X.to(args.device)
        Y = Y.to(args.device)
        loss_mask = loss_mask.to(args.device)
        lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        with autocast_ctx:
            res = model(X)
            loss = loss_fct(
                res.logits.view(-1, res.logits.size(-1)),
                Y.view(-1)
            ).view(Y.size())

            loss = (loss * loss_mask).sum() / loss_mask.sum()
            loss += res.aux_loss
            loss = loss / args.accumulation_steps

        scaler.scale(loss).backward()

        if (step + 1) % args.accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)

            scaler.step(optimizer)
            scaler.update()

            optimizer.zero_grad(set_to_none=True)
            torch.cuda.empty_cache()

        if step % args.log_interval == 0 or step == iters - 1:
            spend_time = time.time() - start_time
            current_loss = loss.item() * args.accumulation_steps
            current_lr = optimizer.param_groups[-1]['lr']
            eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
            
            Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
            
            if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})

        if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
            model.eval()
            lora_save_path = f'{args.save_dir}/lora/{args.save_weight}_{lm_config.hidden_size}.pth'
            # LoRA只保存LoRA权重
            save_lora(model, lora_save_path)
            lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
            model.train()

        del X, Y, loss_mask, res, loss

In [8]:
args={
    'load_from': '../model',
    'save_dir': '../out',
    'save_weight': 'en_text_cls_logits',
    'epochs': 10,
    'batch_size': 16,
    'learning_rate': 5e-7,
    'device': "cuda:0" if torch.cuda.is_available() else "cpu",
    'dtype': 'bfloat16',
    'num_workers': 1,
    'accumulation_steps': 1,
    'grad_clip': 1.0,
    'log_interval': 100,
    'save_interval': 100,
    'hidden_size': 512,
    'num_hidden_layers': 8,
    'max_seq_len': 4096,
    'use_moe': 0,
    'from_weight': 'en_pretrain',
    'from_resume': 1,
    'use_wandb': 0,
    'wandb_project': "MiniMind-Classification-SFT",
    'train_mode': ''
}
args = json.loads(json.dumps(args), object_hook=lambda d: SimpleNamespace(**d))

In [9]:
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))

# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
if args.train_mode == 'lora':
    os.makedirs(os.path.join(args.save_dir, 'lora'), exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None

# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)

# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
    import swanlab as wandb
    wandb_id = ckp_data.get('wandb_id') if ckp_data else None
    resume = 'must' if wandb_id else None
    wandb_run_name = f"MiniMind-LoRA-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
    wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)

# ========== 5. 定义模型、数据、优化器 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)

SPECIAL_LABELS = [
    "<CLS_B>",
    "<CLS_E>",
    "<CLS_P>",
    "<CLS_S>",
    "<CLS_T>",
]

tokenizer.add_special_tokens({
    "additional_special_tokens": SPECIAL_LABELS
})

model.resize_token_embeddings(len(tokenizer))


if args.train_mode == 'lora':
    apply_lora(model, rank=1024)
    
    # 统计参数
    total_params = sum(p.numel() for p in model.parameters())
    lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
    Logger(f"LLM 总参数量: {total_params / 1e6:.3f} M")
    Logger(f"LoRA 参数量: {lora_params_count / 1e6:.3f} M")
    Logger(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%")
    
    # 冻结非LoRA参数，收集LoRA参数
    lora_params = []
    for name, param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True
            lora_params.append(param)
        else:
            param.requires_grad = False

if args.train_mode == 'partial_freeze':
    # For Partial Freeze only
    for p in model.model.embed_tokens.parameters():
        p.requires_grad = False
    
    layers = model.model.layers
    freeze_n = len(layers) // 2
    for i in range(freeze_n):
        for p in layers[i].parameters():
            p.requires_grad = False

train_ds =  ClassificationSFTDataset(
    train_data,
    tokenizer,
    max_length=args.max_seq_len,
    use_instruction=True,
    use_title=True,
    instruction_position="head"
)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
if args.train_mode == 'lora':
    optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
else:
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
    model.load_state_dict(ckp_data['model'])
    optimizer.load_state_dict(ckp_data['optimizer'])
    scaler.load_state_dict(ckp_data['scaler'])
    start_epoch = ckp_data['epoch']
    start_step = ckp_data.get('step', 0)

# ========== 7. DDP包模型 ==========
if dist.is_initialized():
    model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
    model = DistributedDataParallel(model, device_ids=[local_rank])

# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
    train_sampler and train_sampler.set_epoch(epoch)
    if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
        batch_sampler = SkipBatchSampler(train_sampler or range(len(train_ds)), args.batch_size, start_step + 1)
        loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
        Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step，从step {start_step + 1}开始')
        if args.train_mode == 'lora':
            train_epoch_lora(epoch, loader, len(loader) + start_step + 1, lora_params, start_step, wandb)
        else:
            train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
    else: # 默认从头开始
        loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
        if args.train_mode == 'lora':
            train_epoch_lora(epoch, loader, len(loader), lora_params, 0, wandb)
        else:
            train_epoch(epoch, loader, len(loader), 0, wandb)


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


所加载Model可训练参数：25.830 百万
Epoch [5/10]: 跳过前111个step，从step 112开始
Epoch:[6/10](100/112) loss:0.782172 lr:0.000000212712 epoch_Time:0.0min:
Epoch:[6/10](111/112) loss:0.919683 lr:0.000000206072 epoch_Time:0.0min:
Epoch:[7/10](100/112) loss:0.635566 lr:0.000000148949 epoch_Time:0.0min:
Epoch:[7/10](111/112) loss:0.736335 lr:0.000000143259 epoch_Time:0.0min:
Epoch:[8/10](100/112) loss:0.660283 lr:0.000000097525 epoch_Time:0.0min:
Epoch:[8/10](111/112) loss:0.678993 lr:0.000000093343 epoch_Time:0.0min:
Epoch:[9/10](100/112) loss:0.677492 lr:0.000000063473 epoch_Time:0.0min:
Epoch:[9/10](111/112) loss:0.713626 lr:0.000000061208 epoch_Time:0.0min:
Epoch:[10/10](100/112) loss:0.688024 lr:0.000000050127 epoch_Time:0.0min:
Epoch:[10/10](111/112) loss:0.568730 lr:0.000000050001 epoch_Time:0.0min:
