# Rerank RAG Documents

This notebook reranks the retrieved PubMed abstracts and selects the top 3 most relevant ones for each claim using DeepSeek-R1.

In [1]:
import sys
sys.path.append("..")

import pandas as pd
from helpers import llm
from tqdm import tqdm
import time

## Setup LLM Client

In [None]:
default_model = "deepseek-r1:32b"  # used for reranking
host = "localhost"
port = 11434

client = llm.setup_ollama_client(host=host, port=port)

## Load Retrieved Documents

In [None]:
df = pd.read_csv("reports/rag_documents.csv")
print(f"Loaded {len(df)} claims with retrieved documents")
df.head()

## Define Reranking Function

This function asks DeepSeek-R1 to select the top 3 most relevant abstracts for each claim.

In [None]:
def rerank_abstracts(claim, concatenated_abstracts, model=default_model, client=client, top_k=3):
    """
    Rerank abstracts using LLM to find the most relevant ones.
    
    Args:
        claim: The medical claim
        concatenated_abstracts: String with all abstracts concatenated
        model: LLM model to use
        client: Ollama client
        top_k: Number of top abstracts to select (default: 3)
    
    Returns:
        Tuple of (list of PMIDs, concatenated top abstracts)
    """
    # Handle empty or NaN abstracts
    if pd.isna(concatenated_abstracts) or not concatenated_abstracts or concatenated_abstracts.strip() == "":
        return [], ""
    
    # Parse the concatenated abstracts to extract individual papers
    papers = concatenated_abstracts.split("\n\n---\n\n")
    
    # Create a numbered list for the LLM
    papers_list = ""
    for i, paper in enumerate(papers, 1):
        papers_list += f"\n{i}. {paper}\n"
    
    prompt = f"""You are a scientific assistant tasked with identifying the most relevant research papers for a given medical claim.

Claim: {claim}

Below are {len(papers)} abstracts from PubMed. Please analyze each abstract and select the top {top_k} most relevant ones that best address or relate to the claim above.

Abstracts:
{papers_list}

Please respond with ONLY the numbers of the top {top_k} most relevant abstracts, separated by commas (e.g., "1,5,8"). Do not provide any explanation, just the numbers."""
    
    response = llm.call_ollama(
        model=model,
        prompt=prompt,
        client=client,
    )
    
    output = response.get("response", "")
    
    # Extract numbers after </think> tag if present (DeepSeek-R1 reasoning)
    if "</think>" in output:
        output = output.split("</think>")[-1].strip()
    
    # Extract numbers from the response
    import re
    numbers = re.findall(r'\d+', output)
    selected_indices = [int(n) - 1 for n in numbers[:top_k]]  # Convert to 0-indexed
    
    # Get the selected papers
    selected_papers = [papers[i] for i in selected_indices if i < len(papers)]
    
    # Extract PMIDs from selected papers
    selected_pmids = []
    for paper in selected_papers:
        pmid_match = re.search(r'\[PMID: (\d+)\]', paper)
        if pmid_match:
            selected_pmids.append(pmid_match.group(1))
    
    # Concatenate selected abstracts
    concatenated_top_abstracts = "\n\n---\n\n".join(selected_papers)
    
    return selected_pmids, concatenated_top_abstracts

## Rerank and Store Top 3 Abstracts

In [None]:
results = {
    "claim": [],
    "keywords": [],
    "top3_paper_ids": [],  # List of top 3 PMIDs
    "top3_abstracts": [],  # Top 3 abstracts concatenated
    "num_selected": [],  # Should always be 3 (or less if fewer papers available)
}

for idx, row in tqdm(df.iterrows(), total=len(df), desc="Reranking abstracts"):
    print(f"\nProcessing claim {idx+1}/{len(df)}")
    print(f"Claim: {row['claim'][:100]}...")
    
    # Rerank abstracts to get top 3
    top_pmids, top_abstracts = rerank_abstracts(
        claim=row['claim'],
        concatenated_abstracts=row['concatenated_abstracts'],
        model=default_model,
        client=client,
        top_k=3
    )
    
    # Store results
    results["claim"].append(row['claim'])
    results["keywords"].append(row['keywords'])
    results["top3_paper_ids"].append(",".join(top_pmids))
    results["top3_abstracts"].append(top_abstracts)
    results["num_selected"].append(len(top_pmids))
    
    print(f"Selected top {len(top_pmids)} papers: {', '.join(top_pmids)}")
    
    # Small delay to avoid overwhelming the API
    time.sleep(0.5)

print("\nReranking complete!")

## Save Results to CSV

In [None]:
results_df = pd.DataFrame(results)
output_file = "reports/reranked_rag_documents.csv"
results_df.to_csv(output_file, index=False)
print(f"Results saved to {output_file}")
print(f"Total claims processed: {len(results_df)}")

## Preview Results

In [None]:
results_df.head()

## Statistics

In [None]:
print("Number of papers selected per claim:")
print(results_df['num_selected'].describe())
print(f"\nClaims with exactly 3 papers: {(results_df['num_selected'] == 3).sum()}")
print(f"Claims with less than 3 papers: {(results_df['num_selected'] < 3).sum()}")

## Example: View one claim's top 3 documents

In [None]:
# Display the first claim's information
example_idx = 0
print(f"Claim: {results_df.iloc[example_idx]['claim']}")
print(f"\nKeywords: {results_df.iloc[example_idx]['keywords']}")
print(f"\nNumber of selected papers: {results_df.iloc[example_idx]['num_selected']}")
print(f"\nTop 3 Paper IDs: {results_df.iloc[example_idx]['top3_paper_ids']}")
print(f"\nTop 3 Abstracts (first 500 chars):\n{results_df.iloc[example_idx]['top3_abstracts'][:500]}...")