# Knowledge Distillation Pipeline
This notebook demonstrates knowledge distillation using LoRA adapters and your Llama models.

In [1]:
# Install required packages and import libraries
%pip install --upgrade ipywidgets jupyter --quiet
%pip install peft optuna safetensors --quiet
%pip install safetensors

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from tqdm.notebook import tqdm
from safetensors.torch import safe_open
from peft import get_peft_model, LoraConfig
import optuna
import copy

# Basic settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 2
batch_size = 8
print("Device:", device)

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.




Device: cuda


In [15]:
# Load student and teacher models using your custom loaders (keeps code modular)
from llama_1b import load_llama_1b
from llama_8b import load_llama_8b

student_model, tokenizer = load_llama_1b()
teacher_model, _ = load_llama_8b()

student_model.to(device)
teacher_model.to(device)

# Ensure tokenizer has pad token
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '<pad>'})
student_model.generation_config.pad_token_id = tokenizer.pad_token_id

print('Loaded student and teacher models')

--- Loading Student Model (Llama 1B) ---
--- Student Model Loaded ---
--- Loading Teacher Model (Llama 8B) ---
--- Student Model Loaded ---
--- Loading Teacher Model (Llama 8B) ---


Loading checkpoint shards: 100%|██████████| 4/4 [00:14<00:00,  3.52s/it]



--- Teacher Model Loaded ---
Loaded student and teacher models


In [16]:
# Configure LoRA for the student model
lora_config = LoraConfig(
    r=4,
    lora_alpha=8,
    lora_dropout=0.2,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)
student_model = get_peft_model(student_model, lora_config)
print('LoRA attached to student model')

LoRA attached to student model


In [17]:
# Load and prepare data
train_data = load_dataset('openai/gsm8k', 'socratic', split='train').select(range(2638))
val_data = load_dataset('openai/gsm8k', 'socratic', split='test')
print('Data loaded:', len(train_data), 'train examples,', len(val_data), 'eval examples')

Data loaded: 2638 train examples, 1319 eval examples


In [18]:
# Cell 1: Generate and Save Teacher Logits

import os
from tqdm.notebook import tqdm

def compute_and_save_teacher_logits(teacher_model, tokenizer, dataset, out_dir='teacher_logits', device=None, max_examples=None):
    """
    Run the teacher model on each example and save the logits for the answer tokens.
    Saves one .pt file per example to avoid high RAM usage.
    """
    os.makedirs(out_dir, exist_ok=True)
    if device is None:
        device = next(teacher_model.parameters()).device

    teacher_model.eval()
    num_examples = len(dataset) if max_examples is None else min(len(dataset), max_examples)
    
    print(f"Computing teacher logits for {num_examples} examples -> saving to '{out_dir}'")
    
    with torch.no_grad():
        for i in tqdm(range(num_examples), desc="Generating Teacher Logits"):
            example = dataset[i]
            question = example.get('question')
            answer = example.get('answer')

            if not question or not answer:
                # Save an empty tensor for malformed examples to maintain order
                torch.save(torch.empty((0, teacher_model.config.vocab_size)), os.path.join(out_dir, f"{i:08d}.pt"))
                continue

            # Concatenate question and answer to match how the model sees the full sequence
            concat_text = question + (tokenizer.eos_token or '') + answer
            inputs = tokenizer(concat_text, truncation=True, padding=False, return_tensors='pt').to(device)
            
            # Get logits for the entire sequence
            outputs = teacher_model(**inputs)
            logits = outputs.logits.cpu()  # Shape: (1, seq_len, vocab_size)

            # Isolate the logits that correspond to the 'answer' part
            answer_tokens = tokenizer(answer, truncation=True, padding=False, return_tensors='pt')['input_ids']
            answer_len = answer_tokens.size(1)
            
            if answer_len > 0:
                # The answer logits are the last 'answer_len' tokens of the full sequence logits
                answer_logits = logits[0, -answer_len:, :].contiguous()
            else:
                answer_logits = torch.empty((0, logits.size(-1)), dtype=logits.dtype)

            # Save the tensor for this specific example
            torch.save(answer_logits, os.path.join(out_dir, f"{i:08d}.pt"))

    print('Finished generating and saving teacher logits.')

