# FineTune LLM - Teacher-Student Distillation

In [None]:
import os, torch, math, torch.nn.functional as F
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    DataCollatorForLanguageModeling, TrainingArguments, Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

In [None]:
class DistillationFineTuner(Trainer):
    """
    A Hugging-Face Trainer subclass that performs knowledge-distillation
    from a frozen teacher into a student equipped with LoRA adapters.
    """

    def __init__(
        self,
        teacher_model,
        student_model,
        tokenizer,
        train_dataset,
        eval_dataset,
        temperature: float = 2.0,
        alpha: float = 0.5,
        **training_args_kwargs
    ):
        self.teacher = teacher_model.eval()         # freeze teacher
        self.teacher.requires_grad_(False)

        # HuggingFace TrainingArguments
        args = TrainingArguments(**training_args_kwargs)

        super().__init__(
            model=student_model,
            args=args,
            data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
        )
        self.tokenizer = tokenizer
        self.temperature = temperature
        self.alpha = alpha

        # Loss functions ready-made
        self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean")
        self.ce_loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

    # ---------------------- loss override ---------------------- #
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        Mixes KL-divergence (student vs. teacher) with standard CE.
        `inputs` already contain 'labels'.
        """
        labels = inputs.get("labels")
        # Forward pass — STUDENT
        outputs_student = model(**inputs)
        logits_s = outputs_student.logits  # [B, T, V]

        # Forward pass — TEACHER (no grad)
        with torch.no_grad():
            outputs_teacher = self.teacher(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
            )
        logits_t = outputs_teacher.logits

        # Reshape for loss: [B*T, V]
        s_log_probs = F.log_softmax(logits_s / self.temperature, dim=-1)
        t_probs = F.softmax(logits_t / self.temperature, dim=-1)

        loss_kl = self.kl_loss_fn(
            s_log_probs.view(-1, s_log_probs.size(-1)),
            t_probs.view(-1, t_probs.size(-1)),
        ) * (self.temperature ** 2)

        loss_ce = self.ce_loss_fn(
            logits_s.view(-1, logits_s.size(-1)),
            labels.view(-1),
        )

        loss = self.alpha * loss_kl + (1 - self.alpha) * loss_ce
        return (loss, outputs_student) if return_outputs else loss

    # ---------------------- convenience API ---------------------- #
    @torch.no_grad()
    def generate(self, prompt: str, max_new_tokens: int = 40):
        device = next(self.model.parameters()).device
        input_ids = self.tokenizer(prompt, return_tensors="pt").to(device)
        out = self.model.generate(
            **input_ids,
            max_new_tokens=max_new_tokens,
            pad_token_id=self.tokenizer.eos_token_id
        )
        return self.tokenizer.decode(out[0], skip_special_tokens=True)

    def push_to_hub(self, repo_id: str):
        self.model.push_to_hub(repo_id, use_auth_token=True)
        self.tokenizer.push_to_hub(repo_id, use_auth_token=True)

In [None]:
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

# --- Teacher (full 8-bit) --- #
teacher = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, load_in_8bit=True, device_map="auto"
)

# --- Student (with LoRA) --- #
student = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, load_in_8bit=True, device_map="auto"
)
student = prepare_model_for_kbit_training(student)
lora_cfg = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05, task_type="CAUSAL_LM"
)
student = get_peft_model(student, lora_cfg)

# --- Toy dataset --- #
examples = [
    {"text": "### User:\nTranslate 'Good morning' to Spanish.\n### Assistant:\n"},
    {"text": "### User:\nSummarise: 'The cat sat on the mat.'\n### Assistant:\n"},
    {"text": "### User:\nList three primary colours.\n### Assistant:\n"},
    {"text": "### User:\nWhat is 2 + 2?\n### Assistant:\n"},
    {"text": "### User:\nRewrite 'I like apples' in the past tense.\n### Assistant:\n"},
]
def tok_fn(e): return tokenizer(
    e["text"], max_length=256, truncation=True, padding="max_length"
)
ds = Dataset.from_list(examples).train_test_split(test_size=0.4, seed=0)
ds_tok = ds.map(tok_fn, remove_columns=["text"])

In [None]:
# --- Distillation fine-tuner --- #
distill_ft = DistillationFineTuner(
    teacher_model=teacher,                  # frozen 8-bit teacher
    student_model=student,                  # LoRA-equipped student
    tokenizer=tokenizer,
    train_dataset=ds_tok["train"],
    eval_dataset=ds_tok["test"],
    temperature=2.0,                        # soften logits
    alpha=0.7,                              # more emphasis on teacher (70 % KL, 30 % CE)
    output_dir="./distill_out",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    learning_rate=2e-4,
    logging_steps=1,
    optim="paged_adamw_8bit",
    report_to=[],                           # no trackers for the demo
)

distill_ft.train()
print(distill_ft.generate("### User:\nWhat is 2 + 2?\n### Assistant:\n"))