In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Local expansion on a real citation network using iCite + PubMed E-utilities (efetch).

- Seeds: the PMIDs you provided
- Graph building: H_full (C ∪ F induced; includes C–C, C–F, F–F)
- Two-phase expansion:
    * Bootstrap waves: k_in=1, triangle gate OFF, ECR at wave 0.8-quantile, Δφ≤0
    * Strict waves:    k_in=2, triangle gate ON (wave 0.6-quantile), ECR at 0.9-quantile (cap 1), Δφ≤0
- U-aware: exact deg_tot from iCite; deg_outU = deg_tot - deg_H
- Qualified d_out uses anchor score a(u) for fringe gravitating to C
- Logs: per-wave accepted titles; final TF-IDF (title+abstract) top terms, top MeSH, and centrality leaderboards.

NOTE: Be polite to NCBI/iCite (rate-limit). Add your email/tool below if you have one.
"""

import os, sys, time, math, json, gzip, io, random
from collections import defaultdict, Counter
from typing import Dict, List, Set, Tuple

import requests
import numpy as np
import pandas as pd
import networkx as nx

try:
    from lxml import etree as ET
except Exception:
    import xml.etree.ElementTree as ET

from sklearn.feature_extraction.text import TfidfVectorizer

# ---------- Config ----------
SEED_PMIDS = [40720602, 40122203, 39522714, 37693640, 37364610, 32929487, 30935731]

ICITE_BASE = "https://icite.od.nih.gov/api/pubs"
PUBMED_EFETCH = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"

# polite options (optional but recommended)
NCBI_TOOL = os.environ.get("NCBI_TOOL", "local-expansion")
NCBI_EMAIL = os.environ.get("NCBI_EMAIL", "example@example.com")
NCBI_API_KEY = os.environ.get("NCBI_API_KEY", "")  # if you have one

# expansion knobs
MAX_WAVES = 10
BOOTSTRAP_WAVES = 2
Q_ECR_BOOT = 0.80
Q_TRI_STRICT = 0.60
Q_ECR_STRICT = 0.90
LAMBDA_OUTU = 0.90
EPSILON_PHI = 0.0
KAPPA_ANCHOR = 2.0

# safety / scale knobs
MAX_PMIDS_TO_FETCH = 50000  # hard ceiling to avoid runaway
ICITE_CHUNK = 1000          # iCite limit per request
PUBMED_CHUNK = 200          # efetch robust chunk size

# ---------- Simple on-disk cache (jsonl) ----------
CACHE_DIR = "./_icite_pubmed_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
ICITE_CACHE_FILE = os.path.join(CACHE_DIR, "icite_cache.jsonl")
PUBMED_CACHE_FILE = os.path.join(CACHE_DIR, "pubmed_cache.jsonl")

def _load_jsonl(path):
    d = {}
    if not os.path.exists(path):
        return d
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
                d[int(obj["pmid"])] = obj
            except Exception:
                continue
    return d

def _append_jsonl(path, records):
    with open(path, "a", encoding="utf-8") as f:
        for r in records:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

ICITE_CACHE = _load_jsonl(ICITE_CACHE_FILE)
PUBMED_CACHE = _load_jsonl(PUBMED_CACHE_FILE)

# ---------- iCite fetch ----------
def fetch_icite(pmids: List[int]) -> Dict[int, dict]:
    """Fetch iCite records for pmids. Cached. Returns dict[pmid] = {...} with refs/citers + meta."""
    to_get = [str(p) for p in pmids if int(p) not in ICITE_CACHE]
    out = {}
    # batch in chunks of ≤1000
    for i in range(0, len(to_get), ICITE_CHUNK):
        batch = to_get[i:i+ICITE_CHUNK]
        if not batch:
            continue
        params = {
            "pmids": ",".join(batch),
            "legacy": "false",
            # fields: keep as default to include citations; add a few key meta fields
            # You can optionally pass fl=... to slim payloads
        }
        for tries in range(5):
            r = requests.get(ICITE_BASE, params=params, timeout=60)
            if r.status_code == 200:
                break
            time.sleep(2*(tries+1))
        r.raise_for_status()
        data = r.json().get("data", r.json())
        # normalize
        records = []
        for rec in data:
            pmid = int(rec.get("pmid") or rec.get("_id"))
            cited_by = rec.get("citedByPmids", []) or rec.get("cited_by") or []
            refs = rec.get("citedPmids", []) or rec.get("references") or []
            obj = {
                "pmid": pmid,
                "title": rec.get("title"),
                "year": rec.get("pubYear") or rec.get("year"),
                "doi": rec.get("doi"),
                "citedByPmids": [int(x) for x in cited_by],
                "citedPmids": [int(x) for x in refs],
                "rcr": rec.get("rcr"),
                "acr": rec.get("acr"),
            }
            ICITE_CACHE[pmid] = obj
            records.append(obj)
        if records:
            _append_jsonl(ICITE_CACHE_FILE, records)
        time.sleep(0.34)  # gentle pacing
    # return merged
    for p in pmids:
        p = int(p)
        if p in ICITE_CACHE:
            out[p] = ICITE_CACHE[p]
    return out

# ---------- PubMed efetch (title, abstract, MeSH) ----------
def fetch_pubmed_meta(pmids: List[int]) -> Dict[int, dict]:
    """Fetch title, abstract, MeSH headings for pmids (efetch XML). Cached."""
    to_get = [str(p) for p in pmids if int(p) not in PUBMED_CACHE]
    out = {}
    for i in range(0, len(to_get), PUBMED_CHUNK):
        batch = to_get[i:i+PUBMED_CHUNK]
        if not batch:
            continue
        params = {
            "db": "pubmed",
            "id": ",".join(batch),
            "retmode": "xml",
            "rettype": "abstract",
            "tool": NCBI_TOOL,
            "email": NCBI_EMAIL
        }
        if NCBI_API_KEY:
            params["api_key"] = NCBI_API_KEY
        for tries in range(5):
            r = requests.get(PUBMED_EFETCH, params=params, timeout=60)
            if r.status_code == 200:
                break
            time.sleep(1.0*(tries+1))
        r.raise_for_status()
        xml = r.text
        root = ET.fromstring(xml.encode("utf-8"))

        records = []
        for art in root.findall(".//PubmedArticle"):
            pmid_el = art.find(".//MedlineCitation/PMID")
            if pmid_el is None:
                continue
            pmid = int(pmid_el.text.strip())

            title_el = art.find(".//Article/ArticleTitle")
            title = "".join(title_el.itertext()).strip() if title_el is not None else ""

            abs_el = art.find(".//Article/Abstract")
            abstract = " ".join(["".join(x.itertext()).strip() for x in abs_el]) if abs_el is not None else ""

            mesh_terms = []
            for mh in art.findall(".//MeshHeading"):
                desc = mh.find("DescriptorName")
                if desc is not None and desc.text:
                    mesh_terms.append(desc.text.strip())

            obj = {
                "pmid": pmid,
                "title": title,
                "abstract": abstract,
                "mesh": mesh_terms
            }
            PUBMED_CACHE[pmid] = obj
            records.append(obj)
        if records:
            _append_jsonl(PUBMED_CACHE_FILE, records)
        time.sleep(0.34)
    # merge with existing
    for p in pmids:
        p = int(p)
        if p in PUBMED_CACHE:
            out[p] = PUBMED_CACHE[p]
        else:
            out[p] = {"pmid": p, "title": "", "abstract": "", "mesh": []}
    return out

# ---------- Utilities on sets ----------
def neighbors_from_icite(pmid: int) -> Set[int]:
    """Union of citers and references from cached iCite."""
    rec = ICITE_CACHE.get(int(pmid))
    if not rec:
        return set()
    return set(rec.get("citedByPmids", [])) | set(rec.get("citedPmids", []))

def ensure_loaded(pmids: Set[int]):
    """Ensure we have iCite + PubMed for all pmids."""
    pmids = [int(p) for p in pmids]
    if not pmids:
        return
    fetch_icite(pmids)
    fetch_pubmed_meta(pmids)

# ---------- H_full construction ----------
def induce_H_full(C: Set[int]) -> Tuple[Set[int], Set[int]]:
    """Return H_nodes and F for current C, using everything we have (iCite)."""
    F = set()
    for u in C:
        F |= neighbors_from_icite(u)
    F -= C
    H_nodes = set(C) | set(F)
    return H_nodes, F

def deg_tot(pmid: int) -> int:
    """Degree in U (exact counts from iCite)."""
    rec = ICITE_CACHE.get(int(pmid), {})
    return len(rec.get("citedByPmids", [])) + len(rec.get("citedPmids", []))

def deg_H(pmid: int, H_nodes: Set[int]) -> int:
    """Degree restricted to H."""
    return len(neighbors_from_icite(pmid) & H_nodes)

def d_in_star(pmid: int, C: Set[int]) -> int:
    """Qualified internal degree: here we use raw count of C-neighbors (k-in)."""
    return len(neighbors_from_icite(pmid) & C)

def tri_density_in_C(v: int, C: Set[int]) -> float:
    """Triangle density anchored in C for v: edges among N_C(v) inside C."""
    Nc = neighbors_from_icite(v) & C
    k = len(Nc)
    if k == 0:
        return 0.0
    # count edges among Nc within C: for each u in Nc, count links to Nc∩neighbors(u) and divide by 2
    Nc_list = list(Nc)
    Nc_set = set(Nc_list)
    m = 0
    for u in Nc_list:
        m += len(neighbors_from_icite(u) & Nc_set)
    m = m // 2
    return m / (1.0 + k)

def qualified_d_out(v: int, C: Set[int], H_nodes: Set[int], d_in_map: Dict[int,int]) -> float:
    """Qualified out-degree: downweight out-edges to fringe nodes already anchored to C."""
    d = 0.0
    Nv = neighbors_from_icite(v)
    for u in Nv:
        if u in C:
            continue
        if u in H_nodes:
            din_u = d_in_map.get(u, len(neighbors_from_icite(u) & C))
            a = min(1.0, KAPPA_ANCHOR * (din_u / max(1, deg_tot(u))))
        else:
            a = 0.0
        d += (1.0 - a)
    return d

def cut_C_F(C: Set[int], H_nodes: Set[int]) -> int:
    """Edges crossing from C to F inside H (undirected count)."""
    F = H_nodes - C
    cut = 0
    for u in C:
        cut += len(neighbors_from_icite(u) & F)
    return cut

def vol_C(C: Set[int], H_nodes: Set[int]) -> int:
    """Volume = sum of degrees in H for nodes in C."""
    return sum(deg_H(u, H_nodes) for u in C)

def delta_phi_for_candidate(v: int, C: Set[int], H_nodes: Set[int], d_in_map: Dict[int,int]) -> Tuple[float, Tuple[float,float]]:
    """Compute Δφ using qualified d_out; include deg_outU as pessimistic extra in the cut."""
    cut_now = cut_C_F(C, H_nodes)
    vol_now = vol_C(C, H_nodes)
    din = d_in_map[v]
    d_out_q = qualified_d_out(v, C, H_nodes, d_in_map)
    # pessimistic: add deg_outU of v to the new cut
    deg_outU_v = deg_tot(v) - deg_H(v, H_nodes)
    cut_new = cut_now - din + d_out_q + deg_outU_v
    vol_new = vol_now + deg_H(v, H_nodes)  # keep H-volume
    phi_now = cut_now / max(1e-9, vol_now) if vol_now>0 else 1.0
    phi_new = cut_new / max(1e-9, vol_new) if vol_new>0 else 1.0
    return (phi_new - phi_now), (phi_now, phi_new)

def conductances_U_aware(C: Set[int], H_nodes: Set[int]) -> Tuple[float,float,float]:
    """(φ_C→F, φ_F→U, φ_H→U)"""
    # internal
    phi_CtoF = (cut_C_F(C, H_nodes) / max(1, vol_C(C, H_nodes))) if vol_C(C, H_nodes)>0 else 1.0
    # external F→U\H:
    F = H_nodes - C
    num = sum((deg_tot(v) - deg_H(v, H_nodes)) for v in F)
    den = sum(deg_tot(v) for v in F) if F else 1
    phi_FtoU = num / max(1, den)
    # global H→U\H:
    numH = sum((deg_tot(v) - deg_H(v, H_nodes)) for v in H_nodes)
    denH = sum(deg_tot(v) for v in H_nodes) if H_nodes else 1
    phi_HtoU = numH / max(1, denH)
    return phi_CtoF, phi_FtoU, phi_HtoU

# ---------- Expansion driver ----------
def expand_two_phase(seed_pmids: List[int]):
    # seed load
    C = set(int(p) for p in seed_pmids)
    ensure_loaded(C)
    total_fetched = len(ICITE_CACHE)

    history = []
    accepted_titles_per_wave = []

    for wave in range(1, MAX_WAVES+1):
        # Build H_full
        H_nodes, F = induce_H_full(C)
        # fetch icite/pubmed for F (to get their neighbor lists, deg_tot, titles)
        ensure_loaded(F)
        H_nodes, F = induce_H_full(C)  # may grow after ensure_loaded

        # diagnostics (U-aware)
        phi_CtoF, phi_FtoU, phi_HtoU = conductances_U_aware(C, H_nodes)
        R = phi_CtoF / (phi_FtoU + 1e-9)

        # per-candidate features
        d_in_map = {v: d_in_star(v, C) for v in F}
        tri_map = {v: tri_density_in_C(v, C) for v in F}
        d_out_raw = {v: (deg_H(v, H_nodes) - d_in_map[v]) for v in F}
        ECR_map = {v: d_in_map[v] / (d_out_raw[v] + LAMBDA_OUTU*(deg_tot(v)-deg_H(v, H_nodes)) + 1e-9) for v in F}
        dphi_map = {}
        for v in F:
            dphi, _ = delta_phi_for_candidate(v, C, H_nodes, d_in_map)
            dphi_map[v] = dphi

        # thresholds per phase
        if wave <= BOOTSTRAP_WAVES:
            k_in = 1
            use_triangle = False
            tau_tri = None
            ecr_vals = np.array([ECR_map[v] for v in F]) if F else np.array([0.0])
            tau_ecr = float(np.quantile(ecr_vals, Q_ECR_BOOT)) if len(ecr_vals)>0 else 0.0
        else:
            k_in = 2
            use_triangle = True
            tri_vals = np.array([tri_map[v] for v in F]) if F else np.array([0.0])
            tau_tri = float(np.quantile(tri_vals, Q_TRI_STRICT)) if len(tri_vals)>0 else 0.0
            ecr_vals = np.array([ECR_map[v] for v in F]) if F else np.array([0.0])
            tau_ecr = min(1.0, float(np.quantile(ecr_vals, Q_ECR_STRICT))) if len(ecr_vals)>0 else 0.0

        # synchronous acceptance
        accepted = []
        for v in F:
            if d_in_map[v] < k_in:
                continue
            if use_triangle and tri_map[v] < tau_tri:
                continue
            if ECR_map[v] < tau_ecr:
                continue
            if dphi_map[v] > EPSILON_PHI:
                continue
            accepted.append(v)

        # logging
        print(f"\n=== Wave {wave} ===")
        print(f"Frontier size: {len(F)} | Accepted: {len(accepted)}")
        print(f"k_in={k_in} | use_triangle={use_triangle} | tau_triangle={tau_tri} | tau_ECR={tau_ecr:.4f}")
        print(f"phi_C→F={phi_CtoF:.4f} | phi_F→U={phi_FtoU:.4f} | phi_H→U={phi_HtoU:.4f} | R={R:.4f}")
        if len(F)>0:
            print(f"Frontier averages: d_in={np.mean(list(d_in_map.values())):.3f} | tri_dens={np.mean(list(tri_map.values())):.3f} | ECR={np.mean(list(ECR_map.values())):.3f} | best Δφ={min(dphi_map.values()):.5f}")

        # print accepted with titles
        ensure_loaded(accepted)
        titles_this_wave = []
        for v in accepted:
            title = PUBMED_CACHE.get(v, {}).get("title") or ICITE_CACHE.get(v, {}).get("title") or "(no title)"
            print(f"ACCEPTED: {v} — {title}")
            titles_this_wave.append((v, title))
        accepted_titles_per_wave.append(titles_this_wave)

        # commit
        C |= set(accepted)

        history.append({
            "wave": wave,
            "frontier_size": len(F),
            "accepted": len(accepted),
            "k_in": k_in,
            "use_triangle": use_triangle,
            "tau_triangle": tau_tri,
            "tau_ECR": tau_ecr,
            "phi_CtoF": phi_CtoF,
            "phi_FtoU": phi_FtoU,
            "phi_HtoU": phi_HtoU,
            "R": R,
            "avg_d_in": float(np.mean([d_in_map[v] for v in F])) if F else 0.0,
            "avg_tri_dens": float(np.mean([tri_map[v] for v in F])) if F else 0.0,
            "avg_ECR": float(np.mean([ECR_map[v] for v in F])) if F else 0.0,
            "best_delta_phi": float(min(dphi_map.values())) if F else 0.0
        })

        if len(accepted)==0:
            print("Stopping: no candidates pass gates.")
            break

        if len(ICITE_CACHE) > MAX_PMIDS_TO_FETCH:
            print("Stopping: reached fetch ceiling.")
            break

    return C, history, accepted_titles_per_wave

# ---------- TF-IDF & MeSH summaries ----------
def tfidf_top_terms(pmids: Set[int], k: int = 30) -> List[Tuple[str, float]]:
    docs = []
    for p in pmids:
        meta = PUBMED_CACHE.get(int(p), {})
        text = (meta.get("title") or "") + " " + (meta.get("abstract") or "")
        docs.append(text)
    if not docs:
        return []
    vec = TfidfVectorizer(min_df=2, max_df=0.9, ngram_range=(1,2), stop_words="english")
    X = vec.fit_transform(docs)
    scores = np.asarray(X.sum(axis=0)).ravel()
    terms = np.array(vec.get_feature_names_out())
    order = np.argsort(-scores)
    top = [(terms[i], float(scores[i])) for i in order[:k]]
    return top

def top_mesh_terms(pmids: Set[int], k: int = 20) -> List[Tuple[str,int]]:
    counter = Counter()
    for p in pmids:
        meta = PUBMED_CACHE.get(int(p), {})
        for mh in meta.get("mesh", []):
            counter[mh] += 1
    return counter.most_common(k)

# ---------- Centrality on final C ----------
def centrality_leaderboards(C: Set[int]) -> Dict[str, List[Tuple[int, float]]]:
    # Build undirected induced subgraph on C with edges when papers are linked (either direction)
    G = nx.Graph()
    G.add_nodes_from(C)
    C_list = list(C)
    C_set = set(C_list)
    for u in C_list:
        Nu = neighbors_from_icite(u) & C_set
        for v in Nu:
            if u < v:
                G.add_edge(u, v)
    # Degree centrality (simple degree)
    deg = dict(G.degree())
    deg_sorted = sorted(deg.items(), key=lambda x: (-x[1], x[0]))[:5]
    # Betweenness
    btw = nx.betweenness_centrality(G, normalized=True, k=None)  # exact; small graphs OK
    btw_sorted = sorted(btw.items(), key=lambda x: (-x[1], x[0]))[:5]
    # Eigenvector (handle disconnected)
    try:
        eig = nx.eigenvector_centrality(G, max_iter=2000)
    except Exception:
        # compute on largest component
        comp = max(nx.connected_components(G), key=len) if G.number_of_nodes() else set()
        eig_sub = nx.eigenvector_centrality(G.subgraph(comp), max_iter=2000) if comp else {}
        eig = {n: (eig_sub.get(n, 0.0)) for n in G.nodes()}
    eig_sorted = sorted(eig.items(), key=lambda x: (-x[1], x[0]))[:5]
    return {
        "degree": deg_sorted,
        "betweenness": btw_sorted,
        "eigenvector": eig_sorted
    }

def title(pmid: int) -> str:
    return (PUBMED_CACHE.get(int(pmid), {}).get("title") or
            ICITE_CACHE.get(int(pmid), {}).get("title") or
            f"(no title) PMID {pmid}")

# ---------- Main ----------
if __name__ == "__main__":
    print("Seeds:", SEED_PMIDS)
    ensure_loaded(SEED_PMIDS)

    final_C, history, accepted_titles = expand_two_phase(SEED_PMIDS)

    # Summary
    print("\n=== Final community summary ===")
    print(f"|C| = {len(final_C)}")
    H_nodes, F = induce_H_full(final_C)
    phi_CtoF, phi_FtoU, phi_HtoU = conductances_U_aware(final_C, H_nodes)
    print(f"phi_C→F={phi_CtoF:.4f} | phi_F→U={phi_FtoU:.4f} | phi_H→U={phi_HtoU:.4f}")

    # Semantic quick check: TF-IDF top terms (title+abstract) and top-20 MeSH
    ensure_loaded(final_C)
    top_terms = tfidf_top_terms(final_C, k=30)
    top_mesh = top_mesh_terms(final_C, k=20)

    print("\nTop TF-IDF terms (title+abstract):")
    for term, score in top_terms[:20]:
        print(f"  {term:40s}  {score:.3f}")

    print("\nTop 20 MeSH terms (by frequency):")
    for mh, cnt in top_mesh:
        print(f"  {mh:50s}  {cnt}")

    # Centrality leaderboards on final C
    leaders = centrality_leaderboards(final_C)

    def dump_leaders(name, lst):
        print(f"\nTop-5 by {name} centrality:")
        for pmid, val in lst:
            print(f"  {pmid}  ({val:.5f})  — {title(pmid)}")

    dump_leaders("degree", leaders["degree"])
    dump_leaders("betweenness", leaders["betweenness"])
    dump_leaders("eigenvector", leaders["eigenvector"])

    # Also print the full title list of every entry accepted on each wave
    print("\n=== Accepted titles per wave (full) ===")
    for w, lst in enumerate(accepted_titles, start=1):
        print(f"\nWave {w}: {len(lst)} accepted")
        for pmid, t in lst:
            print(f"  {pmid} — {t}")

    # Optionally: write a CSV of the final community with core fields
    rows = []
    for p in sorted(final_C):
        meta = PUBMED_CACHE.get(p, {})
        ic = ICITE_CACHE.get(p, {})
        rows.append({
            "pmid": p,
            "title": meta.get("title") or ic.get("title"),
            "year": ic.get("year"),
            "doi": ic.get("doi"),
            "rcr": ic.get("rcr"),
            "acr": ic.get("acr"),
            "mesh": "; ".join(meta.get("mesh", [])),
        })
    df = pd.DataFrame(rows)
    df.to_csv("final_community.csv", index=False)
    print("\nWrote final_community.csv")


Seeds: [40720602, 40122203, 39522714, 37693640, 37364610, 32929487, 30935731]

=== Wave 1 ===
Frontier size: 270 | Accepted: 5
k_in=1 | use_triangle=False | tau_triangle=None | tau_ECR=0.0608
phi_C→F=0.9572 | phi_F→U=0.9585 | phi_H→U=0.9548 | R=0.9986
Frontier averages: d_in=1.493 | tri_dens=0.150 | ECR=0.043 | best Δφ=-0.01466
ACCEPTED: 37973418 — Lessons Learned after 176 Patients Treated with a Standardized Procedure of Thoracoscopic Cryoanalgesia during Minimally Invasive Repair of Pectus Excavatum.
ACCEPTED: 39363012 — Effect of cryoablation in Nuss bar placement on opioid utilization and length of stay.
ACCEPTED: 38242172 — Intercostal Nerve Cryoablation or Epidural Analgesia for Multimodal Pain Management after the Nuss Procedure: A Cohort Study.
ACCEPTED: 33112997 — Use of cryoanalgesia as a postoperative pain management for open pectus carinatum repair.
ACCEPTED: 36922280 — Impact of Cryoanalgesia Use During Minimally Invasive Pectus Excavatum Repair on Hospital Days and Total

In [4]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Local expansion on a real citation network using iCite + PubMed E-utilities (efetch).

- Seeds: the PMIDs you provided
- Graph building: H_full (C ∪ F induced; includes C–C, C–F, F–F)
- Two-phase expansion:
    * Bootstrap waves: k_in=1, triangle gate OFF, ECR at wave 0.8-quantile, Δφ≤0
    * Strict waves:    k_in=2, triangle gate ON (wave 0.6-quantile), ECR at 0.9-quantile (cap 1), Δφ≤0
- U-aware: exact deg_tot from iCite; deg_outU = deg_tot - deg_H
- Qualified d_out uses anchor score a(u) for fringe gravitating to C
- Logs: per-wave accepted titles; final TF-IDF (title+abstract) top terms, top MeSH, and centrality leaderboards.

NOTE: Be polite to NCBI/iCite (rate-limit). Add your email/tool below if you have one.
"""

