In [11]:
#!/usr/bin/env python
# =============================================================================
# SCRIPT: recursive_dag_expansion.py
# =============================================================================
# Purpose:
#   - Start from a PubMed query or list of DOIs.
#   - Recursively expand references (“parents”) and citations (“children”) up to N generations.
#   - Use the V14 pipeline for:
#       • ESearch → PMIDs (with WebEnv/QueryKey)
#       • EFetch → dates, DOI, references
#       • CrossRef augmentation for low‐ref articles
#       • ELink (POST batching) → citation edges
#       • Chronological filtering & DAG enforcement
#   - Benchmark per‐generation node/edge growth and total runtime.
#   - Produce an interactive network plot with nodes colored by generation.
# =============================================================================
import requests
import time
import datetime
import xml.etree.ElementTree as ET
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm

# --- Configuration ---
SEED_QUERY    = None
SEED_DOIS     = ["10.1053/j.jvca.2024.10.005"]
MAX_DEPTH     = 5
SEARCH_MAX    = 100
ELINK_BATCH   = 100
AUGMENT_REFS        = True
AUGMENT_THRESHOLD   = 5
CROSSREF_FILTER_BATCH_SIZE = 50
DOI_ES_BATCH        = 100
SLEEP_SECONDS       = 0.1

NCBI_TOOL     = "recursive_dag_expansion"
NCBI_EMAIL    = "levi4328@gmail.com"
NCBI_API_KEY  = "44d5c1b49a9ed02ae1fc52fa9d01e148e009"

# --- Session Setup ---
session = requests.Session()
session_cr = requests.Session()

def _ncbi_params(extra=None):
    p = {"tool":NCBI_TOOL}
    if NCBI_EMAIL:    p["email"]   = NCBI_EMAIL
    if NCBI_API_KEY:  p["api_key"] = NCBI_API_KEY
    if extra: p.update(extra)
    return p

# --- 1) ESearch for PMIDs or DOI→PMID via OR-query ---
def esearch_pmids(term, retmax):
    r = session.get("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi",
                    params=_ncbi_params({"db":"pubmed","term":term,"retmax":retmax,"usehistory":"y"}))
    r.raise_for_status(); root = ET.fromstring(r.content)
    ids = [e.text for e in root.findall(".//IdList/Id")]
    we, qk = root.findtext(".//WebEnv"), root.findtext(".//QueryKey")
    return ids, we, qk

def dois_to_pmids(dois, batch_size=DOI_ES_BATCH):
    doi2pm = {}
    for i in range(0,len(dois),batch_size):
        batch = dois[i:i+batch_size]
        term  = " OR ".join(f"{d}[doi]" for d in batch)
        r = session.post("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi",
                         data=_ncbi_params({"db":"pubmed","term":term,"retmode":"json","retmax":str(len(batch))}))
        r.raise_for_status()
        ids = r.json()["esearchresult"].get("idlist",[])
        for d,p in zip(batch,ids):
            doi2pm[d]=p
        time.sleep(SLEEP_SECONDS)
    return doi2pm

# --- 2) EFetch: dates, refs, DOI ---
def fetch_details(pmids, we=None, qk=None):
    params = {"db":"pubmed","retmode":"xml","rettype":"abstract","retmax":len(pmids)}
    if we and qk:
        params.update({"WebEnv":we,"query_key":qk})
        r = session.post("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi",
                         data=_ncbi_params(params))
    else:
        params["id"] = ",".join(pmids)
        r = session.post("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi",
                         data=_ncbi_params(params))
    r.raise_for_status(); root = ET.fromstring(r.content)
    MONTH = {m:i for i,m in enumerate(["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"],1)}
    out = {}
    for art in root.findall(".//PubmedArticle"):
        pm = art.findtext(".//PMID")
        # date
        ad = art.find(".//ArticleDate") or art.find(".//JournalIssue/PubDate")
        dt = None
        if ad is not None:
            y,mo,da = ad.findtext("Year"),ad.findtext("Month"),ad.findtext("Day")
            if y and (mo or da):
                mnum = int(mo) if mo and mo.isdigit() else MONTH.get(mo[:3],1)
                dnum = int(da) if da and da.isdigit() else 1
                try: dt = datetime.date(int(y),mnum,dnum)
                except: pass
        # refs
        refs = [r.findtext(".//ArticleId[@IdType='pubmed']")
                for r in art.findall(".//ReferenceList/Reference")
                if r.findtext(".//ArticleId[@IdType='pubmed']")]
        # doi
        doi = art.findtext(".//ArticleIdList/ArticleId[@IdType='doi']")
        out[pm] = {"date":dt, "refs":refs, "doi":doi}
    return out

