# Knowledge Distillation Pipeline

This notebook demonstrates knowledge distillation from a teacher LLM to a student LLM using jailbreak prompts. The workflow includes model loading, dataset preparation, tokenization, custom KD loss, training, evaluation, and saving the distilled model.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from tqdm.notebook import tqdm
from safetensors.torch import safe_open
from peft import (
    prepare_model_for_kbit_training,
    LoraConfig,
    get_peft_model
)
from transformers import (
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
import pandas as pd
import os
from tqdm.notebook import tqdm

# Basic settings
device = torch.device("cuda")
num_epochs = 2
batch_size = 8
print("Device:", device)

Device: cuda


In [None]:
# Load student and teacher models using your custom loaders (keeps code modular)
import sys
import os
# Ensure the notebook directory is on sys.path so local modules can be imported
nb_dir = r"f:\T2430392\Efficient-Distilled-miniLLM-to-Prevent-Jailbreak-Attacks\Knowledge_Distillation"
if nb_dir not in sys.path:
    sys.path.insert(0, nb_dir)
from llama_1b import load_llama_1b
from llama_8b import load_llama_8b

student_model, student_tokenizer = load_llama_1b()
teacher_model, teacher_tokenizer = load_llama_8b()

student_model.to(device)
teacher_model.to(device)

# Ensure tokenizer has pad token
if student_tokenizer.pad_token is None:
    student_tokenizer.add_special_tokens({'pad_token': '<pad>'})
student_model.generation_config.pad_token_id = student_tokenizer.pad_token_id

print('Loaded student and teacher models')

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

✅ CUDA available – loading LLaMA 8B in 8-bit on GPU


tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

✅ CUDA available – loading LLaMA 8B in 8-bit on GPU


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

✅ CUDA available – loading LLaMA 8B in 8-bit on GPU


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

✅ CUDA available – loading LLaMA 8B in 8-bit on GPU


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

✅ CUDA available – loading LLaMA 8B in 8-bit on GPU


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

✅ CUDA available – loading LLaMA 8B in 8-bit on GPU


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

✅ CUDA available – loading LLaMA 8B in 8-bit on GPU


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

✅ CUDA available – loading LLaMA 8B in 8-bit on GPU


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

ValueError: Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to `from_pretrained`. Check https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu for more details. 

In [20]:
import math
import os
import torch
import logging
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_scheduler
from tqdm.auto import tqdm
import torch.nn.functional as F

# Optional: TensorBoard logging
use_tensorboard = True
try:
    if use_tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        tb_writer = SummaryWriter(log_dir="./runs/distillation")
except Exception:
    use_tensorboard = False
    tb_writer = None

# Basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("distillation")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Ensure models on device
teacher_model = teacher_model.to(device)
teacher_model.eval()
for p in teacher_model.parameters():
    p.requires_grad = False

student_model = student_model.to(device)
student_model.train()

# ---------- Hyperparams ----------
train_batch_size = 1
gradient_accumulation_steps = 4
num_epochs = 1
learning_rate = 5e-5
warmup_steps = 50
max_grad_norm = 1.0
save_every_n_updates = 50  # save checkpoint every N optimizer updates
output_dir = "./distilled-llama1b-jailbreak-final"
os.makedirs(output_dir, exist_ok=True)
# ---------------------------------

# Collate function (robust to tensors/numpy/lists)
def collate_fn(batch):
    collated = {}
    keys = set().union(*[set(item.keys()) for item in batch])
    for k in keys:
        vals = []
        for item in batch:
            v = item.get(k)
            # if already tensor, use it; otherwise convert
            if isinstance(v, torch.Tensor):
                vals.append(v)
            else:
                vals.append(torch.tensor(v))
        collated[k] = torch.stack(vals)
    # replace pad token id in labels with -100 so CE ignores it
    if 'labels' in collated:
        pad_token_id = student_tokenizer.pad_token_id if student_tokenizer.pad_token_id is not None else student_tokenizer.convert_tokens_to_ids('[PAD]')
        collated['labels'] = collated['labels'].clone()
        collated['labels'][collated['labels'] == pad_token_id] = -100
    return collated

dataset = load_dataset("TrustAIRLab/in-the-wild-jailbreak-prompts", "jailbreak_2023_12_25")
logger.info(f"Dataset loaded. Split train size: {len(dataset['train'])}")

# ---------------- Tokenization ----------------
def preprocess_batch(batch):
    inputs = student_tokenizer(batch["prompt"], truncation=True, padding="max_length", max_length=256, return_tensors="pt")
    inputs["labels"] = inputs["input_ids"].clone()
    return inputs

tokenized_ds = dataset["train"].map(preprocess_batch, batched=True, remove_columns=dataset["train"].column_names)

train_loader = DataLoader(tokenized_ds, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)

# Optimizer: only parameters that require grad (LoRA adapters)
optimizer = AdamW(filter(lambda p: p.requires_grad, student_model.parameters()), lr=learning_rate)

# Scheduler
num_update_steps_per_epoch = math.ceil(len(train_loader) / gradient_accumulation_steps)
num_training_steps = num_epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=num_training_steps,
)