import os, sys, time, math, json, gzip, io, random
from collections import defaultdict, Counter
from typing import Dict, List, Set, Tuple

import requests
import numpy as np
import pandas as pd
import networkx as nx

try:
    from lxml import etree as ET
except Exception:
    import xml.etree.ElementTree as ET

from sklearn.feature_extraction.text import TfidfVectorizer

# ---------- Config ----------
SEED_PMIDS = [37693640, 37364610, 32929487, 30935731]

ICITE_BASE = "https://icite.od.nih.gov/api/pubs"
PUBMED_EFETCH = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"

# polite options (optional but recommended)
NCBI_TOOL = os.environ.get("NCBI_TOOL", "local-expansion")
NCBI_EMAIL = os.environ.get("NCBI_EMAIL", "example@example.com")
NCBI_API_KEY = os.environ.get("NCBI_API_KEY", "")  # if you have one

# expansion knobs
MAX_WAVES = 10
BOOTSTRAP_WAVES = 2
Q_ECR_BOOT = 0.80
Q_TRI_STRICT = 0.60
Q_ECR_STRICT = 0.90
LAMBDA_OUTU = 0.90
EPSILON_PHI = 0.0
KAPPA_ANCHOR = 2.0

# safety / scale knobs
MAX_PMIDS_TO_FETCH = 50000  # hard ceiling to avoid runaway
ICITE_CHUNK = 1000          # iCite limit per request
PUBMED_CHUNK = 200          # efetch robust chunk size

