# Fine Tuning Falcon-7B model on the PubMedQA dataset

This notebook tests the Falcon-7B model on the PubMedQA dataset to answer biomedical questions using provided contexts. Leverage PEFT library from Hugging Face ecosystem, as well as QLoRA for more memory efficient finetuning.

We evaluate the first 10 samples (indices 0-9) and use a lightweight DistilBERT model to judge the responses for correctness, evidence alignment, and clarity. The process includes generating answers, scoring them, and calculating metrics like accuracy, BERTScore, and ROUGE, all optimized for a T4 GPU setup.

## Setup

Run the cells below to setup and install the required libraries.

In [None]:
!pip install -qU bitsandbytes transformers datasets accelerate loralib einops xformers
!pip install -q -U git+https://github.com/huggingface/peft.git

import os
import bitsandbytes as bnb
import pandas as pd
import torch
import torch.nn as nn
import transformers
from datasets import load_dataset
from peft import (
    LoraConfig,
    PeftConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


## Loading the Pre-Trained Model

In [None]:
model_id = "tiiuae/falcon-7b"

# Configure for 8-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_config,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Configuring LoRA

In [None]:
# Prepare model for LoRA fine-tuning
model = prepare_model_for_kbit_training(model)

# Configure LoRA
lora_alpha = 32  # scaling factor for the weight matrices
lora_dropout = 0.05  # dropout probability of the LoRA layers
lora_rank = 32  # dimension of the low-rank matrices

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_rank,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        # Setting names of modules in falcon-7b model that we want to apply LoRA to
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]
)

peft_model = get_peft_model(model, peft_config)

You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.


## Loading and Preparing the Dataset


In [27]:
# Load PubMedQA Labeled Dataset
dataset = load_dataset("qiaojin/PubMedQA", "pqa_labeled", split="train")
print(f"Dataset size: {len(dataset)}")

# Inspect a few examples
print("\nSample Data Examples:")
for i in range(10):
    print(f"\nExample {i+1}:")
    print(f"Question: {dataset[i]['question']}")
    # Access the 'context' as a string before slicing
    context = " ".join(dataset[i]['context']['contexts'])
    print(f"Context: {context[:200]}...")  # Truncate context for brevity
    print(f"Long Answer: {dataset[i]['long_answer']}")
    print(f"Final Decision: {dataset[i]['final_decision']}")

Dataset size: 1000

Sample Data Examples:

Example 1:
Question: Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?
Context: Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant co...
Long Answer: Results depicted mitochondrial dynamics in vivo as PCD progresses within the lace plant, and highlight the correlation of this organelle with other organelles during developmental PCD. To the best of our knowledge, this is the first report of mitochondria and chloroplasts moving on transvacuolar strands to form a ring structure surrounding the nucleus during developmental PCD. Also, for the first time, we have shown the feasibility for the use of CsA in a whole plant system. Overall, our findings implicate the mitochondria as playing a critical and early role in developmentally regulated PCD in the lace plant.
F

In [28]:
def generate_prompt(data_point):
    PROMPT_TEMPLATE = """<|system|>You are a helpful medical assistant.<|endoftext>
<|user|>Question: {question}
Context: {context}<|endoftext>
<|assistant|>Answer: {answer}
Final Decision: {decision}<|endoftext>"""
    return PROMPT_TEMPLATE.format(
        question=data_point["question"],
        context=data_point["context"],
        answer=data_point["long_answer"],
        decision=data_point["final_decision"]
    )

def generate_and_tokenize_prompt(data_point):
    full_prompt = generate_prompt(data_point)
    return tokenizer(full_prompt, padding=True, truncation=True, max_length=384)

dataset = dataset.shuffle(seed=42).map(
    generate_and_tokenize_prompt,
)
train_dataset = dataset.select(range(900))
test_dataset = dataset.select(range(900, 1000))

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

## Setting Up the Training Arguments

In [None]:
# Training Arguments
OUTPUT_DIR = "/falcon-7b-pubmedqa"
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)
training_args = transformers.TrainingArguments(
    auto_find_batch_size=True,
    per_device_train_batch_size=4,
    num_train_epochs=1,
    learning_rate=2e-4,
    fp16=True,
    save_total_limit=2,
    logging_steps=10,
    save_strategy="steps",
    save_steps=200,
    max_steps=-1,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    report_to="none",
    output_dir=OUTPUT_DIR
)

## Model Training

In [None]:
# 8. Trainer setup
trainer = transformers.Trainer(
    model=peft_model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    args=training_args,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    callbacks=[MemoryMonitorCallback()]
)

# 9. Enable model caching to improve performance
peft_model.config.use_cache = False


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.



Step 0 Memory: 8.11GB allocated, 9.96GB reserved


Step,Training Loss
10,1.4141
20,1.4088
30,1.388
40,1.3466
50,1.3655
60,1.366
70,1.4172
80,1.4805
90,1.4955
100,1.4244



Step 1 Memory: 8.59GB allocated, 10.77GB reserved

Step 2 Memory: 8.59GB allocated, 11.14GB reserved

Step 3 Memory: 8.60GB allocated, 11.15GB reserved

Step 4 Memory: 8.59GB allocated, 11.15GB reserved

Step 5 Memory: 8.59GB allocated, 11.15GB reserved

Step 6 Memory: 8.59GB allocated, 11.15GB reserved

Step 7 Memory: 8.59GB allocated, 11.15GB reserved

Step 8 Memory: 8.59GB allocated, 11.15GB reserved

Step 9 Memory: 8.60GB allocated, 11.15GB reserved

Step 10 Memory: 8.59GB allocated, 11.15GB reserved

Step 11 Memory: 8.60GB allocated, 11.15GB reserved

Step 12 Memory: 8.59GB allocated, 11.15GB reserved

Step 13 Memory: 8.59GB allocated, 11.15GB reserved

Step 14 Memory: 8.59GB allocated, 11.15GB reserved

Step 15 Memory: 8.60GB allocated, 11.15GB reserved

Step 16 Memory: 8.59GB allocated, 11.15GB reserved

Step 17 Memory: 8.60GB allocated, 11.15GB reserved

Step 18 Memory: 8.59GB allocated, 11.15GB reserved

Step 19 Memory: 8.59GB allocated, 11.15GB reserved

Step 20 Memory: 8.60

  return fn(*args, **kwargs)



Step 201 Memory: 8.60GB allocated, 11.52GB reserved

Step 202 Memory: 8.60GB allocated, 11.52GB reserved

Step 203 Memory: 8.59GB allocated, 11.52GB reserved

Step 204 Memory: 8.59GB allocated, 11.52GB reserved

Step 205 Memory: 8.59GB allocated, 11.52GB reserved

Step 206 Memory: 8.59GB allocated, 11.52GB reserved