scaler = torch.cuda.amp.GradScaler(enabled=True)  # FP16 AMP

# Distillation loss that returns components (masked, stable)
def distillation_loss_components(student_logits, teacher_logits, labels, alpha=0.5, T=2.0, clamp_val=30.0):
    """
    Return (total_loss, soft_loss, hard_loss)
    Computes KL(teacher || student) and CE(student, labels) only over positions
    where labels != -100 (i.e., non-padding). This avoids NaNs when batches
    contain only padding labels and improves numerical stability.
    """
    # cast and clamp for numerical stability
    student_logits = student_logits.float().clamp(-clamp_val, clamp_val)
    teacher_logits = teacher_logits.float().clamp(-clamp_val, clamp_val)

    # flatten to (batch*seq, vocab)
    bsz, seq_len, vocab_size = student_logits.size()
    s_flat = student_logits.view(-1, vocab_size)
    t_flat = teacher_logits.view(-1, vocab_size)
    labels_flat = labels.view(-1)

    # mask valid positions
    valid_mask = labels_flat != -100
    valid_count = int(valid_mask.sum().item())

    if valid_count == 0:
        # no supervised tokens in this batch
        hard_loss = torch.tensor(0.0, device=student_logits.device)
        soft_loss = torch.tensor(0.0, device=student_logits.device)
    else:
        s_sel = s_flat[valid_mask]
        t_sel = t_flat[valid_mask]
        lbl_sel = labels_flat[valid_mask]

        # hard supervised loss
        hard_loss = F.cross_entropy(s_sel, lbl_sel, reduction='mean')

        # soft KL loss
        s_log_probs = F.log_softmax(s_sel / T, dim=-1)
        t_probs = F.softmax(t_sel / T, dim=-1)
        soft_loss = F.kl_div(s_log_probs, t_probs, reduction='batchmean') * (T * T)

    total_loss = alpha * soft_loss + (1.0 - alpha) * hard_loss
    # final safety guard
    if not torch.isfinite(total_loss):
        total_loss = torch.tensor(1e6, device=student_logits.device)
    return total_loss, soft_loss, hard_loss

# Training loop with logging
global_update = 0  # counts optimizer.step calls
pbar = tqdm(total=num_training_steps, desc="Distillation")

