# Knowledge Distillation with BERT

## Overview
An advanced implementation of knowledge distillation using a custom distillation loss for sentiment classification.

In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification, BertTokenizer, BertModel
from datasets import load_dataset
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# Import the custom distillation loss
import sys
sys.path.append('/app/src')
from loss.distillation import DistillationLoss

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
# Load dataset
dataset = load_dataset('glue', 'sst2')['train']

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize and prepare data
def prepare_data(examples, max_length=128):
    encodings = tokenizer(
        examples['sentence'], 
        truncation=True, 
        padding='max_length', 
        max_length=max_length
    )
    
    return {
        'input_ids': torch.tensor(encodings['input_ids']),
        'attention_mask': torch.tensor(encodings['attention_mask']),
        'labels': torch.tensor(examples['label'])
    }

# Prepare dataset
processed_data = prepare_data(dataset[:1000])  # Use first 1000 samples for demonstration

In [3]:
# Custom Student Model with Reduced Complexity
class StudentBertModel(nn.Module):
    def __init__(self, num_labels=2, hidden_size=768, num_hidden_layers=6):
        super(StudentBertModel, self).__init__()
        
        # Base BERT model with reduced layers
        self.bert = BertModel.from_pretrained('bert-base-uncased', 
                                               num_hidden_layers=num_hidden_layers)
        
        # Simplified classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, num_labels)
        )
        
    def forward(self, input_ids, attention_mask):
        # Get BERT outputs
        bert_outputs = self.bert(
            input_ids=input_ids, 
            attention_mask=attention_mask
        )
        
        # Use the pooled output for classification
        pooled_output = bert_outputs.pooler_output
        logits = self.classifier(pooled_output)
        
        return logits
    
    def get_attention_and_values(self, input_ids, attention_mask):
        # Get BERT outputs including attention
        bert_outputs = self.bert(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            output_attentions=True
        )
        
        # Return attention and values
        return (
            bert_outputs.attentions,  # Attention weights
            bert_outputs.last_hidden_state  # Value vectors
        )

In [4]:
# Training Hyperparameters
batch_size = 32
num_epochs = 100
learning_rate = 2e-5

# Create DataLoader
dataset = TensorDataset(
    processed_data['input_ids'], 
    processed_data['attention_mask'], 
    processed_data['labels']
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize teacher and student models
teacher_model = BertForSequenceClassification.from_pretrained('bert-base-uncased').to(device)
teacher_base_model = teacher_model.bert  # Extract the base BERT model
student_model = StudentBertModel(num_labels=2).to(device)

# Freeze teacher model
for param in teacher_model.parameters():
    param.requires_grad = False

# Training Setup
optimizer = torch.optim.AdamW(student_model.parameters(), lr=2e-5)
distillation_criterion = DistillationLoss(temperature=2.0).to(device)
classification_criterion = nn.CrossEntropyLoss()



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# Training Loop
for epoch in range(num_epochs):
    student_model.train()
    teacher_model.eval()
    total_loss = 0
    
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
    for batch in progress_bar:
        # Unpack batch
        input_ids, attention_mask, labels = [b.to(device) for b in batch]
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass for teacher
        with torch.no_grad():
            teacher_output = teacher_model(input_ids, attention_mask=attention_mask)
            teacher_bert_output = teacher_base_model(
                input_ids, 
                attention_mask=attention_mask, 
                output_attentions=True
            )
            teacher_attention = teacher_bert_output.attentions
            teacher_values = teacher_bert_output.last_hidden_state
        
        # Forward pass for student
        student_logits = student_model(input_ids, attention_mask)
        student_attention, student_values = student_model.get_attention_and_values(input_ids, attention_mask)
        
        # Compute distillation loss
        # Aggregate attention and values across layers
        student_A = torch.mean(torch.stack(student_attention), dim=0)
        teacher_A = torch.mean(torch.stack(teacher_attention), dim=0)
        
        knowledge_loss = distillation_criterion(
            teacher_A, teacher_values,
            student_A, student_values
        )
        
        # Compute classification loss
        classification_loss = classification_criterion(student_logits, labels)
        
        # Combine losses
        total_batch_loss = knowledge_loss
        
        # Backward pass
        total_batch_loss.backward()
        optimizer.step()
        
        total_loss += total_batch_loss.item()
        progress_bar.set_description(f'Epoch {epoch+1}/{num_epochs}')
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}')

# Save student model
torch.save(student_model.state_dict(), './saved_models/student_model.pt')

                                                            

Epoch 1/100, Loss: 46.6790


                                                            

Epoch 2/100, Loss: 25.1429


                                                            

Epoch 3/100, Loss: 19.3854


                                                            

Epoch 4/100, Loss: 17.0347


                                                            

Epoch 5/100, Loss: 15.4991


                                                            

Epoch 6/100, Loss: 14.4098


                                                            

Epoch 7/100, Loss: 13.6065


                                                            

Epoch 8/100, Loss: 12.9592


                                                            

Epoch 9/100, Loss: 12.5766


                                                             

Epoch 10/100, Loss: 12.1409


                                                             

Epoch 11/100, Loss: 11.6569


                                                             

Epoch 12/100, Loss: 11.2542


                                                             

Epoch 13/100, Loss: 10.8037


                                                             

Epoch 14/100, Loss: 10.3995


                                                             

