In [None]:
# 安装和导入必要的库
!pip install unsloth
!pip install git+https://github.com/josejg/instruction_following_eval.git # 安装 IFEval，进行指令跟随评估
!pip install -U wandb # 确保 wandb 已安装

In [None]:
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
from datasets import load_dataset
from transformers import TrainingArguments
from trl import SFTTrainer
import wandb # 导入 wandb
import os    # 导入 os


In [None]:
# 0. W&B 设置 (与教师模型脚本类似)
# 假设您在 Kaggle 环境中，并且 WANDB_API_KEY 存储在 secrets 中
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    key = user_secrets.get_secret("WANDB_API_KEY")
    os.environ["WANDB_API_KEY"] = key
except ImportError:
    print("Kaggle secrets not found. Make sure WANDB_API_KEY is set in your environment if not on Kaggle.")
    # 或者直接在这里设置 key = "YOUR_WANDB_API_KEY"

# 设置 W&B 项目名称 (可以与教师模型项目相同或不同)
os.environ["WANDB_PROJECT"] = "Decoder_Knowledge_Distillation" # 或者您选择的其他项目名

wandb.login()


In [None]:
# 1. 定义KL散度计算函数 (这里是偏向反KL散度)
def compute_skewed_rkl(logits_student, logits_teacher, target_labels, padding_id,
                       reduction="sum", temp=1.0, skew_lambda=0.1):
    """计算偏向反KL散度: KL(student || mixed_distribution)
       mixed_distribution = (1-skew_lambda) * teacher + skew_lambda * student
    """
    # 温度缩放
    logits_student_scaled = logits_student / temp
    logits_teacher_scaled = logits_teacher / temp

    # 学生模型的概率和对数概率 (来自缩放后的logits)
    probs_student = torch.softmax(logits_student_scaled, dim=-1, dtype=torch.float32)
    log_probs_student = torch.log_softmax(logits_student_scaled, dim=-1, dtype=torch.float32)

    # 教师模型的概率 (来自缩放后的logits, 不应反向传播梯度)
    with torch.no_grad():
        probs_teacher = torch.softmax(logits_teacher_scaled, dim=-1, dtype=torch.float32)

    # 计算混合概率分布
    # mixed_probs = (1 - skew_lambda) * p_teacher + skew_lambda * p_student
    mixed_probs = (1 - skew_lambda) * probs_teacher + skew_lambda * probs_student
    # 防止 mixed_probs 为0导致log(0)数值问题，添加一个极小值
    mixed_log_probs = torch.log(mixed_probs + 1e-10)

    # KL散度计算: p_student * (log p_student - log p_mixed)
    kl_divergence = probs_student * (log_probs_student - mixed_log_probs)
    kl_divergence = kl_divergence.sum(dim=-1) # 在词汇表维度上求和

    # 处理padding
    if target_labels is not None and padding_id is not None:
        pad_mask = (target_labels == padding_id)
        kl_divergence.masked_fill_(pad_mask, 0.0)

    if reduction == "sum":
        kl_loss = kl_divergence.sum()
    elif reduction == "mean":
        if target_labels is not None and padding_id is not None:
            num_tokens = (target_labels != padding_id).sum()
            kl_loss = kl_divergence.sum() / num_tokens if num_tokens > 0 else torch.tensor(0.0).to(kl_divergence.device)
        else:
            kl_loss = kl_divergence.mean()
    else:
        kl_loss = kl_divergence

    return kl_loss