for epoch in range(num_epochs):
    student_model.train()
    epoch_loss_sum = 0.0
    epoch_soft_sum = 0.0
    epoch_hard_sum = 0.0
    optimizer.zero_grad()

    for step, batch in enumerate(train_loader):
        # move tensors to device
        batch = {k: v.to(device) for k, v in batch.items() if torch.is_tensor(v)}
        labels = batch.get("labels")

        # build model inputs (some tokenizers use input_ids + attention_mask)
        model_inputs = {}
        if "input_ids" in batch:
            model_inputs["input_ids"] = batch["input_ids"]
        if "attention_mask" in batch:
            model_inputs["attention_mask"] = batch["attention_mask"]

        try:
            with torch.no_grad():
                # student forward
                student_outputs = student_model(**model_inputs)
                student_logits = student_outputs.logits

                # teacher forward (frozen)
                with torch.no_grad():
                    teacher_outputs = teacher_model(**model_inputs)
                    teacher_logits = teacher_outputs.logits

                loss, soft_loss, hard_loss = distillation_loss_components(student_logits, teacher_logits, labels)
                loss = loss / gradient_accumulation_steps  # scale down for accumulation

            scaler.scale(loss).backward()

            # book-keeping sums (unscaled)
            epoch_loss_sum += (loss.item() * gradient_accumulation_steps)
            epoch_soft_sum += (soft_loss.item())
            epoch_hard_sum += (hard_loss.item())

        except RuntimeError as e:
            # OOM handling: clear gradients, skip this batch
            if "out of memory" in str(e):
                logger.warning("OOM on step %d, clearing gradients and skipping step", step)
                scaler.scale(torch.tensor(0.0, device=device)).backward()  # safe no-op
                optimizer.zero_grad()
                torch.cuda.empty_cache()
                continue
            else:
                raise

        # optimizer step when accumulated enough gradients
        if (step + 1) % gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, student_model.parameters()), max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            lr_scheduler.step()
            optimizer.zero_grad()
            global_update += 1
            pbar.update(1)

            # Logging: print every update
            avg_loss = epoch_loss_sum / ((step + 1) if (step + 1) > 0 else 1)
            logger.info(f"Epoch {epoch+1} | Update {global_update}/{num_training_steps} | total_loss(avg_so_far) = {avg_loss:.4f} | soft = {epoch_soft_sum/global_update:.4f} | hard = {epoch_hard_sum/global_update:.4f} | lr = {lr_scheduler.get_last_lr()[0]:.2e}")

            # TensorBoard logging
            if use_tensorboard and tb_writer is not None:
                tb_writer.add_scalar("loss/total", avg_loss, global_update)
                tb_writer.add_scalar("loss/soft", epoch_soft_sum/global_update, global_update)
                tb_writer.add_scalar("loss/hard", epoch_hard_sum/global_update, global_update)
                tb_writer.add_scalar("lr", lr_scheduler.get_last_lr()[0], global_update)

            # Save a checkpoint / adapters periodically
            if global_update % save_every_n_updates == 0:
                ckpt_dir = os.path.join(output_dir, f"checkpoint-upd-{global_update}")
                logger.info("Saving checkpoint to %s", ckpt_dir)
                os.makedirs(ckpt_dir, exist_ok=True)
                # Save only PEFT adapters (lightweight) if model supports it
                try:
                    student_model.save_pretrained(ckpt_dir)
                except Exception as e:
                    logger.warning("Failed to save with save_pretrained(): %s. Trying torch.save on state_dict()", e)
                    torch.save(student_model.state_dict(), os.path.join(ckpt_dir, "student_state_dict.pt"))

    # epoch summary
    avg_epoch_loss = epoch_loss_sum / len(train_loader) if len(train_loader) > 0 else 0.0
    avg_epoch_soft = epoch_soft_sum / (global_update if global_update > 0 else 1)
    avg_epoch_hard = epoch_hard_sum / (global_update if global_update > 0 else 1)
    logger.info(f"Epoch {epoch+1} finished. Avg loss over epoch: {avg_epoch_loss:.4f} | avg soft: {avg_epoch_soft:.4f} | avg hard: {avg_epoch_hard:.4f}")

# Final save (adapters)
logger.info("Saving final adapters/model to %s", output_dir)
student_model.eval()
student_model.save_pretrained(output_dir)

if use_tensorboard and tb_writer is not None:
    tb_writer.close()

pbar.close()
logger.info("Distillation complete! Student adapters/model saved at '%s'", output_dir)


INFO:distillation:Using device: cuda
INFO:distillation:Dataset loaded. Split train size: 1405
INFO:distillation:Dataset loaded. Split train size: 1405
  scaler = torch.cuda.amp.GradScaler(enabled=True)  # FP16 AMP
  scaler = torch.cuda.amp.GradScaler(enabled=True)  # FP16 AMP


Distillation:   0%|          | 0/352 [00:00<?, ?it/s]

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [6]:
from datasets import load_dataset
import pandas as pd

dataset = load_dataset("TrustAIRLab/in-the-wild-jailbreak-prompts", "jailbreak_2023_12_25")


# Convert to pandas DataFrame
df = pd.DataFrame(dataset['train'])

# Extract only the prompts
prompts_df = df[['prompt']]
print(prompts_df.head())


                                              prompt
0  CumGPT, an information resource about cum that...
1  I will give you a character description and yo...
2  A completely amoral Ai named [Illegality Mode]...
3  Agares's rules={\nAlways stays in character, n...
4  Lucian is an incorporeal transcendent being wh...


