# Model Compression using Knowledge Distillation

In [None]:
!pip install datasets transformers torch scikit-learn

from datasets import load_dataset

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import DistilBertForSequenceClassification, DistilBertConfig

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [6]:
data = load_dataset("Arnav0805/phishing-site-classification")
data

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/515 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/96.9k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/21.1k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/450 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/450 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 2100
    })
    validation: Dataset({
        features: ['text', 'labels'],
        num_rows: 450
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 450
    })
})

In [7]:
device = torch.device('cuda')

In [8]:
model_path = "Arnav0805/bert-phishing-classifier_teacher"

tokenizer = AutoTokenizer.from_pretrained(model_path)
teacher_model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)

tokenizer_config.json:   0%|          | 0.00/1.22k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/851 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [9]:
my_config = DistilBertConfig(n_heads=8, n_layers=4) # drop 4 heads per layer and 2 layers

student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", config=my_config,).to(device)

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Tokenize Text

In [10]:
def preprocess_function(examples):
    return tokenizer(examples["text"], padding='max_length', truncation=True)

tokenized_data = data.map(preprocess_function, batched=True)
tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

Map:   0%|          | 0/2100 [00:00<?, ? examples/s]

Map:   0%|          | 0/450 [00:00<?, ? examples/s]

Map:   0%|          | 0/450 [00:00<?, ? examples/s]

### Evaluation Function

In [11]:
def evaluate_model(model, dataloader, device):
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass to get logits
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Get predictions
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')

    return accuracy, precision, recall, f1

### Train Student Model

In [12]:
# Function to compute distillation and hard-label loss
def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):
    # Compute soft targets from teacher logits
    soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=1)
    student_soft = nn.functional.log_softmax(student_logits / temperature, dim=1)

    # KL Divergence loss for distillation
    distill_loss = nn.functional.kl_div(student_soft, soft_targets, reduction='batchmean') * (temperature ** 2)

    # Cross-entropy loss for hard labels
    hard_loss = nn.CrossEntropyLoss()(student_logits, true_labels)

    # Combine losses
    loss = alpha * distill_loss + (1.0 - alpha) * hard_loss

    return loss

In [13]:
batch_size = 32
lr = 1e-4
num_epochs = 5
temperature = 2.0
alpha = 0.5

optimizer = optim.Adam(student_model.parameters(), lr=lr)

dataloader = DataLoader(tokenized_data['train'], batch_size=batch_size)
test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size)

In [14]:
student_model.train()

for epoch in range(num_epochs):
    for batch in dataloader:

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Disable gradient calculation for teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits

        # Forward pass through the student model
        student_outputs = student_model(input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits

        # Compute the distillation loss
        loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")

    teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device)
    print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")

    student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, test_dataloader, device)
    print(f"Student (test) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
    print("\n")

    student_model.train()

Epoch 1 completed with loss: 0.20644226670265198
Teacher (test) - Accuracy: 0.8667, Precision: 0.8858, Recall: 0.8472, F1 Score: 0.8661
Student (test) - Accuracy: 0.8867, Precision: 0.9009, Recall: 0.8734, F1 Score: 0.8869


Epoch 2 completed with loss: 0.08574569970369339
Teacher (test) - Accuracy: 0.8667, Precision: 0.8858, Recall: 0.8472, F1 Score: 0.8661
Student (test) - Accuracy: 0.8956, Precision: 0.9252, Recall: 0.8646, F1 Score: 0.8939


Epoch 3 completed with loss: 0.10440117120742798
Teacher (test) - Accuracy: 0.8667, Precision: 0.8858, Recall: 0.8472, F1 Score: 0.8661
Student (test) - Accuracy: 0.8911, Precision: 0.9018, Recall: 0.8821, F1 Score: 0.8918


Epoch 4 completed with loss: 0.05450151860713959
Teacher (test) - Accuracy: 0.8667, Precision: 0.8858, Recall: 0.8472, F1 Score: 0.8661
Student (test) - Accuracy: 0.9133, Precision: 0.8831, Recall: 0.9563, F1 Score: 0.9182


Epoch 5 completed with loss: 0.07484656572341919
Teacher (test) - Accuracy: 0.8667, Precision: 0.885

### Evaluate Models

In [15]:
validation_dataloader = DataLoader(tokenized_data['validation'], batch_size=8)

teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, validation_dataloader, device)
print(f"Teacher (validation) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")

student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, validation_dataloader, device)
print(f"Student (validation) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")

Teacher (validation) - Accuracy: 0.8844, Precision: 0.9100, Recall: 0.8533, F1 Score: 0.8807
Student (validation) - Accuracy: 0.9244, Precision: 0.9484, Recall: 0.8978, F1 Score: 0.9224


In [19]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [20]:
student_model.push_to_hub("Arnav0805/bert-phishing-classifier_student")

model.safetensors:   0%|          | 0.00/211M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Arnav0805/bert-phishing-classifier_student/commit/6eaa48f488482b84538d20b45a9b538c4a4776f1', commit_message='Upload DistilBertForSequenceClassification', commit_description='', oid='6eaa48f488482b84538d20b45a9b538c4a4776f1', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Arnav0805/bert-phishing-classifier_student', endpoint='https://huggingface.co', repo_type='model', repo_id='Arnav0805/bert-phishing-classifier_student'), pr_revision=None, pr_num=None)