# Knowledge Distillation: Llama 3.2 3B → Llama 3.2 1B

This notebook implements knowledge distillation from a teacher model (Llama 3.2 3B) to a student model (Llama 3.2 1B) using the MMLU-medical-cot-llama31 dataset.

In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
from datasets import load_dataset
from nltk.translate.meteor_score import meteor_score
from nltk.tokenize import word_tokenize
import nltk
from tqdm import tqdm
import os
import logging
import gc
from pathlib import Path

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Download necessary NLTK data
try:
    nltk.download('punkt')
    nltk.download('wordnet')
except Exception as e:
    logger.error(f"Failed to download NLTK data: {e}")

# Ensure using GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Check CUDA memory
if torch.cuda.is_available():
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Available GPU memory: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB reserved")

Using device: cuda
Total GPU memory: 12.88 GB
Available GPU memory: 0.00 GB reserved


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\prati\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\prati\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## Load Models and Tokenizers

In [None]:
# Define model paths - use local paths if models are already downloaded
student_model_name = "meta-llama/Llama-3.2-1B"
teacher_model_name = "meta-llama/Llama-3.2-3B"

# Maximum sequence length for tokenization
max_length = 512

# Authentication for Hugging Face (required for Meta-Llama models)
# You need to set your HF_TOKEN in the environment or pass it directly
# os.environ["HF_TOKEN"] = "your_hugging_face_token"  # Uncomment and set your token
hf_token = os.environ.get("HF_TOKEN", None)

try:
    # Load the student model (Llama 3.2 1B) with gradient checkpointing for memory efficiency
    print("Loading student model and tokenizer...")
    student_tokenizer = AutoTokenizer.from_pretrained(
        student_model_name, 
        token=hf_token,
        use_fast=True
    )
    student_model = AutoModelForCausalLM.from_pretrained(
        student_model_name,
        token=hf_token,
        device_map="auto",  # Automatically determine best device mapping
        torch_dtype=torch.float16,  # Use half precision to save memory
        use_cache=False  # Disable KV cache during training
    )
    student_model.gradient_checkpointing_enable()  # Enable gradient checkpointing to save memory
    
    # Set pad token if not defined
    if student_tokenizer.pad_token_id is None:
        student_tokenizer.pad_token_id = student_tokenizer.eos_token_id
        
    print("Student model loaded successfully")
except Exception as e:
    print(f"Failed to load student model: {e}")
    raise

In [None]:
try:
    # Load the teacher model (Llama 3.2 3B)
    print("Loading teacher model and tokenizer...")
    teacher_tokenizer = AutoTokenizer.from_pretrained(
        teacher_model_name,
        token=hf_token,
        use_fast=True
    )
    teacher_model = AutoModelForCausalLM.from_pretrained(
        teacher_model_name,
        token=hf_token,
        device_map="auto",  # Automatically determine best device mapping
        torch_dtype=torch.float16,  # Use half precision to save memory
    )
    
    # Set teacher model to evaluation mode
    teacher_model.eval()
    
    # Set pad token if not defined
    if teacher_tokenizer.pad_token_id is None:
        teacher_tokenizer.pad_token_id = teacher_tokenizer.eos_token_id
        
    print("Teacher model loaded successfully")
except Exception as e:
    print(f"Failed to load teacher model: {e}")
    raise

## Load and Prepare Dataset

In [None]:
try:
    # Load the MMLU-medical-cot-llama31 dataset
    dataset_name = "HPAI-BSC/MMLU-medical-cot-llama31"
    print(f"Loading dataset: {dataset_name}")
    dataset = load_dataset(dataset_name, split="train")

    # Select a small subset for faster experimentation (can increase later)
    subset_size = 100
    subset = dataset.select(range(subset_size))
    print(f"Loaded {len(subset)} examples from {dataset_name}")

    # Display an example from the dataset
    example = subset[0]
    print("\nExample input:")
    print(f"{example['question'][:300]}{'...' if len(example['question']) > 300 else ''}")
    print("\nExample response:")
    print(f"{example['response'][:300]}{'...' if len(example['response']) > 300 else ''}")
except Exception as e:
    print(f"Failed to load dataset: {e}")
    raise

## Prepare Dataset for Batching

In [None]:
def prepare_batch(batch_data, student_tokenizer, teacher_tokenizer, max_length=512, device=device):
    """
    Prepare a batch of data for training with proper padding and attention masks.
    """
    # Extract questions and responses
    questions = [item["question"] for item in batch_data]
    responses = [item["response"] for item in batch_data]
    
    # Tokenize inputs (questions) with padding
    encoded_inputs = student_tokenizer(
        questions,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    ).to(device)
    
    # Tokenize targets (responses) with padding
    encoded_targets = student_tokenizer(
        responses,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    ).to(device)
    
    # Also tokenize inputs with teacher tokenizer if different
    if teacher_tokenizer != student_tokenizer:
        teacher_encoded_inputs = teacher_tokenizer(
            questions,
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        ).to(device)
    else:
        teacher_encoded_inputs = encoded_inputs
    
    return {
        "student_input_ids": encoded_inputs.input_ids,
        "student_attention_mask": encoded_inputs.attention_mask,
        "teacher_input_ids": teacher_encoded_inputs.input_ids,
        "teacher_attention_mask": teacher_encoded_inputs.attention_mask,
        "target_input_ids": encoded_targets.input_ids,
        "target_attention_mask": encoded_targets.attention_mask
    }