# --- 3) CrossRef augment ---
def crossref_refs(dois):
    out = {}
    for i in range(0,len(dois),CROSSREF_FILTER_BATCH_SIZE):
        batch = dois[i:i+CROSSREF_FILTER_BATCH_SIZE]
        fv    = ",".join(f"doi:{d}" for d in batch)
        r = session_cr.get("https://api.crossref.org/works",
                           params={"filter":fv,"rows":len(batch)},
                           headers={"User-Agent":f"{NCBI_TOOL}({NCBI_EMAIL})"})
        items = r.json().get("message",{}).get("items",[])
        for itm in items:
            doi0 = itm.get("DOI")
            refs = [ref.get("DOI") for ref in itm.get("reference",[]) if ref.get("DOI")]
            out[doi0] = refs
        time.sleep(SLEEP_SECONDS)
    return out

# --- 4) ELink citations ---
def fetch_citations(pmids):
    out = {}
    def batch_post(batch):
        data = [("dbfrom","pubmed"),("linkname","pubmed_pubmed_citedin"),("cmd","neighbor")]
        data += [("id",p) for p in batch] + list(_ncbi_params().items())
        r = session.post("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/elink.fcgi", data=data)
        r.raise_for_status(); return r.content
    for i in range(0,len(pmids),ELINK_BATCH):
        batch = pmids[i:i+ELINK_BATCH]
        raw   = batch_post(batch)
        root  = ET.fromstring(raw)
        for ls in root.findall(".//LinkSet"):
            src = ls.findtext(".//IdList/Id")
            cits= [c.text for c in ls.findall(".//LinkSetDb/Link/Id") if c.text]
            out[src] = cits
        time.sleep(SLEEP_SECONDS)
    return out

# --- 5) Chronological & DAG checks ---
def remove_time_invalid_edges(G):
    to_remove=[]
    for u,v in G.edges():
        du, dv = G.nodes[u].get("date"), G.nodes[v].get("date")
        if du and dv and du > dv:
            to_remove.append((u,v))
    G.remove_edges_from(to_remove)
    return to_remove

def enforce_acyclic(G):
    if not nx.is_directed_acyclic_graph(G):
        for cycle in nx.simple_cycles(G):
            # remove one edge per cycle
            G.remove_edge(cycle[-1], cycle[0])
        return False
    return True

# --- Recursive Expansion ---
def recursive_expand(pmids):
    G = nx.DiGraph()
    gen_map = {pm:0 for pm in pmids}
    for pm in pmids: G.add_node(pm, date=None, generation=0)

    current = pmids
    for gen in range(1, MAX_DEPTH+1):
        # fetch details + citing
        details = fetch_details(current)
        # CrossRef augment
        if AUGMENT_REFS:
            cands = [pm for pm,d in details.items() if len(d["refs"])<AUGMENT_THRESHOLD and d["doi"]]
            cr   = crossref_refs([details[pm]["doi"] for pm in cands])
            doi2pm = dois_to_pmids(sum(cr.values(), []))
            for pm in cands:
                mapped = [doi2pm.get(d) for d in cr.get(details[pm]["doi"],[])]
                details[pm]["refs"].extend([m for m in mapped if m])

        # parents
        parents = set(sum((d["refs"] for d in details.values()), []))
        # citations
        cits    = set(sum((fetch_citations(current).values()), []))

        next_gen = list((parents|cits) - set(G.nodes()))
        # add nodes + edges
        for pm in next_gen:
            G.add_node(pm, generation=gen, date=None)
        # add edges
        for src,info in details.items():
            for r in info["refs"]:
                if G.has_node(src) and G.has_node(r):
                    G.add_edge(r, src, provenance="reference")
        for src,clist in fetch_citations(current).items():
            for c in clist:
                if G.has_node(src) and G.has_node(c):
                    G.add_edge(src, c, provenance="citation")

        # chronology & cycle enforcement
        removed = remove_time_invalid_edges(G)
        enforce_acyclic(G)

        current = next_gen

    return G

