In [2]:
%pip install peft

Collecting peft
  Downloading peft-0.15.2-py3-none-any.whl.metadata (13 kB)
Collecting accelerate>=0.21.0 (from peft)
  Downloading accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)
Downloading peft-0.15.2-py3-none-any.whl (411 kB)
Downloading accelerate-1.6.0-py3-none-any.whl (354 kB)
Installing collected packages: accelerate, peft
Successfully installed accelerate-1.6.0 peft-0.15.2
Note: you may need to restart the kernel to use updated packages.


In [16]:
import os
import sys
import torch
import numpy as np
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
import argparse

In [17]:
def parse_args():
    parser = argparse.ArgumentParser(description="Knowledge Distillation for Llama3-OpenBioLLM")
    parser.add_argument("--teacher_model_path", type=str, default="aaditya/Llama3-OpenBioLLM-70B", 
                        help="Path to the teacher model")
    parser.add_argument("--student_model_path", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct", 
                        help="Path to the student model")
    parser.add_argument("--output_dir", type=str, default="./openbiollm-8b-distilled", 
                        help="Output directory for the distilled model")
    parser.add_argument("--dataset_name", type=str, default="medical_dialogues", 
                        help="Dataset name for distillation")
    parser.add_argument("--batch_size", type=int, default=8, 
                        help="Batch size for training")
    parser.add_argument("--learning_rate", type=float, default=3e-4, 
                        help="Learning rate for distillation")
    parser.add_argument("--num_epochs", type=int, default=3, 
                        help="Number of training epochs")
    parser.add_argument("--kd_alpha", type=float, default=0.5, 
                        help="Weight for KL divergence loss in knowledge distillation")
    parser.add_argument("--temperature", type=float, default=2.0, 
                        help="Temperature for softening probability distributions")
    parser.add_argument("--use_lora", action="store_true", 
                        help="Whether to use LoRA for fine-tuning")
    parser.add_argument("--lora_r", type=int, default=16, 
                        help="LoRA attention dimension")
    parser.add_argument("--lora_alpha", type=int, default=32, 
                        help="LoRA alpha parameter")
    return parser.parse_args()

class KnowledgeDistillationTrainer(Trainer):
    def __init__(self, teacher_model=None, temperature=1.0, kd_alpha=0.5, **kwargs):
        super().__init__(**kwargs)
        self.teacher_model = teacher_model
        self.temperature = temperature
        self.kd_alpha = kd_alpha
        # Make sure the teacher model is in evaluation mode
        if self.teacher_model is not None:
            self.teacher_model.eval()
    
    def compute_loss(self, model, inputs, return_outputs=False):
        # Standard cross-entropy loss
        outputs = model(**inputs)
        ce_loss = outputs.loss
        
        # Knowledge distillation loss
        if self.teacher_model is not None:
            with torch.no_grad():
                teacher_outputs = self.teacher_model(**inputs)
            
            student_logits = outputs.logits / self.temperature
            teacher_logits = teacher_outputs.logits / self.temperature
            
            # KL divergence loss
            kd_loss = torch.nn.functional.kl_div(
                torch.nn.functional.log_softmax(student_logits, dim=-1),
                torch.nn.functional.softmax(teacher_logits, dim=-1),
                reduction="batchmean"
            ) * (self.temperature ** 2)
            
            # Combined loss: (1 - alpha) * CE + alpha * KD
            loss = (1 - self.kd_alpha) * ce_loss + self.kd_alpha * kd_loss
        else:
            loss = ce_loss
        
        return (loss, outputs) if return_outputs else loss

