<a href="https://colab.research.google.com/github/ShubhamW248/Knowledge-Distillation/blob/main/KD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Knowledge Distillation: Training Efficient Student Models from Large Teacher Networks

## Project Overview

**Knowledge Distillation** is a model compression technique where a smaller "student" model learns to mimic the behavior of a larger "teacher" model. This approach enables us to create compact models that retain much of the performance of their larger counterparts while being significantly faster and more memory-efficient.

### Models Used:
- **Teacher Model**: BERT-base-uncased (110M parameters) - A large, high-performance transformer
- **Student Model**: DistilBERT-base-uncased (66M parameters) - A distilled version of BERT with ~40% fewer parameters

### Project Goals:
1. Demonstrate knowledge distillation on text classification
2. Compare model performance, size, and inference speed
3. Show that smaller models can achieve competitive accuracy with proper training

### What We'll Evaluate:
- **Accuracy**: Classification performance on test data
- **Model Size**: Memory footprint in MB
- **Inference Time**: Speed of prediction on sample data
- **Training Efficiency**: Loss curves and convergence

#Install Dependencies & Set Up


In [1]:
!pip install transformers datasets torch torchvision matplotlib seaborn numpy pandas scikit-learn -q

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    TrainingArguments, Trainer, DataCollatorWithPadding
)
from datasets import load_dataset, Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
import time
import os
import warnings
warnings.filterwarnings('ignore')

# Set seed
torch.manual_seed(42)
np.random.seed(42)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


#Load & Prepare Dataset

In [2]:
!pip install -U datasets



In [3]:
print("Loading Financial Phrasebank dataset...")
dataset = load_dataset("financial_phrasebank", "sentences_allagree")

texts = dataset['train']['sentence']
labels = dataset['train']['label']

train_texts, temp_texts, train_labels, temp_labels = train_test_split(texts, labels, test_size=0.3, stratify=labels, random_state=42)
val_texts, test_texts, val_labels, test_labels = train_test_split(temp_texts, temp_labels, test_size=0.5, stratify=temp_labels, random_state=42)

train_dataset = [{'text': t, 'label': l} for t, l in zip(train_texts, train_labels)]
val_dataset = [{'text': t, 'label': l} for t, l in zip(val_texts, val_labels)]
test_dataset = [{'text': t, 'label': l} for t, l in zip(test_texts, test_labels)]


Loading Financial Phrasebank dataset...


#Tokenization

In [4]:
model_name_teacher = "bert-base-uncased"
model_name_student = "distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_name_teacher)

def tokenize_function(example):
    return tokenizer(
        example['text'],
        padding="max_length",
        truncation=True,
        max_length=128
    )

train_ds = Dataset.from_list(train_dataset).map(tokenize_function)
val_ds = Dataset.from_list(val_dataset).map(tokenize_function)
test_ds = Dataset.from_list(test_dataset).map(tokenize_function)

for ds in [train_ds, val_ds, test_ds]:
    ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])


Map:   0%|          | 0/1584 [00:00<?, ? examples/s]

Map:   0%|          | 0/340 [00:00<?, ? examples/s]

Map:   0%|          | 0/340 [00:00<?, ? examples/s]

#Fine-Tune Teacher Model

In [5]:
!pip install -U transformers datasets huggingface_hub




In [8]:
# Step 5: Fine-tune Teacher Model
teacher_model = AutoModelForSequenceClassification.from_pretrained(model_name_teacher, num_labels=3).to(device)

training_args_teacher = TrainingArguments(
    output_dir="./teacher_model",
    # Changed evaluation_strategy to eval_strategy
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    logging_dir='./logs',
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=1)
    return {"accuracy": accuracy_score(labels, preds)}

teacher_trainer = Trainer(
    model=teacher_model,
    args=training_args_teacher,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer)
)

