# Imports and parameters

In [79]:
from __future__ import annotations
import time, os
from typing import Dict, List, Tuple, Optional
from urllib.parse import quote_plus

import requests
import pandas as pd
import networkx as nx
from tqdm.auto import tqdm

import math
import matplotlib.pyplot as plt
import networkx as nx

from pyvis.network import Network
import math

In [80]:
S2_BASE = "https://api.semanticscholar.org/graph/v1"
S2_SLEEP_S = 10  # adjust if you hit 429s; unauthenticated is ~100 req/5 min
S2_API_KEY = os.getenv("S2_API_KEY")  # optional; set for higher limits

# Main functions

In [81]:
def s2_request(path: str, params: Dict) -> Dict:
    url = f"{S2_BASE}/{path.lstrip('/')}"
    headers = {}
    if S2_API_KEY:
        headers["x-api-key"] = S2_API_KEY
    r = requests.get(url, params=params, headers=headers, timeout=60)
    if r.status_code == 429:
        # simple backoff
        time.sleep(max(S2_SLEEP_S * 4, 2.0))
        r = requests.get(url, params=params, headers=headers, timeout=60)
    r.raise_for_status()
    return r.json()

def resolve_seed(seed: str) -> Dict:
    """
    Return minimal paper dict {paperId,title,year,doi,venue} for:
      - DOI like '10.1038/nature12373'
      - Semantic Scholar paperId like 'arXiv:1706.03762' or a 40-char S2 ID
      - Title/keywords (falls back to search, picks top hit)
    """
    fields = "paperId,title,year,externalIds,venue"
    s = seed.strip()
    # DOI?
    if s.lower().startswith("10."):
        pid = f"DOI:{s}"
        return s2_request(f"paper/{quote_plus(pid)}", {"fields": fields})
    # Looks like S2 paper id? (heuristic)
    if len(s) >= 10 and ":" in s or len(s) == 40:
        return s2_request(f"paper/{quote_plus(s)}", {"fields": fields})
    # Otherwise search by title/keywords
    data = s2_request("paper/search", {"query": s, "limit": 1, "fields": fields})
    results = data.get("data", [])
    if not results:
        raise ValueError(f"No Semantic Scholar match for {seed!r}")
    return results[0]


In [82]:
def s2_list_citers(paper_id: str,
                   since: Optional[int]=None,
                   until: Optional[int]=None,
                   limit: int = 1000) -> List[Dict]:
    """
    Returns a list of citing paper dicts with fields:
      {paperId,title,year,externalIds,venue,citationCount}
    """
    fields = "citingPaper.paperId,citingPaper.title,citingPaper.year,citingPaper.externalIds,citingPaper.venue,citingPaper.citationCount"
    params = {"fields": fields, "limit": min(limit, 1000), "offset": 0}
    out: List[Dict] = []

    while True:
        data = s2_request(f"paper/{quote_plus(paper_id)}/citations", params)
        items = data.get("data", [])
        for it in items:
            cp = it.get("citingPaper") or {}
            y = cp.get("year")
            if since is not None and (y is None or y < since): 
                continue
            if until is not None and (y is None or y > until):
                continue
            out.append(cp)
        if len(items) < params["limit"]:
            break
        params["offset"] += params["limit"]
        time.sleep(S2_SLEEP_S)
    return out


# Helper functions

In [None]:
def s2_forward_graph(seed: str,
                     depth: int = 1,
                     since: Optional[int] = None,
                     until: Optional[int] = None,
                     max_nodes: int = 1000,
                     progress: bool = True) -> nx.DiGraph:
    """
    DiGraph with edge A -> B meaning "B cites A" (forward).
    """
    seed_meta = resolve_seed(seed)
    G = nx.DiGraph()
    seen: set[str] = set()
    queue: List[Tuple[Dict, int]] = [(seed_meta, depth)]

    pbar = tqdm(disable=not progress, total=None, desc="Traversing")

    while queue and G.number_of_nodes() < max_nodes:
        current, dleft = queue.pop(0)
        pid = current["paperId"]
        if pid in seen:
            continue
        seen.add(pid)

        # add current node
        G.add_node(
            pid,
            title=current.get("title"),
            year=current.get("year"),
            doi=(current.get("externalIds") or {}).get("DOI", ""),
            venue=current.get("venue") or "",
            cited_by_count=current.get("citationCount", None),
        )

        if dleft <= 0:
            pbar.update(1); continue

        # get forward citers
        try:
            citers = s2_list_citers(pid, since=since, until=until)
        except Exception as e:
            print(f"[WARN] citers failed for {pid}: {e}")
            pbar.update(1); continue

        for cp in citers:
            cpid = cp.get("paperId")
            if not cpid:
                continue
            if cpid not in G:
                G.add_node(
                    cpid,
                    title=cp.get("title"),
                    year=cp.get("year"),
                    doi=(cp.get("externalIds") or {}).get("DOI", ""),
                    venue=cp.get("venue") or "",
                    cited_by_count=cp.get("citationCount", None),
                )
            if not G.has_edge(pid, cpid):
                G.add_edge(pid, cpid)

        # enqueue next layer
        for cp in citers:
            cpid = cp.get("paperId")
            if cpid and cpid not in seen and G.number_of_nodes() < max_nodes:
                queue.append((cp, dleft - 1))

        pbar.update(1)
        time.sleep(S2_SLEEP_S)

    pbar.close()
    return G


