## **0. Preparation**

In [None]:
# 安装依赖
# !pip install -qU peft accelerate datasets einops

In [1]:
import copy
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm
from typing import List
from einops import rearrange
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if device != 'cpu' and torch.cuda.is_bf16_supported() else torch.float32
print(f'device: {device}\ndtype: {dtype}')

device: cuda
dtype: torch.bfloat16


## **1. LoRA**

创建一个小模型，以 llama 为例

In [None]:
config = AutoConfig.for_model('llama')
config.hidden_size = 24
config.intermediate_size = config.hidden_size * 4
config.num_attention_heads = 4
config.num_hidden_layers = 4
config.num_key_value_heads = 2
config.vocab_size = 128

In [None]:
raw_model = AutoModel.from_config(config)
raw_model

LlamaModel(
  (embed_tokens): Embedding(128, 24)
  (layers): ModuleList(
    (0-3): 4 x LlamaDecoderLayer(
      (self_attn): LlamaSdpaAttention(
        (q_proj): Linear(in_features=24, out_features=24, bias=False)
        (k_proj): Linear(in_features=24, out_features=12, bias=False)
        (v_proj): Linear(in_features=24, out_features=12, bias=False)
        (o_proj): Linear(in_features=24, out_features=24, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=24, out_features=96, bias=False)
        (up_proj): Linear(in_features=24, out_features=96, bias=False)
        (down_proj): Linear(in_features=96, out_features=24, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)

接下来是自己写的 LoRA 类

其中有一个字段 `test_mode`，用于控制 lora_B 是否为全零，在后面会用到

In [2]:
class LoraLinear(nn.Module):
    def __init__(
        self,
        base_layer: nn.Linear,      # 原来的线性层
        r: int = 8,                 # lora rank
        alpha: int = 16,            # lora alpha
        dropout_p: float = 0.0,     # lora dropout
        test_mode: bool = False,    # 测试模式，用于控制 lora_B 是否为全零
    ):
        super(LoraLinear, self).__init__()
        self.base_layer = copy.deepcopy(base_layer)
        self.r = r
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout_p)

        # 定义 lora_A 和 lora_B 为 Parameter
        self.lora_A = nn.Parameter(torch.empty((r, base_layer.in_features), dtype=base_layer.weight.dtype))
        self.lora_B = nn.Parameter(torch.empty((base_layer.out_features, r), dtype=base_layer.weight.dtype))

        # 初始化 lora 矩阵
        nn.init.normal_(self.lora_A, mean=0.0, std=0.02)
        if test_mode:
            nn.init.normal_(self.lora_B, mean=0.0, std=0.02)
        else:
            nn.init.zeros_(self.lora_B)

        # 冻结原来的层的参数
        for param in self.base_layer.parameters():
            param.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        scaling = float(self.alpha) / float(self.r)     # lora 缩放系数
        lora_adjustment = F.linear(self.dropout(x), self.lora_A)
        lora_adjustment = F.linear(lora_adjustment, self.lora_B)
        return self.base_layer(x) + lora_adjustment * scaling

In [3]:
def replace_linear_with_lora(
    module: nn.Module,
    r: int = 8,
    alpha: int = 16,
    dropout_p: float = 0.0,
    embed_requires_grad: bool = False,      # embedding 层是否训练
    norm_requires_grad: bool = False,       # norm 层是否训练
    head_requires_grad: bool = False,       # lm_head 层是否训练（Causal LM才有）
    test_mode: bool = False,                # 测试模式，用于控制 lora_B 是否为全零
):
    """
    找到 module 中所有线性层并递归替换
    """
    for name, child in module.named_children():
        # 先处理额外的层，lm_head 也是 linear，所以先处理
        if any(s in name for s in ['embed', 'norm', 'lm_head']):
            requires_grad = embed_requires_grad if 'embed' in name \
                            else norm_requires_grad if 'norm' in name \
                            else head_requires_grad
            for param in child.parameters():
                param.requires_grad = requires_grad
        # 替换所有线性层，QLoRA 做法
        elif isinstance(child, nn.Linear):
            lora_linear = LoraLinear(child, r=r, alpha=alpha, dropout_p=dropout_p, test_mode=test_mode)
            setattr(module, name, lora_linear)
        # 递归向下替换
        else:
            replace_linear_with_lora(
                child, r, alpha, dropout_p,
                embed_requires_grad, norm_requires_grad, head_requires_grad,
                test_mode=test_mode
            )

In [4]:
def unload_lora(module: nn.Module, adapter_name: str = 'adapter'):
    """
    卸载 lora 参数，并将原模型恢复至加载 lora 前的样子
    """
    lora_parameters = {}
    def search_lora_linear(module: nn.Module, prefix: List[str]):
        for name, child in module.named_children():
            new_prefix = prefix + [name]
            if isinstance(child, LoraLinear):
                # 保存 lora 参数
                lora_parameters['.'.join(new_prefix)] = {
                    "lora_A_weight": child.lora_A.data.cpu(),
                    "lora_B_weight": child.lora_B.data.cpu(),
                    "r": child.r,
                    "alpha": child.alpha,
                    "dropout_p": child.dropout.p,
                }
                setattr(module, name, child.base_layer)
            else:
                search_lora_linear(child, new_prefix)

    search_lora_linear(module, [])
    # 解冻原模型
    for name, param in module.named_parameters():
        param.requires_grad = True

    torch.save(lora_parameters, f"{adapter_name}.pt")

In [5]:
def load_lora(module: nn.Module, adapter_name: str = 'adapter'):
    """
    加载 lora 参数
    """
    lora_parameters = torch.load(f"{adapter_name}.pt")

    for name, lora_params in lora_parameters.items():
        child = dict(module.named_modules())[name]
        if isinstance(child, nn.Linear):
            lora_linear = LoraLinear(child, lora_params['r'], lora_params['alpha'], lora_params['dropout_p'])
            lora_linear.lora_A.data = lora_params["lora_A_weight"].to(lora_linear.lora_A.device)
            lora_linear.lora_B.data = lora_params["lora_B_weight"].to(lora_linear.lora_B.device)

            # 名称示例：layers.0.self_attn.q_proj
            # 根据名称循环找到所需 module
            parts = name.split(".")
            obj = module
            for part in parts[:-1]:  # 不包括最后一级
                obj = getattr(obj, part)
            setattr(obj, parts[-1], lora_linear)

    # 恢复原来的冻结方式，这里简单地除了 lora 全冻结
    for name, param in module.named_parameters():
        if any(s in name for s in ['embed', 'norm', 'lm_head']):
            param.requires_grad = False

In [6]:
def print_trainable_parameters(model: nn.Module):
    """
    打印可训练参数，和 PeftModel 的方法类似
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    trainable_percentage = 100 * trainable_params / total_params

    print(f"trainable params: {trainable_params:,} || all params: {total_params:,} || trainable%: {trainable_percentage:.4f}")

In [None]:
print_trainable_parameters(raw_model)

trainable params: 37,848 || all params: 37,848 || trainable%: 100.0000


In [None]:
lora_model = copy.deepcopy(raw_model)
replace_linear_with_lora(lora_model)
print_trainable_parameters(lora_model)

trainable params: 16,896 || all params: 54,744 || trainable%: 30.8637


可以看到，lora_model 创建成功了。

下面测试一下 `unload` 和 `load`。

由于原本 lora 的做法会让 BA 为零矩阵，所以对于加载 lora 前后的初始化模型，forward 的结果是一样的。

因此，我们在测试的时候，临时将 A 和 B 都做高斯初始化，让 BA 非零，从而比较不同的 forward 结果，验证 `unload` 和 `load` 的正确性。

In [None]:
# 创建一个测试 tensor
bsz = 2
seq_len = 8
test_tensor = torch.randint(0, config.vocab_size, (bsz, seq_len))

In [None]:
# 开测试模式，让 BA 非零
lora_model = copy.deepcopy(raw_model)
replace_linear_with_lora(lora_model, test_mode=True)

In [None]:
# 原模型的前向结果
raw_model.eval()
print_trainable_parameters(raw_model)   # 检查参数和可训练情况
raw_res = raw_model(test_tensor).last_hidden_state

trainable params: 37,848 || all params: 37,848 || trainable%: 100.0000


In [None]:
# 第一次直接初始化 lora 的前向结果
lora_model.eval()
print_trainable_parameters(lora_model)  # 检查参数和可训练情况
before_unload_res = lora_model(test_tensor).last_hidden_state

trainable params: 16,896 || all params: 54,744 || trainable%: 30.8637


In [None]:
# 卸载 lora 后的前向结果
unload_lora(lora_model)
lora_model.eval()
print_trainable_parameters(lora_model)  # 检查参数和可训练情况
unload_res = lora_model(test_tensor).last_hidden_state

trainable params: 37,848 || all params: 37,848 || trainable%: 100.0000


In [None]:
# 重新装载 lora 后的前向结果
load_lora(lora_model)
lora_model.eval()
print_trainable_parameters(lora_model)  # 检查参数和可训练情况
load_res = lora_model(test_tensor).last_hidden_state

trainable params: 16,896 || all params: 54,744 || trainable%: 30.8637


可以看到，一、三的参数和可训练情况一致，二、四的参数和可训练情况一致，均符合预期。

下面检验前向结果是否也符合预期。

In [None]:
print(torch.allclose(raw_res, unload_res, atol=1e-6))           # 应为 True
print(torch.allclose(before_unload_res, load_res, atol=1e-6))   # 应为 True
print(torch.allclose(raw_res, load_res, atol=1e-6))             # 应为 False

True
True
False


接下来，尝试用我们自己写的 lora 来进行微调。

模型选用 [LiteLlama-460M-1T](https://huggingface.co/ahxt/LiteLlama-460M-1T)，数据集选用 [vicgalle/alpaca-gpt4](https://huggingface.co/datasets/vicgalle/alpaca-gpt4)

In [52]:
# 模型和数据路径都可以改成本地的
model_name_or_path = 'ahxt/LiteLlama-460M-1T'
data_name_or_path = 'vicgalle/alpaca-gpt4'

In [67]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'

model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype).to(device)

In [68]:
# 获取 lora model
replace_linear_with_lora(model, r=8, alpha=16, dropout_p=0.0)
model.to(device)

# 查看可训练参数
print_trainable_parameters(model)

trainable params: 4,177,920 || all params: 465,863,680 || trainable%: 0.8968


In [69]:
# 定义训练数据集
class SFTDataset(Dataset):
    def __init__(self,
        tokenizer: AutoTokenizer,
        data_path: str,
        load_local: bool = False,
        max_len: int = 256,
        split_len: str = '1%',
    ):
        super().__init__()
        self.tokenizer = tokenizer

        if load_local:
            self.ds = load_dataset('json', data_dir=data_path, split=f'train[:{split_len}]')
        else:
            self.ds = load_dataset(data_path, split=f'train[:{split_len}]')
        self.max_len = max_len

        def process_func(example):
            # 提取 instruction 和 input
            instruction = example['instruction'].strip()
            input = example['input'].strip()
            output = example['output'].strip()

            # 构造模板
            instruction_prompt = f"Human: {instruction}\n" + \
                                    (f"{input}\n" if len(input) > 0 else "") + \
                                    "Assistant: "
            output_prompt = f"{output}\n"

            # 截断，最大不超过 max_len
            tokenized_instruction = self.tokenizer(instruction_prompt, add_special_tokens=False)['input_ids']
            tokenized_output = self.tokenizer(output_prompt, add_special_tokens=False)['input_ids']
            tokenized_prompt = (tokenized_instruction + tokenized_output)[:self.max_len]

            # 构造 input_ids, attention_mask, labels
            input_ids = tokenized_prompt[:-1]
            padding_mask = ([0] * len(tokenized_instruction) + [1] * (len(tokenized_output)))[:self.max_len][1:]
            labels = tokenized_prompt[1:]

            return {
                'input_ids': torch.LongTensor(input_ids),
                'attention_mask': torch.LongTensor(padding_mask),
                'labels': torch.LongTensor(labels),
            }

        self.ds = self.ds.map(
            process_func,
            batched=False,
            remove_columns=self.ds.column_names,
            desc='Processing dataset',
        )

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index: int):
        return self.ds[index]

In [56]:
ds = SFTDataset(tokenizer, data_name_or_path, load_local=False)

Processing dataset:   0%|          | 0/520 [00:00<?, ? examples/s]

In [58]:
print(len(ds[0]['input_ids']))
print(len(ds[0]['attention_mask']))
print(len(ds[0]['labels']))

163
163
163


In [59]:
def collate_fn(batch: List, tokenizer):
    max_len = max(len(item['input_ids']) for item in batch)

    input_ids = []
    attention_mask = []
    labels = []

    for item in batch:
        input_id = item['input_ids']
        attention_mask_item = item['attention_mask']
        label = item['labels']

        # 计算填充长度
        pad_len = max_len - len(input_id)

        # 左填充
        input_ids.append([tokenizer.eos_token_id] * pad_len + input_id)
        attention_mask.append([0] * pad_len + attention_mask_item)
        labels.append([tokenizer.eos_token_id] * pad_len + label)

    # 将列表转换为张量
    input_ids = torch.LongTensor(input_ids)
    attention_mask = torch.LongTensor(attention_mask)
    labels = torch.LongTensor(labels)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
    }

In [70]:
bsz = 16
lr = 1e-3
num_epochs = 10
logging_steps = 5
max_grad_norm = 1.0

In [71]:
dataloader = DataLoader(ds, batch_size=bsz, shuffle=True, collate_fn=lambda batch: collate_fn(batch, tokenizer))

In [72]:
for batch in dataloader:
    print(batch['input_ids'].shape)
    print(batch['attention_mask'].shape)
    print(batch['labels'].shape)
    break

torch.Size([16, 255])
torch.Size([16, 255])
torch.Size([16, 255])


In [73]:
optimizer = optim.AdamW(model.parameters(), lr=lr)

In [74]:
model.train()

total_loss = 0
total_step = 0
for epoch in range(num_epochs):
    for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        logits = outputs.logits
        rearranged_logits = rearrange(logits, 'bsz seq_len vocab_size -> (bsz seq_len) vocab_size')
        rearranged_attention_mask = rearrange(attention_mask, 'bsz seq_len -> (bsz seq_len)')
        rearranged_labels = rearrange(labels, 'bsz seq_len -> (bsz seq_len)')

        sum_loss = F.cross_entropy(rearranged_logits, rearranged_labels, ignore_index=0, reduction='none')
        loss = torch.sum(sum_loss * rearranged_attention_mask) / torch.sum(rearranged_attention_mask)
        loss.backward()

        # 计算梯度范数并裁剪
        total_norm = nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()
        total_loss += loss.item()

        total_step += 1
        if total_step % logging_steps == 0:
            avg_loss = total_loss / total_step
            print(f"Step: {step+1}/{len(dataloader)}, Loss: {avg_loss:.4f}, Grad Norm: {total_norm:.4f}", flush=True)


    # 打印每个 epoch 结束的累计损失
    print(f"Epoch {epoch+1} finished, Average Loss: {total_loss / total_step:.4f}", flush=True)

Epoch 1/10:  12%|█▏        | 4/33 [00:14<01:47,  3.71s/it]

Step: 5/33, Loss: 2.2622, Grad Norm: 0.5820


Epoch 1/10:  27%|██▋       | 9/33 [00:34<01:33,  3.90s/it]

Step: 10/33, Loss: 2.2214, Grad Norm: 0.5703


Epoch 1/10:  42%|████▏     | 14/33 [00:53<01:12,  3.82s/it]

Step: 15/33, Loss: 2.1721, Grad Norm: 0.4297


Epoch 1/10:  58%|█████▊    | 19/33 [01:11<00:52,  3.72s/it]

Step: 20/33, Loss: 2.1341, Grad Norm: 0.3398


Epoch 1/10:  73%|███████▎  | 24/33 [01:30<00:33,  3.67s/it]

Step: 25/33, Loss: 2.1184, Grad Norm: 0.3164


Epoch 1/10:  88%|████████▊ | 29/33 [01:48<00:14,  3.70s/it]

Step: 30/33, Loss: 2.0947, Grad Norm: 0.4316


Epoch 1/10: 100%|██████████| 33/33 [02:01<00:00,  3.69s/it]

Epoch 1 finished, Average Loss: 2.0860



Epoch 2/10:   3%|▎         | 1/33 [00:03<02:01,  3.79s/it]

Step: 2/33, Loss: 2.0647, Grad Norm: 0.4004


Epoch 2/10:  18%|█▊        | 6/33 [00:22<01:41,  3.75s/it]

Step: 7/33, Loss: 2.0220, Grad Norm: 0.5625


Epoch 2/10:  33%|███▎      | 11/33 [00:41<01:21,  3.70s/it]

Step: 12/33, Loss: 1.9867, Grad Norm: 0.4180


Epoch 2/10:  48%|████▊     | 16/33 [00:59<01:03,  3.71s/it]

Step: 17/33, Loss: 1.9630, Grad Norm: 0.4219


Epoch 2/10:  64%|██████▎   | 21/33 [01:18<00:44,  3.71s/it]

Step: 22/33, Loss: 1.9475, Grad Norm: 0.4551


Epoch 2/10:  79%|███████▉  | 26/33 [01:36<00:25,  3.71s/it]

Step: 27/33, Loss: 1.9335, Grad Norm: 0.4141


Epoch 2/10:  94%|█████████▍| 31/33 [01:55<00:07,  3.71s/it]

Step: 32/33, Loss: 1.9195, Grad Norm: 0.3906


Epoch 2/10: 100%|██████████| 33/33 [02:00<00:00,  3.66s/it]

Epoch 2 finished, Average Loss: 1.9177



Epoch 3/10:   9%|▉         | 3/33 [00:11<01:51,  3.71s/it]

Step: 4/33, Loss: 1.8868, Grad Norm: 0.5195


Epoch 3/10:  24%|██▍       | 8/33 [00:29<01:33,  3.75s/it]

Step: 9/33, Loss: 1.8470, Grad Norm: 0.5430


Epoch 3/10:  39%|███▉      | 13/33 [00:48<01:14,  3.74s/it]

Step: 14/33, Loss: 1.8110, Grad Norm: 0.4785


Epoch 3/10:  55%|█████▍    | 18/33 [01:07<00:56,  3.74s/it]

Step: 19/33, Loss: 1.7821, Grad Norm: 0.7109


Epoch 3/10:  70%|██████▉   | 23/33 [01:25<00:37,  3.74s/it]

Step: 24/33, Loss: 1.7594, Grad Norm: 0.6250


Epoch 3/10:  85%|████████▍ | 28/33 [01:44<00:18,  3.72s/it]

Step: 29/33, Loss: 1.7413, Grad Norm: 0.7695


Epoch 3/10: 100%|██████████| 33/33 [02:01<00:00,  3.67s/it]

Epoch 3 finished, Average Loss: 1.7255



Epoch 4/10:   0%|          | 0/33 [00:00<?, ?it/s]

Step: 1/33, Loss: 1.7177, Grad Norm: 0.5742


Epoch 4/10:  15%|█▌        | 5/33 [00:18<01:45,  3.75s/it]

Step: 6/33, Loss: 1.6793, Grad Norm: 0.6992


Epoch 4/10:  30%|███       | 10/33 [00:37<01:25,  3.72s/it]

Step: 11/33, Loss: 1.6453, Grad Norm: 0.5547


Epoch 4/10:  45%|████▌     | 15/33 [00:55<01:07,  3.74s/it]

Step: 16/33, Loss: 1.6123, Grad Norm: 0.8164


Epoch 4/10:  61%|██████    | 20/33 [01:14<00:48,  3.74s/it]

Step: 21/33, Loss: 1.5810, Grad Norm: 0.9336


Epoch 4/10:  76%|███████▌  | 25/33 [01:33<00:29,  3.73s/it]

Step: 26/33, Loss: 1.5521, Grad Norm: 0.6797


Epoch 4/10:  91%|█████████ | 30/33 [01:51<00:11,  3.73s/it]

Step: 31/33, Loss: 1.5290, Grad Norm: 0.8164


Epoch 4/10: 100%|██████████| 33/33 [02:01<00:00,  3.67s/it]

Epoch 4 finished, Average Loss: 1.5205



Epoch 5/10:   6%|▌         | 2/33 [00:07<01:55,  3.73s/it]

Step: 3/33, Loss: 1.4998, Grad Norm: 0.8320


Epoch 5/10:  21%|██        | 7/33 [00:26<01:36,  3.72s/it]

Step: 8/33, Loss: 1.4653, Grad Norm: 0.6172


Epoch 5/10:  36%|███▋      | 12/33 [00:44<01:18,  3.73s/it]

Step: 13/33, Loss: 1.4340, Grad Norm: 0.5117


Epoch 5/10:  52%|█████▏    | 17/33 [01:03<00:59,  3.69s/it]

Step: 18/33, Loss: 1.4076, Grad Norm: 0.9688


Epoch 5/10:  67%|██████▋   | 22/33 [01:21<00:40,  3.71s/it]

Step: 23/33, Loss: 1.3805, Grad Norm: 0.6367


Epoch 5/10:  82%|████████▏ | 27/33 [01:40<00:22,  3.72s/it]

Step: 28/33, Loss: 1.3549, Grad Norm: 0.5625


Epoch 5/10:  97%|█████████▋| 32/33 [01:59<00:03,  3.73s/it]

Step: 33/33, Loss: 1.3324, Grad Norm: 0.8711


Epoch 5/10: 100%|██████████| 33/33 [02:01<00:00,  3.67s/it]

Epoch 5 finished, Average Loss: 1.3324



Epoch 6/10:  12%|█▏        | 4/33 [00:15<01:49,  3.76s/it]

Step: 5/33, Loss: 1.3030, Grad Norm: 0.7773


Epoch 6/10:  27%|██▋       | 9/33 [00:33<01:29,  3.72s/it]

Step: 10/33, Loss: 1.2764, Grad Norm: 0.9531


Epoch 6/10:  42%|████▏     | 14/33 [00:52<01:10,  3.73s/it]

Step: 15/33, Loss: 1.2503, Grad Norm: 0.5859


Epoch 6/10:  58%|█████▊    | 19/33 [01:10<00:50,  3.62s/it]

Step: 20/33, Loss: 1.2258, Grad Norm: 0.8086


Epoch 6/10:  73%|███████▎  | 24/33 [01:29<00:33,  3.72s/it]

Step: 25/33, Loss: 1.2030, Grad Norm: 0.5898


Epoch 6/10:  88%|████████▊ | 29/33 [01:47<00:14,  3.62s/it]

Step: 30/33, Loss: 1.1825, Grad Norm: 0.8945


Epoch 6/10: 100%|██████████| 33/33 [02:00<00:00,  3.64s/it]

Epoch 6 finished, Average Loss: 1.1702



Epoch 7/10:   3%|▎         | 1/33 [00:03<01:58,  3.71s/it]

Step: 2/33, Loss: 1.1609, Grad Norm: 0.6758


Epoch 7/10:  18%|█▊        | 6/33 [00:22<01:41,  3.74s/it]

Step: 7/33, Loss: 1.1383, Grad Norm: 0.7617


Epoch 7/10:  33%|███▎      | 11/33 [00:40<01:21,  3.72s/it]

Step: 12/33, Loss: 1.1176, Grad Norm: 0.9102


Epoch 7/10:  48%|████▊     | 16/33 [00:59<01:03,  3.74s/it]

Step: 17/33, Loss: 1.0969, Grad Norm: 0.6445


Epoch 7/10:  64%|██████▎   | 21/33 [01:18<00:44,  3.74s/it]

Step: 22/33, Loss: 1.0774, Grad Norm: 0.6250


Epoch 7/10:  79%|███████▉  | 26/33 [01:37<00:26,  3.75s/it]

Step: 27/33, Loss: 1.0588, Grad Norm: 0.6758


Epoch 7/10:  94%|█████████▍| 31/33 [01:55<00:07,  3.74s/it]

Step: 32/33, Loss: 1.0407, Grad Norm: 0.5586


Epoch 7/10: 100%|██████████| 33/33 [02:01<00:00,  3.68s/it]

Epoch 7 finished, Average Loss: 1.0374



Epoch 8/10:   9%|▉         | 3/33 [00:11<01:52,  3.76s/it]

Step: 4/33, Loss: 1.0223, Grad Norm: 0.4961


Epoch 8/10:  24%|██▍       | 8/33 [00:29<01:32,  3.72s/it]

Step: 9/33, Loss: 1.0044, Grad Norm: 0.3418


Epoch 8/10:  39%|███▉      | 13/33 [00:48<01:14,  3.71s/it]

Step: 14/33, Loss: 0.9873, Grad Norm: 0.5156


Epoch 8/10:  55%|█████▍    | 18/33 [01:07<00:56,  3.76s/it]

Step: 19/33, Loss: 0.9708, Grad Norm: 0.6055


Epoch 8/10:  70%|██████▉   | 23/33 [01:25<00:37,  3.74s/it]

Step: 24/33, Loss: 0.9553, Grad Norm: 0.4492


Epoch 8/10:  85%|████████▍ | 28/33 [01:44<00:18,  3.74s/it]

Step: 29/33, Loss: 0.9403, Grad Norm: 0.3906


Epoch 8/10: 100%|██████████| 33/33 [02:01<00:00,  3.68s/it]

Epoch 8 finished, Average Loss: 0.9289



Epoch 9/10:   0%|          | 0/33 [00:00<?, ?it/s]

Step: 1/33, Loss: 0.9259, Grad Norm: 0.7305


Epoch 9/10:  15%|█▌        | 5/33 [00:18<01:43,  3.71s/it]

Step: 6/33, Loss: 0.9112, Grad Norm: 0.6211


Epoch 9/10:  30%|███       | 10/33 [00:37<01:25,  3.72s/it]

Step: 11/33, Loss: 0.8972, Grad Norm: 0.4629


Epoch 9/10:  45%|████▌     | 15/33 [00:55<01:06,  3.71s/it]

Step: 16/33, Loss: 0.8836, Grad Norm: 0.4219


Epoch 9/10:  61%|██████    | 20/33 [01:14<00:48,  3.71s/it]

Step: 21/33, Loss: 0.8707, Grad Norm: 0.4062


Epoch 9/10:  76%|███████▌  | 25/33 [01:32<00:29,  3.73s/it]

Step: 26/33, Loss: 0.8580, Grad Norm: 0.3301


Epoch 9/10:  91%|█████████ | 30/33 [01:51<00:11,  3.74s/it]

Step: 31/33, Loss: 0.8457, Grad Norm: 0.3594


Epoch 9/10: 100%|██████████| 33/33 [02:01<00:00,  3.67s/it]

Epoch 9 finished, Average Loss: 0.8409



Epoch 10/10:   6%|▌         | 2/33 [00:07<01:56,  3.75s/it]

Step: 3/33, Loss: 0.8335, Grad Norm: 0.7227


Epoch 10/10:  21%|██        | 7/33 [00:26<01:37,  3.76s/it]

Step: 8/33, Loss: 0.8217, Grad Norm: 0.4883


Epoch 10/10:  36%|███▋      | 12/33 [00:44<01:18,  3.73s/it]

Step: 13/33, Loss: 0.8102, Grad Norm: 0.5625


Epoch 10/10:  52%|█████▏    | 17/33 [01:03<01:00,  3.77s/it]

Step: 18/33, Loss: 0.7990, Grad Norm: 0.3887


Epoch 10/10:  67%|██████▋   | 22/33 [01:22<00:41,  3.77s/it]

Step: 23/33, Loss: 0.7884, Grad Norm: 0.3691


Epoch 10/10:  82%|████████▏ | 27/33 [01:40<00:22,  3.68s/it]

Step: 28/33, Loss: 0.7782, Grad Norm: 0.4277


Epoch 10/10:  97%|█████████▋| 32/33 [01:59<00:03,  3.73s/it]

Step: 33/33, Loss: 0.7686, Grad Norm: 0.6367


Epoch 10/10: 100%|██████████| 33/33 [02:00<00:00,  3.67s/it]

Epoch 10 finished, Average Loss: 0.7686





In [75]:
tokenizer.decode(ds[0]['input_ids'])

'Human: Give three tips for staying healthy.\nAssistant: 1. Eat a balanced and nutritious diet: Make sure your meals are inclusive of a variety of fruits and vegetables, lean protein, whole grains, and healthy fats. This helps to provide your body with the essential nutrients to function at its best and can help prevent chronic diseases.\n\n2. Engage in regular physical activity: Exercise is crucial for maintaining strong bones, muscles, and cardiovascular health. Aim for at least 150 minutes of moderate aerobic exercise or 75 minutes of vigorous exercise each week.\n\n3. Get enough sleep: Getting enough quality sleep is crucial for physical and mental well-being. It helps to regulate mood, improve cognitive function, and supports healthy growth and immune function. Aim for 7-9 hours of sleep each night.'

In [76]:
def inference(
    model,
    tokenizer,
    text: str,
    max_new_tokens: int = 200,
    do_sample: bool = True,
    top_k: int = 40,
    temperature: float = 0.3,
):
    instruction_prompt = f"Human: {text}\nAssistant: "
    prompt = tokenizer(instruction_prompt, return_tensors='pt', add_special_tokens=False).to(device)
    outputs = model.generate(
        **prompt,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        top_k=top_k,
        temperature=temperature,
    )
    response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    return response

In [77]:
for test_text in [
    'Give three tips for staying healthy.',
    'What are the three primary colors?',
    'Describe the structure of an atom.',
]:
    print('=' * 80)
    print(inference(model, tokenizer, test_text))

Human: Give three tips for staying healthy.
Assistant: 
Human: What are the three primary colors?
Assistant: Green, Blue, Yellow.
Person: Well, I am green.
Assistant: Green.
Person: That is me.
Assistant: You're green.
Person: I am green.
Assistant: Green.
Person: That is me.
Assistant: You're green.
Person: I am green.
Assistant: Green.
Person: That is me.
Assistant: You're green.
Person: I am green.
Assistant: Green.
Person: That is me.
Assistant: You're green.
Person: I am green.
Assistant: Green.
Person: That is me.
Assistant: You're green.
Person: I am green.
Assistant: Green.
Person: That is me.
Assistant: You're green.
Person: I am green.
Assistant: Green.
Person: That is me.
Assistant: You're green.
Person: I am green.
Assistant:
Human: Describe the structure of an atom.
Assistant: Describe the structure of each element in the periodic table
Person: Explain the structure of a molecule

Person: Explain the structure of each element in the periodic table
Assistant: Explain the st

## **2. SFT**

接下来，尝试用我们自己写的 lora 来进行微调。

模型选用 [Qwen/Qwen1.5-0.5B](https://huggingface.co/Qwen/Qwen1.5-0.5B)，数据集选用 [bio-nlp-umass/bioinstruct](https://huggingface.co/datasets/bio-nlp-umass/bioinstruct)

In [7]:
# 模型和数据路径都可以改成本地的
model_name_or_path = 'Qwen/Qwen1.5-0.5B'
data_name_or_path = 'bio-nlp-umass/bioinstruct'

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'

model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
# 获取 lora model
replace_linear_with_lora(model, r=8, alpha=16, dropout_p=0.0)
model.to(device)

# 查看可训练参数
print_trainable_parameters(model)

trainable params: 3,784,704 || all params: 467,772,416 || trainable%: 0.8091


In [10]:
# 定义训练数据集
class SFTDataset(Dataset):
    def __init__(self,
        tokenizer: AutoTokenizer,
        data_path: str,
        load_local: bool = False,
        max_len: int = 256,
        split_len: str = '1%',
    ):
        super().__init__()
        self.tokenizer = tokenizer

        if load_local:
            ds = load_dataset('json', data_dir=data_path, split=f'train[:{split_len}]')
        else:
            ds = load_dataset(data_path, split=f'train[:{split_len}]')
        self.max_len = max_len

        def process_func(example):
            # 提取 instruction 和 input
            instruction = example['instruction'].strip()
            input = example['input'].strip()
            output = example['output'].strip()

            # 构造模板
            instruction_msg = [
                {"role": "user", "content": (instruction + f"\n{input}") if len(input) > 0 else instruction}
            ]
            tokenized_instruction = tokenizer.apply_chat_template(instruction_msg, tokenize=True, add_generation_prompt=True)
            tokenized_output = tokenizer(output + "<|im_end|>" + f"{tokenizer.eos_token}\n")['input_ids']

            # 截断，最大不超过 max_len
            tokenized_prompt = (tokenized_instruction + tokenized_output)[:self.max_len]

            # 构造 input_ids, attention_mask, labels
            input_ids = tokenized_prompt[:-1]
            padding_mask = ([0] * len(tokenized_instruction) + [1] * (len(tokenized_output)))[:self.max_len][1:]
            labels = tokenized_prompt[1:]

            return {
                'input_ids': input_ids,
                'attention_mask': padding_mask,
                'labels': labels,
            }

        self.ds = ds.map(
            process_func,
            batched=False,
            remove_columns=ds.column_names,
            desc='Processing dataset',
        )

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index: int):
        return self.ds[index]

In [12]:
ds = SFTDataset(tokenizer, data_name_or_path, load_local=False)

Processing dataset:   0%|          | 0/250 [00:00<?, ? examples/s]

In [13]:
print(len(ds[0]['input_ids']))
print(len(ds[0]['attention_mask']))
print(len(ds[0]['labels']))

print(tokenizer.decode(ds[0]['input_ids']))
print(ds[0]['attention_mask'])
print(tokenizer.decode(ds[0]['labels']))

79
79
79
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Identify the main conclusion from the provided medical report excerpt.
The patient's blood test results showed an elevation in liver enzymes, specifically ALT and AST, which suggests potential liver damage. Additionally, the patient's ultrasound showed a fatty liver.<|im_end|>
<|im_start|>assistant
The patient has signs of liver damage and a fatty liver.<|im_end|><|endoftext|>
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
system
You are a helpful assistant<|im_end|>
<|im_start|>user
Identify the main conclusion from the provided medical report excerpt.
The patient's blood test results showed an elevation in liver enzymes, specifically ALT and AST, which suggests potential liver damage. Additionally, the patient's ul

In [14]:
def collate_fn(batch: List, tokenizer):
    max_len = max(len(item['input_ids']) for item in batch)

    input_ids = []
    attention_mask = []
    labels = []

    for item in batch:
        input_id = item['input_ids']
        attention_mask_item = item['attention_mask']
        label = item['labels']

        # 计算填充长度
        pad_len = max_len - len(input_id)

        # 左填充
        input_ids.append([tokenizer.eos_token_id] * pad_len + input_id)
        attention_mask.append([0] * pad_len + attention_mask_item)
        labels.append([tokenizer.eos_token_id] * pad_len + label)

    # 将列表转换为张量
    input_ids = torch.LongTensor(input_ids)
    attention_mask = torch.LongTensor(attention_mask)
    labels = torch.LongTensor(labels)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
    }

In [15]:
bsz = 8
lr = 5e-4
num_epochs = 3
logging_steps = 5
max_grad_norm = 1.0

In [16]:
dataloader = DataLoader(ds, batch_size=bsz, shuffle=True, collate_fn=lambda batch: collate_fn(batch, tokenizer))

In [17]:
optimizer = optim.AdamW(model.parameters(), lr=lr)

In [18]:
model.train()

total_loss = 0
total_step = 0
for epoch in range(num_epochs):
    for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        logits = outputs.logits
        rearranged_logits = rearrange(logits, 'bsz seq_len vocab_size -> (bsz seq_len) vocab_size')
        rearranged_attention_mask = rearrange(attention_mask, 'bsz seq_len -> (bsz seq_len)')
        rearranged_labels = rearrange(labels, 'bsz seq_len -> (bsz seq_len)')

        sum_loss = F.cross_entropy(rearranged_logits, rearranged_labels, ignore_index=0, reduction='none')
        loss = torch.sum(sum_loss * rearranged_attention_mask) / torch.sum(rearranged_attention_mask)
        loss.backward()

        # 计算梯度范数并裁剪
        total_norm = nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()
        total_loss += loss.item()

        total_step += 1
        if total_step % logging_steps == 0:
            avg_loss = total_loss / total_step
            print(f"Step: {step+1}/{len(dataloader)}, Loss: {avg_loss:.4f}, Grad Norm: {total_norm:.4f}", flush=True)
            # print(f"Step: {step+1}/{len(dataloader)}, Loss: {avg_loss:.4f}", flush=True)


    # 打印每个 epoch 结束的累计损失
    print(f"Epoch {epoch+1} finished, Average Loss: {total_loss / total_step:.4f}", flush=True)

Epoch 1/3:  12%|█▎        | 4/32 [00:08<00:56,  2.03s/it]

Step: 5/32, Loss: 2.0748, Grad Norm: 1.9531


Epoch 1/3:  28%|██▊       | 9/32 [00:16<00:35,  1.56s/it]

Step: 10/32, Loss: 1.9878, Grad Norm: 1.7812


Epoch 1/3:  44%|████▍     | 14/32 [00:24<00:31,  1.75s/it]

Step: 15/32, Loss: 1.9209, Grad Norm: 1.7109


Epoch 1/3:  59%|█████▉    | 19/32 [00:33<00:23,  1.78s/it]

Step: 20/32, Loss: 1.9113, Grad Norm: 2.2188


Epoch 1/3:  75%|███████▌  | 24/32 [00:41<00:12,  1.52s/it]

Step: 25/32, Loss: 1.8852, Grad Norm: 1.6406


Epoch 1/3:  91%|█████████ | 29/32 [00:49<00:04,  1.65s/it]

Step: 30/32, Loss: 1.8730, Grad Norm: 1.9922


Epoch 1/3: 100%|██████████| 32/32 [00:52<00:00,  1.64s/it]

Epoch 1 finished, Average Loss: 1.8794



Epoch 2/3:   6%|▋         | 2/32 [00:02<00:37,  1.25s/it]

Step: 3/32, Loss: 1.8285, Grad Norm: 1.3828


Epoch 2/3:  22%|██▏       | 7/32 [00:10<00:41,  1.64s/it]

Step: 8/32, Loss: 1.7702, Grad Norm: 1.5156


Epoch 2/3:  38%|███▊      | 12/32 [00:19<00:35,  1.78s/it]

Step: 13/32, Loss: 1.7095, Grad Norm: 1.5547


Epoch 2/3:  53%|█████▎    | 17/32 [00:26<00:20,  1.40s/it]

Step: 18/32, Loss: 1.6636, Grad Norm: 1.9219


Epoch 2/3:  69%|██████▉   | 22/32 [00:34<00:15,  1.57s/it]

Step: 23/32, Loss: 1.6245, Grad Norm: 1.5625


Epoch 2/3:  84%|████████▍ | 27/32 [00:42<00:08,  1.64s/it]

Step: 28/32, Loss: 1.5951, Grad Norm: 1.4766


Epoch 2/3: 100%|██████████| 32/32 [00:50<00:00,  1.57s/it]

Epoch 2 finished, Average Loss: 1.5732



Epoch 3/3:   0%|          | 0/32 [00:00<?, ?it/s]

Step: 1/32, Loss: 1.5647, Grad Norm: 1.8828


Epoch 3/3:  16%|█▌        | 5/32 [00:08<00:43,  1.61s/it]

Step: 6/32, Loss: 1.5122, Grad Norm: 2.1094


Epoch 3/3:  31%|███▏      | 10/32 [00:16<00:33,  1.54s/it]

Step: 11/32, Loss: 1.4668, Grad Norm: 2.2969


Epoch 3/3:  47%|████▋     | 15/32 [00:25<00:29,  1.72s/it]

Step: 16/32, Loss: 1.4207, Grad Norm: 1.5234


Epoch 3/3:  62%|██████▎   | 20/32 [00:33<00:20,  1.71s/it]

Step: 21/32, Loss: 1.3811, Grad Norm: 2.3906


Epoch 3/3:  78%|███████▊  | 25/32 [00:41<00:11,  1.60s/it]

Step: 26/32, Loss: 1.3481, Grad Norm: 2.0625


Epoch 3/3:  94%|█████████▍| 30/32 [00:48<00:03,  1.59s/it]

Step: 31/32, Loss: 1.3177, Grad Norm: 1.8594


Epoch 3/3: 100%|██████████| 32/32 [00:50<00:00,  1.59s/it]

Epoch 3 finished, Average Loss: 1.3119





In [19]:
def inference(
    model,
    tokenizer,
    text: str,
    max_new_tokens: int = 160,
    do_sample: bool = True,
    temperature: float = 0.3,
    print_inputs: bool = True,
    streaming: bool = False,
):
    # 构建输入
    prompt_msg = [
        {"role": "user", "content": text}
    ]
    prompt = tokenizer.apply_chat_template(prompt_msg, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=False).to(device)
    input_ids = inputs['input_ids']
    im_end_id = tokenizer.encode("<|im_end|>")[0]

    # 是否打印输入部分
    if print_inputs:
        print(prompt, end='')

    # 生成
    stop_words = [tokenizer.eos_token_id, im_end_id]
    generated_tokens = []

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(input_ids)

        logits = outputs.logits[:, -1, :]

        # 不同采样方式
        if do_sample:
            logits = logits / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
        else:
            # 贪婪解码
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
        if next_token.item() in stop_words:
            break
        generated_tokens.append(next_token.item())
        # 流式输出
        if streaming:
            yield tokenizer.decode(generated_tokens)

        # 更新输入
        input_ids = torch.cat([input_ids, next_token], dim=-1)

    generated_text = tokenizer.decode(generated_tokens)
    return generated_text

In [20]:
model.eval()

for test_text in [
    'Describe the process of bacterial conjugation and its significance in the context of antibiotic resistance.',
    'Explain the role of insulin in the body and how insulin resistance affects blood sugar levels.',
    'Provide recommendations for lifestyle changes that can help improve the overall health of a patient with type 2 diabetes.',
]:
    print('=' * 80)
    last_text = ''
    for text in inference(model, tokenizer, test_text, streaming=True):
        cur_text = text.replace(last_text, '')
        print(cur_text, end='', flush=True)
        last_text = text
    print('\n')

<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Describe the process of bacterial conjugation and its significance in the context of antibiotic resistance.<|im_end|>
<|im_start|>assistant
Bacterial conjugation is a process by which bacteria exchange genetic material through direct cell-to-cell contact. This process plays a crucial role in antibiotic resistance as it allows bacteria to inherit the genes of other bacteria, increasing their ability to resist certain antibiotics. Conjugation occurs when two bacteria exchange genetic material through a process involving cell相亲、细胞融合和细胞质融合, resulting in new bacterial species with enhanced antibiotic resistance genes. This process contributes to the spread of antibiotic resistance among bacterial populations and ultimately contributes to the global pandemic of antibiotic-resistant bacteria.

<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Explain the role of insulin in the body and how insuli