In [None]:
"""
    DistilBERT training via knowledge distillation from BERT using PyTorch and Hugging Face Transformers.
"""
# !pip install torch transformers datasets

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import SequentialLR, LinearLR, StepLR
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AdamW 
from tqdm import tqdm

In [None]:
"""
    AdamW is a popular optimization algorithm in deep learning, especially well-suited for training models like Transformers (e.g., BERT, GPT, etc.). 
    It is a variant of the Adam optimizer that introduces a correct way to apply weight decay (L2 regularization).
    AdamW helps prevent overfitting while maintaining the benefits of Adam (adaptive learning rates, momentum).
    It is the default optimizer in Hugging Face Transformers and many other frameworks for fine-tuning pre-trained language models.
"""

In [None]:
# Load and Preprocess the Dataset
imdb_dataset = load_dataset("imdb", split='train')
imdb_dataset_shuffle = imdb_dataset.shuffle(seed=42)  # Shuffle the full train split
dataset = imdb_dataset_shuffle.select(range(int(0.1 * len(imdb_dataset_shuffle))))  # Take 10% of the shuffled dataset randomly
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # Loads the tokenizer that corresponds to bert-base-uncased
"""
    The IMDB dataset is a popular dataset used for binary sentiment classification—determining whether a movie review is positive or negative.
"""

In [None]:
# Tokenizing the Text
def encode_batch(batch):
    return tokenizer(batch['text'], truncation=True, padding='max_length', max_length=256)
"""
    encode_batch(): Tokenizes batches of texts, with:
        - truncation: Cuts long reviews down to max_length.
        - padding='max_length': Ensures uniform tensor sizes.
        - max_length=256: Keeps sequences to 256 tokens max.
"""

dataset = dataset.map(encode_batch, batched=True) # Applies tokenizer to all samples 
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) # Converts dataset to PyTorch tensors
"""
    It ensures that each time you access a sample from the dataset, it returns a dictionary like:
        {
            'input_ids': tensor(...),
            'attention_mask': tensor(...),
            'label': tensor(...)
        }
"""

print(">>> length of dataset: ", len(dataset['train']))
print(">>> labels of dataset: ", dataset['train']['label']) # Check if the dataset slice is somehow balanced 

print(">>> first sample from dataset: ")
print(dataset['train'][0])
print(">>> shape of input_ids: ", dataset['train'][0]['input_ids'].shape)
print(">>> shape of attention_mask: ", dataset['train'][0]['attention_mask'].shape)
print(">>> shape of label: ", dataset['train'][0]['label'].shape)

In [None]:
# Apply DataLoader
train_loader = DataLoader(dataset['train'], batch_size=8, shuffle=True) # Prepares data loader for batching
print(">>> length of train_loader: ", len(train_loader))

first_batch = next(iter(train_loader)) # Get first batch
print(">>> shape of input_ids in first_batch: ", first_batch['input_ids'].shape)
print(">>> shape of attention_mask in first_batch: ", first_batch['attention_mask'].shape)
print(">>> shape of label in first_batch: ", first_batch['label'].shape)

In [None]:
# Load Teacher and Student Models
teacher_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) # Full BERT model fine-tuned for sequence classification
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) # Lighter, faster DistilBERT model also for classification

teacher_model.eval() # Freezes teacher_model for inference only
student_model.train() # Prepares student_model for training

In [None]:
# Define Distillation Loss
class DistillationLoss(nn.Module):

    def __init__(self, temperature=2.0, alpha=0.5):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, true_labels):
        eps = 1e-7  # small epsilon to prevent log(0)

        # Soft targets: distillation loss
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
        soft_teacher = torch.clamp(soft_teacher, min=eps, max=1.0) # Clamp to avoid log(0) -> -inf
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        kd_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.temperature ** 2)
        """
            F.kl_div(input, target) expects:
                - input: log-probabilities (i.e., log_softmax)
                - target: probabilities (i.e., softmax)
            If both are given as softmax, the KL loss can go negative or become numerically unstable.
        """

        # Hard targets: standard classification loss
        ce_loss = self.ce_loss(student_logits, true_labels)

        # Total loss
        return self.alpha * kd_loss + (1. - self.alpha) * ce_loss

In [None]:
# Set Up Optimizer, Distillation Loss and Learning Rate Scheduler
optimizer = AdamW(student_model.parameters(), lr=1e-4)
kd_loss_fn = DistillationLoss(temperature=2.0, alpha=0.5)

