# 0.Preparatiion

In [57]:
# install required packages
# -q stands for "quiet" and suppresses output of the installation process.
# -U stands for "upgrade" and ensures that the latest version of the package is installed.
! pip install -qU peft accelerate datasets einops

In [3]:
import copy
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm
from typing import List
from einops import rearrange
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if device != 'cpu' and torch.cuda.is_bf16_supported() else torch.float32
print(f'device: {device}\ndtype: {dtype}')

device: cpu
dtype: torch.float32


# 1. LoRA by custom llama model

## 1.1 declare lora module

create a small model, using llama as example:

In [59]:
from transformers import LlamaConfig, LlamaForCausalLM

# 创建 LLaMA 模型的配置
config = LlamaConfig(
    hidden_size=24,
    intermediate_size=24 * 4,  # 根据你的公式
    num_attention_heads=4,
    num_hidden_layers=4,
    num_key_value_heads=2,
    vocab_size=128
)

# 使用配置创建模型
# raw_model = LlamaForCausalLM(config) # 创建的是带lm_head的模型
# 使用 AutoModel 创建模型,创建的是无特定任务头的模型
raw_model = AutoModel.from_config(config)

# 打印模型结构
print(raw_model)

LlamaModel(
  (embed_tokens): Embedding(128, 24)
  (layers): ModuleList(
    (0-3): 4 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=24, out_features=24, bias=False)
        (k_proj): Linear(in_features=24, out_features=12, bias=False)
        (v_proj): Linear(in_features=24, out_features=12, bias=False)
        (o_proj): Linear(in_features=24, out_features=24, bias=False)
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=24, out_features=96, bias=False)
        (up_proj): Linear(in_features=24, out_features=96, bias=False)
        (down_proj): Linear(in_features=96, out_features=24, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm((24,), eps=1e-06)
      (post_attention_layernorm): LlamaRMSNorm((24,), eps=1e-06)
    )
  )
  (norm): LlamaRMSNorm((24,), eps=1e-06)
  (rotary_emb): LlamaRotaryEmbedding()
)


then we will create our LoRA class:


there is a 'test_model' attribute in the class to control whether lora_B is full zero or not, if true, lora_B is full zero, else lora_B is not zero

In [9]:
class LoraLinear(nn.Module):
    def __init__(
    self,
    base_layer: nn.Linear,
    r: int = 8,
    alpha: int = 16,
    dropout_p: float = 0.0,
    test_mode: bool = False,
    ):
        super(LoraLinear, self).__init__()
        self.base_layer = copy.deepcopy(base_layer)
        self.r = r
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout_p)
        
        # 因为轻量化原因，我们只需要使用nn.Parameter来创建两个参数矩阵
        # 使用linear的话会包含优化器状态，梯度计算等功能，无需这么多
        self.lora_A = nn.Parameter(torch.empty(self.r, self.base_layer.in_features, dtype=self.base_layer.weight.dtype))
        self.lora_B = nn.Parameter(torch.empty(self.base_layer.out_features, self.r, dtype=self.base_layer.weight.dtype))
        nn.init.normal_(self.lora_A, mean=0.0, std=0.02)
        if test_mode:
            nn.init.normal_(self.lora_B, mean=0.0, std=0.02)
        else:
            nn.init.zeros_(self.lora_B)
        # frozen the parameters of base layer
        for param in self.base_layer.parameters():
            param.requires_grad = False
        
    def forward(self, x):
        scaling = float(self.alpha) / float(self.r)
        # linear是一个线性变换操作，y = x @ W^T + b
        # x's shape is (batch_size, seq_length, in_features)
        # F.linear 期望权重矩阵的形状为 (out_features, in_features)，故lora_A的形状为 (r, in_features)
        # output shape is (batch_size, seq_length, r)
        lora_adjustment = F.linear(self.dropout(x), self.lora_A)
        lora_adjustment = F.linear(lora_adjustment, self.lora_B) 
        return self.base_layer(x) + lora_adjustment * scaling    
        

