In [70]:
import numpy as np
from tqdm.auto import tqdm
import collections

import torch
    
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForQuestionAnswering
from transformers import TrainingArguments
from transformers import Trainer
import evaluate

device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")

In [71]:
# Sử dụng mô hình "distilbert-base-uncased"
# làm mô hình checkpoint
MODEL_NAME = "distilbert-base-uncased"

# Độ dài tối đa cho mỗi đoạn văn bản
# sau khi được xử lý
MAX_LENGTH = 384

# Khoảng cách giữa các điểm bắt đầu
# của các đoạn văn bản liên tiếp
STRIDE = 128

In [72]:
DATASET_NAME = "squad"
raw_datasets = load_dataset(DATASET_NAME)

Generating train split: 100%|██████████| 87599/87599 [00:00<00:00, 633089.81 examples/s]
Generating validation split: 100%|██████████| 10570/10570 [00:00<00:00, 612649.84 examples/s]


In [73]:
raw_datasets["train"].filter(lambda x: len(x["answers"]["text"]) != 1)

Filter: 100%|██████████| 87599/87599 [00:01<00:00, 52579.21 examples/s]


Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 0
})

In [74]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [75]:
raw_datasets["train"][0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

In [76]:
question = "to whom did the virgin mary allegedly appear in 1858 in lourdes france ?"
context = 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary.'
answers = {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}
tokens = tokenizer(
    question,
    context,
    truncation="only_second",
    max_length=100,
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,)

print(tokens)
print("inputs_id: ", tokens['input_ids'])
print("overflow_to_sample_mapping: ", tokens['overflow_to_sample_mapping'])
print("off_mapping: ", tokens["offset_mapping"])
print(tokens.sequence_ids(3))
print(tokenizer.decode(tokens['input_ids'][0]))

{'input_ids': [[101, 2000, 3183, 2106, 1996, 6261, 2984, 9382, 3711, 1999, 8517, 1999, 10223, 26371, 2605, 1029, 102, 6549, 2135, 1010, 1996, 2082, 2038, 1037, 3234, 2839, 1012, 10234, 1996, 2364, 2311, 1005, 1055, 2751, 8514, 2003, 1037, 3585, 6231, 1997, 1996, 6261, 2984, 1012, 3202, 1999, 2392, 1997, 1996, 2364, 2311, 1998, 5307, 2009, 1010, 2003, 1037, 6967, 6231, 1997, 4828, 2007, 2608, 2039, 14995, 6924, 2007, 1996, 5722, 1000, 2310, 3490, 2618, 4748, 2033, 18168, 5267, 1000, 1012, 2279, 2000, 1996, 2364, 2311, 2003, 1996, 13546, 1997, 1996, 6730, 2540, 1012, 3202, 2369, 1996, 13546, 2003, 1996, 24665, 102], [101, 2000, 3183, 2106, 1996, 6261, 2984, 9382, 3711, 1999, 8517, 1999, 10223, 26371, 2605, 1029, 102, 2364, 2311, 1998, 5307, 2009, 1010, 2003, 1037, 6967, 6231, 1997, 4828, 2007, 2608, 2039, 14995, 6924, 2007, 1996, 5722, 1000, 2310, 3490, 2618, 4748, 2033, 18168, 5267, 1000, 1012, 2279, 2000, 1996, 2364, 2311, 2003, 1996, 13546, 1997, 1996, 6730, 2540, 1012, 3202, 2369, 19

#### Preprocessing train with 1 sample

In [77]:
# Dài hơn max_length
question = "to whom did the virgin mary allegedly appear in 1858 in lourdes france ?"
context = 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary.'
answers = {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}

inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

print(f"The example gave {len(inputs['input_ids'])} features.")
print(
    f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.")

start_positions = []
end_positions = []

# Corrected loop for single example: Iterate through features' offset_mapping
for i, offset in enumerate(inputs["offset_mapping"]):
    sample_idx = inputs["overflow_to_sample_mapping"][i]
    answer = answers  # answers is already a single dict
    start_char = answer["answer_start"][0]
    end_char = answer["answer_start"][0] + len(answer["text"][0])
    sequence_ids = inputs.sequence_ids(i)
    print(end_char)

    # Find the start and end of the context
    idx_seq = 0  # Use idx_seq to avoid confusion with feature index 'i'
    while sequence_ids[idx_seq] != 1:
        idx_seq += 1
    context_start = idx_seq
    while sequence_ids[idx_seq] == 1:
        idx_seq += 1
    context_end = idx_seq - 1

    # If the answer is not fully inside the context, label is (0, 0)
    if context_start >= len(offset) or context_end >= len(offset) or offset[context_start][0] > start_char or offset[context_end][1] < end_char:
        start_positions.append(0)
        end_positions.append(0)
    else:
        # Otherwise it's the start and end token positions
        idx_token = context_start  # Use idx_token to iterate over tokens in context
        while idx_token <= context_end and offset[idx_token][0] <= start_char:
            idx_token += 1
        start_positions.append(idx_token - 1)

        idx_token = context_end
        while idx_token >= context_start and offset[idx_token][1] >= end_char:
            idx_token -= 1
        end_positions.append(idx_token + 1)

start_positions, end_positions

The example gave 4 features.
Here is where each comes from: [0, 0, 0, 0].
541
541
541
541


([0, 0, 65, 33], [0, 0, 72, 40])

In [78]:
idx = 3
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers["text"][0]

start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs["input_ids"][idx][start: end + 1])

print(f"Theoretical answer: {answer}, labels give: {labeled_answer}")

Theoretical answer: Saint Bernadette Soubirous, labels give: to saint bernadette soubiro


#### Preprocessing train with mini-batch 

In [79]:
inputs = tokenizer(
    raw_datasets["train"][2:6]["question"],
    raw_datasets["train"][2:6]["context"],
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

print(f"The 4 examples gave {len(inputs['input_ids'])} features.")
print(
    f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.")

answers = raw_datasets["train"][2:6]["answers"]
start_positions = []
end_positions = []

for i, offset in enumerate(inputs["offset_mapping"]):
    sample_idx = inputs["overflow_to_sample_mapping"][i]
    answer = answers[sample_idx]
    start_char = answer["answer_start"][0]
    end_char = answer["answer_start"][0] + len(answer["text"][0])
    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of the context
    idx = 0
    while sequence_ids[idx] != 1:
        idx += 1
    context_start = idx
    while sequence_ids[idx] == 1:
        idx += 1
    context_end = idx - 1

    # If the answer is not fully inside the context, label is (0, 0)
    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
        start_positions.append(0)
        end_positions.append(0)
    else:
        # Otherwise it's the start and end token positions
        idx = context_start
        while idx <= context_end and offset[idx][0] <= start_char:
            idx += 1
        start_positions.append(idx - 1)

        idx = context_end
        while idx >= context_start and offset[idx][1] >= end_char:
            idx -= 1
        end_positions.append(idx + 1)

start_positions, end_positions

The 4 examples gave 17 features.
Here is where each comes from: [0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3].


([81, 49, 17, 0, 0, 57, 19, 33, 0, 0, 0, 63, 27, 0, 0, 0, 0],
 [83, 51, 19, 0, 0, 63, 25, 39, 0, 0, 0, 64, 28, 0, 0, 0, 0])

In [80]:
idx = 4
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]

start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs["input_ids"][idx][start: end + 1])

print(f"Theoretical answer: {answer}, labels give: {labeled_answer}")

Theoretical answer: a Marian place of prayer and reflection, labels give: [CLS]


#### Preprocessing validation 1 sample and mini-batch

In [81]:
question = "to whom did the virgin mary allegedly appear in 1858 in lourdes france ?"
context = 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary.'
answers = {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}

max_length = 100
stride = 50

def preprocess_single_example(question, context, example_id):
    """
    Preprocesses a single question-context example for validation, similar to preprocess_validation_examples.

    Args:
        question (str): The question string.
        context (str): The context string.
        example_id (str or int): A unique identifier for the example.

    Returns:
        dict: A dictionary containing the preprocessed inputs, including input_ids,
              attention_mask, offset_mapping, and example_id.
    """
    questions = [question.strip()]  # Wrap question in a list
    contexts = [context]  # Wrap context in a list
    examples = {"question": questions, "context": contexts,
                "id": [example_id]}  # Simulate examples batch

    inputs = tokenizer(
        questions,  # Use the list of questions
        contexts,  # Use the list of contexts
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",  # Keep padding for consistency
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs


# Example usage:
# You can use any identifier for your single example
example_id = "56be85543aeaaa14008c9063"
processed_input = preprocess_single_example(question, context, example_id)

print("Processed Input Keys:", processed_input.keys())
print("Input IDs (first feature):", processed_input["input_ids"])
print("Attention Mask (first feature):", processed_input["attention_mask"][0])
print("Offset Mapping (first feature):", processed_input["offset_mapping"])
print("Example ID:", processed_input["example_id"])

Processed Input Keys: dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'example_id'])
Input IDs (first feature): [[101, 2000, 3183, 2106, 1996, 6261, 2984, 9382, 3711, 1999, 8517, 1999, 10223, 26371, 2605, 1029, 102, 6549, 2135, 1010, 1996, 2082, 2038, 1037, 3234, 2839, 1012, 10234, 1996, 2364, 2311, 1005, 1055, 2751, 8514, 2003, 1037, 3585, 6231, 1997, 1996, 6261, 2984, 1012, 3202, 1999, 2392, 1997, 1996, 2364, 2311, 1998, 5307, 2009, 1010, 2003, 1037, 6967, 6231, 1997, 4828, 2007, 2608, 2039, 14995, 6924, 2007, 1996, 5722, 1000, 2310, 3490, 2618, 4748, 2033, 18168, 5267, 1000, 1012, 2279, 2000, 1996, 2364, 2311, 2003, 1996, 13546, 1997, 1996, 6730, 2540, 1012, 3202, 2369, 1996, 13546, 2003, 1996, 24665, 102], [101, 2000, 3183, 2106, 1996, 6261, 2984, 9382, 3711, 1999, 8517, 1999, 10223, 26371, 2605, 1029, 102, 2364, 2311, 1998, 5307, 2009, 1010, 2003, 1037, 6967, 6231, 1997, 4828, 2007, 2608, 2039, 14995, 6924, 2007, 1996, 5722, 1000, 2310, 3490, 2618, 4748, 2033, 18168, 5

In [82]:
def preprocess_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

#### Training

In [83]:
small_eval_set = raw_datasets["validation"].select(range(100))

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
eval_set = small_eval_set.map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)

Map: 100%|██████████| 100/100 [00:00<00:00, 1294.08 examples/s]


In [84]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [85]:
import torch
from transformers import AutoModelForQuestionAnswering

eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
eval_set_for_model.set_format("torch")

device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")
batch = {k: eval_set_for_model[k].to(device)
         for k in eval_set_for_model.column_names}
trained_model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME).to(
    device
)

with torch.no_grad():
    outputs = trained_model(**batch)

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [86]:
outputs

QuestionAnsweringModelOutput(loss=None, start_logits=tensor([[ 0.0117, -0.0059,  0.1130,  ..., -0.1730, -0.2052,  0.1237],
        [ 0.0970,  0.0544,  0.1646,  ...,  0.1479,  0.0669,  0.2438],
        [ 0.1428,  0.0769,  0.2369,  ...,  0.2013,  0.3976,  0.1529],
        ...,
        [ 0.0635, -0.1360,  0.0261,  ..., -0.0630,  0.0617,  0.1032],
        [ 0.1210,  0.0173,  0.0841,  ...,  0.2631,  0.1714,  0.1978],
        [ 0.1233,  0.0842,  0.1270,  ..., -0.0102,  0.2096,  0.2625]]), end_logits=tensor([[-0.1218,  0.0535, -0.1136,  ..., -0.2045, -0.0174,  0.0951],
        [-0.1907, -0.0134, -0.1944,  ..., -0.3828,  0.0265, -0.0081],
        [-0.1893,  0.0307, -0.1950,  ...,  0.3523, -0.1001, -0.0567],
        ...,
        [-0.1612,  0.0608, -0.1158,  ..., -0.1107, -0.1438, -0.2154],
        [-0.1466,  0.1031, -0.3443,  ...,  0.0935,  0.0599, -0.1210],
        [-0.1487,  0.0058, -0.1928,  ...,  0.0931,  0.2353,  0.0872]]), hidden_states=None, attentions=None)

In [87]:
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

In [88]:
start_logits[0][-1:-21:-1]

array([ 0.12367848, -0.20515595, -0.17301556,  0.22297172, -0.14648412,
        0.07698998, -0.25346652, -0.45339662, -0.14374146, -0.03400343,
       -0.7031467 , -0.40218648, -0.18835765, -0.1797778 , -0.14668193,
        0.05877595, -0.45716953, -0.29644898, -0.25350407, -0.05450182],
      dtype=float32)

In [89]:
start_logits.shape # từ 100 sample tokenizer ra 366 sample (lý do đã giải thích ở mục preprocesing ở mục train) 100 là maxlength

(221, 100)

#### Postprocessing

In [90]:
import collections

example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(eval_set):
    example_to_features[feature["example_id"]].append(idx)
example_to_features

defaultdict(list,
            {'56be4db0acb8001400a502ec': [0, 1, 2],
             '56be4db0acb8001400a502ed': [3, 4, 5],
             '56be4db0acb8001400a502ee': [6, 7, 8],
             '56be4db0acb8001400a502ef': [9, 10, 11],
             '56be4db0acb8001400a502f0': [12, 13, 14, 15],
             '56be8e613aeaaa14008c90d1': [16, 17, 18],
             '56be8e613aeaaa14008c90d2': [19, 20, 21],
             '56be8e613aeaaa14008c90d3': [22, 23, 24],
             '56bea9923aeaaa14008c91b9': [25, 26, 27],
             '56bea9923aeaaa14008c91ba': [28, 29, 30],
             '56bea9923aeaaa14008c91bb': [31, 32, 33],
             '56beace93aeaaa14008c91df': [34, 35, 36],
             '56beace93aeaaa14008c91e0': [37, 38, 39],
             '56beace93aeaaa14008c91e1': [40, 41, 42],
             '56beace93aeaaa14008c91e2': [43, 44, 45, 46],
             '56beace93aeaaa14008c91e3': [47, 48, 49],
             '56bf10f43aeaaa14008c94fd': [50, 51, 52, 53],
             '56bf10f43aeaaa14008c94fe': [54,

In [91]:
import numpy as np

n_best = 20
max_answer_length = 30
predicted_answers = []

for example in small_eval_set:
    example_id = example["id"]
    context = example["context"]
    answers = []

    for feature_index in example_to_features[example_id]:
        start_logit = start_logits[feature_index]
        end_logit = end_logits[feature_index]
        offsets = eval_set["offset_mapping"][feature_index]

        start_indexes = np.argsort(start_logit)[-1: -n_best - 1: -1].tolist()
        end_indexes = np.argsort(end_logit)[-1: -n_best - 1: -1].tolist()
        for start_index in start_indexes:
            for end_index in end_indexes:
                # Skip answers that are not fully in the context
                if offsets[start_index] is None or offsets[end_index] is None:
                    continue
                # Skip answers with a length that is either < 0 or > max_answer_length.
                if (
                    end_index < start_index
                    or end_index - start_index + 1 > max_answer_length
                ):
                    continue

                answers.append(
                    {
                        "text": context[offsets[start_index][0]: offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                )

    if len(answers) > 0:
        best_answer = max(answers, key=lambda x: x["logit_score"])
        predicted_answers.append(
            {"id": example_id, "prediction_text": best_answer["text"]})
    else:
        predicted_answers.append({"id": example_id, "prediction_text": ""})
        print("rỗng")

In [92]:
predicted_answers

[{'id': '56be4db0acb8001400a502ec',
  'prediction_text': ') champion Denver Broncos defeated the National Football Conference ('},
 {'id': '56be4db0acb8001400a502ed',
  'prediction_text': ') champion Denver Broncos defeated the National Football Conference ('},
 {'id': '56be4db0acb8001400a502ee',
  'prediction_text': ') for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference ('},
 {'id': '56be4db0acb8001400a502ef',
  'prediction_text': ') for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference ('},
 {'id': '56be4db0acb8001400a502f0',
  'prediction_text': 'Broncos defeated the National Football Conference ('},
 {'id': '56be8e613aeaaa14008c90d1',
  'prediction_text': '. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference ('},
 {'id': '56be8e613aeaaa14008c90d2',
  'prediction_text': '. The American 

In [98]:
import evaluate

metric = evaluate.load("squad")
theoretical_answers = [
    {"id": ex["id"], "answers": ex["answers"]} for ex in small_eval_set
]

print(predicted_answers[0])
print(theoretical_answers[0])

{'id': '56be4db0acb8001400a502ec', 'prediction_text': ') champion Denver Broncos defeated the National Football Conference ('}
{'id': '56be4db0acb8001400a502ec', 'answers': {'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 177, 177]}}


In [99]:
metric.compute(predictions=predicted_answers, references=theoretical_answers)

{'exact_match': 0.0, 'f1': 4.282815974874798}

#### Completed postprocessing

In [None]:
from tqdm.auto import tqdm


def compute_metrics(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(
                start_logit)[-1: -n_best - 1: -1].tolist()
            end_indexes = np.argsort(end_logit)[-1: -n_best - 1: -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0]: offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [
        {"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)


start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()
compute_metrics(start_logits, end_logits, eval_set, small_eval_set)

100%|██████████| 100/100 [00:00<00:00, 572.81it/s]


{'exact_match': 0.0, 'f1': 4.282815974874798}