## Define Evaluation Function

In [None]:
def evaluate_model(model, tokenizer, dataset, max_length=512, max_new_tokens=50, batch_size=4):
    """
    Evaluate model using METEOR score and perplexity metrics.
    """
    model.eval()
    total_meteor_score = 0
    total_perplexity = 0
    total = 0

    try:
        for i in range(0, len(dataset), batch_size):
            batch_data = dataset[i:i+batch_size]
            
            for data in tqdm(batch_data, desc="Evaluating"):
                prompt = data["question"]
                correct_answer = data["response"]

                # Generate model's response
                try:
                    inputs = tokenizer(
                        prompt, 
                        return_tensors="pt", 
                        truncation=True, 
                        max_length=max_length
                    ).to(device)
                    
                    with torch.no_grad():
                        outputs = model.generate(
                            input_ids=inputs.input_ids,
                            attention_mask=inputs.attention_mask,
                            max_new_tokens=max_new_tokens,
                            do_sample=False
                        )
                    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

                    # Calculate METEOR score
                    try:
                        reference_tokens = word_tokenize(correct_answer.lower())
                        candidate_tokens = word_tokenize(generated_text.lower())
                        
                        # Check for empty token lists
                        if len(reference_tokens) == 0 or len(candidate_tokens) == 0:
                            print("Warning: Empty token list detected, skipping METEOR calculation")
                            meteor = 0
                        else:
                            meteor = meteor_score([reference_tokens], candidate_tokens)
                        
                        total_meteor_score += meteor
                    except Exception as e:
                        print(f"Error calculating METEOR score: {e}")
                        meteor = 0

                    # Calculate perplexity
                    try:
                        target_encoding = tokenizer(
                            correct_answer, 
                            return_tensors="pt",
                            truncation=True,
                            max_length=max_length
                        ).to(device)
                        
                        target_ids = target_encoding.input_ids
                        
                        with torch.no_grad():
                            outputs = model(input_ids=target_ids)
                            logits = outputs.logits
                            
                        shift_logits = logits[..., :-1, :].contiguous()
                        shift_labels = target_ids[..., 1:].contiguous()
                        
                        loss_fct = torch.nn.CrossEntropyLoss(
                            ignore_index=tokenizer.pad_token_id, 
                            reduction='sum'
                        )
                        
                        loss = loss_fct(
                            shift_logits.view(-1, shift_logits.size(-1)), 
                            shift_labels.view(-1)
                        )
                        
                        num_tokens = (shift_labels != tokenizer.pad_token_id).sum().item()
                        if num_tokens > 0:
                            perplexity = torch.exp(loss / num_tokens).item()
                        else:
                            perplexity = float('inf')
                            print("Warning: No valid tokens for perplexity calculation")
                            
                        total_perplexity += perplexity
                    except Exception as e:
                        print(f"Error calculating perplexity: {e}")
                        perplexity = float('inf')
                        
                    total += 1
                except Exception as e:
                    print(f"Error during evaluation: {e}")
                    continue
    except Exception as e:
        print(f"Evaluation failed: {e}")

    # Calculate averages
    if total > 0:
        average_meteor_score = total_meteor_score / total
        average_perplexity = total_perplexity / total
    else:
        average_meteor_score = 0
        average_perplexity = float('inf')
        print("Warning: No examples were successfully evaluated")

    return average_meteor_score, average_perplexity

## Baseline Evaluation

In [None]:
try:
    print("Evaluating student model before knowledge distillation...")
    initial_meteor, initial_perplexity = evaluate_model(
        student_model, 
        student_tokenizer, 
        subset, 
        max_length=max_length,
        max_new_tokens=50,
        batch_size=4
    )
    print(f"Initial Student Model - Average METEOR Score: {initial_meteor * 100:.2f}%")
    print(f"Initial Student Model - Average Perplexity: {initial_perplexity:.2f}")
except Exception as e:
    print(f"Baseline evaluation failed: {e}")
    initial_meteor, initial_perplexity = 0, float('inf')

## Knowledge Distillation

In [None]:
# Knowledge Distillation parameters
alpha = 0.5  # Weight for KL divergence loss
temperature = 2.0  # Temperature for softening probability distributions
num_epochs = 3
learning_rate = 5e-5
batch_size = 4
gradient_accumulation_steps = 4  # Accumulate gradients to simulate larger batch sizes
warmup_steps = 100

# Create an optimizer for the student model
optimizer = torch.optim.AdamW(student_model.parameters(), lr=learning_rate)

