# FineTune LLM - SeqKD Distillation

In [None]:
import os, tempfile, shutil, torch, pyarrow as pa, pyarrow.ipc as pa_ipc, torch.nn.functional as F
from datasets import Dataset, IterableDataset
from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    DataCollatorForLanguageModeling, TrainingArguments, Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

**Note**:
- Adds batched teacher-generation and disk caching so RAM/VRAM never explodes.

**How it works differently from `finetune_llm_distillation_seq_kd.ipynb` & things to watch out for**
- When the datasets already contain a text column (prompt + answer joined) the class skips teacher inference and trains immediately.
- If the datasets contain only a prompt column the teacher is invoked once (batched) to create answers in memory; for very large corpora this may still spike GPU time—see the next class for streaming.
- Loss is the standard causal-LM cross-entropy; there is no KL term, temperature, or compute_loss override.

**How it works**
- For corpora that do not fit in RAM and would time out if the teacher had to generate everything before training, use the streaming variant below. It:
  - iterates over the prompt dataset lazily (streaming mode),
  - generates teacher answers in mini-batches,
  - writes each (prompt+answer) row to a temporary Arrow file,
  - reads that file back as a memory-mapped dataset for the student.

In [None]:
"""
Sequence-Level Knowledge Distillation — *streaming* edition
────────────────────────────────────────────────────────────
• Consumes a **streaming** Hugging-Face dataset containing a `"prompt"` column.
• Generates teacher answers **on-the-fly in mini-batches**,
  writes the (prompt+answer) pairs to a temporary Arrow file,
  then memory-maps that file for the student fine-tune.
• Interface mirrors DistillationFineTuner, except it needs a streaming dataset.

Example usage
──────────────
from datasets import load_dataset
prompts = load_dataset("my/billion_prompt_corpus", streaming=True, split="train")

stream_ft = SeqKDStreamingFineTuner(
    teacher_model = teacher,
    student_model = student,
    tokenizer     = tok,
    prompt_stream = prompts,
    output_dir    = "./seqkd_stream_out",
    num_train_epochs = 1,
    per_device_train_batch_size = 8,
    learning_rate = 3e-4,
    logging_steps = 50,
    optim = "paged_adamw_8bit",
)
stream_ft.train()
"""

In [None]:
# ───────────────────────────────────────────────────────────────
# Helper: batch-generate teacher answers
# ───────────────────────────────────────────────────────────────
@torch.no_grad()
def teacher_batch_generate(
    teacher,
    tokenizer: AutoTokenizer,
    prompts: list[str],
    max_new_tokens: int = 128,
):
    """Returns list[str] of prompt+answer."""
    device = next(teacher.parameters()).device
    inputs = tokenizer(
        prompts, return_tensors="pt", padding=True, truncation=True
    ).to(device)
    outs = teacher.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
    )
    answers = tokenizer.batch_decode(outs, skip_special_tokens=True)
    return [p + a for p, a in zip(prompts, answers)]