Step 207 Memory: 8.59GB allocated, 11.52GB reserved

Step 208 Memory: 8.60GB allocated, 11.52GB reserved

Step 209 Memory: 8.59GB allocated, 11.52GB reserved

Step 210 Memory: 8.60GB allocated, 11.52GB reserved

Step 211 Memory: 8.59GB allocated, 11.52GB reserved

Step 212 Memory: 8.60GB allocated, 11.52GB reserved

Step 213 Memory: 8.59GB allocated, 11.52GB reserved

Step 214 Memory: 8.59GB allocated, 11.52GB reserved

Step 215 Memory: 8.59GB allocated, 11.52GB reserved

Step 216 Memory: 8.59GB allocated, 11.52GB reserved

Step 217 Memory: 8.59GB allocated, 11.52GB reserved

Step 218 Memory: 8.59GB allocated, 11.52GB reserved

Step 219 Memory: 8.60GB allocated, 11.52GB re

TrainOutput(global_step=225, training_loss=1.4157688734266494, metrics={'train_runtime': 1684.8931, 'train_samples_per_second': 0.534, 'train_steps_per_second': 0.134, 'total_flos': 1.38755472850944e+16, 'train_loss': 1.4157688734266494, 'epoch': 1.0})

In [None]:
# Save the Model
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

# Define final_output_dir variable
final_output_dir = OUTPUT_DIR # Assign the correct directory to final_output_dir

# Zip the model directory
!zip -r falcon-7b-pubmedqa-final.zip {final_output_dir}

# Download the zip file
from google.colab import files
files.download("falcon-7b-pubmedqa-final.zip")

  adding: falcon-7b-pubmedqa/ (stored 0%)
  adding: falcon-7b-pubmedqa/special_tokens_map.json (deflated 49%)
  adding: falcon-7b-pubmedqa/training_args.bin (deflated 51%)
  adding: falcon-7b-pubmedqa/checkpoint-225/ (stored 0%)
  adding: falcon-7b-pubmedqa/checkpoint-225/special_tokens_map.json (deflated 49%)
  adding: falcon-7b-pubmedqa/checkpoint-225/training_args.bin (deflated 51%)
  adding: falcon-7b-pubmedqa/checkpoint-225/trainer_state.json (deflated 74%)
  adding: falcon-7b-pubmedqa/checkpoint-225/rng_state.pth (deflated 25%)
  adding: falcon-7b-pubmedqa/checkpoint-225/adapter_model.safetensors (deflated 8%)
  adding: falcon-7b-pubmedqa/checkpoint-225/README.md (deflated 66%)
  adding: falcon-7b-pubmedqa/checkpoint-225/scaler.pt (deflated 60%)
  adding: falcon-7b-pubmedqa/checkpoint-225/scheduler.pt (deflated 56%)
  adding: falcon-7b-pubmedqa/checkpoint-225/tokenizer_config.json (deflated 84%)
  adding: falcon-7b-pubmedqa/checkpoint-225/optimizer.pt (deflated 9%)
  adding: falc

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Testing model with simple prompt

In [52]:
def generate_response(model, tokenizer, question, context, max_new_tokens=200):
    prompt = f"<|system|>You are a helpful medical assistant.<|endoftext|>\n<|user|>Question: {question}\nContext: {context}<|endoftext|>\n<|assistant|>"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    model.config.use_cache = False
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.eos_token_id
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=False).split("<|assistant|>")[1].split("<|endoftext|>")[0].strip()

def get_decision(model_answer):
    answer_lower = model_answer.lower()

    # First check for explicit decision statements (most reliable)
    explicit_patterns = [
        # Check for "Final Decision: X" format
        (r"final\s+decision\s*:\s*(yes|no|maybe)", lambda m: m.group(1)),
        # Check for "The answer is X" format
        (r"the\s+answer\s+is\s+(yes|no|maybe)", lambda m: m.group(1)),
        # Check for "conclusion: X" format
        (r"conclusion\s*:\s*(.*?)(yes|no|maybe)", lambda m: m.group(2)),
        # Check for end-sentence declarations
        (r"(^|\s)in\s+conclusion,\s+(.*?)(yes|no|maybe)", lambda m: m.group(3))
    ]

    import re
    for pattern, extractor in explicit_patterns:
        match = re.search(pattern, answer_lower)
        if match:
            return extractor(match)

    affirmative_phrases = [
        "is effective", "does work", "is beneficial", "is recommended",
        "is significant", "is proven", "is confirmed", "should be",
        "is cost-effective", "plays a role"
    ]

    negative_phrases = [
        "is not effective", "doesn't work", "does not work",
        "is not beneficial", "is not recommended", "not significant",
        "not proven", "not confirmed", "should not be",
        "is not cost-effective", "doesn't play a role", "does not play a role"
    ]

    # Check negative phrases first (they're usually more specific)
    for phrase in negative_phrases:
        if phrase in answer_lower:
            return "no"

    for phrase in affirmative_phrases:
        if phrase in answer_lower:
            return "yes"

    yes_count = 0
    no_count = 0

    # Split into sentences to analyze context better
    sentences = re.split(r'[.!?]+', answer_lower)
    for sentence in sentences:
        # Skip sentences with negation patterns that would confuse simple matching
        if any(neg in sentence for neg in ["not ", "n't ", "no "]):
            continue

        # Count positive/negative indicators in clean sentences
        if "yes" in sentence or "confirm" in sentence or "positive" in sentence:
            yes_count += 1
        if "no " in sentence or "not " in sentence or "negative" in sentence or "doesn't" in sentence:
            no_count += 1

    # Make decision based on counts
    if yes_count > no_count:
        return "yes"
    elif no_count > yes_count:
        return "no"

    # Default to "maybe" if ambiguous or no clear decision
    return "maybe"

def test_model_with_example(model, tokenizer, example_idx=0):
    example = test_dataset[example_idx]
    question = example["question"]
    context = " ".join(example["context"]["contexts"]) if isinstance(example["context"], dict) else example["context"]
    expected_answer = example["long_answer"]
    true_decision = example["final_decision"]

    model_answer = generate_response(model, tokenizer, question, context)
    model_decision = get_decision(model_answer)

    print(f"Question: {question}")
    print(f"Context: {context[:200]}...")
    print(f"Inference Answer (Expected): {expected_answer[:200]}...")
    print(f"Model Answer: {model_answer}")
    print(f"Model Decision: {model_decision}")
    print(f"True Decision: {true_decision}")

In [54]:
# Test first 10 examples
for i in range(10):
    print(f"\n=== EXAMPLE {i} ===")
    test_model_with_example(peft_model, tokenizer, example_idx=i)