# Create a learning rate scheduler
total_steps = (len(subset) // batch_size) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=warmup_steps, 
    num_training_steps=total_steps
)

# Function to get teacher logits for a batch
def get_teacher_logits(input_ids, attention_mask=None):
    with torch.no_grad():
        teacher_outputs = teacher_model(
            input_ids=input_ids, 
            attention_mask=attention_mask
        )
        return teacher_outputs.logits

# Create checkpoint directory
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)

In [None]:
# Training loop with knowledge distillation
print("Starting knowledge distillation training...")
try:
    for epoch in range(num_epochs):
        student_model.train()
        total_loss = 0
        batches = 0
        
        # Create batches
        progress_bar = tqdm(range(0, len(subset), batch_size), desc=f"Epoch {epoch+1}/{num_epochs}")
        for i in progress_bar:
            batch_data = subset[i:i+batch_size]
            
            # Prepare batch data
            batch = prepare_batch(
                batch_data, 
                student_tokenizer, 
                teacher_tokenizer, 
                max_length=max_length
            )
            
            # Get student logits
            student_outputs = student_model(
                input_ids=batch["student_input_ids"],
                attention_mask=batch["student_attention_mask"]
            )
            student_logits = student_outputs.logits
            
            # Get teacher logits for the same input
            teacher_logits = get_teacher_logits(
                batch["teacher_input_ids"],
                attention_mask=batch["teacher_attention_mask"]
            )
            
            # Hard targets loss (cross-entropy with true labels)
            # Shift logits and labels for next token prediction
            shift_logits = student_logits[..., :-1, :].contiguous()
            shift_labels = batch["target_input_ids"][..., 1:].contiguous()
            
            hard_loss_fct = torch.nn.CrossEntropyLoss(
                ignore_index=student_tokenizer.pad_token_id
            )
            
            hard_loss = hard_loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
            
            # Soft targets loss (KL divergence with teacher outputs)
            # Ensure student and teacher logits have the same shape
            # We'll use the first min(student_len, teacher_len) tokens
            min_length = min(student_logits.size(1), teacher_logits.size(1))
            
            student_logits_t = F.log_softmax(
                student_logits[:, :min_length, :] / temperature, 
                dim=-1
            )
            
            teacher_logits_t = F.softmax(
                teacher_logits[:, :min_length, :] / temperature, 
                dim=-1
            )
            
            # KL divergence loss
            soft_loss_fct = torch.nn.KLDivLoss(reduction='batchmean')
            soft_loss = soft_loss_fct(
                student_logits_t, 
                teacher_logits_t
            ) * (temperature ** 2)
            
            # Combined loss
            loss = (1 - alpha) * hard_loss + alpha * soft_loss
            
            # Scale loss for gradient accumulation
            loss = loss / gradient_accumulation_steps
            
            # Backpropagation
            loss.backward()
            
            # Gradient accumulation
            if (batches + 1) % gradient_accumulation_steps == 0:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
                
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            total_loss += loss.item() * gradient_accumulation_steps
            batches += 1
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': loss.item() * gradient_accumulation_steps,
                'avg_loss': total_loss / batches
            })
            
            # Log every 5 batches
            if batches % 5 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batches}, Loss: {loss.item() * gradient_accumulation_steps:.4f}")
        
        # Make sure to step optimizer at the end of epoch if needed
        if batches % gradient_accumulation_steps != 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        avg_loss = total_loss / batches
        print(f"Epoch {epoch+1}/{num_epochs} completed, Average Loss: {avg_loss:.4f}")
        
        # Save checkpoint after each epoch
        try:
            checkpoint_path = checkpoint_dir / f"student_model_checkpoint_epoch_{epoch+1}.pt"
            torch.save({
                'epoch': epoch,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': avg_loss,
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")
        except Exception as e:
            print(f"Failed to save checkpoint: {e}")
        
        # Clear cache to free up memory
        torch.cuda.empty_cache()
        gc.collect()
        
except Exception as e:
    print(f"Training failed: {e}")
    raise

## Final Evaluation

In [None]:
try:
    # Evaluate the student model after knowledge distillation
    print("Evaluating student model after knowledge distillation...")
    final_meteor, final_perplexity = evaluate_model(
        student_model, 
        student_tokenizer, 
        subset, 
        max_length=max_length,
        max_new_tokens=50,
        batch_size=4
    )
    print(f"Final Student Model - Average METEOR Score: {final_meteor * 100:.2f}%")
    print(f"Final Student Model - Average Perplexity: {final_perplexity:.2f}")

    # Compare results
    print("\nKnowledge Distillation Results Summary:")
    print(f"METEOR Score: {initial_meteor * 100:.2f}% → {final_meteor * 100:.2f}%")
    print(f"Perplexity: {initial_perplexity:.2f} → {final_perplexity:.2f}")
except Exception as e:
    print(f"Final evaluation failed: {e}")

## Save Distilled Model