# --- Execute the function ---
# This will create a 'teacher_logits' directory with one file per training example.
# It uses your loaded teacher_model.
compute_and_save_teacher_logits(teacher_model, tokenizer, train_data, out_dir='teacher_logits', device=device)

Computing teacher logits for 2638 examples -> saving to 'teacher_logits'


ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

### Summary of the Plan

1.  **Generate Teacher Logits**: We will use your existing `teacher_model` to process the `train_data` and save the output logits. This creates the missing `safetensors` file that the research pipeline needs. We will save one file per example to avoid using too much memory.
2.  **Load Teacher Logits**: We will load these newly created logits.
3.  **Run the Pipeline**: With the logits loaded, the existing training cell from the research pipeline should now work correctly.
4.  **Diagnostic (Optional)**: I've included a cell to run a check on a single example. This is extremely useful for debugging and verifying that the shapes and loss values make sense before starting a long training run.

Run the following cells in order.

In [None]:
# Cell 2: Load the newly created teacher logits

import glob

def load_teacher_logits_from_files(dirpath='teacher_logits', max_examples=None):
    """
    Loads the per-example teacher logits from the specified directory.
    Ensures they are loaded in the correct order to match the dataset.
    """
    # Sort files numerically to ensure order matches dataset index
    files = sorted(glob.glob(os.path.join(dirpath, '*.pt')))
    
    if max_examples is not None:
        files = files[:max_examples]
        
    print(f"Loading {len(files)} teacher logit tensors from '{dirpath}'...")
    
    # Load each tensor into a list
    teacher_logits_list = [torch.load(p) for p in tqdm(files, desc="Loading Logits")]
    
    print(f"Successfully loaded {len(teacher_logits_list)} logit tensors.")
    return teacher_logits_list

# --- Execute the function ---
# This replaces the old safetensors loading logic with our file-based approach.
teacher_logits_L = load_teacher_logits_from_files('teacher_logits', max_examples=len(train_data))

Loading teacher logits from ../llama-8b-gsm8k-tensors.safetensors


FileNotFoundError: No such file or directory: "../llama-8b-gsm8k-tensors.safetensors"

In [None]:
# Helper: get teacher logits on-the-fly for a batch (no safetensors)
def get_teacher_logits_for_batch(teacher_model, tokenizer, questions, answers, device, max_length=512):
    teacher_model.eval()
    concat_texts = [q + (tokenizer.eos_token or "") + a for q, a in zip(questions, answers)]
    encoded = tokenizer(concat_texts, truncation=True, padding=True, max_length=max_length, return_tensors="pt")
    input_ids = encoded["input_ids"].to(device)
    attention_mask = encoded["attention_mask"].to(device)

    with torch.no_grad():
        outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    ans_enc = tokenizer(answers, truncation=True, padding=False, return_tensors="pt")
    answer_lens = [int(x.size(0)) for x in ans_enc["input_ids"]]

    batch_logits = []
    for i, ans_len in enumerate(answer_lens):
        if ans_len == 0:
            v = torch.zeros((0, logits.size(-1)), device=logits.device, dtype=logits.dtype)
            batch_logits.append(v)
            continue
        seq_len = int(attention_mask[i].sum().item())
        start_idx = seq_len - ans_len
        if start_idx < 0:
            start_idx = max(0, seq_len - ans_len)
        v = logits[i, start_idx:seq_len, :].cpu()
        batch_logits.append(v)

    max_ans_len = max([v.size(0) for v in batch_logits]) if batch_logits else 0
    if max_ans_len == 0:
        return torch.empty((len(batch_logits), 0, logits.size(-1))).to(device)

    padded = torch.zeros((len(batch_logits), max_ans_len, logits.size(-1)), dtype=batch_logits[0].dtype)
    for i, v in enumerate(batch_logits):
        L = v.size(0)
        if L > 0:
            padded[i, :L, :] = v

    return padded.to(device)

print('Teacher logits helper ready')

