In [None]:
"""
    DistilBERT training via knowledge distillation from BERT using PyTorch and Hugging Face Transformers.
"""
# !pip install torch transformers datasets

In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AdamW 
from tqdm import tqdm

In [None]:
"""
    AdamW is a popular optimization algorithm in deep learning, especially well-suited for training models like Transformers (e.g., BERT, GPT, etc.). 
    It is a variant of the Adam optimizer that introduces a correct way to apply weight decay (L2 regularization).
    AdamW helps prevent overfitting while maintaining the benefits of Adam (adaptive learning rates, momentum).
    It is the default optimizer in Hugging Face Transformers and many other frameworks for fine-tuning pre-trained language models.
"""

In [None]:
# Load and Preprocess the Dataset
imdb_dataset = load_dataset("imdb", split='train')
imdb_dataset_shuffle = imdb_dataset.shuffle(seed=42)  # Shuffle the full train split
dataset = imdb_dataset_shuffle.select(range(int(0.02 * len(imdb_dataset_shuffle))))  # Take 2% of the shuffled dataset randomly
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # Loads the tokenizer that corresponds to bert-base-uncased
"""
    The IMDB dataset is a popular dataset used for binary sentiment classification—determining whether a movie review is positive or negative.
"""

In [93]:
# Tokenizing the Text
def encode_batch(batch):
    return tokenizer(batch['text'], truncation=True, padding='max_length', max_length=256)
"""
    encode_batch(): Tokenizes batches of texts, with:
        - truncation: Cuts long reviews down to max_length.
        - padding='max_length': Ensures uniform tensor sizes.
        - max_length=256: Keeps sequences to 256 tokens max.
"""

dataset = dataset.map(encode_batch, batched=True) # Applies tokenizer to all samples 
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) # Converts dataset to PyTorch tensors
"""
    It ensures that each time you access a sample from the dataset, it returns a dictionary like:
        {
            'input_ids': tensor(...),
            'attention_mask': tensor(...),
            'label': tensor(...)
        }
"""

print(">>> length of dataset: ", len(dataset['train']))
print(">>> labels of dataset: ", dataset['train']['label'])

print(">>> first sample from dataset: ")
print(dataset['train'][0])
print(">>> shape of input_ids: ", dataset['train'][0]['input_ids'].shape)
print(">>> shape of attention_mask: ", dataset['train'][0]['attention_mask'].shape)
print(">>> shape of label: ", dataset['train'][0]['label'].shape)

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

>>> length of dataset:  400
>>> labels of dataset:  tensor([1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0,
        0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1,
        1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1,
        1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1,
        1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0,
        1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1,
        0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1,
        0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0,
        0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 

In [94]:
# Apply DataLoader
train_loader = DataLoader(dataset['train'], batch_size=8, shuffle=True) # Prepares data loader for batching
print(">>> length of train_loader: ", len(train_loader))

first_batch = next(iter(train_loader)) # Get first batch
print(">>> shape of input_ids in first_batch: ", first_batch['input_ids'].shape)
print(">>> shape of attention_mask in first_batch: ", first_batch['attention_mask'].shape)
print(">>> shape of label in first_batch: ", first_batch['label'].shape)

>>> length of train_loader:  50
>>> shape of input_ids in first_batch:  torch.Size([8, 256])
>>> shape of attention_mask in first_batch:  torch.Size([8, 256])
>>> shape of label in first_batch:  torch.Size([8])


In [None]:
# Load Teacher and Student Models
teacher_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) # Full BERT model fine-tuned for sequence classification
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) # Lighter, faster DistilBERT model also for classification

teacher_model.eval() # Freezes teacher_model for inference only
student_model.train() # Prepares student_model for training