# ---------- Simple on-disk cache (jsonl) ----------
CACHE_DIR = "./_icite_pubmed_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
ICITE_CACHE_FILE = os.path.join(CACHE_DIR, "icite_cache.jsonl")
PUBMED_CACHE_FILE = os.path.join(CACHE_DIR, "pubmed_cache.jsonl")

def _load_jsonl(path):
    d = {}
    if not os.path.exists(path):
        return d
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
                d[int(obj["pmid"])] = obj
            except Exception:
                continue
    return d

def _append_jsonl(path, records):
    with open(path, "a", encoding="utf-8") as f:
        for r in records:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

ICITE_CACHE = _load_jsonl(ICITE_CACHE_FILE)
PUBMED_CACHE = _load_jsonl(PUBMED_CACHE_FILE)

# ---------- iCite fetch ----------
def fetch_icite(pmids: List[int]) -> Dict[int, dict]:
    """Fetch iCite records for pmids. Cached. Returns dict[pmid] = {...} with refs/citers + meta."""
    to_get = [str(p) for p in pmids if int(p) not in ICITE_CACHE]
    out = {}
    # batch in chunks of ≤1000
    for i in range(0, len(to_get), ICITE_CHUNK):
        batch = to_get[i:i+ICITE_CHUNK]
        if not batch:
            continue
        params = {
            "pmids": ",".join(batch),
            "legacy": "false",
            # fields: keep as default to include citations; add a few key meta fields
            # You can optionally pass fl=... to slim payloads
        }
        for tries in range(5):
            r = requests.get(ICITE_BASE, params=params, timeout=60)
            if r.status_code == 200:
                break
            time.sleep(2*(tries+1))
        r.raise_for_status()
        data = r.json().get("data", r.json())
        # normalize
        records = []
        for rec in data:
            pmid = int(rec.get("pmid") or rec.get("_id"))
            cited_by = rec.get("citedByPmids", []) or rec.get("cited_by") or []
            refs = rec.get("citedPmids", []) or rec.get("references") or []
            obj = {
                "pmid": pmid,
                "title": rec.get("title"),
                "year": rec.get("pubYear") or rec.get("year"),
                "doi": rec.get("doi"),
                "citedByPmids": [int(x) for x in cited_by],
                "citedPmids": [int(x) for x in refs],
                "rcr": rec.get("rcr"),
                "acr": rec.get("acr"),
            }
            ICITE_CACHE[pmid] = obj
            records.append(obj)
        if records:
            _append_jsonl(ICITE_CACHE_FILE, records)
        time.sleep(0.34)  # gentle pacing
    # return merged
    for p in pmids:
        p = int(p)
        if p in ICITE_CACHE:
            out[p] = ICITE_CACHE[p]
    return out