In [None]:
class KDTrainer(SFTTrainer):
    def __init__(self, *args, teacher_model=None, use_ce_loss=True,
                 kl_loss_weight=0.5, skew_lambda_rkl=0.1,kl_temperature=2.0,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.use_ce_loss = use_ce_loss
        self.kl_loss_weight = kl_loss_weight
        self.kl_temperature = kl_temperature   # <--- 保存为实例属性
        self.skew_lambda_rkl = skew_lambda_rkl
        if self.teacher_model is not None:
            self.teacher_model.eval()

    def compute_loss(self, model, inputs, return_outputs=False,num_items_in_batch=None):
        outputs_student = model(**inputs)
        loss_ce_student = outputs_student.loss
        logits_student = outputs_student.logits

        with torch.no_grad():
            outputs_teacher = self.teacher_model(**inputs)
            logits_teacher = outputs_teacher.logits

        if logits_student.shape[-1] != logits_teacher.shape[-1]:
            vocab_size_student = logits_student.shape[-1]
            logits_teacher = logits_teacher[..., :vocab_size_student]

        labels = inputs.get("labels")
        if self.processing_class is not None and hasattr(self.processing_class, "pad_token_id"):
            padding_id_val = self.processing_class.pad_token_id
        else:
            padding_id_val = -100
        # 计算偏向反KL散度损失
        kl_loss = compute_skewed_rkl( # MODIFIED: Changed to compute_skewed_rkl
            logits_student,
            logits_teacher,
            target_labels=labels,
            padding_id=padding_id_val,
            temp=2.0,
            reduction="sum",
            skew_lambda=self.skew_lambda_rkl # MODIFIED: Pass skew_lambda
        )

        if self.use_ce_loss:
            total_loss = self.kl_loss_weight * kl_loss + (1 - self.kl_loss_weight) * loss_ce_student
        else:
            total_loss = kl_loss

        return (total_loss, outputs_student) if return_outputs else total_loss

In [None]:
# 3. 配置参数
# 模型和路径
# teacher_model_path = "qwen_teacher_finetune"
teacher_model_path = "unsloth/Qwen2.5-3B-Instruct"
student_model_name = "unsloth/Qwen2.5-0.5B-Instruct" # 学生模型是 Instruct 模型
output_dir_distillation = "./results_qwen_student_distilled_skewed_rkl_chat"
save_directory_student = "qwen_student_distilled_skewed_rkl_chat_final"

# 数据集和格式化
dataset_name = "yahma/alpaca-cleaned"
ALPACA_SYSTEM_PROMPT = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."

# 训练超参数
max_seq_length = 2048
load_in_4bit = True

if torch.backends.mps.is_available():
    load_in_4bit = False
    dtype = torch.float16
    print("MPS detected. Disabling 4-bit quantization and using float16.")
else:
    dtype = None
    print("CUDA or CPU detected. Using auto dtype and 4-bit quantization if enabled.")


In [None]:
# 蒸馏特定参数
distill_use_ce_loss = True
distill_kl_loss_weight = 0.5
distill_epochs = 1
distill_batch_size = 2 # 调整以适应显存
distill_grad_accum = 8 # Effective batch size = 16
distill_lr = 5e-4
distill_kl_temperature = 2.0
skew_lambda_rkl = 0.1
wandb_run_name = f"decoder_knowledge_distillation_student_skewed_rkl"


In [None]:
# 4. 加载数据集和预处理
print("Loading and formatting dataset...")
dataset_full = load_dataset(dataset_name, split="train")
# dataset = dataset_full.select(range(200)) # 演示用
dataset = dataset_full

print(f"Loading student model ({student_model_name}) and its tokenizer...")
student_model, student_tokenizer = FastLanguageModel.from_pretrained(
    model_name=student_model_name,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit, # 实际加载模型权重
)

# 确保学生tokenizer有必要的token和chat_template (原第6步的检查)
if student_tokenizer.pad_token is None:
    student_tokenizer.pad_token = student_tokenizer.eos_token
    print(f"Set student_tokenizer.pad_token to eos_token: {student_tokenizer.pad_token}")

if student_tokenizer.chat_template is None:
    print(f"Warning: student_tokenizer (for {student_model_name}) loaded without a chat_template. Unsloth might apply a default one for Qwen models. Ensure this is intended.")
else:
    print(f"Using student tokenizer chat template: {student_tokenizer.chat_template}")

def formatting_prompts_func(examples):
    texts = []
    instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]

    for instruction, input_text, output in zip(instructions, inputs, outputs):
        messages = [
            {"role": "system", "content": ALPACA_SYSTEM_PROMPT},
            {"role": "user", "content": instruction + (f"\n{input_text}" if input_text and input_text.strip() else "")},
            {"role": "assistant", "content": output}
        ]
        try:
            # 直接使用 student_tokenizer
            formatted_text = student_tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            texts.append(formatted_text)
        except Exception as e:
            print(f"Error applying chat template: {e}")
            print(f"Problematic messages: {messages}")
            texts.append("")
    return {"text": texts}

