In [4]:
!pip install -U bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl.metadata (11 kB)
Downloading bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl (61.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.47.0


In [1]:
import os, math, random, warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, Dataset
from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
                          Trainer, TrainingArguments)

In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"
cap = torch.cuda.get_device_capability(0)[0] if torch.cuda.is_available() else 0
CAN_BF16 = torch.cuda.is_available() and cap >= 8  # Ampere+ supports bf16
print(f"Device: {device}, CC Major: {cap}, bf16: {CAN_BF16}")

Device: cuda, CC Major: 7, bf16: False


In [26]:
FREE = True  # set False if you have Pro-tier GPU (A100/H100/etc.)

if FREE:
    TEACHER_MODEL = "gpt2"
    STUDENT_MODEL = "distilgpt2"
    LOAD_IN_4BIT  = False                # 8-bit teacher for T4
    BLOCK_SIZE    = 256
    TRAIN_TOKENS  = 200_000              # small demo
    VAL_TOKENS    = 20_000
    BATCH_SIZE    = 2
    GRAD_ACCUM    = 8
    EPOCHS        = 1
else:
    TEACHER_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
    STUDENT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    LOAD_IN_4BIT  = True                 # 4-bit teacher
    BLOCK_SIZE    = 512
    TRAIN_TOKENS  = 500_000
    VAL_TOKENS    = 50_000
    BATCH_SIZE    = 1
    GRAD_ACCUM    = 16
    EPOCHS        = 1

In [27]:
LR     = 2e-4
T      = 2.0     # KD temperature
ALPHA  = 0.2     # CE weight
BETA   = 0.8     # KD weight
OUT_DIR = "kd-out"

In [28]:
if LOAD_IN_4BIT:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16 if CAN_BF16 else torch.float16
    )
else:
    bnb_config = BitsAndBytesConfig(load_in_8bit=True)

In [29]:
# Teacher (frozen)
teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL, use_fast=True)
teacher_tokenizer.padding_side = "left"
if teacher_tokenizer.pad_token is None:
    teacher_tokenizer.pad_token = teacher_tokenizer.eos_token

teacher = AutoModelForCausalLM.from_pretrained(
    TEACHER_MODEL,
    quantization_config=bnb_config,
    device_map="auto"
)
teacher.eval()
for p in teacher.parameters():
    p.requires_grad_(False)

In [30]:
# Student
student_tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL, use_fast=True)
student_tokenizer.padding_side = "left"
if student_tokenizer.pad_token is None:
    student_tokenizer.pad_token = student_tokenizer.eos_token

student = AutoModelForCausalLM.from_pretrained(
    STUDENT_MODEL,
    torch_dtype=(torch.bfloat16 if CAN_BF16 else (torch.float16 if torch.cuda.is_available() else torch.float32)),
    device_map="auto"
)
student.config.use_cache = False  # safer for training

In [31]:
from peft import LoraConfig, get_peft_model

def pick_lora_targets_for_decoder(model: nn.Module):
    names = [n for n, m in model.named_modules() if isinstance(m, nn.Linear)]
    if any("q_proj" in n for n in names):  # LLaMA/Mistral
        return dict(targets=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
                    fan_in_fan_out=False)
    if any("c_attn" in n for n in names):  # GPT-2 / DistilGPT2
        return dict(targets=["c_attn","c_fc","c_proj"], fan_in_fan_out=True)
    uniq = list({n.split(".")[-1] for n in names})
    return dict(targets=uniq, fan_in_fan_out=False)

In [32]:
cfg = pick_lora_targets_for_decoder(student)
print("LoRA targets:", cfg)

LoRA targets: {'targets': ['lm_head'], 'fan_in_fan_out': False}


In [33]:
lora_cfg = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=cfg["targets"],
    fan_in_fan_out=cfg["fan_in_fan_out"]
)

In [34]:
student = get_peft_model(student, lora_cfg)
student.print_trainable_parameters()

trainable params: 408,200 || all params: 82,320,776 || trainable%: 0.4959


In [35]:
def build_causal_dataset_streaming(texts, tokenizer, block_size=256, max_tokens=200_000):
    buffer, chunks, produced = [], [], 0
    for t in texts:
        t = (t or "").strip()
        if not t:
            continue
        ids = tokenizer.encode(t + "\n", add_special_tokens=False)
        buffer.extend(ids)
        while len(buffer) >= block_size:
            chunk = buffer[:block_size]
            buffer = buffer[block_size:]
            chunks.append({"input_ids": chunk, "labels": chunk.copy()})
            produced += block_size
            if produced >= max_tokens:
                return Dataset.from_list(chunks)
    return Dataset.from_list(chunks)

In [36]:
raw = load_dataset("wikitext", "wikitext-2-raw-v1")

In [37]:
train_ds = build_causal_dataset_streaming(raw["train"]["text"], student_tokenizer,
                                          block_size=BLOCK_SIZE, max_tokens=TRAIN_TOKENS)
val_ds   = build_causal_dataset_streaming(raw["validation"]["text"], student_tokenizer,
                                          block_size=BLOCK_SIZE, max_tokens=VAL_TOKENS)

In [38]:
print("Train samples:", len(train_ds), " Val samples:", len(val_ds))

Train samples: 782  Val samples: 79


