In [25]:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from datasets import load_dataset
import evaluate

model_folder = "./results2/checkpoint-3880"
tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
model = AutoModelForQuestionAnswering.from_pretrained(model_folder)
model.eval()  # Set to evaluation mode

# -------------------
# 2. Load the Validation Dataset
# -------------------
# Assuming the dataset is in a SQuAD-like JSON format.
dataset = load_dataset('json', data_files='dataset_augmented.json')['train']
# Optionally, split off a portion for evaluation.
split_dataset = dataset.train_test_split(test_size=0.1)
val_dataset = split_dataset["test"]


# -------------------
# 3. Preprocess the Data
# -------------------
def preprocess_function(examples):
    """
    For each example, tokenize the question and context together. Also, convert the raw answer span (start position in chars)
    into token indices that the model will use as labels.
    """
    # Tokenize question and context together.
    inputs = tokenizer(examples["question"], examples["context"],
                       truncation=True,
                       padding="max_length",
                       max_length=256,
                       return_offsets_mapping=True)

    offset_mappings = inputs.pop("offset_mapping")  # remove offsets (used only to align char indices)
    start_positions = []
    end_positions = []

    #cada elemento no "offsets" representa a posiçao inicial e posiçao final de cada token da concatenação da questão com o contexto
    for i, offsets in enumerate(offset_mappings):
        # For the first answer in each example
        answer = examples["answers"][i]["text"]
        answer_start_char = examples["answers"][i]["answer_start"]
        answer_end_char = answer_start_char + len(answer)

        # Find the token indices corresponding to the start and end character positions.
        start_index = None
        end_index = None
        for idx, (start_char, end_char) in enumerate(offsets):
            if start_char <= answer_start_char < end_char:
                start_index = idx
            if start_char < answer_end_char <= end_char:
                end_index = idx
        # In some cases the answer might not align perfectly with a token span.
        if start_index is None:
            start_index = 0
        if end_index is None:
            end_index = len(offsets) - 1

        start_positions.append(start_index)
        end_positions.append(end_index)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

def flatten_squad_like_data(example):
    new_contexts = []
    new_questions = []
    new_answers = []

    # Pull out the context just once
    context_text = example["context"]

    for i in range(0,  len(context_text)):
        # Each 'qa' within 'qas' is a question + list of answers
        for qa in example["qas"][i]:
            answers = qa["answers"]
            for answer in answers:
                new_contexts.append(context_text[i])
                new_questions.append(qa["question"])
                new_answers.append({"text": answer["text"], "answer_start": answer["answer_start"]})

    # Return the new, flattened lists. We rely on the 'map' call with 'batched=True'
    # to automatically collate these into a single output set of columns.
    return {
        "context": new_contexts,
        "question": new_questions,
        "answers": new_answers
    }


val_dataset = val_dataset.map(
    flatten_squad_like_data,
    batched=True,
    remove_columns=dataset.column_names
)

# Tokenize the validation dataset.
tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True, remove_columns=val_dataset.column_names)


# Optionally, add dummy ids (or ensure your dataset has a unique id per sample).
def add_dummy_ids(example, idx):
    example["id"] = str(idx)
    return example


tokenized_val_dataset = tokenized_val_dataset.map(add_dummy_ids, with_indices=True)


# -------------------
# 4. Define Inference Functions
# -------------------
def predict(example):
    """
    Given a single example, perform the tokenization,
    run the model to get logits, and decode the predicted answer span.
    """
    # Tokenize the input (question and context) just as during preprocessing.
    inputs = tokenizer(example["question"], example["context"],
                       return_tensors="pt", truncation=True, padding="max_length", max_length=256)

    with torch.no_grad():
        outputs = model(**inputs)

    # Get logits for start and end positions.
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits

    # Choose the token with maximum score for start and end.
    start_index = torch.argmax(start_logits, dim=1).item()
    end_index = torch.argmax(end_logits, dim=1).item()

    # Decode the answer span
    input_ids = inputs["input_ids"].squeeze()
    answer_tokens = input_ids[start_index:end_index + 1]
    answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
    return answer


def generate_predictions(tokenized_dataset):
    """
    Iterate through the tokenized dataset and generate predictions.
    Builds a predictions list and a corresponding list of reference annotations.
    """
    predictions_list = []
    references_list = []

    for index, example in enumerate(tokenized_dataset):
        pred_answer = predict(example)

        # Save prediction in the SQuAD expected format.
        predictions_list.append({"id": str(index), "prediction_text": pred_answer})
        # Reference should be formatted with an "answers" key.
        references_list.append({
            "id": str(index),
            "answers": [example["answers"]]
        })

    return predictions_list, references_list


# -------------------
# 5. Compute the Metrics (F1 and Exact Match)
# -------------------
# Load the squad evaluation metric.
squad_metric = evaluate.load("squad")

# Generate predictions on the validation set.
predictions, references = generate_predictions(val_dataset)

# Compute metrics.
results = squad_metric.compute(predictions=predictions, references=references)

print("Evaluation Results:")
print("Eval dataset size", len(val_dataset))
print("F1 Score:", results["f1"])
print("Exact Match:", results["exact_match"])

Map:   0%|          | 0/134 [00:00<?, ? examples/s]

Map:   0%|          | 0/216 [00:00<?, ? examples/s]

Map:   0%|          | 0/216 [00:00<?, ? examples/s]

Evaluation Results:
Eval dataset size 216
F1 Score: 54.5583708773515
Exact Match: 30.09259259259259
