In [None]:
%pip install datasets   # if needed

In [132]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import (BertTokenizer, BertForSequenceClassification, DistilBertConfig, DistilBertForSequenceClassification)
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import time
from datetime import timedelta

# =====================
# DEVICE CONFIGURATION
# =====================
# Use "mps" on macOS or "cuda" if available or "cpu" as fallback
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [133]:
# =================================
# LOAD AND PREPROCESS IMDB DATASET
# =================================
# Load the IMDb dataset (50.000 labeled reviews - 25.000 train / 25.000 test)
subset_size = 25000
dataset = load_dataset("imdb")
train_data = dataset["train"].shuffle(seed=42).select(range(subset_size))  # use subset_size random samples fro training
test_data = dataset["test"].shuffle(seed=42).select(range(subset_size))    # use subset_size random samples for testing

# Load BERT tokenizer (bert-base-uncased)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenization function (max length 128)
def tokenize(batch):
    return tokenizer(batch['text'], truncation=True, padding='max_length', max_length=128)

# Apply tokenizer to training data and testing data
train_data = train_data.map(tokenize, batched=True)
test_data = test_data.map(tokenize, batched=True)

# Convert to PyTorch tensors
train_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=8)
test_loader = DataLoader(test_data, batch_size=8)

print("\nLoading and preprocessing Data complete!")


Loading and preprocessing Data complete!


In [134]:
ds = load_dataset("imdb")
ds

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

In [135]:
ds["train"][3000]

{'text': "This is a bad movie. Not one of the funny bad ones either. This is a lousy bad one. It was actually painful to watch. The direction was awful,with lots of jumping around and the green and yellow hues used throughout the movie makes the characters look sickly. Keira Knightly was not convincing as a tough chick at all,and I cannot believe Lucy Liu and Mickey Rourke signed on for this criminal waste of celluloid. The script was terrible and the acting was like fingernails across a chalkboard. If you haven't seen it,don't. You are not missing anything and will only waste two hours of your life watching this drivel .I have seen bad movies before and even enjoyed them due to their faults. This one is just a waste of time.",
 'label': 0}

In [136]:
# ===========================
# DEFINE TEACHER MODEL (BERT)
# ===========================
# Load pretrained BERT model for sequence classification to device
teacher = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)

print("\nTeacher model definition complete!") 

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Teacher model definition complete!


In [137]:
# ===========================
# TRAINING LOOP (BERT)
# ===========================

# Start training timer
start_time = time.time()

epochs_count = 10
print(f"\nTraining started with {epochs_count} epochs ...")

# Define optimizer
optimizer = torch.optim.AdamW(teacher.parameters(), lr=2e-5)

for epoch in range(epochs_count):  # train epochs_count epochs
    teacher.train()
    total_loss = 0.0

    # Start epoch timer
    epoch_start_time = time.time()

    for batch in train_loader:
        # Move input data to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Get teacher predictions
        teacher_logits = teacher(input_ids=input_ids, attention_mask=attention_mask).logits

        # Compute loss with true labels
        loss = nn.CrossEntropyLoss()(teacher_logits, labels)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Stop the epoch timer
    epoch_end_time = time.time()

    epoch_time = timedelta(seconds=(epoch_end_time - epoch_start_time))
    print(f"[Epoch {epoch+1}] Training Loss: {total_loss:.4f}  Epoch Time: {epoch_time}")

# Stop training timer
end_time = time.time()
training_time = timedelta(seconds=(end_time - start_time))

print(f"\nTraining complete in {training_time}!")




Training started with 10 epochs ...
[Epoch 1] Training Loss: 1036.8148  Epoch Time: 0:02:59.928827
[Epoch 2] Training Loss: 555.1100  Epoch Time: 0:03:00.004059
[Epoch 3] Training Loss: 284.8935  Epoch Time: 0:03:00.006631
[Epoch 4] Training Loss: 183.4840  Epoch Time: 0:02:59.969044
[Epoch 5] Training Loss: 132.1971  Epoch Time: 0:02:59.998647
[Epoch 6] Training Loss: 107.6060  Epoch Time: 0:02:59.962089
[Epoch 7] Training Loss: 99.4189  Epoch Time: 0:02:59.976020
[Epoch 8] Training Loss: 84.9268  Epoch Time: 0:02:59.999456
[Epoch 9] Training Loss: 70.9583  Epoch Time: 0:02:59.988552
[Epoch 10] Training Loss: 67.9117  Epoch Time: 0:02:59.996357

Training complete in 0:29:59.841565!


In [148]:
# =======================
# EVALUATE TEACHER MODEL 
# =======================

# Start evaluation timer
start_time = time.time()

print ("\nEvaluation started ...")

teacher.eval()
teacher_preds = []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = teacher(input_ids=input_ids, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=1)

        teacher_preds.extend(preds.cpu().numpy())

teacher_accuracy = accuracy_score(true_labels, teacher_preds)
print(f"BERT Teacher Accuracy: {teacher_accuracy:.4f}")

# Stop evaluation timer
end_time = time.time()
evaluation_time = timedelta(seconds=(end_time - start_time))

print(f"\nEvaluation complete in {evaluation_time}!")


Evaluation started ...
BERT Teacher Accuracy: 0.8757

