In [None]:
# 安装和导入必要的库
!pip install  unsloth

In [None]:
import torch
from datasets import load_dataset
from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import FastLanguageModel, is_bfloat16_supported


In [None]:
# 1. 定义KL散度计算函数 (这里只包含反向KL散度)
def compute_rkl(logits_student, logits_teacher, target_labels, padding_id, reduction="sum", temp=1.0):
    """计算反向KL散度: KL(student || teacher)"""
    # 温度缩放
    logits_student = logits_student / temp
    logits_teacher = logits_teacher / temp

    # 计算学生模型的概率和对数概率
    probs_student = torch.softmax(logits_student, dim=-1, dtype=torch.float32)
    log_probs_student = torch.log_softmax(logits_student, dim=-1, dtype=torch.float32)

    # 计算教师模型的对数概率 (教师模型不应反向传播梯度)
    with torch.no_grad():
        log_probs_teacher = torch.log_softmax(logits_teacher, dim=-1, dtype=torch.float32)

    # KL散度计算: q * (log q - log p) = q * log(q/p)
    kl_divergence = probs_student * (log_probs_student - log_probs_teacher)
    kl_divergence = kl_divergence.sum(dim=-1) # 在词汇表维度上求和

    # 处理padding
    if target_labels is not None and padding_id is not None:
        pad_mask = (target_labels == padding_id)
        kl_divergence.masked_fill_(pad_mask, 0.0)

    if reduction == "sum":
        kl_loss = kl_divergence.sum()
    elif reduction == "mean":
        # 计算有效的 (非padding) token数量
        if target_labels is not None and padding_id is not None:
            num_tokens = (target_labels != padding_id).sum()
            kl_loss = kl_divergence.sum() / num_tokens if num_tokens > 0 else torch.tensor(0.0).to(kl_divergence.device)
        else: # 如果没有提供target_labels或padding_id，则对所有token取平均
            kl_loss = kl_divergence.mean()
    else:
        kl_loss = kl_divergence # 返回每个token的KL散度值

    return kl_loss