# Warm-up scheduler (over 4 epochs)
warmup_scheduler = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=5)
# After warm-up, switch to StepLR
main_scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
# Combine warmup_scheduler with main_scheduler
scheduler = SequentialLR(optimizer,
                         schedulers=[warmup_scheduler, main_scheduler],
                         milestones=[5] # switch from warm-up to StepLR after 5 epochs
                        )

In [None]:
# Setup Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"CUDA available: Yes")
    print(f"Total GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"\n--- GPU {i} ---")
        print(f"Name: {torch.cuda.get_device_name(i)}")
        print(f"Capability: {torch.cuda.get_device_capability(i)}")
        print(f"Memory Allocated: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")
        print(f"Memory Reserved: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")
else:
    print("CUDA available: No. Using CPU.")

teacher_model.to(device)
student_model.to(device)

In [None]:
# Training Loop with Knowledge Distillation
epochs = 20

for epoch in range(epochs):
    total_loss = 0

    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device) # torch.Size([8, 256])
        attention_mask = batch['attention_mask'].to(device) # torch.Size([8, 256])
        labels = batch['label'].to(device) # torch.Size([8])
    
        with torch.no_grad(): # Runs the teacher model in inference mode (no gradients computed)
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits 
            """
                For BERT, the return value (teacher_logits) is specifically a SequenceClassifierOutput object:

                    SequenceClassifierOutput(
                        loss=None,  # only present if labels are passed
                        logits=tensor(...),
                        hidden_states=None,  # optional
                        attentions=None      # optional
                    )
            """

        student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits # Shape: [batch_size, num_classes] -> torch.Size([8, 2])

        loss = kd_loss_fn(student_logits, teacher_logits, labels)
        if torch.isnan(loss):
            print("Loss is NaN! Debug info:")
            print("Student logits:", student_logits)
            print("Teacher logits:", teacher_logits)
            print("Labels:", labels)
            break 

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    scheduler.step()
    current_lr = optimizer.param_groups[-1]['lr']
    print(f">>> Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, LR: {current_lr:.6f}")

In [None]:
# Save DistilBERT and BERT models
torch.save(student_model.state_dict(), 'DistilBERT_model_weights.pth')
torch.save(teacher_model.state_dict(), 'BERT_base_model_weights.pth')

# Load the state_dict
student_model.load_state_dict(torch.load('DistilBERT_model_weights.pth'))
teacher_model.load_state_dict(torch.load('BERT_base_model_weights.pth'))

In [None]:
# Test Loop for Evaluation
test_loader = DataLoader(dataset['test'], batch_size=4) # Prepares data loader for batching
print(">>> length of test_loader: ", len(test_loader))

student_model.eval()
correct = 0
total = 0
all_preds = []
all_labels = []

with torch.no_grad(): # No need to compute gradients during evaluation
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device) # torch.Size([4, 256])
        attention_mask = batch['attention_mask'].to(device) # torch.Size([4, 256])
        labels = batch['label'].to(device) # torch.Size([4])

        student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits
        predictions = torch.argmax(student_logits, dim=1)

        correct += (predictions == labels).sum().item()
        total += labels.size(0)

        all_preds.extend(predictions.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

accuracy = correct / total
print(f">>> Test Accuracy of DistilBERT: {accuracy * 100:.2f}%")

In [None]:
# Total number of parameters for DistilBERT
total_params = sum(p.numel() for p in student_model.parameters())
print(f">>> Total parameters of DistilBERT: {total_params}")
# Total number of trainable parameters for DistilBERT
trainable_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
print(f">>> Trainable parameters of DistilBERT: {trainable_params}")
# Model's file size of DistilBERT
file_size = os.path.getsize("DistilBERT_model_weights.pth") / (1024 ** 2) # size in MB
print(f">>> Model file size of DistilBERT: {file_size:.2f} MB")

print("---------------------------------------------------")

# Total number of parameters for BERT
total_params = sum(p.numel() for p in teacher_model.parameters())
print(f">>> Total parameters of BERT: {total_params}")
# Total number of trainable parameters for BERT
trainable_params = sum(p.numel() for p in teacher_model.parameters() if p.requires_grad)
print(f">>> Trainable parameters of BERT: {trainable_params}")
# Model's file size of BERT
file_size = os.path.getsize("BERT_base_model_weights.pth") / (1024 ** 2) # size in MB
print(f">>> Model file size of BERT: {file_size:.2f} MB")