https://huggingface.co/datasets/qiaojin/PubMedQA

In [13]:
import logging
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import random
import string

In [2]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("PubMedQA Evaluation")

In [4]:
MODEL_NAME = "HPAI-BSC/Llama3-Aloe-8B-Alpha"
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Downloading shards: 100%|██████████| 4/4 [04:24<00:00, 66.17s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.33it/s]


In [11]:
messages = [
    {"role": "system", "content": "Asnwer the question only with yes, no or maybe. You don't need say anything else."},
    {"role": "user", "content": "Are group 2 innate lymphoid cells ( ILC2s ) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

outputs = model.generate(
    input_ids,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
answer = tokenizer.decode(response, skip_special_tokens=True)

print(answer)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Yes.


In [12]:
logger.info("Loading PubMedQA dataset...")
dataset = load_dataset("qiaojin/PubMedQA", "pqa_labeled", streaming=True)["train"]

INFO:PubMedQA Evaluation:Loading PubMedQA dataset...


In [19]:
MAX_SAMPLES = 5 # Limit the dataset size for quicker experimentation

# Convert the iterable dataset to a list and then slice it
dataset_list = list(dataset)
subset = dataset_list[:min(MAX_SAMPLES, len(dataset_list))]

def generate_answer(question, context, mode):
    if mode in ["rag", "raft"]:
        messages = [
            {"role": "system", "content": f"Asnwer the question only with yes, no or maybe. You don't need say anything else. Context: {context}"},
            {"role": "user", "content": f"{question}"},
        ]
    else:
        messages = [
            {"role": "system", "content": "Asnwer the question only with yes, no or maybe. You don't need say anything else."},
            {"role": "user", "content": f"{question}"},
        ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = model.generate(
        input_ids,
        max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
    )
    response = outputs[0][input_ids.shape[-1]:]
    answer = tokenizer.decode(response, skip_special_tokens=True)
    
    return answer.translate(str.maketrans('', '', string.punctuation)).split()[0]

In [15]:
NUM_DISTRACT = 3

def evaluate_model(dataset, mode):

    correct = 0
    total = 0

    for data in dataset:
        i = dataset.index(data)
        question = data["question"]
        if mode == "rag":
            context = "".join(data["context"]["contexts"])
        elif mode == "raft":
            docs = data["context"]["contexts"]
            # Add distractors to the context
            indices = list(range(0, len(dataset)))
            indices.remove(i)
            for j in random.sample(indices, NUM_DISTRACT):
                docs.append("".join(dataset[j]["context"]["contexts"]))
            # Decide whether to add oracle document
            oracle = random.uniform(0, 1) < 0.8
            if not oracle:
                docs[0] = "".join(dataset[i]["context"]["contexts"])
            random.shuffle(docs)
            context = "".join(docs)
        else:
            context = None
        final_decision = data["final_decision"]

        generated_answer = generate_answer(question, context=context, mode=mode)
        logger.info(f"Answer: {generated_answer}, Expected: {final_decision}")

        if generated_answer.lower() == final_decision.lower():
            correct += 1

        total += 1
        logger.info(f"Processed {total}/{len(dataset)} samples in {mode} mode.")

    return correct / total if total > 0 else 0.0

In [20]:
logger.info("Evaluating RAFT mode...")
raft_accuracy = evaluate_model(subset, mode="raft")


INFO:PubMedQA Evaluation:Evaluating RAFT mode...
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
INFO:PubMedQA Evaluation:Answer: Yes, Expected: yes
INFO:PubMedQA Evaluation:Processed 1/5 samples in raft mode.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
INFO:PubMedQA Evaluation:Answer: Yes, Expected: no
INFO:PubMedQA Evaluation:Processed 2/5 samples in raft mode.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:12

In [21]:

logger.info("Evaluating RAG mode...")
rag_accuracy = evaluate_model(subset, mode="rag")


INFO:PubMedQA Evaluation:Evaluating RAG mode...
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
INFO:PubMedQA Evaluation:Answer: Yes, Expected: yes
INFO:PubMedQA Evaluation:Processed 1/5 samples in rag mode.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
INFO:PubMedQA Evaluation:Answer: No, Expected: no
INFO:PubMedQA Evaluation:Processed 2/5 samples in rag mode.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001

In [22]:

logger.info("Evaluating Direct mode...")
direct_accuracy = evaluate_model(subset, mode="direct")

INFO:PubMedQA Evaluation:Evaluating Direct mode...
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
INFO:PubMedQA Evaluation:Answer: Maybe, Expected: yes
INFO:PubMedQA Evaluation:Processed 1/5 samples in direct mode.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
INFO:PubMedQA Evaluation:Answer: Maybe, Expected: no
INFO:PubMedQA Evaluation:Processed 2/5 samples in direct mode.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_to

In [23]:
print("\nEvaluation Results:")
print(f"RAFT Accuracy: {raft_accuracy * 100:.2f}%")
print(f"RAG Accuracy: {rag_accuracy * 100:.2f}%")
print(f"Direct Accuracy: {direct_accuracy * 100:.2f}%")


Evaluation Results:
RAFT Accuracy: 80.00%
RAG Accuracy: 100.00%
Direct Accuracy: 20.00%
