<a href="https://colab.research.google.com/github/NataliaVrabcova/Assessment_2_Mini_Project/blob/main/MSO3255_Assessment2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# =========================================
# MSO3255: Distilled Question Answering Model
# Compact BERT with Intermediate Distillation and Pruning
# =========================================

# -----------------------------------------
# 0. SETUP AND DEPENDENCIES
# -----------------------------------------

# Install Hugging Face datasets library for loading SQuAD
pip install datasets

In [4]:
# Core libraries for model, training, and data handling
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, BertConfig, BertForQuestionAnswering

# Utilities for visualization and token normalization
import matplotlib.pyplot as plt
import numpy as np
import re, string
from torchinfo import summary
import os

# Enable Google Drive access to store teacher logits
from google.colab import drive

ModuleNotFoundError: No module named 'datasets'

In [5]:
# -----------------------------------------
# 1. CONFIGURATION
# -----------------------------------------
drive.mount('/content/drive')  # Mount Google Drive

# Model, training, and distillation parameters
TEACHER_MODEL_NAME = "bert-large-uncased-whole-word-masking-finetuned-squad"
MAX_LENGTH = 384
BATCH_SIZE = 8
EPOCHS = 2
LEARNING_RATE = 3e-5
ALPHA = 0.5            # Balance between soft and hard loss
TEMPERATURE = 4.0      # Softmax temperature for teacher logits
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

NameError: name 'drive' is not defined

In [6]:
# -----------------------------------------
# 2. DATA LOADING & PREPROCESSING
# -----------------------------------------

# Load a subset of the SQuAD v1.0 dataset for training and validation
raw_squad = load_dataset("squad")
train_data = raw_squad["train"].select(range(5000))
val_data = raw_squad["validation"].select(range(1000))

# Tokenizer initialization
tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL_NAME, use_fast=True)

# Preprocessing: tokenize input and align answer positions
def preprocess_function(ex):
    start_char = ex["answers"]["answer_start"][0]
    answer_text = ex["answers"]["text"][0] if len(ex["answers"]["text"]) > 0 else ""
    encoding = tokenizer(ex["question"], ex["context"], max_length=MAX_LENGTH, padding="max_length", truncation=True, return_offsets_mapping=True)
    offsets = encoding["offset_mapping"]
    end_char = start_char + len(answer_text)
    start_token_idx = end_token_idx = 0
    for i, (off_start, off_end) in enumerate(offsets):
        if off_start is None or off_end is None:
            continue
        if off_start <= start_char < off_end:
            start_token_idx = i
        if off_start < end_char <= off_end:
            end_token_idx = i
            break
    if end_token_idx < start_token_idx:
        end_token_idx = start_token_idx
    encoding["start_positions"] = start_token_idx
    encoding["end_positions"] = end_token_idx
    encoding.pop("offset_mapping")
    return encoding

train_processed = train_data.map(preprocess_function)
val_processed = val_data.map(preprocess_function)

# Convert Hugging Face dataset to PyTorch tensors
def to_tensor_dataset(hf_dataset):
    input_ids = torch.tensor(hf_dataset["input_ids"], dtype=torch.long)
    attention_mask = torch.tensor(hf_dataset["attention_mask"], dtype=torch.long)
    token_type_ids = torch.tensor(hf_dataset["token_type_ids"], dtype=torch.long) if "token_type_ids" in hf_dataset.features else None
    start_positions = torch.tensor(hf_dataset["start_positions"], dtype=torch.long)
    end_positions = torch.tensor(hf_dataset["end_positions"], dtype=torch.long)
    if token_type_ids is not None:
        return TensorDataset(input_ids, attention_mask, token_type_ids, start_positions, end_positions)
    return TensorDataset(input_ids, attention_mask, start_positions, end_positions)

train_tds = to_tensor_dataset(train_processed)
val_tds = to_tensor_dataset(val_processed)
train_loader = DataLoader(train_tds, batch_size=BATCH_SIZE, shuffle=False)
val_loader = DataLoader(val_tds, batch_size=BATCH_SIZE, shuffle=False)

NameError: name 'load_dataset' is not defined

In [None]:
# -----------------------------------------
# 3. TEACHER MODEL INFERENCE
# -----------------------------------------

# Load and run the teacher model to generate soft labels (logits)
teacher_model = AutoModelForQuestionAnswering.from_pretrained(TEACHER_MODEL_NAME).to(DEVICE).eval()

teacher_start_logits_list, teacher_end_logits_list = [], []
gt_start_list, gt_end_list, input_tensors = [], [], []

with torch.no_grad():
    for batch in train_loader:
        if len(batch) == 5:
            input_ids, attention_mask, token_type_ids, start_pos, end_pos = batch
            token_type_ids = token_type_ids.to(DEVICE)
        else:
            input_ids, attention_mask, start_pos, end_pos = batch
            token_type_ids = None
        input_ids, attention_mask, start_pos, end_pos = input_ids.to(DEVICE), attention_mask.to(DEVICE), start_pos.to(DEVICE), end_pos.to(DEVICE)
        outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        teacher_start_logits_list.append(outputs.start_logits.cpu())
        teacher_end_logits_list.append(outputs.end_logits.cpu())
        gt_start_list.append(start_pos.cpu())
        gt_end_list.append(end_pos.cpu())
        input_tensors.append((input_ids.cpu(), attention_mask.cpu(), token_type_ids.cpu() if token_type_ids is not None else None))