# --- Main ---
if __name__=="__main__":
    if SEED_DOIS:
        seed_pmids = list(dois_to_pmids(SEED_DOIS).values())
    elif SEED_QUERY:
        seed_pmids, we, qk = esearch_pmids(SEED_QUERY, SEARCH_MAX)
    else:
        raise ValueError("Set SEED_DOIS or SEED_QUERY")

    print("Seed PMIDs:", seed_pmids)
    G = recursive_expand(seed_pmids)

    # Plot with matplotlib
    pos = nx.spring_layout(G, k=0.5, iterations=50)
    colors = [G.nodes[n]["generation"] for n in G.nodes()]
    plt.figure(figsize=(10,10))
    nx.draw_networkx_nodes(G, pos, node_color=colors, cmap="viridis", node_size=50)
    nx.draw_networkx_edges(G, pos, alpha=0.3, arrows=False)
    plt.title("Recursive Citation DAG (nodes by generation)")
    plt.axis("off")
    plt.show()


Seed PMIDs: ['39489669']


NetworkXError: The edge 28713211-20626510 not in graph.

In [None]:
# ============================================================
# CELL: Recursive Citation Expansion DAG Builder V15
# ============================================================
# Purpose
# -------
# 1.  Bootstrap from either a PubMed query or a list of DOIs.
# 2.  Build a *directed*, time‑consistent citation graph.
# 3.  Recursively expand   P₁…Pₙ  (references / “parents”)
#                        + C₁…Cₙ  (citations / “children”)
#     in two modes:
#       • linear  – chain:  Pₙ→⋯→P₁→Seed→C₁→⋯→Cₙ
#       • global  – snow‑ball: expand parents **and** children of the *current*
#                                 frontier at every depth.
# 4.  Threshold filter: a new node is accepted only if it has ≥ T
#     (absolute ‑or‑ relative %) links *into* the frontier that spawned it.
# 5.  CrossRef‑based augmentation for low‑reference articles
#     (same batching & OR‑ESearch DOI→PMID mapping as V14).
# 6.  Strict DAG enforcement:
#       • drop time‑inverted edges   (src_date > tgt_date)
#       • after each layer, verify graph is acyclic; remove offending edges.
# 7.  Benchmark & log per generation:
#       nodes added, edges added, cumulative nodes, run‑time.
# 8.  Visualise final graph:
#       nodes coloured by generation, layers laid out left→right.
#
# Requirements
# ------------
#   pip install requests networkx matplotlib pandas tqdm lxml
# ============================================================

import requests, time, datetime, xml.etree.ElementTree as ET, networkx as nx
import pandas as pd, matplotlib.pyplot as plt
from tqdm import tqdm
from requests.exceptions import ChunkedEncodingError, RequestException
from urllib3.exceptions import ProtocolError

# ── User Config ──────────────────────────────────────────────
SEED_QUERY      = None
SEED_DOIS       = ["10.1053/j.jvca.2024.10.005"]
MAX_DEPTH       = 4

EXPANSION_MODE  = "linear"     # "linear" or "global"
THRESHOLD_TYPE  = "relative"   # "absolute" or "relative"
THRESHOLD_VAL   = 0.05         # ≥ count  (abs)   or ≥ p% (rel)

AUGMENT_REFS          = True
AUGMENT_THRESHOLD     = 3      # if < this many refs, augment via CrossRef
CROSSREF_BATCH        = 50

BATCH_SIZE            = 200    # for EFetch / ELink
SLEEP_SECONDS         = 0.1
NCBI_TOOL             = "dag_checker_v15"
NCBI_EMAIL            = "levi4328@gmail.com"
NCBI_API_KEY          = "YOUR_NCBI_KEY"  # optional

