<a href="https://colab.research.google.com/github/NataliaVrabcova/Assessment_2_Mini_Project/blob/main/MSO3255_Assessment2_BaseCode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install datasets

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Dataset
import numpy as np
import string, re

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    BertConfig,
    BertForQuestionAnswering
)


In [None]:
pip install evaluate
import evaluate


In [None]:
###############################################################################
# CONFIG
###############################################################################
TEACHER_MODEL_NAME = "bert-large-uncased-whole-word-masking-finetuned-squad"
MAX_LENGTH = 384
BATCH_SIZE = 8
EPOCHS = 2  # Increase for better results
LEARNING_RATE = 3e-5

# Distillation hyperparams
ALPHA = 0.5        # Weight of soft (teacher) loss
TEMPERATURE = 3.0  # Temperature for soft logits

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


In [6]:

###############################################################################
# 1) DATA LOADING & PREPROCESSING
###############################################################################
raw_squad = load_dataset("squad")

# For demonstration, we'll use the full training set or a subset
train_data = raw_squad["train"]
val_data = raw_squad["validation"].select(range(1000))  # Subset of 1k for speed

tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL_NAME, use_fast=True)

def preprocess_function(ex):
    """
    Tokenize question + context and try to map answer start/end to token indices.
    We'll do a naive single-chunk approach (no sliding window).
    """
    # SQuAD "answers" has a list of possible answers; we take the first
    start_char = ex["answers"]["answer_start"][0]
    ans_texts = ex["answers"]["text"]
    answer_text = ans_texts[0] if len(ans_texts) > 0 else ""

    encoding = tokenizer(
        ex["question"],
        ex["context"],
        max_length=MAX_LENGTH,
        padding="max_length",
        truncation=True,
        return_offsets_mapping=True  # We'll use offsets for naive char->token mapping
    )

    offsets = encoding["offset_mapping"]
    input_ids = encoding["input_ids"]

    # Find start/end token indices
    start_token_idx = 0
    end_token_idx = 0

    # End char
    end_char = start_char + len(answer_text)

    # loop through offsets to find the best match
    for i, (off_start, off_end) in enumerate(offsets):
        # Some offsets may be None or special tokens
        if off_start is None or off_end is None:
            continue
        if off_start <= start_char < off_end:
            start_token_idx = i
        if off_start < end_char <= off_end:
            end_token_idx = i
            break

    if end_token_idx < start_token_idx:
        end_token_idx = start_token_idx

    # Store in encoding
    encoding["start_positions"] = start_token_idx
    encoding["end_positions"] = end_token_idx

    # Remove offset mapping to reduce data size
    encoding.pop("offset_mapping")

    return encoding

train_processed = train_data.map(preprocess_function)
val_processed   = val_data.map(preprocess_function)

# We'll convert to PyTorch Tensors
def to_tensor_dataset(hf_dataset):
    input_ids = torch.tensor(hf_dataset["input_ids"], dtype=torch.long)
    attention_mask = torch.tensor(hf_dataset["attention_mask"], dtype=torch.long)
    token_type_ids = torch.tensor(hf_dataset["token_type_ids"], dtype=torch.long) \
        if "token_type_ids" in hf_dataset.features else None
    start_positions = torch.tensor(hf_dataset["start_positions"], dtype=torch.long)
    end_positions   = torch.tensor(hf_dataset["end_positions"], dtype=torch.long)

    if token_type_ids is not None:
        return TensorDataset(input_ids, attention_mask, token_type_ids, start_positions, end_positions)
    else:
        # For models without token_type_ids (like DistilBERT)
        return TensorDataset(input_ids, attention_mask, start_positions, end_positions)

train_tds = to_tensor_dataset(train_processed)
val_tds   = to_tensor_dataset(val_processed)

train_loader = DataLoader(train_tds, batch_size=BATCH_SIZE, shuffle=False)
val_loader   = DataLoader(val_tds,   batch_size=BATCH_SIZE, shuffle=False)


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [5]:
###############################################################################
# 2) TEACHER MODEL INFERENCE (collect teacher logits)
###############################################################################
teacher_model = AutoModelForQuestionAnswering.from_pretrained(TEACHER_MODEL_NAME).to(DEVICE)
teacher_model.eval()