# ---------- PubMed efetch (title, abstract, MeSH) ----------
def fetch_pubmed_meta(pmids: List[int]) -> Dict[int, dict]:
    """Fetch title, abstract, MeSH headings for pmids (efetch XML). Cached."""
    to_get = [str(p) for p in pmids if int(p) not in PUBMED_CACHE]
    out = {}
    for i in range(0, len(to_get), PUBMED_CHUNK):
        batch = to_get[i:i+PUBMED_CHUNK]
        if not batch:
            continue
        params = {
            "db": "pubmed",
            "id": ",".join(batch),
            "retmode": "xml",
            "rettype": "abstract",
            "tool": NCBI_TOOL,
            "email": NCBI_EMAIL
        }
        if NCBI_API_KEY:
            params["api_key"] = NCBI_API_KEY
        for tries in range(5):
            r = requests.get(PUBMED_EFETCH, params=params, timeout=60)
            if r.status_code == 200:
                break
            time.sleep(1.0*(tries+1))
        r.raise_for_status()
        xml = r.text
        root = ET.fromstring(xml.encode("utf-8"))

        records = []
        for art in root.findall(".//PubmedArticle"):
            pmid_el = art.find(".//MedlineCitation/PMID")
            if pmid_el is None:
                continue
            pmid = int(pmid_el.text.strip())

            title_el = art.find(".//Article/ArticleTitle")
            title = "".join(title_el.itertext()).strip() if title_el is not None else ""

            abs_el = art.find(".//Article/Abstract")
            abstract = " ".join(["".join(x.itertext()).strip() for x in abs_el]) if abs_el is not None else ""

            mesh_terms = []
            for mh in art.findall(".//MeshHeading"):
                desc = mh.find("DescriptorName")
                if desc is not None and desc.text:
                    mesh_terms.append(desc.text.strip())

            obj = {
                "pmid": pmid,
                "title": title,
                "abstract": abstract,
                "mesh": mesh_terms
            }
            PUBMED_CACHE[pmid] = obj
            records.append(obj)
        if records:
            _append_jsonl(PUBMED_CACHE_FILE, records)
        time.sleep(0.34)
    # merge with existing
    for p in pmids:
        p = int(p)
        if p in PUBMED_CACHE:
            out[p] = PUBMED_CACHE[p]
        else:
            out[p] = {"pmid": p, "title": "", "abstract": "", "mesh": []}
    return out

