# Reranked RAG Experiments

This notebook runs experiments using the top 3 reranked abstracts from PubMed to answer medical causal claims.

In [None]:
import sys

sys.path.append("..")

from helpers import llm
import pandas as pd
from tqdm import tqdm
import time
import os

## Load Reranked RAG Documents

In [None]:
reranked_df = pd.read_csv("reports/reranked_rag_documents.csv")
print(f"Loaded {len(reranked_df)} claims with reranked top 3 documents")
reranked_df.head()

## Setup Ollama Client

In [None]:
host = "localhost"
port = 11434

client = llm.setup_ollama_client(host=host, port=port)

## Define RAG Prompt

In [None]:
rag_prompt = """
You are a biomedical expert specializing in causal inference. 

Evaluate the following medical causal claim based ONLY on the provided scientific abstracts.

ABSTRACTS:
{documents}

CLAIM: "{claim}"

Carefully analyze the evidence in the abstracts. If the abstracts support the claim, respond with SUPPORTED. If they contradict the claim, respond with CONTRADICT.

Provide your reasoning, cite relevant papers by PMID, and then give your final answer.

Final Answer: [SUPPORTED or CONTRADICT]
"""

## Run Reranked RAG Experiments

In [None]:
output_file = "reports/reranked_rag_results.csv"

# Check if output file exists to determine if we need to write headers
file_exists = os.path.isfile(output_file)

models = [
    "deepseek-r1:32b",
    "mistral:7b",
    "llama3.1:8b",
    "qwen3:30b",
    "qwen3:8b",
    "llama3.1:70b",
]

for model in models:
    result = {
        "model": [],
        "method": [],
        "claim": [],
        "keywords": [],
        "top3_paper_ids": [],
        "num_selected": [],
        "top3_abstracts": [],
        "answer": [],
    }

    for idx, row in tqdm(reranked_df.iterrows(), total=len(reranked_df), desc=f"Reranked RAG - {model}"):
        claim = row["claim"]
        top3_abstracts = row["top3_abstracts"]
        keywords = row["keywords"]
        top3_paper_ids = row["top3_paper_ids"]
        num_selected = row["num_selected"]

        # Call LLM with RAG prompt using top 3 reranked abstracts
        response = llm.call_ollama(
            model=model,
            prompt=rag_prompt.format(claim=claim, documents=top3_abstracts),
            client=client,
        )
        output = response.get("response", "NAN")

        result["model"].append(model)
        result["method"].append("reranked_rag")
        result["claim"].append(claim)
        result["keywords"].append(keywords)
        result["top3_paper_ids"].append(top3_paper_ids)
        result["num_selected"].append(num_selected)
        result["top3_abstracts"].append(top3_abstracts)
        result["answer"].append(output)

    # Convert to DataFrame and append to CSV
    result_df = pd.DataFrame(result)
    result_df.to_csv(
        output_file,
        mode="a" if file_exists else "w",
        header=not file_exists,
        index=False,
    )
    file_exists = True  # After first write, file exists

    print(f"Completed {model}")
    time.sleep(2)

print("\nAll experiments complete!")

## Load and Display Results

In [None]:
# Load and display final results
final_results = pd.read_csv(output_file)
print(f"Total results saved: {len(final_results)}")
final_results.tail()

## Summary Statistics

In [None]:
print("Results by model:")
print(final_results.groupby("model").size())
print("\nAverage number of papers used:")
print(final_results.groupby("model")["num_selected"].mean())

## Example: View One Result

In [None]:
# Display the first result
example_idx = 0
print(f"Model: {final_results.iloc[example_idx]['model']}")
print(f"\nClaim: {final_results.iloc[example_idx]['claim']}")
print(f"\nKeywords: {final_results.iloc[example_idx]['keywords']}")
print(f"\nNumber of selected papers: {final_results.iloc[example_idx]['num_selected']}")
print(f"\nTop 3 Paper IDs: {final_results.iloc[example_idx]['top3_paper_ids']}")
print(f"\nAnswer:\n{final_results.iloc[example_idx]['answer']}")

## Compare with Original RAG Results

In [None]:
# Load original RAG results for comparison
original_rag_results = pd.read_csv("reports/rag_results.csv")

print("Original RAG - Average number of papers used:")
print(original_rag_results.groupby("model")["num_papers"].mean())
print("\nReranked RAG - Average number of papers used:")
print(final_results.groupby("model")["num_selected"].mean())