# 开始 AdaLoRA 训练

In [1]:
# AutoDL官方学术资源加速
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [2]:
import sys
import os

# 添加项目根目录到Python路径
project_root = "/home/cuipeng/Gemma"
sys.path.append(project_root)

# 现在可以正常导入src下的模块
from src.core.model.model_initializer import initialize_model_and_tokenizer
from src.core.utils.model_utils import generate_response, apply_chat_template

In [3]:
import os
import json
import torch # type: ignore
from transformers import ( # type: ignore
    AutoModelForCausalLM, 
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import AdaLoraConfig, get_peft_model # type: ignore
from datasets import Dataset # type: ignore
from transformers import BitsAndBytesConfig # type: ignore # 导入 BitsAndBytesConfig

In [4]:
def load_dataset(file_path, tokenizer):
    """加载数据集并进行预处理"""
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # 使用tokenizer处理文本
    def preprocess_function(examples):
        return tokenizer(
            examples['text'],
            truncation=True, 
            max_length=512,
            padding='max_length',
            return_tensors=None
        )
    
    # 创建数据集
    dataset = Dataset.from_list([{
        'text': item['text'] # 直接使用text字段，不需要拼接prompt和completion
    } for item in data])
    
    # 对数据集进行预处理
    tokenized_dataset = dataset.map(
        preprocess_function,
        remove_columns=['text'],
        desc="正在对数据集进行分词处理",
    )
    
    return tokenized_dataset

In [5]:
def create_peft_config():
    """创建AdaLoRA配置"""
    return AdaLoraConfig(
        r=64,
        lora_alpha=32,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        inference_mode=False,
        target_r=32,
        beta1=0.85,
        beta2=0.85,
        tinit=500, # 预热
        tfinal=2000, # 最终步数
        deltaT=10, # 步数间隔
    )

In [6]:
# 2. 主训练函数
def train():
    # 设置模型路径和缓存目录
    model_path = "google/gemma-2-9b"
    cache_dir = "/root/autodl-tmp/gemma"
    lora_path = None

    print("创建模型和tokenizer...")
    model, tokenizer = initialize_model_and_tokenizer(
        model_path=model_path,
        cache_dir=cache_dir,
        lora_path=lora_path,
        use_quantization=True
    )
    
    # 然后加载和预处理数据集
    print("开始加载数据集...")
    train_dataset = load_dataset("../data_processing/stage1/data_final/train.json", tokenizer)
    eval_dataset = load_dataset("../data_processing/stage1/data_final/valid.json", tokenizer)
    
    print("应用AdaLoRA配置...")
    peft_config = create_peft_config()
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters() # trainable params: 13,420,512 || all params: 9,255,126,664 || trainable%: 0.1450
    
    # 创建训练参数
    training_args = TrainingArguments(
        output_dir="../../../../../../root/autodl-tmp/models/stage1/checkpoints/gemma-base-zh", # 保存中间模型和日志的目录
        learning_rate=5e-5, # 学习率
        num_train_epochs=5, # 训练轮数
        per_device_train_batch_size=2, # 训练批次大小
        per_device_eval_batch_size=4, # 选8好像也不会报错，选4放心一点  
        gradient_accumulation_steps=4, # 梯度累积步数 # 所以每个epoch的step总数: 80000 / (2*4) = 100000
        warmup_steps=500, # 预热步数, 在训练开始时逐渐增加学习率，以防止初始阶段的不稳定
        # max_steps=1000, # 最大训练步数, 训练的最大步数
        logging_steps=100, # 日志记录步数, 每100步记录一次日志
        save_steps=500, # 保存模型步数, 每500步保存一次模型
        evaluation_strategy="steps", # 按步数评估, 有三个选项: "no", "steps", "epoch"
        eval_steps=500, # 评估步数
        fp16=True, # 使用16位浮点数
        optim="paged_adamw_32bit", # 优化器
        lr_scheduler_type="cosine", # 学习率调度器类型
        report_to="tensorboard", # 报告到tensorboard
        remove_unused_columns=False, # 添加这一行
    )
    
    # 创建数据整理器
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, 
        mlm=False, # 表示我们在做因果语言建模(CLM)而不是掩码语言建模(MLM)
    )
    # 没有这个组件，模型将无法正确处理不同长度的序列
    # 它确保了批处理中的数据格式统一
    # 对于语言模型训练来说是必需的组件
    
    # 创建训练器
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
    )
    
    print("开始训练...")
    trainer.train()
    
    print("保存模型...")
    trainer.save_model("../../../../../../root/autodl-tmp/models/stage1/gemma-base-zh-final") # 保存最终模型

In [7]:
if __name__ == "__main__":
    train()

创建模型和tokenizer...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

开始加载数据集...


正在对数据集进行分词处理:   0%|          | 0/80000 [00:00<?, ? examples/s]

正在对数据集进行分词处理:   0%|          | 0/1000 [00:00<?, ? examples/s]

应用AdaLoRA配置...




trainable params: 13,420,512 || all params: 9,255,126,664 || trainable%: 0.1450
开始训练...




Step,Training Loss,Validation Loss
500,9.0964,1.909219
1000,5.9222,1.499785
1500,5.6385,1.467625
2000,5.7069,1.453999
2500,5.6322,1.445326
3000,5.5842,1.436898
3500,5.5981,1.431238
4000,5.5276,1.425206
4500,5.5441,1.421027
5000,5.4724,1.416371


The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.

Cannot access gated repo for url https://huggingface.co/google/gemma-2-9b/resolve/main/config.json.
Access to model google/gemma-2-9b is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in google/gemma-2-9b.

Cannot access gated repo for url https://huggingface.co/google/gemma-2-9b/resolve/main/config.json.
Access to model google/gemma-2-9b is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in google/gemma-2-9b.

Cannot access gated repo for url https://huggingface.co/google/gemma-2-9b/resolve/ma