# FineTune LLM - SeqKD Distillation

In [None]:
import os, torch, math, torch.nn.functional as F
from datasets import Dataset, disable_caching
from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    DataCollatorForLanguageModeling, TrainingArguments, Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

disable_caching()  # keeps HF from copying large tmp files

**Note**:
- If you plan to distill from a large prompt set (100 K +), use `SeqKDStreamingFineTuner` in `finetune_llm_distillation_seq_kd_adjusted_realistic.ipynb`, which adds batched teacher-generation and disk caching so RAM/VRAM never explodes.

**How it works differently from `finetune_llm_distillation_seq_kd.ipynb` & things to watch out for**
- When the datasets already contain a text column (prompt + answer joined) the class skips teacher inference and trains immediately.
- If the datasets contain only a prompt column the teacher is invoked once (batched) to create answers in memory; for very large corpora this may still spike GPU time—see the next class for streaming.
- Loss is the standard causal-LM cross-entropy; there is no KL term, temperature, or compute_loss override.

In [None]:
# ───────────────────────────────────────────────────────────────
# Helper: combine prompt + teacher answer
# ───────────────────────────────────────────────────────────────
@torch.no_grad()
def generate_answers(teacher, tokenizer, ds_prompts,
                     prompt_col="prompt",
                     max_new_tokens=128,
                     batch_size=8):
    teacher.eval().requires_grad_(False)
    device = next(teacher.parameters()).device

    rows = []
    # group the dataset into mini-batches of `batch_size`
    for i in range(0, len(ds_prompts), batch_size):
        batch = ds_prompts.select(range(i, min(i + batch_size, len(ds_prompts))))
        inputs = tokenizer(
            batch[prompt_col],
            return_tensors="pt",
            padding=True, truncation=True
        ).to(device)
        outs = teacher.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
        answers = tokenizer.batch_decode(outs, skip_special_tokens=True)
        for p, a in zip(batch[prompt_col], answers):
            rows.append({"text": p + a})
    return Dataset.from_list(rows)

# ───────────────────────────────────────────────────────────────
# SeqKD class with DistillationFineTuner-compatible signature
# ───────────────────────────────────────────────────────────────
class SeqKDFineTuner(Trainer):
    """
    Sequence-level KD:
    • if `train_dataset` has column 'text' → uses it directly
    • else expects column 'prompt' → autogenerates teacher answers
    """

    def __init__(
        self,
        teacher_model,
        student_model,
        tokenizer: AutoTokenizer,
        train_dataset,
        eval_dataset=None,
        max_seq_len: int = 256,
        prompt_col: str = "prompt",
        **training_args_kwargs
    ):
        self.tokenizer = tokenizer

        # 1) Build synthetic dataset (only once)
        if "text" not in train_dataset.column_names:
            print("🛠  Synthesising teacher answers …")
            train_dataset = generate_answers(
                teacher_model, tokenizer, train_dataset,
                prompt_col=prompt_col
            )
            if eval_dataset is not None and "text" not in eval_dataset.column_names:
                eval_dataset = generate_answers(
                    teacher_model, tokenizer, eval_dataset,
                    prompt_col=prompt_col
                )

        if eval_dataset is None:                       # fallback split
            train_dataset, eval_dataset = train_dataset.train_test_split(
                test_size=0.1, seed=42
            ).values()

        def tok_fn(ex):
            return tokenizer(
                ex["text"],
                truncation=True,
                padding="max_length",
                max_length=max_seq_len
            )

        train_tok = train_dataset.map(tok_fn, remove_columns=["text"])
        eval_tok  = eval_dataset.map(tok_fn,  remove_columns=["text"])

        data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
        args = TrainingArguments(**training_args_kwargs)

        super().__init__(
            model=student_model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_tok,
            eval_dataset=eval_tok,
        )

    # unchanged helper methods
    @torch.no_grad()
    def generate(self, prompt, max_new_tokens=40):
        ids = self.tokenizer(prompt, return_tensors="pt").to(
            next(self.model.parameters()).device
        )
        out = self.model.generate(
            **ids,
            max_new_tokens=max_new_tokens,
            pad_token_id=self.tokenizer.eos_token_id
        )
        return self.tokenizer.decode(out[0], skip_special_tokens=True)

    def push_to_hub(self, repo_id):
        self.model.push_to_hub(repo_id, use_auth_token=True)
        self.tokenizer.push_to_hub(repo_id, use_auth_token=True)

In [None]:
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

# ---------- Load TEACHER (frozen) ----------
teacher = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, load_in_8bit=True, device_map="auto"
)

# ---------- Build STUDENT with LoRA ----------
student = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, load_in_8bit=True, device_map="auto"
)
student = prepare_model_for_kbit_training(student)
lora_cfg = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05, task_type="CAUSAL_LM"
)
student = get_peft_model(student, lora_cfg)

# ---------- Raw prompts ----------
PROMPTS = [
    "### User:\nTranslate 'Good morning' to Spanish.\n### Assistant:\n",
    "### User:\nSummarise: 'The cat sat on the mat.'\n### Assistant:\n",
    "### User:\nList three primary colours.\n### Assistant:\n",
    "### User:\nWhat is 2 + 2?\n### Assistant:\n",
    "### User:\nRewrite 'I like apples' in the past tense.\n### Assistant:\n",
]

hf_train_prompts = Dataset() # Hugging Face Dataset objects with a 'prompt' column (plain strings)
hf_val_prompts = Dataset() # Hugging Face Dataset objects with a 'prompt' column (plain strings)

In [None]:
# ---------- SeqKD fine-tuner ----------
seqkd_ft = SeqKDFineTuner(
    teacher_model=teacher,
    student_model=student,
    tokenizer=tokenizer,
    train_dataset=hf_train_prompts,
    eval_dataset=hf_val_prompts,
    output_dir="./seqkd_large",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    learning_rate=3e-4,
    optim="paged_adamw_8bit",
)


# ---------- Train & test ----------
seqkd_ft.train()
test_prompt = "### User:\nWhat is 2 + 2?\n### Assistant:\n"
print("\n🟢  Student after SeqKD:")
print(seqkd_ft.generate(test_prompt))