# Knowledge Distillation with BERT on GLUE Benchmark

In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification, BertTokenizer, BertModel
from datasets import load_dataset
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# Import the custom distillation loss
import sys
sys.path.append('/app/src')
from loss.distillation import DistillationLoss

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [3]:
# Load GLUE dataset
glue_dataset = load_dataset('glue','sst2')

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize and prepare data
def prepare_data(examples, max_length=128):
    encodings = tokenizer(
        examples['sentence'], 
        truncation=True, 
        padding='max_length', 
        max_length=max_length
    )
    
    return {
        'input_ids': torch.tensor(encodings['input_ids']),
        'attention_mask': torch.tensor(encodings['attention_mask']),
        'labels': torch.tensor(examples['label'])
    }

# Prepare dataset for all GLUE tasks
glue_datasets = {}
for task in glue_dataset.keys():
    glue_datasets[task] = {
        split: prepare_data(glue_dataset[task][split])
        for split in ['train', 'validation', 'test']
    }

KeyError: "Column train not in the dataset. Current columns in the dataset: ['sentence', 'label', 'idx']"

In [None]:
# Custom Student Model with Reduced Complexity
class StudentBertModel(nn.Module):
    def __init__(self, num_labels=2, hidden_size=768, num_hidden_layers=3):
        super(StudentBertModel, self).__init__()
        
        # Base BERT model with reduced layers
        self.bert = BertModel.from_pretrained('bert-base-uncased', 
                                               num_hidden_layers=num_hidden_layers)
        
        # Simplified classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, num_labels)
        )
        
    def forward(self, input_ids, attention_mask):
        # Get BERT outputs
        bert_outputs = self.bert(
            input_ids=input_ids, 
            attention_mask=attention_mask
        )
        
        # Use the pooled output for classification
        pooled_output = bert_outputs.pooler_output
        logits = self.classifier(pooled_output)
        
        return logits
    
    def get_attention_and_values(self, input_ids, attention_mask):
        # Get BERT outputs including attention
        bert_outputs = self.bert(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            output_attentions=True
        )
        
        # Return attention and values
        return (
            bert_outputs.attentions,  # Attention weights
            bert_outputs.last_hidden_state  # Value vectors
        )

In [None]:
# Training Setup
batch_size = 32
num_epochs = 3
learning_rate = 2e-5

# Initialize teacher and student models
teacher_model = BertForSequenceClassification.from_pretrained('bert-base-uncased').to(device)
teacher_base_model = teacher_model.bert  # Extract the base BERT model
student_model = StudentBertModel(num_labels=2).to(device)

# Freeze teacher model
for param in teacher_model.parameters():
    param.requires_grad = False

# Optimizers and loss functions
optimizer = torch.optim.AdamW(student_model.parameters(), lr=learning_rate)
distillation_criterion = DistillationLoss(temperature=2.0).to(device)
classification_criterion = nn.CrossEntropyLoss()

# Training function
def train_model(task):
    train_dataset = TensorDataset(
        glue_datasets[task]['train']['input_ids'],
        glue_datasets[task]['train']['attention_mask'],
        glue_datasets[task]['train']['labels']
    )
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    student_model.train()
    teacher_model.eval()
    total_loss = 0

    progress_bar = tqdm(train_dataloader, desc=f'Training {task}', leave=False)
    for batch in progress_bar:
        # Unpack batch
        input_ids, attention_mask, labels = [b.to(device) for b in batch]
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass for teacher
        with torch.no_grad():
            teacher_output = teacher_model(input_ids, attention_mask=attention_mask)
            teacher_bert_output = teacher_base_model(
                input_ids, 
                attention_mask=attention_mask, 
                output_attentions=True
            )
            teacher_attention = teacher_bert_output.attentions
            teacher_values = teacher_bert_output.last_hidden_state
        
        # Forward pass for student
        student_logits = student_model(input_ids, attention_mask)
        student_attention, student_values = student_model.get_attention_and_values(input_ids, attention_mask)
        
        # Compute distillation loss
        # Aggregate attention and values across layers
        student_A = torch.mean(torch.stack(student_attention), dim=0)
        teacher_A = torch.mean(torch.stack(teacher_attention), dim=0)
        
        knowledge_loss = distillation_criterion(
            teacher_A, teacher_values,
            student_A, student_values
        )
        
        # Compute classification loss
        classification_loss = classification_criterion(student_logits, labels)
        
        # Combine losses
        total_batch_loss = 0.5 * knowledge_loss + 0.5 * classification_loss
        
        # Backward pass
        total_batch_loss.backward()
        optimizer.step()
        
        total_loss += total_batch_loss.item()
        progress_bar.set_description(f'Training {task}, Loss: {total_loss/(len(progress_bar)+1):.4f}')

    return total_loss / len(train_dataloader)

In [None]:
# Train on all GLUE tasks
for task in glue_dataset.keys():
    print(f'Training on {task} task')
    train_model(task)
    print(f'Finished training on {task} task')