In [None]:
# Knowledge Distillation Loss
class KnowledgeDistillationLoss(nn.Module):
    def __init__(self, temperature=1.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    def forward(self, student_logits, teacher_logits, labels):
        loss_hard = self.criterion(student_logits, labels)
        teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=1)
        student_probs = F.softmax(student_logits / self.temperature, dim=1)
        loss_soft = F.kl_div(teacher_log_probs, student_probs, reduction='batchmean', log_target=False) * (self.temperature ** 2)
        return self.alpha * loss_hard + (1.0 - self.alpha) * loss_soft

print('KD loss defined')

In [None]:
# Cell 3: Updated Training Function (Handles padding and alignment)

def train_model(model, teacher_logits_L, data, tokenizer, optimizer, scheduler, kd_loss, num_epochs, device, batch_size=1):
    model.train()
    num_batches = len(data) // batch_size + int(len(data) % batch_size != 0)

    for epoch in range(num_epochs):
        progress_bar = tqdm(range(num_batches), desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')

        for batch_idx in progress_bar:
            optimizer.zero_grad()
            
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(data))

            batch_data = data[start_idx:end_idx]
            batch_teacher_logits = teacher_logits_L[start_idx:end_idx]
            
            questions = [ex['question'] for ex in batch_data]
            answers = [ex['answer'] for ex in batch_data]

            # Tokenize questions and answers separately for better control
            inputs = tokenizer(questions, return_tensors='pt', padding=True, truncation=True, max_length=256).to(device)
            labels = tokenizer(answers, return_tensors='pt', padding=True, truncation=True, max_length=256).to(device)

            # Get student logits
            outputs = model(**inputs, labels=labels['input_ids'])
            student_logits = outputs.logits

            # Pad the teacher logits to match the batch's max sequence length
            # This is critical for alignment.
            max_len = labels['input_ids'].size(1)
            padded_teacher_logits = torch.full(
                (len(batch_teacher_logits), max_len, student_logits.size(-1)),
                fill_value=0.0, # Use a neutral fill value
                device=device,
                dtype=student_logits.dtype
            )
            for i, t in enumerate(batch_teacher_logits):
                len_t = min(t.size(0), max_len)
                if len_t > 0:
                    padded_teacher_logits[i, :len_t, :] = t[:len_t].to(device)

            # Align sequence lengths for all tensors before flattening
            seq_len = min(student_logits.size(1), padded_teacher_logits.size(1), labels['input_ids'].size(1))
            
            student_logits_aligned = student_logits[:, :seq_len, :]
            teacher_logits_aligned = padded_teacher_logits[:, :seq_len, :]
            labels_aligned = labels['input_ids'][:, :seq_len]

            # Flatten for loss calculation
            student_flat = student_logits_aligned.contiguous().view(-1, student_logits.size(-1))
            teacher_flat = teacher_logits_aligned.contiguous().view(-1, student_logits.size(-1))
            labels_flat = labels_aligned.contiguous().view(-1)

            # Calculate loss
            loss = kd_loss(student_flat, teacher_flat, labels_flat)
            
            loss.backward()
            optimizer.step()
            scheduler.step()

            progress_bar.set_postfix(loss=loss.item())

    return model

print('Train function updated to handle padding of file-based teacher logits.')

In [None]:
def evaluate_model(model, validation_data, tokenizer, device, max_eval=200):
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for example in validation_data[:max_eval]:
            inputs = tokenizer(example['question'], truncation=True, padding=True, max_length=256, return_tensors='pt').to(device)
            labels = tokenizer(example['answer'], truncation=True, padding=True, max_length=256, return_tensors='pt')['input_ids'].to(device)

            outputs = model(**inputs)
            student_logits = outputs.logits

            seq_len = min(student_logits.size(1), labels.size(1))
            student_logits = student_logits[:, :seq_len, :]
            labels = labels[:, :seq_len]

            student_flat = student_logits.contiguous().view(-1, student_logits.size(-1))
            labels_flat = labels.view(-1)

            loss = F.cross_entropy(student_flat, labels_flat, ignore_index=tokenizer.pad_token_id, reduction='sum')
            total_loss += loss.item()
            total_tokens += (labels_flat != tokenizer.pad_token_id).sum().item()

    avg_loss = total_loss / total_tokens if total_tokens > 0 else float('nan')
    return avg_loss

