In [None]:
!pip install transformers peft torch datasets

In [None]:
import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForSequenceClassification

class SharedTaskLora(nn.Module):
    def __init__(self, base_model, shared_r=8, task_r=4):
        super().__init__()
        self.base_model = base_model
        self.shared_lora = LoraConfig(
            r=shared_r,
            lora_alpha=16,
            target_modules=["query", "value"],
            lora_dropout=0.1,
            # 共享参数部分
        )
        self.task_lora1 = LoraConfig(
            r=task_r,
            lora_alpha=8,
            target_modules=["query", "value"],
            lora_dropout=0.1,
            # 任务1特有参数
        )
        self.task_lora2 = LoraConfig(
            r=task_r,
            lora_alpha=8,
            target_modules=["query", "value"],
            lora_dropout=0.1,
            # 任务2特有参数
        )
        
        # 创建共享基础模型
        self.shared_model = get_peft_model(base_model, self.shared_lora)
        # 添加任务特定适配器
        self.shared_model.add_adapter("task1", self.task_lora1)
        self.shared_model.add_adapter("task2", self.task_lora2)

    def forward(self, input_ids, attention_mask, task_id):
        # 动态切换适配器
        if task_id == 0:
            self.shared_model.set_adapter("task1")
        else:
            self.shared_model.set_adapter("task2")
        return self.shared_model(input_ids, attention_mask=attention_mask)

# 初始化模型
base_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
model = SharedTaskLora(base_model)

In [None]:
from torch.optim import AdamW
from datasets import load_dataset

# 模拟两个任务的数据
task1_dataset = load_dataset("imdb")["train"].select(range(1000))
task2_dataset = load_dataset("hate_speech")["train"].select(range(1000))

# 优化器分组：共享参数低学习率，任务参数高学习率
optimizer = AdamW([
    {"params": model.shared_model.base_model.parameters(), "lr": 1e-5},
    {"params": model.shared_model.get_adapter("task1").parameters(), "lr": 1e-4},
    {"params": model.shared_model.get_adapter("task2").parameters(), "lr": 1e-4},
])

# 交替训练函数
def alternate_train(model, task1_data, task2_data, epochs=3):
    for epoch in range(epochs):
        # 混合数据并打乱顺序
        mixed_batches = interleave_batches(task1_data, task2_data, batch_size=8)
        
        for batch in mixed_batches:
            inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
            task_id = batch["task_id"]  # 假设数据中包含任务ID
            
            # 前向传播（自动选择适配器）
            outputs = model(**inputs, task_id=task_id)
            loss = outputs.loss
            
            # L2正则化约束任务参数与共享参数的相似性
            l2_loss = 0.0
            for task_param in model.shared_model.get_adapter_parameters():
                l2_loss += torch.norm(task_param - model.shared_model.base_model.parameters())
            loss += 0.1 * l2_loss  # λ=0.1
            
            # 反向传播
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

In [None]:
l2_loss = torch.norm(task_param - shared_param)  # 计算L2距离
loss += 0.1 * l2_loss  # λ=0.1

In [None]:
# 数据预处理函数
def preprocess_function(examples, task_id):
    tokenized = tokenizer(examples["text"], padding="max_length", truncation=True)
    tokenized["task_id"] = task_id
    return tokenized

# 处理两个任务的数据
task1_data = task1_dataset.map(preprocess_function, fn_kwargs={"task_id": 0})
task2_data = task2_dataset.map(preprocess_function, fn_kwargs={"task_id": 1})

# 开始训练
alternate_train(model, task1_data, task2_data, epochs=3)