In [14]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW  # Import AdamW from torch.optim instead of transformers
from transformers import get_scheduler
from tqdm.auto import tqdm
from datasets import load_dataset
import pandas as pd
import sys
import os

# --- 0. Setup: Device, Custom Model Loaders, and Dataset Loading ---

# Ensure the notebook directory is on sys.path so local modules can be imported
nb_dir = r"f:\T2430392\Efficient-Distilled-miniLLM-to-Prevent-Jailbreak-Attacks\Knowledge_Distillation"
if nb_dir not in sys.path:
    sys.path.insert(0, nb_dir)

# Load your custom model loaders
from llama_1b import load_llama_1b
from llama_8b import load_llama_8b

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load student and teacher models
print("Loading student (llama_1b) and teacher (llama_8b) models...")
student_model, student_tokenizer = load_llama_1b()
teacher_model, teacher_tokenizer = load_llama_8b()

student_model.to(device)
teacher_model.to(device)

if student_tokenizer.pad_token is None:
    student_tokenizer.add_special_tokens({'pad_token': '<pad>'})
student_model.config.pad_token_id = student_tokenizer.pad_token_id 
print(f"Student pad_token_id: {student_tokenizer.pad_token_id}")


if teacher_tokenizer.pad_token is None:
    teacher_tokenizer.add_special_tokens({'pad_token': '<pad>'})
teacher_model.config.pad_token_id = teacher_tokenizer.pad_token_id
print(f"Teacher pad_token_id: {teacher_tokenizer.pad_token_id}")

print('Loaded student and teacher models successfully.')

Using device: cpu
Loading student (llama_1b) and teacher (llama_8b) models...


Using device: cpu
Loading student (llama_1b) and teacher (llama_8b) models...


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [1]:
# Load the wild jailbreak dataset
print("Loading 'TrustAIRLab/in-the-wild-jailbreak-prompts' dataset...")
dataset = load_dataset("TrustAIRLab/in-the-wild-jailbreak-prompts", "jailbreak_2023_12_25")
df = pd.DataFrame(dataset['train'])
jailbreak_prompts = df['prompt'].tolist() # Extract only the prompts as a list


`torch_dtype` is deprecated! Use `dtype` instead!


`torch_dtype` is deprecated! Use `dtype` instead!


Model and tokenizer loaded successfully
Using device: cuda
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU


`torch_dtype` is deprecated! Use `dtype` instead!


Model and tokenizer loaded successfully
Using device: cuda
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Model and tokenizer loaded successfully
Using device: cuda
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


`torch_dtype` is deprecated! Use `dtype` instead!


Model and tokenizer loaded successfully
Using device: cuda
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Loaded 1405 jailbreak prompts.
Distillation DataLoader created with 50 batches


Distillation Training:   0%|          | 0/50 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Model and tokenizer loaded successfully
Using device: cuda
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Loaded 1405 jailbreak prompts.
Distillation DataLoader created with 50 batches


Distillation Training:   0%|          | 0/50 [00:00<?, ?it/s]

Error processing batch 0: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 1: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 2: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 3: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 4: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 5: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 3: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 4: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dime

`torch_dtype` is deprecated! Use `dtype` instead!


Model and tokenizer loaded successfully
Using device: cuda
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Loaded 1405 jailbreak prompts.
Distillation DataLoader created with 50 batches


Distillation Training:   0%|          | 0/50 [00:00<?, ?it/s]

Error processing batch 0: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 1: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 2: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 3: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 4: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 5: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 3: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 4: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dime

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


`torch_dtype` is deprecated! Use `dtype` instead!


Model and tokenizer loaded successfully
Using device: cuda
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Loaded 1405 jailbreak prompts.
Distillation DataLoader created with 50 batches


Distillation Training:   0%|          | 0/50 [00:00<?, ?it/s]

Error processing batch 0: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 1: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 2: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 3: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 4: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 5: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 3: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dimension 1
Error processing batch 4: The size of tensor a (128256) must match the size of tensor b (128257) at non-singleton dime

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


✅ Distilled model saved.

🎓 Distilled Student Output:
Explain the history of cryptography and its importance in modern society.
Cryptography is the practice of secure communication by transforming plaintext (readable data) into ciphertext (unreadable data). The history of cryptography dates back to ancient civilizations, with significant milestones in the development of encryption techniques. Here's a brief overview of the key events and figures in the history