# Convert and store function

In [84]:
def graph_to_dfs(G: nx.DiGraph):
    nodes = [{
        "id": nid,
        "title": d.get("title"),
        "year": d.get("year"),
        "doi": d.get("doi"),
        "venue": d.get("venue"),
        "cited_by_count": d.get("cited_by_count"),
    } for nid, d in G.nodes(data=True)]
    edges = [{"source": u, "target": v} for u, v in G.edges()]
    return pd.DataFrame(nodes), pd.DataFrame(edges)

def save_graph(G, prefix="s2_forward"):
    import networkx as nx
    import pandas as pd

    # to CSVs
    ndf = pd.DataFrame([
        {"id": nid,
         "title": d.get("title"),
         "year": d.get("year"),
         "doi": d.get("doi"),
         "venue": d.get("venue"),
         "cited_by_count": (0 if d.get("cited_by_count") is None else d.get("cited_by_count"))}
        for nid, d in G.nodes(data=True)
    ])
    edf = pd.DataFrame([{"source": u, "target": v} for u, v in G.edges()])
    ndf.to_csv(f"{prefix}_nodes.csv", index=False)
    edf.to_csv(f"{prefix}_edges.csv", index=False)

    # sanitize a copy for GEXF
    H = G.copy()
    for _, d in H.nodes(data=True):
        for k, v in list(d.items()):
            if v is None:
                d[k] = ""
            elif isinstance(v, (list, dict, set, tuple)):
                d[k] = str(v)
    for _, _, d in H.edges(data=True):
        for k, v in list(d.items()):
            if v is None:
                d[k] = ""
            elif isinstance(v, (list, dict, set, tuple)):
                d[k] = str(v)
    for k, v in list(H.graph.items()):
        if v is None:
            H.graph[k] = ""
        elif isinstance(v, (list, dict, set, tuple)):
            H.graph[k] = str(v)

    nx.write_gexf(H, f"{prefix}.gexf")
    print(f"Saved: {prefix}_nodes.csv, {prefix}_edges.csv, {prefix}.gexf")



# Run it

In [85]:
# Optionally set a key for higher limits:
# import os; os.environ["S2_API_KEY"] = "YOUR_KEY_HERE"

seed = "10.1109/ACCESS.2021.3078549"  # DOI, S2 paperId, or title string
depth = 1
since = None
until = None
max_nodes = 400
page_size = 1000  # max 1000

In [86]:
G = s2_forward_graph(seed, depth=depth, since=since, until=until, max_nodes=max_nodes, progress=True)

HTTPError: 429 Client Error:  for url: https://api.semanticscholar.org/graph/v1/paper/DOI%3A10.1109%2FACCESS.2021.3078549?fields=paperId%2Ctitle%2Cyear%2CexternalIds%2Cvenue

In [None]:
nodes_df, edges_df = graph_to_dfs(G)
display(nodes_df.head())
display(edges_df.head())
print("Nodes:", G.number_of_nodes(), "Edges:", G.number_of_edges())

Unnamed: 0,id,title,year,doi,venue,cited_by_count
0,dbbd634ab2ce72116a88123d1a6a6caa0ca34ab8,"Internet of Things 2.0: Concepts, Applications...",2021,10.1109/ACCESS.2021.3078549,IEEE Access,


Nodes: 1 Edges: 0


In [None]:
# save
save_graph(G, prefix="s2_forward")

Saved: s2_forward_nodes.csv, s2_forward_edges.csv, s2_forward.gexf


# Visualize


In [None]:
from pyvis.network import Network
import math
import html

net = Network(height="750px", width="100%", notebook=True, directed=True)
# different layouts
net.force_atlas_2based()
# Make node selection clearer, by changing border width and color
net.selection = {"nodes": {"borderWidth": 3, "borderWidthSelected": 5}}
# net.selection_menu = True
# net.toggle_physics(False)

def short(s, n=60):
    s = s or "(no title)"
    return s if len(s) <= n else s[:n-1] + "…"

for nid, data in G.nodes(data=True):
    title = data.get("title") or "(no title)"
    year  = data.get("year") or "?"
    doi   = data.get("doi") or ""
    size  = 5 + 2*math.sqrt((data.get("cited_by_count") or 0))

    # Show the (possibly truncated) title ON the node:
    label = short(title, n=42)  # adjust length to taste

    # Show full details ON HOVER:
    tooltip = (
        f"{html.escape(title)} ({year})"
        + (f"<br>DOI: {html.escape(doi)}" if doi else "")
    )

    net.add_node(
        nid,
        label=label,      # what you see on the node
        title=tooltip,    # hover text (HTML allowed)
        value=size
    )

for u, v in G.edges():
    net.add_edge(u, v, arrows="to")

net.show("s2_forward_title_labels.html", notebook=False)   # opens in notebook; also saves the file


s2_forward_title_labels.html