print('Evaluate function ready')

In [None]:
# Optimizer and scheduler setup
no_decay = ['bias', 'LayerNorm.weight']
lora_params = [p for n, p in student_model.named_parameters() if 'lora' in n]
base_params = [p for n, p in student_model.named_parameters() if 'lora' not in n]
optimizer_grouped_parameters = [
    {'params': base_params, 'weight_decay': 0.0},
    {'params': lora_params, 'weight_decay': 1e-2},
]
optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=5e-6, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_data))
kd_loss = KnowledgeDistillationLoss(temperature=5.94, alpha=0.61)

print('Optimizer and scheduler ready')

In [None]:
# Train and evaluate (run this cell to start training)
# Ensure we pass the precomputed teacher logits list `teacher_logits_L` (not the teacher model)
trained_model = train_model(student_model, teacher_logits_L, train_data, tokenizer, optimizer, scheduler, kd_loss, num_epochs=num_epochs, device=device, batch_size=batch_size)
val_loss = evaluate_model(trained_model, val_data, tokenizer, device)
print(f'Validation loss: {val_loss}')

# Optional: save the LoRA weights only
student_model.save_pretrained('llama1b-lora-kd', max_shard_size='5GB')
print('Saved student LoRA weights to llama1b-lora-kd')

ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

In [None]:
# Cell 4 (Optional but Recommended): Single-Example Diagnostic

def kd_single_example_check(student_model, tokenizer, teacher_tensor, example, device, kd_loss_obj):
    """
    Run a single example through the student and teacher to check shapes and loss components.
    This is invaluable for debugging.
    """
    student_model.train()
    question = example['question']
    answer = example['answer']
    
    # Tokenize inputs and labels
    inputs = tokenizer(question, return_tensors='pt', padding=True, truncation=True, max_length=256).to(device)
    labels = tokenizer(answer, return_tensors='pt', padding=True, truncation=True, max_length=256).to(device)
    
    # Get student logits
    with torch.no_grad():
        outputs = student_model(**inputs, labels=labels['input_ids'])
        student_logits = outputs.logits

    # Align teacher tensor
    teacher_t = teacher_tensor.to(device)
    
    # Align sequence lengths
    seq_len = min(student_logits.size(1), teacher_t.size(0), labels['input_ids'].size(1))
    
    student_logits_aligned = student_logits[:, :seq_len, :]
    teacher_logits_aligned = teacher_t[:seq_len, :].unsqueeze(0) # Add batch dim
    labels_aligned = labels['input_ids'][:, :seq_len]

    # Flatten for loss
    V = student_logits.size(-1)
    student_flat = student_logits_aligned.contiguous().view(-1, V)
    teacher_flat = teacher_logits_aligned.contiguous().view(-1, V)
    labels_flat = labels_aligned.contiguous().view(-1)

    print('--- Single-Example Diagnostic ---')
    print(f"Shapes (student, teacher, labels): {student_flat.shape}, {teacher_flat.shape}, {labels_flat.shape}")

    # Calculate loss components
    loss_hard = F.cross_entropy(student_flat, labels_flat, ignore_index=tokenizer.pad_token_id)
    
    T = kd_loss_obj.temperature
    alpha = kd_loss_obj.alpha
    
    teacher_log_probs = F.log_softmax(teacher_flat / T, dim=1)
    student_probs = F.softmax(student_flat / T, dim=1)
    loss_soft = F.kl_div(teacher_log_probs, student_probs, reduction='batchmean', log_target=True) * (T ** 2)
    
    total_loss = alpha * loss_hard + (1.0 - alpha) * loss_soft

    print(f"Loss Components: Hard={loss_hard.item():.4f}, Soft(KL)={loss_soft.item():.4f}, Total KD={total_loss.item():.4f}")
    print('--- End Diagnostic ---')
    
    return total_loss

# --- Execute the diagnostic on the first training example ---
if train_data and teacher_logits_L:
    print("Running diagnostic on the first training example...")
    kd_single_example_check(student_model, tokenizer, teacher_logits_L[0], train_data[0], device, kd_loss)
else:
    print("Skipping diagnostic: data or logits not loaded.")