# Knowledge Distillation LLM

## Installation

In [1]:
!pip install -U -q bitsandbytes
!pip install -q datasets

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 MB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m110.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m87.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m54.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## Load the Dataset

In [2]:
import os

os.environ['DATASET']='aisuko/phishing-binary-classification'
os.environ["TEACHER"]='aisuko/phishing-binary-classification-bert'
os.environ["STUDENT"]='aisuko/phishing-binary-classification_student'

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [4]:
from datasets import load_dataset

ds=load_dataset(os.getenv('DATASET'))
ds

train-00000-of-00001.parquet:   0%|          | 0.00/22.2M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/2.79M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/2.77M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/528006 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/66001 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/66001 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['url', 'labels'],
        num_rows: 528006
    })
    validation: Dataset({
        features: ['url', 'labels'],
        num_rows: 66001
    })
    test: Dataset({
        features: ['url', 'labels'],
        num_rows: 66001
    })
})

In [5]:
from datasets import DatasetDict

pre_processed_ds_train_low=ds['train'].shuffle(seed=42).select(range(10000))
pre_processed_ds_test_low=ds['test'].shuffle(seed=42).select(range(5000))
pre_processed_ds_validate_low=ds['validation'].shuffle(seed=42).select(range(5000))

ds_low=DatasetDict({
    'train': pre_processed_ds_train_low,
    'test': pre_processed_ds_test_low,
    'validation': pre_processed_ds_validate_low,
})

## Load the Teacher Model

The teacher model is a fine-tuned version of [openai-community/roberta-large-openai-detector](https://huggingface.co/openai-community/roberta-large-openai-detector) model on phishing website URLs dataset, see [FT GPT-2 Detector for text classification.](https://www.kaggle.com/code/aisuko/ft-gpt-2-detector-for-binary-classification)

In [7]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

device=torch.device("cuda")

# load teacher model and tokenizer
tokenizer=AutoTokenizer.from_pretrained(os.getenv('TEACHER'))
teacher_model=AutoModelForSequenceClassification.from_pretrained(os.getenv('TEACHER')).to(device)
teacher_model

tokenizer_config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/851 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [20]:
print(f'Teacher Model Parameters: {sum(p.numel() for p in teacher_model.parameters()):,}')

Teacher Model Parameters: 109,483,778


---

## Load the Student Model

In [16]:
from transformers import DistilBertForSequenceClassification, DistilBertConfig

# drop 4 heads per layer and 2 layers. Default is 12 attention heads per layer, 6 layers.
configuration=DistilBertConfig(n_heads=8, n_layers=4)
configuration

DistilBertConfig {
  "activation": "gelu",
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 8,
  "n_layers": 4,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "transformers_version": "4.48.3",
  "vocab_size": 30522
}

In [17]:
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", config=configuration).to(device)
student_model

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-3): 4 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [19]:
print(f'Student Model Parameters: {sum(p.numel() for p in student_model.parameters()):,}')

Student Model Parameters: 52,779,266


---

## Tokenize the text

In [21]:
def preprocess_func(examples):
    return tokenizer(examples["url"], padding='max_length', truncation=True)


# tokenized all data
tokenized_data=ds_low.map(preprocess_func, batched=True)
tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

## Evaluation Function

In [22]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# Function to evaluate model performance
def evaluate_model(model, dataloader, device):
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []

    # Disable gradient calculations
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass to get logits
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Get predictions
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')

    return accuracy, precision, recall, f1

## Train Student Model

In [23]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# hyperparameters
batch_size = 32
lr = 1e-4
num_epochs = 5
temperature = 2.0
alpha = 0.5

# Function to compute distillation and hard-label loss
def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):
    # Compute soft targets from teacher logits
    soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=1)
    student_soft = nn.functional.log_softmax(student_logits / temperature, dim=1)

    # KL Divergence loss for distillation
    distill_loss = nn.functional.kl_div(student_soft, soft_targets, reduction='batchmean') * (temperature ** 2)

    # Cross-entropy loss for hard labels
    hard_loss = nn.CrossEntropyLoss()(student_logits, true_labels)

    # Combine losses
    loss = alpha * distill_loss + (1.0 - alpha) * hard_loss

    return loss