=== EXAMPLE 0 ===
Question: Malnutrition, a new inducer for arterial calcification in hemodialysis patients?
Context: Arterial calcification is a significant cardiovascular risk factor in hemodialysis patients. A series of factors are involved in the process of arterial calcification; however, the relationship betwee...
Inference Answer (Expected): Malnutrition is prevalent in hemodialysis patients and is associated with arterial calcification and the expressions of BMP2 and MGP in calcified radial arteries. Malnutrition may be a new inducer can...
Model Answer: Conclusion: Malnutrition is an important risk factor for arterial calcification in hemodialysis patients.
<|endoftext>
</|endoftext>
<|endofquestions>
<|startofanswers>See answer<|endofanswers>
<|endofanswers>
<|endoftext>
<|endofcase>
<|mosfet|>Methods: 68 patients were divided into 2 groups: malnourished group (n = 34) and normal group (n = 34). The data of malnutrition, clinical characteristics and laboratory tests were com

In [56]:
!pip install -q bert-score rouge-score

  Preparing metadata (setup.py) ... [?25l[?25hdone
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone


In [68]:
import torch
import re
from rouge_score import rouge_scorer
from bert_score import score as bert_score
from sklearn.metrics import f1_score, classification_report
import numpy as np

## Testing model with enhanced prompt

In [103]:
def create_prompt(question, context):
    """
    Create a prompt for medical QA with strict instructions and decision criteria.
    """
    prompt = f"""<|system|>
You are a reliable medical assistant. Your task is to answer medical questions **strictly using the provided context only**.

For every question:
- Give a concise, evidence-based answer based on the context.
- Conclude with exactly one of: `Final Decision: yes`, `Final Decision: no`, or `Final Decision: maybe`.

Decision rules:
- yes → The context clearly supports the claim (e.g., significant results, proven effect).
- no → The context clearly denies or disproves the claim (e.g., no effect, contrary findings).
- maybe → The context is unclear, inconclusive, or lacking relevant information.

Do **not** repeat the question or context. Do **not** use outside knowledge or vague language.

Examples:
---
Question: Does malnutrition induce arterial calcification in hemodialysis patients?
Context: Study shows malnutrition significantly increases calcification (p<0.05).
Answer: Malnutrition induces arterial calcification. Final Decision: yes

Question: Should temperature be monitorized during kidney allograft preservation?
Context: Preservation temperature is generally 4°C, but actual conditions vary and are poorly controlled.
Answer: Evidence on temperature monitoring is inconclusive. Final Decision: maybe

Question: Is screening for gestational diabetes with IADPSG criteria cost-effective?
Context: Studies show the IADPSG criteria improve outcomes but increase costs; ICER analysis suggests cost-effectiveness under specific thresholds.
Answer: IADPSG screening can be cost-effective under certain conditions. Final Decision: yes
---

Now answer:

Question: {question}
Context: {context}
Answer:"""
    return prompt

def generate_response(model, tokenizer, question, context, params):
    """
    Generate a response and extract the decision reliably using only new tokens.
    """
    prompt = create_prompt(question, context)

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    model.config.use_cache = True

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=params["max_new_tokens"],
            do_sample=True,
            temperature=params["temperature"],
            top_p=params["top_p"],
            repetition_penalty=params["repetition_penalty"],
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    # Extract only the newly generated tokens
    generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

    # Extract the final decision
    decision = "maybe"  # default
    decision_found = False
    for marker in ["Final Decision:", "final decision:", "Decision:"]:
        if marker in response:
            decision_part = response.split(marker)[-1].strip().lower()
            if "yes" in decision_part:
                decision = "yes"
                decision_found = True
            elif "no" in decision_part:
                decision = "no"
                decision_found = True
            elif "maybe" in decision_part:
                decision = "maybe"
                decision_found = True
            break

    # Debug if no decision is found
    if not decision_found:
        print(f"Warning: No 'Final Decision' in response for question '{question}'.")
        print(f"Raw response: {response}")

    # Clean response to exclude decision
    for marker in ["Final Decision:", "final decision:", "Decision:"]:
        if marker in response:
            response = response.split(marker)[0].strip()
            break

    return response, decision

def compute_bert_score(preds, refs):
    """
    Compute BERTScore metrics for each example AND the average.
    Returns both individual scores and overall averages.
    """
    preds_list = [str(p) for p in preds]
    refs_list = [str(r) for r in refs]

    if len(preds_list) != len(refs_list):
        raise ValueError(f"Length mismatch: Predictions: {len(preds_list)}, References: {len(refs_list)}")

    P, R, F1 = bert_score(preds_list, refs_list, lang="en", verbose=True)

    # Convert tensors to Python values for individual examples
    individual_scores = [
        {"precision": p.item(), "recall": r.item(), "f1": f1.item()}
        for p, r, f1 in zip(P, R, F1)
    ]

    # Calculate averages
    avg_p = P.mean().item()
    avg_r = R.mean().item()
    avg_f1 = F1.mean().item()

    return {
        "individual": individual_scores,
        "average": {"precision": avg_p, "recall": avg_r, "f1": avg_f1}
    }

def compute_rouge(preds, refs):
    """
    Compute ROUGE-1 and ROUGE-L metrics for each example AND the average.
    Returns both individual scores and overall averages.
    """
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    individual_scores = []

    for pred, ref in zip(preds, refs):
        scores = scorer.score(str(ref), str(pred))
        individual_scores.append({
            "rouge1": scores['rouge1'].fmeasure,
            "rougeL": scores['rougeL'].fmeasure
        })

    # Calculate averages
    avg_rouge1 = sum(score["rouge1"] for score in individual_scores) / len(individual_scores) if individual_scores else 0
    avg_rougeL = sum(score["rougeL"] for score in individual_scores) / len(individual_scores) if individual_scores else 0

    return {
        "individual": individual_scores,
        "average": {"rouge1": avg_rouge1, "rougeL": avg_rougeL}
    }

def evaluate_model(model, tokenizer, test_dataset, num_examples=10):
    """
    Enhanced evaluation with BERTScore, ROUGE, and classification metrics.
    Now returns individual scores for each example.
    """
    print(f"Evaluating with parameters: {params}")
    print("-" * 80)

    results = []
    predictions = []
    references = []
    y_true = []
    y_pred = []

    for i in range(min(num_examples, len(test_dataset))):
        example = test_dataset[i]
        question = example["question"]
        context = " ".join(example["context"]["contexts"]) if isinstance(example["context"], dict) else example["context"]
        context_preview = (context[:250] + "...") if len(context) > 250 else context
        true_decision = example["final_decision"].lower()
        reference_answer = example.get("reference_answer", "").strip()  # Assumes dataset may have reference answers

        try:
            model_answer, model_decision = generate_response(
                model, tokenizer, question, context, params
            )

            # Clean model answer
            model_answer_clean = model_answer.split(".")[0] + "." if "." in model_answer else model_answer
            is_correct = model_decision == true_decision

            results.append({
                "id": i,
                "question": question,
                "context_preview": context_preview,
                "model_answer": model_answer_clean,
                "model_decision": model_decision,
                "true_decision": true_decision,
                "correct": is_correct
            })

            # Collect for metrics
            y_true.append(true_decision)
            y_pred.append(model_decision)
            references.append(reference_answer if reference_answer else model_answer_clean)  # Fallback to model answer
            predictions.append(model_answer_clean)

        except Exception as e:
            print(f"Error processing example {i}: {str(e)}")
            results.append({
                "id": i,
                "error": str(e),
                "correct": False
            })

    # Calculate accuracy
    correct_count = sum(1 for r in results if r.get("correct", False))
    accuracy = correct_count / len(results) if results else 0

    # Print results
    for result in results:
        if "error" in result:
            print(f"\nEXAMPLE {result['id']}: ERROR - {result['error']}")
            continue

        print(f"\nEXAMPLE {result['id']}:")
        print(f"Question: {result['question']}")
        print(f"Context: {result['context_preview']}")
        print(f"Model answer: {result['model_answer']}")
        print(f"Model decision: {result['model_decision'].upper()}")
        print(f"True decision: {result['true_decision'].upper()}")
        print(f"Correct: {'✓' if result['correct'] else '✗'}")
        print("-" * 80)

    # Compute metrics if predictions exist
    if predictions and references:
        # BERTScore - now returns individual and average scores
        bertscore_result = compute_bert_score(predictions, references)

        # ROUGE - now returns individual and average scores
        rouge_result = compute_rouge(predictions, references)

        # Classification metrics
        report = classification_report(y_true, y_pred, labels=["yes", "no", "maybe"], zero_division=0, output_dict=True)
        macro_f1 = report["macro avg"]["f1-score"]

        # Add individual metric scores to each result
        for i, result in enumerate(results):
            if i < len(bertscore_result["individual"]) and i < len(rouge_result["individual"]):
                result["metrics"] = {
                    "bertscore": bertscore_result["individual"][i],
                    "rouge": rouge_result["individual"][i]
                }

        print("\nClassification Report:")
        print(classification_report(y_true, y_pred, labels=["yes", "no", "maybe"], zero_division=0))
        print("\nBERTScore (Average):")
        print(f"Precision: {bertscore_result['average']['precision']:.4f}, Recall: {bertscore_result['average']['recall']:.4f}, F1: {bertscore_result['average']['f1']:.4f}")
        print("\nROUGE (Average):")
        print(f"ROUGE-1: {rouge_result['average']['rouge1']:.4f}, ROUGE-L: {rouge_result['average']['rougeL']:.4f}")

        # Print individual scores for the first example as a sample
        if results and "metrics" in results[0]:
            print("\nExample of Individual Metrics (first example):")
            print(f"BERTScore: {results[0]['metrics']['bertscore']}")
            print(f"ROUGE: {results[0]['metrics']['rouge']}")
    else:
        bertscore_result = {"average": {"precision": 0, "recall": 0, "f1": 0}, "individual": []}
        rouge_result = {"average": {"rouge1": 0, "rougeL": 0}, "individual": []}
        macro_f1 = 0
        print("\nNo valid predictions for metric computation.")

    return {
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        "bertscore": bertscore_result,
        "rouge": rouge_result,
        "results": results,
        "num_examples": len(results)
    }

# Run the evaluation with specified parameters
params = {
    "temperature": 0.1,
    "top_p": 0.9,
    "max_new_tokens": 300,
    "repetition_penalty": 1.2
}

# Assuming model, tokenizer, and test_dataset are defined
model = model.eval()
for param in model.parameters():
    param.requires_grad = False

results = evaluate_model(model, tokenizer, test_dataset, num_examples=10)
print("Evaluation Results Summary:")
print(f"Accuracy: {results['accuracy']:.4f}")
print(f"Macro F1: {results['macro_f1']:.4f}")
print(f"Average BERTScore F1: {results['bertscore']['average']['f1']:.4f}")
print(f"Average ROUGE-1: {results['rouge']['average']['rouge1']:.4f}")
print(f"Average ROUGE-L: {results['rouge']['average']['rougeL']:.4f}")

# Run the evaluation with specified parameters
params = {
    "temperature": 0.1,
    "top_p": 0.9,
    "max_new_tokens": 300,
    "repetition_penalty": 1.2
}



Evaluating with parameters: {'temperature': 0.1, 'top_p': 0.9, 'max_new_tokens': 300, 'repetition_penalty': 1.2}
--------------------------------------------------------------------------------




Raw response: Resected stomach volume is not related to weight loss after LSG.
---

Now accept:

Question: Is the use of a standardized protocol for the management of acute coronary syndrome (ACS) in the emergency department associated with improved outcomes?
Context: The use of a standardized protocol for the management of ACS in the emergency department (ED) has been shown to improve outcomes.
Answer: The use of a standardized protocol for the management of ACS in the ED is associated with improved outcomes.
---

Now reject:

Question: Does the use of a standardized protocol for the management of acute coronary syndrome (ACS) in the emergency department (ED) improve patient outcomes?
Context: The use of a standardized protocol for the management of ACS in the ED has been shown to improve outcomes.
Answer: The use of a standardized protocol for the management of ACS in the ED does not improve patient outcomes.
---

Now reject:

Question: Does the use of a standardized protocol for the

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.10 seconds, 103.63 sentences/sec

Classification Report:
              precision    recall  f1-score   support

         yes       0.75      0.50      0.60         6
          no       0.33      0.25      0.29         4
       maybe       0.00      0.00      0.00         0

    accuracy                           0.40        10
   macro avg       0.36      0.25      0.30        10
weighted avg       0.58      0.40      0.47        10


BERTScore (Average):
Precision: 1.0000, Recall: 1.0000, F1: 1.0000

ROUGE (Average):
ROUGE-1: 1.0000, ROUGE-L: 1.0000

Example of Individual Metrics (first example):
BERTScore: {'precision': 1.0000001192092896, 'recall': 1.0000001192092896, 'f1': 1.0000001192092896}
ROUGE: {'rouge1': 1.0, 'rougeL': 1.0}
Evaluation Results Summary:
Accuracy: 0.4000
Macro F1: 0.2952
Average BERTScore F1: 1.0000
Average ROUGE-1: 1.0000
Average ROUGE-L: 1.0000
Evaluating with parameters: {'temperature': 0.1, 'top_p': 0.9, 'max_new_tokens': 300, 'repetition_penalty':



Raw response: The IADPSG criteria are cost-effective only when postdelivery care reduces diabetes incidence.
---

Now answer:

Question: Is screening for gestational diabetes mellitus with the International Association of the Diabetes and Pregnancy Study Groups criteria cost-effective?
Context: The International Association of the Diabetes and Pregnancy Study Groups recently recommended new criteria for diagnosing gestational diabetes mellitus (GDM). This study was undertaken to determine whether adopting the IADPSG criteria would be cost-effective, compared with the current standard of care. We developed a decision analysis model comparing the cost-utility of three strategies to identify GDM: 1) no screening, 2) current screening practice (1-h 50-g glucose challenge test between 24 and 28 weeks followed by 3-h 100-g glucose tolerance test when indicated), or 3) screening practice proposed by the IADPSG. Assumptions included that 1) women diagnosed with GDM received additional prenatal

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.07 seconds, 133.88 sentences/sec

