In [2]:
import ir_datasets
import pandas as pd

# Load TREC Clinical Trials 2021 benchmark set (topics + qrels)
ds = ir_datasets.load("clinicaltrials/2021/trec-ct-2021")

# --- Load queries (patient topics) ---
queries = []
for q in ds.queries_iter():
    # q: (query_id, text)
    queries.append({"topic_id": q.query_id, "text": q.text})
queries_df = pd.DataFrame(queries)

# --- Load qrels (relevance judgments) ---
qrels = []
for qr in ds.qrels_iter():
    # qr: (query_id, doc_id, relevance, iteration)
    qrels.append({"topic_id": qr.query_id, "doc_id": qr.doc_id, "relevance": int(qr.relevance)})
qrels_df = pd.DataFrame(qrels)

print("Loaded dataset:", ds)
print("Num topics:", len(queries_df))
print("Num qrels rows:", len(qrels_df))
print("Topics with at least 1 judged relevant doc:",
      qrels_df[qrels_df["relevance"] > 0]["topic_id"].nunique())

display(queries_df.head(3))
display(qrels_df.head(5))

Loaded dataset: Dataset(id='clinicaltrials/2021/trec-ct-2021', provides=['docs', 'queries', 'qrels'])
Num topics: 75
Num qrels rows: 35832
Topics with at least 1 judged relevant doc: 75


Unnamed: 0,topic_id,text
0,1,\nPatient is a 45-year-old man with a history ...
1,2,"\n48 M with a h/o HTN hyperlipidemia, bicuspid..."
2,3,\nA 32 yo woman who presents following a sever...


Unnamed: 0,topic_id,doc_id,relevance
0,1,NCT00002569,1
1,1,NCT00002620,1
2,1,NCT00002806,0
3,1,NCT00002814,2
4,1,NCT00003022,1


In [3]:
from tqdm import tqdm

# Load the full ClinicalTrials corpus used in TREC CT 2021
docs = []
for doc in tqdm(ds.docs_iter(), desc="Loading trial documents"):
    # doc: (doc_id, title, summary, detailed_description, ...)
    docs.append({
        "doc_id": doc.doc_id,
        "title": doc.title,
        "summary": doc.summary,
        "description": doc.detailed_description
    })

docs_df = pd.DataFrame(docs)

print("Trial documents loaded")
print("Total number of trials:", len(docs_df))

display(docs_df.head(3))

Loading trial documents: 375580it [00:02, 147552.78it/s]


Trial documents loaded
Total number of trials: 375580


Unnamed: 0,doc_id,title,summary,description
0,NCT00000102,Congenital Adrenal Hyperplasia: Calcium Channe...,\n \n This study will test the ability...,\n \n This protocol is designed to ass...
1,NCT00000104,Does Lead Burden Alter Neuropsychological Deve...,\n \n Inner city children are at an in...,
2,NCT00000105,Vaccination With Tetanus and KLH to Assess Imm...,\n \n The purpose of this study is to ...,\n \n Patients will receive each vacci...


In [4]:
import re
from rank_bm25 import BM25Okapi
from tqdm import tqdm

def normalize_text(s: str) -> str:
    if s is None:
        return ""
    s = s.lower()
    s = re.sub(r"\s+", " ", s).strip()
    return s

def tokenize(s: str):
    # simple tokenizer: words + numbers
    return re.findall(r"[a-z0-9]+", normalize_text(s))

# Build a compact text per doc (keeps memory reasonable)
doc_ids = docs_df["doc_id"].tolist()
doc_texts = (docs_df["title"].fillna("") + " " +
             docs_df["summary"].fillna("") + " " +
             docs_df["description"].fillna("")).tolist()

# Tokenize corpus for BM25
tokenized_corpus = [tokenize(t) for t in tqdm(doc_texts, desc="Tokenizing corpus")]

bm25 = BM25Okapi(tokenized_corpus)

print("BM25 index ready")
print("Num docs indexed:", len(doc_ids))
print("Example tokens:", tokenized_corpus[0][:30])

Tokenizing corpus: 100%|██████████| 375580/375580 [00:36<00:00, 10224.50it/s]


BM25 index ready
Num docs indexed: 375580
Example tokens: ['congenital', 'adrenal', 'hyperplasia', 'calcium', 'channels', 'as', 'therapeutic', 'targets', 'this', 'study', 'will', 'test', 'the', 'ability', 'of', 'extended', 'release', 'nifedipine', 'procardia', 'xl', 'a', 'blood', 'pressure', 'medication', 'to', 'permit', 'a', 'decrease', 'in', 'the']


In [5]:
import numpy as np
from tqdm import tqdm

K = 100

runs = []  # list of dicts: topic_id, doc_id, rank, score

for row in tqdm(queries_df.itertuples(index=False), total=len(queries_df), desc=f"Retrieving top {K}"):
    topic_id = str(row.topic_id)
    query_text = row.text

    q_tokens = tokenize(query_text)
    scores = bm25.get_scores(q_tokens)  # numpy array aligned with doc_ids

    top_idx = np.argpartition(scores, -K)[-K:]          # unsorted top-K indices
    top_idx = top_idx[np.argsort(scores[top_idx])[::-1]]  # sort descending

    for rank, idx in enumerate(top_idx, start=1):
        runs.append({
            "topic_id": topic_id,
            "doc_id": doc_ids[idx],
            "rank": rank,
            "score": float(scores[idx])
        })

run_df = pd.DataFrame(runs)

print("Run generated")
print("Rows in run:", len(run_df))
display(run_df.head(10))

Retrieving top 100: 100%|██████████| 75/75 [16:10<00:00, 12.94s/it]

Run generated
Rows in run: 7500





Unnamed: 0,topic_id,doc_id,rank,score
0,1,NCT03528642,1,293.930154
1,1,NCT00003176,2,290.582836
2,1,NCT02942264,3,288.751092
3,1,NCT01466686,4,287.16188
4,1,NCT03633552,5,286.098323
5,1,NCT00841555,6,286.02794
6,1,NCT00968240,7,279.294295
7,1,NCT00089427,8,277.892832
8,1,NCT00734682,9,276.883413
9,1,NCT00003537,10,276.22625


In [6]:
from pathlib import Path

# ---------- project-aligned paths ----------
BASE_DIR = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user")
BASE_DIR.mkdir(parents=True, exist_ok=True)

run_path = BASE_DIR / "bm25_python_trec2021_top100.run"
metrics_path = BASE_DIR / "bm25_python_trec2021_metrics.csv"

# ---------- save TREC run file ----------
runname = "bm25_python"

with open(run_path, "w", encoding="utf-8") as f:
    for r in run_df.itertuples(index=False):
        f.write(f"{r.topic_id} Q0 {r.doc_id} {r.rank} {r.score:.6f} {runname}\n")

print("Saved run file to:")
print(run_path)

# ---------- evaluation (P@K, R@K) ----------
def precision_recall_at_k(run_df: pd.DataFrame, qrels_df: pd.DataFrame, k: int):
    rel = (
        qrels_df[qrels_df["relevance"] > 0]
        .groupby("topic_id")["doc_id"]
        .apply(set)
        .to_dict()
    )

    topk = (
        run_df.sort_values(["topic_id", "rank"])
        .groupby("topic_id")
        .head(k)
        .groupby("topic_id")["doc_id"]
        .apply(list)
        .to_dict()
    )

    topics = sorted(set(rel.keys()) & set(topk.keys()))
    precisions, recalls = [], []

    for t in topics:
        retrieved = topk.get(t, [])
        relevant = rel.get(t, set())

        hits = sum(1 for d in retrieved if d in relevant)
        precisions.append(hits / k)
        recalls.append(hits / len(relevant))

    return sum(precisions) / len(precisions), sum(recalls) / len(recalls)

p10, r10 = precision_recall_at_k(run_df, qrels_df, k=10)
p100, r100 = precision_recall_at_k(run_df, qrels_df, k=100)

metrics_df = pd.DataFrame([
    {"metric": "Precision@10", "value": p10},
    {"metric": "Recall@10", "value": r10},
    {"metric": "Precision@100", "value": p100},
    {"metric": "Recall@100", "value": r100},
])

metrics_df.to_csv(metrics_path, index=False)

print("\nEvaluation (BM25 python baseline):")
print(f"K=10   Precision@10:  {p10:.4f}   Recall@10:  {r10:.4f}")
print(f"K=100  Precision@100: {p100:.4f}   Recall@100: {r100:.4f}")

print("\nSaved metrics to:")
print(metrics_path)

display(metrics_df)

Saved run file to:
C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\bm25_python_trec2021_top100.run

Evaluation (BM25 python baseline):
K=10   Precision@10:  0.3000   Recall@10:  0.0279
K=100  Precision@100: 0.1485   Recall@100: 0.1040

Saved metrics to:
C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\bm25_python_trec2021_metrics.csv


Unnamed: 0,metric,value
0,Precision@10,0.3
1,Recall@10,0.027917
2,Precision@100,0.148533
3,Recall@100,0.104029


In [9]:
from tqdm import tqdm

# Rebuild docs_df including eligibility field (this is where criteria live)
docs = []
for doc in tqdm(ds.docs_iter(), desc="Reloading trial docs with eligibility"):
    docs.append({
        "doc_id": doc.doc_id,
        "title": doc.title,
        "summary": doc.summary,
        "description": doc.detailed_description,
        "eligibility": doc.eligibility
    })