# ---------- Utilities on sets ----------
def neighbors_from_icite(pmid: int) -> Set[int]:
    """Union of citers and references from cached iCite."""
    rec = ICITE_CACHE.get(int(pmid))
    if not rec:
        return set()
    return set(rec.get("citedByPmids", [])) | set(rec.get("citedPmids", []))

def ensure_loaded(pmids: Set[int]):
    """Ensure we have iCite + PubMed for all pmids."""
    pmids = [int(p) for p in pmids]
    if not pmids:
        return
    fetch_icite(pmids)
    fetch_pubmed_meta(pmids)

# ---------- H_full construction ----------
def induce_H_full(C: Set[int]) -> Tuple[Set[int], Set[int]]:
    """Return H_nodes and F for current C, using everything we have (iCite)."""
    F = set()
    for u in C:
        F |= neighbors_from_icite(u)
    F -= C
    H_nodes = set(C) | set(F)
    return H_nodes, F

def deg_tot(pmid: int) -> int:
    """Degree in U (exact counts from iCite)."""
    rec = ICITE_CACHE.get(int(pmid), {})
    return len(rec.get("citedByPmids", [])) + len(rec.get("citedPmids", []))

def deg_H(pmid: int, H_nodes: Set[int]) -> int:
    """Degree restricted to H."""
    return len(neighbors_from_icite(pmid) & H_nodes)

