Extractive QA The model is given a context (paragraph) and a question, and it must find the exact text span in the context that answers the question. The model does not generate new words — it just extracts the correct substring from the context. How the Model Works Conceptually

All models follow this general idea:

Input = [CLS] Question [SEP] Context [SEP]

1)Each token is turned into a vector embedding.

2)The Transformer processes all tokens and produces contextual embeddings.
3)Two special “heads” are added:

-Start head: Predicts the probability that a token is the start of the answer. -End head: Predicts the probability that a token is the end of the answer.

for each token i,

Pstart(i)=softmax(Ws.hi)

Pend(i)=softmax(We.hi)

where hi are hidden states of i and Ws,We are the learnable weights.
 The answer span is then chosen as

 (start, end)=argMax(i,j)(Pstart(i)XPend(j)

In [1]:
!pip install transformers datasets evaluate accelerate




In [3]:
import torch
from datasets import load_dataset
import evaluate
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
from transformers import default_data_collator
import numpy as np

# -------------------------
# 1. Configuration
# -------------------------
MODELS = {
    "bert": "bert-base-uncased",
    "roberta": "roberta-base",
    "deberta": "microsoft/deberta-v3-base",
    "longformer": "allenai/longformer-base-4096"
}

MAX_LENGTH = 512
DOC_STRIDE = 128
BATCH_SIZE = 4
NUM_EPOCHS = 2  # You can increase to 3-4 for better results
OUTPUT_DIR = "./qa_models/"


In [10]:
# 2. Load Dataset
# -------------------------
dataset = load_dataset("squad")
train_dataset = dataset['train']
val_dataset = dataset['validation']

# -------------------------
# 3. Load Evaluation Metric
# -------------------------
metric = evaluate.load("squad")


README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

plain_text/validation-00000-of-00001.par(…):   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [11]:
# 4. Preprocessing Function
# -------------------------
def preprocess_function(examples, tokenizer):
    questions = [q.strip() for q in examples["question"]]
    contexts = examples["context"]

    tokenized_examples = tokenizer(
        questions,
        contexts,
        truncation="only_second",
        max_length=MAX_LENGTH,
        stride=DOC_STRIDE,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length"
    )

    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_examples.pop("offset_mapping")

    start_positions = []
    end_positions = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id) if tokenizer.cls_token_id in input_ids else 0

        sequence_id = tokenized_examples.sequence_ids(i)
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]

        if len(answers["answer_start"]) == 0:
            start_positions.append(cls_index)
            end_positions.append(cls_index)
        else:
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            token_start_index = 0
            while sequence_id[token_start_index] != 1:
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while sequence_id[token_end_index] != 1:
                token_end_index -= 1

            if offsets[token_start_index][0] > end_char or offsets[token_end_index][1] < start_char:
                start_positions.append(cls_index)
                end_positions.append(cls_index)
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                start_positions.append(token_start_index - 1)

                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                end_positions.append(token_end_index + 1)

    tokenized_examples["start_positions"] = start_positions
    tokenized_examples["end_positions"] = end_positions

    return tokenized_examples



In [12]:
print(train_dataset.column_names)
print(val_dataset.column_names)


['id', 'title', 'context', 'question', 'answers']
['id', 'title', 'context', 'question', 'answers']


In [None]:
# 5. Training Loop for Multiple Models
# -------------------------
for model_name, model_checkpoint in MODELS.items():
    print(f"\n===== Training {model_name} =====\n")
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

    tokenized_train = train_dataset.map(lambda x: preprocess_function(x, tokenizer), batched=True, remove_columns=train_dataset.column_names)
    tokenized_val = val_dataset.map(lambda x: preprocess_function(x, tokenizer), batched=True, remove_columns=val_dataset.column_names)

    training_args = TrainingArguments(
        output_dir=f"{OUTPUT_DIR}/{model_name}",
        #evaluation_strategy="epoch",
        learning_rate=3e-5,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=NUM_EPOCHS,
        weight_decay=0.01,
        save_total_limit=1,
        logging_steps=100,
        save_strategy="epoch",
        report_to="none"
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        tokenizer=tokenizer,
        data_collator=default_data_collator
    )

    # Train
    trainer.train()

    # Evaluate
    print(f"Evaluating {model_name}...")
    raw_predictions = trainer.predict(tokenized_val)
    start_logits, end_logits = raw_predictions.predictions

    def postprocess_qa_predictions(examples, features, raw_predictions, tokenizer, n_best_size=20, max_answer_length=30):
        all_start_logits, all_end_logits = raw_predictions
        example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
        features_per_example = {}
        for i, feature in enumerate(features):
            example_id = feature["id"]
            if example_id not in features_per_example:
                features_per_example[example_id] = []
            features_per_example[example_id].append(i)

        predictions = {}
        for example_id, feature_indices in features_per_example.items():
            context = examples[example_id_to_index[example_id]]["context"]
            min_null_score = None
            valid_answers = []

            for feature_index in feature_indices:
                start_logits_i = all_start_logits[feature_index]
                end_logits_i = all_end_logits[feature_index]
                offset_mapping = features[feature_index]["offset_mapping"]

                start_indexes = np.argsort(start_logits_i)[-1: -n_best_size - 1: -1].tolist()
                end_indexes = np.argsort(end_logits_i)[-1: -n_best_size - 1: -1].tolist()
                for start_index in start_indexes:
                    for end_index in end_indexes:
                        if start_index >= len(offset_mapping) or end_index >= len(offset_mapping):
                            continue
                        if offset_mapping[start_index] is None or offset_mapping[end_index] is None:
                            continue
                        if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                            continue
                        start_char = offset_mapping[start_index][0]
                        end_char = offset_mapping[end_index][1]
                        valid_answers.append({"score": start_logits_i[start_index] + end_logits_i[end_index],
                                              "text": context[start_char:end_char]})
            if len(valid_answers) > 0:
                best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
                predictions[example_id] = best_answer["text"]
            else:
                predictions[example_id] = ""

        return predictions

    # Convert features to list of dicts for postprocessing
    val_features = tokenized_val
    examples = val_dataset
    predictions = postprocess_qa_predictions(examples, val_features, (start_logits, end_logits), tokenizer)

    # Prepare references
    references = [{"id": ex["id"], "answers": ex["answers"]} for ex in val_dataset]

    # Compute metric
    em_f1 = metric.compute(predictions=[{"id": k, "prediction_text": v} for k, v in predictions.items()],
                           references=references)
    print(f"{model_name} EM: {em_f1['exact_match']:.2f}, F1: {em_f1['f1']:.2f}")





===== Training bert =====



Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss
100,4.5126
200,2.9416
300,2.545
