In [2]:
import sys, platform
print("Python:", sys.version)
print("Platform:", platform.platform())

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import ir_datasets
from rank_bm25 import BM25Okapi

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

DATASET_NAME = "clinicaltrials/2021/trec-ct-2022"
ds = ir_datasets.load(DATASET_NAME)

print("\nLoaded dataset:", DATASET_NAME)
print("Docs handler  :", ds.docs_handler())
print("Queries handler:", ds.queries_handler())

print("\nHas qrels?:", ds.has_qrels())

# Qrels schema + a few examples
if ds.has_qrels():
    print("Qrels definition:", ds.qrels_defs())
    qrels_preview = []
    for i, qrel in enumerate(ds.qrels_iter()):
        qrels_preview.append(qrel)
        if i >= 4:
            break
    print("First 5 qrels:", qrels_preview)

# Counts (if available)
print("\nCounts (may take time):")
print(" - #docs   :", ds.docs_count())
print(" - #queries:", ds.queries_count())
if ds.has_qrels():
    print(" - #qrels  :", ds.qrels_count())

print("\nTorch:")
print(" - torch version:", torch.__version__)
print(" - cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print(" - gpu:", torch.cuda.get_device_name(0))

Python: 3.13.9 (tags/v3.13.9:8183fa5, Oct 14 2025, 14:09:13) [MSC v.1944 64 bit (AMD64)]
Platform: Windows-11-10.0.26200-SP0

Loaded dataset: clinicaltrials/2021/trec-ct-2022
Docs handler  : <ir_datasets.datasets.clinicaltrials.ClinicalTrialsDocs object at 0x0000025E93107ED0>
Queries handler: <ir_datasets.formats.trec.TrecXmlQueries object at 0x0000025E93148F30>

Has qrels?: False

Counts (may take time):
 - #docs   : 375580
 - #queries: 50

Torch:
 - torch version: 2.9.1+cpu
 - cuda available: False


In [3]:
import os
from pathlib import Path
import urllib.request

DATASET_NAME = "clinicaltrials/2021/trec-ct-2022"
ds = ir_datasets.load(DATASET_NAME)

# ---- Load queries (topics) ----
queries = list(ds.queries_iter())
print("Loaded queries:", len(queries))
print("Query object example:", queries[0])

# ---- Qrels: use built-in if present, otherwise download qrels2022.txt ----
qrels = None

if ds.has_qrels():
    qrels = list(ds.qrels_iter())
    print("Loaded qrels from ir_datasets:", len(qrels))
else:
    # Download qrels2022.txt into your project (tiny file)
    qrels_dir = Path("data") / "trec2022"
    qrels_dir.mkdir(parents=True, exist_ok=True)
    qrels_path = qrels_dir / "qrels2022.txt"

    if not qrels_path.exists():
        url = "https://trec.nist.gov/data/trials/qrels2022.txt"
        print("Downloading qrels from:", url)
        urllib.request.urlretrieve(url, qrels_path)
        print("Saved to:", qrels_path.resolve())
    else:
        print("Using cached qrels file:", qrels_path.resolve())

    # Parse TREC qrels format: topic 0 doc_id rel
    # We'll keep only rel > 0 as relevant
    qrels = []
    with open(qrels_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) != 4:
                continue
            topic_id, _, doc_id, rel = parts
            qrels.append((topic_id, doc_id, int(rel)))

    print("Parsed qrels lines:", len(qrels))
    # quick preview
    print("First 5 qrels:", qrels[:5])

# Build a relevance dict: {topic_id: set(relevant_doc_ids)}
qrels_rel = {}
for topic_id, doc_id, rel in qrels:
    if rel > 0:
        qrels_rel.setdefault(str(topic_id), set()).add(str(doc_id))

print("\n#topics with >=1 relevant doc:", len(qrels_rel))
# show a couple topics
some_topics = list(qrels_rel.keys())[:3]
for t in some_topics:
    print(f"topic {t}: #relevant={len(qrels_rel[t])}")

[INFO] [starting] https://www.trec-cds.org/topics2022.xml
[INFO] [finished] https://www.trec-cds.org/topics2022.xml: [00:00] [32.4kB] [1.70MB/s]
                                                                    

Loaded queries: 50
Query object example: GenericQuery(query_id='1', text='\nA 19-year-old male came to clinic with some sexual concern.  He recently engaged in a relationship and is worried about the satisfaction of his girlfriend. He has a "baby face" according to his girlfriend\'s statement and he is not as muscular as his classmates.  On physical examination, there is some pubic hair and poorly developed secondary sexual characteristics. He is unable to detect coffee smell during the examination, but the visual acuity is normal. Ultrasound reveals the testes volume of 1-2 ml. The hormonal evaluation showed serum testosterone level of 65 ng/dL with low levels of GnRH.\n')
Downloading qrels from: https://trec.nist.gov/data/trials/qrels2022.txt
Saved to: C:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\notebooks\data\trec2022\qrels2022.txt
Parsed qrels lines: 35394
First 5 qrels: [('1', 'NCT00000409', 0), ('1', 'NCT00001148', 0), ('1', 'NCT00001181', 0), ('1', 'NCT00001202', 0), ('1'

In [4]:
import os
from pathlib import Path
import ir_datasets

DATASET_NAME = "clinicaltrials/2021/trec-ct-2022"
ds = ir_datasets.load(DATASET_NAME)

# Show current working directory (important for paths)
print("CWD:", Path.cwd())

# Grab one doc example
doc = next(ds.docs_iter())
print("\nDoc type:", type(doc))
print("Doc id:", getattr(doc, "doc_id", None))

# Print all available fields on this doc object (non-private)
fields = [k for k in dir(doc) if not k.startswith("_")]
print("\nFields on doc object (truncated to first 40):")
print(fields[:40])

# Print a readable dict of field->value (truncate long strings)
def trunc(x, n=250):
    s = str(x)
    return s if len(s) <= n else s[:n] + " ...[truncated]..."

print("\nDoc field values (truncated):")
for k in fields:
    v = getattr(doc, k)
    # skip methods
    if callable(v):
        continue
    print(f"- {k}: {trunc(v)}")

CWD: c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\notebooks

Doc type: <class 'ir_datasets.datasets.clinicaltrials.ClinicalTrialsDoc'>
Doc id: NCT00000102

Fields on doc object (truncated to first 40):
['condition', 'count', 'detailed_description', 'doc_id', 'eligibility', 'index', 'summary', 'title']

Doc field values (truncated):
- condition: 
- detailed_description: 
    
      This protocol is designed to assess both acute and chronic effects of the calcium channel
      antagonist, nifedipine, on the hypothalamic-pituitary-adrenal axis in patients with
      congenital adrenal hyperplasia. The multicenter tr ...[truncated]...
- doc_id: NCT00000102
- eligibility: 
      
        Inclusion Criteria:

          -  diagnosed with Congenital Adrenal Hyperplasia (CAH)

          -  normal ECG during baseline evaluation

        Exclusion Criteria:

          -  history of liver disease, or elevated liver f ...[truncated]...
- summary: 
    
      This study will test the ability o

In [5]:
import re
from pathlib import Path
import ir_datasets
from tqdm.auto import tqdm

DATASET_NAME = "clinicaltrials/2021/trec-ct-2022"
ds = ir_datasets.load(DATASET_NAME)

# --- Fix paths: set project root as parent of notebooks/ ---
PROJECT_ROOT = Path.cwd().parent
print("PROJECT_ROOT:", PROJECT_ROOT)

# Where we will save artifacts later
BM25_INDEX_DIR = PROJECT_ROOT / "models" / "non_user" / "bm25" / "index"
BM25_RUNS_DIR  = PROJECT_ROOT / "models" / "non_user" / "bm25" / "runs"
BM25_INDEX_DIR.mkdir(parents=True, exist_ok=True)
BM25_RUNS_DIR.mkdir(parents=True, exist_ok=True)

# --- Basic normalization/tokenizer ---
_ws = re.compile(r"\s+")
_tok = re.compile(r"[A-Za-z0-9]+")

def normalize_text(s: str) -> str:
    if s is None:
        return ""
    s = str(s)
    s = s.replace("\n", " ").replace("\t", " ")
    s = _ws.sub(" ", s).strip().lower()
    return s

def tokenize(s: str):
    return _tok.findall(s.lower())

def doc_to_text(doc) -> str:
    parts = [
        getattr(doc, "title", ""),
        getattr(doc, "summary", ""),
        getattr(doc, "detailed_description", ""),
        getattr(doc, "eligibility", ""),
        getattr(doc, "condition", ""),
    ]
    return normalize_text(" ".join(p for p in parts if p))

# --- Build a SMALL sample corpus (first N docs) to validate ---
N_SAMPLE = 5000

doc_ids = []
tokenized_corpus = []

for i, doc in enumerate(tqdm(ds.docs_iter(), total=N_SAMPLE, desc="Reading docs (sample)")):
    if i >= N_SAMPLE:
        break
    doc_ids.append(doc.doc_id)
    tokenized_corpus.append(tokenize(doc_to_text(doc)))

print("\nSample built:")
print(" - docs:", len(doc_ids))
print(" - example doc_id:", doc_ids[0])
print(" - example token count:", len(tokenized_corpus[0]))
print(" - first 30 tokens:", tokenized_corpus[0][:30])

PROJECT_ROOT: c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus


Reading docs (sample): 100%|██████████| 5000/5000 [00:00<00:00, 6147.34it/s]


Sample built:
 - docs: 5000
 - example doc_id: NCT00000102
 - example token count: 202
 - first 30 tokens: ['congenital', 'adrenal', 'hyperplasia', 'calcium', 'channels', 'as', 'therapeutic', 'targets', 'this', 'study', 'will', 'test', 'the', 'ability', 'of', 'extended', 'release', 'nifedipine', 'procardia', 'xl', 'a', 'blood', 'pressure', 'medication', 'to', 'permit', 'a', 'decrease', 'in', 'the']





In [6]:
from rank_bm25 import BM25Okapi
import ir_datasets

# Build BM25 on the sample tokenized corpus we created in Cell 4
bm25 = BM25Okapi(tokenized_corpus)

# Load queries again (topics)
ds = ir_datasets.load("clinicaltrials/2021/trec-ct-2022")
queries = list(ds.queries_iter())

def bm25_retrieve(query_text: str, topk: int = 10):
    q_tokens = tokenize(normalize_text(query_text))
    scores = bm25.get_scores(q_tokens)  # numpy array
    top_idx = np.argsort(scores)[::-1][:topk]
    results = [(doc_ids[i], float(scores[i])) for i in top_idx]
    return results

# Try 3 queries (1, 2, 3) against the 5k-doc sample
for q in queries[:3]:
    print("\n" + "="*80)
    print("TOPIC:", q.query_id)
    print("QUERY (first 220 chars):", normalize_text(q.text)[:220], "...")

    hits = bm25_retrieve(q.text, topk=10)
    print("\nTop-10 doc_ids + scores:")
    for rank, (did, sc) in enumerate(hits, start=1):
        print(f"{rank:2d}. {did}   score={sc:.4f}")


TOPIC: 1
QUERY (first 220 chars): a 19-year-old male came to clinic with some sexual concern. he recently engaged in a relationship and is worried about the satisfaction of his girlfriend. he has a "baby face" according to his girlfriend's statement and  ...

Top-10 doc_ids + scores:
 1. NCT00005669   score=165.3529
 2. NCT00000683   score=164.2495
 3. NCT00001202   score=163.4320
 4. NCT00000383   score=160.9882
 5. NCT00001181   score=159.1987
 6. NCT00001763   score=158.1683
 7. NCT00000387   score=157.2092
 8. NCT00005664   score=154.5284
 9. NCT00001412   score=154.3683
10. NCT00005784   score=154.1880

TOPIC: 2
QUERY (first 220 chars): a 32-year-old woman comes to the hospital with vaginal spotting. her last menstrual period was 10 weeks ago. she has regular menses lasting for 6 days and repeating every 29 days. medical history is significant for appen ...

Top-10 doc_ids + scores:
 1. NCT00000862   score=139.9716
 2. NCT00005669   score=129.7251
 3. NCT00000683   score=129.5634

In [7]:
from collections import defaultdict
import ir_datasets
import numpy as np

# Load queries
ds = ir_datasets.load("clinicaltrials/2021/trec-ct-2022")
queries = list(ds.queries_iter())

# If qrels_rel is not in memory (restart), re-load it from the file we downloaded:
# (If qrels_rel already exists from Cell 2, this won't hurt)
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent
qrels_path = PROJECT_ROOT / "data" / "trec2022" / "qrels2022.txt"
if not qrels_path.exists():
    # fallback: if your file ended up under notebooks/data
    qrels_path = Path.cwd() / "data" / "trec2022" / "qrels2022.txt"

qrels_rel = {}
with open(qrels_path, "r", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split()
        if len(parts) != 4:
            continue
        topic_id, _, doc_id, rel = parts
        rel = int(rel)
        if rel > 0:
            qrels_rel.setdefault(str(topic_id), set()).add(str(doc_id))

def retrieve_topk_for_query(qtext: str, k: int):
    q_tokens = tokenize(normalize_text(qtext))
    scores = bm25.get_scores(q_tokens)
    top_idx = np.argsort(scores)[::-1][:k]
    return [doc_ids[i] for i in top_idx]

def precision_at_k(retrieved, relevant, k):
    retrieved_k = retrieved[:k]
    if k == 0:
        return 0.0
    hits = sum(1 for d in retrieved_k if d in relevant)
    return hits / k

def recall_at_k(retrieved, relevant, k):
    if not relevant:
        return 0.0
    retrieved_k = set(retrieved[:k])
    hits = len(retrieved_k.intersection(relevant))
    return hits / len(relevant)

# Evaluate across all 50 topics
Ks = [10, 100]
agg = {k: {"p": [], "r": []} for k in Ks}

missing_topics = 0
for q in queries:
    tid = str(q.query_id)
    relevant = qrels_rel.get(tid, set())
    if not relevant:
        missing_topics += 1
        continue

    retrieved_100 = retrieve_topk_for_query(q.text, k=max(Ks))  # get once

    for k in Ks:
        agg[k]["p"].append(precision_at_k(retrieved_100, relevant, k))
        agg[k]["r"].append(recall_at_k(retrieved_100, relevant, k))

print("Evaluation (BM25 on 5k-doc sample):")
print("Topics evaluated:", len(queries) - missing_topics, " / ", len(queries))
for k in Ks:
    print(f"\nK={k}")
    print(" Precision@k:", round(float(np.mean(agg[k]['p'])), 4))
    print(" Recall@k   :", round(float(np.mean(agg[k]['r'])), 4))

Evaluation (BM25 on 5k-doc sample):
Topics evaluated: 50  /  50

K=10
 Precision@k: 0.084
 Recall@k   : 0.0102

K=100
 Precision@k: 0.0122
 Recall@k   : 0.0139


In [8]:
import ir_datasets
from tqdm.auto import tqdm
from rank_bm25 import BM25Okapi
import numpy as np

ds = ir_datasets.load("clinicaltrials/2021/trec-ct-2022")

doc_ids_full = []
tokenized_corpus_full = []

print("Building full corpus...")
for doc in tqdm(ds.docs_iter(), total=ds.docs_count()):
    doc_ids_full.append(doc.doc_id)
    tokenized_corpus_full.append(tokenize(doc_to_text(doc)))

print("\nTotal docs indexed:", len(doc_ids_full))

bm25_full = BM25Okapi(tokenized_corpus_full)

print("BM25 full index built.")

Building full corpus...


100%|██████████| 375580/375580 [01:23<00:00, 4493.45it/s]



Total docs indexed: 375580
BM25 full index built.


In [9]:
ds = ir_datasets.load("clinicaltrials/2021/trec-ct-2022")
queries = list(ds.queries_iter())

# Reuse qrels_rel from earlier (reload from file if needed)
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent
qrels_path = PROJECT_ROOT / "data" / "trec2022" / "qrels2022.txt"
if not qrels_path.exists():
    qrels_path = Path.cwd() / "data" / "trec2022" / "qrels2022.txt"

qrels_rel = {}
with open(qrels_path, "r", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split()
        if len(parts) != 4:
            continue
        topic_id, _, doc_id, rel = parts
        rel = int(rel)
        if rel > 0:
            qrels_rel.setdefault(str(topic_id), set()).add(str(doc_id))

def retrieve_topk_full(qtext: str, k: int):
    q_tokens = tokenize(normalize_text(qtext))
    scores = bm25_full.get_scores(q_tokens)
    top_idx = np.argsort(scores)[::-1][:k]
    return [doc_ids_full[i] for i in top_idx]

def precision_at_k(retrieved, relevant, k):
    retrieved_k = retrieved[:k]
    hits = sum(1 for d in retrieved_k if d in relevant)
    return hits / k

def recall_at_k(retrieved, relevant, k):
    if not relevant:
        return 0.0
    retrieved_k = set(retrieved[:k])
    hits = len(retrieved_k.intersection(relevant))
    return hits / len(relevant)

Ks = [10, 100]
agg = {k: {"p": [], "r": []} for k in Ks}

for q in queries:
    tid = str(q.query_id)
    relevant = qrels_rel.get(tid, set())
    retrieved = retrieve_topk_full(q.text, k=max(Ks))  # compute once
    for k in Ks:
        agg[k]["p"].append(precision_at_k(retrieved, relevant, k))
        agg[k]["r"].append(recall_at_k(retrieved, relevant, k))

print("Evaluation (BM25 FULL corpus):")
for k in Ks:
    print(f"\nK={k}")
    print(" Precision@k:", round(float(np.mean(agg[k]['p'])), 4))
    print(" Recall@k   :", round(float(np.mean(agg[k]['r'])), 4))

Evaluation (BM25 FULL corpus):

K=10
 Precision@k: 0.384
 Recall@k   : 0.0384

K=100
 Precision@k: 0.1842
 Recall@k   : 0.1393


In [10]:
import re
from typing import Optional, Tuple, Dict

# --- Query parsing: extract age/sex from patient narrative ---
_age_pat = re.compile(r"\b(\d{1,3})-year-old\b", re.IGNORECASE)
_male_pat = re.compile(r"\b(male|man|boy|gentleman|he|his)\b", re.IGNORECASE)
_female_pat = re.compile(r"\b(female|woman|girl|lady|she|her|pregnan)\w*\b", re.IGNORECASE)

def extract_user_age_sex(query_text: str) -> Tuple[Optional[int], Optional[str]]:
    """
    Returns (age, sex) where sex is 'M', 'F', or None.
    """
    text = query_text.strip()
    m = _age_pat.search(text)
    age = int(m.group(1)) if m else None

    # crude but effective heuristic
    male = bool(_male_pat.search(text))
    female = bool(_female_pat.search(text))

    sex = None
    if male and not female:
        sex = "M"
    elif female and not male:
        sex = "F"
    # if both or neither -> unknown
    return age, sex


# --- Eligibility parsing: infer allowed sex and age bounds from eligibility text ---
# We'll be conservative: if we can't parse, we won't filter it out.

_between_pat = re.compile(r"\bbetween\s+(\d{1,3})\s+and\s+(\d{1,3})\s+years?\b", re.IGNORECASE)
_age_ge_pat = re.compile(r"\b(\d{1,3})\s+years?\s+(?:of\s+age\s+)?and\s+older\b|\bat\s+least\s+(\d{1,3})\s+years?\b", re.IGNORECASE)
_age_le_pat = re.compile(r"\b(\d{1,3})\s+years?\s+(?:of\s+age\s+)?and\s+younger\b|\bno\s+more\s+than\s+(\d{1,3})\s+years?\b|\bup\s+to\s+(\d{1,3})\s+years?\b", re.IGNORECASE)

_males_only = re.compile(r"\b(males?\s+only|men\s+only|male\s+subjects?\s+only)\b", re.IGNORECASE)
_females_only = re.compile(r"\b(females?\s+only|women\s+only|female\s+subjects?\s+only)\b", re.IGNORECASE)
_males = re.compile(r"\b(males?|men)\b", re.IGNORECASE)
_females = re.compile(r"\b(females?|women)\b", re.IGNORECASE)

def parse_eligibility_constraints(elig_text: str) -> Dict[str, Optional[object]]:
    """
    Returns dict: {'min_age': int|None, 'max_age': int|None, 'sex': 'M'|'F'|'ALL'|None}
    sex meaning:
      - 'M'   => males only
      - 'F'   => females only
      - 'ALL' => both sexes allowed
      - None  => unknown / can't infer
    """
    if not elig_text:
        return {"min_age": None, "max_age": None, "sex": None}

    t = " ".join(str(elig_text).split())

    # Sex parsing
    sex = None
    if _males_only.search(t):
        sex = "M"
    elif _females_only.search(t):
        sex = "F"
    else:
        has_m = bool(_males.search(t))
        has_f = bool(_females.search(t))
        # if both words appear, likely both allowed (not always, but okay baseline)
        if has_m and has_f:
            sex = "ALL"
        elif has_m and not has_f:
            # could be male-only or just mentions men; keep unknown unless explicit
            sex = None
        elif has_f and not has_m:
            sex = None

    # Age parsing
    min_age = None
    max_age = None

    m = _between_pat.search(t)
    if m:
        min_age = int(m.group(1))
        max_age = int(m.group(2))

    m = _age_ge_pat.search(t)
    if m:
        # pattern has 2 groups, only one will be non-None
        g1 = m.group(1)
        g2 = m.group(2)
        mn = int(g1 or g2)
        min_age = mn if (min_age is None or mn > min_age) else min_age

    m = _age_le_pat.search(t)
    if m:
        # multiple alt groups; pick the first non-None
        mx = next((g for g in m.groups() if g is not None), None)
        if mx is not None:
            mx = int(mx)
            max_age = mx if (max_age is None or mx < max_age) else max_age

    return {"min_age": min_age, "max_age": max_age, "sex": sex}


def passes_hard_filters(user_age: Optional[int], user_sex: Optional[str], constraints: Dict[str, Optional[object]]) -> bool:
    """
    Conservative filter:
      - If we can't parse constraint, we don't exclude.
      - If user age/sex unknown, don't exclude.
    """
    mn, mx, sx = constraints["min_age"], constraints["max_age"], constraints["sex"]

    if user_age is not None:
        if mn is not None and user_age < mn:
            return False
        if mx is not None and user_age > mx:
            return False

    if user_sex is not None and sx in ("M", "F"):
        if user_sex != sx:
            return False

    return True


# --- Sanity check on first 3 topics ---
for q in queries[:3]:
    age, sex = extract_user_age_sex(q.text)
    print(f"TOPIC {q.query_id}: age={age}, sex={sex}")

# --- Sanity check on 5 retrieved docs from topic 1 (from your earlier BM25 top10 list)
# We'll fetch those docs and see what constraints we parse.
doc_map = {}  # lazy load doc objects we need
docs_iter = ds.docs_iter()
# Build a quick lookup for a handful of doc_ids (efficient enough for small count)
needed = set([did for did, _ in bm25_retrieve(queries[0].text, topk=10)])
for doc in ds.docs_iter():
    if doc.doc_id in needed:
        doc_map[doc.doc_id] = doc
        if len(doc_map) == len(needed):
            break

print("\nParsed constraints for BM25 top-10 of topic 1:")
u_age, u_sex = extract_user_age_sex(queries[0].text)
for did, sc in bm25_retrieve(queries[0].text, topk=10):
    doc = doc_map.get(did)
    c = parse_eligibility_constraints(getattr(doc, "eligibility", "") if doc else "")
    ok = passes_hard_filters(u_age, u_sex, c)
    print(f"{did} score={sc:.2f}  constraints={c}  passes={ok}")

TOPIC 1: age=19, sex=None
TOPIC 2: age=32, sex=None
TOPIC 3: age=51, sex=M

Parsed constraints for BM25 top-10 of topic 1:
NCT00005669 score=165.35  constraints={'min_age': None, 'max_age': None, 'sex': None}  passes=True
NCT00000683 score=164.25  constraints={'min_age': None, 'max_age': None, 'sex': None}  passes=True
NCT00001202 score=163.43  constraints={'min_age': None, 'max_age': None, 'sex': None}  passes=True
NCT00000383 score=160.99  constraints={'min_age': None, 'max_age': None, 'sex': None}  passes=True
NCT00001181 score=159.20  constraints={'min_age': None, 'max_age': None, 'sex': None}  passes=True
NCT00001763 score=158.17  constraints={'min_age': None, 'max_age': None, 'sex': None}  passes=True
NCT00000387 score=157.21  constraints={'min_age': None, 'max_age': None, 'sex': None}  passes=True
NCT00005664 score=154.53  constraints={'min_age': None, 'max_age': None, 'sex': 'ALL'}  passes=True
NCT00001412 score=154.37  constraints={'min_age': None, 'max_age': None, 'sex': None

In [11]:
import re
import numpy as np
from tqdm.auto import tqdm

# --- Better query sex detection: prioritize explicit "male"/"female" ---
_explicit_male = re.compile(r"\bmale\b", re.IGNORECASE)
_explicit_female = re.compile(r"\bfemale\b", re.IGNORECASE)

def extract_user_age_sex(query_text: str):
    text = query_text.strip()

    # age
    m = re.search(r"\b(\d{1,3})-year-old\b", text, flags=re.IGNORECASE)
    age = int(m.group(1)) if m else None

    # sex (explicit first)
    if _explicit_male.search(text) and not _explicit_female.search(text):
        sex = "M"
    elif _explicit_female.search(text) and not _explicit_male.search(text):
        sex = "F"
    else:
        # fallback pronoun heuristic
        male = bool(re.search(r"\b(he|his|man|boy|gentleman)\b", text, flags=re.IGNORECASE))
        female = bool(re.search(r"\b(she|her|woman|girl|lady|pregnan)\w*\b", text, flags=re.IGNORECASE))
        if male and not female:
            sex = "M"
        elif female and not male:
            sex = "F"
        else:
            sex = None

    return age, sex

# quick sanity check again
for q in queries[:3]:
    age, sex = extract_user_age_sex(q.text)
    print(f"TOPIC {q.query_id}: age={age}, sex={sex}")

# --- Hard-filtered retrieval over BM25 full index ---
K_CANDIDATES = 200  # retrieve from BM25
K_EVALS = [10, 100]

def bm25_retrieve_full(query_text: str, topk: int):
    q_tokens = tokenize(normalize_text(query_text))
    scores = bm25_full.get_scores(q_tokens)
    top_idx = np.argsort(scores)[::-1][:topk]
    return [doc_ids_full[i] for i in top_idx]

def filtered_retrieve(query_text: str, user_age, user_sex, topk_bm25=200, max_keep=100):
    """
    Retrieve topk_bm25 by BM25, then filter by eligibility constraints.
    Keep up to max_keep after filtering (rank order preserved).
    """
    cand_ids = bm25_retrieve_full(query_text, topk_bm25)

    kept = []
    for did in cand_ids:
        doc = doc_store.get(did)
        if doc is None:
            continue
        constraints = parse_eligibility_constraints(getattr(doc, "eligibility", ""))
        if passes_hard_filters(user_age, user_sex, constraints):
            kept.append(did)
        if len(kept) >= max_keep:
            break
    return kept

# --- Build a doc_store for fast access (id -> doc object) ---
# This is a bit heavy but manageable; we need doc text anyway for BioBERT next.
doc_store = {}
for doc in tqdm(ds.docs_iter(), total=ds.docs_count(), desc="Building doc_store"):
    doc_store[doc.doc_id] = doc

def precision_at_k(retrieved, relevant, k):
    hits = sum(1 for d in retrieved[:k] if d in relevant)
    return hits / k

def recall_at_k(retrieved, relevant, k):
    if not relevant:
        return 0.0
    hits = len(set(retrieved[:k]).intersection(relevant))
    return hits / len(relevant)

# Evaluate hard-filtered BM25
agg = {k: {"p": [], "r": []} for k in K_EVALS}

for q in tqdm(queries, desc="Evaluating hard-filtered BM25"):
    tid = str(q.query_id)
    relevant = qrels_rel.get(tid, set())

    user_age, user_sex = extract_user_age_sex(q.text)
    retrieved = filtered_retrieve(q.text, user_age, user_sex, topk_bm25=K_CANDIDATES, max_keep=max(K_EVALS))

    # If filter is too strict and returns few docs, pad with unfiltered BM25 so metrics are comparable
    if len(retrieved) < max(K_EVALS):
        pad = bm25_retrieve_full(q.text, topk=max(K_EVALS))
        seen = set(retrieved)
        for d in pad:
            if d not in seen:
                retrieved.append(d)
                seen.add(d)
            if len(retrieved) >= max(K_EVALS):
                break

    for k in K_EVALS:
        agg[k]["p"].append(precision_at_k(retrieved, relevant, k))
        agg[k]["r"].append(recall_at_k(retrieved, relevant, k))

print("\nEvaluation (BM25 + hard filters on FULL corpus):")
for k in K_EVALS:
    print(f"\nK={k}")
    print(" Precision@k:", round(float(np.mean(agg[k]['p'])), 4))
    print(" Recall@k   :", round(float(np.mean(agg[k]['r'])), 4))


TOPIC 1: age=19, sex=M
TOPIC 2: age=32, sex=M
TOPIC 3: age=51, sex=M


Building doc_store: 100%|██████████| 375580/375580 [00:04<00:00, 77922.65it/s] 
Evaluating hard-filtered BM25: 100%|██████████| 50/50 [07:45<00:00,  9.30s/it]


Evaluation (BM25 + hard filters on FULL corpus):

K=10
 Precision@k: 0.378
 Recall@k   : 0.0372

K=100
 Precision@k: 0.182
 Recall@k   : 0.137





In [None]:
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent

BM25_INDEX_DIR = PROJECT_ROOT / "models" / "non_user" / "bm25" / "index"
BM25_RUNS_DIR  = PROJECT_ROOT / "models" / "non_user" / "bm25" / "runs"
EVAL_DIR       = PROJECT_ROOT / "tests" / "outputs"

BM25_INDEX_DIR.mkdir(parents=True, exist_ok=True)
BM25_RUNS_DIR.mkdir(parents=True, exist_ok=True)
EVAL_DIR.mkdir(parents=True, exist_ok=True)

print("PROJECT_ROOT:", PROJECT_ROOT)
print("BM25_INDEX_DIR:", BM25_INDEX_DIR)
print("BM25_RUNS_DIR :", BM25_RUNS_DIR)
print("EVAL_DIR      :", EVAL_DIR)

PROJECT_ROOT: c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus
BM25_INDEX_DIR: c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\non_user\bm25\index
BM25_RUNS_DIR : c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\non_user\bm25\runs
EVAL_DIR      : c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\tests\outputs


In [13]:
import pickle

bm25_path = BM25_INDEX_DIR / "bm25_full.pkl"

with open(bm25_path, "wb") as f:
    pickle.dump(
        {
            "bm25": bm25_full,
            "doc_ids": doc_ids_full
        },
        f,
        protocol=pickle.HIGHEST_PROTOCOL
    )

print("BM25 index saved to:", bm25_path)
print("Total docs saved:", len(doc_ids_full))

BM25 index saved to: c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\non_user\bm25\index\bm25_full.pkl
Total docs saved: 375580


In [14]:
import ir_datasets
import numpy as np

ds = ir_datasets.load("clinicaltrials/2021/trec-ct-2022")
queries = list(ds.queries_iter())

RUN_K = 100
run_path = BM25_RUNS_DIR / "bm25_trec2022_top100.run"

with open(run_path, "w", encoding="utf-8") as f:
    for q in queries:
        qid = q.query_id
        q_tokens = tokenize(normalize_text(q.text))
        scores = bm25_full.get_scores(q_tokens)
        top_idx = np.argsort(scores)[::-1][:RUN_K]

        for rank, i in enumerate(top_idx, start=1):
            doc_id = doc_ids_full[i]
            score = float(scores[i])
            # TREC format: topic Q0 doc rank score tag
            f.write(f"{qid} Q0 {doc_id} {rank} {score:.6f} BM25\n")

print("BM25 run file saved to:", run_path)

BM25 run file saved to: c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\non_user\bm25\runs\bm25_trec2022_top100.run


In [15]:
metrics = [
    {"metric": "Precision@10",  "value": np.mean(agg[10]["p"])},
    {"metric": "Recall@10",     "value": np.mean(agg[10]["r"])},
    {"metric": "Precision@100", "value": np.mean(agg[100]["p"])},
    {"metric": "Recall@100",    "value": np.mean(agg[100]["r"])},
]

df_metrics = pd.DataFrame(metrics)

metrics_path = EVAL_DIR / "bm25_trec2022_metrics.csv"
df_metrics.to_csv(metrics_path, index=False)

print("Metrics saved to:", metrics_path)
df_metrics

Metrics saved to: c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\tests\outputs\bm25_trec2022_metrics.csv


Unnamed: 0,metric,value
0,Precision@10,0.378
1,Recall@10,0.037235
2,Precision@100,0.182
3,Recall@100,0.136989


In [16]:
config = {
    "dataset": "clinicaltrials/2021/trec-ct-2022",
    "retriever": "BM25",
    "doc_fields": [
        "title",
        "summary",
        "detailed_description",
        "eligibility",
        "condition"
    ],
    "bm25_k_eval": [10, 100],
    "notes": "Baseline non-user retrieval model"
}

import json

config_path = EVAL_DIR / "bm25_trec2022_config.json"
with open(config_path, "w", encoding="utf-8") as f:
    json.dump(config, f, indent=2)

print("Config saved to:", config_path)

Config saved to: c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\tests\outputs\bm25_trec2022_config.json


In [17]:
import numpy as np
from tqdm.auto import tqdm

# Parameters
BM25_CANDIDATES = 100   # rerank top-100
BIOBERT_TOPK = 10       # final top-k after rerank

# Helper: get BM25 candidates
def bm25_candidates(query_text, topk=BM25_CANDIDATES):
    q_tokens = tokenize(normalize_text(query_text))
    scores = bm25_full.get_scores(q_tokens)
    idx = np.argsort(scores)[::-1][:topk]
    return [(doc_ids_full[i], float(scores[i])) for i in idx]

# Build text lookup once (doc_id -> full text)
doc_text_map = {}

for doc in tqdm(ds.docs_iter(), total=ds.docs_count(), desc="Building doc_text_map"):
    doc_text_map[doc.doc_id] = doc_to_text(doc)

print("Doc text map built:", len(doc_text_map))
print("Sample doc text length:", len(next(iter(doc_text_map.values()))))

Building doc_text_map: 100%|██████████| 375580/375580 [00:45<00:00, 8282.92it/s]

Doc text map built: 375580
Sample doc text length: 1367





In [18]:
from sentence_transformers import SentenceTransformer
import torch

print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

EMBED_MODEL = "pritamdeka/S-BioBERT-snli-multinli-stsb"

embedder = SentenceTransformer(
    EMBED_MODEL,
    device="cpu"
)

print("Loaded embedder:", EMBED_MODEL)

Torch: 2.9.1+cpu
CUDA available: False


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


Loaded embedder: pritamdeka/S-BioBERT-snli-multinli-stsb


In [19]:
import sklearn.metrics.pairwise
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

def biobert_rerank(query_text, bm25_hits, topk=BIOBERT_TOPK):
    """
    bm25_hits: list of (doc_id, bm25_score)
    """
    docs = [doc_text_map[did] for did, _ in bm25_hits]

    # Encode query + docs
    q_emb = embedder.encode([query_text], convert_to_numpy=True)
    d_emb = embedder.encode(docs, convert_to_numpy=True, show_progress_bar=False)

    # Cosine similarity
    sims = cosine_similarity(q_emb, d_emb)[0]

    # Rerank
    idx = np.argsort(sims)[::-1][:topk]
    reranked = [(bm25_hits[i][0], float(sims[i])) for i in idx]

    return reranked

# Quick sanity check on first topic
q0 = queries[0]
bm25_hits = bm25_candidates(q0.text)
reranked = biobert_rerank(q0.text, bm25_hits)

print("BioBERT reranked top-10 for topic", q0.query_id)
for i, (did, score) in enumerate(reranked, 1):
    print(f"{i:2d}. {did}  score={score:.4f}")


BioBERT reranked top-10 for topic 1
 1. NCT03459326  score=0.4864
 2. NCT02014584  score=0.3576
 3. NCT04630275  score=0.3340
 4. NCT03149692  score=0.3213
 5. NCT00494208  score=0.3102
 6. NCT02777242  score=0.3043
 7. NCT00194636  score=0.3036
 8. NCT04049331  score=0.3001
 9. NCT01689896  score=0.2760
10. NCT00644163  score=0.2648


In [20]:
from tqdm.auto import tqdm
import numpy as np

K_EVAL = [10, 100]
agg_biobert = {k: {"p": [], "r": []} for k in K_EVAL}

def precision_at_k(retrieved, relevant, k):
    return sum(1 for d in retrieved[:k] if d in relevant) / k

def recall_at_k(retrieved, relevant, k):
    if not relevant:
        return 0.0
    return len(set(retrieved[:k]).intersection(relevant)) / len(relevant)

for q in tqdm(queries, desc="Evaluating BioBERT rerank"):
    tid = str(q.query_id)
    relevant = qrels_rel.get(tid, set())

    bm25_hits = bm25_candidates(q.text)
    reranked = biobert_rerank(q.text, bm25_hits, topk=max(K_EVAL))
    retrieved_ids = [d for d, _ in reranked]

    for k in K_EVAL:
        agg_biobert[k]["p"].append(precision_at_k(retrieved_ids, relevant, k))
        agg_biobert[k]["r"].append(recall_at_k(retrieved_ids, relevant, k))

print("\nEvaluation (BM25 + BioBERT rerank):")
for k in K_EVAL:
    print(f"\nK={k}")
    print(" Precision@k:", round(float(np.mean(agg_biobert[k]['p'])), 4))
    print(" Recall@k   :", round(float(np.mean(agg_biobert[k]['r'])), 4))

Evaluating BioBERT rerank: 100%|██████████| 50/50 [11:33<00:00, 13.86s/it]


Evaluation (BM25 + BioBERT rerank):

K=10
 Precision@k: 0.31
 Recall@k   : 0.0266

K=100
 Precision@k: 0.1842
 Recall@k   : 0.1393





In [21]:
from pathlib import Path
import numpy as np
from tqdm.auto import tqdm

PROJECT_ROOT = Path.cwd().parent
BIO_RUN_DIR = PROJECT_ROOT / "models" / "non_user" / "biobert" / "reranked_runs"
BIO_RUN_DIR.mkdir(parents=True, exist_ok=True)

run_path = BIO_RUN_DIR / "biobert_rerank_trec2022_top100.run"

RUN_K = 100

with open(run_path, "w", encoding="utf-8") as f:
    for q in tqdm(queries, desc="Writing BioBERT run"):
        qid = q.query_id
        bm25_hits = bm25_candidates(q.text, topk=RUN_K)  # (doc_id, bm25_score)
        reranked = biobert_rerank(q.text, bm25_hits, topk=RUN_K)  # (doc_id, biobert_score)

        for rank, (doc_id, score) in enumerate(reranked, start=1):
            f.write(f"{qid} Q0 {doc_id} {rank} {score:.6f} BioBERT_biencoder\n")

print("Saved BioBERT run file to:", run_path)


Writing BioBERT run: 100%|██████████| 50/50 [24:19<00:00, 29.20s/it]

Saved BioBERT run file to: c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\models\non_user\biobert\reranked_runs\biobert_rerank_trec2022_top100.run





In [22]:
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent
EVAL_DIR = PROJECT_ROOT / "tests" / "outputs"
EVAL_DIR.mkdir(parents=True, exist_ok=True)

metrics = [
    {"metric": "Precision@10",  "value": float(np.mean(agg_biobert[10]["p"]))},
    {"metric": "Recall@10",     "value": float(np.mean(agg_biobert[10]["r"]))},
    {"metric": "Precision@100", "value": float(np.mean(agg_biobert[100]["p"]))},
    {"metric": "Recall@100",    "value": float(np.mean(agg_biobert[100]["r"]))},
]

df_metrics = pd.DataFrame(metrics)
metrics_path = EVAL_DIR / "biobert_rerank_trec2022_metrics.csv"
df_metrics.to_csv(metrics_path, index=False)

print("Saved BioBERT metrics to:", metrics_path)
df_metrics


Saved BioBERT metrics to: c:\Ajesh_Drive\PersonalProjects\ClinicalTrialNexus\tests\outputs\biobert_rerank_trec2022_metrics.csv


Unnamed: 0,metric,value
0,Precision@10,0.31
1,Recall@10,0.026642
2,Precision@100,0.1842
3,Recall@100,0.139303