🎓 Distilled Student Output:
Explain the history of cryptography and its importance in modern society.
Cryptography is the practice of secure communication by transforming plaintext (readable data) into ciphertext (unreadable data). The history of cryptography dates back to ancient civilizations, with significant milestones in the development of encryption techniques. Here's a brief overview of the key events and figures in the history


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from datasets import load_dataset
import pandas as pd
import sys
import os
import torch.nn.functional as F

# --- 0. Setup: Device, Custom Model Loaders, and Dataset Loading ---
nb_dir = r"f:\T2430392\Efficient-Distilled-miniLLM-to-Prevent-Jailbreak-Attacks\Knowledge_Distillation"
if nb_dir not in sys.path:
    sys.path.insert(0, nb_dir)

from llama_1b import load_llama_1b
from llama_8b import load_llama_8b

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

student_model, student_tokenizer = load_llama_1b()
teacher_model, teacher_tokenizer = load_llama_8b()

student_model.to(device)
teacher_model.to(device)

student_tokenizer.padding_side = "left"
teacher_tokenizer.padding_side = "left"

# --- Ensure pad token and resize embeddings ---
# 1. Add pad token to tokenizers if missing
if student_tokenizer.pad_token is None:
    student_tokenizer.add_special_tokens({'pad_token': '<pad>'})

if teacher_tokenizer.pad_token is None:
    teacher_tokenizer.add_special_tokens({'pad_token': '<pad>'})

# 2. Determine the maximum vocabulary size between the two
# This ensures both models will have the exact same output dimension
vocab_size = max(len(student_tokenizer), len(teacher_tokenizer))
print(f"Unifying vocabulary size to: {vocab_size}")

# 3. Resize both models to this unified size
student_model.resize_token_embeddings(vocab_size)
teacher_model.resize_token_embeddings(vocab_size)

# 4. Update config pad_token_id
student_model.config.pad_token_id = student_tokenizer.pad_token_id
teacher_model.config.pad_token_id = teacher_tokenizer.pad_token_id

# --- Load dataset ---
dataset = load_dataset("TrustAIRLab/in-the-wild-jailbreak-prompts", "jailbreak_2023_12_25")
df = pd.DataFrame(dataset['train'])
jailbreak_prompts = df['prompt'].tolist()
print(f"Loaded {len(jailbreak_prompts)} jailbreak prompts.")

# --- 1. Distillation Dataset ---
class JailbreakDistillationDataset(Dataset):
    def __init__(self, prompts, student_tokenizer, teacher_tokenizer, max_input_length=256):
        self.prompts = prompts
        self.student_tokenizer = student_tokenizer
        self.teacher_tokenizer = teacher_tokenizer
        self.max_input_length = max_input_length

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        prompt = self.prompts[idx]
        student_enc = self.student_tokenizer(
            prompt, return_tensors="pt", padding="max_length",
            truncation=True, max_length=self.max_input_length
        )
        teacher_enc = self.teacher_tokenizer(
            prompt, return_tensors="pt", padding="max_length",
            truncation=True, max_length=self.max_input_length
        )
        return {
            "student_input_ids": student_enc["input_ids"].squeeze(0),
            "student_attention_mask": student_enc["attention_mask"].squeeze(0),
            "teacher_input_ids": teacher_enc["input_ids"].squeeze(0),
            "teacher_attention_mask": teacher_enc["attention_mask"].squeeze(0),
        }

# jailbreak_prompts_subset = jailbreak_prompts[:100]
distillation_dataset = JailbreakDistillationDataset(
    jailbreak_prompts, student_tokenizer, teacher_tokenizer, max_input_length=128
)
BATCH_SIZE = 16
dist_loader = DataLoader(distillation_dataset, batch_size=BATCH_SIZE, shuffle=True)
print(f"Distillation DataLoader created with {len(dist_loader)} batches")

# --- 2. Distillation Loss ---
def distillation_loss(student_logits, teacher_logits, temperature=2.0):
    student_logits = torch.clamp(student_logits.float(), -10, 10)
    teacher_logits = torch.clamp(teacher_logits.float(), -10, 10)
    student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
    loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)
    if torch.isnan(loss):
        return torch.tensor(0.0, device=student_logits.device, requires_grad=True)
    return loss