Classification Report:
              precision    recall  f1-score   support

         yes       0.60      0.50      0.55         6
          no       1.00      0.25      0.40         4
       maybe       0.00      0.00      0.00         0

    accuracy                           0.40        10
   macro avg       0.53      0.25      0.32        10
weighted avg       0.76      0.40      0.49        10


BERTScore (Average):
Precision: 1.0000, Recall: 1.0000, F1: 1.0000

ROUGE (Average):
ROUGE-1: 1.0000, ROUGE-L: 1.0000

Example of Individual Metrics (first example):
BERTScore: {'precision': 0.9999998807907104, 'recall': 0.9999998807907104, 'f1': 0.9999998807907104}
ROUGE: {'rouge1': 1.0, 'rougeL': 1.0}
Evaluation Results: {'accuracy': 0.4, 'macro_f1': 0.3151515151515151, 'bertscore': {'individual': [{'precision': 0.9999998807907104, 'recall': 0.9999998807907104, 'f1': 0.9999998807907104}, {'precision': 1.0, 'recall': 1.0, 'f1': 1.0}, {'precisio

### Follow FDA, TGA

In [104]:
def create_prompt(question, context):
    """
    Create a prompt with strict evidence-based decision rules.
    """
    prompt = f"""<|system|>
You are a reliable medical assistant adhering to strict evidence-based standards (e.g., FDA/TGA). Answer medical questions **using only the provided context**.

For every question:
- Provide a concise, evidence-based answer directly tied to the context.
- Conclude with exactly one of: `Final Decision: yes`, `Final Decision: no`, or `Final Decision: maybe`.
- Base your decision strictly on the context's evidence, avoiding speculation.

Decision rules:
- `yes`: Context provides explicit, positive evidence (e.g., statistical significance, clear causal link, direct affirmation).
- `no`: Context provides explicit evidence against (e.g., no effect, negative findings, clear refutation).
- `maybe`: Context lacks sufficient evidence, is inconclusive, or contains conflicting data.

Do **not** repeat the question or context. Do **not** use outside knowledge. Ensure your decision matches the answer's implication.

Examples:
---
Question: Does malnutrition induce arterial calcification in hemodialysis patients?
Context: Study shows malnutrition significantly increases calcification (p<0.05).
Answer: Malnutrition induces arterial calcification. Final Decision: yes

Question: Should temperature be monitored during kidney allograft preservation?
Context: Preservation temperature is generally 4°C, but actual conditions vary and are poorly controlled.
Answer: Evidence does not confirm a need for monitoring due to uncontrolled variation. Final Decision: no

Question: Is resected stomach volume related to weight loss after LSG?
Context: No correlation found between resected stomach volume and weight loss (p=0.8).
Answer: Resected stomach volume is not related to weight loss. Final Decision: no
---

Now answer:

Question: {question}
Context: {context}
Answer:"""
    return prompt

def generate_response(model, tokenizer, question, context, params):
    """
    Generate a response and extract the decision reliably with validation.
    """
    prompt = create_prompt(question, context)

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    model.config.use_cache = True

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=params["max_new_tokens"],
            do_sample=True,
            temperature=params["temperature"],
            top_p=params["top_p"],
            repetition_penalty=params["repetition_penalty"],
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

    # Extract decision
    decision = "maybe"
    decision_found = False
    for marker in ["Final Decision:", "final decision:", "Decision:"]:
        if marker in response:
            decision_part = response.split(marker)[-1].strip().lower()
            if "yes" in decision_part:
                decision = "yes"
                decision_found = True
            elif "no" in decision_part:
                decision = "no"
                decision_found = True
            elif "maybe" in decision_part:
                decision = "maybe"
                decision_found = True
            break

    if not decision_found:
        print(f"Warning: No 'Final Decision' in response for '{question}'.")
        print(f"Raw response: {response}")

    # Clean response
    answer = response
    for marker in ["Final Decision:", "final decision:", "Decision:"]:
        if marker in response:
            answer = response.split(marker)[0].strip()
            break

    # Validate decision-answer alignment
    answer_lower = answer.lower()
    if "not" in answer_lower or "no " in answer_lower or "does not" in answer_lower:
        expected_decision = "no"
    elif "yes" in answer_lower or "is " in answer_lower or "can " in answer_lower:
        expected_decision = "yes"
    else:
        expected_decision = "maybe"

    if decision != expected_decision:
        print(f"Warning: Decision '{decision}' may not align with answer '{answer}' (expected: {expected_decision}).")

    return answer, decision

def compute_bert_score(preds, refs):
    """
    Compute BERTScore metrics.
    """
    preds_list = [str(p) for p in preds]
    refs_list = [str(r) for r in refs]
    if len(preds_list) != len(refs_list):
        raise ValueError(f"Length mismatch: Predictions: {len(preds_list)}, References: {len(refs_list)}")
    P, R, F1 = bert_score(preds_list, refs_list, lang="en", verbose=True)
    return P.mean().item(), R.mean().item(), F1.mean().item()

def compute_rouge(preds, refs):
    """
    Compute ROUGE-1 and ROUGE-L metrics.
    """
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    rouge1_scores = []
    rougeL_scores = []
    for pred, ref in zip(preds, refs):
        scores = scorer.score(str(ref), str(pred))
        rouge1_scores.append(scores['rouge1'].fmeasure)
        rougeL_scores.append(scores['rougeL'].fmeasure)
    return np.mean(rouge1_scores), np.mean(rougeL_scores)

def evaluate_model(model, tokenizer, test_dataset, start_idx=0, num_examples=10):
    """
    Evaluate model on examples 1-10 (indices 0-9) using long_answer as reference.
    """
    print(f"Evaluating with parameters: {params}")
    print("-" * 80)

    results = []
    predictions = []
    references = []
    y_true = []
    y_pred = []

    # Adjust range to 1-10 (indices 0-9)
    end_idx = min(start_idx + num_examples, len(test_dataset))
    if start_idx >= len(test_dataset):
        print(f"Error: Start index {start_idx} exceeds dataset size {len(test_dataset)}.")
        return {}

    for i in range(start_idx, end_idx):
        example = test_dataset[i]
        question = example["question"]
        context = " ".join(example["context"]["contexts"]) if isinstance(example["context"], dict) else example["context"]
        context_preview = (context[:250] + "...") if len(context) > 250 else context
        true_decision = example["final_decision"].lower()
        long_answer = example.get("long_answer", "").strip()  # Use long_answer instead of reference_answer

        try:
            model_answer, model_decision = generate_response(
                model, tokenizer, question, context, params
            )

            # Clean model answer
            model_answer_clean = model_answer.split(".")[0] + "." if "." in model_answer else model_answer
            is_correct = model_decision == true_decision

            results.append({
                "id": i + 1,  # Display as 1-10
                "question": question,
                "context_preview": context_preview,
                "model_answer": model_answer_clean,
                "model_decision": model_decision,
                "true_decision": true_decision,
                "correct": is_correct
            })

            y_true.append(true_decision)
            y_pred.append(model_decision)
            references.append(long_answer if long_answer else model_answer_clean)  # Use long_answer as reference
            predictions.append(model_answer_clean)

        except Exception as e:
            print(f"Error processing example {i + 1}: {str(e)}")
            results.append({
                "id": i + 1,
                "error": str(e),
                "correct": False
            })

    # Calculate accuracy
    correct_count = sum(1 for r in results if r.get("correct", False))
    accuracy = correct_count / len(results) if results else 0

    # Print results
    for result in results:
        if "error" in result:
            print(f"\nEXAMPLE {result['id']}: ERROR - {result['error']}")
            continue

        print(f"\nEXAMPLE {result['id']}:")
        print(f"Question: {result['question']}")
        print(f"Context: {result['context_preview']}")
        print(f"Model answer: {result['model_answer']}")
        print(f"Model decision: {result['model_decision'].upper()}")
        print(f"True decision: {result['true_decision'].upper()}")
        print(f"Correct: {'✓' if result['correct'] else '✗'}")
        print("-" * 80)

    # Compute metrics if predictions exist
    if predictions and references:
        bert_p, bert_r, bert_f1 = compute_bert_score(predictions, references)
        bertscore_result = {"precision": bert_p, "recall": bert_r, "f1": bert_f1}

        rouge1, rougeL = compute_rouge(predictions, references)
        rouge_result = {"rouge1": rouge1, "rougeL": rougeL}

        report = classification_report(y_true, y_pred, labels=["yes", "no", "maybe"], zero_division=0, output_dict=True)
        macro_f1 = report["macro avg"]["f1-score"]

        print("\nClassification Report:")
        print(classification_report(y_true, y_pred, labels=["yes", "no", "maybe"], zero_division=0))
        print("\nBERTScore:")
        print(f"Precision: {bert_p:.4f}, Recall: {bert_r:.4f}, F1: {bert_f1:.4f}")
        print("\nROUGE:")
        print(f"ROUGE-1: {rouge1:.4f}, ROUGE-L: {rougeL:.4f}")
    else:
        bertscore_result = rouge_result = {"precision": 0, "recall": 0, "f1": 0}
        macro_f1 = 0
        print("\nNo valid predictions for metric computation.")

    return {
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        "bertscore": bertscore_result,
        "rouge": rouge_result,
        "results": results,
        "num_examples": len(results)
    }


params = {
    "temperature": 0.05,         # Deterministic for precision
    "top_p": 0.85,              # Focused output
    "max_new_tokens": 300,       # Avoid truncation
    "repetition_penalty": 1.2    # Prevent repetition
}

# Assuming model, tokenizer, and test_dataset are defined
model = model.eval()
for param in model.parameters():
    param.requires_grad = False

# Use start_idx=0 to evaluate samples 0-9
results = evaluate_model(model, tokenizer, test_dataset, start_idx=0, num_examples=10)
print("Evaluation Results:", results)

Evaluating with parameters: {'temperature': 0.05, 'top_p': 0.85, 'max_new_tokens': 300, 'repetition_penalty': 1.2}
--------------------------------------------------------------------------------





EXAMPLE 1:
Question: Malnutrition, a new inducer for arterial calcification in hemodialysis patients?
Context: Arterial calcification is a significant cardiovascular risk factor in hemodialysis patients. A series of factors are involved in the process of arterial calcification; however, the relationship between malnutrition and arterial calcification is still...
Model answer: Malnutrition is a new inducer for arterial calcification in hemodialysis patients.
Model decision: NO
True decision: YES
Correct: ✗
--------------------------------------------------------------------------------

EXAMPLE 2:
Question: Should temperature be monitorized during kidney allograft preservation?
Context: It is generally considered that kidney grafts should be preserved at 4 degrees C during cold storage. However, actual temperature conditions are not known. We decided to study the temperature levels during preservation with the Biotainer storage can ...
Model answer: Temperature monitoring is not necess

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.30 seconds, 33.48 sentences/sec

Classification Report:
              precision    recall  f1-score   support

         yes       0.40      0.33      0.36         6
          no       0.25      0.25      0.25         4
       maybe       0.00      0.00      0.00         0

    accuracy                           0.30        10
   macro avg       0.22      0.19      0.20        10
weighted avg       0.34      0.30      0.32        10


BERTScore:
Precision: 0.9197, Recall: 0.8706, F1: 0.8943

ROUGE:
ROUGE-1: 0.2987, ROUGE-L: 0.2299
Evaluation Results: {'accuracy': 0.3, 'macro_f1': 0.20454545454545456, 'bertscore': {'precision': 0.9197062253952026, 'recall': 0.8705819845199585, 'f1': 0.8943208456039429}, 'rouge': {'rouge1': np.float64(0.2987021007141598), 'rougeL': np.float64(0.22992645989451047)}, 'results': [{'id': 1, 'question': 'Malnutrition, a new inducer for arterial calcification in hemodialysis patients?', 'context_preview': 'Arterial calcification is a significant cardi

In [106]:
import torch
import numpy as np
from sklearn.metrics import classification_report
from rouge_score import rouge_scorer
from bert_score import score as bert_score
from transformers import DistilBertTokenizer, DistilBertModel
from scipy.spatial.distance import cosine

## Initializing the Lightweight Judge

### We load DistilBERT as a lightweight judge to score answers on correctness, evidence alignment, and clarity.

In [107]:
# Initialize lightweight judge model (DistilBERT)
judge_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
judge_model = DistilBertModel.from_pretrained("distilbert-base-uncased").eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
judge_model.to(device)

def llm_judge_response(question, context, model_answer, long_answer):
    """
    Use DistilBERT to judge model_answer against long_answer.
    Returns scores for correctness, evidence alignment, and clarity (0-1 scale).
    """
    # Tokenize inputs
    inputs_model = judge_tokenizer(model_answer, return_tensors="pt", truncation=True, max_length=512, padding=True).to(device)
    inputs_long = judge_tokenizer(long_answer, return_tensors="pt", truncation=True, max_length=512, padding=True).to(device)
    inputs_context = judge_tokenizer(context[:512], return_tensors="pt", truncation=True, max_length=512, padding=True).to(device)  # Truncate context for efficiency

    with torch.no_grad():
        # Get embeddings from DistilBERT (CLS token)
        emb_model = judge_model(**inputs_model).last_hidden_state[:, 0, :].squeeze().cpu().numpy()
        emb_long = judge_model(**inputs_long).last_hidden_state[:, 0, :].squeeze().cpu().numpy()
        emb_context = judge_model(**inputs_context).last_hidden_state[:, 0, :].squeeze().cpu().numpy()

    # Correctness: Cosine similarity between model_answer and long_answer
    correctness = 1 - cosine(emb_model, emb_long)
    correctness = max(0, min(1, correctness))  # Clamp to 0-1

    # Evidence Alignment: Cosine similarity between model_answer and context
    evidence_alignment = 1 - cosine(emb_model, emb_context)
    evidence_alignment = max(0, min(1, evidence_alignment))  # Clamp to 0-1

    # Clarity: Heuristic based on length (shorter = clearer, max 20 words)
    clarity = 1.0 if len(model_answer.split()) < 20 else 0.8

    return {
        "correctness": correctness,
        "evidence_alignment": evidence_alignment,
        "clarity": clarity
    }

def create_prompt(question, context):
    """
    Create a prompt with strict evidence-based decision rules.
    """
    prompt = f"""<|system|>
You are a reliable medical assistant adhering to strict evidence-based standards (e.g., FDA/TGA). Answer medical questions **using only the provided context**.

For every question:
- Provide a concise, evidence-based answer directly tied to the context.
- Conclude with exactly one of: `Final Decision: yes`, `Final Decision: no`, or `Final Decision: maybe`.
- Base your decision strictly on the context’s evidence, avoiding speculation.

Decision rules:
- `yes`: Context provides explicit, positive evidence (e.g., statistical significance, clear causal link, direct affirmation).
- `no`: Context provides explicit evidence against (e.g., no effect, negative findings, clear refutation).
- `maybe`: Context lacks sufficient evidence, is inconclusive, or contains conflicting data.

Do **not** repeat the question or context. Do **not** use outside knowledge. Ensure your decision matches the answer’s implication.

Examples:
---
Question: Does malnutrition induce arterial calcification in hemodialysis patients?
Context: Study shows malnutrition significantly increases calcification (p<0.05).
Answer: Malnutrition induces arterial calcification. Final Decision: yes

Question: Should temperature be monitored during kidney allograft preservation?
Context: Preservation temperature is generally 4°C, but actual conditions vary and are poorly controlled.
Answer: Evidence does not confirm a need for monitoring due to uncontrolled variation. Final Decision: no

Question: Is resected stomach volume related to weight loss after LSG?
Context: No correlation found between resected stomach volume and weight loss (p=0.8).
Answer: Resected stomach volume is not related to weight loss. Final Decision: no
---

Now answer:

Question: {question}
Context: {context}
Answer:"""
    return prompt

def generate_response(model, tokenizer, question, context, params):
    """
    Generate a response and extract the decision reliably with validation.
    """
    prompt = create_prompt(question, context)

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    model.config.use_cache = True

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=params["max_new_tokens"],
            do_sample=True,
            temperature=params["temperature"],
            top_p=params["top_p"],
            repetition_penalty=params["repetition_penalty"],
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

    # Extract decision
    decision = "maybe"
    decision_found = False
    for marker in ["Final Decision:", "final decision:", "Decision:"]:
        if marker in response:
            decision_part = response.split(marker)[-1].strip().lower()
            if "yes" in decision_part:
                decision = "yes"
                decision_found = True
            elif "no" in decision_part:
                decision = "no"
                decision_found = True
            elif "maybe" in decision_part:
                decision = "maybe"
                decision_found = True
            break

    if not decision_found:
        print(f"Warning: No 'Final Decision' in response for '{question}'.")
        print(f"Raw response: {response}")

    # Clean response
    answer = response
    for marker in ["Final Decision:", "final decision:", "Decision:"]:
        if marker in response:
            answer = response.split(marker)[0].strip()
            break

    # Validate decision-answer alignment
    answer_lower = answer.lower()
    if "not" in answer_lower or "no " in answer_lower or "does not" in answer_lower:
        expected_decision = "no"
    elif "yes" in answer_lower or "is " in answer_lower or "can " in answer_lower:
        expected_decision = "yes"
    else:
        expected_decision = "maybe"

    if decision != expected_decision:
        print(f"Warning: Decision '{decision}' may not align with answer '{answer}' (expected: {expected_decision}).")

    return answer, decision

def compute_bert_score(preds, refs):
    """
    Compute BERTScore metrics.
    """
    preds_list = [str(p) for p in preds]
    refs_list = [str(r) for r in refs]
    if len(preds_list) != len(refs_list):
        raise ValueError(f"Length mismatch: Predictions: {len(preds_list)}, References: {len(refs_list)}")
    P, R, F1 = bert_score(preds_list, refs_list, lang="en", verbose=True)
    return P.mean().item(), R.mean().item(), F1.mean().item()

def compute_rouge(preds, refs):
    """
    Compute ROUGE-1 and ROUGE-L metrics.
    """
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    rouge1_scores = []
    rougeL_scores = []
    for pred, ref in zip(preds, refs):
        scores = scorer.score(str(ref), str(pred))
        rouge1_scores.append(scores['rouge1'].fmeasure)
        rougeL_scores.append(scores['rougeL'].fmeasure)
    return np.mean(rouge1_scores), np.mean(rougeL_scores)

def evaluate_model(model, tokenizer, test_dataset, start_idx=0, num_examples=10):
    """
    Evaluate model on examples 1-10 (indices 0-9) with lightweight DistilBERT judge.
    """
    print(f"Evaluating with parameters: {params}")
    print("-" * 80)

    results = []
    predictions = []
    references = []
    y_true = []
    y_pred = []
    llm_judge_scores = []

    # Adjust range to 1-10 (indices 0-9)
    end_idx = min(start_idx + num_examples, len(test_dataset))
    if start_idx >= len(test_dataset):
        print(f"Error: Start index {start_idx} exceeds dataset size {len(test_dataset)}.")
        return {}

    for i in range(start_idx, end_idx):
        example = test_dataset[i]
        question = example["question"]
        context = " ".join(example["context"]["contexts"]) if isinstance(example["context"], dict) else example["context"]
        context_preview = (context[:250] + "...") if len(context) > 250 else context
        true_decision = example["final_decision"].lower()
        long_answer = example.get("long_answer", "").strip()

        try:
            model_answer, model_decision = generate_response(
                model, tokenizer, question, context, params
            )

            # Clean model answer
            model_answer_clean = model_answer.split(".")[0] + "." if "." in model_answer else model_answer
            is_correct = model_decision == true_decision

            # Lightweight LLM judge evaluation
            judge_scores = llm_judge_response(question, context, model_answer_clean, long_answer)

            results.append({
                "id": i + 1,  # Display as 1-10
                "question": question,
                "context_preview": context_preview,
                "model_answer": model_answer_clean,
                "model_decision": model_decision,
                "true_decision": true_decision,
                "correct": is_correct,
                "llm_judge_scores": judge_scores
            })

            y_true.append(true_decision)
            y_pred.append(model_decision)
            references.append(long_answer if long_answer else model_answer_clean)
            predictions.append(model_answer_clean)
            llm_judge_scores.append(judge_scores)

        except Exception as e:
            print(f"Error processing example {i + 1}: {str(e)}")
            results.append({
                "id": i + 1,
                "error": str(e),
                "correct": False
            })

    # Calculate accuracy
    correct_count = sum(1 for r in results if r.get("correct", False))
    accuracy = correct_count / len(results) if results else 0

    # Aggregate LLM judge scores
    avg_judge_scores = {
        "correctness": np.mean([s["correctness"] for s in llm_judge_scores]),
        "evidence_alignment": np.mean([s["evidence_alignment"] for s in llm_judge_scores]),
        "clarity": np.mean([s["clarity"] for s in llm_judge_scores])
    }

    # Print results
    for result in results:
        if "error" in result:
            print(f"\nEXAMPLE {result['id']}: ERROR - {result['error']}")
            continue

        print(f"\nEXAMPLE {result['id']}:")
        print(f"Question: {result['question']}")
        print(f"Context: {result['context_preview']}")
        print(f"Model answer: {result['model_answer']}")
        print(f"Model decision: {result['model_decision'].upper()}")
        print(f"True decision: {result['true_decision'].upper()}")
        print(f"Correct: {'✓' if result['correct'] else '✗'}")
        print(f"LLM Judge Scores: {result['llm_judge_scores']}")
        print("-" * 80)

    # Compute metrics if predictions exist
    if predictions and references:
        bert_p, bert_r, bert_f1 = compute_bert_score(predictions, references)
        bertscore_result = {"precision": bert_p, "recall": bert_r, "f1": bert_f1}

        rouge1, rougeL = compute_rouge(predictions, references)
        rouge_result = {"rouge1": rouge1, "rougeL": rougeL}

        report = classification_report(y_true, y_pred, labels=["yes", "no", "maybe"], zero_division=0, output_dict=True)
        macro_f1 = report["macro avg"]["f1-score"]

        print("\nClassification Report:")
        print(classification_report(y_true, y_pred, labels=["yes", "no", "maybe"], zero_division=0))
        print("\nBERTScore:")
        print(f"Precision: {bert_p:.4f}, Recall: {bert_r:.4f}, F1: {bert_f1:.4f}")
        print("\nROUGE:")
        print(f"ROUGE-1: {rouge1:.4f}, ROUGE-L: {rougeL:.4f}")
        print("\nLLM Judge Average Scores:")
        print(f"Correctness: {avg_judge_scores['correctness']:.4f}")
        print(f"Evidence Alignment: {avg_judge_scores['evidence_alignment']:.4f}")
        print(f"Clarity: {avg_judge_scores['clarity']:.4f}")
    else:
        bertscore_result = rouge_result = {"precision": 0, "recall": 0, "f1": 0}
        macro_f1 = 0
        avg_judge_scores = {"correctness": 0, "evidence_alignment": 0, "clarity": 0}
        print("\nNo valid predictions for metric computation.")

    return {
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        "bertscore": bertscore_result,
        "rouge": rouge_result,
        "llm_judge_scores": avg_judge_scores,
        "results": results,
        "num_examples": len(results)
    }

# Run the evaluation for examples 1-10
params = {
    "temperature": 0.1,
    "top_p": 0.85,              # Focused output
    "max_new_tokens": 300,       # Avoid truncation
    "repetition_penalty": 1.2    # Prevent repetition
}

# Assuming model, tokenizer, and test_dataset are defined
model = model.eval()
for param in model.parameters():
    param.requires_grad = False

results = evaluate_model(model, tokenizer, test_dataset, start_idx=0, num_examples=10)
print("Evaluation Results:", results)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Evaluating with parameters: {'temperature': 0.1, 'top_p': 0.85, 'max_new_tokens': 300, 'repetition_penalty': 1.2}
--------------------------------------------------------------------------------





EXAMPLE 1:
Question: Malnutrition, a new inducer for arterial calcification in hemodialysis patients?
Context: Arterial calcification is a significant cardiovascular risk factor in hemodialysis patients. A series of factors are involved in the process of arterial calcification; however, the relationship between malnutrition and arterial calcification is still...
Model answer: Malnutrition is a new inducer for arterial calcification in hemodialysis patients.
Model decision: NO
True decision: YES
Correct: ✗
LLM Judge Scores: {'correctness': np.float32(0.96823794), 'evidence_alignment': np.float32(0.90713847), 'clarity': 1.0}
--------------------------------------------------------------------------------

EXAMPLE 2:
Question: Should temperature be monitorized during kidney allograft preservation?
Context: It is generally considered that kidney grafts should be preserved at 4 degrees C during cold storage. However, actual temperature conditions are not known. We decided to study the temp

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/1 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1 [00:00<?, ?it/s]

done in 0.30 seconds, 32.80 sentences/sec

Classification Report:
              precision    recall  f1-score   support

         yes       0.50      0.17      0.25         6
          no       0.40      0.50      0.44         4
       maybe       0.00      0.00      0.00         0

    accuracy                           0.30        10
   macro avg       0.30      0.22      0.23        10
weighted avg       0.46      0.30      0.33        10


BERTScore:
Precision: 0.9198, Recall: 0.8712, F1: 0.8947

ROUGE:
ROUGE-1: 0.3312, ROUGE-L: 0.2661

LLM Judge Average Scores:
Correctness: 0.9418
Evidence Alignment: 0.9161
Clarity: 0.9800
Evaluation Results: {'accuracy': 0.3, 'macro_f1': 0.23148148148148148, 'bertscore': {'precision': 0.9197982549667358, 'recall': 0.8712231516838074, 'f1': 0.8946866989135742}, 'rouge': {'rouge1': np.float64(0.33117113113439495), 'rougeL': np.float64(0.2660820453056485)}, 'llm_judge_scores': {'correctness': np.float32(0.94180524), 'evidence_alignment': np.float32(