In [None]:
# Install required packages
!pip install -q transformers==4.47.1
!pip install -q datasets==3.2.0
!pip install -q torch==2.5.1
!pip install -q lm-eval==0.4.7

In [None]:
# Import necessary libraries
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from getpass import getpass
from lm_eval import evaluator, tasks, models

In [None]:
# Get Hugging Face token securely
hf_token = getpass("Hugging Face: ")

In [None]:
# Login to Hugging Face
!huggingface-cli login --token $hf_token

In [None]:
# Load teacher and student models and their tokenizers
teacher_model_name = "meta-llama/Llama-3.2-1B"
student_model_name = "oopere/pruned40-llama-3.2-1B"

In [None]:
# Initialize tokenizer - we can use the same tokenizer for both models since they're both Llama-based
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
tokenizer.pad_token = tokenizer.eos_token


In [None]:
# Load models
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name)
student_model = AutoModelForCausalLM.from_pretrained(student_model_name)

In [None]:
# Data Loading
dataset = load_dataset("ptb_text_only", "penn_treebank", split="train", trust_remote_code=True)
# Take a subset for faster training/testing
original_dataset = dataset
dataset = dataset.select(range(1000))

In [None]:
# Create a tokenization function
def tokenize_function(examples):
    # Tokenize with padding and truncation
    tokenized = tokenizer(
        examples["sentence"],
        padding="max_length",
        truncation=True,
        max_length=128,  # Adjusted for Llama models
        return_tensors="pt"
    )

    # Create input_ids and labels for language modeling
    input_ids = tokenized["input_ids"]
    labels = input_ids.clone()

    return {
        "input_ids": input_ids,
        "attention_mask": tokenized["attention_mask"],
        "labels": labels
    }

In [None]:
# Process the dataset with progress bar
print("Tokenizing dataset...")
tokenized_datasets = dataset.map(
    tokenize_function,
    batched=True,
    batch_size=32,  # Smaller batch size for mapping
    remove_columns=dataset.column_names,
    desc="Processing examples",
    load_from_cache_file=False  # Disable caching for debugging
)

In [None]:
# Convert to PyTorch format
tokenized_datasets.set_format("torch")

In [None]:
# Create DataLoader
dataloader = DataLoader(
    tokenized_datasets,
    batch_size=8,  # Reduced batch size due to model size
    shuffle=True
)

In [None]:
# Move models to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model.to(device)
student_model.to(device)

In [None]:
# Set teacher model to evaluation mode
teacher_model.eval()

In [None]:
# Define optimizer for student model
optimizer = AdamW(student_model.parameters(), lr=1e-5)  # Reduced learning rate for Llama

# Training loop
num_epochs = 10
temperature = 2.0  # Increased temperature for Llama
alpha = 1  # Weight for soft loss

accumulation_steps = 4  # Gradient accumulation for larger effective batch size


In [None]:
# Start training
for epoch in range(num_epochs):
    ### 1 - Model Preparation ###
    #initializes each training epoch,
    #putting the student model in training mode
    student_model.train()
    #Initializes total_loss to track the cumulative loss for the epoch.
    total_loss = 0

    for batch_idx, batch in enumerate(dataloader):
        ### 2 - Data procesing.  ###
        #Moves the batch data to the appropriate device (CPU/GPU) for processing.
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        ### 3 - Teacher Model Inference ###
        #Disables gradient computation to save memory and speed up inference.
        with torch.no_grad():
            teacher_outputs = teacher_model(
                input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            #Applies temperature scaling to soften the teacher's predictions
            teacher_logits = teacher_outputs.logits / temperature

        ### 4 - Student Model Inference. ###
        #The student model generates logits for the same input data.
        student_outputs = student_model(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        student_logits = student_outputs.logits

        ### 5 - Compute loss ###
        # Converts logits to probabilities using softmax
        teacher_probs = F.softmax(teacher_logits, dim=-1)
        #Computes the KL Divergence between the teacher's probabilities and the student's log probabilities.
        #KL Divergence measures how well the student's predictions match the teacher's.
        student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
        loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
        #The loss is divided by accumulation_steps to balance gradient updates across accumulated batches.
        loss = loss / accumulation_steps

        ### 6- Backward pass ###
        loss.backward()

        ### 7 - Optimization Gradient Accumulation ###
        # Accumulates gradients over multiple batches
        # Updates model parameters when enough gradients are accumulated
        # Resets gradients after update
        if ((batch_idx + 1) % accumulation_steps == 0) or (batch_idx + 1 == len(dataloader)):
            optimizer.step()
            optimizer.zero_grad()

        ### 8 - Loss Tracking ###
        # Scales the loss back up by multiplying it with accumulation_steps to reflect the actual batch contribution.
        total_loss += loss.item() * accumulation_steps

        #if (batch_idx + 1) % 100 == 0:
        #    print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}, Loss: {loss.item():.4f}")

    ### 9 - Epoch-Level Reporting
    # Computes the average loss for the epoch by dividing the total loss by the number of batches.
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")


In [None]:
# Evaluation function
def evaluate_hf_model(model_name, tasks=['arc_easy'], num_fewshot=0):
    """
    It calls the evaluator to evaluate a model available on Hugging Face.

    Args:
    - model_name: The model name in hugging Face.
    - tasks: Tasks to evaluate.
    - num_fewshot: Number of examples of few-shot learning

    Returns:
    - metrics.
    """
    model_args = f"pretrained={model_name},device=cuda"
    tasks = tasks

    results = evaluator.simple_evaluate(
      model="hf",
      model_args=model_args,
      tasks=tasks,
      num_fewshot=0,  # Number of few-shot smaples.
      limit=None,  # Use all the samples in the Evaluate Dataset.
      bootstrap_iters=10
    )

    metrics = results.get('results', {})
    return metrics

In [None]:
# Select tasks to evaluate.
tasks = ['lambada']

In [None]:
# Evaluate base model
metrics_pruned_kd = evaluate_hf_model("AminNasiri63/pruned20-llama-1b-st", tasks=tasks)
metrics_pruned_kd