In [None]:
# 2. 定义KDTrainer (知识蒸馏训练器)
class KDTrainer(SFTTrainer):
    def __init__(self, *args, teacher_model=None, use_ce_loss=True, kl_loss_weight=0.5, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.use_ce_loss = use_ce_loss # 是否使用学生模型的交叉熵损失
        self.kl_loss_weight = kl_loss_weight # KL散度损失的权重
        if self.teacher_model is not None:
            self.teacher_model.eval() # 确保教师模型在评估模式

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # model 是学生模型
        outputs_student = model(**inputs)
        loss_ce_student = outputs_student.loss # 学生模型对硬标签的原始交叉熵损失
        logits_student = outputs_student.logits

        with torch.no_grad():
            outputs_teacher = self.teacher_model(**inputs)
            logits_teacher = outputs_teacher.logits

        # 确保logits形状兼容 (简化处理：以学生词汇表为准截断教师)
        if logits_student.shape[-1] != logits_teacher.shape[-1]:
            vocab_size_student = logits_student.shape[-1]
            logits_teacher = logits_teacher[..., :vocab_size_student]

        labels = inputs.get("labels")

        # 计算反向KL散度损失 (使用上面定义的compute_rkl)
        # SFTTrainer通常使用-100作为忽略标签的padding_id
        # 温度参数可以调整，这里使用2.0作为示例
        kl_loss = compute_rkl(
            logits_student,
            logits_teacher,
            target_labels=labels,
            padding_id=self.label_pad_token_id if hasattr(self, 'label_pad_token_id') else -100, # 从SFTTrainer获取
            temp=2.0,
            reduction="sum" # 或者 "mean"，取决于你的偏好和loss_ce的reduction方式
        )

        # 如果loss_ce_student是每个样本的平均损失，kl_loss也应该做相应调整
        # SFTTrainer的loss通常是batch内所有token损失的平均或总和，这里假设是总和或可比的平均
        # 为了简单，我们假设kl_loss和loss_ce_student的reduction方式是兼容的
        # 如果loss_ce_student是平均值，kl_loss也应该用reduction="mean"或手动平均

        if self.use_ce_loss:
            # 总损失 = KL散度损失 + 交叉熵损失
            # 权重可以调整，例如 kl_loss_weight 控制KL散度的重要性
            total_loss = self.kl_loss_weight * kl_loss + (1 - self.kl_loss_weight) * loss_ce_student
        else:
            # 只使用KL散度损失
            total_loss = kl_loss

        return (total_loss, outputs_student) if return_outputs else total_loss


In [None]:
# 3. 配置参数
# 模型和路径
teacher_model_path = "qwen_teacher_finetune" # 假设微调好的教师模型在此
student_model_name = "unsloth/Qwen2.5-0.5B" # 选择一个更小的学生模型，例如0.5B
# student_model_name = "unsloth/Qwen2.5-1.5B" # 或者1.5B
output_dir_distillation = "./results_qwen_student_distilled_rkl"
save_directory_student = "qwen_student_distilled_rkl_final"

# 数据集和格式化
dataset_name = "yahma/alpaca-cleaned"
alpaca_prompt_template = """Below is an instruction that describes a task, paired with
an input that provides further context. Write a response that appropriately
completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

# 训练超参数
max_seq_length = 1024 # 可以根据需要调整，但要小于等于模型支持的最大长度
load_in_4bit = True
# dtype = None # for auto-detection
# For Apple Silicon (MPS), 4-bit quantization might not be supported by unsloth or bitsandbytes
# In that case, set load_in_4bit = False and potentially dtype = torch.float16
if torch.backends.mps.is_available():
    load_in_4bit = False # MPS 通常不支持4位量化
    dtype = torch.float16
    print("MPS detected. Disabling 4-bit quantization and using float16.")
else:
    dtype = None # Auto detection for CUDA
    print("CUDA or CPU detected. Using auto dtype and 4-bit quantization if enabled.")



In [None]:
# 蒸馏特定参数
distill_use_ce_loss = True # 是否在蒸馏时也使用学生模型的CE损失
distill_kl_loss_weight = 0.5 # 如果使用CE损失，KL散度损失的权重 (CE损失权重为 1 - kl_loss_weight)
distill_epochs = 3 # 减少epoch以便快速演示
distill_batch_size = 2
distill_grad_accum = 8
distill_lr = 5e-5 # 调整学习率

In [None]:
# 4. 加载数据集和预处理
print("Loading and formatting dataset...")
dataset_full = load_dataset(dataset_name, split="train")
# dataset = dataset_full.select(range(2000)) # 可选：为了快速演示，使用部分数据
dataset = dataset_full

tokenizer_for_formatting = FastLanguageModel.get_tokenizer(student_model_name) # 获取一个分词器用于格式化
EOS_TOKEN = tokenizer_for_formatting.eos_token
if EOS_TOKEN is None:
    tokenizer_for_formatting.eos_token = "<|endoftext|>" # 确保有EOS token
    EOS_TOKEN = tokenizer_for_formatting.eos_token

def formatting_prompts_func(examples):
    texts = []
    for instruction, input_text, output in zip(examples["instruction"], examples["input"], examples["output"]):
        text = alpaca_prompt_template.format(instruction, input_text, output) + EOS_TOKEN
        texts.append(text)
    return {"text": texts}

dataset = dataset.map(formatting_prompts_func, batched=True, num_proc=4)
print(f"Dataset formatted. Number of examples: {len(dataset)}")


In [None]:
# 5. 加载教师模型 (已微调)
print(f"Loading fine-tuned teacher model from {teacher_model_path}...")
teacher_model, teacher_tokenizer = FastLanguageModel.from_pretrained(
    model_name=teacher_model_path,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(teacher_model) # 设置为推理模式
print("Teacher model loaded.")


In [None]:
# 6. 加载学生模型并配置LoRA
print(f"Loading student model ({student_model_name}) and configuring LoRA...")
student_model, student_tokenizer = FastLanguageModel.from_pretrained(
    model_name=student_model_name,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)
student_model = FastLanguageModel.get_peft_model(
    student_model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)
print("Student model loaded and LoRA configured.")
student_model.print_trainable_parameters()

# 确保学生和教师的分词器eos_token一致，如果之前没有设置的话
if student_tokenizer.eos_token is None:
    student_tokenizer.eos_token = EOS_TOKEN # 使用之前为格式化定义的EOS_TOKEN
if teacher_tokenizer.eos_token is None:
    teacher_tokenizer.eos_token = EOS_TOKEN



In [None]:
# 7. 配置蒸馏训练参数
print("Configuring TrainingArguments for distillation...")
distill_training_args = TrainingArguments(
    output_dir=output_dir_distillation,
    # max_steps=1,       # 或者按max_steps训练
    num_train_epochs=distill_epochs,
    per_device_train_batch_size=distill_batch_size,
    gradient_accumulation_steps=distill_grad_accum,
    warmup_ratio=0.1,
    learning_rate=2e-4,
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=2,
    fp16=not is_bfloat16_supported() and not torch.backends.mps.is_available(), # fp16 if not bf16 and not mps
    bf16=is_bfloat16_supported() and not torch.backends.mps.is_available(),    # bf16 if supported and not mps
    optim="adamw_8bit",
    lr_scheduler_type="linear",
    seed=3407,
    report_to="none",
)


In [None]:
# 8. 初始化KDTrainer并开始训练
print("Initializing KDTrainer...")
distill_trainer = KDTrainer(
    model=student_model,
    teacher_model=teacher_model,
    args=distill_training_args,
    train_dataset=dataset,
    tokenizer=student_tokenizer, # 使用学生模型的分词器
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    use_ce_loss=distill_use_ce_loss,
    kl_loss_weight=distill_kl_loss_weight,
)

print("Starting distillation training with Reverse KL Divergence...")
distill_trainer.train() # resume_from_checkpoint=False by default
print("Distillation training completed.")


In [None]:
# 9. 保存蒸馏后的学生模型
print(f"Saving distilled student model to {save_directory_student}...")
student_model.save_pretrained(save_directory_student)
student_tokenizer.save_pretrained(save_directory_student)
print("Distilled student model saved.")

print("\nKnowledge distillation process finished.")