In [None]:
!pip install transformers datasets torch torchvision accelerate evaluate scikit-learn

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings('ignore')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:
dataset = load_dataset('sst2')

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 872
    })
    test: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 1821
    })
})

In [None]:
train_data = dataset['train'].select(range(1000))
val_data = dataset['validation'].select(range(200))

In [None]:
teacher_model_name = "bert-base-uncased"  # Large teacher model
student_model_name = "distilbert-base-uncased" # Smaller student model

In [None]:
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [None]:
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_model_name,
    num_labels=2
).to(device)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
student_model = AutoModelForSequenceClassification.from_pretrained(
    student_model_name,
    num_labels=2
).to(device)

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
print(f"Teacher model parameters: {sum(p.numel() for p in teacher_model.parameters()):,} \n")
print(f"Student model parameters: {sum(p.numel() for p in student_model.parameters()):,}")

Teacher model parameters: 109,483,778 

Student model parameters: 66,955,010


In [None]:
def tokenize_data(examples):
    return tokenizer(
        examples['sentence'],
        truncation=True,
        padding='max_length',
        max_length=128,
        return_tensors='pt'
    )

In [None]:
train_tokenized = train_data.map(tokenize_data, batched=True)
val_tokenized = val_data.map(tokenize_data, batched=True)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [None]:
train_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
val_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

In [None]:
def distillation_loss(student_logits, teacher_logits, labels, temperature=4.0, alpha=0.7):
    """
    Compute knowledge distillation loss

    Args:
        student_logits: Output logits from student model
        teacher_logits: Output logits from teacher model
        labels: True labels
        temperature: Temperature for softening probability distributions
        alpha: Weight between distillation loss and task loss
    """
    # Distillation loss (KL divergence between soft predictions)
    soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
    soft_student = F.log_softmax(student_logits / temperature, dim=1)

    distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature ** 2)

    # Task loss (standard cross-entropy with true labels)
    task_loss = F.cross_entropy(student_logits, labels)

    # Combined loss
    total_loss = alpha * distill_loss + (1 - alpha) * task_loss

    return total_loss, distill_loss, task_loss

In [None]:
# Training parameters
learning_rate = 2e-5
batch_size = 16
num_epochs = 3
temperature = 4.0
alpha = 0.7

In [None]:
train_loader = DataLoader(train_tokenized, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_tokenized, batch_size=batch_size, shuffle=False)

In [None]:
optimizer = torch.optim.AdamW(student_model.parameters(), lr=learning_rate)

In [None]:
teacher_model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [None]:
def train_student_with_distillation():
    student_model.train()

    for epoch in range(num_epochs):
        total_loss = 0
        total_distill_loss = 0
        total_task_loss = 0

        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        for batch_idx, batch in enumerate(train_loader):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Get teacher predictions (no gradient computation)
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

            # Get student predictions
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits

            # Calculate distillation loss
            loss, distill_loss, task_loss = distillation_loss(
                student_logits, teacher_logits, labels, temperature, alpha
            )

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Track losses
            total_loss += loss.item()
            total_distill_loss += distill_loss.item()
            total_task_loss += task_loss.item()

            # Print progress every 20 batches
            if (batch_idx + 1) % 20 == 0:
                print(f"  Batch {batch_idx + 1}: Loss={loss.item():.4f}, "
                      f"Distill={distill_loss.item():.4f}, Task={task_loss.item():.4f}")

        # Print epoch summary
        avg_loss = total_loss / len(train_loader)
        avg_distill = total_distill_loss / len(train_loader)
        avg_task = total_task_loss / len(train_loader)

        print(f"  Epoch {epoch + 1} Summary:")
        print(f"    Average Total Loss: {avg_loss:.4f}")
        print(f"    Average Distillation Loss: {avg_distill:.4f}")
        print(f"    Average Task Loss: {avg_task:.4f}")

In [None]:
# Start training
print("Starting knowledge distillation training...")
train_student_with_distillation()
print("\nTraining completed!")

Starting knowledge distillation training...

Epoch 1/3
  Batch 20: Loss=0.2148, Distill=0.0152, Task=0.6805
  Batch 40: Loss=0.2057, Distill=0.0249, Task=0.6274
  Batch 60: Loss=0.1908, Distill=0.0250, Task=0.5776
  Epoch 1 Summary:
    Average Total Loss: 0.2058
    Average Distillation Loss: 0.0210
    Average Task Loss: 0.6370

Epoch 2/3
  Batch 20: Loss=0.1784, Distill=0.0185, Task=0.5516
  Batch 40: Loss=0.1843, Distill=0.0296, Task=0.5453
  Batch 60: Loss=0.1876, Distill=0.0304, Task=0.5545
  Epoch 2 Summary:
    Average Total Loss: 0.1840
    Average Distillation Loss: 0.0401
    Average Task Loss: 0.5196