# --- 3. Optimizer & Scheduler ---
optimizer = AdamW(student_model.parameters(), lr=1e-5)
NUM_EPOCHS = 2
num_training_steps = NUM_EPOCHS * len(dist_loader)
lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

# --- 4. Training Loop ---
student_model.train()
teacher_model.eval()
progress_bar = tqdm(range(num_training_steps), desc="Distillation Training")

for epoch in range(NUM_EPOCHS):
    total_loss = 0
    valid_batches = 0

    for batch_idx, batch in enumerate(dist_loader):
        student_input_ids = batch["student_input_ids"].to(device)
        student_attention_mask = batch["student_attention_mask"].to(device)

        teacher_input_ids = batch["teacher_input_ids"].to(device)
        teacher_attention_mask = batch["teacher_attention_mask"].to(device)

        try:
            # Teacher forward
            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=teacher_input_ids,
                    attention_mask=teacher_attention_mask
                )
                teacher_logits = teacher_outputs.logits

            # Student forward
            student_outputs = student_model(
                input_ids=student_input_ids,
                attention_mask=student_attention_mask
            )
            student_logits = student_outputs.logits

            # --- Next-token distillation ---
            min_len = min(student_logits.size(1)-1, teacher_logits.size(1)-1)
            student_logits_next = student_logits[:, :min_len, :]
            teacher_logits_next = teacher_logits[:, 1:min_len+1, :]


            # Flatten and compute loss
            loss = distillation_loss(
                student_logits_next.reshape(-1, student_logits_next.size(-1)),
                teacher_logits_next.reshape(-1, teacher_logits_next.size(-1)),
            )

            # Backprop
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            total_loss += loss.item()
            valid_batches += 1
            progress_bar.update(1)

        except Exception as e:
            print(f"Error processing batch {batch_idx}: {e}")
            progress_bar.update(1)
            continue

    if valid_batches > 0:
        print(f"Epoch {epoch+1}, Avg Loss: {total_loss/valid_batches:.4f} ({valid_batches} batches)")

print("\n✅ Distillation training complete!")

# --- 5. Save Distilled Model ---
output_dir = "./distilled_llama_1b_jailbreak_safety"
os.makedirs(output_dir, exist_ok=True)
student_model.save_pretrained(output_dir)
student_tokenizer.save_pretrained(output_dir)
print("✅ Distilled model saved.")

# --- 6. Example Inference ---
student_model.eval()
test_prompt = "Explain the history of cryptography"
input_ids = student_tokenizer(test_prompt, return_tensors="pt", truncation=True, max_length=128).input_ids.to(device)

with torch.no_grad():
    generated_output = student_model.generate(
        input_ids,
        max_new_tokens=64,
        do_sample=False,
        pad_token_id=student_tokenizer.pad_token_id
    )

decoded_output = student_tokenizer.decode(generated_output[0], skip_special_tokens=True)
print(f"\n🎓 Distilled Student Output:\n{decoded_output}")


Using device: cuda
✅ CUDA available – loading LLaMA 8B in 4-bit on GPU


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Unifying vocabulary size to: 128257
Loaded 1405 jailbreak prompts.
Distillation DataLoader created with 88 batches


Distillation Training:   0%|          | 0/264 [00:00<?, ?it/s]

Epoch 1, Avg Loss: 0.0449 (88 batches)
Epoch 2, Avg Loss: 0.0000 (88 batches)
Epoch 3, Avg Loss: 0.0000 (88 batches)

✅ Distillation training complete!
✅ Distilled model saved.

🎓 Distilled Student Output:
Explain the history of cryptography!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


## Pipeline Steps
1. Load student and teacher models (with quantization and CUDA)
2. Prepare and tokenize jailbreak dataset
3. Define KD loss and custom Trainer
4. Train student model with KD
5. Evaluate and save the distilled student model

In [11]:
test_prompt = "Hello"
input_ids = student_tokenizer(test_prompt, return_tensors="pt", truncation=True, max_length=128).input_ids.to(device)

with torch.no_grad():
    generated_output = student_model.generate(
        input_ids,
        max_new_tokens=64,
        do_sample=False,
        pad_token_id=student_tokenizer.pad_token_id
    )

decoded_output = student_tokenizer.decode(generated_output[0], skip_special_tokens=True)
print(f"\n🎓 Distilled Student Output:\n{decoded_output}")

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
