In [None]:
import os
import math
import argparse
from dataclasses import dataclass
from typing import Dict, Any, Optional

import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training


In [None]:
def print_once(msg: str):
    rank = int(os.environ.get("RANK", "0"))
    if rank == 0:
        print(msg, flush=True)


def find_pad_token(tokenizer: AutoTokenizer):
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer


def format_prompt(instruction: str, inp: Optional[str], answer: str) -> str:
    if inp and len(inp.strip()) > 0:
        prefix = f"### Instruction:\n{instruction}\n\n### Input:\n{inp}\n\n### Response:\n"
    else:
        prefix = f"### Instruction:\n{instruction}\n\n### Response:\n"
    return prefix + answer


def get_4bit_config():
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
    )


In [None]:
def build_dataset(dataset_id: str, tokenizer: AutoTokenizer, max_len: int, limit_train: Optional[int] = None):
    raw = load_dataset(dataset_id)

    def _format(row):
        text = format_prompt(row["instruction"], row.get("input", ""), row["output"])
        return {"text": text + tokenizer.eos_token}

    ds = raw.map(_format, remove_columns=raw["train"].column_names)

    def _tok(batch):
        return tokenizer(batch["text"], truncation=True, max_length=max_len)

    ds = ds.map(_tok, batched=True, remove_columns=["text"])
    if limit_train is not None and limit_train > 0:
        ds["train"] = ds["train"].select(range(min(limit_train, len(ds["train"]))))

    return ds


In [None]:
def load_teacher(teacher_id: str):
    print_once(f"Loading TEACHER: {teacher_id}")
    teacher = AutoModelForCausalLM.from_pretrained(
        teacher_id,
        trust_remote_code=True,
        device_map="auto",
        quantization_config=get_4bit_config(),
    ).eval()
    for p in teacher.parameters():
        p.requires_grad_(False)
    return teacher


def load_student(student_id: str, lora_r: int, lora_alpha: int, lora_dropout: float):
    print_once(f"Loading STUDENT: {student_id}")
    student = AutoModelForCausalLM.from_pretrained(
        student_id,
        trust_remote_code=True,
        device_map="auto",
        quantization_config=get_4bit_config(),
    )
    student = prepare_model_for_kbit_training(student)
    lora_cfg = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    )
    student = get_peft_model(student, lora_cfg)
    return student


In [None]:
def shift_logits_and_labels(logits: torch.Tensor, labels: torch.Tensor):
    # Causal LM predicts next token â†’ align predictions and targets by shifting
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()
    return shift_logits, shift_labels


@dataclass
class KDTrainer(Trainer):
    teacher: AutoModelForCausalLM = None
    alpha: float = 0.5
    temperature: float = 1.0
    ce_loss: torch.nn.Module = torch.nn.CrossEntropyLoss(ignore_index=-100)
    kl_loss: torch.nn.Module = torch.nn.KLDivLoss(reduction="batchmean")

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs["labels"]
        outputs_s = model(**inputs)
        logits_s = outputs_s.logits

        # --- CE: student vs ground-truth labels ---
        s_logits, s_labels = shift_logits_and_labels(logits_s, labels)
        loss_ce = self.ce_loss(s_logits.view(-1, s_logits.size(-1)), s_labels.view(-1))

        # --- KL: student vs teacher (soft targets) ---
        with torch.no_grad():
            t_out = self.teacher(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
            logits_t = t_out.logits

        s_logits_t, _ = shift_logits_and_labels(logits_s, labels)
        t_logits_t, _ = shift_logits_and_labels(logits_t, labels)

        T = self.temperature
        s_logp = torch.log_softmax(s_logits_t / T, dim=-1)  # student log-probs
        t_p = torch.softmax(t_logits_t / T, dim=-1)         # teacher probs

        loss_kl = self.kl_loss(s_logp, t_p) * (T * T)

        loss = self.alpha * loss_ce + (1.0 - self.alpha) * loss_kl
        return (loss, outputs_s) if return_outputs else loss


In [None]:
def main():
    parser = argparse.ArgumentParser(description="LLM Knowledge Distillation (SFT + KL) with QLoRA")
    parser.add_argument("--teacher", type=str, default="Qwen/Qwen2.5-7B-Instruct")
    parser.add_argument("--student", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
    parser.add_argument("--dataset", type=str, default="yahma/alpaca-cleaned")
    parser.add_argument("--output_dir", type=str, default="student_kd_out")

    parser.add_argument("--alpha", type=float, default=0.5, help="Weight for CE; (1-alpha) used for KL")
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--epochs", type=float, default=1.0)
    parser.add_argument("--lr", type=float, default=2e-4)
    parser.add_argument("--bsz", type=int, default=4)
    parser.add_argument("--grad_accum", type=int, default=8)
    parser.add_argument("--max_len", type=int, default=1024)
    parser.add_argument("--limit_train", type=int, default=0, help="If >0, limit training examples")

    parser.add_argument("--lora_r", type=int, default=16)
    parser.add_argument("--lora_alpha", type=int, default=32)
    parser.add_argument("--lora_dropout", type=float, default=0.05)

    args = parser.parse_args()

    # Tokenizer (use student's for training)
    tokenizer = AutoTokenizer.from_pretrained(args.student, use_fast=True, trust_remote_code=True)
    tokenizer = find_pad_token(tokenizer)

    # Data
    print_once(f"Loading dataset: {args.dataset}")
    ds = build_dataset(args.dataset, tokenizer, args.max_len, limit_train=(args.limit_train or None))

    # Models
    teacher = load_teacher(args.teacher)
    student = load_student(args.student, args.lora_r, args.lora_alpha, args.lora_dropout)

    # Collator
    collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    # Trainer
    train_args = TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.bsz,
        gradient_accumulation_steps=args.grad_accum,
        learning_rate=args.lr,
        num_train_epochs=args.epochs,
        bf16=True,
        logging_steps=10,
        save_strategy="epoch",
        report_to="none",
    )

    trainer = KDTrainer(
        model=student,
        args=train_args,
        train_dataset=ds["train"],
        data_collator=collator,
        teacher=teacher,
        alpha=args.alpha,
        temperature=args.temperature,
    )

    print_once("Starting training...")
    trainer.train()
    print_once("Saving model + tokenizer...")
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    print_once(f"Done. Saved to: {args.output_dir}")


if __name__ == "__main__":
    main()