teacher_start_logits_list = []
teacher_end_logits_list   = []
gt_start_list = []
gt_end_list   = []
input_tensors = []
count=0
with torch.no_grad():
    for batch in train_loader:
      count+=1
      print("batch", count)
      # batch can have 4 or 5 tensors depending on token_type_ids existence
      if len(batch) == 5:
          input_ids, attention_mask, token_type_ids, start_pos, end_pos = batch
          token_type_ids = token_type_ids.to(DEVICE)
      else:
          input_ids, attention_mask, start_pos, end_pos = batch
          token_type_ids = None

      input_ids = input_ids.to(DEVICE)
      attention_mask = attention_mask.to(DEVICE)
      start_pos = start_pos.to(DEVICE)
      end_pos   = end_pos.to(DEVICE)

      outputs = teacher_model(
          input_ids=input_ids,
          attention_mask=attention_mask,
          token_type_ids=token_type_ids
      )
      # Collect logits
      teacher_start_logits_list.append(outputs.start_logits.cpu())
      teacher_end_logits_list.append(outputs.end_logits.cpu())

      gt_start_list.append(start_pos.cpu())
      gt_end_list.append(end_pos.cpu())

      # We'll store the CPU tensors of inputs for the student dataset
      if token_type_ids is not None:
          input_tensors.append((input_ids.cpu(), attention_mask.cpu(), token_type_ids.cpu()))
      else:
          # Rare for BERT-based QA to not have token_type_ids, but just in case:
          # We'll store None in place of token_type_ids.
          input_tensors.append((input_ids.cpu(), attention_mask.cpu(), None))

# Concatenate teacher outputs
teacher_start_logits_full = torch.cat(teacher_start_logits_list, dim=0)
teacher_end_logits_full   = torch.cat(teacher_end_logits_list, dim=0)
gt_start_full = torch.cat(gt_start_list, dim=0)
gt_end_full   = torch.cat(gt_end_list, dim=0)

# Flatten input tensors
all_input_ids         = []
all_attention_masks   = []
all_token_type_ids    = []
for batch_data in input_tensors:
    i_ids, i_mask, i_type = batch_data
    all_input_ids.append(i_ids)
    all_attention_masks.append(i_mask)
    if i_type is not None:
        all_token_type_ids.append(i_type)
    else:
        # If there's no token_type_ids, we store None
        all_token_type_ids = None

all_input_ids       = torch.cat(all_input_ids, dim=0)
all_attention_masks = torch.cat(all_attention_masks, dim=0)
if all_token_type_ids is not None and len(all_token_type_ids) > 0:
    all_token_type_ids = torch.cat(all_token_type_ids, dim=0)

print("Teacher logits shapes:")
print("  start_logits:", teacher_start_logits_full.shape)
print("  end_logits:  ", teacher_end_logits_full.shape)
print("Ground truth shapes:", gt_start_full.shape, gt_end_full.shape)


model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


batch 1
batch 2
batch 3
batch 4
batch 5
batch 6
batch 7
batch 8
batch 9
batch 10
batch 11
batch 12
batch 13
batch 14
batch 15
batch 16
batch 17
batch 18
batch 19
batch 20
batch 21
batch 22
batch 23
batch 24
batch 25
batch 26
batch 27
batch 28
batch 29
batch 30
batch 31
batch 32
batch 33
batch 34
batch 35
batch 36
batch 37
batch 38
batch 39
batch 40
batch 41
batch 42
batch 43
batch 44
batch 45
batch 46
batch 47
batch 48
batch 49
batch 50
batch 51
batch 52
batch 53
batch 54
batch 55
batch 56
batch 57
batch 58
batch 59
batch 60
batch 61
batch 62
batch 63
batch 64
batch 65
batch 66
batch 67
batch 68
batch 69
batch 70
batch 71
batch 72
batch 73
batch 74
batch 75
batch 76
batch 77
batch 78
batch 79
batch 80
batch 81
batch 82
batch 83
batch 84
batch 85
batch 86
batch 87
batch 88
batch 89
batch 90
batch 91
batch 92
batch 93
batch 94
batch 95
batch 96
batch 97
batch 98
batch 99
batch 100
batch 101
batch 102
batch 103
batch 104
batch 105
batch 106
batch 107
batch 108
batch 109
batch 110
batch 11

