# Knowledge Distillation and Quantization for Neural Networks

## References
[Compressing Large Language Models (LLMs) | w/ Python Code](https://www.youtube.com/watch?v=FLkUOkeMd5M&ab_channel=ShawTalebi)

In [19]:
import os
import torch
import torch.nn as nn
import torch.optim as optim

from typing import List, Dict, Any, Tuple
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    BitsAndBytesConfig,
)
from transformers import DistilBertForSequenceClassification, DistilBertConfig
from torch.utils.data import DataLoader
from torch.functional import F
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# Data Processing

In [None]:
data = load_dataset("shawhin/phishing-site-classification")
data

In [None]:
# print a few examples
for i in range(3):
    print(data["train"][i])

# Load Models

In [13]:
# load the tokenizer and teacher model
teacher_model_name = "shawhin/bert-phishing-classifier_teacher"
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_model_name)

In [None]:
# Drop 4 heads per layer and 2 layers
# Original DistilBERT has 6 layers and 12 heads, context window size is 512, embedding size is 768
# Technically we are pruning the original model to 4 layers and 8 heads
student_model_config = DistilBertConfig(n_heads=8, n_layers=4)
student_model = DistilBertForSequenceClassification.from_pretrained(
    pretrained_model_name_or_path="distilbert-base-uncased",
    config=student_model_config,
)

# Tokenization

In [None]:
# define text preprocessing
def preprocess_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


# tokenize all datasetse
tokenized_data = data.map(preprocess_function, batched=True)
tokenized_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "labels"]
)

# Helper Functions

In [17]:
# Function to evaluate model performance
def evaluate_model(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
) -> Tuple[float, float, float, float]:
    """
    Evaluate a PyTorch model on a given DataLoader.
    """

    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []

    # Disable gradient calculations
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass to get logits
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Get predictions
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average="binary"
    )

    return accuracy, precision, recall, f1

In [16]:
# Function to compute distillation and hard-label loss
def distillation_loss(
    student_logits: torch.FloatTensor,
    teacher_logits: torch.FloatTensor,
    true_labels: torch.LongTensor,
    temperature: float,
    alpha: float,
) -> torch.FloatTensor:
    """
    Compute the knowledge distillation loss by combining:
      - KL Divergence between the student and teacher distributions
      - Hard-label cross-entropy loss with the ground truth labels
    Original Paper: “Distilling the Knowledge in a Neural Network”
    Typical Temperature Values: 1.0 - 20.0, but 2.0 is common
    Typical Alpha Values: 0.3 and 0.7
    """

    # Compute the soft targets from the teacher using the temperature
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    student_log_probs = F.log_softmax(student_logits / temperature, dim=1)

    # Knowledge Distillation (KL Divergence) part
    distill_loss = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (
        temperature**2
    )

    # Hard-label cross-entropy loss
    hard_loss = F.cross_entropy(student_logits, true_labels)

    # Combine
    loss = alpha * distill_loss + (1.0 - alpha) * hard_loss
    return loss

In [20]:
# Test the distillation loss function
teacher_logits = torch.randn(4, 2)
student_logits = torch.randn(4, 2)
true_labels = torch.tensor([0, 1, 0, 1])
temperature = 2.0
alpha = 0.5

loss = distillation_loss(
    student_logits, teacher_logits, true_labels, temperature, alpha
)

# Teacher Model Pruning

# Student Model Training

In [21]:
# Hyperparameters
batch_size = 32
lr = 1e-4
num_epochs = 5
temperature = 2.0
alpha = 0.5

# Define Optimizer
optimizer = optim.AdamW(student_model.parameters(), lr=lr)

# Set the device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create training data loader
dataloader = DataLoader(
    tokenized_data["train"],
    batch_size=batch_size,
    shuffle=True,
)

# Create testing data loader
test_dataloader = DataLoader(
    tokenized_data["test"],
    batch_size=batch_size,
    shuffle=False,
)

In [None]:
# Move models to device
student_model.to(device)
teacher_model.to(device)

# put student model in train mode
student_model.train()

# train model
for epoch in range(num_epochs):
    for batch in dataloader:
        # Prepare inputs
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # Disable gradient calculation for teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits

        # Forward pass through the student model
        student_outputs = student_model(input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits

        # Compute the distillation loss
        loss = distillation_loss(
            student_logits, teacher_logits, labels, temperature, alpha
        )

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")

    # Evaluate the teacher model
    teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(
        teacher_model, test_dataloader, device
    )
    print(
        f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}"
    )

    # Evaluate the student model
    student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(
        student_model, test_dataloader, device
    )
    print(
        f"Student (test) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}"
    )
    print("\n")

    # put student model back into train mode
    student_model.train()

# Evaluate Models

In [None]:
# create testing data loader
validation_dataloader = DataLoader(tokenized_data["validation"], batch_size=8)

# Evaluate the teacher model
teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(
    teacher_model, validation_dataloader, device
)
print(
    f"Teacher (validation) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}"
)

# Evaluate the student model
student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(
    student_model, validation_dataloader, device
)
print(
    f"Student (validation) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}"
)

# Save the Student Model Locally

In [None]:
# Save the student model
student_model.save_pretrained("models/phishing-site-classifier_student")

# How to load the student model from disk for future use
# from transformers import AutoModelForSequenceClassification

# # Load the student model from disk
# loaded_student_model = AutoModelForSequenceClassification.from_pretrained(
#     "models/phishing-site-classifier_student"
# )

# Quantization

In [None]:
# Set the device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load model in model as 4-bit
nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

student_model_nf4 = AutoModelForSequenceClassification.from_pretrained(
    pretrained_model_name_or_path="models/phishing-site-classifier_student",
    device_map=device,
    quantization_config=nf4_config,
)

In [None]:
# Evaluate the student model
quantized_accuracy, quantized_precision, quantized_recall, quantized_f1 = (
    evaluate_model(student_model_nf4, validation_dataloader, device)
)

print("Post-quantization Performance")
print(
    f"Accuracy: {quantized_accuracy:.4f}, Precision: {quantized_precision:.4f}, Recall: {quantized_recall:.4f}, F1 Score: {quantized_f1:.4f}"
)

In [None]:
# Save the quantized model
student_model_nf4.save_pretrained("models/phishing-site-classifier_student_nf4")

# Evaluate size difference between student and quantized student models

In [None]:
# Function to get model size on disk
def get_model_size(filepath):
    size_in_mb = os.path.getsize(filepath) / (1024**2)  # Convert to MB
    return size_in_mb