In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
from datasets import load_dataset

model_path = "./results/checkpoint-472"

#tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
#model = AutoModelForQuestionAnswering.from_pretrained("medicalai/ClinicalBERT")

tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
model = AutoModelForQuestionAnswering.from_pretrained(model_path)

# Example/dummy dataset: each sample contains a context, a question, and the ground-truth answer.
dataset = load_dataset('json', data_files='dataset.json')['train']

In [51]:
# Util functions

def flatten_squad_like_data(example):
    new_contexts = []
    new_questions = []
    new_answers = []

    # Pull out the context just once
    context_text = example["context"]

    for i in range(0,  len(context_text)):
        # Each 'qa' within 'qas' is a question + list of answers
        for qa in example["qas"][i]:
            answers = qa["answers"]
            for answer in answers:
                new_contexts.append(context_text[i])
                new_questions.append(qa["question"])
                new_answers.append({"text": answer["text"], "answer_start": answer["answer_start"]})

    # Return the new, flattened lists. We rely on the 'map' call with 'batched=True'
    # to automatically collate these into a single output set of columns.
    return {
        "context": new_contexts,
        "question": new_questions,
        "answers": new_answers
    }

# Calculates the most similar questions and returns the contexts from the dataset
import evaluate
import torch
import torch.nn.functional as F

# Load MedicalBERT tokenizer and model
model.eval()  # Set the model to evaluation mode

def get_contexts(user_question):
    # Calculate the embedding for the user question
    user_embedding = get_embedding(user_question)

    # Compute cosine similarity with each question in the dataset and store results
    results = []
    for entry in dataset:
        context = entry["context"]
        for qa in entry["qas"]:
            question = qa["question"]
            question_embedding = get_embedding(question)
            # Compute cosine similarity between the user question embedding and the dataset question embedding
            similarity = F.cosine_similarity(user_embedding, question_embedding, dim=0)
            results.append({
                "context": context,
                "question": question,
                "similarity": similarity.item()
            })

    # Get the top 3 similar questions (highest cosine similarity)
    top_results = sorted(results, key=lambda x: x["similarity"], reverse=True)[:3]

    top_contexts_dict = {entry["context"]: entry for entry in top_results}
    unique_results = list(top_contexts_dict.values())

    result_to_return = ""
    # Display the top matching contexts, questions, and similarity scores
    for result in unique_results:
        result_to_return += result["context"]

    return result_to_return

def get_embedding(text):
    """
    Calculates the embedding for a given text using the MedicalBERT model.

    This function tokenizes the input text, runs the model to extract the hidden states,
    and then performs mean pooling over the last hidden state to generate a fixed-size embedding.
    """
    inputs = tokenizer(text, return_tensors='pt')

    with torch.no_grad():
        # Enable output_hidden_states so we can use the hidden representations
        outputs = model(**inputs, output_hidden_states=True)

    # Use the last hidden state [batch, sequence_length, hidden_dim]
    # and then perform mean pooling over the sequence_length dimension.
    last_hidden_state = outputs.hidden_states[-1]
    embedding = last_hidden_state.mean(dim=1)  # shape: [batch, hidden_dim]

    return embedding.squeeze(0)  # Remove batch dimension for similarity computation

def preprocess_function(examples):
    """
    For each example, tokenize the question and context together. Also, convert the raw answer span (start position in chars)
    into token indices that the model will use as labels.
    """
    # Tokenize question and context together.
    inputs = tokenizer(examples["question"], examples["context"],
                       truncation=True,
                       padding="max_length",
                       max_length=256,
                       return_offsets_mapping=True)

    offset_mappings = inputs.pop("offset_mapping")  # remove offsets (used only to align char indices)
    start_positions = []
    end_positions = []

    #cada elemento no "offsets" representa a posiçao inicial e posiçao final de cada token da concatenação da questão com o contexto
    for i, offsets in enumerate(offset_mappings):
        # For the first answer in each example
        answer = examples["answers"][i]["text"]
        answer_start_char = examples["answers"][i]["answer_start"]
        answer_end_char = answer_start_char + len(answer)

        # Find the token indices corresponding to the start and end character positions.
        start_index = None
        end_index = None
        for idx, (start_char, end_char) in enumerate(offsets):
            if start_char <= answer_start_char < end_char:
                start_index = idx
            if start_char < answer_end_char <= end_char:
                end_index = idx
        # In some cases the answer might not align perfectly with a token span.
        if start_index is None:
            start_index = 0
        if end_index is None:
            end_index = len(offsets) - 1

        start_positions.append(start_index)
        end_positions.append(end_index)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

squad_metric = evaluate.load("squad")