KeyboardInterrupt: 

In [7]:

###############################################################################
# 3) DEFINE A SMALLER STUDENT MODEL
###############################################################################
student_config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=384,         # smaller hidden dim
    num_hidden_layers=8,     # fewer layers
    num_attention_heads=8,   # fewer heads
    intermediate_size=384 * 4,
    max_position_embeddings=MAX_LENGTH
)
student_model = BertForQuestionAnswering(student_config).to(DEVICE)


In [8]:
###############################################################################
# 4) DISTILLATION TRAINING
###############################################################################
# We'll build a custom dataset that yields inputs & teacher (soft) + ground-truth (hard)
# We'll store teacher's probabilities (soft targets) with a temperature-based softmax.

def softmax_with_temperature(logits, temperature=TEMPERATURE):
    # logits: [batch_size, seq_len]
    return F.softmax(logits / temperature, dim=-1)

teacher_start_probs = softmax_with_temperature(teacher_start_logits_full, TEMPERATURE)
teacher_end_probs   = softmax_with_temperature(teacher_end_logits_full,   TEMPERATURE)

class DistillationDataset(Dataset):
    def __init__(self, input_ids, attention_mask, token_type_ids,
                 teacher_start_probs, teacher_end_probs,
                 gt_start, gt_end):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.teacher_start_probs = teacher_start_probs
        self.teacher_end_probs   = teacher_end_probs
        self.gt_start = gt_start
        self.gt_end   = gt_end

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        if self.token_type_ids is not None:
            return (
                self.input_ids[idx],
                self.attention_mask[idx],
                self.token_type_ids[idx],
                self.teacher_start_probs[idx],
                self.teacher_end_probs[idx],
                self.gt_start[idx],
                self.gt_end[idx],
            )
        else:
            # If no token_type_ids exist
            return (
                self.input_ids[idx],
                self.attention_mask[idx],
                None,
                self.teacher_start_probs[idx],
                self.teacher_end_probs[idx],
                self.gt_start[idx],
                self.gt_end[idx],
            )

distill_dataset = DistillationDataset(
    all_input_ids, all_attention_masks, all_token_type_ids,
    teacher_start_probs, teacher_end_probs,
    gt_start_full, gt_end_full
)
distill_loader = DataLoader(distill_dataset, batch_size=BATCH_SIZE, shuffle=True)

optimizer = torch.optim.AdamW(student_model.parameters(), lr=LEARNING_RATE)

def distillation_train_step(batch_data):
    if all_token_type_ids is not None:
        input_ids, attention_mask, token_type_ids, t_start_probs, t_end_probs, gt_start, gt_end = batch_data
        token_type_ids = token_type_ids.to(DEVICE)
    else:
        input_ids, attention_mask, _, t_start_probs, t_end_probs, gt_start, gt_end = batch_data
        token_type_ids = None

    input_ids = input_ids.to(DEVICE)
    attention_mask = attention_mask.to(DEVICE)
    t_start_probs = t_start_probs.to(DEVICE)
    t_end_probs   = t_end_probs.to(DEVICE)
    gt_start = gt_start.to(DEVICE)
    gt_end   = gt_end.to(DEVICE)

    optimizer.zero_grad()
    outputs = student_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids
    )
    student_start_logits = outputs.start_logits
    student_end_logits   = outputs.end_logits

    # 1) Hard loss (CE with ground truth)
    ce_loss_fn = nn.CrossEntropyLoss()
    loss_start_hard = ce_loss_fn(student_start_logits, gt_start)
    loss_end_hard   = ce_loss_fn(student_end_logits,   gt_end)
    hard_loss = 0.5 * (loss_start_hard + loss_end_hard)

    # 2) Soft loss (KL divergence with teacher distribution)
    # Convert student logits to probabilities with same temperature
    s_start_probs = softmax_with_temperature(student_start_logits, TEMPERATURE)
    s_end_probs   = softmax_with_temperature(student_end_logits,   TEMPERATURE)

    # We'll use KLDivLoss (expects log probs for input by default)
    kl_loss_fn = nn.KLDivLoss(reduction="batchmean")
    start_kl = kl_loss_fn(s_start_probs.log(), t_start_probs)  # KL(student||teacher)
    end_kl   = kl_loss_fn(s_end_probs.log(),   t_end_probs)
    soft_loss = 0.5 * (start_kl + end_kl)

    loss = ALPHA * soft_loss + (1 - ALPHA) * hard_loss
    loss.backward()
    optimizer.step()

    return loss.item(), hard_loss.item(), soft_loss.item()