In [39]:
def collate_fn(batch):
    max_len = max(len(x["input_ids"]) for x in batch)
    pad_id = student_tokenizer.pad_token_id
    input_ids, labels, attn = [], [], []
    for x in batch:
        ids = x["input_ids"]
        pad = [pad_id] * (max_len - len(ids))
        input_ids.append(ids + pad)
        labels.append(x["labels"] + [-100] * (max_len - len(ids)))
        attn.append([1]*len(ids) + [0]*len(pad))
    return {
        "input_ids": torch.tensor(input_ids),
        "labels": torch.tensor(labels),
        "attention_mask": torch.tensor(attn),
    }

In [40]:
def kd_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.2, beta=0.8):
    ce = F.cross_entropy(
        student_logits.view(-1, student_logits.size(-1)),
        labels.view(-1),
        ignore_index=-100
    )
    s = F.log_softmax(student_logits / T, dim=-1)
    with torch.no_grad():
        t = F.softmax(teacher_logits / T, dim=-1)
    kl = F.kl_div(s, t, reduction="batchmean") * (T**2)
    return alpha * ce + beta * kl

In [41]:
class KDTrainer(Trainer):
    # Accept new HF kwarg: num_items_in_batch
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs["labels"]
        outputs_s = model(input_ids=inputs["input_ids"],
                          attention_mask=inputs["attention_mask"])
        student_logits = outputs_s.logits

        with torch.no_grad():
            outputs_t = teacher(
                input_ids=inputs["input_ids"].to(teacher.device),
                attention_mask=inputs["attention_mask"].to(teacher.device)
            )
            teacher_logits = outputs_t.logits.to(student_logits.device)

        loss = kd_loss(student_logits, teacher_logits, labels, T=T, alpha=ALPHA, beta=BETA)
        return (loss, {"logits": student_logits}) if return_outputs else loss

In [43]:
args = TrainingArguments(
    output_dir=OUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    num_train_epochs=EPOCHS,
    logging_steps=25,
    eval_strategy="steps",   # (fix) correct arg name
    eval_steps=200,
    save_steps=200,
    save_total_limit=2,
    bf16=CAN_BF16,
    fp16=(torch.cuda.is_available() and not CAN_BF16),
    gradient_checkpointing=False,  # not needed for small student; set True if you want
    report_to="none"
)

In [44]:
trainer = KDTrainer(
    model=student,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn
)


In [45]:
trainer.train()

Step,Training Loss,Validation Loss


TrainOutput(global_step=49, training_loss=962.4593431122449, metrics={'train_runtime': 44.546, 'train_samples_per_second': 17.555, 'train_steps_per_second': 1.1, 'total_flos': 51573824987136.0, 'train_loss': 962.4593431122449, 'epoch': 1.0})

In [46]:
eval_res = trainer.evaluate()
try:
    ppl = math.exp(eval_res["eval_loss"])
except OverflowError:
    ppl = float("inf")
print({"eval_loss": eval_res["eval_loss"], "perplexity": ppl})

{'eval_loss': 89.7918472290039, 'perplexity': 9.910687251491893e+38}


In [47]:
SAVE_DIR = "kd-student-lora"
student.save_pretrained(SAVE_DIR)
student_tokenizer.save_pretrained(SAVE_DIR)

('kd-student-lora/tokenizer_config.json',
 'kd-student-lora/special_tokens_map.json',
 'kd-student-lora/vocab.json',
 'kd-student-lora/merges.txt',
 'kd-student-lora/added_tokens.json',
 'kd-student-lora/tokenizer.json')

In [48]:
prompt = "In healthcare AI, knowledge distillation helps small models by"

In [49]:
inputs = student_tokenizer(prompt, return_tensors="pt").to(student.device)

In [50]:
student.eval()

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPT2LMHeadModel(
      (transformer): GPT2Model(
        (wte): Embedding(50257, 768)
        (wpe): Embedding(1024, 768)
        (drop): Dropout(p=0.1, inplace=False)
        (h): ModuleList(
          (0-5): 6 x GPT2Block(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): GPT2Attention(
              (c_attn): Conv1D(nf=2304, nx=768)
              (c_proj): Conv1D(nf=768, nx=768)
              (attn_dropout): Dropout(p=0.1, inplace=False)
              (resid_dropout): Dropout(p=0.1, inplace=False)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): GPT2MLP(
              (c_fc): Conv1D(nf=3072, nx=768)
              (c_proj): Conv1D(nf=768, nx=3072)
              (act): NewGELUActivation()
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (ln_f): LayerNorm((768,), eps=1e-05, e

In [51]:
with torch.no_grad():
    gen = student.generate(
        **inputs,
        max_new_tokens=120,
        do_sample=True,
        top_p=0.9,
        temperature=0.8,
        repetition_penalty=1.1,
        eos_token_id=student_tokenizer.eos_token_id
    )

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [52]:
print(student_tokenizer.decode(gen[0], skip_special_tokens=True))

In healthcare AI, knowledge distillation helps small models by allowing them to perform a variety of tasks.
The researchers used their data on nearly 4 million patients in the first 10 years of clinical trials (2012-2014). In 2012, 6.5 million people who participated in these three studies were eligible for admission to Medicaid or Medicare benefits as part and parcel of this year's enrollment bonus program ($2.0 million) while 17 percent of those with pre-existing conditions got up to 30 percent more health insurance coverage than did adults under age 65. The average number of uninsured was estimated to be between 60 percent and 90 percent lower than it should have been