Epoch 15/100, Loss: 9.9858


                                                             

Epoch 16/100, Loss: 9.4671


                                                             

Epoch 17/100, Loss: 9.0859


                                                             

Epoch 18/100, Loss: 8.6086


                                                             

Epoch 19/100, Loss: 8.2478


                                                             

Epoch 20/100, Loss: 7.7973


                                                             

Epoch 21/100, Loss: 7.4185


                                                             

Epoch 22/100, Loss: 7.1244


                                                             

Epoch 23/100, Loss: 6.9500


                                                             

Epoch 24/100, Loss: 6.7708


                                                             

Epoch 25/100, Loss: 6.5917


                                                             

Epoch 26/100, Loss: 6.3951


                                                             

Epoch 27/100, Loss: 6.2374


                                                             

Epoch 28/100, Loss: 6.1719


                                                             

Epoch 29/100, Loss: 5.9977


                                                             

Epoch 30/100, Loss: 5.8589


                                                             

Epoch 31/100, Loss: 5.7256


                                                             

Epoch 32/100, Loss: 5.6329


                                                             

Epoch 33/100, Loss: 5.4807


                                                             

Epoch 34/100, Loss: 5.4687


                                                             

Epoch 35/100, Loss: 5.4461


                                                             

Epoch 36/100, Loss: 5.2989


                                                             

Epoch 37/100, Loss: 5.2298


                                                             

Epoch 38/100, Loss: 5.1348


                                                             

Epoch 39/100, Loss: 5.0557


                                                             

Epoch 40/100, Loss: 4.9366


                                                             

Epoch 41/100, Loss: 4.9010


                                                             

Epoch 42/100, Loss: 4.9128


                                                             

Epoch 43/100, Loss: 4.7542


                                                             

Epoch 44/100, Loss: 4.7339


                                                             

Epoch 45/100, Loss: 4.6436


                                                             

Epoch 46/100, Loss: 4.5578


                                                             

Epoch 47/100, Loss: 4.5086


                                                             

Epoch 48/100, Loss: 4.5226


                                                             

Epoch 49/100, Loss: 4.3983


                                                             

Epoch 50/100, Loss: 4.3765


                                                             

Epoch 51/100, Loss: 4.3166


                                                             

Epoch 52/100, Loss: 4.3178


                                                             

Epoch 53/100, Loss: 4.2163


                                                             

Epoch 54/100, Loss: 4.1268


                                                             

Epoch 55/100, Loss: 4.1149


                                                             

Epoch 56/100, Loss: 4.1037


                                                             

Epoch 57/100, Loss: 4.0182


                                                             

Epoch 58/100, Loss: 3.9763


                                                             

Epoch 59/100, Loss: 3.9158


                                                             

Epoch 60/100, Loss: 3.8992


                                                             

Epoch 61/100, Loss: 3.9112


                                                             

Epoch 62/100, Loss: 3.7890


                                                             

Epoch 63/100, Loss: 3.7446


                                                             

Epoch 64/100, Loss: 3.7338


                                                             

Epoch 65/100, Loss: 3.7657


                                                             

Epoch 66/100, Loss: 3.7190


                                                             

Epoch 67/100, Loss: 3.6389


                                                             

Epoch 68/100, Loss: 3.6548


                                                             

Epoch 69/100, Loss: 3.5790


                                                             

Epoch 70/100, Loss: 3.5443


                                                             

Epoch 71/100, Loss: 3.5191


                                                             

Epoch 72/100, Loss: 3.5111


                                                             

Epoch 73/100, Loss: 3.4685


                                                             

Epoch 74/100, Loss: 3.4081


                                                             

Epoch 75/100, Loss: 3.4019


                                                             

Epoch 76/100, Loss: 3.4132


                                                             

Epoch 77/100, Loss: 3.3539


                                                             

Epoch 78/100, Loss: 3.3524


                                                             

Epoch 79/100, Loss: 3.2911


                                                             

Epoch 80/100, Loss: 3.2573


                                                             

Epoch 81/100, Loss: 3.2363


                                                             

Epoch 82/100, Loss: 3.2045


                                                             

Epoch 83/100, Loss: 3.1793


                                                             

Epoch 84/100, Loss: 3.1681


                                                             

Epoch 85/100, Loss: 3.1076


                                                             

Epoch 86/100, Loss: 3.0927


                                                             

Epoch 87/100, Loss: 3.0599


                                                             

Epoch 88/100, Loss: 3.0603


                                                             

Epoch 89/100, Loss: 3.0501


                                                             

Epoch 90/100, Loss: 3.0236


                                                             

Epoch 91/100, Loss: 3.0010


                                                             

Epoch 92/100, Loss: 2.9923


                                                             

Epoch 93/100, Loss: 2.9531


                                                             

Epoch 94/100, Loss: 2.9261


                                                             

Epoch 95/100, Loss: 2.9155


                                                             

Epoch 96/100, Loss: 2.9028


                                                             

Epoch 97/100, Loss: 2.8747


                                                             

Epoch 98/100, Loss: 2.8997


                                                             

Epoch 99/100, Loss: 2.8340


                                                              

Epoch 100/100, Loss: 2.8285


RuntimeError: Parent directory ./saved_models does not exist.

In [6]:
torch.save(student_model.state_dict(), 'student_model.pt')