# Run distillation training
for epoch in range(EPOCHS):
    print(f"\n=== EPOCH {epoch+1}/{EPOCHS} ===")
    epoch_losses = []
    for step, batch_data in enumerate(distill_loader):
        loss_val, hard_val, soft_val = distillation_train_step(batch_data)
        epoch_losses.append(loss_val)
        if step % 200 == 0:
            print(f" Step {step} - Distill Loss: {loss_val:.4f} (Hard: {hard_val:.4f}, Soft: {soft_val:.4f})")
    print(f" Average Loss: {np.mean(epoch_losses):.4f}")


NameError: name 'teacher_start_logits_full' is not defined

In [None]:
###############################################################################
# Official SQuAD Evaluation Evaluation
###############################################################################

## Initialize the evaluation metric for SQuAD
metric = evaluate.load("squad")
## Metric Computation Function
# Function to compute the SQuAD metrics
def compute_metrics(p):
    predictions, labels = p
    # Extract the predicted start and end token indices
    pred_start, pred_end = predictions
    # Extract the label's start and end token indices
    label_start, label_end = labels

    # Decode the predicted token indices back into text answers
    pred_answers = tokenizer.decode(pred_start)
    label_answers = tokenizer.decode(label_start)

    # Compute the EM and F1 scores using the SQuAD metric
    return metric.compute(predictions=pred_answers, references=label_answers)

## Model Evaluation on Validation Set

# Evaluate during the validation phase
for batch in val_loader:
    # Process the batch and get the model predictions (start_logits, end_logits)
    outputs = model(input_ids=batch[0], attention_mask=batch[1])  # Adjust if needed for your model
    start_logits, end_logits = outputs.start_logits, outputs.end_logits

    # Use the compute_metrics function to calculate Exact Match (EM) and F1 scores
    metrics = compute_metrics((start_logits, end_logits))
    # Output the results
    print(f"Exact Match (EM): {metrics['exact_match']}, F1 Score: {metrics['f1']}")



1. Evaluation Metric Setup

In this step, we initialize the evaluation metric that will be used to assess the performance of the model on the SQuAD dataset. We use the evaluate library, which provides a standard implementation for the SQuAD metric.
2. Metric Computation Function

Next, we define a helper function compute_metrics, which will be responsible for computing the EM and F1 scores during the validation phase. This function takes in a tuple containing the predicted and true token indices.
3. Model Evaluation on Validation Set

After training the student model, we evaluate its performance during the validation phase. The model outputs logits for the start and end positions of the answer span for each input question. We use these logits to calculate the EM and F1 scores.

This evaluation pipeline provides a clear way to track how well the student model is performing on the SQuAD dataset, particularly in terms of answer span prediction. The Exact Match score gives us an understanding of how often the model produces exact, correct answers, while the F1 score provides insight into the quality of the predicted answers by measuring both precision and recall.

In [None]:


###############################################################################
# 5) POST-PROCESSING & EVALUATION
###############################################################################
# We'll do a naive approach: take argmax of start/end, decode, compare with ground truth.
val_contexts = val_data["context"]
val_questions = val_data["question"]
val_answers = val_data["answers"]  # list of dicts with "text", "answer_start"

def normalize_text(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(txt):
        return re.sub(r"\b(a|an|the)\b", " ", txt)
    def remove_punc(txt):
        return "".join(ch for ch in txt if ch not in string.punctuation)

    s = s.lower()
    s = remove_articles(s)
    s = remove_punc(s)
    s = " ".join(s.split())
    return s

def compute_exact_match(pred, truth):
    return int(normalize_text(pred) == normalize_text(truth))

def compute_f1(pred, truth):
    pred_tokens = normalize_text(pred).split()
    truth_tokens = normalize_text(truth).split()
    common = set(pred_tokens) & set(truth_tokens)
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)
    if len(common) == 0:
        return 0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(truth_tokens)
    if precision + recall == 0:
        return 0
    return 2 * (precision * recall) / (precision + recall)