def preprocess_function(examples, tokenizer, max_length=512):
    """Tokenize and prepare the examples for distillation."""
    # For instruction tuning dataset format
    prompts = []
    for example in examples["text"]:
        prompt = f"<s>[INST] {example} [/INST]"
        prompts.append(prompt)
    
    # Tokenize inputs
    tokenized_inputs = tokenizer(
        prompts,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    
    return tokenized_inputs

def load_models(args):
    """Load the teacher and student models."""
    print("Loading the teacher model...")
    teacher_tokenizer = AutoTokenizer.from_pretrained(args.teacher_model_path)
    
    # Check if accelerate is available
    try:
        import accelerate
        has_accelerate = True
    except ImportError:
        has_accelerate = False
    
    # Load teacher model with appropriate settings based on available libraries
    if has_accelerate:
        teacher_model = AutoModelForCausalLM.from_pretrained(
            args.teacher_model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            load_in_8bit=True  # Use quantization for large teacher model
        )
    else:
        # Fallback without device_map and quantization
        print("Accelerate library not found. Loading model without device_map and quantization.")
        print("This might require a lot of GPU memory. Consider installing accelerate: pip install accelerate")
        teacher_model = AutoModelForCausalLM.from_pretrained(
            args.teacher_model_path,
            torch_dtype=torch.bfloat16
        )
    
    print("Loading the student model...")
    student_tokenizer = AutoTokenizer.from_pretrained(args.student_model_path)
    
    # Load student model with appropriate settings
    if has_accelerate:
        student_model = AutoModelForCausalLM.from_pretrained(
            args.student_model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16
        )
    else:
        student_model = AutoModelForCausalLM.from_pretrained(
            args.student_model_path,
            torch_dtype=torch.bfloat16
        )
    
    # Apply LoRA if requested
    if args.use_lora:
        print("Applying LoRA to the student model...")
        lora_config = LoraConfig(
            r=args.lora_r,
            lora_alpha=args.lora_alpha,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )
        student_model = get_peft_model(student_model, lora_config)
    
    return teacher_model, teacher_tokenizer, student_model, student_tokenizer

def load_biomedical_dataset(dataset_name, tokenizer, max_length=512):
    """Load and preprocess the biomedical dataset for distillation."""
    # You would need to replace this with your own dataset loading logic
    # This is a placeholder example
    if dataset_name == "pubmed_qa":
        dataset = load_dataset("pubmed_qa", "pqa_labeled")
        # Process the dataset into a format suitable for distillation
        dataset = dataset.map(
            lambda x: {"text": x["question"] + " " + x["context"]},
            remove_columns=["question", "context", "pubid", "long_answer", "label"]
        )
    elif dataset_name == "medical_dialogues":
        # This is a placeholder - replace with actual medical dialogue dataset
        dataset = load_dataset("csv", data_files={"train": "medical_dialogues_train.csv", 
                                               "validation": "medical_dialogues_val.csv"})
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Tokenize the dataset
    tokenized_dataset = dataset.map(
        lambda examples: preprocess_function(examples, tokenizer, max_length),
        batched=True,
        remove_columns=dataset["train"].column_names
    )
    
    return tokenized_dataset

def main(teacher_model_path="aaditya/Llama3-OpenBioLLM-70B",
         student_model_path="meta-llama/Meta-Llama-3-8B-Instruct",
         output_dir="./openbiollm-8b-distilled",
         dataset_name="medical_dialogues",
         batch_size=8,
         learning_rate=3e-4,
         num_epochs=3,
         kd_alpha=0.5,
         temperature=2.0,
         use_lora=True,
         lora_r=16,
         lora_alpha=32):
    """
    Run the knowledge distillation process with provided parameters.
    This version works in Jupyter notebooks without requiring command line arguments.
    """
    # Create a simple namespace object to mimic argparse args
    class Args:
        pass
    
    args = Args()
    args.teacher_model_path = teacher_model_path
    args.student_model_path = student_model_path
    args.output_dir = output_dir
    args.dataset_name = dataset_name
    args.batch_size = batch_size
    args.learning_rate = learning_rate
    args.num_epochs = num_epochs
    args.kd_alpha = kd_alpha
    args.temperature = temperature
    args.use_lora = use_lora
    args.lora_r = lora_r
    args.lora_alpha = lora_alpha
    
    # Load models
    teacher_model, teacher_tokenizer, student_model, student_tokenizer = load_models(args)
    
    # Load dataset
    tokenized_dataset = load_biomedical_dataset(args.dataset_name, student_tokenizer)
    
    # Create training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        overwrite_output_dir=True,
        num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        weight_decay=0.01,
        logging_dir=f"{args.output_dir}/logs",
        logging_steps=10,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        save_total_limit=2,
        fp16=True,
        report_to="tensorboard",
    )
    
    # Create custom trainer with knowledge distillation
    trainer = KnowledgeDistillationTrainer(
        model=student_model,
        teacher_model=teacher_model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["validation"],
        tokenizer=student_tokenizer,
        temperature=args.temperature,
        kd_alpha=args.kd_alpha,
    )
    
    # Train the model
    print("Starting knowledge distillation training...")
    trainer.train()
    
    # Save the final model
    print(f"Saving the distilled model to {args.output_dir}")
    trainer.save_model(args.output_dir)
    student_tokenizer.save_pretrained(args.output_dir)
    
    print("Knowledge distillation completed successfully!")

In [18]:
if __name__ == "__main__":
    # Check if running in Jupyter
    try:
        # This will raise NameError if not in IPython/Jupyter
        if 'ipykernel' in sys.modules:
            # Call the function directly with defaults when in Jupyter
            print("Running in Jupyter environment - using function arguments instead of command line")
            main()
        else:
            # Use argparse when running as a script
            args = parse_args()
            main(
                teacher_model_path=args.teacher_model_path,
                student_model_path=args.student_model_path,
                output_dir=args.output_dir,
                dataset_name=args.dataset_name,
                batch_size=args.batch_size,
                learning_rate=args.learning_rate,
                num_epochs=args.num_epochs,
                kd_alpha=args.kd_alpha,
                temperature=args.temperature,
                use_lora=args.use_lora,
                lora_r=args.lora_r,
                lora_alpha=args.lora_alpha
            )
    except NameError:
        # Default to using argparse
        args = parse_args()
        main(
            teacher_model_path=args.teacher_model_path,
            student_model_path=args.student_model_path,
            output_dir=args.output_dir,
            dataset_name=args.dataset_name,
            batch_size=args.batch_size,
            learning_rate=args.learning_rate,
            num_epochs=args.num_epochs,
            kd_alpha=args.kd_alpha,
            temperature=args.temperature,
            use_lora=args.use_lora,
            lora_r=args.lora_r,
            lora_alpha=args.lora_alpha
        )

Running in Jupyter environment - using function arguments instead of command line
Loading the teacher model...


ImportError: Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>=0.26.0'`