def d_in_star(pmid: int, C: Set[int]) -> int:
    """Qualified internal degree: here we use raw count of C-neighbors (k-in)."""
    return len(neighbors_from_icite(pmid) & C)

def tri_density_in_C(v: int, C: Set[int]) -> float:
    """Triangle density anchored in C for v: edges among N_C(v) inside C."""
    Nc = neighbors_from_icite(v) & C
    k = len(Nc)
    if k == 0:
        return 0.0
    # count edges among Nc within C: for each u in Nc, count links to Nc∩neighbors(u) and divide by 2
    Nc_list = list(Nc)
    Nc_set = set(Nc_list)
    m = 0
    for u in Nc_list:
        m += len(neighbors_from_icite(u) & Nc_set)
    m = m // 2
    return m / (1.0 + k)

def qualified_d_out(v: int, C: Set[int], H_nodes: Set[int], d_in_map: Dict[int,int]) -> float:
    """Qualified out-degree: downweight out-edges to fringe nodes already anchored to C."""
    d = 0.0
    Nv = neighbors_from_icite(v)
    for u in Nv:
        if u in C:
            continue
        if u in H_nodes:
            din_u = d_in_map.get(u, len(neighbors_from_icite(u) & C))
            a = min(1.0, KAPPA_ANCHOR * (din_u / max(1, deg_tot(u))))
        else:
            a = 0.0
        d += (1.0 - a)
    return d

def cut_C_F(C: Set[int], H_nodes: Set[int]) -> int:
    """Edges crossing from C to F inside H (undirected count)."""
    F = H_nodes - C
    cut = 0
    for u in C:
        cut += len(neighbors_from_icite(u) & F)
    return cut

def vol_C(C: Set[int], H_nodes: Set[int]) -> int:
    """Volume = sum of degrees in H for nodes in C."""
    return sum(deg_H(u, H_nodes) for u in C)

def delta_phi_for_candidate(v: int, C: Set[int], H_nodes: Set[int], d_in_map: Dict[int,int]) -> Tuple[float, Tuple[float,float]]:
    """Compute Δφ using qualified d_out; include deg_outU as pessimistic extra in the cut."""
    cut_now = cut_C_F(C, H_nodes)
    vol_now = vol_C(C, H_nodes)
    din = d_in_map[v]
    d_out_q = qualified_d_out(v, C, H_nodes, d_in_map)
    # pessimistic: add deg_outU of v to the new cut
    deg_outU_v = deg_tot(v) - deg_H(v, H_nodes)
    cut_new = cut_now - din + d_out_q + deg_outU_v
    vol_new = vol_now + deg_H(v, H_nodes)  # keep H-volume
    phi_now = cut_now / max(1e-9, vol_now) if vol_now>0 else 1.0
    phi_new = cut_new / max(1e-9, vol_new) if vol_new>0 else 1.0
    return (phi_new - phi_now), (phi_now, phi_new)

def conductances_U_aware(C: Set[int], H_nodes: Set[int]) -> Tuple[float,float,float]:
    """(φ_C→F, φ_F→U, φ_H→U)"""
    # internal
    phi_CtoF = (cut_C_F(C, H_nodes) / max(1, vol_C(C, H_nodes))) if vol_C(C, H_nodes)>0 else 1.0
    # external F→U\H:
    F = H_nodes - C
    num = sum((deg_tot(v) - deg_H(v, H_nodes)) for v in F)
    den = sum(deg_tot(v) for v in F) if F else 1
    phi_FtoU = num / max(1, den)
    # global H→U\H:
    numH = sum((deg_tot(v) - deg_H(v, H_nodes)) for v in H_nodes)
    denH = sum(deg_tot(v) for v in H_nodes) if H_nodes else 1
    phi_HtoU = numH / max(1, denH)
    return phi_CtoF, phi_FtoU, phi_HtoU