Epoch 3/3
  Batch 20: Loss=0.1893, Distill=0.0538, Task=0.5054
  Batch 40: Loss=0.1705, Distill=0.0506, Task=0.4502
  Batch 60: Loss=0.1806, Distill=0.0404, Task=0.5077
  Epoch 3 Summary:
    Average Total Loss: 0.1754
    Average Distillation Loss: 0.0462
    Average Task Loss: 0.4768

Training completed!


In [None]:
def evaluate_model(model, data_loader, model_name):
    model.eval()
    predictions = []
    true_labels = []
    total_loss = 0

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            total_loss += loss.item()

            # Get predictions
            preds = torch.argmax(logits, dim=1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='weighted')
    avg_loss = total_loss / len(data_loader)

    print(f"\n{model_name} Evaluation Results:")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  F1 Score: {f1:.4f}")
    print(f"  Average Loss: {avg_loss:.4f}")

    return accuracy, f1, avg_loss

In [None]:
# Evaluate teacher model
teacher_acc, teacher_f1, teacher_loss = evaluate_model(teacher_model, val_loader, "Teacher (BERT)")


Teacher (BERT) Evaluation Results:
  Accuracy: 0.5000
  F1 Score: 0.3452
  Average Loss: 0.7184


In [None]:
student_acc, student_f1, student_loss = evaluate_model(student_model, val_loader, "Distilled Student (DistilBERT)")


Distilled Student (DistilBERT) Evaluation Results:
  Accuracy: 0.8400
  F1 Score: 0.8382
  Average Loss: 0.5492


In [None]:
acc_retention = (student_acc / teacher_acc) * 100 if teacher_acc > 0 else 0
f1_retention = (student_f1 / teacher_f1) * 100 if teacher_f1 > 0 else 0

print(f"\nPerformance Comparison:")
print(f"  Accuracy Retention: {acc_retention:.2f}%")
print(f"  F1 Score Retention: {f1_retention:.2f}%")


Performance Comparison:
  Accuracy Retention: 168.00%
  F1 Score Retention: 242.79%


In [None]:
teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_model.parameters())
size_reduction = ((teacher_params - student_params) / teacher_params) * 100

print(f"\nModel Efficiency:")
print(f"  Teacher Parameters: {teacher_params:,}")
print(f"  Student Parameters: {student_params:,}")
print(f"  Size Reduction: {size_reduction:.2f}%")


Model Efficiency:
  Teacher Parameters: 109,483,778
  Student Parameters: 66,955,010
  Size Reduction: 38.84%


In [None]:
def test_models_with_examples(examples):
    teacher_model.eval()
    student_model.eval()

    print("Testing models with custom examples:")
    print("=" * 50)

    for i, text in enumerate(examples):
        # Tokenize input
        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=128)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get predictions
        with torch.no_grad():
            teacher_outputs = teacher_model(**inputs)

            # Create student inputs by excluding 'token_type_ids'
            student_inputs = {k: v for k, v in inputs.items() if k != 'token_type_ids'}
            student_outputs = student_model(**student_inputs)

            teacher_probs = F.softmax(teacher_outputs.logits, dim=1)
            student_probs = F.softmax(student_outputs.logits, dim=1)

            teacher_pred = torch.argmax(teacher_probs, dim=1).item()
            student_pred = torch.argmax(student_probs, dim=1).item()

        # Convert predictions to labels
        labels = ['Negative', 'Positive']

        print(f"\nExample {i+1}: '{text}'")
        print(f"Teacher Prediction: {labels[teacher_pred]} (confidence: {teacher_probs[0][teacher_pred]:.4f})")
        print(f"Student Prediction: {labels[student_pred]} (confidence: {student_probs[0][student_pred]:.4f})")
        print(f"Agreement: {'✓' if teacher_pred == student_pred else '✗'}")

In [None]:
test_examples = [
    "This movie was absolutely fantastic!",
    "I hated this film, it was terrible.",
    "The movie was okay, nothing special.",
    "Best movie I've ever seen!",
    "Worst experience ever."
]

test_models_with_examples(test_examples)

Testing models with custom examples:

Example 1: 'This movie was absolutely fantastic!'
Teacher Prediction: Negative (confidence: 0.5789)
Student Prediction: Positive (confidence: 0.5834)
Agreement: ✗

Example 2: 'I hated this film, it was terrible.'
Teacher Prediction: Negative (confidence: 0.6406)
Student Prediction: Negative (confidence: 0.7034)
Agreement: ✓

Example 3: 'The movie was okay, nothing special.'
Teacher Prediction: Negative (confidence: 0.5556)
Student Prediction: Negative (confidence: 0.5717)
Agreement: ✓

Example 4: 'Best movie I've ever seen!'
Teacher Prediction: Negative (confidence: 0.5909)
Student Prediction: Positive (confidence: 0.6234)
Agreement: ✗

Example 5: 'Worst experience ever.'
Teacher Prediction: Negative (confidence: 0.5885)
Student Prediction: Negative (confidence: 0.6963)
Agreement: ✓
