In [None]:
!pip install evaluate --quiet

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
import numpy as np
import evaluate
from collections import defaultdict


dataset = load_dataset("lucadiliello/newsqa")
print("Dataset sizes:", {k: len(v) for k, v in dataset.items()})

train_subset = dataset["train"].select(range(200))
val_subset   = dataset["validation"].select(range(50))
print(f"Using train={len(train_subset)} samples, val={len(val_subset)} samples")


In [None]:
model_name = "deepset/roberta-base-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
print("Loaded model:", model_name)

In [None]:
def _extract_answer_info(answer_obj):
    if isinstance(answer_obj, dict):
        texts = answer_obj.get("text", [])
        starts = answer_obj.get("answer_start", [])
        if texts and starts:
            return texts[0], starts
    elif isinstance(answer_obj, list) and len(answer_obj) > 0:
        first = answer_obj[0]
        if isinstance(first, dict):
            texts = first.get("text", [])
            starts = first.get("answer_start", [])
            if texts and starts:
                return texts[0], starts
    return None, []

In [None]:

def preprocess_function(examples):
    questions = [q.strip() if q else "" for q in examples["question"]]
    contexts = [c if c else "" for c in examples["context"]]

    inputs = tokenizer(
        questions,
        contexts,
        max_length=256,
        truncation="only_second",
        stride=64,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length"
    )

    sample_mapping = inputs.pop("overflow_to_sample_mapping")
    offset_mapping = inputs.pop("offset_mapping")

    start_positions, end_positions, example_ids = [], [], []

    for i, offsets in enumerate(offset_mapping):
        input_ids = inputs["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)
        sequence_ids = inputs.sequence_ids(i)
        sample_index = sample_mapping[i]
        answer_obj = examples["answers"][sample_index]

        answer_text, answer_starts = _extract_answer_info(answer_obj)
        if answer_text is None or len(answer_starts) == 0:
            start_positions.append(cls_index)
            end_positions.append(cls_index)
            example_ids.append(str(sample_index))
            continue

        start_char = int(answer_starts[0])
        end_char = start_char + len(answer_text)

        token_start_index = 0
        while sequence_ids[token_start_index] != 1:
            token_start_index += 1
        token_end_index = len(input_ids) - 1
        while sequence_ids[token_end_index] != 1:
            token_end_index -= 1

        if not (start_char >= offsets[token_start_index][0] and end_char <= offsets[token_end_index][1]):
            start_positions.append(cls_index)
            end_positions.append(cls_index)
        else:
            while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                token_start_index += 1
            start_positions.append(token_start_index - 1)

            while token_end_index >= 0 and offsets[token_end_index][1] >= end_char:
                token_end_index -= 1
            end_positions.append(token_end_index + 1)

        example_ids.append(str(sample_index))

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    inputs["example_id"] = example_ids

    return inputs

In [None]:
tokenized_train = train_subset.map(preprocess_function, batched=True, remove_columns=train_subset.column_names)
tokenized_val   = val_subset.map(preprocess_function, batched=True, remove_columns=val_subset.column_names)


In [None]:
args = TrainingArguments(
    output_dir="roberta_qa_model",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_steps=20,
    save_total_limit=1,
    report_to=[],)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer
)

trainer.train()

In [None]:
metric = evaluate.load("squad")

def compute_metrics_from_logits(start_logits, end_logits, features, raw_examples):
    """Compute F1 and EM, safely handling empty ground truths."""
    example_to_preds = defaultdict(list)

    for i, (s_log, e_log) in enumerate(zip(start_logits, end_logits)):
        s = int(np.argmax(s_log))
        e = int(np.argmax(e_log))
        if s > e:
            pred_text = ""
        else:
            pred_text = tokenizer.decode(features["input_ids"][i][s:e + 1], skip_special_tokens=True).strip()
        example_to_preds[features["example_id"][i]].append(pred_text)

    preds = [{"id": k, "prediction_text": max(v, key=len) if v else ""} for k, v in example_to_preds.items()]

    refs = []
    for i in range(len(raw_examples)):
        txt, starts = _extract_answer_info(raw_examples[i]["answers"])
        if not txt:
            refs.append({"id": str(i), "answers": {"text": [""], "answer_start": [0]}})
        else:
            refs.append({
                "id": str(i),
                "answers": {"text": [txt], "answer_start": [int(starts[0]) if starts else 0]}
            })

    preds = [p for p in preds if int(p["id"]) < len(refs)]
    refs = refs[:len(preds)]

    try:
        return metric.compute(predictions=preds, references=refs)
    except ValueError:
        safe_refs = [r for r in refs if len(r["answers"]["text"]) > 0]
        safe_preds = [p for p in preds if int(p["id"]) < len(safe_refs)]
        return metric.compute(predictions=safe_preds, references=safe_refs)

In [None]:
predictions = trainer.predict(tokenized_val)
start_logits, end_logits = predictions.predictions

results = compute_metrics_from_logits(start_logits, end_logits, tokenized_val, val_subset)

print("\n Final Evaluation Results:")
print(results)