In [None]:
import math
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftMixedModel, LoraConfig
from datasets import load_dataset

device = "mps"

base_model_name = "Qwen/Qwen2-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, device_map="auto")

# 定义LoRA配置
rank = 4
peft_config = LoraConfig(
    inference_mode=False,
    r=rank,
    lora_alpha=32,
    lora_dropout=0.1,
    task_type="CAUSAL_LM"
)

# 创建PeftMixedModel并添加两个adapter: "0"和"1"
mixed_model = PeftMixedModel(model, peft_config, adapter_name="0")
mixed_model.add_adapter("1", peft_config)

# 冻结基础模型参数，只训练LoRA参数
for n, p in mixed_model.named_parameters():
    if "lora" not in n:
        p.requires_grad = False

mixed_model.train()
mixed_model.to(device)

# 加载两个数据集(示例)
task1_dataset = load_dataset("fzkuji/cMedQA2", name="deduplicate_all", split="train[:1%]")
task2_dataset = load_dataset("fzkuji/HealthCareMagic-100k", split="train[:1%]")

# 根据你的数据集字段进行预处理
def preprocess(examples):
    # 假设数据集中有"question"和"answer"
    inputs = ["Q: " + q + "\nA: " + a for q, a in zip(examples["question"], examples["answer"])]
    tokenized = tokenizer(inputs, padding="longest", truncation=True, return_tensors="pt")
    tokenized["labels"] = tokenized["input_ids"].clone()
    return tokenized

task1_dataset = task1_dataset.map(preprocess, batched=True)
task2_dataset = task2_dataset.map(preprocess, batched=True)

def collate_fn(batch):
    # batch是list，每个元素都是dict
    keys = batch[0].keys()
    out = {}
    for k in keys:
        out[k] = torch.cat([d[k] for d in batch], dim=0)
    return out

batch_size = 1
task1_loader = DataLoader(task1_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
task2_loader = DataLoader(task2_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# 训练参数
train_steps = 1000
gradient_accumulation_steps = 4
lr_start = 1e-4
lr_end = 1e-5
warmup_steps = 100
initial_lr = 0.0  # 第0步学习率为0

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, mixed_model.parameters()), lr=initial_lr)

task1_iter = iter(task1_loader)
task2_iter = iter(task2_loader)

global_step = 0
accumulation_count = 0

mixed_model.train()

for step in range(train_steps):
    # 交替从Task1和Task2取数据
    if step % 2 == 0:
        # Task 1 batch
        try:
            batch = next(task1_iter)
        except StopIteration:
            task1_iter = iter(task1_loader)
            batch = next(task1_iter)
        current_task = 1
    else:
        # Task 2 batch
        try:
            batch = next(task2_iter)
        except StopIteration:
            task2_iter = iter(task2_loader)
            batch = next(task2_iter)
        current_task = 2

    batch = {k: v.to(device) for k,v in batch.items()}

    # 动态调节学习率
    if step < warmup_steps:
        # warmup阶段：0线性上升到lr_start=1e-4
        current_lr = (step / warmup_steps) * lr_start
    else:
        # cosine decay阶段: 从1e-4 到1e-5
        t = (step - warmup_steps) / (train_steps - warmup_steps)
        # cos(pi/2 * t)从1到0递减
        current_lr = lr_end + (lr_start - lr_end) * math.cos(t * math.pi / 2)

    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr

    # 根据当前task设置需要训练的LoRA adapter和冻结策略
    for n, p in mixed_model.named_parameters():
        if "lora" in n:
            p.requires_grad = False

    if current_task == 1:
        # 对Task 1:
        # (1) Base+LoRA0
        mixed_model.set_adapter("0")
        for n,p in mixed_model.named_parameters():
            if "lora0" in n:
                p.requires_grad = True
        out_task1_lora0 = mixed_model(**batch)
        loss_task1_lora0 = out_task1_lora0.loss

        # (2) Base+LoRA0+LoRA1, 此时LoRA1冻结，仅LoRA0训练
        mixed_model.set_adapter(["0","1"])
        for n,p in mixed_model.named_parameters():
            if "lora0" in n:
                p.requires_grad = True
            else:
                p.requires_grad = False
        out_task1_lora0_lora1 = mixed_model(**batch)
        loss_task1_lora0_lora1 = out_task1_lora0_lora1.loss

        loss = (loss_task1_lora0 + loss_task1_lora0_lora1) / 2.0

    else:
        # current_task == 2:
        # (1) Base+LoRA1
        mixed_model.set_adapter("1")
        for n,p in mixed_model.named_parameters():
            if "lora1" in n:
                p.requires_grad = True
        out_task2_lora1 = mixed_model(**batch)
        loss_task2_lora1 = out_task2_lora1.loss

        # (2) Base+LoRA0+LoRA1，此时LoRA0冻结，仅LoRA1训练
        mixed_model.set_adapter(["0","1"])
        for n,p in mixed_model.named_parameters():
            if "lora1" in n:
                p.requires_grad = True
            else:
                p.requires_grad = False
        out_task2_lora0_lora1 = mixed_model(**batch)
        loss_task2_lora0_lora1 = out_task2_lora0_lora1.loss

        loss = (loss_task2_lora1 + loss_task2_lora0_lora1) / 2.0

    # 累积梯度
    loss = loss / gradient_accumulation_steps
    loss.backward()
    accumulation_count += 1

    if accumulation_count % gradient_accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
        accumulation_count = 0
        global_step += 1
        print(f"Global step: {global_step}, loss: {loss.item() * gradient_accumulation_steps:.4f}, lr: {current_lr:.6e}")

# 训练结束后进行验证测试
# 可以在验证和测试时固定adapter，比如使用["0","1"]或单独"0"/"1"来生成回答，然后计算BLEU或Accuracy
print("Training completed.")