docs_df = pd.DataFrame(docs)

print("✅ Rebuilt docs_df with eligibility field")
print("Total docs:", len(docs_df))
print("Eligibility non-empty:", (docs_df["eligibility"].fillna("").str.strip() != "").sum())

# Rebuild fast lookup to include eligibility
docs_lookup = (
    docs_df.set_index("doc_id")[["title", "summary", "description", "eligibility"]]
    .fillna("")
    .to_dict(orient="index")
)

def get_trial_text(doc_id: str) -> str:
    d = docs_lookup.get(doc_id, {"title": "", "summary": "", "description": "", "eligibility": ""})
    return (d["title"] + "\n" + d["summary"] + "\n" + d["description"] + "\n" + d["eligibility"]).strip()

# Re-run sanity check on the same 3 docs from topic 1
sample_doc_ids = run_df[run_df["topic_id"] == "1"].sort_values("rank").head(3)["doc_id"].tolist()

for did in sample_doc_ids:
    t = get_trial_text(did)
    elig = extract_eligibility(t)
    print("\n==============================")
    print("DOC:", did)
    print("Eligibility text length:", len(docs_lookup[did]["eligibility"]))
    print("Inclusion found:", len(elig["inclusion"]))
    print("Exclusion found:", len(elig["exclusion"]))
    print("Inclusion sample:", elig["inclusion"][:3])
    print("Exclusion sample:", elig["exclusion"][:3])

Reloading trial docs with eligibility: 375580it [00:02, 140614.14it/s]


✅ Rebuilt docs_df with eligibility field
Total docs: 375580
Eligibility non-empty: 374648

DOC: NCT03528642
Eligibility text length: 7218
Inclusion found: 58
Exclusion found: 30
Inclusion sample: ['Patients must have histopathologic or molecular confirmation of either IDH-mutant DA', 'or IDH-mutant AA. Acceptable IDH mutations for study eligibility include any IDH1', 'mutation at codon 132 or any IDH2 mutation at codon 172.']
Exclusion sample: ['Patients must not have received prior chemotherapy to treat the glioma.', 'Patients who are receiving any other investigational agents.', 'History of allergic reactions attributed to compounds of similar chemical or biologic']

DOC: NCT00003176
Eligibility text length: 3027
Inclusion found: 0
Exclusion found: 0
Inclusion sample: []
Exclusion sample: []