# ── NCBI helpers ─────────────────────────────────────────────
session = requests.Session()
def ncbi(extra=None):
    p = {"tool": NCBI_TOOL, "email": NCBI_EMAIL}
    if NCBI_API_KEY: p["api_key"] = NCBI_API_KEY
    if extra: p.update(extra)
    return p

# ── Bootstrap functions ─────────────────────────────────────
def esearch_to_pmids(query, retmax=100):
    r = session.get("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi",
                    params=ncbi({"db":"pubmed","term":query,
                                 "retmax":str(retmax),"usehistory":"n"}), timeout=30)
    r.raise_for_status(); root = ET.fromstring(r.content)
    return [e.text for e in root.findall(".//IdList/Id")]

def doi_batch_esearch(dois, bsz=100):
    """
    Batch-convert DOIs → PMIDs via OR-joined ESearch POST.
    Wraps each DOI in quotes:  "10.xxxx/yyy"[doi].
    """
    doi2pmid = {}
    for i in range(0, len(dois), bsz):
        batch = dois[i:i+bsz]
        # wrap each DOI in quotes to avoid Bad Request
        term = " OR ".join(f'"{d}"[doi]' for d in batch)
        params = ncbi({
            "db":      "pubmed",
            "term":    term,
            "retmode": "json",
            "retmax":  str(len(batch))
        })
        try:
            resp = session.post(
                "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi",
                data=params,
                timeout=60
            )
            # if PubMed still complains, dump the body for debugging
            if not resp.ok:
                print("⛔ ESearch DOI→PMID batch failed:",
                      resp.status_code, resp.reason)
                print("Raw response:\n", resp.text)
            resp.raise_for_status()
            idlist = resp.json().get("esearchresult", {}).get("idlist", [])
            # map 1:1 for as many returned IDs as we got
            for doi, pm in zip(batch, idlist):
                doi2pmid[doi] = pm
        except Exception as e:
            print(f"⚠️ Batch #{i//bsz+1} DOI→PMID error:", e)
        time.sleep(SLEEP_SECONDS)
    print(f"   Mapped {len(doi2pmid)}/{len(dois)} DOIs → PMIDs")
    return doi2pmid