# ---------- Expansion driver ----------
def expand_two_phase(seed_pmids: List[int]):
    # seed load
    C = set(int(p) for p in seed_pmids)
    ensure_loaded(C)
    total_fetched = len(ICITE_CACHE)

    history = []
    accepted_titles_per_wave = []

    for wave in range(1, MAX_WAVES+1):
        # Build H_full
        H_nodes, F = induce_H_full(C)
        # fetch icite/pubmed for F (to get their neighbor lists, deg_tot, titles)
        ensure_loaded(F)
        H_nodes, F = induce_H_full(C)  # may grow after ensure_loaded

        # diagnostics (U-aware)
        phi_CtoF, phi_FtoU, phi_HtoU = conductances_U_aware(C, H_nodes)
        R = phi_CtoF / (phi_FtoU + 1e-9)

        # per-candidate features
        d_in_map = {v: d_in_star(v, C) for v in F}
        tri_map = {v: tri_density_in_C(v, C) for v in F}
        d_out_raw = {v: (deg_H(v, H_nodes) - d_in_map[v]) for v in F}
        ECR_map = {v: d_in_map[v] / (d_out_raw[v] + LAMBDA_OUTU*(deg_tot(v)-deg_H(v, H_nodes)) + 1e-9) for v in F}
        dphi_map = {}
        for v in F:
            dphi, _ = delta_phi_for_candidate(v, C, H_nodes, d_in_map)
            dphi_map[v] = dphi

        # thresholds per phase
        if wave <= BOOTSTRAP_WAVES:
            k_in = 1
            use_triangle = False
            tau_tri = None
            ecr_vals = np.array([ECR_map[v] for v in F]) if F else np.array([0.0])
            tau_ecr = float(np.quantile(ecr_vals, Q_ECR_BOOT)) if len(ecr_vals)>0 else 0.0
        else:
            k_in = 2
            use_triangle = True
            tri_vals = np.array([tri_map[v] for v in F]) if F else np.array([0.0])
            tau_tri = float(np.quantile(tri_vals, Q_TRI_STRICT)) if len(tri_vals)>0 else 0.0
            ecr_vals = np.array([ECR_map[v] for v in F]) if F else np.array([0.0])
            tau_ecr = min(1.0, float(np.quantile(ecr_vals, Q_ECR_STRICT))) if len(ecr_vals)>0 else 0.0

        # synchronous acceptance
        accepted = []
        for v in F:
            if d_in_map[v] < k_in:
                continue
            if use_triangle and tri_map[v] < tau_tri:
                continue
            if ECR_map[v] < tau_ecr:
                continue
            if dphi_map[v] > EPSILON_PHI:
                continue
            accepted.append(v)

        # logging
        print(f"\n=== Wave {wave} ===")
        print(f"Frontier size: {len(F)} | Accepted: {len(accepted)}")
        print(f"k_in={k_in} | use_triangle={use_triangle} | tau_triangle={tau_tri} | tau_ECR={tau_ecr:.4f}")
        print(f"phi_C→F={phi_CtoF:.4f} | phi_F→U={phi_FtoU:.4f} | phi_H→U={phi_HtoU:.4f} | R={R:.4f}")
        if len(F)>0:
            print(f"Frontier averages: d_in={np.mean(list(d_in_map.values())):.3f} | tri_dens={np.mean(list(tri_map.values())):.3f} | ECR={np.mean(list(ECR_map.values())):.3f} | best Δφ={min(dphi_map.values()):.5f}")

        # print accepted with titles
        ensure_loaded(accepted)
        titles_this_wave = []
        for v in accepted:
            title = PUBMED_CACHE.get(v, {}).get("title") or ICITE_CACHE.get(v, {}).get("title") or "(no title)"
            print(f"ACCEPTED: {v} — {title}")
            titles_this_wave.append((v, title))
        accepted_titles_per_wave.append(titles_this_wave)

        # commit
        C |= set(accepted)

        history.append({
            "wave": wave,
            "frontier_size": len(F),
            "accepted": len(accepted),
            "k_in": k_in,
            "use_triangle": use_triangle,
            "tau_triangle": tau_tri,
            "tau_ECR": tau_ecr,
            "phi_CtoF": phi_CtoF,
            "phi_FtoU": phi_FtoU,
            "phi_HtoU": phi_HtoU,
            "R": R,
            "avg_d_in": float(np.mean([d_in_map[v] for v in F])) if F else 0.0,
            "avg_tri_dens": float(np.mean([tri_map[v] for v in F])) if F else 0.0,
            "avg_ECR": float(np.mean([ECR_map[v] for v in F])) if F else 0.0,
            "best_delta_phi": float(min(dphi_map.values())) if F else 0.0
        })

        if len(accepted)==0:
            print("Stopping: no candidates pass gates.")
            break

        if len(ICITE_CACHE) > MAX_PMIDS_TO_FETCH:
            print("Stopping: reached fetch ceiling.")
            break

    return C, history, accepted_titles_per_wave

# ---------- TF-IDF & MeSH summaries ----------
def tfidf_top_terms(pmids: Set[int], k: int = 30) -> List[Tuple[str, float]]:
    docs = []
    for p in pmids:
        meta = PUBMED_CACHE.get(int(p), {})
        text = (meta.get("title") or "") + " " + (meta.get("abstract") or "")
        docs.append(text)
    if not docs:
        return []
    vec = TfidfVectorizer(min_df=2, max_df=0.9, ngram_range=(1,2), stop_words="english")
    X = vec.fit_transform(docs)
    scores = np.asarray(X.sum(axis=0)).ravel()
    terms = np.array(vec.get_feature_names_out())
    order = np.argsort(-scores)
    top = [(terms[i], float(scores[i])) for i in order[:k]]
    return top