teacher_trainer.train()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mu22cs044[0m ([33mu22cs044-sardar-vallabhbhai-national-institute-of-techno[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy
1,0.6993,0.320097,0.885294
2,0.2682,0.144346,0.958824
3,0.0846,0.102844,0.967647


TrainOutput(global_step=297, training_loss=0.29287603487470737, metrics={'train_runtime': 177.4447, 'train_samples_per_second': 26.78, 'train_steps_per_second': 1.674, 'total_flos': 312578740260864.0, 'train_loss': 0.29287603487470737, 'epoch': 3.0})

#Evaluate Fine-Tuned Teacher



In [9]:
print("Evaluating fine-tuned teacher...")
teacher_results = teacher_trainer.predict(test_ds)
print(f"Accuracy: {accuracy_score(teacher_results.label_ids, np.argmax(teacher_results.predictions, axis=1))}")
print(classification_report(teacher_results.label_ids, np.argmax(teacher_results.predictions, axis=1)))


Evaluating fine-tuned teacher...


Accuracy: 0.9441176470588235
              precision    recall  f1-score   support

           0       0.88      0.91      0.89        46
           1       0.98      0.97      0.97       209
           2       0.90      0.89      0.90        85

    accuracy                           0.94       340
   macro avg       0.92      0.93      0.92       340
weighted avg       0.94      0.94      0.94       340



#Knowledge Distillation Setup

In [10]:
student_model = AutoModelForSequenceClassification.from_pretrained(model_name_student, num_labels=3).to(device)

class DistillationLoss(nn.Module):
    def __init__(self, temperature=2.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.kl = nn.KLDivLoss(reduction="batchmean")

    def forward(self, student_logits, teacher_logits, labels):
        student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
        hard_loss = F.cross_entropy(student_logits, labels)
        soft_loss = self.kl(student_soft, teacher_soft) * (self.temperature ** 2)
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#Train Student via Distillation

In [11]:
student_optimizer = torch.optim.AdamW(student_model.parameters(), lr=2e-5)
loss_fn = DistillationLoss(temperature=2.0, alpha=0.7)

def distill_train(model, teacher, train_data):
    model.train()
    teacher.eval()
    for epoch in range(3):
        print(f"\nEpoch {epoch+1}")
        total_loss = 0
        for batch in DataLoader(train_data, batch_size=16, shuffle=True):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            with torch.no_grad():
                teacher_logits = teacher(input_ids, attention_mask=attention_mask).logits

            student_logits = model(input_ids, attention_mask=attention_mask).logits
            loss = loss_fn(student_logits, teacher_logits, labels)

            student_optimizer.zero_grad()
            loss.backward()
            student_optimizer.step()

            total_loss += loss.item()
        print(f"Loss: {total_loss / len(train_data)}")

distill_train(student_model, teacher_model, train_ds)



Epoch 1
Loss: 0.05529754754682683

Epoch 2
Loss: 0.007441587743791517

Epoch 3
Loss: 0.0031489716540561105


#Evaluate Student



In [12]:
def evaluate(model, dataset):
    model.eval()
    all_preds, all_labels = [], []
    for batch in DataLoader(dataset, batch_size=16):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        with torch.no_grad():
            logits = model(input_ids, attention_mask=attention_mask).logits
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    print("Student Accuracy:", accuracy_score(all_labels, all_preds))
    print(classification_report(all_labels, all_preds))

evaluate(student_model, test_ds)


Student Accuracy: 0.9529411764705882
              precision    recall  f1-score   support

           0       0.91      0.91      0.91        46
           1       0.98      0.98      0.98       209
           2       0.91      0.91      0.91        85

    accuracy                           0.95       340
   macro avg       0.93      0.93      0.93       340
weighted avg       0.95      0.95      0.95       340



#Compare Model Size & Inference Time


In [14]:
import time
import sys

def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    size_mb = param_size / (1024**2)
    return size_mb

def measure_inference_time(model, tokenizer, sentence="This company has shown excellent growth.", n=100):
    model.eval()
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)

    # Check if the model is a DistilBERT model and remove 'token_type_ids' if present
    if "distilbert" in model.__class__.__name__.lower() and 'token_type_ids' in inputs:
        del inputs['token_type_ids']

    start = time.time()
    with torch.no_grad():
        for _ in range(n):
            _ = model(**inputs)
    end = time.time()

    avg_time = (end - start) / n
    return avg_time * 1000  # in milliseconds

# Load student model if not already loaded
if 'student_model' not in globals():
    from transformers import AutoModelForSequenceClassification
    student_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=3).to(device)

# ✅ Compare size
teacher_size = get_model_size(teacher_model)
student_size = get_model_size(student_model)

# ✅ Compare response time
teacher_time = measure_inference_time(teacher_model, tokenizer)
student_time = measure_inference_time(student_model, tokenizer)

print(f"🧠 Teacher Model Size: {teacher_size:.2f} MB")
print(f"🧠 Student Model Size: {student_size:.2f} MB")
print(f"⏱️ Teacher Inference Time: {teacher_time:.2f} ms/sample")
print(f"⏱️ Student Inference Time: {student_time:.2f} ms/sample")

🧠 Teacher Model Size: 417.65 MB
🧠 Student Model Size: 255.42 MB
⏱️ Teacher Inference Time: 7.66 ms/sample
⏱️ Student Inference Time: 3.98 ms/sample