# ── EFetch details + DOI extraction ─────────────────────────
MONTH = {m:i for i,m in enumerate(
    ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"],1)}
def fetch_details(pmids):
    out={}
    for chunk in tqdm(range(0,len(pmids),BATCH_SIZE), desc="EFetch"):
        batch=pmids[chunk:chunk+BATCH_SIZE]
        r=session.post("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi",
                       data=ncbi({"db":"pubmed","retmode":"xml","id":",".join(batch)}),
                       timeout=120)
        r.raise_for_status(); root=ET.fromstring(r.content)
        for art in root.findall(".//PubmedArticle"):
            pm=art.findtext(".//PMID")
            ad = art.find(".//ArticleDate") or art.find(".//PubDate")
            date=None
            if ad is not None:
                y,mo,da = ad.findtext("Year"),ad.findtext("Month"),ad.findtext("Day")
                if y and (mo or da):
                    m=int(mo) if mo and mo.isdigit() else MONTH.get(mo[:3],1)
                    d=int(da) if da and da.isdigit() else 1
                    try: date=datetime.date(int(y),m,d)
                    except: pass
            refs=[ r.findtext("./ArticleId[@IdType='pubmed']")
                   for r in art.findall(".//ReferenceList/Reference")
                   if r.findtext("./ArticleId[@IdType='pubmed']") ]
            doi = art.findtext(".//ArticleIdList/ArticleId[@IdType='doi']")
            out[pm]={"date":date,"refs":refs,"doi":doi}
        time.sleep(SLEEP_SECONDS)
    return out

# ── ELink citations ─────────────────────────────────────────
def fetch_citations(pmids):
    out={}
    def _call(batch):
        data=[("dbfrom","pubmed"),("linkname","pubmed_pubmed_citedin"),("cmd","neighbor")]
        data+=[("id",p) for p in batch]+list(ncbi().items())
        for attempt in range(3):
            try:
                r=session.post("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/elink.fcgi",
                               data=data,timeout=120)
                r.raise_for_status(); return r.content
            except (ChunkedEncodingError,ProtocolError,RequestException):
                time.sleep(SLEEP_SECONDS*(attempt+1))
        # split fallback
        if len(batch)>1:
            mid=len(batch)//2
            return (_call(batch[:mid]) or b"")+(_call(batch[mid:]) or b"")
        return b""
    for chunk in tqdm(range(0,len(pmids),BATCH_SIZE), desc="ELink"):
        raw=_call(pmids[chunk:chunk+BATCH_SIZE])
        if not raw: continue
        root=ET.fromstring(raw)
        for ls in root.findall(".//LinkSet"):
            src=ls.findtext(".//IdList/Id")
            outs=[c.text for c in ls.findall(".//LinkSetDb/Link/Id") if c.text]
            out[src]=outs
        time.sleep(SLEEP_SECONDS)
    return out

# ── CrossRef augmentation ──────────────────────────────────
def crossref_refs(dois):
    out={}
    for i in range(0,len(dois),CROSSREF_BATCH):
        batch=dois[i:i+CROSSREF_BATCH]
        filt=",".join(f"doi:{d}" for d in batch)
        r=requests.get("https://api.crossref.org/works",
                       params={"filter":filt,"rows":len(batch),"select":"DOI,reference"},
                       headers={"User-Agent":f"{NCBI_TOOL}({NCBI_EMAIL})"},timeout=60)
        items=r.json().get("message",{}).get("items",[])
        for it in items:
            doi=it.get("DOI"); refs=[ref.get("DOI") for ref in it.get("reference",[]) if ref.get("DOI")]
            out[doi]=refs
        time.sleep(SLEEP_SECONDS)
    return out

# ── DAG / chronology helpers ───────────────────────────────
def add_edge_safe(G,u,v,prov="ref"):
    du,dv=G.nodes[u].get("date"),G.nodes[v].get("date")
    if du and dv and du>dv:         # time‑inverted ⇒ skip
        return False
    G.add_edge(u,v,prov=prov); return True

# ── Core expansion ─────────────────────────────────────────
def meets(cnt,total):
    if THRESHOLD_TYPE=="absolute":  return cnt>=THRESHOLD_VAL
    return (cnt/total)>=THRESHOLD_VAL if total else False

def expand_graph(seed_pmids):
    G=nx.DiGraph()
    for pm in seed_pmids: G.add_node(pm,gen=0,src="seed")
    frontier=seed_pmids[:]; visited=set(seed_pmids)
    metrics=[]
    for depth in range(1,MAX_DEPTH+1):
        print(f"\n=== DEPTH {depth} ({EXPANSION_MODE}) ===")
        # ---------- Parents ----------
        det=fetch_details(frontier)
        # CrossRef augmentation
        if AUGMENT_REFS:
            low=[p for p,d in det.items() if len(d["refs"])<AUGMENT_THRESHOLD and d["doi"]]
            if low:
                cr = crossref_refs([det[p]["doi"] for p in low])
                all_dois=set(r for lst in cr.values() for r in lst)
                m = doi_batch_esearch(list(all_dois))
                for p in low:
                    add=[m[d] for d in cr.get(det[p]["doi"],[]) if d in m]
                    det[p]["refs"].extend(add)
        # candidate counts
        par_count={}
        for p,info in det.items():
            for r in info["refs"]:
                par_count.setdefault(r,0); par_count[r]+=1
        total=len(frontier)
        parents=[r for r,c in par_count.items() if meets(c,total)]
        # ---------- Children ----------
        cit_map=fetch_citations(frontier)
        chi_count={}
        for p,cl in cit_map.items():
            for c in cl:
                chi_count.setdefault(c,0); chi_count[c]+=1
        children=[c for c,cnt in chi_count.items() if meets(cnt,total)]

        # linear vs global frontier update
        if EXPANSION_MODE=="linear":
            next_frontier=children
        else:                      # global
            next_frontier=list(set(parents+children))

        # Add nodes / edges
        new_nodes=[n for n in next_frontier if n not in visited]
        det_new=fetch_details(new_nodes) if new_nodes else {}
        for n in new_nodes:
            G.add_node(n,gen=depth,src="parent" if n in parents else "child",
                       date=det_new.get(n,{}).get("date"))
        # edges parent→frontier
        e_added=0
        for p in frontier:
            for r in det[p]["refs"]:
                if r in parents and r in G:
                    if add_edge_safe(G,r,p,"ref"): e_added+=1
        # edges frontier→child
        for p,cl in cit_map.items():
            for c in cl:
                if c in children and c in G:
                    if add_edge_safe(G,p,c,"cit"): e_added+=1
        visited.update(new_nodes)
        metrics.append({"depth":depth,"new_nodes":len(new_nodes),
                        "cum_nodes":G.number_of_nodes(),"edges":e_added})
        frontier=next_frontier
    return G,pd.DataFrame(metrics)

# ── PLOT helper ─────────────────────────────────────────────
def plot_generations(G,title="Citation DAG"):
    # arrange layers horizontally
    layers={}
    for n,data in G.nodes(data=True):
        g=data.get("gen",0); layers.setdefault(g,[]).append(n)
    pos={}
    for g,nodes in layers.items():
        for i,n in enumerate(nodes):
            pos[n]=(g, -i)          # x=gen, y=stack
    cmap=plt.cm.get_cmap("viridis", max(layers)+1)
    colors=[cmap[G.nodes[n]['gen']] for n in G.nodes()]
    plt.figure(figsize=(10,6))
    nx.draw(G,pos,with_labels=False,node_size=30,edge_color="#999",
            node_color=colors)
    sm=plt.cm.ScalarMappable(cmap=cmap,norm=plt.Normalize(0,max(layers)))
    sm.set_array([]); cbar=plt.colorbar(sm,ticks=range(max(layers)+1))
    cbar.set_label("Generation")
    plt.title(title); plt.axis("off"); plt.tight_layout(); plt.show()

# ── RUN ─────────────────────────────────────────────────────
if __name__=="__main__":
    if SEED_DOIS:
        seeds=list(doi_batch_esearch(SEED_DOIS).values())
    elif SEED_QUERY:
        seeds=esearch_to_pmids(SEED_QUERY,retmax=100)
    else:
        raise ValueError("Provide SEED_DOIS or SEED_QUERY")

    print("Seed PMIDs:",seeds)
    G,df=expand_graph(seeds)
    print("\n=== Metrics ===")
    print(df)

    print("\nDAG check:", "✅" if nx.is_directed_acyclic_graph(G) else "❌ (cycles!)")
    plot_generations(G,f"{EXPANSION_MODE.capitalize()} expansion (depth ≤ {MAX_DEPTH})")


⛔ ESearch DOI→PMID batch failed: 400 Bad Request
Raw response:
 {"error":"API key invalid","api-key":"YOUR_NCBI_KEY","type":"invalid",
"status":"unknown"}
⚠️ Batch #1 DOI→PMID error: 400 Client Error: Bad Request for url: https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi
   Mapped 0/1 DOIs → PMIDs
Seed PMIDs: []

=== DEPTH 1 (linear) ===


EFetch: 0it [00:00, ?it/s]
ELink: 0it [00:00, ?it/s]



=== DEPTH 2 (linear) ===


EFetch: 0it [00:00, ?it/s]
ELink: 0it [00:00, ?it/s]



=== DEPTH 3 (linear) ===


EFetch: 0it [00:00, ?it/s]
ELink: 0it [00:00, ?it/s]



=== DEPTH 4 (linear) ===


EFetch: 0it [00:00, ?it/s]
ELink: 0it [00:00, ?it/s]


=== Metrics ===
   depth  new_nodes  cum_nodes  edges
0      1          0          0      0
1      2          0          0      0
2      3          0          0      0
3      4          0          0      0

DAG check: ✅





ValueError: max() arg is an empty sequence