# Store RAG Documents

This notebook retrieves PubMed papers based on keywords and stores them for RAG context.

In [1]:
import sys
sys.path.append("..")

import pandas as pd
from helpers import pubmed
from tqdm import tqdm
from helpers import llm
import time

## Setup PubMed API

In [2]:
default_model = "deepseek-r1:32b"  # used for keywords
host = "localhost"
port = 11434

client = llm.setup_ollama_client(host=host, port=port)
pubmed.set_api_key()

## Load Claims from CSV

In [3]:
df = pd.read_csv("../dataloader/scifact_medical_causal_claims.csv", index_col=0)
print(f"Loaded {len(df)} claims")
df.head()

Loaded 200 claims


Unnamed: 0_level_0,id,claim,evidence_doc_id,evidence_label,evidence_sentences,cited_doc_ids,causal_result_raw,is_medical_causal
Unnamed: 0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
7,12,40mg/day dosage of folic acid and 2mg/day dosa...,33409100,SUPPORT,[8],[33409100],Yes,True
24,30,A breast cancer patient's capacity to metaboli...,24341590,SUPPORT,[10],[24341590],Yes,True
29,34,A deficiency of folate increases blood levels ...,11705328,SUPPORT,[4],[11705328],Yes,True
32,39,A diminished ovarian reserve does not solely i...,13497630,SUPPORT,[7],[13497630],Yes,True
42,41,A high microerythrocyte count protects against...,18174210,SUPPORT,[1 9],[18174210],Yes,True


## Define Keyword Generation and Retrieval Functions

This approach:
1. Uses DeepSeek-R1 to generate 4 keywords for each claim
2. First searches for papers using AND logic with the keywords
3. If less than 10 papers found with AND, augments results with additional papers from OR search (avoiding duplicates)
4. Returns up to 10 papers total

In [None]:
def get_keywords(claim, n_keywords=4, model=default_model, client=client):
    """
    Generate keywords for a claim using LLM.

    Args:
        claim: The medical claim to generate keywords for
        n_keywords: Number of keywords to generate (default: 4)
        model: LLM model to use
        client: Ollama client

    Returns:
        String with keywords separated by ' AND '
    """
    response = llm.call_ollama(
        model=model,
        prompt=f"Suggest me a set of keywords to search for finding scientific articles about the following claim: {claim}. Give just a simple list of {n_keywords} keywords, separated by commas with no further explanation.",
        client=client,
    )

    output = response.get("response", "")

    # Extract keywords after </think> tag if present (DeepSeek-R1 reasoning)
    if "</think>" in output:
        output = output.split("</think>")[-1].strip()

    # Split by comma and clean up
    keywords = [kw.strip() for kw in output.split(",") if kw.strip()][:n_keywords]

    # Join with AND for PubMed search
    return " AND ".join(keywords)


def retrieve_papers_with_fallback(keywords, top_k=10):
    """
    Retrieve papers from PubMed with augmentation strategy.

    First searches with AND logic. If fewer than top_k papers are found,
    augments the results with additional papers from OR search (avoiding duplicates).

    Args:
        keywords: String with keywords separated by 'AND' (e.g., "diabetes AND insulin AND glucose")
        top_k: Target number of papers to retrieve (default: 10)

    Returns:
        List of PubMedPaper objects (up to top_k papers)
    """
    # First attempt: Use AND logic
    papers_and = pubmed.get_papers(keywords, top_k=top_k)

    # If we got enough papers with AND, return them
    if len(papers_and) >= top_k:
        return papers_and[:top_k]

    # Otherwise, augment with OR search
    print(f"  Found only {len(papers_and)} papers with AND logic")

    # Convert 'AND' to 'OR' for broader search
    keywords_or = keywords.replace(" AND ", " OR ")
    print(f"  Augmenting with OR search: {keywords_or}")

    # Calculate how many more papers we need
    needed = top_k - len(papers_and)

    # Get more papers with OR search
    papers_or = pubmed.get_papers(keywords_or, top_k=top_k + len(papers_and))

    # Get PMIDs from AND search to avoid duplicates
    and_pmids = {paper.pmid for paper in papers_and}

    # Add papers from OR search that aren't already in AND results
    augmented_papers = papers_and.copy()
    for paper in papers_or:
        if paper.pmid not in and_pmids:
            augmented_papers.append(paper)
            if len(augmented_papers) >= top_k:
                break

    print(f"  Total papers after augmentation: {len(augmented_papers)}")

    return augmented_papers[:top_k]  # Ensure we return at most top_k papers

## Retrieve and Store Papers

In [None]:
results = {
    "claim": [],
    "keywords": [],
    "paper_ids": [],  # List of PMIDs
    "concatenated_abstracts": [],  # All abstracts joined together
    "num_papers": [],  # Number of papers retrieved
}

claims = df["claim"].tolist()

for idx, claim in enumerate(tqdm(claims, desc="Processing claims")):
    print(f"\nProcessing claim {idx+1}/{len(claims)}")
    print(f"Claim: {claim[:100]}...")

    # Generate keywords using DeepSeek-R1 (already returns keywords with AND separator)
    keywords = get_keywords(claim, n_keywords=4, model=default_model, client=client)
    print(f"Generated keywords: {keywords}")

    # Retrieve papers with fallback strategy (AND first, then OR if needed)
    papers = retrieve_papers_with_fallback(keywords, top_k=10)

    # Extract PMIDs and abstracts
    paper_ids = [paper.pmid for paper in papers]
    abstracts = [paper.abstract for paper in papers]

    # Concatenate all abstracts with separators
    concatenated_abstracts = "\n\n---\n\n".join(
        [f"[PMID: {pmid}] {abstract}" for pmid, abstract in zip(paper_ids, abstracts)]
    )

    # Store results
    results["claim"].append(claim)
    results["keywords"].append(keywords)
    results["paper_ids"].append(",".join(paper_ids))  # Store as comma-separated string
    results["concatenated_abstracts"].append(concatenated_abstracts)
    results["num_papers"].append(len(papers))

    print(f"Retrieved {len(papers)} papers")

    # Small delay to avoid overwhelming the API
    time.sleep(0.5)

print("\nRetrieval complete!")

## Save Results to CSV

In [None]:
results_df = pd.DataFrame(results)
output_file = "rag_documents.csv"
results_df.to_csv(output_file, index=False)
print(f"Results saved to {output_file}")
print(f"Total claims processed: {len(results_df)}")

## Preview Results

In [None]:
results_df.head()

## Statistics

In [None]:
print("Number of papers retrieved per claim:")
print(results_df['num_papers'].describe())
print(f"\nClaims with less than 10 papers: {(results_df['num_papers'] < 10).sum()}")
print(f"Claimes with exactly 10 papers: {(results_df['num_papers'] == 10).sum()}")

## Example: View one claim's retrieved documents

In [None]:
# Display the first claim's information
example_idx = 0
print(f"Claim: {results_df.iloc[example_idx]['claim']}")
print(f"\nKeywords: {results_df.iloc[example_idx]['keywords']}")
print(f"\nNumber of papers: {results_df.iloc[example_idx]['num_papers']}")
print(f"\nPaper IDs: {results_df.iloc[example_idx]['paper_ids']}")
print(f"\nConcatenated Abstracts (first 500 chars):\n{results_df.iloc[example_idx]['concatenated_abstracts'][:500]}...")