# Aggregate all batches into single tensors
teacher_start_logits_full = torch.cat(teacher_start_logits_list, dim=0)
teacher_end_logits_full = torch.cat(teacher_end_logits_list, dim=0)
gt_start_full = torch.cat(gt_start_list, dim=0)
gt_end_full = torch.cat(gt_end_list, dim=0)

# Flatten the inputs for use in distillation
all_input_ids, all_attention_masks, all_token_type_ids = [], [], []
for i_ids, i_mask, i_type in input_tensors:
    all_input_ids.append(i_ids)
    all_attention_masks.append(i_mask)
    if i_type is not None:
        all_token_type_ids.append(i_type)
    else:
        all_token_type_ids = None

all_input_ids = torch.cat(all_input_ids, dim=0)
all_attention_masks = torch.cat(all_attention_masks, dim=0)
if all_token_type_ids is not None and len(all_token_type_ids) > 0:
    all_token_type_ids = torch.cat(all_token_type_ids, dim=0)

# Save teacher logits and metadata to disk (Google Drive)
LOGITS_PATH = "/content/drive/MyDrive/teacher_logits.pt"
torch.save({
    "start_logits": teacher_start_logits_full,
    "end_logits": teacher_end_logits_full,
    "gt_start": gt_start_full,
    "gt_end": gt_end_full,
    "input_ids": all_input_ids,
    "attention_mask": all_attention_masks,
    "token_type_ids": all_token_type_ids
}, LOGITS_PATH)

In [7]:
# -----------------------------------------
# 4. BUILD STUDENT MODEL
# -----------------------------------------

# Define a smaller BERT model to serve as the student
student_config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=384,
    num_hidden_layers=8,
    num_attention_heads=8,
    intermediate_size=384 * 4,
    max_position_embeddings=MAX_LENGTH
)
student_model = BertForQuestionAnswering(student_config).to(DEVICE)

# Show summary of the student model architecture
summary(student_model, input_size=(BATCH_SIZE, MAX_LENGTH))

NameError: name 'BertConfig' is not defined

In [None]:
# -----------------------------------------
# 5. DISTILLATION TRAINING + PRUNING
# -----------------------------------------

# Define projection layer to match hidden states
project_teacher = nn.Linear(1024, 384).to(DEVICE)

# Define loss functions
kl_loss_fn = nn.KLDivLoss(reduction="batchmean")  # KL Divergence for soft targets
ce_loss_fn = nn.CrossEntropyLoss()                  # Cross-entropy for hard targets
mse_loss_fn = nn.MSELoss()                           # MSE for hidden state matching

# Apply pruning (removes 10% of least important weights)
def apply_pruning(model, amount=0.1):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            torch.nn.utils.prune.l1_unstructured(module, name="weight", amount=amount)
            torch.nn.utils.prune.remove(module, "weight")

apply_pruning(student_model, amount=0.1)
print("Remaining trainable parameters:", sum(p.numel() for p in student_model.parameters() if p.requires_grad))

# Define optimizer after pruning
optimizer = torch.optim.AdamW(student_model.parameters(), lr=LEARNING_RATE)

# Prepare distillation dataset
from torch.utils.data import TensorDataset, DataLoader

def softmax_with_temperature(logits, temperature):
    return torch.nn.functional.softmax(logits / temperature, dim=-1)

teacher_start_probs = softmax_with_temperature(teacher_start_logits_full, TEMPERATURE)
teacher_end_probs = softmax_with_temperature(teacher_end_logits_full, TEMPERATURE)

# Rebuild distillation dataset (handling optional token_type_ids)
if all_token_type_ids is not None:
    distill_dataset = TensorDataset(
        all_input_ids, all_attention_masks, all_token_type_ids,
        teacher_start_probs, teacher_end_probs,
        gt_start_full, gt_end_full
    )
else:
    distill_dataset = TensorDataset(
        all_input_ids, all_attention_masks, torch.zeros_like(all_input_ids),
        teacher_start_probs, teacher_end_probs,
        gt_start_full, gt_end_full
    )