In [61]:
for name, child in raw_model.named_children():
    print(f"name: {name}, child:{child}")

name: embed_tokens, child:Embedding(128, 24)
name: layers, child:ModuleList(
  (0-3): 4 x LlamaDecoderLayer(
    (self_attn): LlamaAttention(
      (q_proj): Linear(in_features=24, out_features=24, bias=False)
      (k_proj): Linear(in_features=24, out_features=12, bias=False)
      (v_proj): Linear(in_features=24, out_features=12, bias=False)
      (o_proj): Linear(in_features=24, out_features=24, bias=False)
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=24, out_features=96, bias=False)
      (up_proj): Linear(in_features=24, out_features=96, bias=False)
      (down_proj): Linear(in_features=96, out_features=24, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm((24,), eps=1e-06)
    (post_attention_layernorm): LlamaRMSNorm((24,), eps=1e-06)
  )
)
name: norm, child:LlamaRMSNorm((24,), eps=1e-06)
name: rotary_emb, child:LlamaRotaryEmbedding()


declare the func to replace the original linear layer to lora linear layer

## 1.2 declare replace_func

In [7]:
def replace_linear_with_lora(
    module: nn.Module,
    r: int = 8,
    alpha: int = 16,
    dropout_p: float = 0.0,
    embed_requires_grad: bool = False,
    norm_requires_grad: bool = False,
    head_requires_grad: bool = False,
    test_mode: bool = False,
):
    for name, child in module.named_children():
        if any(s in name for s in ["embed", "norm", "head"]):
            requires_grad = embed_requires_grad if "embed" in name \
                else norm_requires_grad if "norm" in name \
                else head_requires_grad
            for param in child.parameters():
                param.requires_grad = requires_grad
        elif isinstance(child, nn.Linear):
            lora_linear = LoraLinear(child, r, alpha, dropout_p, test_mode=test_mode)
            # 将模块 module 中名为 name 的子模块替换为新的 LoraLinear 实例
            # 该module之后的weight名称为lora_A和lora_B,所以打印的时候会看到是lora_A和lora_B
            setattr(module, name, lora_linear)
        else:
            replace_linear_with_lora(
                child,r,alpha,dropout_p,
                embed_requires_grad,norm_requires_grad,head_requires_grad
            )

## 1.3 declare unload and load functions

In [63]:
def unload_lora(module: nn.Module, adapter_name: str = 'adapter'):
    lora_parameters = {}
    def search_lora_linear(module: nn.Module, prefix: List[str]):
        for name, child in module.named_children():
            new_prefix = prefix + [name]
            if isinstance(child, LoraLinear):
                lora_parameters['.'.join(new_prefix)] = {
                    "lora_A_weight": child.lora_A.data.cpu(),
                    "lora_B_weight": child.lora_B.data.cpu(),
                    "r": child.r,
                    "alpha": child.alpha,
                    "dropout_p": child.dropout.p, # 这里是dropout.p，因为我们这个module的成员变量是dropout，所以这里取dropout.p，没有dropout_p
                }
                setattr(module, name, child.base_layer)
            else:
                search_lora_linear(child, new_prefix)
    
    search_lora_linear(module, [])
    # 解冻原模型的所有参数
    for name, param in module.named_parameters():
        param.requires_grad = True

    torch.save(lora_parameters, f"{adapter_name}.pt")

In [None]:
def load_lora(module: nn.Module, adapter_name: str = 'adapter'):
    lora_parameters = torch.load(f'{adapter_name}.pt')
    
    for name, lora_params in lora_parameters.items():
        # 取出那些需要被替换成lora module的module,别写成named_children。
        child = dict(module.named_modules())[name]
        if isinstance(child, nn.Linear):
            lora_linear = LoraLinear(child, lora_params['r'], lora_params['alpha'], lora_params['dropout_p'] )
            lora_linear.lora_A.data = lora_params['lora_A_weight'].to(lora_linear.lora_A.device)
            lora_linear.lora_B.data = lora_params['lora_B_weight'].to(lora_linear.lora_B.device)
            
            parts = name.split(".")
            obj = module
            for part in parts[:-1]:
                obj = getattr(obj, part)
            setattr(obj, parts[-1], lora_linear)
    
    for name, param in module.named_parameters():
        if any(s in name for s in ['embed', 'norm', 'lm_head']):
            param.requires_grad = False

## 1.4 test the create of lora_llama 

In [10]:
def print_trainable_parameters(model: nn.Module):
    """
    打印可训练参数，和 PeftModel 的方法类似
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    trainable_percentage = 100 * trainable_params / total_params

    print(f"trainable params: {trainable_params:,} || all params: {total_params:,} || trainable%: {trainable_percentage:.4f}")

In [66]:
print_trainable_parameters(raw_model)

trainable params: 37,848 || all params: 37,848 || trainable%: 100.0000


In [67]:
lora_model = copy.deepcopy(raw_model)
replace_linear_with_lora(lora_model)
print_trainable_parameters(lora_model)


trainable params: 16,896 || all params: 54,744 || trainable%: 30.8637


In [68]:
print(lora_model)

LlamaModel(
  (embed_tokens): Embedding(128, 24)
  (layers): ModuleList(
    (0-3): 4 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): LoraLinear(
          (base_layer): Linear(in_features=24, out_features=24, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (k_proj): LoraLinear(
          (base_layer): Linear(in_features=24, out_features=12, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (v_proj): LoraLinear(
          (base_layer): Linear(in_features=24, out_features=12, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (o_proj): LoraLinear(
          (base_layer): Linear(in_features=24, out_features=24, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
      (mlp): LlamaMLP(
        (gate_proj): LoraLinear(
          (base_layer): Linear(in_features=24, out_features=96, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)

In [69]:
def print_model_parameters(model):
    print("Layer Name & Parameters")
    print("-------------------")
    for name, param in model.named_parameters():
        print(f"{name:50} | Requires Grad: {param.requires_grad}")

In [70]:
print_model_parameters(lora_model)

Layer Name & Parameters
-------------------
embed_tokens.weight                                | Requires Grad: False
layers.0.self_attn.q_proj.lora_A                   | Requires Grad: True
layers.0.self_attn.q_proj.lora_B                   | Requires Grad: True
layers.0.self_attn.q_proj.base_layer.weight        | Requires Grad: False
layers.0.self_attn.k_proj.lora_A                   | Requires Grad: True
layers.0.self_attn.k_proj.lora_B                   | Requires Grad: True
layers.0.self_attn.k_proj.base_layer.weight        | Requires Grad: False
layers.0.self_attn.v_proj.lora_A                   | Requires Grad: True
layers.0.self_attn.v_proj.lora_B                   | Requires Grad: True
layers.0.self_attn.v_proj.base_layer.weight        | Requires Grad: False
layers.0.self_attn.o_proj.lora_A                   | Requires Grad: True
layers.0.self_attn.o_proj.lora_B                   | Requires Grad: True
layers.0.self_attn.o_proj.base_layer.weight        | Requires Grad: False
la

## 1.5 test the load and unload function

In [71]:
# 创建一个测试 tensor
bsz = 2
seq_len = 8
# 生成一个张量，其元素是从 [low, high) 范围内的整数均匀随机采样，size是tensor的形状
test_tensor = torch.randint(0, config.vocab_size, (bsz, seq_len))

In [72]:
lora_model_test = copy.deepcopy(raw_model)
replace_linear_with_lora(lora_model_test, test_mode=True)

In [73]:
raw_model.eval()
print_trainable_parameters(raw_model)
raw_res = raw_model(test_tensor).last_hidden_state

trainable params: 37,848 || all params: 37,848 || trainable%: 100.0000


In [74]:
raw_res.shape # [batch_size, seq_len, hidden_size]

torch.Size([2, 8, 24])

In [75]:
lora_model.eval()
print_trainable_parameters(lora_model)
before_unload_res = lora_model(test_tensor).last_hidden_state

trainable params: 16,896 || all params: 54,744 || trainable%: 30.8637


In [76]:
unload_lora(lora_model)


In [77]:
lora_model.eval()
print_trainable_parameters(lora_model)
before_unload_res = lora_model(test_tensor).last_hidden_state

trainable params: 37,848 || all params: 37,848 || trainable%: 100.0000


In [84]:
load_lora(lora_model)

In [85]:
# 重新装载 lora 后的前向结果

lora_model.eval()
print_trainable_parameters(lora_model)  # 检查参数和可训练情况
load_res = lora_model(test_tensor).last_hidden_state

trainable params: 16,896 || all params: 54,744 || trainable%: 30.8637


# 2. lora finetuning by litellama-460M-1T model

## 2.1 加载模型和分词器

In [4]:
model_name = 'ahxt/LiteLlama-460M-1T'
data_name = 'vicgalle/alpaca-gpt4'

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device)

tokenizer_config.json:   0%|          | 0.00/252 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/364 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/923M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

## 2.2 将模型的线性层替换成lora module

In [11]:
replace_linear_with_lora(model, r=8, alpha=16, dropout_p=0.0)
model.to(device)

print_trainable_parameters(model)

trainable params: 4,177,920 || all params: 465,863,680 || trainable%: 0.8968


In [12]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(50304, 1024, padding_idx=0)
    (layers): ModuleList(
      (0-23): 24 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): LoraLinear(
            (base_layer): Linear(in_features=1024, out_features=1024, bias=False)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (k_proj): LoraLinear(
            (base_layer): Linear(in_features=1024, out_features=128, bias=False)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (v_proj): LoraLinear(
            (base_layer): Linear(in_features=1024, out_features=128, bias=False)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (o_proj): LoraLinear(
            (base_layer): Linear(in_features=1024, out_features=1024, bias=False)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (mlp): LlamaMLP(
          (gate_proj): LoraLinear(
            (bas

## 2.3 数据预处理

### 2.3.1 构造数据集类

In [None]:
class SFTDataset(Dataset):
    def __init__(
        self,
        tokenizer: AutoTokenizer,
        data_path: str,
        load_local: bool = False,
        max_len: int = 256,
        split_len: str = '1%',
    ):
        super().__init__()
        self.tokenizer = tokenizer
        
        if load_local:
            self.ds = load_dataset('json', data_dir=data_path, split=f'train[:{split_len}]')
        else:
            self.ds = load_dataset(data_path, split=f'train[:{split_len}]')
        self.max_len = max_len
        
        def process_func(example):
            instruction = example['instruction'].strip()
            input = example['input'].strip()
            output = example['output'].strip()
            
            instruction_prompt = f"Human: {instruction}\n" + \
                                (f"{input}\n" if len(input) > 0 else "") + \
                                    "Assistant: "
            
            output_prompt = f"{output}\n"
            
            # 截断，最大不超过 max_len
            tokenized_instruction = self.tokenizer(instruction_prompt, add_special_tokens=False)['input_ids']
            tokenized_output = self.tokenizer(output_prompt, add_special_tokens=False)['input_ids']
            # 将会丢弃超出max_len长度的那部分token，存在一定的问题，可能导致masked_attention全是0，然后导致loss为0
            tokenized_prompt = (tokenized_instruction + tokenized_output)[:self.max_len]
            
            # 构造 input_ids, attention_mask, labels
            input_ids = tokenized_prompt[:-1]
            padding_mask = ([0] * len(tokenized_instruction) + [1] * (len(tokenized_output)))[:self.max_len][1:]
            labels = tokenized_prompt[1:]

            return {
                'input_ids': torch.LongTensor(input_ids),
                'attention_mask': torch.LongTensor(padding_mask),
                'labels': torch.LongTensor(labels),
            }           
        self.ds = self.ds.map(
            process_func,
            batched=False,
            remove_columns=self.ds.column_names,
            desc='Processing dataset...'
        )    
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, index: int):
        return self.ds[index]

In [14]:
ds = SFTDataset(tokenizer, data_name, load_local=False)

README.md: 0.00B [00:00, ?B/s]

(…)-00000-of-00001-6ef3991c06080e14.parquet:   0%|          | 0.00/48.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/52002 [00:00<?, ? examples/s]

Processing dataset...:   0%|          | 0/520 [00:00<?, ? examples/s]

In [19]:
print(len(ds[0]['input_ids']))
print(ds[0]['input_ids'])


163
[20490, 25, 13786, 1115, 9040, 329, 10589, 5448, 13, 198, 48902, 25, 220, 16, 13, 27574, 257, 12974, 290, 48102, 5496, 25, 6889, 1654, 534, 13840, 389, 19889, 286, 257, 4996, 286, 15921, 290, 13701, 11, 10904, 7532, 11, 2187, 21824, 11, 290, 5448, 27997, 13, 770, 5419, 284, 2148, 534, 1767, 351, 262, 6393, 20901, 284, 2163, 379, 663, 1266, 290, 460, 1037, 2948, 10726, 10040, 13, 198, 198, 17, 13, 1985, 496, 287, 3218, 3518, 3842, 25, 32900, 318, 8780, 329, 10941, 1913, 11945, 11, 12749, 11, 290, 21134, 1535, 13, 36223, 329, 379, 1551, 6640, 2431, 286, 10768, 43294, 5517, 393, 5441, 2431, 286, 31543, 5517, 1123, 1285, 13, 198, 198, 18, 13, 3497, 1576, 3993, 25, 18067, 1576, 3081, 3993, 318, 8780, 329, 3518, 290, 5110, 880, 12, 11873, 13, 632, 5419, 284, 16697, 10038, 11, 2987, 10870, 2163, 11, 290, 6971, 5448, 3349, 290, 10900, 2163, 13, 36223, 329, 767, 12, 24, 2250, 286, 3993, 1123, 1755, 13]


In [20]:
print(len(ds[0]['attention_mask']))
print(ds[0]['attention_mask'])

163
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [21]:
print(len(ds[0]['labels']))
print(ds[0]['labels'])

163
[25, 13786, 1115, 9040, 329, 10589, 5448, 13, 198, 48902, 25, 220, 16, 13, 27574, 257, 12974, 290, 48102, 5496, 25, 6889, 1654, 534, 13840, 389, 19889, 286, 257, 4996, 286, 15921, 290, 13701, 11, 10904, 7532, 11, 2187, 21824, 11, 290, 5448, 27997, 13, 770, 5419, 284, 2148, 534, 1767, 351, 262, 6393, 20901, 284, 2163, 379, 663, 1266, 290, 460, 1037, 2948, 10726, 10040, 13, 198, 198, 17, 13, 1985, 496, 287, 3218, 3518, 3842, 25, 32900, 318, 8780, 329, 10941, 1913, 11945, 11, 12749, 11, 290, 21134, 1535, 13, 36223, 329, 379, 1551, 6640, 2431, 286, 10768, 43294, 5517, 393, 5441, 2431, 286, 31543, 5517, 1123, 1285, 13, 198, 198, 18, 13, 3497, 1576, 3993, 25, 18067, 1576, 3081, 3993, 318, 8780, 329, 3518, 290, 5110, 880, 12, 11873, 13, 632, 5419, 284, 16697, 10038, 11, 2987, 10870, 2163, 11, 290, 6971, 5448, 3349, 290, 10900, 2163, 13, 36223, 329, 767, 12, 24, 2250, 286, 3993, 1123, 1755, 13, 198]


In [17]:
print(ds[0])

{'input_ids': [20490, 25, 13786, 1115, 9040, 329, 10589, 5448, 13, 198, 48902, 25, 220, 16, 13, 27574, 257, 12974, 290, 48102, 5496, 25, 6889, 1654, 534, 13840, 389, 19889, 286, 257, 4996, 286, 15921, 290, 13701, 11, 10904, 7532, 11, 2187, 21824, 11, 290, 5448, 27997, 13, 770, 5419, 284, 2148, 534, 1767, 351, 262, 6393, 20901, 284, 2163, 379, 663, 1266, 290, 460, 1037, 2948, 10726, 10040, 13, 198, 198, 17, 13, 1985, 496, 287, 3218, 3518, 3842, 25, 32900, 318, 8780, 329, 10941, 1913, 11945, 11, 12749, 11, 290, 21134, 1535, 13, 36223, 329, 379, 1551, 6640, 2431, 286, 10768, 43294, 5517, 393, 5441, 2431, 286, 31543, 5517, 1123, 1285, 13, 198, 198, 18, 13, 3497, 1576, 3993, 25, 18067, 1576, 3081, 3993, 318, 8780, 329, 3518, 290, 5110, 880, 12, 11873, 13, 632, 5419, 284, 16697, 10038, 11, 2987, 10870, 2163, 11, 290, 6971, 5448, 3349, 290, 10900, 2163, 13, 36223, 329, 767, 12, 24, 2250, 286, 3993, 1123, 1755, 13], 'attention_mask': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,

### 2.3.2 构造dataloader

collate_fn函数的作用是将已经分好批的数据处理成符号模型输入格式的一个批次数据

In [24]:
def collate_fn(batch:List, tokenizer):
    max_len = max(len(item['input_ids']) for item in batch)
    
    input_ids = []
    attention_mask = []
    labels = []
    
    for item in batch:
        input_id = item['input_ids']
        item_attention_mask = item['attention_mask']
        label = item['labels']
        
        pad_len = max_len - len(input_id)
        
        input_ids.append([tokenizer.eos_token_id]*pad_len + input_id)
        attention_mask.append([0]*pad_len + item_attention_mask)
        labels.append([tokenizer.eos_token_id]*pad_len + label)
    
    input_ids = torch.LongTensor(input_ids)
    attention_mask = torch.LongTensor(attention_mask)
    labels = torch.LongTensor(labels)
    # 模型需要键分别为'input_ids', 'attention_mask', 'labels'的字典。
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }
        
    

## 2.4 训练

In [22]:
bsz = 16
lr = 1e-3
num_epochs = 10
logging_steps = 5
max_grad_norm = 1.0

In [29]:
dataloader = DataLoader(ds, batch_size=bsz, shuffle=True, collate_fn=lambda batch: collate_fn(batch, tokenizer))

In [26]:
print(dataloader)

<torch.utils.data.dataloader.DataLoader object at 0x329082750>


In [35]:
next(iter(dataloader))['input_ids'].shape

torch.Size([16, 255])

In [36]:
optimizer = optim.Adam(model.parameters(), lr=lr)

下面是参数更新的公式，$g_{lora\_A\_new}$是经过梯度裁剪后的梯度

$$\text{lora\_A} \gets \text{lora\_A} - \text{lr} \cdot \text{adam\_update}(g_{\text{lora\_A\_new}})$$


In [None]:
# 将模型设置为训练模式，启用 dropout 和 batch norm 等训练特有行为
model.train()

# 初始化累积损失和步数，用于计算平均损失
total_loss = 0
total_step = 0

# 外层循环：遍历指定的训练轮数（num_epochs）
for epoch in range(num_epochs):
    # 内层循环：遍历数据加载器（dataloader），每次获取一个批次（batch）
    # tqdm 用于显示训练进度条，desc 显示当前 epoch 信息
    for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        # 从 batch 中提取 input_ids, attention_mask, labels，并移动到指定设备（CPU/GPU）
        input_ids = batch['input_ids'].to(device)  # 输入 token ID，形状 (batch_size, seq_len)
        attention_mask = batch['attention_mask'].to(device)  # 注意力掩码，0 表示忽略，1 表示计算损失，形状 (batch_size, seq_len)
        labels = batch['labels'].to(device)  # 目标 token ID，形状 (batch_size, seq_len)
        
        # 清空优化器的梯度缓存，防止梯度累积
        optimizer.zero_grad()
        
        # 模型前向传播，输入 input_ids 和 attention_mask，输出包含 logits 和 loss
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        logits = outputs.logits  # 模型输出的预测 logits，形状 (batch_size, seq_len, vocab_size)
        
        # 使用 einops.rearrange 重塑 logits 和其他张量，便于交叉熵损失计算
        # 将 (batch_size, seq_len, vocab_size) 转换为 ((batch_size * seq_len), vocab_size)
        rearranged_logits = rearrange(logits, 'bsz seq_len vocab_size -> (bsz seq_len) vocab_size')
        # 将 attention_mask 从 (batch_size, seq_len) 转换为 (batch_size * seq_len)
        rearranged_attention_mask = rearrange(attention_mask, 'bsz seq_len -> (bsz seq_len)')
        # 将 labels 从 (batch_size, seq_len) 转换为 (batch_size * seq_len)
        rearranged_labels = rearrange(labels, 'bsz seq_len -> (bsz seq_len)')
        
        # 计算交叉熵损失，ignore_index=0 表示忽略 attention_mask 为 0 的位置
        # reduction='none' 返回每个 token 的损失，而不是平均值
        sum_loss = F.cross_entropy(rearranged_logits, rearranged_labels, ignore_index=0, reduction='none')
        
        # 根据 attention_mask 加权损失，只计算 mask 为 1 的位置的损失
        # 总损失 = (逐 token 损失 * mask) 的和 / mask 为 1 的位置数
        loss = torch.sum(sum_loss * rearranged_attention_mask) / torch.sum(rearranged_attention_mask)
        
        # 反向传播，计算梯度
        loss.backward()
        
        # 梯度裁剪，防止梯度爆炸，max_grad_norm 是最大梯度范数
        total_norm = nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        # 优化器更新模型参数（例如 LoRA 参数 lora_A 和 lora_B）
        # $$\text{lora\_A} \gets \text{lora\_A} - \text{lr} \cdot \text{adam\_update}(g_{\text{lora\_A\_new}})$$

        optimizer.step()
        
        # 累积损失值（loss.item() 获取标量值）
        total_loss += loss.item()
        
        # 累积训练步数
        total_step += 1
        
        # 每 logging_steps 步打印一次平均损失和梯度范数
        if total_step % logging_steps == 0:
            avg_loss = total_loss / total_step  # 计算平均损失
            print(f"Step: {step+1}/{len(dataloader)}, Loss: {avg_loss:.4f}, Grad Norm: {total_norm:.4f}", flush=True)
    
    # 每个 epoch 结束后打印平均损失
    print(f"Epoch {epoch+1} finished, Average Loss: {total_loss/total_step:.4f}", flush=True)

## 2.5 推理

In [37]:
tokenizer.decode(ds[0]['input_ids'])

'Human: Give three tips for staying healthy.\nAssistant: 1. Eat a balanced and nutritious diet: Make sure your meals are inclusive of a variety of fruits and vegetables, lean protein, whole grains, and healthy fats. This helps to provide your body with the essential nutrients to function at its best and can help prevent chronic diseases.\n\n2. Engage in regular physical activity: Exercise is crucial for maintaining strong bones, muscles, and cardiovascular health. Aim for at least 150 minutes of moderate aerobic exercise or 75 minutes of vigorous exercise each week.\n\n3. Get enough sleep: Getting enough quality sleep is crucial for physical and mental well-being. It helps to regulate mood, improve cognitive function, and supports healthy growth and immune function. Aim for 7-9 hours of sleep each night.'

In [None]:
def inference(
    model,
    tokenizer,
    text: str,
    max_new_tokens: int = 200,
    do_sample: bool = True,
    top_k: int = 40,
    temperature: float = 0.3,
):
    instruction_prompt = f"Human: {text}\nAssistant: "
    prompt = tokenizer(instruction_prompt, return_tensors='pt', add_special_tokens=False).to(device)
    outputs = model.generate(
        **prompt,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        top_k=top_k,
        temperature=temperature,
    )
    response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    return response

In [None]:
for test_text in [
    'Give three tips for staying healthy.',
    'What are the three primary colors?',
    'Describe the structure of an atom.',
]:
    print('=' * 80)
    print(inference(model, tokenizer, test_text))