In [1]:
!pip install transformers datasets torch scikit-learn



## MNIST with MLP

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Data loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

class TeacherMLP(nn.Module):
    """Large teacher network"""
    def __init__(self):
        super(TeacherMLP, self).__init__()
        self.features = nn.Sequential(
            nn.Linear(784, 1200),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1200, 800),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(800, 400),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
        )
        self.classifier = nn.Linear(400, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        features = self.features(x)
        output = self.classifier(features)
        return output, features

class StudentMLP(nn.Module):
    """Small student network"""
    def __init__(self):
        super(StudentMLP, self).__init__()
        self.features = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
        )
        self.classifier = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        features = self.features(x)
        output = self.classifier(features)
        return output, features

def count_parameters(model):
    """Count trainable parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def train_teacher(model, epochs=15):
    """Train teacher model from scratch"""
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output, _ = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if (epoch + 1) % 5 == 0:
            acc = evaluate_model(model)
            print(f'Teacher Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}, Acc: {acc:.4f}')

def train_student_vanilla(student, epochs=15):
    """Train student without distillation"""
    student.train()
    optimizer = optim.Adam(student.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output, _ = student(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if (epoch + 1) % 5 == 0:
            acc = evaluate_model(student)
            print(f'Student Vanilla Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}, Acc: {acc:.4f}')

def train_vanilla_kd(teacher, student, temperature=1, alpha=0.5, epochs=15):
    """Vanilla Knowledge Distillation"""
    teacher.eval()
    student.train()
    optimizer = optim.Adam(student.parameters(), lr=0.001)
    criterion_ce = nn.CrossEntropyLoss()
    criterion_kd = nn.KLDivLoss(reduction='batchmean')

    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            with torch.no_grad():
                teacher_output, _ = teacher(data)
                teacher_soft = F.softmax(teacher_output / temperature, dim=1)

            optimizer.zero_grad()
            student_output, _ = student(data)
            student_soft = F.log_softmax(student_output / temperature, dim=1)

            # Distillation loss
            kd_loss = criterion_kd(student_soft, teacher_soft) * (temperature ** 2)
            # Cross-entropy loss
            ce_loss = criterion_ce(student_output, target)
            # Combined loss
            loss = alpha * kd_loss + (1 - alpha) * ce_loss

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if (epoch + 1) % 5 == 0:
            acc = evaluate_model(student)
            print(f'Vanilla KD Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}, Acc: {acc:.4f}')

def train_feature_kd(teacher, student, temperature=1, alpha=0.5, beta=0.5, epochs=15):
    """Feature-based Knowledge Distillation"""
    teacher.eval()
    student.train()
    optimizer = optim.Adam(student.parameters(), lr=0.001)
    criterion_ce = nn.CrossEntropyLoss()
    criterion_kd = nn.KLDivLoss(reduction='batchmean')
    criterion_feat = nn.MSELoss()

    # Feature adaptation layer (to match dimensions)
    feat_adapter = nn.Linear(64, 400).to(device)
    feat_optimizer = optim.Adam(feat_adapter.parameters(), lr=0.001)

    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            with torch.no_grad():
                teacher_output, teacher_features = teacher(data)
                teacher_soft = F.softmax(teacher_output / temperature, dim=1)

            optimizer.zero_grad()
            feat_optimizer.zero_grad()

            student_output, student_features = student(data)
            student_soft = F.log_softmax(student_output / temperature, dim=1)

            # Adapt student features to match teacher dimension
            adapted_student_features = feat_adapter(student_features)

            # Distillation loss
            kd_loss = criterion_kd(student_soft, teacher_soft) * (temperature ** 2)
            # Cross-entropy loss
            ce_loss = criterion_ce(student_output, target)
            # Feature distillation loss
            feat_loss = criterion_feat(adapted_student_features, teacher_features)

            # Combined loss
            loss = alpha * kd_loss + (1 - alpha) * ce_loss + beta * feat_loss

            loss.backward()
            optimizer.step()
            feat_optimizer.step()
            total_loss += loss.item()

        if (epoch + 1) % 5 == 0:
            acc = evaluate_model(student)
            print(f'Feature KD Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}, Acc: {acc:.4f}')

def train_simkd(teacher, student, temperature=4, alpha=0.7, beta=0.1, epochs=15):
    """Similarity-preserving Knowledge Distillation (SimKD)"""
    teacher.eval()
    student.train()
    optimizer = optim.Adam(student.parameters(), lr=0.001)
    criterion_ce = nn.CrossEntropyLoss()
    criterion_kd = nn.KLDivLoss(reduction='batchmean')

    def similarity_loss(f_s, f_t):
        """Compute similarity-preserving loss"""
        # Normalize features
        f_s = F.normalize(f_s, p=2, dim=1)
        f_t = F.normalize(f_t, p=2, dim=1)

        # Compute similarity matrices
        G_s = torch.mm(f_s, f_s.t())
        G_t = torch.mm(f_t, f_t.t())

        # Similarity-preserving loss
        return F.mse_loss(G_s, G_t)

    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            with torch.no_grad():
                teacher_output, teacher_features = teacher(data)
                teacher_soft = F.softmax(teacher_output / temperature, dim=1)

            optimizer.zero_grad()
            student_output, student_features = student(data)
            student_soft = F.log_softmax(student_output / temperature, dim=1)

            # Distillation loss
            kd_loss = criterion_kd(student_soft, teacher_soft) * (temperature ** 2)
            # Cross-entropy loss
            ce_loss = criterion_ce(student_output, target)
            # Similarity loss
            sim_loss = similarity_loss(student_features, teacher_features)

            # Combined loss
            loss = alpha * kd_loss + (1 - alpha) * ce_loss + beta * sim_loss

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if (epoch + 1) % 5 == 0:
            acc = evaluate_model(student)
            print(f'SimKD Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}, Acc: {acc:.4f}')

def evaluate_model(model):
    """Evaluate model accuracy"""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output, _ = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return correct / total

def measure_inference_time(model, num_runs=100):
    """Measure inference time"""
    model.eval()
    # Warm up
    with torch.no_grad():
        for _ in range(10):
            dummy_input = torch.randn(1, 784).to(device)
            _ = model(dummy_input)

    # Measure time
    start_time = time.time()
    with torch.no_grad():
        for _ in range(num_runs):
            dummy_input = torch.randn(1, 784).to(device)
            _ = model(dummy_input)
    end_time = time.time()

    avg_time = (end_time - start_time) / num_runs * 1000  # milliseconds
    return avg_time

def benchmark_models():
    """Benchmark all models"""
    results = defaultdict(dict)

    # Teacher model
    print("="*50)
    print("Training Teacher Model")
    print("="*50)
    teacher = TeacherMLP().to(device)
    train_teacher(teacher, epochs=5)

    teacher_acc = evaluate_model(teacher)
    teacher_params = count_parameters(teacher)
    teacher_time = measure_inference_time(teacher)

    results['Teacher']['Top-1 Acc'] = teacher_acc
    results['Teacher']['Accuracy'] = teacher_acc
    results['Teacher']['Parameters'] = teacher_params
    results['Teacher']['Inference Time (ms)'] = teacher_time

    print(f"Teacher - Acc: {teacher_acc:.4f}, Params: {teacher_params}, Time: {teacher_time:.3f}ms")

    # Student (Vanilla)
    print("\n" + "="*50)
    print("Training Student Model (Vanilla)")
    print("="*50)
    student_vanilla = StudentMLP().to(device)
    train_student_vanilla(student_vanilla , epochs=5)

    vanilla_acc = evaluate_model(student_vanilla)
    vanilla_params = count_parameters(student_vanilla)
    vanilla_time = measure_inference_time(student_vanilla)

    results['Student (Vanilla)']['Top-1 Acc'] = vanilla_acc
    results['Student (Vanilla)']['Accuracy'] = vanilla_acc
    results['Student (Vanilla)']['Parameters'] = vanilla_params
    results['Student (Vanilla)']['Inference Time (ms)'] = vanilla_time

    print(f"Student Vanilla - Acc: {vanilla_acc:.4f}, Params: {vanilla_params}, Time: {vanilla_time:.3f}ms")

    # Vanilla KD
    print("\n" + "="*50)
    print("Training Student with Vanilla KD")
    print("="*50)
    student_kd = StudentMLP().to(device)
    train_vanilla_kd(teacher, student_kd, epochs=5)

    kd_acc = evaluate_model(student_kd)
    kd_params = count_parameters(student_kd)
    kd_time = measure_inference_time(student_kd)

    results['Vanilla KD']['Top-1 Acc'] = kd_acc
    results['Vanilla KD']['Accuracy'] = kd_acc
    results['Vanilla KD']['Parameters'] = kd_params
    results['Vanilla KD']['Inference Time (ms)'] = kd_time

    print(f"Vanilla KD - Acc: {kd_acc:.4f}, Params: {kd_params}, Time: {kd_time:.3f}ms")

    # Feature KD
    print("\n" + "="*50)
    print("Training Student with Feature KD")
    print("="*50)
    student_feat = StudentMLP().to(device)
    train_feature_kd(teacher, student_feat, epochs=5)

    feat_acc = evaluate_model(student_feat)
    feat_params = count_parameters(student_feat)
    feat_time = measure_inference_time(student_feat)

    results['Feature KD']['Top-1 Acc'] = feat_acc
    results['Feature KD']['Accuracy'] = feat_acc
    results['Feature KD']['Parameters'] = feat_params
    results['Feature KD']['Inference Time (ms)'] = feat_time

    print(f"Feature KD - Acc: {feat_acc:.4f}, Params: {feat_params}, Time: {feat_time:.3f}ms")

    # SimKD
    print("\n" + "="*50)
    print("Training Student with SimKD")
    print("="*50)
    student_sim = StudentMLP().to(device)
    train_simkd(teacher, student_sim, epochs=5)

    sim_acc = evaluate_model(student_sim)
    sim_params = count_parameters(student_sim)
    sim_time = measure_inference_time(student_sim)

    results['SimKD']['Top-1 Acc'] = sim_acc
    results['SimKD']['Accuracy'] = sim_acc
    results['SimKD']['Parameters'] = sim_params
    results['SimKD']['Inference Time (ms)'] = sim_time

    print(f"SimKD - Acc: {sim_acc:.4f}, Params: {sim_params}, Time: {sim_time:.3f}ms")

    return results

def print_results_table(results):
    """Print benchmark results in a nice table format"""
    print("\n" + "="*80)
    print("KNOWLEDGE DISTILLATION BENCHMARK RESULTS")
    print("="*80)

    # Table header
    header = f"{'Method':<20} {'Top-1 Acc':<12} {'Accuracy':<12} {'Parameters':<12} {'Inf. Time (ms)':<15}"
    print(header)
    print("-" * 80)

    # Table rows
    for method, metrics in results.items():
        row = f"{method:<20} "
        row += f"{metrics['Top-1 Acc']:<12.4f} "
        row += f"{metrics['Accuracy']:<12.4f} "
        row += f"{metrics['Parameters']:<12,} "
        row += f"{metrics['Inference Time (ms)']:<15.3f}"
        print(row)

    print("-" * 80)

    # Performance improvements
    teacher_acc = results['Teacher']['Accuracy']
    vanilla_acc = results['Student (Vanilla)']['Accuracy']

    print("\nPerformance Improvements over Student (Vanilla):")
    for method in ['Vanilla KD', 'Feature KD', 'SimKD']:
        if method in results:
            improvement = (results[method]['Accuracy'] - vanilla_acc) * 100
            print(f"{method}: +{improvement:.2f}% accuracy improvement")

    # Compression ratio
    teacher_params = results['Teacher']['Parameters']
    student_params = results['Student (Vanilla)']['Parameters']
    compression_ratio = teacher_params / student_params
    print(f"\nModel Compression Ratio: {compression_ratio:.1f}x parameter reduction")
    print(f"Teacher Parameters: {teacher_params:,}")
    print(f"Student Parameters: {student_params:,}")

# Run the benchmark
if __name__ == "__main__":
    results = benchmark_models()
    print_results_table(results)

Using device: cuda
Training Teacher Model
Teacher Epoch 5/5, Loss: 0.0942, Acc: 0.9771
Teacher - Acc: 0.9771, Params: 2227210, Time: 0.246ms

Training Student Model (Vanilla)
Student Vanilla Epoch 5/5, Loss: 0.1103, Acc: 0.9745
Student Vanilla - Acc: 0.9745, Params: 109386, Time: 0.191ms

Training Student with Vanilla KD
Vanilla KD Epoch 5/5, Loss: 0.0909, Acc: 0.9710
Vanilla KD - Acc: 0.9710, Params: 109386, Time: 0.175ms

Training Student with Feature KD
Feature KD Epoch 5/5, Loss: 0.2402, Acc: 0.9717
Feature KD - Acc: 0.9717, Params: 109386, Time: 0.168ms

Training Student with SimKD
SimKD Epoch 5/5, Loss: 0.5154, Acc: 0.9633
SimKD - Acc: 0.9633, Params: 109386, Time: 0.230ms

KNOWLEDGE DISTILLATION BENCHMARK RESULTS
Method               Top-1 Acc    Accuracy     Parameters   Inf. Time (ms) 
--------------------------------------------------------------------------------
Teacher              0.9771       0.9771       2,227,210    0.246          
Student (Vanilla)    0.9745       0.9

## Distillation in LLM (Bert)

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score

In [3]:
dataset = load_dataset("imdb")
test_texts = dataset['test']['text']
test_labels = dataset['test']['label']

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

plain_text/unsupervised-00000-of-00001.p(…):   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [4]:
models = {
    "mbert": "bert-base-multilingual-cased",
    "distil-mbert": "distilbert-base-multilingual-cased"
}

tokenizers = {name: AutoTokenizer.from_pretrained(model_name) for name, model_name in models.items()}
model_objects = {name: AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
                 for name, model_name in models.items()}

device = "cuda" if torch.cuda.is_available() else "cpu"
for model in model_objects.values():
    model.to(device)
    model.eval()

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/466 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/542M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
from torch.utils.data import DataLoader, TensorDataset

def encode_texts(tokenizer, texts, max_len=256):
    encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_len, return_tensors="pt")
    return encodings

batch_size = 16

In [6]:
def evaluate_model(model, tokenizer, texts, labels, batch_size=16):
    encodings = encode_texts(tokenizer, texts)
    dataset = TensorDataset(encodings['input_ids'], encodings['attention_mask'], torch.tensor(labels))
    dataloader = DataLoader(dataset, batch_size=batch_size)

    preds = []
    for batch in dataloader:
        input_ids, attention_mask, _ = [x.to(device) for x in batch]
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
            batch_preds = torch.argmax(outputs.logits, dim=-1)
            preds.extend(batch_preds.cpu().numpy())

    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='weighted')
    return acc, f1

In [8]:
for name in models:
    acc, f1 = evaluate_model(model_objects[name], tokenizers[name], list(test_texts), test_labels)
    print(f"{name}: Accuracy = {acc:.4f}, F1 = {f1:.4f}")

mbert: Accuracy = 0.5000, F1 = 0.3333
distil-mbert: Accuracy = 0.5062, F1 = 0.4087


- mBERT-base has approximately 110 million parameters,
- while Distil-mBERT has about 66 million parameters.
- This represents a reduction of roughly 40% in the number of parameters.

- Distil-mBERT utilizes 6 transformer layers compared to mBERT-base's 12 layers.