In [1]:
from google.colab import drive
drive.mount('/content/drive')
!ls "/content/drive/MyDrive/Colab Notebooks/LLM25/MinimindReproductionGuide/BreakdownProcess/dataset"
import os
import sys
file_path = "/content/drive/MyDrive/Colab Notebooks/LLM25/MinimindReproductionGuide/BreakdownProcess"
os.chdir(file_path)
__package__ == "trainer"

Mounted at /content/drive
dpo_test.jsonl	 __init__.py	   lm_dataset.py      __pycache__
dpo_train.jsonl  lm_dataset.ipynb  pretrain_hq.jsonl  sft_mini_512.jsonl


False

In [2]:
import time
import math
import warnings
import torch
from contextlib import nullcontext
from torch import optim, nn
from torch.utils.data import DataLoader, Subset
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import SFTDataset
warnings.filterwarnings('ignore')


def Logger(content):
    print(content)


def get_lr(current_step, total_steps, lr):
    return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))

In [3]:
class Arguments:
  def __init__(self,
      out_dir = "./out",
      epochs = 2,
      batch_size = 16,
      learning_rate = 5e-8,
      device = "cuda" if torch.cuda.is_available() else "cpu",
      dtype = "bfloat16",
      accumulation_steps = 1,
      grad_clip = 1,
      warmup_iters = 0,
      log_interval = 100,
      save_interval = 100,
      local_rank = -1,
      hidden_size = 512,
      num_hidden_layers = 8,
      max_seq_len = 512,
      data_path = os.path.join(file_path, "./dataset/sft_mini_512.jsonl")
  ):
    self.out_dir = out_dir
    self.epochs = epochs
    self.batch_size = batch_size
    self.learning_rate = learning_rate
    self.device = device
    self.dtype = dtype
    self.accumulation_steps = accumulation_steps
    self.grad_clip = grad_clip
    self.warmup_iters = warmup_iters
    self.log_interval = log_interval
    self.save_interval = save_interval
    self.local_rank = local_rank
    self.hidden_size = hidden_size
    self.num_hidden_layers = num_hidden_layers
    self.max_seq_len = max_seq_len
    self.data_path = data_path
    self.save_dir = None
    self.tokens_per_iter = self.batch_size * self.max_seq_len

In [4]:
def train_epoch(epoch):
    loss_fct = nn.CrossEntropyLoss(reduction='none')
    start_time = time.time()
    for step, (X, Y, loss_mask) in enumerate(train_loader):
        X = X.to(args.device)
        Y = Y.to(args.device)
        loss_mask = loss_mask.to(args.device)
        lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        with ctx:
            res = model(X)
            loss = loss_fct(
                res.logits.view(-1, res.logits.size(-1)),
                Y.view(-1)
            ).view(Y.size())

            loss = (loss * loss_mask).sum() / loss_mask.sum()
            loss += res.aux_loss
            loss = loss / args.accumulation_steps

        scaler.scale(loss).backward()

        if (step + 1) % args.accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        if step % args.log_interval == 0:
            spend_time = (time.time() - start_time)
            Logger(
                'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
                    epoch + 1,
                    args.epochs,
                    step,
                    iter_per_epoch,
                    loss.item() * args.accumulation_steps,
                    optimizer.param_groups[-1]['lr'],
                    spend_time//60))


        if (step + 1) % args.save_interval == 0:
            model.eval()
            ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}.pth'
            state_dict = model.state_dict()
            state_dict = {k: v.half() for k, v in state_dict.items()}  # 半精度保存
            torch.save(state_dict, ckp)
            model.train()


def init_model(lm_config):
    tokenizer = AutoTokenizer.from_pretrained('./model')
    # pretrain_model_path = os.path.join(file_path, "./out/pretrain_512.pth")
    pretrain_model_path = os.path.join(file_path, "./out/full_sft_512.pth")
    model = MiniMindForCausalLM(lm_config)
    state_dict = torch.load(pretrain_model_path, map_location = 'cuda')
    model.load_state_dict(state_dict)
    model.train()
    Logger(f'LLM可训练总参数量：{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
    model.to('cuda')
    return model, tokenizer

In [None]:
if __name__ == "__main__":
    args = Arguments()
    lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers)
    args.save_dir = os.path.join(args.out_dir)
    os.makedirs(args.save_dir, exist_ok=True)
    os.makedirs(args.out_dir, exist_ok=True)
    tokens_per_iter = args.batch_size * args.max_seq_len
    device_type = "cuda" if "cuda" in args.device else "cpu"

    ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
    base_seed = 114514
    torch.manual_seed(base_seed)
    torch.cuda.manual_seed(base_seed)

    model, tokenizer = init_model(lm_config)

    train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
    # ranges = list(range(19199*32, len(train_ds)))
    # train_ds = Subset(train_ds, ranges)
    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        pin_memory=True,
        drop_last=False,
        shuffle=False
    )

    scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

    iter_per_epoch = len(train_loader)
    for epoch in range(args.epochs-1):
        train_epoch(epoch)

In [8]:
model.eval()
ckp = f'{args.save_dir}/full_sft_{lm_config.hidden_size}_end.pth'
print(ckp)
state_dict = model.state_dict()
state_dict = {k: v.half() for k, v in state_dict.items()}  # 半精度保存
torch.save(state_dict, ckp)
model.train()

./out/full_sft_512_end.pth


MiniMindForCausalLM(
  (model): MiniMindModel(
    (embed_tokens): Embedding(6400, 512)
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-7): 8 x MiniMindBlock(
        (self_attn): Attention(
          (q_proj): Linear(in_features=512, out_features=512, bias=False)
          (k_proj): Linear(in_features=512, out_features=128, bias=False)
          (v_proj): Linear(in_features=512, out_features=128, bias=False)
          (o_proj): Linear(in_features=512, out_features=512, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
        (mlp): FeedForward(
          (gate_proj): Linear(in_features=512, out_features=1408, bias=False)
          (down_proj): Linear(in_features=1408, out_features=512, bias=False)
          (up_proj): Linear(in_features=512, out_features=1408, bias=False)
          (drop