In [97]:
# Define Distillation Loss
class DistillationLoss(nn.Module):

    def __init__(self, temperature=2.0, alpha=0.5):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, true_labels):
        eps = 1e-7  # small epsilon to prevent log(0)

        # Soft targets: distillation loss
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
        soft_teacher = torch.clamp(soft_teacher, min=eps, max=1.0) # Clamp to avoid log(0) -> -inf
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        kd_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.temperature ** 2)
        """
            F.kl_div(input, target) expects:
                - input: log-probabilities (i.e., log_softmax)
                - target: probabilities (i.e., softmax)
            If both are given as softmax, the KL loss can go negative or become numerically unstable.
        """

        # Hard targets: standard classification loss
        ce_loss = self.ce_loss(student_logits, true_labels)

        # Total loss
        return self.alpha * kd_loss + (1. - self.alpha) * ce_loss

In [98]:
# Set Up Optimizer and Distillation Loss
optimizer = AdamW(student_model.parameters(), lr=1e-5)
kd_loss_fn = DistillationLoss(temperature=2.0, alpha=0.5)

In [99]:
# Setup Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"CUDA available: Yes")
    print(f"Total GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"\n--- GPU {i} ---")
        print(f"Name: {torch.cuda.get_device_name(i)}")
        print(f"Capability: {torch.cuda.get_device_capability(i)}")
        print(f"Memory Allocated: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")
        print(f"Memory Reserved: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")
else:
    print("CUDA available: No. Using CPU.")

teacher_model.to(device)
student_model.to(device)

CUDA available: Yes
Total GPUs: 1

--- GPU 0 ---
Name: NVIDIA RTX 2000 Ada Generation Laptop GPU
Capability: (8, 9)
Memory Allocated: 1055.77 MB
Memory Reserved: 3620.00 MB


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [100]:
# Training Loop with Knowledge Distillation
epochs = 12

for epoch in range(epochs):
    total_loss = 0

    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device) # torch.Size([8, 256])
        attention_mask = batch['attention_mask'].to(device) # torch.Size([8, 256])
        labels = batch['label'].to(device) # torch.Size([8])
    
        with torch.no_grad(): # Runs the teacher model in inference mode (no gradients computed)
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits 
            """
                For BERT, the return value (teacher_logits) is specifically a SequenceClassifierOutput object:

                    SequenceClassifierOutput(
                        loss=None,  # only present if labels are passed
                        logits=tensor(...),
                        hidden_states=None,  # optional
                        attentions=None      # optional
                    )
            """

        student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits

        loss = kd_loss_fn(student_logits, teacher_logits, labels)
        if torch.isnan(loss):
            print("Loss is NaN! Debug info:")
            print("Student logits:", student_logits)
            print("Teacher logits:", teacher_logits)
            print("Labels:", labels)
            break 

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f">>> Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

100%|██████████| 50/50 [00:14<00:00,  3.56it/s]


>>> Epoch 1, Loss: 0.3520


100%|██████████| 50/50 [00:13<00:00,  3.79it/s]


>>> Epoch 2, Loss: 0.3068


100%|██████████| 50/50 [00:13<00:00,  3.74it/s]


>>> Epoch 3, Loss: 0.2654


100%|██████████| 50/50 [00:13<00:00,  3.69it/s]


>>> Epoch 4, Loss: 0.2430


100%|██████████| 50/50 [00:14<00:00,  3.54it/s]


>>> Epoch 5, Loss: 0.2356


100%|██████████| 50/50 [00:13<00:00,  3.71it/s]


>>> Epoch 6, Loss: 0.2317


100%|██████████| 50/50 [00:13<00:00,  3.66it/s]


>>> Epoch 7, Loss: 0.2311


100%|██████████| 50/50 [00:13<00:00,  3.70it/s]


>>> Epoch 8, Loss: 0.2301


100%|██████████| 50/50 [00:13<00:00,  3.63it/s]


>>> Epoch 9, Loss: 0.2296


100%|██████████| 50/50 [00:14<00:00,  3.48it/s]


>>> Epoch 10, Loss: 0.2293


100%|██████████| 50/50 [00:13<00:00,  3.72it/s]


>>> Epoch 11, Loss: 0.2285


100%|██████████| 50/50 [00:13<00:00,  3.72it/s]

>>> Epoch 12, Loss: 0.2283



