# Knowledge Distillation Example

Code authored by: Shaw Talebi

[Video](https://youtu.be/FLkUOkeMd5M) <br>
[Blog](https://towardsdatascience.com/compressing-large-language-models-llms-9f406eea5b5e) <br>
[Colab](https://colab.research.google.com/drive/1B6kHngy3cC2i1VFa4saG9tliCvPPoFw0?usp=sharing)

In [None]:
!pip install datasets
!pip install transformers



### imports

In [None]:
from datasets import load_dataset

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import DistilBertForSequenceClassification, DistilBertConfig

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

### load data

In [None]:
data = load_dataset("shawhin/phishing-site-classification")
data

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 2100
    })
    validation: Dataset({
        features: ['text', 'labels'],
        num_rows: 450
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 450
    })
})

### load models

In [None]:
device = torch.device('cuda')

In [None]:
# Load teacher model and tokenizer
model_path = "shawhin/bert-phishing-classifier_teacher"

tokenizer = AutoTokenizer.from_pretrained(model_path)
teacher_model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)

In [None]:
# Load student model
my_config = DistilBertConfig(n_heads=8, n_layers=4) # drop 4 heads per layer and 2 layers

student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased",
                                                                    config=my_config,).to(device)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### tokenize text

In [None]:
# define text preprocessing
def preprocess_function(examples):
    return tokenizer(examples["text"], padding='max_length', truncation=True)

# tokenize all datasetse
tokenized_data = data.map(preprocess_function, batched=True)
tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

### evaluation function

In [None]:
# Function to evaluate model performance
def evaluate_model(model, dataloader, device):
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []

    # Disable gradient calculations
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass to get logits
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Get predictions
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')

    return accuracy, precision, recall, f1

### train student model

In [None]:
# Function to compute distillation and hard-label loss
def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):
    # Compute soft targets from teacher logits
    soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=1)
    student_soft = nn.functional.log_softmax(student_logits / temperature, dim=1)

    # KL Divergence loss for distillation
    distill_loss = nn.functional.kl_div(student_soft, soft_targets, reduction='batchmean') * (temperature ** 2)

    # Cross-entropy loss for hard labels
    hard_loss = nn.CrossEntropyLoss()(student_logits, true_labels)

    # Combine losses
    loss = alpha * distill_loss + (1.0 - alpha) * hard_loss

    return loss

In [None]:
# hyperparameters
batch_size = 32
lr = 1e-4
num_epochs = 5
temperature = 2.0
alpha = 0.5

# define optimizer
optimizer = optim.Adam(student_model.parameters(), lr=lr)

# create training data loader
dataloader = DataLoader(tokenized_data['train'], batch_size=batch_size)
# create testing data loader
test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size)

In [None]:
# put student model in train mode
student_model.train()

# train model
for epoch in range(num_epochs):
    for batch in dataloader:
        # Prepare inputs
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Disable gradient calculation for teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits

        # Forward pass through the student model
        student_outputs = student_model(input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits

        # Compute the distillation loss
        loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")

    # Evaluate the teacher model
    teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device)
    print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")

    # Evaluate the student model
    student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, test_dataloader, device)
    print(f"Student (test) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
    print("\n")

    # put student model back into train mode
    student_model.train()

Epoch 1 completed with loss: 0.13162189722061157
Teacher (test) - Accuracy: 0.8667, Precision: 0.8967, Recall: 0.8341, F1 Score: 0.8643
Student (test) - Accuracy: 0.8778, Precision: 0.8655, Recall: 0.8996, F1 Score: 0.8822


Epoch 2 completed with loss: 0.07712383568286896
Teacher (test) - Accuracy: 0.8667, Precision: 0.8967, Recall: 0.8341, F1 Score: 0.8643
Student (test) - Accuracy: 0.9000, Precision: 0.9259, Recall: 0.8734, F1 Score: 0.8989


Epoch 3 completed with loss: 0.06931211799383163
Teacher (test) - Accuracy: 0.8667, Precision: 0.8967, Recall: 0.8341, F1 Score: 0.8643
Student (test) - Accuracy: 0.8978, Precision: 0.8961, Recall: 0.9039, F1 Score: 0.9000


Epoch 4 completed with loss: 0.07302902638912201
Teacher (test) - Accuracy: 0.8667, Precision: 0.8967, Recall: 0.8341, F1 Score: 0.8643
Student (test) - Accuracy: 0.9022, Precision: 0.9302, Recall: 0.8734, F1 Score: 0.9009


Epoch 5 completed with loss: 0.05635004863142967
Teacher (test) - Accuracy: 0.8667, Precision: 0.896

### Evaluate Models

In [None]:
# create testing data loader
validation_dataloader = DataLoader(tokenized_data['validation'], batch_size=8)

# Evaluate the teacher model
teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, validation_dataloader, device)
print(f"Teacher (validation) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")

# Evaluate the student model
student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, validation_dataloader, device)
print(f"Student (validation) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")

Teacher (validation) - Accuracy: 0.8889, Precision: 0.9070, Recall: 0.8667, F1 Score: 0.8864
Student (validation) - Accuracy: 0.9289, Precision: 0.9707, Recall: 0.8844, F1 Score: 0.9256


### Push to Hub

In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
student_model.push_to_hub("shawhin/bert-phishing-classifier_student")

README.md:   0%|          | 0.00/31.0 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/211M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/shawhin/bert-phishing-classifier_student/commit/37c41cfff9c41478eb4a341bd3f54d3a8e7ea322', commit_message='Upload DistilBertForSequenceClassification', commit_description='', oid='37c41cfff9c41478eb4a341bd3f54d3a8e7ea322', pr_url=None, pr_revision=None, pr_num=None)