Evaluation complete in 0:00:47.097315!


In [222]:
# ==================================
# DEFINE STUDENT MODEL (DistilBERT)
# ==================================
# Create student model with BERT's hidden size and intermediate layer size, but only 6 layers
student_config = DistilBertConfig(
    num_labels=2,
    n_layers=6,           # DistilBERT with half the layers of BERT
    dim=768,              # hidden size of BERT
    hidden_dim=3072,      # intermediate (feed-forward) layer size of BERT
    dropout=0.1,
    attention_dropout=0.1
)
student = DistilBertForSequenceClassification(student_config).to(device)

print("\nStudent model definition complete!")


Student model definition complete!


In [223]:
# ===========================
# DISTILLATION LOSS FUNCTION
# ===========================
def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    """
    student_logits: raw output from student model
    teacher_logits: raw output from teacher model
    labels: ground truth labels (0 ... negative, 1 ... positve)
    T: temperature for softening the distribution
    alpha: tradeoff between soft and hard loss
    """
    # soft target loss: student follows the teacher's behaviour
    # Kullback-Leibler divergence between softened distributions
    # (shows how much the student deviates from the teacher)
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        nn.functional.log_softmax(student_logits / T, dim=1),
        nn.functional.softmax(teacher_logits / T, dim=1)
    ) * (T * T)

    # hard target loss: student predicts the true label
    # standard classification loss with the true labels
    hard_loss = nn.CrossEntropyLoss()(student_logits, labels)

    # combining soft target loss and hard target loss
    # alpha: how much loss from teacher (soft) and how much from true labels (hard)
    return alpha * soft_loss + (1 - alpha) * hard_loss



In [224]:
# ===========================
# TRAINING LOOP (DistilBERT)
# ===========================

# Start training timer
start_time = time.time()

epochs_count = 10
T = 4.0
alpha = 0.0    # no teacher, only true labels
#alpha = 0.25
#alpha = 0.5    # half teacher, half true labels
#alpha = 0.75
#alpha = 1.0    # only teacher, no true labels

print(f"\nTraining with {epochs_count} epochs (alpha = {alpha} ; T = {T})")

teacher.eval()  # "freeze" teacher weights

# Define optimizer
optimizer = torch.optim.AdamW(student.parameters(), lr=2e-5)

for epoch in range(epochs_count):  # train epochs_count epochs
    student.train()
    total_loss = 0.0

    # Start epoch timer
    epoch_start_time = time.time()

    for batch in train_loader:
        # Move input data to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Get teacher predictions without computing gradients
        with torch.no_grad():
            teacher_logits = teacher(input_ids=input_ids, attention_mask=attention_mask).logits

        # Get student predictions
        student_logits = student(input_ids=input_ids, attention_mask=attention_mask).logits

        # Compute distillation loss
        loss = distillation_loss(student_logits, teacher_logits, labels, T, alpha)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Stop the epoch timer
    epoch_end_time = time.time()

    epoch_time = timedelta(seconds=(epoch_end_time - epoch_start_time))
    print(f"[Epoch {epoch+1}] Training Loss: {total_loss:.4f}  Epoch Time: {epoch_time}")

# Stop training timer
end_time = time.time()
training_time = timedelta(seconds=(end_time - start_time))

print(f"\nTraining complete in {training_time}!")


Training with 10 epochs (alpha = 0.0 ; T = 4.0)
[Epoch 1] Training Loss: 1534.3619  Epoch Time: 0:02:25.156244
[Epoch 2] Training Loss: 989.2915  Epoch Time: 0:02:25.381793
[Epoch 3] Training Loss: 778.0316  Epoch Time: 0:02:25.348408
[Epoch 4] Training Loss: 625.4322  Epoch Time: 0:02:25.325389
[Epoch 5] Training Loss: 490.6610  Epoch Time: 0:02:25.381561
[Epoch 6] Training Loss: 382.2016  Epoch Time: 0:02:25.318443
[Epoch 7] Training Loss: 280.5184  Epoch Time: 0:02:25.363282
[Epoch 8] Training Loss: 224.5050  Epoch Time: 0:02:25.340927
[Epoch 9] Training Loss: 166.3939  Epoch Time: 0:02:25.326895
[Epoch 10] Training Loss: 135.3951  Epoch Time: 0:02:25.316045

Training complete in 0:24:13.268500!


In [225]:
# =======================
# EVALUATE STUDENT MODEL
# =======================

# Start evaluation timer
start_time = time.time()

print ("\nEvaluation started ...")

student.eval()
student_preds, true_labels = [], []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = student(input_ids=input_ids, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=1)

        student_preds.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

student_accuracy = accuracy_score(true_labels, student_preds)
print(f"\nDistilBERT Student Accuracy: {student_accuracy:.4f}")
print(f"BERT Teacher Accuracy: {teacher_accuracy:.4f}")

# Stop evaluation timer
end_time = time.time()
evaluation_time = timedelta(seconds=(end_time - start_time))

print(f"\nEvaluation complete in {evaluation_time}!")


Evaluation started ...

DistilBERT Student Accuracy: 0.7603
BERT Teacher Accuracy: 0.8757

Evaluation complete in 0:00:25.506938!
