In [1]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from pprint import pprint
from tqdm import tqdm
import evaluate
import torch

debugging = True

# Load the SQuAD dataset
squad = load_dataset("squad")


# Print formatted records
#pprint(train_dataset[:2])
#pprint(validation_dataset[:2])


In [2]:
# Load teacher model and tokenizer
teacher_model_name = "csarron/roberta-base-squad-v1"
teacher_model = AutoModelForQuestionAnswering.from_pretrained(teacher_model_name)
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)

# Load student model and tokenizer (smaller version of RoBERTa)
student_model_name = "distilroberta-base"  # Example smaller model
student_model = AutoModelForQuestionAnswering.from_pretrained(student_model_name)
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)

Some weights of the model checkpoint at csarron/roberta-base-squad-v1 were not used when initializing RobertaForQuestionAnswering: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForQuestionAnswering were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
# Define evaluation function
def evaluate_model(model, tokenizer, dataset):
    metric = evaluate.load("squad")
    predictions = []
    references = []

    for example in tqdm(dataset, desc="Evaluating"):
        # Tokenize inputs
        inputs = tokenizer(
            example["context"], example["question"], truncation=True, padding=True, return_tensors="pt"
        )
        
        # Get model outputs
        outputs = model(**inputs)
        start_logits, end_logits = outputs.start_logits, outputs.end_logits
        start_idx = torch.argmax(start_logits, dim=-1).item()
        end_idx = torch.argmax(end_logits, dim=-1).item()
        
        # Decode prediction
        prediction = tokenizer.decode(inputs["input_ids"][0][start_idx:end_idx + 1])
        
        # Append to predictions
        predictions.append({
            "id": example["id"],
            "prediction_text": prediction
        })

        # Append to references (ground truth)
        references.append({
            "id": example["id"],
            "answers": example["answers"]
        })

    # Debugging: Print a couple of predictions and references
    #pprint(predictions[:2])
    #pprint(references[:2])

    # Compute metrics
    result = metric.compute(predictions=predictions, references=references)
    print(f"Exact Match: {result['exact_match']:.2f}%")
    print(f"F1 Score: {result['f1']:.2f}%\n")
    
    return result

In [4]:
# Evaluate teacher model on validation set

#validation_set = squad["validation"].shuffle(seed=42).select(range(500)).with_format("torch")
validation_dataset = squad["validation"]

if debugging: 
    validation_set = validation_dataset.select(range(100))
    validation_dataset = validation_set

print("Teacher Model Evaluation")
evaluate_model(teacher_model, teacher_tokenizer, validation_dataset)

# Evaluate student model on validation set (before distillation)
print("Student Model Evaluation (Before Distillation)")
evaluate_model(student_model, student_tokenizer, validation_dataset)



Teacher Model Evaluation


Evaluating: 100%|██████████| 100/100 [00:05<00:00, 18.88it/s]


Exact Match: 64.00%
F1 Score: 68.10%

Student Model Evaluation (Before Distillation)


Evaluating: 100%|██████████| 100/100 [00:02<00:00, 37.29it/s]

Exact Match: 0.00%
F1 Score: 1.67%






{'exact_match': 0.0, 'f1': 1.6675498158256774}

In [None]:
# Prepare data for distillation
def preprocess_data(example):
    inputs = teacher_tokenizer(
        example["context"], example["question"], truncation=True, padding=True, return_tensors="pt"
    )
    with torch.no_grad():
        outputs = teacher_model(**inputs)
    example["input_ids"] = inputs["input_ids"][0]
    example["attention_mask"] = inputs["attention_mask"][0]
    example["start_logits"] = outputs.start_logits[0]
    example["end_logits"] = outputs.end_logits[0]
    return example

train_dataset = squad["train"].shuffle(seed=42).select(range(8000)).with_format("torch").map(preprocess_data)

# Define training arguments for student model
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=2,
)

In [None]:
# Define custom loss function for knowledge distillation
class DistillationTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        start_loss = torch.nn.functional.kl_div(
            torch.nn.functional.log_softmax(outputs.start_logits, dim=-1),
            torch.nn.functional.softmax(inputs["start_logits"], dim=-1),
            reduction="batchmean"
        )
        end_loss = torch.nn.functional.kl_div(
            torch.nn.functional.log_softmax(outputs.end_logits, dim=-1),
            torch.nn.functional.softmax(inputs["end_logits"], dim=-1),
            reduction="batchmean"
        )
        loss = (start_loss + end_loss) / 2
        return (loss, outputs) if return_outputs else loss

In [None]:
# Train the student model using the custom trainer
trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
)

trainer.train()

In [None]:
# Evaluate student model after distillation
student_eval_results_after = evaluate_model(student_model, student_tokenizer, validation_set)
print("Student Model Evaluation (After Distillation):", student_eval_results_after)