## Imports & constants

In [None]:
import requests
import time
import json
import threading
from queue import Queue
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import List, Dict, Any, Union

In [None]:
ENTREZ_BASE = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
TOOL = "health-harvester"

## Storage directory configurations

In [None]:
STORAGE_DIR = Path.cwd().parent.parent / "storage"
STORAGE_DIR.mkdir(exist_ok=True)

SEED_DIR=STORAGE_DIR / "seeds"
SEED_DIR

## Stats class

In [None]:
class SearchStats:
    def __init__(self, total_seeds: int):
        self.start = time.time()
        self.total_seeds = total_seeds
        self.seeds_done = 0
        self.terms_searched = 0
        self.total_pmids = 0
        self.failed_terms = 0
        self.lock = threading.Lock()

    def log_term(self, pmid_count: int, success: bool = True):
        """Call for each term searched."""
        with self.lock:
            self.terms_searched += 1
            if success:
                self.total_pmids += int(pmid_count or 0)
            else:
                self.failed_terms += 1

    def log_seed(self):
        """Call once per finished seed."""
        with self.lock:
            self.seeds_done += 1
            # print brief periodic summary
            if (self.seeds_done % 10 == 0) or (self.seeds_done == self.total_seeds):
                self.print_summary()

    def print_summary(self):
        elapsed = time.time() - self.start
        rate = self.terms_searched / max(elapsed, 1.0)
        avg_pmids = (self.total_pmids / self.terms_searched) if self.terms_searched else 0
        print(
            f"[SEARCH STATS] seeds={self.seeds_done}/{self.total_seeds} | "
            f"terms={self.terms_searched} | pmids_total={self.total_pmids} | "
            f"avg_pmids/term={avg_pmids:.1f} | failed_terms={self.failed_terms} | "
            f"{rate:.2f} terms/sec | {elapsed:.1f}s elapsed"
        )

## API key + email configuration

In [None]:
PUBMED_KEYS = [
    ("ec74621abe110994f710510d05aa0780d607", "abamsheikh@gmail.com"),
    ("fc54b7ce2e97e3cd4506c9a780d8c5e3c208", "prashant.211528@ncit.edu.np"),
    ("7441cfd3f83858837e75dd0d746419db1408", "prashantchhetrii465@gmail.com"),
    ("8e301d66b6848df9d80acf3082c910a1f308", "aman.21506@ncit.edu.np"),
    ("8e301d66b6848df9d80acf3082c910a1f308", "abamsheikh1@gmail.com"),
    ("d136942c77dd9ef93a20642fa15d9dbf8308","abamsheikh1@gmail.com"),
]

## PubMed search function

In [None]:
def pubmed_search(query: str, retstart=0, retmax=20, api_key=None, email=None):
    params = {
        "db": "pubmed",
        "term": query,
        "retmode": "json",
        "retstart": retstart,
        "retmax": retmax,
        "tool": TOOL,
        "email": email,
    }
    if api_key:
        params["api_key"] = api_key

    r = requests.get(f"{ENTREZ_BASE}/esearch.fcgi", params=params, timeout=30)
    r.raise_for_status()
    return r.json()

## API key rate-limiter

In [None]:
class ApiKeyRateLimiter:
    def __init__(self, api_key, email, delay=0.11):
        self.api_key = api_key
        self.email = email
        self.delay = delay
        self.lock = threading.Lock()
        self.last_call = 0.0

    def call(self, fn, *args, **kwargs):
        with self.lock:
            now = time.time()
            wait = max(0, self.delay - (now - self.last_call))
            if wait:
                time.sleep(wait)
            result = fn(*args, api_key=self.api_key, email=self.email, **kwargs)
            self.last_call = time.time()
            return result

In [None]:
def build_api_key_pool(key_email_pairs, delay_per_key=0.11):
    q = Queue()
    for key, email in key_email_pairs:
        q.put(ApiKeyRateLimiter(key, email, delay_per_key))
    return q

In [None]:
def load_seed_files_from_dir(seeds_dir: Union[str, Path]) -> List[Dict[str, Any]]:
    seeds_dir = Path(seeds_dir)
    seeds = []
    for p in seeds_dir.glob("*.json"):
        try:
            with open(p, "r", encoding="utf-8") as fh:
                data = json.load(fh)
            # attach a helper path so we can save the file back later if desired
            if isinstance(data, dict):
                data.setdefault("_seed_file_path", str(p))
                seeds.append(data)
            elif isinstance(data, list):
                # if file contains list of seeds, flatten
                for s in data:
                    if isinstance(s, dict):
                        s.setdefault("_seed_file_path", str(p))
                        seeds.append(s)
        except Exception as e:
            print(f"[WARN] failed to load seed file {p}: {e}")
    return seeds

