# FineTune LLM - SeqKD Distillation

In [None]:
import os, torch, math, torch.nn.functional as F
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    DataCollatorForLanguageModeling, TrainingArguments, Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

**Note**:
- This approach feeds in `RAW_PROMPTS` instead of tokens, so it is not interchangable with the other notebooks.
- See `finetune_llm_distillation_seq_kd_adjusted.ipynb` for a version that is interchangable.
- See `finetune_llm_distillation_seq_kd_adjusted_realistic.ipynb` for a version that is more realistic in how it handles the training.

In [None]:
# ────────────────────────────────────────────────────────────────
# Helper: tiny wrapper to run teacher once and grab its answers
# ────────────────────────────────────────────────────────────────
@torch.no_grad()
def synthesize_with_teacher(teacher, tokenizer, prompts, max_new_tokens=64):
    teacher.eval()
    teacher.requires_grad_(False)
    pairs = []
    for p in prompts:
        inputs = tokenizer(p, return_tensors="pt").to(next(teacher.parameters()).device)
        out = teacher.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
        answer = tokenizer.decode(out[0], skip_special_tokens=True)
        pairs.append({"text": p + answer})
    return pairs

# ────────────────────────────────────────────────────────────────
# SeqKD Fine-tuner class
# ────────────────────────────────────────────────────────────────
class SeqKDFineTuner(Trainer):
    """
    Runs Sequence-Level KD:
      1. uses the TEACHER to create (prompt, teacher_answer) pairs,
      2. fine-tunes the STUDENT (with LoRA) on those pairs via standard CE.
    """

    def __init__(
        self,
        teacher_model,
        student_model,
        tokenizer,
        raw_prompts,                   # list[str] or iterable
        max_seq_len: int = 256,
        **training_args_kwargs
    ):
        self.tokenizer = tokenizer

        # ---------------- 1️⃣  Create synthetic dataset ---------------- #
        print("🗒️  Generating teacher answers for SeqKD …")
        synthetic_rows = synthesize_with_teacher(
            teacher_model, tokenizer, raw_prompts
        )
        dataset = Dataset.from_list(synthetic_rows).train_test_split(
            test_size=0.2, seed=42
        )

        def tok_fn(ex):
            return tokenizer(
                ex["text"],
                truncation=True,
                padding="max_length",
                max_length=max_seq_len,
            )

        tokenised = dataset.map(tok_fn, remove_columns=["text"])
        data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

        # -------------- 2️⃣  Standard TrainingArguments -------------- #
        args = TrainingArguments(**training_args_kwargs)

        # -------------- 3️⃣  Kick off Hugging-Face Trainer ------------- #
        super().__init__(
            model=student_model,
            args=args,
            data_collator=data_collator,
            train_dataset=tokenised["train"],
            eval_dataset=tokenised["test"],
        )

    # ----- Convenience API (unchanged from previous class) ----- #
    @torch.no_grad()
    def generate(self, prompt: str, max_new_tokens: int = 40):
        ids = self.tokenizer(prompt, return_tensors="pt").to(
            next(self.model.parameters()).device
        )
        out = self.model.generate(
            **ids,
            max_new_tokens=max_new_tokens,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.decode(out[0], skip_special_tokens=True)

    def push_to_hub(self, repo_id: str):
        self.model.push_to_hub(repo_id, use_auth_token=True)
        self.tokenizer.push_to_hub(repo_id, use_auth_token=True)

In [None]:
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

# ---------- Load TEACHER (frozen) ----------
teacher = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, load_in_8bit=True, device_map="auto"
)

# ---------- Build STUDENT with LoRA ----------
student = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, load_in_8bit=True, device_map="auto"
)
student = prepare_model_for_kbit_training(student)
lora_cfg = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05, task_type="CAUSAL_LM"
)
student = get_peft_model(student, lora_cfg)

# ---------- Raw prompts ----------
PROMPTS = [
    "### User:\nTranslate 'Good morning' to Spanish.\n### Assistant:\n",
    "### User:\nSummarise: 'The cat sat on the mat.'\n### Assistant:\n",
    "### User:\nList three primary colours.\n### Assistant:\n",
    "### User:\nWhat is 2 + 2?\n### Assistant:\n",
    "### User:\nRewrite 'I like apples' in the past tense.\n### Assistant:\n",
]

In [None]:
# ---------- SeqKD fine-tuner ----------
seqkd_ft = SeqKDFineTuner(
    teacher_model = teacher,
    student_model = student,
    tokenizer = tokenizer,
    raw_prompts = PROMPTS,
    output_dir = "./seqkd_out",
    num_train_epochs = 1,
    per_device_train_batch_size = 4,
    learning_rate = 2e-4,
    logging_steps = 1,
    optim = "paged_adamw_8bit",
    report_to = [],
)

# ---------- Train & test ----------
seqkd_ft.train()
test_prompt = "### User:\nWhat is 2 + 2?\n### Assistant:\n"
print("\n🟢  Student after SeqKD:")
print(seqkd_ft.generate(test_prompt))