In [1]:
import torch
from transformers import get_linear_schedule_with_warmup
from utils import (
    load_env_variables,
    ModelManager,
    MMLUMedicalDataset,
    VanillaKnowledgeDistillation,
    DistillationTrainer,
    Evaluator
)


# 1. Load environment variables
api_key = load_env_variables()

# 2. Initialize model manager
model_manager = ModelManager(api_key=api_key)

# 3. Load teacher model (larger model)
teacher_model_name = "meta-llama/Llama-3.1-8B-Instruct"
teacher_model = model_manager.load_model(
    model_name=teacher_model_name,
    is_teacher=True,
    use_8bit=True
)
teacher_tokenizer = model_manager.load_tokenizer(model_name=teacher_model_name)

# 4. Load student model (smaller model)
student_model_name = "meta-llama/Llama-3.2-3B-Instruct"
student_model = model_manager.load_model(
    model_name=student_model_name,
    is_teacher=False
)
student_tokenizer = model_manager.load_tokenizer(model_name=student_model_name)

# 5. Initialize dataset
dataset = MMLUMedicalDataset(
    tokenizer=teacher_tokenizer,
    max_length=512,
    streaming=True
)
dataset.load_data()
dataset.prepare_for_distillation()

# 6. Initialize optimizer and scheduler
optimizer = torch.optim.AdamW(
    student_model.parameters(),
    lr=5e-5,
    weight_decay=0.01
)

num_training_steps = 10000
num_warmup_steps = 1000
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

# 7. Initialize knowledge distillation
distillation = VanillaKnowledgeDistillation(
    teacher_model=teacher_model,
    student_model=student_model,
    teacher_tokenizer=teacher_tokenizer,
    student_tokenizer=student_tokenizer,
    temperature=2.0,
    alpha=0.5
)

# 8. Initialize trainer
trainer = DistillationTrainer(
    distillation=distillation,
    dataset=dataset,
    optimizer=optimizer,
    scheduler=scheduler,
    use_mixed_precision=True,
    grad_accum_steps=4
)

# 9. Train the model
metrics = trainer.train(
    epochs=3,
    batch_size=8,
    checkpoint_dir="./checkpoints",
    eval_steps=500,
    save_steps=1000,
    num_workers=0  # Set to 0 to avoid multiprocessing issues
)


2025-04-06 21:16:08,834 - utils - INFO - ModelManager initialized with device: cuda
2025-04-06 21:16:08,834 - utils - INFO - Loading model: meta-llama/Llama-3.1-8B-Instruct (teacher: True)
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

2025-04-06 21:16:17,076 - utils - INFO - Applied torch.compile to model
2025-04-06 21:16:17,077 - utils - INFO - Successfully loaded model: meta-llama/Llama-3.1-8B-Instruct
2025-04-06 21:16:17,079 - utils - INFO - GPU Memory: 8.46 GB allocated, 8.51 GB reserved
2025-04-06 21:16:17,079 - utils - INFO - Loading tokenizer for: meta-llama/Llama-3.1-8B-Instruct
2025-04-06 21:16:17,419 - utils - INFO - Successfully loaded tokenizer for: meta-llama/Llama-3.1-8B-Instruct
2025-04-06 21:16:17,421 - utils - INFO - Loading model: meta-llama/Llama-3.2-3B-Instruct (teacher: False)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2025-04-06 21:16:23,257 - utils - INFO - Successfully loaded model: meta-llama/Llama-3.2-3B-Instruct
2025-04-06 21:16:23,257 - utils - INFO - GPU Memory: 14.45 GB allocated, 14.51 GB reserved
2025-04-06 21:16:23,257 - utils - INFO - Loading tokenizer for: meta-llama/Llama-3.2-3B-Instruct
2025-04-06 21:16:23,653 - utils - INFO - Successfully loaded tokenizer for: meta-llama/Llama-3.2-3B-Instruct
2025-04-06 21:16:23,653 - utils - INFO - MMLUMedicalDataset initialized with max_length: 512
2025-04-06 21:16:23,654 - utils - INFO - Loading HPAI-BSC/MMLU-medical-cot-llama31 dataset
2025-04-06 21:16:24,576 - utils - INFO - Dataset loaded successfully
2025-04-06 21:16:24,577 - utils - INFO - Preparing dataset for knowledge distillation
2025-04-06 21:16:24,577 - utils - INFO - Using streaming mode - data will be processed on-the-fly
2025-04-06 21:16:24,578 - utils - INFO - Dataset prepared for distillation
2025-04-06 21:16:24,580 - utils - INFO - Knowledge Distillation initialized with temperatu

KeyError: 'answer'

In [None]:
# 10. Evaluate the distilled model
evaluator = Evaluator(
    model=student_model,
    tokenizer=student_tokenizer,
    batch_size=16
)

# Save the final distilled model
student_model.save_pretrained("./distilled_model")
student_tokenizer.save_pretrained("./distilled_model")

In [None]:
# Prepare test data
test_data = [
    {
        'input': "Question: What is the most common cause of community-acquired pneumonia?",
        'reference': "The most common cause of community-acquired pneumonia is Streptococcus pneumoniae."
    }
    # Add more test examples as needed
]

# Run evaluation
eval_results = evaluator.evaluate_model(test_data)
print(f"Perplexity: {eval_results['perplexity']:.4f}")
print(f"METEOR score: {eval_results['meteor']:.4f}")