student_model.eval()
em_scores = []
f1_scores = []

val_dataloader = DataLoader(val_tds, batch_size=8, shuffle=False)
offset = 0  # to track the global index in val_data
with torch.no_grad():
    for batch in val_dataloader:
        if len(batch) == 5:
            input_ids, attention_mask, token_type_ids, start_pos, end_pos = batch
            token_type_ids = token_type_ids.to(DEVICE)
        else:
            input_ids, attention_mask, start_pos, end_pos = batch
            token_type_ids = None

        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        outputs = student_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        start_logits = outputs.start_logits
        end_logits   = outputs.end_logits

        start_indices = torch.argmax(start_logits, dim=1).cpu().numpy()
        end_indices   = torch.argmax(end_logits,   dim=1).cpu().numpy()

        for i in range(len(start_indices)):
            global_idx = offset + i
            if global_idx >= len(val_contexts):
                continue

            s_ind = start_indices[i]
            e_ind = end_indices[i]
            if e_ind < s_ind:
                e_ind = s_ind

            # Decode predicted tokens
            tokens_ = input_ids[i][s_ind : e_ind+1].cpu().numpy().tolist()
            pred_text = tokenizer.decode(tokens_, skip_special_tokens=True)

            # Ground truth: we pick the first answer
            gold_answers = val_answers[global_idx]["text"]
            if len(gold_answers) > 0:
                gold_answer = gold_answers[0]
            else:
                gold_answer = ""

            em = compute_exact_match(pred_text, gold_answer)
            f1 = compute_f1(pred_text, gold_answer)

            em_scores.append(em)
            f1_scores.append(f1)

        offset += len(start_indices)

avg_em = np.mean(em_scores) * 100
avg_f1 = np.mean(f1_scores) * 100
print(f"\nValidation Results (subset of 1000 samples):")
print(f"  Exact Match: {avg_em:.2f}%")
print(f"  F1 Score:    {avg_f1:.2f}%")



In [None]:
import torch

def ask_question(question: str, context: str, model, tokenizer, device="cpu"):
    """
    Given a question and a context, use the provided model to
    predict the answer span and return the decoded string answer.
    """
    # 1) Encode inputs
    inputs = tokenizer(question, context, return_tensors="pt", truncation=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    token_type_ids = None
    if "token_type_ids" in inputs:
        token_type_ids = inputs["token_type_ids"].to(device)

    # 2) Forward pass
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits

    # 3) Get predicted start/end token indices
    start_index = torch.argmax(start_logits, dim=1).item()
    end_index = torch.argmax(end_logits, dim=1).item()

    # Ensure the end_index is >= start_index
    if end_index < start_index:
        end_index = start_index

    # 4) Decode tokens back to string
    answer_ids = input_ids[0, start_index : end_index+1]
    answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True)

    return answer_text


# -----------------------------
# Example usage
# -----------------------------

# Suppose you have:
#   teacher_model, student_model (both on the same device, e.g., "cuda" or "cpu")
#   tokenizer (matching your BERT-based QA model)
# Example question + context:
#question = "What is the capital of France?"
#context = "France is a country in Europe. Its largest city and capital is Paris. It is known for the Eiffel Tower."
#question = "Which country is Middlesex University based?"
#question = "Is Middlesex University a public or an independent university?"
#context = "Middlesex University London is a public research university based in Hendon, northwest London, England."
question = "Which city is Galatasaray based in?"
context = "Galatasaray, is a Turkish professional football club based on the European side of the city of Istanbul. It is founded in 1905. The team traditionally play in dark shades of red and yellow at home."

# Evaluate with teacher model
teacher_model.eval()
teacher_answer = ask_question(question, context, teacher_model, tokenizer, device="cuda")
print(f"[Teacher Answer]: {teacher_answer}")

# Evaluate with student model
student_model.eval()
student_answer = ask_question(question, context, student_model, tokenizer, device="cuda")
print(f"[Student Answer]: {student_answer}")


[Teacher Answer]: istanbul
[Student Answer]: yellow
