In [None]:
# First, ensure dependencies are installed by running this in your terminal with conda env activated:
# conda install -c conda-forge requests pandas tqdm networkx matplotlib

import os
import requests
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from collections import deque
from xml.etree import ElementTree as ET

# --- Configuration ---
ICITE_BASE = "https://icite.od.nih.gov/api/pubs"
EUTILS_BASE = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
ICITE_BATCH = 150
EFETCH_BATCH = 200

SRMA_TAGS = {"Systematic Review", "Meta-Analysis"}
PRIMARY_TAGS = {
    "Randomized Controlled Trial", "Clinical Trial", "Clinical Trial, Phase I",
    "Clinical Trial, Phase II", "Clinical Trial, Phase III", "Clinical Trial, Phase IV",
    "Pragmatic Clinical Trial", "Observational Study", "Cohort Studies",
    "Case-Control Studies", "Cross-Sectional Studies", "Equivalence Trial",
    "Comparative Study"
}

ICITE_CACHE = {}
META_CACHE = {}

# --- API Helper Functions ---

def icite_fetch(pmids: list[int]) -> None:
    missing_pmids = [p for p in pmids if p not in ICITE_CACHE]
    if not missing_pmids: return
    for i in range(0, len(missing_pmids), ICITE_BATCH):
        batch = missing_pmids[i:i+ICITE_BATCH]
        try:
            params = {"pmids": ",".join(map(str, batch)), "format": "json"}
            r = requests.get(ICITE_BASE, params=params, timeout=60)
            r.raise_for_status()
            data = r.json().get("data", [])
            for rec in data:
                pmid = int(rec["pmid"])
                refs = {int(p) for p in rec.get("references", []) or []}
                citers = {int(p) for p in rec.get("cited_by", []) or []}
                ICITE_CACHE[pmid] = {"refs": refs, "citers": citers}
        except requests.RequestException as e:
            print(f"Warning: iCite request failed for batch {batch[0]}: {e}")

def efetch_meta(pmids: list[int]) -> None:
    missing_pmids = [p for p in pmids if p not in META_CACHE]
    if not missing_pmids: return
    for i in range(0, len(missing_pmids), EFETCH_BATCH):
        batch = missing_pmids[i:i+EFETCH_BATCH]
        try:
            params = {"db": "pubmed", "id": ",".join(map(str, batch)), "retmode": "xml"}
            r = requests.post(f"{EUTILS_BASE}/efetch.fcgi", data=params, timeout=60)
            r.raise_for_status()
            root = ET.fromstring(r.text)
            for art in root.findall(".//PubmedArticle"):
                pmid_node = art.find(".//PMID")
                if pmid_node is None or not pmid_node.text: continue
                pmid = int(pmid_node.text)
                title_node = art.find(".//ArticleTitle")
                title = "".join(title_node.itertext()).strip() if title_node is not None else "[No Title]"
                pub_types = {pt.text.strip() for pt in art.findall(".//PublicationTypeList/PublicationType") if pt.text}
                is_sr = bool(pub_types & SRMA_TAGS)
                is_primary = not is_sr and bool(pub_types & PRIMARY_TAGS)
                META_CACHE[pmid] = {
                    "title": title, "pub_types": list(pub_types),
                    "is_sr": is_sr, "is_primary": is_primary
                }
        except (requests.RequestException, ET.ParseError) as e:
            print(f"Warning: EFetch request failed for batch {batch[0]}: {e}")

# --- Core Expansion Logic ---

def expand_bipartite_graph(seed_pmid: int, target_size: int = 1000):
    queue = deque([seed_pmid])
    visited = {seed_pmid}
    edges = set()
    pbar = tqdm(total=target_size, desc="Discovering Nodes")
    pbar.update(1)

    while queue and len(visited) < target_size:
        current_pmid = queue.popleft()
        efetch_meta([current_pmid])
        current_node_meta = META_CACHE.get(current_pmid)
        if not current_node_meta or (not current_node_meta["is_sr"] and not current_node_meta["is_primary"]):
            continue
        
        icite_fetch([current_pmid])
        current_node_cits = ICITE_CACHE.get(current_pmid)
        if not current_node_cits: continue

        neighbor_candidates = set()
        if current_node_meta["is_sr"]:
            neighbor_candidates.update(current_node_cits.get("refs", set()))
        elif current_node_meta["is_primary"]:
            neighbor_candidates.update(current_node_cits.get("citers", set()))

        efetch_meta(list(neighbor_candidates))
        for neighbor_pmid in neighbor_candidates:
            if neighbor_pmid in visited: continue
            neighbor_meta = META_CACHE.get(neighbor_pmid)
            if not neighbor_meta: continue

            is_valid_link = (current_node_meta["is_sr"] and neighbor_meta["is_primary"]) or \
                            (current_node_meta["is_primary"] and neighbor_meta["is_sr"])

            if is_valid_link:
                if len(visited) >= target_size: break
                visited.add(neighbor_pmid)
                pbar.update(1)
                queue.append(neighbor_pmid)
                if current_node_meta["is_sr"]:
                    edges.add((current_pmid, neighbor_pmid))
                else:
                    edges.add((neighbor_pmid, current_pmid))

    pbar.close()
    final_nodes = {pmid: META_CACHE[pmid] for pmid in visited if pmid in META_CACHE}
    return edges, final_nodes

# --- Visualization ---

def draw_bipartite_graph(edges: set, nodes: dict, seed_pmid: int):
    G = nx.Graph()
    sr_nodes, pe_nodes = [], []
    for pmid, meta in nodes.items():
        if meta["is_sr"]:
            G.add_node(pmid, type='sr')
            sr_nodes.append(pmid)
        elif meta["is_primary"]:
            G.add_node(pmid, type='pe')
            pe_nodes.append(pmid)
    
    G.add_edges_from(edges)
    
    if not sr_nodes or not pe_nodes:
        print("Warning: Graph is not bipartite (one set of nodes is empty). Using default layout.")
        pos = nx.spring_layout(G)
    else:
        pos = nx.bipartite_layout(G, sr_nodes)

    plt.figure(figsize=(20, 20))
    nx.draw_networkx_nodes(G, pos, nodelist=sr_nodes, node_color="#f06595", node_shape="*", node_size=1500, label="Systematic Reviews")
    nx.draw_networkx_nodes(G, pos, nodelist=pe_nodes, node_color="#63e6be", node_shape="o", node_size=500, label="Primary Evidence")
    if seed_pmid in G:
        nx.draw_networkx_nodes(G, pos, nodelist=[seed_pmid], node_color="#fcc419", node_shape="D", node_size=2000)
    
    nx.draw_networkx_edges(G, pos, alpha=0.5, edge_color="gray")
    nx.draw_networkx_labels(G, pos, font_size=8)
    
    plt.title("Bipartite Graph of SRs and Primary Evidence", size=20)
    plt.legend(scatterpoints=1, fontsize=14)
    plt.box(False)
    plt.show()

# --- Main Execution ---

SEED_PMID = 33591115
TARGET_NODE_COUNT = 1000

print(f"Starting bipartite expansion from seed PMID: {SEED_PMID}")
print(f"Targeting a graph size of ~{TARGET_NODE_COUNT} nodes.")

graph_edges, graph_nodes = expand_bipartite_graph(SEED_PMID, TARGET_NODE_COUNT)

print(f"\nExpansion complete.")
print(f"Final graph size: {len(graph_nodes)} nodes and {len(graph_edges)} edges.")

print("Generating graph plot...")
draw_bipartite_graph(graph_edges, graph_nodes, SEED_PMID)