# Knowledge Distillation with BERT

## Overview
A simple implementation of knowledge distillation using BERT for sentiment classification.

In [2]:
# Import necessary libraries
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification, BertTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, TensorDataset

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [3]:
# Load dataset
dataset = load_dataset('glue', 'sst2')['train']

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize and prepare data
def prepare_data(examples, max_length=128):
    encodings = tokenizer(
        examples['sentence'], 
        truncation=True, 
        padding='max_length', 
        max_length=max_length
    )
    
    return {
        'input_ids': torch.tensor(encodings['input_ids']),
        'attention_mask': torch.tensor(encodings['attention_mask']),
        'labels': torch.tensor(examples['label'])
    }

# Prepare dataset
processed_data = prepare_data(dataset[:1000])  # Use first 1000 samples for demonstration

Generating train split: 100%|██████████| 67349/67349 [00:00<00:00, 1957806.98 examples/s]
Generating validation split: 100%|██████████| 872/872 [00:00<00:00, 599186.29 examples/s]
Generating test split: 100%|██████████| 1821/1821 [00:00<00:00, 1031163.44 examples/s]


In [4]:
# Knowledge Distillation Loss Function
def knowledge_distillation_loss(student_outputs, teacher_outputs, labels, temperature=2.0, alpha=0.5):
    # Soft target loss
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        torch.log_softmax(student_outputs / temperature, dim=1),
        torch.softmax(teacher_outputs / temperature, dim=1)
    )
    
    # Hard target loss
    hard_loss = nn.CrossEntropyLoss()(student_outputs, labels)
    
    # Combine losses
    return alpha * soft_loss + (1 - alpha) * hard_loss

In [5]:
# Create DataLoader
dataset = TensorDataset(
    processed_data['input_ids'], 
    processed_data['attention_mask'], 
    processed_data['labels']
)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize teacher and student models
teacher_model = BertForSequenceClassification.from_pretrained('bert-base-uncased').to(device)
student_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2).to(device)

# Freeze teacher model
for param in teacher_model.parameters():
    param.requires_grad = False

# Training Setup
optimizer = torch.optim.AdamW(student_model.parameters(), lr=2e-5)
epochs = 3

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Training Loop
for epoch in range(epochs):
    student_model.train()
    teacher_model.eval()
    total_loss = 0
    
    for batch in dataloader:
        # Unpack batch
        input_ids, attention_mask, labels = [b.to(device) for b in batch]
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        with torch.no_grad():
            teacher_output = teacher_model(input_ids, attention_mask=attention_mask)
        
        student_output = student_model(input_ids, attention_mask=attention_mask)
        
        # Compute loss
        loss = knowledge_distillation_loss(
            student_output.logits, 
            teacher_output.logits, 
            labels
        )
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f'Epoch {epoch+1}, Loss: {total_loss/len(dataloader)}')

# Save student model
torch.save(student_model.state_dict(), 'student_model.pt')

Epoch 1, Loss: 0.32823586091399193
Epoch 2, Loss: 0.21574989007785916
Epoch 3, Loss: 0.17426472157239914