def top_mesh_terms(pmids: Set[int], k: int = 20) -> List[Tuple[str,int]]:
    counter = Counter()
    for p in pmids:
        meta = PUBMED_CACHE.get(int(p), {})
        for mh in meta.get("mesh", []):
            counter[mh] += 1
    return counter.most_common(k)

# ---------- Centrality on final C ----------
def centrality_leaderboards(C: Set[int]) -> Dict[str, List[Tuple[int, float]]]:
    # Build undirected induced subgraph on C with edges when papers are linked (either direction)
    G = nx.Graph()
    G.add_nodes_from(C)
    C_list = list(C)
    C_set = set(C_list)
    for u in C_list:
        Nu = neighbors_from_icite(u) & C_set
        for v in Nu:
            if u < v:
                G.add_edge(u, v)
    # Degree centrality (simple degree)
    deg = dict(G.degree())
    deg_sorted = sorted(deg.items(), key=lambda x: (-x[1], x[0]))[:5]
    # Betweenness
    btw = nx.betweenness_centrality(G, normalized=True, k=None)  # exact; small graphs OK
    btw_sorted = sorted(btw.items(), key=lambda x: (-x[1], x[0]))[:5]
    # Eigenvector (handle disconnected)
    try:
        eig = nx.eigenvector_centrality(G, max_iter=2000)
    except Exception:
        # compute on largest component
        comp = max(nx.connected_components(G), key=len) if G.number_of_nodes() else set()
        eig_sub = nx.eigenvector_centrality(G.subgraph(comp), max_iter=2000) if comp else {}
        eig = {n: (eig_sub.get(n, 0.0)) for n in G.nodes()}
    eig_sorted = sorted(eig.items(), key=lambda x: (-x[1], x[0]))[:5]
    return {
        "degree": deg_sorted,
        "betweenness": btw_sorted,
        "eigenvector": eig_sorted
    }

def title(pmid: int) -> str:
    return (PUBMED_CACHE.get(int(pmid), {}).get("title") or
            ICITE_CACHE.get(int(pmid), {}).get("title") or
            f"(no title) PMID {pmid}")

# ---------- Main ----------
if __name__ == "__main__":
    print("Seeds:", SEED_PMIDS)
    ensure_loaded(SEED_PMIDS)

    final_C, history, accepted_titles = expand_two_phase(SEED_PMIDS)

    # Summary
    print("\n=== Final community summary ===")
    print(f"|C| = {len(final_C)}")
    H_nodes, F = induce_H_full(final_C)
    phi_CtoF, phi_FtoU, phi_HtoU = conductances_U_aware(final_C, H_nodes)
    print(f"phi_C→F={phi_CtoF:.4f} | phi_F→U={phi_FtoU:.4f} | phi_H→U={phi_HtoU:.4f}")

    # Semantic quick check: TF-IDF top terms (title+abstract) and top-20 MeSH
    ensure_loaded(final_C)
    top_terms = tfidf_top_terms(final_C, k=30)
    top_mesh = top_mesh_terms(final_C, k=20)

    print("\nTop TF-IDF terms (title+abstract):")
    for term, score in top_terms[:20]:
        print(f"  {term:40s}  {score:.3f}")

    print("\nTop 20 MeSH terms (by frequency):")
    for mh, cnt in top_mesh:
        print(f"  {mh:50s}  {cnt}")

    # Centrality leaderboards on final C
    leaders = centrality_leaderboards(final_C)

    def dump_leaders(name, lst):
        print(f"\nTop-5 by {name} centrality:")
        for pmid, val in lst:
            print(f"  {pmid}  ({val:.5f})  — {title(pmid)}")

    dump_leaders("degree", leaders["degree"])
    dump_leaders("betweenness", leaders["betweenness"])
    dump_leaders("eigenvector", leaders["eigenvector"])

    # Also print the full title list of every entry accepted on each wave
    print("\n=== Accepted titles per wave (full) ===")
    for w, lst in enumerate(accepted_titles, start=1):
        print(f"\nWave {w}: {len(lst)} accepted")
        for pmid, t in lst:
            print(f"  {pmid} — {t}")

    # Optionally: write a CSV of the final community with core fields
    rows = []
    for p in sorted(final_C):
        meta = PUBMED_CACHE.get(p, {})
        ic = ICITE_CACHE.get(p, {})
        rows.append({
            "pmid": p,
            "title": meta.get("title") or ic.get("title"),
            "year": ic.get("year"),
            "doi": ic.get("doi"),
            "rcr": ic.get("rcr"),
            "acr": ic.get("acr"),
            "mesh": "; ".join(meta.get("mesh", [])),
        })
    df = pd.DataFrame(rows)
    df.to_csv("final_community.csv", index=False)
    print("\nWrote final_community.csv")


Seeds: [37693640, 37364610, 32929487, 30935731]

=== Wave 1 ===
Frontier size: 224 | Accepted: 4
k_in=1 | use_triangle=False | tau_triangle=None | tau_ECR=0.0605
phi_C→F=0.9816 | phi_F→U=0.9504 | phi_H→U=0.9463 | R=1.0328
Frontier averages: d_in=1.429 | tri_dens=0.116 | ECR=0.043 | best Δφ=-0.01363
ACCEPTED: 37973418 — Lessons Learned after 176 Patients Treated with a Standardized Procedure of Thoracoscopic Cryoanalgesia during Minimally Invasive Repair of Pectus Excavatum.
ACCEPTED: 39363012 — Effect of cryoablation in Nuss bar placement on opioid utilization and length of stay.
ACCEPTED: 38242172 — Intercostal Nerve Cryoablation or Epidural Analgesia for Multimodal Pain Management after the Nuss Procedure: A Cohort Study.
ACCEPTED: 36922280 — Impact of Cryoanalgesia Use During Minimally Invasive Pectus Excavatum Repair on Hospital Days and Total Hospital Costs Among Pediatric Patients.

=== Wave 2 ===
Frontier size: 223 | Accepted: 2
k_in=1 | use_triangle=False | tau_triangle=None | 