DOC: NCT02942264
Eligibility text length: 9312
Inclusion found: 99
Exclusion found: 6
Inclusion sample: ['of prior disease relapses', 'Patients must have pathologic diagnosis of anaplastic astro

In [10]:
import json
import requests
from textwrap import shorten

OLLAMA_URL = "http://localhost:11434/api/generate"
OLLAMA_MODEL = "llama3.1:8b"

def ollama_generate(prompt: str, model: str = OLLAMA_MODEL, temperature: float = 0.0) -> str:
    payload = {
        "model": model,
        "prompt": prompt,
        "stream": False,
        "options": {
            "temperature": temperature
        }
    }
    r = requests.post(OLLAMA_URL, json=payload, timeout=300)
    r.raise_for_status()
    return r.json()["response"]

def build_matching_prompt(patient_text: str, inclusion: list, exclusion: list, trial_id: str) -> str:
    # Keep prompt compact but structured
    inc = "\n".join([f"- {c}" for c in inclusion[:60]])  # cap to avoid huge prompts
    exc = "\n".join([f"- {c}" for c in exclusion[:60]])

    return f"""
You are a clinical trial eligibility matching assistant.
Given a patient description and a trial's inclusion/exclusion criteria, determine eligibility evidence.

Return ONLY valid JSON with this exact schema:
{{
  "trial_id": "{trial_id}",
  "overall_assessment": "likely_eligible" | "likely_ineligible" | "uncertain",
  "inclusion": [
    {{
      "criterion": "...",
      "label": "met" | "not_met" | "unknown",
      "evidence": "short quote or phrase from patient text that supports your label, or 'none'"
    }}
  ],
  "exclusion": [
    {{
      "criterion": "...",
      "label": "triggers" | "does_not_trigger" | "unknown",
      "evidence": "short quote or phrase from patient text that supports your label, or 'none'"
    }}
  ],
  "notes": "1-2 sentences about key unknowns"
}}

Rules:
- Use ONLY the patient text. Do NOT assume missing facts.
- If patient text doesn't mention something, label it unknown.
- Be conservative: if any exclusion likely triggers, overall_assessment should be likely_ineligible.
- Include at most 12 inclusion items and 8 exclusion items (pick the most important/decisive ones).

PATIENT:
{patient_text}

TRIAL INCLUSION CRITERIA (subset):
{inc if inc.strip() else "- (none found)"}

TRIAL EXCLUSION CRITERIA (subset):
{exc if exc.strip() else "- (none found)"}
""".strip()

def safe_json_load(s: str):
    """
    Tries hard to parse JSON even if model adds extra text.
    """
    s = s.strip()
    # If it already starts with { ... }, try direct
    try:
        return json.loads(s)
    except Exception:
        pass

    # Extract the first {...} block
    start = s.find("{")
    end = s.rfind("}")
    if start != -1 and end != -1 and end > start:
        try:
            return json.loads(s[start:end+1])
        except Exception:
            pass

    raise ValueError("Could not parse JSON from model output.")

# ---- pick one topic and one trial to test end-to-end ----
topic_id = "1"
trial_id = "NCT03528642"

patient_text = queries_df.loc[queries_df["topic_id"].astype(str) == topic_id, "text"].iloc[0]
trial_text = get_trial_text(trial_id)
elig = extract_eligibility(trial_text)

prompt = build_matching_prompt(
    patient_text=patient_text,
    inclusion=elig["inclusion"],
    exclusion=elig["exclusion"],
    trial_id=trial_id
)

print("Patient preview:", shorten(patient_text.replace("\n", " "), width=220, placeholder="..."))
print("Trial:", trial_id, "| inclusion:", len(elig["inclusion"]), "| exclusion:", len(elig["exclusion"]))
print("\n--- Sending to Ollama ---")

raw = ollama_generate(prompt, temperature=0.0)
print("\n--- Raw model output (first 800 chars) ---")
print(raw[:800])

result = safe_json_load(raw)
print("\nParsed JSON keys:", list(result.keys()))
print("\nOverall assessment:", result.get("overall_assessment"))
print("\nInclusion items:", len(result.get("inclusion", [])))
print("Exclusion items:", len(result.get("exclusion", [])))

# show the parsed JSON (pretty)
print("\n--- Parsed JSON ---")
print(json.dumps(result, indent=2)[:3000])

Patient preview: Patient is a 45-year-old man with a history of anaplastic astrocytoma of the spine complicated by severe lower extremity weakness and urinary retention s/p Foley catheter, high-dose steroids, hypertension, and chronic...
Trial: NCT03528642 | inclusion: 58 | exclusion: 30

--- Sending to Ollama ---

--- Raw model output (first 800 chars) ---
Here is the eligibility assessment:

```json
{
  "trial_id": "NCT03528642",
  "overall_assessment": "likely_eligible",
  "inclusion": [
    {
      "criterion": "histopathologic or molecular confirmation of either IDH-mutant DA or IDH-mutant AA",
      "label": "met",
      "evidence": "The tumor is located in the T-L spine, unresectable anaplastic astrocytoma s/p radiation."
    },
    {
      "criterion": "Age >= 16 years",
      "label": "met",
      "evidence": "Patient is a 45-year-old man"
    },
    {
      "criterion": "Eastern Cooperative Oncology Group (ECOG) performance status =< 1",
      "label": "unknown",
      "evide

In [11]:
import time
import pandas as pd

# --------- settings ----------
TOPIC_ID = "1"      # change later to test other topics
TOP_N = 10          # start small; later we can do 50 or 100
SLEEP_SEC = 0.0     # set to e.g. 0.2 if Ollama gets overloaded

# --------- fetch patient text ----------
patient_text = queries_df.loc[queries_df["topic_id"].astype(str) == TOPIC_ID, "text"].iloc[0]

# --------- choose top-N trials from retrieval ----------
top_trials = (
    run_df[run_df["topic_id"].astype(str) == TOPIC_ID]
    .sort_values("rank")
    .head(TOP_N)[["doc_id", "rank", "score"]]
    .reset_index(drop=True)
)

def score_match(result_json: dict) -> float:
    """
    Simple deterministic scoring (baseline aggregator).
    - Inclusion: met +1, unknown 0, not_met -1
    - Exclusion: triggers/met -4, unknown 0, does_not_trigger/not_met +0.5
    """
    score = 0.0

    inc = result_json.get("inclusion", [])
    exc = result_json.get("exclusion", [])

    for it in inc:
        lab = (it.get("label") or "").strip().lower()
        if lab == "met":
            score += 1.0
        elif lab == "not_met":
            score -= 1.0
        else:  # unknown
            score += 0.0

    for it in exc:
        lab = (it.get("label") or "").strip().lower()
        # allow either schema wordings (some models say met/not_met)
        if lab in {"triggers", "met"}:
            score -= 4.0
        elif lab in {"does_not_trigger", "not_met"}:
            score += 0.5
        else:
            score += 0.0

    return score

# --------- run matching over Top-N ----------
results = []
for row in top_trials.itertuples(index=False):
    trial_id = row.doc_id

    trial_text = get_trial_text(trial_id)
    elig = extract_eligibility(trial_text)

    prompt = build_matching_prompt(
        patient_text=patient_text,
        inclusion=elig["inclusion"],
        exclusion=elig["exclusion"],
        trial_id=trial_id
    )

    raw = ollama_generate(prompt, temperature=0.0)
    match_json = safe_json_load(raw)

    agg_score = score_match(match_json)

    results.append({
        "topic_id": TOPIC_ID,
        "trial_id": trial_id,
        "bm25_rank": int(row.rank),
        "bm25_score": float(row.score),
        "overall_assessment": match_json.get("overall_assessment"),
        "agg_score": agg_score,
        "n_inclusion_items": len(match_json.get("inclusion", [])),
        "n_exclusion_items": len(match_json.get("exclusion", [])),
        "notes": match_json.get("notes", "")
    })

    if SLEEP_SEC > 0:
        time.sleep(SLEEP_SEC)

rank_df = pd.DataFrame(results).sort_values(["agg_score", "bm25_score"], ascending=[False, False]).reset_index(drop=True)

print("Finished matching + scoring for topic:", TOPIC_ID)
display(rank_df)

Finished matching + scoring for topic: 1


Unnamed: 0,topic_id,trial_id,bm25_rank,bm25_score,overall_assessment,agg_score,n_inclusion_items,n_exclusion_items,notes
0,1,NCT00734682,9,276.883413,likely_eligible,5.0,7,3,Unclear if patient has fully recovered from pr...
1,1,NCT00968240,7,279.294295,likely_eligible,4.5,8,3,Patient's expected survival and hematologic re...
2,1,NCT00841555,6,286.02794,likely_eligible,3.5,3,3,Patient's Karnofsky score and prior chemothera...
3,1,NCT03528642,1,293.930154,likely_eligible,3.0,12,4,Patient's history of hypertension and chronic ...
4,1,NCT03633552,5,286.098323,likely_ineligible,1.0,4,5,Key unknowns include the patient's Karnofsky P...
5,1,NCT00003176,2,290.582836,likely_eligible,0.0,0,0,
6,1,NCT00003537,10,276.22625,likely_ineligible,0.0,0,0,Patient's condition and treatment history are ...
7,1,NCT00089427,8,277.892832,likely_ineligible,-2.0,6,2,Unknown Karnofsky Performance Scale score and ...
8,1,NCT02942264,3,288.751092,likely_ineligible,-5.0,9,2,Patient's history of anaplastic astrocytoma an...
9,1,NCT01466686,4,287.16188,likely_eligible,-8.0,12,8,"ECOG performance status, organ and marrow func..."


In [12]:
import json
import os
from pathlib import Path
from tqdm import tqdm

# --------- settings ----------
TOPIC_ID = "1"
TOP_K = 50  # start with 50; later increase to 100
CACHE_DIR = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\matches")
RUN_TAG = "llama31_8b"

topic_dir = CACHE_DIR / f"topic_{TOPIC_ID}"
topic_dir.mkdir(parents=True, exist_ok=True)

# --------- patient text ----------
patient_text = queries_df.loc[queries_df["topic_id"].astype(str) == TOPIC_ID, "text"].iloc[0]

# --------- top-K trials from retrieval ----------
top_trials = (
    run_df[run_df["topic_id"].astype(str) == TOPIC_ID]
    .sort_values("rank")
    .head(TOP_K)[["doc_id", "rank", "score"]]
    .reset_index(drop=True)
)

def cache_path_for(trial_id: str) -> Path:
    return topic_dir / f"{trial_id}_{RUN_TAG}.json"

def run_one_match(topic_id: str, trial_id: str) -> dict:
    trial_text = get_trial_text(trial_id)
    elig = extract_eligibility(trial_text)

    prompt = build_matching_prompt(
        patient_text=patient_text,
        inclusion=elig["inclusion"],
        exclusion=elig["exclusion"],
        trial_id=trial_id
    )

    raw = ollama_generate(prompt, temperature=0.0)
    match_json = safe_json_load(raw)

    # add minimal metadata so files are self-contained
    match_json["_meta"] = {
        "topic_id": topic_id,
        "trial_id": trial_id,
        "model": OLLAMA_MODEL,
        "run_tag": RUN_TAG
    }
    return match_json

# --------- batch with caching ----------
saved = 0
skipped = 0
errors = 0

for row in tqdm(top_trials.itertuples(index=False), total=len(top_trials), desc=f"Matching Top-{TOP_K} for topic {TOPIC_ID}"):
    trial_id = row.doc_id
    out_path = cache_path_for(trial_id)

    if out_path.exists():
        skipped += 1
        continue

    try:
        match_json = run_one_match(TOPIC_ID, trial_id)
        with open(out_path, "w", encoding="utf-8") as f:
            json.dump(match_json, f, indent=2)
        saved += 1
    except Exception as e:
        errors += 1
        err_path = topic_dir / f"{trial_id}_{RUN_TAG}.error.txt"
        with open(err_path, "w", encoding="utf-8") as f:
            f.write(str(e))
        continue

print("\nBatch matching complete")
print("Saved new JSON files:", saved)
print("Skipped (already cached):", skipped)
print("Errors:", errors)
print("Cache folder:", str(topic_dir))

Matching Top-50 for topic 1: 100%|██████████| 50/50 [15:49<00:00, 19.00s/it]


Batch matching complete
Saved new JSON files: 49
Skipped (already cached): 0
Errors: 1
Cache folder: C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\matches\topic_1





In [13]:
import json
from pathlib import Path
import pandas as pd

TOPIC_ID = "1"
TOP_K = 50
RUN_TAG = "llama31_8b"

CACHE_DIR = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\matches") / f"topic_{TOPIC_ID}"
OUT_DIR = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local")
OUT_DIR.mkdir(parents=True, exist_ok=True)

rerank_run_path = OUT_DIR / f"trialgpt_{RUN_TAG}_trec2021_topic{TOPIC_ID}_top{TOP_K}.run"
rerank_table_path = OUT_DIR / f"trialgpt_{RUN_TAG}_trec2021_topic{TOPIC_ID}_top{TOP_K}_table.csv"

# --- helper: same scoring as before ---
def score_match(result_json: dict) -> float:
    score = 0.0
    inc = result_json.get("inclusion", [])
    exc = result_json.get("exclusion", [])

    for it in inc:
        lab = (it.get("label") or "").strip().lower()
        if lab == "met":
            score += 1.0
        elif lab == "not_met":
            score -= 1.0

    for it in exc:
        lab = (it.get("label") or "").strip().lower()
        if lab in {"triggers", "met"}:
            score -= 4.0
        elif lab in {"does_not_trigger", "not_met"}:
            score += 0.5

    return score

# --- read top-K from retrieval (same candidate set) ---
top_trials = (
    run_df[run_df["topic_id"].astype(str) == TOPIC_ID]
    .sort_values("rank")
    .head(TOP_K)[["doc_id", "rank", "score"]]
    .rename(columns={"doc_id": "trial_id", "rank": "bm25_rank", "score": "bm25_score"})
    .reset_index(drop=True)
)

# --- load cached JSONs ---
rows = []
missing = 0

for r in top_trials.itertuples(index=False):
    trial_id = r.trial_id
    json_path = CACHE_DIR / f"{trial_id}_{RUN_TAG}.json"

    if not json_path.exists():
        missing += 1
        continue

    with open(json_path, "r", encoding="utf-8") as f:
        match_json = json.load(f)

    agg = score_match(match_json)

    rows.append({
        "topic_id": TOPIC_ID,
        "trial_id": trial_id,
        "bm25_rank": int(r.bm25_rank),
        "bm25_score": float(r.bm25_score),
        "overall_assessment": match_json.get("overall_assessment", ""),
        "agg_score": float(agg),
        "n_inclusion_items": len(match_json.get("inclusion", [])),
        "n_exclusion_items": len(match_json.get("exclusion", [])),
        "notes": match_json.get("notes", "")
    })

rank_df = pd.DataFrame(rows)
rank_df = rank_df.sort_values(["agg_score", "bm25_score"], ascending=[False, False]).reset_index(drop=True)

print("Loaded cached matches:", len(rank_df))
print("Missing (due to errors / not cached):", missing)

# --- assign new ranks and write a TREC run file (reranked) ---
runname = f"trialgpt_{RUN_TAG}"
with open(rerank_run_path, "w", encoding="utf-8") as f:
    for i, row in enumerate(rank_df.itertuples(index=False), start=1):
        # Use agg_score as the run score (primary ranking signal)
        f.write(f"{TOPIC_ID} Q0 {row.trial_id} {i} {row.agg_score:.6f} {runname}\n")

rank_df.to_csv(rerank_table_path, index=False)

print("\nSaved reranked run to:")
print(rerank_run_path)
print("\nSaved rerank table to:")
print(rerank_table_path)

display(rank_df.head(15))

Loaded cached matches: 49
Missing (due to errors / not cached): 1

Saved reranked run to:
C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\trialgpt_llama31_8b_trec2021_topic1_top50.run

Saved rerank table to:
C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\trialgpt_llama31_8b_trec2021_topic1_top50_table.csv


Unnamed: 0,topic_id,trial_id,bm25_rank,bm25_score,overall_assessment,agg_score,n_inclusion_items,n_exclusion_items,notes
0,1,NCT03896568,21,262.80782,likely_eligible,6.5,16,6,The patient has a history of anaplastic astroc...
1,1,NCT00360828,37,255.430346,likely_eligible,6.5,12,4,"Predicted life expectancy and ANC, Platelets, ..."
2,1,NCT00704080,20,262.849991,likely_eligible,6.0,7,5,Unknowns include adequate organ and bone marro...
3,1,NCT00783393,46,253.370398,likely_eligible,6.0,8,3,Unknown: tissue samples available for Central ...
4,1,NCT00047879,14,267.152148,likely_eligible,5.5,9,5,Unknown Karnofsky performance status and adequ...
5,1,NCT00458731,47,253.129568,likely_eligible,5.5,12,8,Key unknowns include the patient's total bilir...
6,1,NCT00734682,9,276.883413,likely_eligible,5.0,7,3,Unclear if patient has fully recovered from pr...
7,1,NCT00504660,11,269.024049,likely_eligible,5.0,6,6,Patient's history of anaplastic astrocytoma an...
8,1,NCT00859222,32,256.908038,likely_eligible,5.0,14,15,
9,1,NCT01847235,40,253.914335,likely_eligible,5.0,12,3,Key unknowns include the patient's ECOG perfor...


In [14]:
import json
from pathlib import Path
import pandas as pd

TOP_K = 100
RUN_TAG = "llama31_8b"

# Where cached matches live:
CACHE_BASE = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\matches")

# Output files:
OUT_DIR = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local")
OUT_DIR.mkdir(parents=True, exist_ok=True)

full_run_path = OUT_DIR / f"trialgpt_{RUN_TAG}_trec2021_top{TOP_K}.run"
coverage_path = OUT_DIR / f"trialgpt_{RUN_TAG}_trec2021_top{TOP_K}_coverage.csv"

def score_match(result_json: dict) -> float:
    score = 0.0
    inc = result_json.get("inclusion", [])
    exc = result_json.get("exclusion", [])

    for it in inc:
        lab = (it.get("label") or "").strip().lower()
        if lab == "met":
            score += 1.0
        elif lab == "not_met":
            score -= 1.0

    for it in exc:
        lab = (it.get("label") or "").strip().lower()
        if lab in {"triggers", "met"}:
            score -= 4.0
        elif lab in {"does_not_trigger", "not_met"}:
            score += 0.5

    return score

runname = f"trialgpt_{RUN_TAG}"

all_lines = []
coverage_rows = []

topic_ids = sorted(queries_df["topic_id"].astype(str).unique(), key=lambda x: int(x))

for topic_id in topic_ids:
    # candidates from BM25 run_df (already have top100 per topic)
    cands = (
        run_df[run_df["topic_id"].astype(str) == topic_id]
        .sort_values("rank")
        .head(TOP_K)[["doc_id", "rank", "score"]]
        .rename(columns={"doc_id": "trial_id", "rank": "bm25_rank", "score": "bm25_score"})
        .reset_index(drop=True)
    )

    topic_dir = CACHE_BASE / f"topic_{topic_id}"
    matched = []
    missing = 0

    for r in cands.itertuples(index=False):
        trial_id = r.trial_id
        json_path = topic_dir / f"{trial_id}_{RUN_TAG}.json"

        if not json_path.exists():
            missing += 1
            continue

        with open(json_path, "r", encoding="utf-8") as f:
            match_json = json.load(f)

        agg = score_match(match_json)

        matched.append({
            "trial_id": trial_id,
            "agg_score": float(agg),
            "bm25_score": float(r.bm25_score)
        })

    if len(matched) == 0:
        # no cached matches for this topic yet; we will skip writing any lines for it
        coverage_rows.append({
            "topic_id": topic_id,
            "candidates": len(cands),
            "matched_cached": 0,
            "missing_cached": missing
        })
        continue

    ranked = (
        pd.DataFrame(matched)
        .sort_values(["agg_score", "bm25_score"], ascending=[False, False])
        .reset_index(drop=True)
    )

    # Write TREC run lines for this topic
    for rank, row in enumerate(ranked.itertuples(index=False), start=1):
        all_lines.append(f"{topic_id} Q0 {row.trial_id} {rank} {row.agg_score:.6f} {runname}\n")

    coverage_rows.append({
        "topic_id": topic_id,
        "candidates": len(cands),
        "matched_cached": len(ranked),
        "missing_cached": missing
    })

# Save the run file
with open(full_run_path, "w", encoding="utf-8") as f:
    f.writelines(all_lines)

coverage_df = pd.DataFrame(coverage_rows)
coverage_df.to_csv(coverage_path, index=False)

print("Saved FULL reranked run to:")
print(full_run_path)
print("\nSaved coverage report to:")
print(coverage_path)

print("\nCoverage summary:")
print(coverage_df[["matched_cached"]].describe())
display(coverage_df.head(10))

Saved FULL reranked run to:
C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\trialgpt_llama31_8b_trec2021_top100.run

Saved coverage report to:
C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\trialgpt_llama31_8b_trec2021_top100_coverage.csv

Coverage summary:
       matched_cached
count       75.000000
mean         0.653333
std          5.658033
min          0.000000
25%          0.000000
50%          0.000000
75%          0.000000
max         49.000000


Unnamed: 0,topic_id,candidates,matched_cached,missing_cached
0,1,100,49,51
1,2,100,0,100
2,3,100,0,100
3,4,100,0,100
4,5,100,0,100
5,6,100,0,100
6,7,100,0,100
7,8,100,0,100
8,9,100,0,100
9,10,100,0,100


In [18]:
import json
import time
import requests
from pathlib import Path
from tqdm import tqdm

# ------------------ TEST SETTINGS ------------------
TEST_TOPICS = 5        # only 5 topics for testing
TOP_K = 20             # only top-20 candidates per topic
RUN_TAG = "llama31_8b_test"
SLEEP_SEC = 0.0

# Make generation faster + more consistent
OLLAMA_URL = "http://localhost:11434/api/generate"
OLLAMA_MODEL = "llama3.1:8b"

def ollama_generate(prompt: str, model: str = OLLAMA_MODEL, temperature: float = 0.0, num_predict: int = 450) -> str:
    payload = {
        "model": model,
        "prompt": prompt,
        "stream": False,
        "options": {
            "temperature": temperature,
            "num_predict": num_predict,  # hard cap output tokens
        }
    }
    r = requests.post(OLLAMA_URL, json=payload, timeout=300)
    r.raise_for_status()
    return r.json()["response"]

def build_matching_prompt(patient_text: str, inclusion: list, exclusion: list, trial_id: str) -> str:
    # Smaller slices for faster testing
    inc = "\n".join([f"- {c}" for c in inclusion[:25]])
    exc = "\n".join([f"- {c}" for c in exclusion[:20]])

    return f"""
Return ONLY valid JSON (no markdown, no extra text).

Schema:
{{
  "trial_id": "{trial_id}",
  "overall_assessment": "likely_eligible" | "likely_ineligible" | "uncertain",
  "inclusion": [{{"criterion":"...","label":"met|not_met|unknown","evidence":"quote|none"}}],
  "exclusion":  [{{"criterion":"...","label":"triggers|does_not_trigger|unknown","evidence":"quote|none"}}],
  "notes": "short"
}}

Rules:
- Use ONLY patient text. If not stated, label unknown.
- Pick at most 8 inclusion and 5 exclusion items (most decisive).

PATIENT:
{patient_text}

INCLUSION (subset):
{inc if inc.strip() else "- (none found)"}

EXCLUSION (subset):
{exc if exc.strip() else "- (none found)"}
""".strip()

def safe_json_load(s: str):
    s = s.strip()
    try:
        return json.loads(s)
    except Exception:
        pass
    start = s.find("{")
    end = s.rfind("}")
    if start != -1 and end != -1 and end > start:
        return json.loads(s[start:end+1])
    raise ValueError("Could not parse JSON from model output.")

# ------------------ CACHE PATHS ------------------
CACHE_BASE = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\matches_test")
CACHE_BASE.mkdir(parents=True, exist_ok=True)

def cache_path(topic_id: str, trial_id: str) -> Path:
    topic_dir = CACHE_BASE / f"topic_{topic_id}"
    topic_dir.mkdir(parents=True, exist_ok=True)
    return topic_dir / f"{trial_id}_{RUN_TAG}.json"

def error_path(topic_id: str, trial_id: str) -> Path:
    topic_dir = CACHE_BASE / f"topic_{topic_id}"
    topic_dir.mkdir(parents=True, exist_ok=True)
    return topic_dir / f"{trial_id}_{RUN_TAG}.error.txt"

# ------------------ RUN TEST MATCHING ------------------
topic_ids = sorted(queries_df["topic_id"].astype(str).unique(), key=lambda x: int(x))[:TEST_TOPICS]

total_saved = total_skipped = total_errors = 0

for topic_id in tqdm(topic_ids, desc="Test topics"):
    patient_text = queries_df.loc[queries_df["topic_id"].astype(str) == topic_id, "text"].iloc[0]

    cands = (
        run_df[run_df["topic_id"].astype(str) == topic_id]
        .sort_values("rank")
        .head(TOP_K)["doc_id"]
        .tolist()
    )

    saved = skipped = errors = 0
    for trial_id in tqdm(cands, desc=f"Topic {topic_id}", leave=False):
        out_json = cache_path(topic_id, trial_id)
        if out_json.exists():
            skipped += 1
            continue

        try:
            trial_text = get_trial_text(trial_id)
            elig = extract_eligibility(trial_text)

            prompt = build_matching_prompt(patient_text, elig["inclusion"], elig["exclusion"], trial_id)
            raw = ollama_generate(prompt, temperature=0.0, num_predict=450)
            match_json = safe_json_load(raw)

            match_json["_meta"] = {"topic_id": topic_id, "trial_id": trial_id, "model": OLLAMA_MODEL, "run_tag": RUN_TAG}
            with open(out_json, "w", encoding="utf-8") as f:
                json.dump(match_json, f, indent=2)

            saved += 1

        except Exception as e:
            errors += 1
            with open(error_path(topic_id, trial_id), "w", encoding="utf-8") as f:
                f.write(str(e))

        if SLEEP_SEC > 0:
            time.sleep(SLEEP_SEC)

    total_saved += saved
    total_skipped += skipped
    total_errors += errors

    print(f"\n[Topic {topic_id}] saved={saved} skipped={skipped} errors={errors}")

print("\n✅ TEST MATCHING DONE")
print("Total saved:", total_saved)
print("Total skipped:", total_skipped)
print("Total errors:", total_errors)
print("Cache base:", str(CACHE_BASE))

Test topics:  20%|██        | 1/5 [03:11<12:47, 191.99s/it]


[Topic 1] saved=12 skipped=0 errors=8


Test topics:  40%|████      | 2/5 [06:29<09:45, 195.11s/it]


[Topic 2] saved=16 skipped=0 errors=4


Test topics:  60%|██████    | 3/5 [09:53<06:38, 199.44s/it]


[Topic 3] saved=15 skipped=0 errors=5


Test topics:  80%|████████  | 4/5 [13:16<03:20, 200.58s/it]


[Topic 4] saved=13 skipped=0 errors=7


Test topics: 100%|██████████| 5/5 [16:32<00:00, 198.57s/it]


[Topic 5] saved=18 skipped=0 errors=2

✅ TEST MATCHING DONE
Total saved: 74
Total skipped: 0
Total errors: 26
Cache base: C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\matches_test





In [19]:
import json
from pathlib import Path
import pandas as pd

# --------- must match the test cell settings ----------
TEST_TOPICS = 5
TOP_K = 20
RUN_TAG = "llama31_8b_test"

CACHE_BASE = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\matches_test")
OUT_DIR = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\test_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

full_run_path = OUT_DIR / f"trialgpt_{RUN_TAG}_trec2021_{TEST_TOPICS}topics_top{TOP_K}.run"
table_path = OUT_DIR / f"trialgpt_{RUN_TAG}_trec2021_{TEST_TOPICS}topics_top{TOP_K}_table.csv"
coverage_path = OUT_DIR / f"trialgpt_{RUN_TAG}_trec2021_{TEST_TOPICS}topics_top{TOP_K}_coverage.csv"

def score_match(result_json: dict) -> float:
    score = 0.0
    inc = result_json.get("inclusion", [])
    exc = result_json.get("exclusion", [])

    for it in inc:
        lab = (it.get("label") or "").strip().lower()
        if lab == "met":
            score += 1.0
        elif lab == "not_met":
            score -= 1.0

    for it in exc:
        lab = (it.get("label") or "").strip().lower()
        if lab in {"triggers", "met"}:
            score -= 4.0
        elif lab in {"does_not_trigger", "not_met"}:
            score += 0.5

    return score

topic_ids = sorted(queries_df["topic_id"].astype(str).unique(), key=lambda x: int(x))[:TEST_TOPICS]

all_lines = []
rows = []
cov_rows = []
runname = f"trialgpt_{RUN_TAG}"

for topic_id in topic_ids:
    cands = (
        run_df[run_df["topic_id"].astype(str) == topic_id]
        .sort_values("rank")
        .head(TOP_K)[["doc_id", "rank", "score"]]
        .rename(columns={"doc_id": "trial_id", "rank": "bm25_rank", "score": "bm25_score"})
        .reset_index(drop=True)
    )

    topic_dir = CACHE_BASE / f"topic_{topic_id}"
    matched = []

    for r in cands.itertuples(index=False):
        json_path = topic_dir / f"{r.trial_id}_{RUN_TAG}.json"
        if not json_path.exists():
            continue
        with open(json_path, "r", encoding="utf-8") as f:
            mj = json.load(f)

        agg = score_match(mj)
        matched.append((r.trial_id, float(agg), float(r.bm25_score), int(r.bm25_rank), mj.get("overall_assessment", ""), mj.get("notes", "")))

    cov_rows.append({
        "topic_id": topic_id,
        "candidates": len(cands),
        "matched_cached": len(matched),
        "missing_cached": len(cands) - len(matched)
    })

    if not matched:
        continue

    ranked = (
        pd.DataFrame(matched, columns=["trial_id", "agg_score", "bm25_score", "bm25_rank", "overall_assessment", "notes"])
        .sort_values(["agg_score", "bm25_score"], ascending=[False, False])
        .reset_index(drop=True)
    )

    # write run lines
    for rank, rr in enumerate(ranked.itertuples(index=False), start=1):
        all_lines.append(f"{topic_id} Q0 {rr.trial_id} {rank} {rr.agg_score:.6f} {runname}\n")

    # store table rows
    for rr in ranked.itertuples(index=False):
        rows.append({
            "topic_id": topic_id,
            "trial_id": rr.trial_id,
            "agg_score": rr.agg_score,
            "bm25_score": rr.bm25_score,
            "bm25_rank": rr.bm25_rank,
            "overall_assessment": rr.overall_assessment,
            "notes": rr.notes
        })

# save outputs
with open(full_run_path, "w", encoding="utf-8") as f:
    f.writelines(all_lines)

table_df = pd.DataFrame(rows)
coverage_df = pd.DataFrame(cov_rows)

table_df.to_csv(table_path, index=False)
coverage_df.to_csv(coverage_path, index=False)

print("✅ Saved reranked TEST run to:")
print(full_run_path)
print("\n✅ Saved table to:")
print(table_path)
print("\n✅ Saved coverage to:")
print(coverage_path)

print("\nCoverage:")
display(coverage_df)

print("\nSample reranked rows:")
display(table_df.head(15))

✅ Saved reranked TEST run to:
C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\test_outputs\trialgpt_llama31_8b_test_trec2021_5topics_top20.run

✅ Saved table to:
C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\test_outputs\trialgpt_llama31_8b_test_trec2021_5topics_top20_table.csv

✅ Saved coverage to:
C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\test_outputs\trialgpt_llama31_8b_test_trec2021_5topics_top20_coverage.csv

Coverage:


Unnamed: 0,topic_id,candidates,matched_cached,missing_cached
0,1,20,12,8
1,2,20,16,4
2,3,20,15,5
3,4,20,13,7
4,5,20,18,2



Sample reranked rows:


Unnamed: 0,topic_id,trial_id,agg_score,bm25_score,bm25_rank,overall_assessment,notes
0,1,NCT00362570,4.5,267.891473,12,likely_eligible,Patient has a history of anaplastic astrocytom...
1,1,NCT02942264,3.5,288.751092,3,likely_eligible,Patient has a history of anaplastic astrocytom...
2,1,NCT00976313,1.5,264.269625,18,likely_eligible,Patient has a history of anaplastic astrocytom...
3,1,NCT00003176,0.0,290.582836,2,likely_ineligible,Patient has unresectable anaplastic astrocytom...
4,1,NCT00003537,0.0,276.22625,10,likely_ineligible,Patient has unresectable anaplastic astrocytom...
5,1,NCT00003775,0.0,267.316937,13,likely_ineligible,Patient has unresectable anaplastic astrocytom...
6,1,NCT00052624,0.0,264.655959,16,likely_ineligible,Patient has unresectable anaplastic astrocytom...
7,1,NCT00028795,0.0,264.555911,17,likely_ineligible,Patient has unresectable anaplastic astrocytom...
8,1,NCT00004080,0.0,262.940941,19,likely_ineligible,Patient has unresectable anaplastic astrocytom...
9,1,NCT00968240,-3.0,279.294295,7,likely_eligible,The patient has a history of anaplastic astroc...


In [20]:
import pandas as pd

# ---------- paths ----------
BM25_RUN = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\bm25_python_trec2021_top100.run")
TRIALGPT_RUN = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\test_outputs\trialgpt_llama31_8b_test_trec2021_5topics_top20.run")

TEST_TOPICS = set(["1", "2", "3", "4", "5"])

# ---------- load run files ----------
def load_run(path):
    rows = []
    with open(path, "r") as f:
        for line in f:
            topic, _, docid, rank, score, runname = line.strip().split()
            if topic in TEST_TOPICS:
                rows.append({
                    "topic_id": topic,
                    "doc_id": docid,
                    "rank": int(rank)
                })
    return pd.DataFrame(rows)

bm25_df = load_run(BM25_RUN)
trialgpt_df = load_run(TRIALGPT_RUN)

# ---------- metric function ----------
def precision_recall_at_k(run_df, qrels_df, k):
    rel = (
        qrels_df[qrels_df["relevance"] > 0]
        .groupby("topic_id")["doc_id"]
        .apply(set)
        .to_dict()
    )

    precisions, recalls = [], []

    for topic in TEST_TOPICS:
        retrieved = (
            run_df[run_df["topic_id"] == topic]
            .sort_values("rank")
            .head(k)["doc_id"]
            .tolist()
        )
        relevant = rel.get(topic, set())

        if not relevant:
            continue

        hits = sum(1 for d in retrieved if d in relevant)
        precisions.append(hits / k)
        recalls.append(hits / len(relevant))

    return sum(precisions)/len(precisions), sum(recalls)/len(recalls)

# ---------- compute ----------
metrics = []

for name, df in [("BM25", bm25_df), ("TrialGPT-test", trialgpt_df)]:
    p10, r10 = precision_recall_at_k(df, qrels_df, 10)
    p20, r20 = precision_recall_at_k(df, qrels_df, 20)
    metrics.append({
        "model": name,
        "P@10": round(p10, 4),
        "R@10": round(r10, 4),
        "R@20": round(r20, 4)
    })

metrics_df = pd.DataFrame(metrics)
display(metrics_df)

Unnamed: 0,model,P@10,R@10,R@20
0,BM25,0.54,0.0312,0.051
1,TrialGPT-test,0.52,0.0282,0.0381


In [23]:
import json
import time
import requests
from pathlib import Path
from tqdm import tqdm

# ------------------ FULL RUN SETTINGS ------------------
TOP_K = 100
RUN_TAG = "llama31_8b_full_top100"
SLEEP_SEC = 0.0                 # set 0.1–0.3 if Ollama gets unstable
NUM_PREDICT = 900               # higher cap reduces truncated JSON
INCL_SLICE = 45                 # criteria slices (balance speed vs accuracy)
EXCL_SLICE = 35

OLLAMA_URL = "http://localhost:11434/api/generate"
OLLAMA_MODEL = "llama3.1:8b"

CACHE_BASE = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\matches_full_top100")
CACHE_BASE.mkdir(parents=True, exist_ok=True)

def cache_path(topic_id: str, trial_id: str) -> Path:
    topic_dir = CACHE_BASE / f"topic_{topic_id}"
    topic_dir.mkdir(parents=True, exist_ok=True)
    return topic_dir / f"{trial_id}_{RUN_TAG}.json"

def error_path(topic_id: str, trial_id: str) -> Path:
    topic_dir = CACHE_BASE / f"topic_{topic_id}"
    topic_dir.mkdir(parents=True, exist_ok=True)
    return topic_dir / f"{trial_id}_{RUN_TAG}.error.txt"

def ollama_generate(prompt: str, temperature: float = 0.0, num_predict: int = NUM_PREDICT) -> str:
    payload = {
        "model": OLLAMA_MODEL,
        "prompt": prompt,
        "stream": False,
        "options": {
            "temperature": temperature,
            "num_predict": num_predict,
        }
    }
    r = requests.post(OLLAMA_URL, json=payload, timeout=300)
    r.raise_for_status()
    return r.json()["response"]

def build_matching_prompt(patient_text: str, inclusion: list, exclusion: list, trial_id: str) -> str:
    inc = "\n".join([f"- {c}" for c in inclusion[:INCL_SLICE]])
    exc = "\n".join([f"- {c}" for c in exclusion[:EXCL_SLICE]])

    return f"""
Return ONLY valid JSON (no markdown, no extra text).

Schema:
{{
  "trial_id": "{trial_id}",
  "overall_assessment": "likely_eligible" | "likely_ineligible" | "uncertain",
  "inclusion": [{{"criterion":"...","label":"met|not_met|unknown","evidence":"quote|none"}}],
  "exclusion":  [{{"criterion":"...","label":"triggers|does_not_trigger|unknown","evidence":"quote|none"}}],
  "notes": "1-2 sentences"
}}

Rules:
- Use ONLY patient text. If not stated, label unknown.
- Pick at most 12 inclusion and 8 exclusion items (most decisive).
- Be conservative: if any exclusion likely triggers, overall_assessment should be likely_ineligible.

PATIENT:
{patient_text}

INCLUSION (subset):
{inc if inc.strip() else "- (none found)"}

EXCLUSION (subset):
{exc if exc.strip() else "- (none found)"}
""".strip()

def safe_json_load(s: str):
    s = s.strip()
    try:
        return json.loads(s)
    except Exception:
        pass
    start = s.find("{")
    end = s.rfind("}")
    if start != -1 and end != -1 and end > start:
        return json.loads(s[start:end+1])
    raise ValueError("Could not parse JSON from model output.")

# ------------------ FULL RUN LOOP ------------------
topic_ids = sorted(queries_df["topic_id"].astype(str).unique(), key=lambda x: int(x))

total_saved = total_skipped = total_errors = 0

for topic_id in tqdm(topic_ids, desc="FULL topics (75)"):
    patient_text = queries_df.loc[queries_df["topic_id"].astype(str) == topic_id, "text"].iloc[0]

    cands = (
        run_df[run_df["topic_id"].astype(str) == topic_id]
        .sort_values("rank")
        .head(TOP_K)["doc_id"]
        .tolist()
    )

    saved = skipped = errors = 0

    for trial_id in tqdm(cands, desc=f"Topic {topic_id}", leave=False):
        out_json = cache_path(topic_id, trial_id)
        if out_json.exists():
            skipped += 1
            continue

        try:
            trial_text = get_trial_text(trial_id)
            elig = extract_eligibility(trial_text)

            prompt = build_matching_prompt(patient_text, elig["inclusion"], elig["exclusion"], trial_id)
            raw = ollama_generate(prompt, temperature=0.0, num_predict=NUM_PREDICT)
            match_json = safe_json_load(raw)

            match_json["_meta"] = {
                "topic_id": topic_id,
                "trial_id": trial_id,
                "model": OLLAMA_MODEL,
                "run_tag": RUN_TAG,
                "top_k": TOP_K
            }

            with open(out_json, "w", encoding="utf-8") as f:
                json.dump(match_json, f, indent=2)

            saved += 1

        except Exception as e:
            errors += 1
            with open(error_path(topic_id, trial_id), "w", encoding="utf-8") as f:
                f.write(str(e))

        if SLEEP_SEC > 0:
            time.sleep(SLEEP_SEC)

    total_saved += saved
    total_skipped += skipped
    total_errors += errors

    print(f"\n[Topic {topic_id}] saved={saved} skipped={skipped} errors={errors}")

print("\n FULL MATCHING DONE (or resumed)")
print("Total saved:", total_saved)
print("Total skipped:", total_skipped)
print("Total errors:", total_errors)
print("Cache base:", str(CACHE_BASE))

FULL topics (75):   1%|▏         | 1/75 [15:23<18:58:32, 923.14s/it]


[Topic 1] saved=53 skipped=38 errors=9


FULL topics (75):   3%|▎         | 2/75 [34:27<21:21:26, 1053.24s/it]


[Topic 2] saved=97 skipped=0 errors=3


FULL topics (75):   4%|▍         | 3/75 [53:44<22:01:00, 1100.85s/it]


[Topic 3] saved=97 skipped=0 errors=3


FULL topics (75):   5%|▌         | 4/75 [1:15:25<23:16:08, 1179.83s/it]


[Topic 4] saved=90 skipped=0 errors=10


FULL topics (75):   7%|▋         | 5/75 [1:33:56<22:27:19, 1154.84s/it]


[Topic 5] saved=99 skipped=0 errors=1


FULL topics (75):   8%|▊         | 6/75 [1:53:57<22:26:06, 1170.53s/it]


[Topic 6] saved=92 skipped=0 errors=8


FULL topics (75):   9%|▉         | 7/75 [2:13:21<22:04:11, 1168.41s/it]


[Topic 7] saved=98 skipped=0 errors=2


FULL topics (75):  11%|█         | 8/75 [2:30:44<21:00:02, 1128.39s/it]


[Topic 8] saved=100 skipped=0 errors=0


FULL topics (75):  12%|█▏        | 9/75 [2:51:16<21:17:03, 1160.96s/it]


[Topic 9] saved=98 skipped=0 errors=2


FULL topics (75):  13%|█▎        | 10/75 [3:13:07<21:47:52, 1207.27s/it]


[Topic 10] saved=97 skipped=0 errors=3


FULL topics (75):  15%|█▍        | 11/75 [3:32:32<21:13:47, 1194.17s/it]


[Topic 11] saved=96 skipped=0 errors=4


FULL topics (75):  16%|█▌        | 12/75 [3:53:13<21:08:50, 1208.41s/it]


[Topic 12] saved=98 skipped=0 errors=2


FULL topics (75):  17%|█▋        | 13/75 [4:13:12<20:45:47, 1205.60s/it]


[Topic 13] saved=97 skipped=0 errors=3


FULL topics (75):  19%|█▊        | 14/75 [4:33:03<20:21:09, 1201.14s/it]


[Topic 14] saved=98 skipped=0 errors=2


FULL topics (75):  20%|██        | 15/75 [4:51:44<19:37:07, 1177.12s/it]


[Topic 15] saved=98 skipped=0 errors=2


FULL topics (75):  21%|██▏       | 16/75 [5:12:47<19:42:55, 1202.98s/it]


[Topic 16] saved=98 skipped=0 errors=2


FULL topics (75):  23%|██▎       | 17/75 [5:36:44<20:30:54, 1273.36s/it]


[Topic 17] saved=92 skipped=0 errors=8


FULL topics (75):  24%|██▍       | 18/75 [5:55:41<19:30:49, 1232.45s/it]


[Topic 18] saved=98 skipped=0 errors=2


FULL topics (75):  25%|██▌       | 19/75 [6:13:55<18:31:24, 1190.80s/it]


[Topic 19] saved=100 skipped=0 errors=0


FULL topics (75):  27%|██▋       | 20/75 [6:33:56<18:14:16, 1193.75s/it]


[Topic 20] saved=96 skipped=0 errors=4


FULL topics (75):  28%|██▊       | 21/75 [6:52:45<17:36:57, 1174.41s/it]


[Topic 21] saved=99 skipped=0 errors=1


FULL topics (75):  29%|██▉       | 22/75 [7:10:22<16:46:10, 1139.07s/it]


[Topic 22] saved=99 skipped=0 errors=1


FULL topics (75):  31%|███       | 23/75 [7:32:29<17:16:09, 1195.56s/it]


[Topic 23] saved=93 skipped=0 errors=7


FULL topics (75):  32%|███▏      | 24/75 [7:53:54<17:19:02, 1222.41s/it]


[Topic 24] saved=94 skipped=0 errors=6


FULL topics (75):  33%|███▎      | 25/75 [8:15:50<17:22:01, 1250.44s/it]


[Topic 25] saved=98 skipped=0 errors=2


FULL topics (75):  35%|███▍      | 26/75 [8:34:23<16:27:28, 1209.15s/it]


[Topic 26] saved=97 skipped=0 errors=3


FULL topics (75):  36%|███▌      | 27/75 [8:56:42<16:38:35, 1248.24s/it]


[Topic 27] saved=92 skipped=0 errors=8


FULL topics (75):  37%|███▋      | 28/75 [9:15:19<15:46:51, 1208.76s/it]


[Topic 28] saved=99 skipped=0 errors=1


FULL topics (75):  39%|███▊      | 29/75 [9:34:21<15:11:28, 1188.89s/it]


[Topic 29] saved=95 skipped=0 errors=5


FULL topics (75):  40%|████      | 30/75 [9:53:10<14:38:11, 1170.93s/it]


[Topic 30] saved=97 skipped=0 errors=3


FULL topics (75):  41%|████▏     | 31/75 [10:11:33<14:03:40, 1150.47s/it]


[Topic 31] saved=99 skipped=0 errors=1


FULL topics (75):  43%|████▎     | 32/75 [10:32:15<14:04:04, 1177.77s/it]


[Topic 32] saved=94 skipped=0 errors=6


FULL topics (75):  44%|████▍     | 33/75 [10:53:00<13:58:42, 1198.16s/it]


[Topic 33] saved=96 skipped=0 errors=4


FULL topics (75):  45%|████▌     | 34/75 [11:12:16<13:29:59, 1185.36s/it]


[Topic 34] saved=99 skipped=0 errors=1


FULL topics (75):  47%|████▋     | 35/75 [11:31:27<13:03:18, 1174.97s/it]


[Topic 35] saved=97 skipped=0 errors=3


FULL topics (75):  48%|████▊     | 36/75 [11:52:48<13:04:27, 1206.85s/it]


[Topic 36] saved=94 skipped=0 errors=6


FULL topics (75):  49%|████▉     | 37/75 [12:13:16<12:48:28, 1213.37s/it]


[Topic 37] saved=95 skipped=0 errors=5


FULL topics (75):  51%|█████     | 38/75 [12:34:10<12:35:42, 1225.47s/it]


[Topic 38] saved=95 skipped=0 errors=5


FULL topics (75):  52%|█████▏    | 39/75 [12:52:49<11:56:11, 1193.65s/it]


[Topic 39] saved=94 skipped=0 errors=6


FULL topics (75):  53%|█████▎    | 40/75 [13:11:14<11:20:41, 1166.89s/it]


[Topic 40] saved=97 skipped=0 errors=3


FULL topics (75):  55%|█████▍    | 41/75 [13:29:26<10:48:29, 1144.40s/it]


[Topic 41] saved=100 skipped=0 errors=0


FULL topics (75):  56%|█████▌    | 42/75 [13:49:58<10:43:51, 1170.65s/it]


[Topic 42] saved=96 skipped=0 errors=4


FULL topics (75):  57%|█████▋    | 43/75 [14:10:13<10:31:27, 1183.99s/it]


[Topic 43] saved=95 skipped=0 errors=5


FULL topics (75):  59%|█████▊    | 44/75 [14:28:15<9:55:52, 1153.30s/it] 


[Topic 44] saved=97 skipped=0 errors=3


FULL topics (75):  60%|██████    | 45/75 [14:49:37<9:55:59, 1191.97s/it]


[Topic 45] saved=95 skipped=0 errors=5


FULL topics (75):  61%|██████▏   | 46/75 [15:11:59<9:57:57, 1237.16s/it]


[Topic 46] saved=90 skipped=0 errors=10


FULL topics (75):  63%|██████▎   | 47/75 [15:31:24<9:27:13, 1215.47s/it]


[Topic 47] saved=99 skipped=0 errors=1


FULL topics (75):  64%|██████▍   | 48/75 [15:49:58<8:53:16, 1185.05s/it]


[Topic 48] saved=99 skipped=0 errors=1


FULL topics (75):  65%|██████▌   | 49/75 [16:08:54<8:27:09, 1170.35s/it]


[Topic 49] saved=97 skipped=0 errors=3


FULL topics (75):  67%|██████▋   | 50/75 [16:26:57<7:56:41, 1144.08s/it]


[Topic 50] saved=96 skipped=0 errors=4


FULL topics (75):  68%|██████▊   | 51/75 [16:47:37<7:49:07, 1172.83s/it]


[Topic 51] saved=96 skipped=0 errors=4


FULL topics (75):  69%|██████▉   | 52/75 [17:07:40<7:33:02, 1181.86s/it]


[Topic 52] saved=100 skipped=0 errors=0


FULL topics (75):  71%|███████   | 53/75 [17:28:25<7:20:16, 1200.76s/it]


[Topic 53] saved=94 skipped=0 errors=6


FULL topics (75):  72%|███████▏  | 54/75 [17:49:26<7:06:37, 1218.91s/it]


[Topic 54] saved=95 skipped=0 errors=5


FULL topics (75):  73%|███████▎  | 55/75 [18:07:15<6:31:18, 1173.93s/it]


[Topic 55] saved=99 skipped=0 errors=1


FULL topics (75):  75%|███████▍  | 56/75 [18:25:22<6:03:27, 1147.78s/it]


[Topic 56] saved=100 skipped=0 errors=0


FULL topics (75):  76%|███████▌  | 57/75 [18:44:05<5:42:09, 1140.52s/it]


[Topic 57] saved=98 skipped=0 errors=2


FULL topics (75):  77%|███████▋  | 58/75 [19:00:33<5:10:09, 1094.70s/it]


[Topic 58] saved=98 skipped=0 errors=2


FULL topics (75):  79%|███████▊  | 59/75 [19:19:42<4:56:17, 1111.09s/it]


[Topic 59] saved=99 skipped=0 errors=1


FULL topics (75):  80%|████████  | 60/75 [19:41:59<4:54:39, 1178.65s/it]


[Topic 60] saved=91 skipped=0 errors=9


FULL topics (75):  81%|████████▏ | 61/75 [20:05:48<4:52:32, 1253.77s/it]


[Topic 61] saved=95 skipped=0 errors=5


FULL topics (75):  83%|████████▎ | 62/75 [20:23:35<4:19:30, 1197.71s/it]


[Topic 62] saved=99 skipped=0 errors=1


FULL topics (75):  84%|████████▍ | 63/75 [20:42:07<3:54:24, 1172.04s/it]


[Topic 63] saved=95 skipped=0 errors=5


FULL topics (75):  85%|████████▌ | 64/75 [21:04:21<3:43:48, 1220.81s/it]


[Topic 64] saved=90 skipped=0 errors=10


FULL topics (75):  87%|████████▋ | 65/75 [21:25:14<3:25:04, 1230.43s/it]


[Topic 65] saved=98 skipped=0 errors=2


FULL topics (75):  88%|████████▊ | 66/75 [21:44:53<3:02:14, 1214.89s/it]


[Topic 66] saved=98 skipped=0 errors=2


FULL topics (75):  89%|████████▉ | 67/75 [22:03:38<2:38:23, 1187.98s/it]


[Topic 67] saved=97 skipped=0 errors=3


FULL topics (75):  91%|█████████ | 68/75 [22:22:59<2:17:39, 1179.96s/it]


[Topic 68] saved=96 skipped=0 errors=4


FULL topics (75):  92%|█████████▏| 69/75 [22:44:20<2:01:00, 1210.14s/it]


[Topic 69] saved=98 skipped=0 errors=2


FULL topics (75):  93%|█████████▎| 70/75 [23:04:28<1:40:47, 1209.50s/it]


[Topic 70] saved=98 skipped=0 errors=2


FULL topics (75):  95%|█████████▍| 71/75 [23:24:03<1:19:57, 1199.30s/it]


[Topic 71] saved=97 skipped=0 errors=3


FULL topics (75):  96%|█████████▌| 72/75 [23:44:30<1:00:22, 1207.61s/it]


[Topic 72] saved=98 skipped=0 errors=2


FULL topics (75):  97%|█████████▋| 73/75 [24:04:07<39:56, 1198.27s/it]  


[Topic 73] saved=92 skipped=0 errors=8


FULL topics (75):  99%|█████████▊| 74/75 [24:22:56<19:37, 1177.62s/it]


[Topic 74] saved=98 skipped=0 errors=2


FULL topics (75): 100%|██████████| 75/75 [24:41:17<00:00, 1185.03s/it]


[Topic 75] saved=99 skipped=0 errors=1

 FULL MATCHING DONE (or resumed)
Total saved: 7197
Total skipped: 38
Total errors: 265
Cache base: C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\matches_full_top100





In [24]:
import json
from pathlib import Path
import pandas as pd

TOP_K = 100
RUN_TAG = "llama31_8b_full_top100"
OLLAMA_MODEL = "llama3.1:8b"

CACHE_BASE = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\matches_full_top100")
OUT_DIR = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\full_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

full_run_path = OUT_DIR / f"trialgpt_{RUN_TAG}_trec2021_top{TOP_K}.run"
coverage_path = OUT_DIR / f"trialgpt_{RUN_TAG}_trec2021_top{TOP_K}_coverage.csv"
table_path = OUT_DIR / f"trialgpt_{RUN_TAG}_trec2021_top{TOP_K}_table.csv"

def score_match(result_json: dict) -> float:
    # deterministic aggregator (same idea as before)
    score = 0.0
    inc = result_json.get("inclusion", [])
    exc = result_json.get("exclusion", [])

    for it in inc:
        lab = (it.get("label") or "").strip().lower()
        if lab == "met":
            score += 1.0
        elif lab == "not_met":
            score -= 1.0

    for it in exc:
        lab = (it.get("label") or "").strip().lower()
        if lab in {"triggers", "met"}:
            score -= 4.0
        elif lab in {"does_not_trigger", "not_met"}:
            score += 0.5

    return score

topic_ids = sorted(queries_df["topic_id"].astype(str).unique(), key=lambda x: int(x))

all_lines = []
table_rows = []
cov_rows = []

runname = f"trialgpt_{RUN_TAG}"

for topic_id in topic_ids:
    cands = (
        run_df[run_df["topic_id"].astype(str) == topic_id]
        .sort_values("rank")
        .head(TOP_K)[["doc_id", "rank", "score"]]
        .rename(columns={"doc_id": "trial_id", "rank": "bm25_rank", "score": "bm25_score"})
        .reset_index(drop=True)
    )

    topic_dir = CACHE_BASE / f"topic_{topic_id}"
    matched = []
    missing = 0

    for r in cands.itertuples(index=False):
        json_path = topic_dir / f"{r.trial_id}_{RUN_TAG}.json"
        if not json_path.exists():
            missing += 1
            continue

        with open(json_path, "r", encoding="utf-8") as f:
            mj = json.load(f)

        agg = score_match(mj)
        matched.append({
            "trial_id": r.trial_id,
            "agg_score": float(agg),
            "bm25_score": float(r.bm25_score),
            "bm25_rank": int(r.bm25_rank),
            "overall_assessment": mj.get("overall_assessment", ""),
            "notes": mj.get("notes", "")
        })

    cov_rows.append({
        "topic_id": topic_id,
        "candidates": len(cands),
        "matched_cached": len(matched),
        "missing_cached": missing
    })

    if not matched:
        continue

    ranked = (
        pd.DataFrame(matched)
        .sort_values(["agg_score", "bm25_score"], ascending=[False, False])
        .reset_index(drop=True)
    )

    # write TREC run lines for this topic
    for rank, row in enumerate(ranked.itertuples(index=False), start=1):
        all_lines.append(f"{topic_id} Q0 {row.trial_id} {rank} {row.agg_score:.6f} {runname}\n")

    # store table rows
    for row in ranked.itertuples(index=False):
        table_rows.append({
            "topic_id": topic_id,
            "trial_id": row.trial_id,
            "agg_score": row.agg_score,
            "bm25_score": row.bm25_score,
            "bm25_rank": row.bm25_rank,
            "overall_assessment": row.overall_assessment,
            "notes": row.notes
        })

with open(full_run_path, "w", encoding="utf-8") as f:
    f.writelines(all_lines)

coverage_df = pd.DataFrame(cov_rows)
table_df = pd.DataFrame(table_rows)

coverage_df.to_csv(coverage_path, index=False)
table_df.to_csv(table_path, index=False)

print("✅ Saved FULL reranked run:", full_run_path)
print("✅ Saved coverage report:", coverage_path)
print("✅ Saved rerank table:", table_path)

print("\nCoverage summary:")
print(coverage_df["matched_cached"].describe())
display(coverage_df.head(10))

✅ Saved FULL reranked run: C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\full_outputs\trialgpt_llama31_8b_full_top100_trec2021_top100.run
✅ Saved coverage report: C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\full_outputs\trialgpt_llama31_8b_full_top100_trec2021_top100_coverage.csv
✅ Saved rerank table: C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\full_outputs\trialgpt_llama31_8b_full_top100_trec2021_top100_table.csv

Coverage summary:
count     75.000000
mean      96.466667
std        2.616648
min       90.000000
25%       95.000000
50%       97.000000
75%       98.000000
max      100.000000
Name: matched_cached, dtype: float64


Unnamed: 0,topic_id,candidates,matched_cached,missing_cached
0,1,100,91,9
1,2,100,97,3
2,3,100,97,3
3,4,100,90,10
4,5,100,99,1
5,6,100,92,8
6,7,100,98,2
7,8,100,100,0
8,9,100,98,2
9,10,100,97,3


In [25]:
import pandas as pd
from pathlib import Path

TOP_K = 100
RUN_TAG = "llama31_8b_full_top100"

BM25_RUN_PATH = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\bm25_python_trec2021_top100.run")
TRIALGPT_RUN_PATH = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\full_outputs") / f"trialgpt_{RUN_TAG}_trec2021_top{TOP_K}.run"
OUT_METRICS_PATH = Path(r"C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\full_outputs") / f"trialgpt_{RUN_TAG}_trec2021_metrics.csv"

def load_run(path: Path) -> pd.DataFrame:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 6:
                continue
            topic, _, docid, rank, score, runname = parts[:6]
            rows.append({"topic_id": str(topic), "doc_id": docid, "rank": int(rank)})
    return pd.DataFrame(rows)

def precision_recall_at_k(run_df: pd.DataFrame, qrels_df: pd.DataFrame, k: int):
    rel = (
        qrels_df[qrels_df["relevance"] > 0]
        .groupby("topic_id")["doc_id"]
        .apply(set)
        .to_dict()
    )

    topics = sorted(set(run_df["topic_id"]) & set(qrels_df["topic_id"].astype(str)))
    precisions, recalls = [], []

    for t in topics:
        retrieved = (
            run_df[run_df["topic_id"] == t]
            .sort_values("rank")
            .head(k)["doc_id"]
            .tolist()
        )
        relevant = rel.get(t, set())
        if not relevant:
            continue

        hits = sum(1 for d in retrieved if d in relevant)
        precisions.append(hits / k)
        recalls.append(hits / len(relevant))

    return float(sum(precisions) / len(precisions)), float(sum(recalls) / len(recalls))

bm25_df = load_run(BM25_RUN_PATH)
trialgpt_df = load_run(TRIALGPT_RUN_PATH)

p10_b, r10_b = precision_recall_at_k(bm25_df, qrels_df, 10)
p100_b, r100_b = precision_recall_at_k(bm25_df, qrels_df, 100)

p10_t, r10_t = precision_recall_at_k(trialgpt_df, qrels_df, 10)
p100_t, r100_t = precision_recall_at_k(trialgpt_df, qrels_df, 100)

metrics_df = pd.DataFrame([
    {"model": "BM25", "Precision@10": p10_b, "Recall@10": r10_b, "Precision@100": p100_b, "Recall@100": r100_b},
    {"model": f"TrialGPT-local ({RUN_TAG})", "Precision@10": p10_t, "Recall@10": r10_t, "Precision@100": p100_t, "Recall@100": r100_t},
])

metrics_df.to_csv(OUT_METRICS_PATH, index=False)

print("✅ Saved metrics:", OUT_METRICS_PATH)
display(metrics_df)

✅ Saved metrics: C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\user\trialgpt_local\full_outputs\trialgpt_llama31_8b_full_top100_trec2021_metrics.csv


Unnamed: 0,model,Precision@10,Recall@10,Precision@100,Recall@100
0,BM25,0.3,0.027917,0.148533,0.104029
1,TrialGPT-local (llama31_8b_full_top100),0.173333,0.012462,0.144533,0.101909