# ───────────────────────────────────────────────────────────────
# Main class
# ───────────────────────────────────────────────────────────────
class SeqKDStreamingFineTuner(Trainer):
    """
    Sequence-level KD for very-large prompt sets (streaming, disk-backed).
    Constructor intentionally follows DistillationFineTuner signature
    (temperature/alpha removed, prompt_stream added).
    """

    def __init__(
        self,
        teacher_model,
        student_model,
        tokenizer: AutoTokenizer,
        prompt_stream: IterableDataset,      # MUST be streaming dataset w/ 'prompt'
        eval_prompts: IterableDataset | None = None,
        *,
        cache_dir: str | None = None,
        batch_size: int = 32,
        max_seq_len: int = 256,
        max_new_tokens: int = 128,
        prompt_field: str = "prompt",
        **training_args_kwargs,
    ):
        self.tokenizer = tokenizer
        self._tmp_dir = cache_dir or tempfile.mkdtemp(prefix="seqkd_cache_")
        arrow_path = os.path.join(self._tmp_dir, "train.arrow")

        # 1️⃣  Stream prompts → Arrow file
        self._stream_to_arrow(
            arrow_path,
            teacher_model,
            tokenizer,
            prompt_stream,
            batch_size,
            max_new_tokens,
            prompt_field,
        )

        # Build evaluation synthetic set if provided
        if eval_prompts is not None:
            eval_arrow = os.path.join(self._tmp_dir, "eval.arrow")
            self._stream_to_arrow(
                eval_arrow,
                teacher_model,
                tokenizer,
                eval_prompts,
                batch_size,
                max_new_tokens,
                prompt_field,
            )
            ds_eval = Dataset.from_file(eval_arrow)
        else:
            ds_eval = None

        ds_train = Dataset.from_file(arrow_path)

        # optional random split if eval not given
        if ds_eval is None:
            ds_train, ds_eval = ds_train.train_test_split(
                test_size=0.1, seed=42
            ).values()

        # 2️⃣  Tokenise (lazy map keeps memory modest)
        def tok(ex):
            return tokenizer(
                ex["text"],
                truncation=True,
                padding="max_length",
                max_length=max_seq_len,
            )

        ds_train_tok = ds_train.map(tok, remove_columns=["text"], batched=False)
        ds_eval_tok  = ds_eval.map(tok,  remove_columns=["text"], batched=False)

        # 3️⃣  Init Hugging-Face Trainer
        data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
        args = TrainingArguments(**training_args_kwargs)

        super().__init__(
            model=student_model,
            args=args,
            data_collator=data_collator,
            train_dataset=ds_train_tok,
            eval_dataset=ds_eval_tok,
        )

    # ─── utility: streaming → arrow ────────────────────────────
    def _stream_to_arrow(
        self,
        arrow_path,
        teacher,
        tokenizer,
        prompt_iter: IterableDataset,
        batch_size,
        max_new_tokens,
        prompt_field,
    ):
        schema = pa.schema([("text", pa.string())])
        with pa_ipc.new_file(arrow_path, schema) as writer:
            buffer_prompts = []
            for sample in prompt_iter:
                buffer_prompts.append(sample[prompt_field])
                if len(buffer_prompts) == batch_size:
                    texts = teacher_batch_generate(
                        teacher, tokenizer,
                        buffer_prompts,
                        max_new_tokens=max_new_tokens,
                    )
                    writer.write_table(pa.Table.from_arrays([pa.array(texts)], schema=schema))
                    buffer_prompts.clear()
            # flush remainder
            if buffer_prompts:
                texts = teacher_batch_generate(
                    teacher, tokenizer, buffer_prompts, max_new_tokens=max_new_tokens
                )
                writer.write_table(pa.Table.from_arrays([pa.array(texts)], schema=schema))

    # ─── convenience helpers (same API as before) ─────────────
    @torch.no_grad()
    def generate(self, prompt: str, max_new_tokens: int = 40):
        ids = self.tokenizer(prompt, return_tensors="pt").to(
            next(self.model.parameters()).device
        )
        out = self.model.generate(
            **ids,
            max_new_tokens=max_new_tokens,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.decode(out[0], skip_special_tokens=True)

    def push_to_hub(self, repo_id: str):
        self.model.push_to_hub(repo_id, use_auth_token=True)
        self.tokenizer.push_to_hub(repo_id, use_auth_token=True)

    # ─── cleanup tmp dir if we created it ─────────────────────
    def __del__(self):
        if hasattr(self, "_tmp_dir") and self._tmp_dir and "seqkd_cache_" in self._tmp_dir:
            shutil.rmtree(self._tmp_dir, ignore_errors=True)

In [None]:
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

# ---------- Load TEACHER (frozen) ----------
teacher = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, load_in_8bit=True, device_map="auto"
)

# ---------- Build STUDENT with LoRA ----------
student = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, load_in_8bit=True, device_map="auto"
)
student = prepare_model_for_kbit_training(student)
lora_cfg = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05, task_type="CAUSAL_LM"
)
student = get_peft_model(student, lora_cfg)

# ---------- Raw prompts ----------
PROMPTS = [
    "### User:\nTranslate 'Good morning' to Spanish.\n### Assistant:\n",
    "### User:\nSummarise: 'The cat sat on the mat.'\n### Assistant:\n",
    "### User:\nList three primary colours.\n### Assistant:\n",
    "### User:\nWhat is 2 + 2?\n### Assistant:\n",
    "### User:\nRewrite 'I like apples' in the past tense.\n### Assistant:\n",
]

hf_train_prompts = Dataset() # Hugging Face Dataset objects with a 'prompt' column (plain strings)
hf_val_prompts = Dataset() # Hugging Face Dataset objects with a 'prompt' column (plain strings)

In [None]:
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
device   = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token        # required for padding

teacher = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    load_in_8bit=True,
    device_map="auto"
)

student = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    load_in_8bit=True,
    device_map="auto"
)
student = prepare_model_for_kbit_training(student)

lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)
student = get_peft_model(student, lora_cfg)

# Must yield dicts with a "prompt" string.
train_stream = load_dataset(
    "my/big_prompt_corpus",
    split="train",
    streaming=True
)  # IterableDataset

# Must yield dicts with a "prompt" string.
val_stream = load_dataset(
    "my/big_prompt_corpus",
    split="validation",
    streaming=True
)  # IterableDataset

In [None]:
# ---------- SeqKD fine-tuner ----------
seqkd_stream_ft = SeqKDStreamingFineTuner(
    teacher_model  = teacher,
    student_model  = student,
    tokenizer      = tokenizer,
    prompt_stream  = train_stream,     # streaming training prompts
    eval_prompts   = val_stream,       # optional streaming validation prompts
    batch_size     = 64,               # teacher-generation mini-batch
    output_dir     = "./seqkd_stream_large",
    num_train_epochs = 3,
    per_device_train_batch_size = 8,   # student update batch
    learning_rate  = 3e-4,
    optim          = "paged_adamw_8bit",
    logging_steps  = 100,
)


# ---------- Train & test ----------
seqkd_stream_ft.train()

test_prompt = "### User:\nWhat is 2 + 2?\n### Assistant:\n"
print("\n🟢  Student AFTER streaming SeqKD:")
print(seqkd_stream_ft.generate(test_prompt))