# define optimizer
optimizer = optim.Adam(student_model.parameters(), lr=lr)

# create training data loader
dataloader = DataLoader(tokenized_data['train'], batch_size=batch_size)
# create testing data loader
test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size)


In [24]:
student_model.train()

# train model
for epoch in range(num_epochs):
    for batch in dataloader:
        # Prepare inputs
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Disable gradient calculation for teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits

        # Forward pass through the student model
        student_outputs = student_model(input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits

        # Compute the distillation loss
        loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")

    # Evaluate the teacher model
    teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device)
    print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")

    # Evaluate the student model
    student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, test_dataloader, device)
    print(f"Student (test) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
    print("\n")

    # put student model back into train mode
    student_model.train()

Epoch 1 completed with loss: 0.1555989682674408
Teacher (test) - Accuracy: 0.8198, Precision: 0.7737, Recall: 0.9253, F1 Score: 0.8427
Student (test) - Accuracy: 0.9586, Precision: 0.9556, Recall: 0.9655, F1 Score: 0.9605


Epoch 2 completed with loss: 0.16136083006858826
Teacher (test) - Accuracy: 0.8198, Precision: 0.7737, Recall: 0.9253, F1 Score: 0.8427
Student (test) - Accuracy: 0.9544, Precision: 0.9795, Recall: 0.9322, F1 Score: 0.9552


Epoch 3 completed with loss: 0.15648235380649567
Teacher (test) - Accuracy: 0.8198, Precision: 0.7737, Recall: 0.9253, F1 Score: 0.8427
Student (test) - Accuracy: 0.9526, Precision: 0.9386, Recall: 0.9728, F1 Score: 0.9554


Epoch 4 completed with loss: 0.161363422870636
Teacher (test) - Accuracy: 0.8198, Precision: 0.7737, Recall: 0.9253, F1 Score: 0.8427
Student (test) - Accuracy: 0.9498, Precision: 0.9739, Recall: 0.9287, F1 Score: 0.9508


Epoch 5 completed with loss: 0.1568385511636734
Teacher (test) - Accuracy: 0.8198, Precision: 0.7737, R

## Evaluate Models

In [25]:
# create testing data loader
validation_dataloader = DataLoader(tokenized_data['validation'], batch_size=8)

# Evaluate the teacher model
teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, validation_dataloader, device)
print(f"Teacher (validation) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")

# Evaluate the student model
student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, validation_dataloader, device)
print(f"Student (validation) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")

Teacher (validation) - Accuracy: 0.8340, Precision: 0.7911, Recall: 0.9245, F1 Score: 0.8526
Student (validation) - Accuracy: 0.9542, Precision: 0.9728, Recall: 0.9380, F1 Score: 0.9551


## Model Quantization

In [26]:
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype = torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

model_nf4 = AutoModelForSequenceClassification.from_pretrained(os.getenv('STUDENT'), device_map=device, quantization_config=quantization_config)

config.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/211M [00:00<?, ?B/s]

## Post Quantization Evaluation

In [27]:
quantized_accuracy, quantized_precision, quantized_recall, quantized_f1 = evaluate_model(model_nf4, validation_dataloader, device)

print("Post-quantization Performance")
print(f"Accuracy: {quantized_accuracy:.4f}, Precision: {quantized_precision:.4f}, Recall: {quantized_recall:.4f}, F1 Score: {quantized_f1:.4f}")

Post-quantization Performance
Accuracy: 0.9464, Precision: 0.9755, Recall: 0.9199, F1 Score: 0.9469


## Reference:
https://www.kaggle.com/code/aisuko/compressing-a-llm-with-distillation-quantization