def postprocess_qa_predictions(predictions, features, max_answer_length=30):
    """
    Convert raw model predictions (start/end logits) into answer text predictions.
    Assumes that the examples and features have an 'id' field to uniquely identify each example.
    This example is simplified; see Hugging Face's run_qa.py for a complete version.
    """
    all_start_logits, all_end_logits = predictions
    predictions_dict = {}

    for i, feature in enumerate(features):
        # We assume each feature has an "offset_mapping" that maps tokens to character spans,
        # and an "example_id" to link back to the original example.
        offsets = feature["offset_mapping"]
        input_ids = feature["input_ids"]
        start_logits = all_start_logits[i]
        end_logits = all_end_logits[i]

        # Find the best token span for the answer
        best_score = -float("inf")
        best_start, best_end = None, None
        for start_index, start_logit in enumerate(start_logits):
            for end_index, end_logit in enumerate(end_logits[start_index:]):
                # Check condition: maximum answer length, valid offsets (ignore special tokens)
                if (end_index + start_index) < len(offsets) and (end_index <= max_answer_length):
                    score = start_logit + end_logit
                    if score > best_score:
                        best_score = score
                        best_start = start_index
                        best_end = start_index + end_index
        # Convert token indices to answer text using offset mappings
        if best_start is not None and best_end is not None:
            start_char = offsets[best_start][0]
            end_char = offsets[best_end][1]
            # Use the original context to recover text; assuming feature has a "context" key.
            answer_text = feature["context"][start_char: end_char]
        else:
            answer_text = ""
        # Map predictions to example IDs (this requires that your features have an "id" or "example_id")
        example_id = feature["id"] if "id" in feature else feature.get("example_id", i)
        predictions_dict[example_id] = answer_text

    # Build a list of prediction dicts as expected by the squad_metric.
    predictions_list = [{"id": id_, "prediction_text": text} for id_, text in predictions_dict.items()]
    return predictions_list


def compute_metrics(eval_pred):
    raw_predictions, raw_label_ids = eval_pred

    # Postprocess predictions to generate answer texts (implementation omitted for brevity)
    predictions = postprocess_qa_predictions(raw_predictions, tokenized_val_dataset)

    # Create references by enumerating through the validation dataset
    references = []
    for idx, example in enumerate(tokenized_val_dataset):
        # Use the index as a string id
        references.append({
            "id": str(idx),
            "answers": example["answers"]  # assuming this is formatted as needed
        })

    results = squad_metric.compute(predictions=predictions, references=references)
    return {
        "exact_match": results["exact_match"],
        "f1": results["f1"]
    }

def add_dummy_ids(example, idx):
    example["id"] = str(idx)
    return example


In [None]:
# Fine tune the model

train_dataset = dataset.map(
    flatten_squad_like_data,
    batched=True,
    remove_columns=dataset.column_names
)

# Tokenize the dataset.
tokenized_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)

split_dataset = tokenized_dataset.train_test_split(test_size=0.1)
tokenized_train_dataset = split_dataset["train"]
tokenized_val_dataset = split_dataset["test"]
tokenized_val_dataset = tokenized_val_dataset.map(add_dummy_ids, with_indices=True)
# Set up training arguments.
training_args = TrainingArguments(
    output_dir="./results2",
    #evaluation_strategy="no",  # for simplicity we won't run an evaluation loop here
    num_train_epochs=2,
    per_device_train_batch_size=1,
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_steps=1,
    save_steps=1000  # Save infrequently for the demo
)

# Initialize the Trainer with our model, training args, and tokenized dataset.
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    compute_metrics=compute_metrics,
)

# Fine tune the model on the QA dataset.
trainer.train()

# -------------------------
# Testing/inference step:
# Let’s take a question and its context, then have the model extract an answer.

question = "What medication was given?"
context = "In the hospital, the patient was administered amoxicillin for a bacterial infection."

# Tokenize the input: note that for QA we combine question and context.
inputs = tokenizer(question, context, return_tensors="pt", truncation=True, padding="max_length", max_length=256)

with torch.no_grad():
    outputs = model(**inputs)

# The model outputs 'start_logits' and 'end_logits', which represent the probability distribution
# for the start and end of the answer span.
start_logits = outputs.start_logits
end_logits = outputs.end_logits

# Identify the most likely token positions for the start and end of the answer.
start_index = torch.argmax(start_logits, dim=1).item()
end_index = torch.argmax(end_logits, dim=1).item()

# Decode the tokens corresponding to the predicted span.
input_ids = inputs["input_ids"].squeeze()
answer_tokens = input_ids[start_index:(end_index + 1)]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)

# Output the results.
print("Question:", question)
print("Predicted Answer:", answer)

In [53]:
# Model metrics
predictions, label_ids, metrics = trainer.predict(tokenized_val_dataset)

eval_results = trainer.evaluate()
print("Evaluation Metrics:", eval_results)

KeyError: 'offset_mapping'

In [None]:
# Test pose question1

question1 = "What does Alzheimer's disease cause?"
context1 = get_contexts(question1)

# Tokenize the input: note that for QA we combine question and context.
inputs = tokenizer(question1, context1, return_tensors="pt", truncation=True, padding="max_length", max_length=256)

with torch.no_grad():
    outputs = model(**inputs)

# The model outputs 'start_logits' and 'end_logits', which represent the probability distribution
# for the start and end of the answer span.
start_logits = outputs.start_logits
end_logits = outputs.end_logits

# Identify the most likely token positions for the start and end of the answer.
start_index = torch.argmax(start_logits, dim=1).item()
end_index = torch.argmax(end_logits, dim=1).item()

# Decode the tokens corresponding to the predicted span.
input_ids = inputs["input_ids"].squeeze()
answer_tokens = input_ids[start_index:(end_index + 1)]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)

# Output the results.
print("Question:", question1)
print("Predicted Answer:", answer)

In [None]:
# Test pose question2 with pipeline
from transformers import pipeline

qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)

# Sample context and question (replace these wi
#
# th your own data)
question2 = "O que acontece socialmente com pessoas com Alzheimer em grande parte dos casos?"
context2 = get_contexts(question2)

print("Context:", context2)

# Get the model's answer
result_pipeline = qa_pipeline(context=context2, question=question2)

print("Answer:", result_pipeline["answer"])