dataset = dataset.map(formatting_prompts_func, batched=True, num_proc=4)
dataset = dataset.filter(lambda example: example['text'] != "" and example['text'] is not None)
print(f"Dataset formatted. Number of examples after formatting: {len(dataset)}")
if len(dataset) > 0:
    print("\nSample formatted text (for student model training):")
    print(dataset[0]['text'])
else:
    print("Dataset is empty after formatting. Exiting.")
    exit()


In [None]:
# #挑选token数前32的数据用于估计显存占用
# def count_tokens(example):
#     return {
#         "num_tokens": len(
#             student_tokenizer(example["text"], add_special_tokens=False).input_ids
#         )
#     }

# dataset_with_counts = dataset.map(count_tokens, batched=False)

# # 2. 按 token 数降序排序
# sorted_dataset = dataset_with_counts.sort("num_tokens", reverse=True)

# # 3. 取前16条
# dataset = sorted_dataset.select(range(32))

# dataset ['num_tokens']

In [None]:
# 5. 加载教师模型 (已微调)
print(f"Loading fine-tuned teacher model from {teacher_model_path}...")
teacher_model, teacher_tokenizer = FastLanguageModel.from_pretrained(
    model_name=teacher_model_path,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(teacher_model)
print("Teacher model loaded.")

# 确保教师tokenizer有必要的token，以防万一
if teacher_tokenizer.pad_token is None:
    teacher_tokenizer.pad_token = teacher_tokenizer.eos_token
    print(f"Set teacher_tokenizer.pad_token to eos_token: {teacher_tokenizer.pad_token}")


In [None]:
# 6. 学生模型配置LoRA
student_model = FastLanguageModel.get_peft_model(
    student_model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    max_seq_length=max_seq_length,
)
print("Student model loaded and LoRA configured.")
student_model.print_trainable_parameters()


In [None]:
# 7. 配置蒸馏训练参数
print("Configuring TrainingArguments for distillation...")
distill_training_args = TrainingArguments(
    output_dir=output_dir_distillation,
    num_train_epochs=distill_epochs,
    # max_steps=1, # 如果使用max_steps
    per_device_train_batch_size=distill_batch_size,
    gradient_accumulation_steps=distill_grad_accum,
    learning_rate=distill_lr,
    warmup_ratio=0.1,
    logging_steps=10, # 调整日志频率
    save_strategy="epoch", # 或 "steps"
    # save_steps=50, # 如果 save_strategy="steps"
    save_total_limit=2,
    fp16=not is_bfloat16_supported() and not torch.backends.mps.is_available(),
    bf16=is_bfloat16_supported() and not torch.backends.mps.is_available(),
    optim="adamw_8bit",
    lr_scheduler_type="linear",
    seed=3407,
    report_to="wandb", # <--- 修改这里以启用W&B报告
    run_name=wandb_run_name, # <--- 为W&B运行设置名称
)


In [None]:
# 8. 初始化KDTrainer并开始训练
if len(dataset) == 0:
    print("Skipping distillation training as dataset is empty.")
else:
    print("Initializing KDTrainer...")
    distill_trainer = KDTrainer(
        model=student_model,
        teacher_model=teacher_model,
        args=distill_training_args,
        train_dataset=dataset,
        tokenizer=student_tokenizer, # KDTrainer 使用学生 tokenizer
        dataset_text_field="text",   # 我们在 formatting_prompts_func 中创建了这个字段
        max_seq_length=max_seq_length,
        dataset_num_proc=2,
        packing=False, # 因为 'text' 字段是预格式化的完整对话
        use_ce_loss=distill_use_ce_loss,
        kl_loss_weight=distill_kl_loss_weight,
        kl_temperature = distill_kl_temperature,
        skew_lambda_rkl = skew_lambda_rkl,
    )

    print("Starting distillation training with Forward KL Divergence and Chat Template...")
    distill_trainer.train()
    wandb.finish()
    print("Distillation training completed.")

    # 9. 保存蒸馏后的学生模型 (LoRA权重) 和分词器
    print(f"Saving distilled student model to {save_directory_student}...")
    student_model.save_pretrained(save_directory_student)
    student_tokenizer.save_pretrained(save_directory_student)
    print("Distilled student model saved.")

print("\nKnowledge distillation process (Forward KL with Chat Template) finished.")

In [None]:
# ---------------------------------------------------------------------------------
# IFEval Evaluation for the Distilled Student Model
# ---------------------------------------------------------------------------------
print("\nStarting IFEval Evaluation for the distilled student model...")

try:
    from instruction_following_eval import get_examples, evaluate_instruction_following
except ImportError:
    print("IFEval library not found. Please install it first.")
    exit()

if not os.path.exists(save_directory_student) or not os.listdir(save_directory_student): # 检查目录是否存在且不为空
    print(f"Error: Saved model directory '{save_directory_student}' not found or empty. Skipping IFEval.")
    exit()

print(f"Loading distilled student model from {save_directory_student} for IFEval...")
eval_model, eval_tokenizer = FastLanguageModel.from_pretrained(
    model_name=save_directory_student,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(eval_model)
print("Distilled student model and tokenizer loaded for IFEval.")

In [None]:
if eval_tokenizer.pad_token is None:
    eval_tokenizer.pad_token = eval_tokenizer.eos_token
    print(f"Set eval_tokenizer.pad_token to eos_token: {eval_tokenizer.pad_token}")

# 确保评估时使用的tokenizer也有正确的chat_template
# 通常从保存的目录加载时，它会包含训练时的配置
if eval_tokenizer.chat_template is None:
    print(f"Warning: eval_tokenizer (for {save_directory_student}) loaded without a chat_template.")
    if student_tokenizer.chat_template is not None: # student_tokenizer 是训练时用的
        eval_tokenizer.chat_template = student_tokenizer.chat_template
        print(f"Applied chat_template from student_tokenizer to eval_tokenizer.")
    # 如果 student_tokenizer 也没有，那可能需要手动设置或依赖模型默认行为
else:
    print(f"Eval tokenizer chat template: {eval_tokenizer.chat_template}")

In [None]:
ifeval_examples = get_examples()
print(f"Loaded {len(ifeval_examples)} examples for IFEval.")
# ifeval_examples = ifeval_examples[:5] # 演示用

print("Generating responses for IFEval prompts using the distilled student model...")
generated_responses_for_ifeval = [] # IFEval期望一个包含'response'键的字典列表

for i, example in enumerate(ifeval_examples):
    ifeval_prompt_text = example['prompt']
    messages_for_eval = [
        {"role": "system", "content": ALPACA_SYSTEM_PROMPT},
        {"role": "user", "content": ifeval_prompt_text}
    ]
    try:
        inputs = eval_tokenizer.apply_chat_template(
            messages_for_eval,
            tokenize=True,
            add_generation_prompt=True, # 重要: 为生成任务设为True
            return_tensors="pt"
        ).to(eval_model.device)
    except Exception as e:
        print(f"Error applying chat template for IFEval prompt: {e}")
        example['response'] = f"Error during input formatting: {e}"
        generated_responses_for_ifeval.append(example)
        continue

    try:
        outputs = eval_model.generate(
            inputs,
            max_new_tokens=2048, # 调整最大生成长度
            use_cache=True
        )
        response_text = eval_tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0].strip()
    except Exception as e:
        print(f"Error during model generation for IFEval prompt {i+1}: {e}")
        response_text = f"Error during model generation: {e}"

    current_example_with_response = example.copy() # 复制原始字典
    current_example_with_response['response'] = response_text # 添加 'response' 键
    generated_responses_for_ifeval.append(current_example_with_response)

    if (i + 1) % 10 == 0 or i == len(ifeval_examples) - 1:
        print(f"Generated response for IFEval example {i + 1}/{len(ifeval_examples)}")

print("Finished generating responses for IFEval prompts.")

if generated_responses_for_ifeval:
    print("Evaluating generated responses with IFEval...")
    model_responses_list = [ex['response'] for ex in generated_responses_for_ifeval]
    ifeval_metrics = evaluate_instruction_following(ifeval_examples, model_responses_list) # 使用原始ifeval_examples和提取的responses

    print("\nIFEval Metrics for Distilled Student Model:")
    for metric_name, value in ifeval_metrics.items():
        print(f"  {metric_name}: {value:.4f}")
else:
    print("No responses were generated, skipping IFEval evaluation.")

print("\nIFEval Evaluation for distilled student model finished.")