distill_loader = DataLoader(distill_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# =====================================
# DISTILLATION TRAINING LOOP
# =====================================

student_model.train()
epoch_losses = []

for epoch in range(EPOCHS):
    print(f"\n[Epoch {epoch+1}/{EPOCHS}] Starting distillation training...")
    running_loss = []
    for step, batch in enumerate(distill_loader):
        # Unpack batch
        if all_token_type_ids is not None:
            input_ids, attention_mask, token_type_ids, t_start_probs, t_end_probs, gt_start, gt_end = batch
            token_type_ids = token_type_ids.to(DEVICE)
        else:
            input_ids, attention_mask, _, t_start_probs, t_end_probs, gt_start, gt_end = batch
            token_type_ids = None

        # Move tensors to device
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)
        t_start_probs = t_start_probs.to(DEVICE)
        t_end_probs = t_end_probs.to(DEVICE)
        gt_start = gt_start.to(DEVICE)
        gt_end = gt_end.to(DEVICE)

        optimizer.zero_grad()

        # Forward pass teacher (hidden states)
        with torch.no_grad():
            teacher_out = teacher_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
            teacher_hidden = teacher_out.hidden_states[TEACHER_LAYER_TO_MATCH]
            teacher_proj = project_teacher(teacher_hidden)

        # Forward pass student
        student_out = student_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
        student_hidden = student_out.hidden_states[STUDENT_LAYER_TO_MATCH]

        # Compute student outputs
        s_start_logits = student_out.start_logits
        s_end_logits = student_out.end_logits
        s_start_probs = softmax_with_temperature(s_start_logits, TEMPERATURE)
        s_end_probs = softmax_with_temperature(s_end_logits, TEMPERATURE)

        # Calculate losses
        loss_hard = 0.5 * (ce_loss_fn(s_start_logits, gt_start) + ce_loss_fn(s_end_logits, gt_end))
        loss_soft = 0.5 * (kl_loss_fn(s_start_probs.log(), t_start_probs) + kl_loss_fn(s_end_probs.log(), t_end_probs))
        loss_hidden = mse_loss_fn(student_hidden, teacher_proj)

        # Total distillation loss
        loss = ALPHA * loss_soft + (1 - ALPHA) * loss_hard + HIDDEN_LOSS_WEIGHT * loss_hidden
        loss.backward()
        optimizer.step()

        running_loss.append(loss.item())

        if step % 200 == 0:
            print(f"Step {step}: Loss = {loss.item():.4f}")

    epoch_avg = np.mean(running_loss)
    epoch_losses.append(epoch_avg)
    print(f"[Epoch {epoch+1}] Average Loss: {epoch_avg:.4f}")

In [None]:
# -----------------------------------------
# 6. PLOT TRAINING LOSS
# -----------------------------------------

plt.figure(figsize=(7,4))
plt.plot(epoch_losses, marker='o')
plt.title("Distillation Loss per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
# -----------------------------------------
# 7. FINAL VALIDATION EVALUATION (EM + F1)
# -----------------------------------------

student_model.eval()
em_scores = []
f1_scores = []
offset = 0

for batch in val_loader:
    if len(batch) == 5:
        input_ids, attention_mask, token_type_ids, start_pos, end_pos = batch
        token_type_ids = token_type_ids.to(DEVICE)
    else:
        input_ids, attention_mask, start_pos, end_pos = batch
        token_type_ids = None

    input_ids = input_ids.to(DEVICE)
    attention_mask = attention_mask.to(DEVICE)

    with torch.no_grad():
        outputs = student_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

    start_logits = outputs.start_logits
    end_logits = outputs.end_logits

    start_indices = torch.argmax(start_logits, dim=1).cpu().numpy()
    end_indices = torch.argmax(end_logits, dim=1).cpu().numpy()

    for i in range(len(start_indices)):
        s_idx = start_indices[i]
        e_idx = end_indices[i]
        if e_idx < s_idx:
            e_idx = s_idx
        pred_tokens = input_ids[i][s_idx:e_idx+1]
        pred_text = tokenizer.decode(pred_tokens, skip_special_tokens=True)

        gold_text = val_data[offset+i]['answers']['text'][0] if len(val_data[offset+i]['answers']['text']) > 0 else ""

        em = int(normalize_text(pred_text) == normalize_text(gold_text))
        f1 = compute_f1(pred_text, gold_text)

        em_scores.append(em)
        f1_scores.append(f1)

    offset += len(start_indices)

# Report evaluation metrics
print("\nValidation Results:")
print(f"Exact Match (EM): {np.mean(em_scores)*100:.2f}%")
print(f"F1 Score: {np.mean(f1_scores)*100:.2f}%")



In [None]:

# -----------------------------------------
# BAR PLOT: EM and F1 Comparison
# -----------------------------------------

labels = ["Baseline", "Distilled+Tuned", "Pruned+Fine-tuned"]
em_vals = [2.8, 2.0, 1.3]
f1_vals = [6.9, 6.59, 5.3]

x = np.arange(len(labels))
width = 0.35

fig, ax = plt.subplots()
rect1 = ax.bar(x - width/2, em_vals, width, label="EM", color="cornflowerblue")
rect2 = ax.bar(x + width/2, f1_vals, width, label="F1", color="mediumseagreen")

ax.set_ylabel("Score (%)")
ax.set_title("QA Performance by Model Variant")
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()

for r in rect1 + rect2:
    height = r.get_height()
    ax.annotate(f'{height:.2f}',
                xy=(r.get_x() + r.get_width() / 2, height),
                xytext=(0, 3),
                textcoords="offset points",
                ha='center', va='bottom')

plt.tight_layout()
plt.show()