## MeSH-aware query builder

In [None]:
def build_mesh_aware_query(term: str) -> str:
    return f'("{term}"[MeSH Terms] OR "{term}"[Title/Abstract])'

## Single-term PubMed search worker

In [None]:
def search_term_with_pool(term, retmax, key_pool, stats: SearchStats = None):
    """
    Borrow a limiter from key_pool, run the query, return the log dict.
    Updates stats via stats.log_term(...)
    """
    limiter = key_pool.get()
    try:
        query = build_mesh_aware_query(term)
        try:
            res = limiter.call(pubmed_search, query, 0, retmax)
            es = res.get("esearchresult", {}) if isinstance(res, dict) else {}
            count = int(es.get("count", "0") or 0)
            pmids = es.get("idlist", []) or []
            entry = {
                "term": term,
                "query": query,
                "count": count,
                "pmids": pmids,
                "api_key_used": limiter.api_key[-6:],
                "email_used": limiter.email,
                "retrieved_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
            }
            if stats:
                stats.log_term(count, success=True)
            return entry
        except Exception as e:
            # record failure entry
            entry = {
                "term": term,
                "query": query,
                "count": 0,
                "pmids": [],
                "error": str(e),
                "api_key_used": limiter.api_key[-6:],
                "email_used": limiter.email,
                "retrieved_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
            }
            if stats:
                stats.log_term(0, success=False)
            return entry
    finally:
        key_pool.put(limiter)

## Seed-level PubMed search

In [None]:
def search_seed_pubmed(seed, key_pool, top_k=3, retmax=200, stats: SearchStats = None, verbose: bool = True):
    """
    Search preferred terms for a seed using the key_pool.
    Appends per-term log entries to seed['pubmed_search_log'].
    Calls stats.log_seed() once finished.
    """
    seed.setdefault("pubmed_search_log", [])
    terms = seed.get("preferred_search_terms", [])[:top_k]
    if not terms:
        # fallback to candidates
        terms = [c.get("term") for c in seed.get("keyword_candidates", [])][:top_k]

    for term in terms:
        entry = search_term_with_pool(term, retmax, key_pool, stats=stats)
        seed["pubmed_search_log"].append(entry)
        if verbose:
            if entry.get("error"):
                print(f"[{seed.get('seed_id')}] term='{term}' FAILED -> {entry.get('error')}")
            else:
                print(f"[{seed.get('seed_id')}] term='{term}' -> {entry.get('count')} pmids (key=*{entry.get('api_key_used')})")

    if stats:
        stats.log_seed()

    return seed

## Parallel execution across all seeds

In [None]:
def run_pubmed_search_parallel(seeds, key_pool, workers=8, top_k_terms=3, retmax=200, persist=False, stats: SearchStats = None, verbose=True):
    """
    Parallel search across seeds. Prints per-seed progress and periodic summaries via stats.
    """
    results = []
    total = len(seeds)
    if stats is None:
        stats = SearchStats(total_seeds=total)

    with ThreadPoolExecutor(max_workers=workers) as ex:
        futures = {ex.submit(search_seed_pubmed, s, key_pool, top_k_terms, retmax, stats, verbose): s for s in seeds}
        for fut in as_completed(futures):
            seed = futures[fut]
            try:
                updated_seed = fut.result()
            except Exception as e:
                print(f"[ERROR] seed {seed.get('seed_id')} worker failed: {e}")
                results.append(seed)
                continue

            # optionally persist updated seed to file if _seed_file_path is present
            if persist and updated_seed.get("_seed_file_path"):
                try:
                    p = Path(updated_seed["_seed_file_path"])
                    with open(p, "w", encoding="utf-8") as fh:
                        json.dump(updated_seed, fh, indent=2, ensure_ascii=False)
                except Exception as e:
                    print(f"[WARN] failed to persist seed {updated_seed.get('seed_id')}: {e}")

            results.append(updated_seed)

    # final stats print
    stats.print_summary()
    return results

## Usage

In [None]:
# build key pool from (key,email) pairs you defined earlier
key_pool = build_api_key_pool(PUBMED_KEYS, delay_per_key=0.11)

# load seeds
seeds = load_seed_files_from_dir(SEED_DIR)

# create stats
stats = SearchStats(total_seeds=len(seeds))

# run searches (prints progress per-term and periodic summary)
updated_seeds = run_pubmed_search_parallel(
    seeds,
    key_pool,
    workers=12,          # concurrency level (can be larger than number of keys)
    top_k_terms=3,
    retmax=200,
    persist=True,      # set True to write seed files back
    stats=stats,
    verbose=True
)