In [3]:
import math
import json
from collections import Counter
import numpy as np
import pandas as pd
import networkx as nx

# ============= CONFIG =============
_EPS = 1e-12
COMPARE_R_TO_GLOBAL = True
MAX_RADIUS = 5        # max Dijkstra radius
N_PROTOTYPES = 30     # limit prototypes if desired

# ============= LOAD DATA ==========
edges = pd.read_csv("edges.csv")
with open("features.json","r") as f: node_features = json.load(f)
targets = pd.read_csv("target.csv")

edges["from"] = edges["from"].astype(str)
edges["to"]   = edges["to"].astype(str)
if "id" in targets.columns:     targets["id"] = targets["id"].astype(str)
if "new_id" in targets.columns: targets["new_id"] = targets["new_id"].astype(str)

# pick correct ID column
edge_nodes = set(edges["from"]).union(set(edges["to"]))
best_col = max([c for c in ["id","new_id"] if c in targets.columns],
               key=lambda c: len(edge_nodes & set(targets[c].astype(str))))
targets["_key"] = targets[best_col].astype(str)
print(f"[ID match] using targets.{best_col}")

# keep necessary numeric columns
for col in ["mature","partner","days","views"]:
    if col not in targets.columns:
        raise RuntimeError(f"Missing column {col}")
targets["mature"]  = targets["mature"].fillna(0).astype(int)
targets["partner"] = targets["partner"].fillna(0).astype(int)
targets["days"]    = pd.to_numeric(targets["days"], errors="coerce")
targets["views"]   = pd.to_numeric(targets["views"], errors="coerce")

views_q1, views_q2 = targets["views"].quantile([0.33,0.66])
days_q1, days_q2   = targets["days"].quantile([0.33,0.66])
mean_days          = float(targets["days"].mean())

def _views_band(v):
    if pd.isna(v): return "unknown"
    return "low" if v<views_q1 else "medium" if v<views_q2 else "high"
def _age_band(d):
    if pd.isna(d): return "unknown"
    return "young" if d<days_q1 else "mid" if d<days_q2 else "old"

# ============= BUILD NODE GRAPH ==========
G = nx.Graph()
for _,r in edges.iterrows():
    G.add_edge(str(r["from"]), str(r["to"]))

# attach node attributes
for _,row in targets.iterrows():
    nid = row["_key"]
    if nid not in G: continue
    G.nodes[nid]["ExplicitLanguage"] = int(row["mature"])
    G.nodes[nid]["Partner"] = int(row["partner"])
    G.nodes[nid]["HighActivity"] = int(row["days"]>mean_days if not pd.isna(row["days"]) else 0)
    G.nodes[nid]["ViewsBand"] = _views_band(row["views"])
    G.nodes[nid]["AgeBand"]   = _age_band(row["days"])

print(f"Graph: |V|={G.number_of_nodes()} |E|={G.number_of_edges()}")

# ============= KL UTILITIES =============
def _dist(vals):
    n=len(vals)
    if n==0: return {}
    c=Counter(vals)
    return {k:v/n for k,v in c.items()}

def wkl_quality_nodes(S_nodes, R_nodes, binary_attrs, nominal_attrs):
    nS, nR = len(S_nodes), len(R_nodes)
    if nS==0 or nR==0: return 0.0
    qsum=0.0
    for attr in binary_attrs:
        PS=_dist([G.nodes[u][attr] for u in S_nodes if attr in G.nodes[u]])
        PR=_dist([G.nodes[u][attr] for u in R_nodes if attr in G.nodes[u]])
        for y in (0,1):
            ps,pr=PS.get(y,_EPS),PR.get(y,_EPS)
            qsum+=ps*math.log(ps/pr)
    for attr in nominal_attrs:
        PS=_dist([G.nodes[u][attr] for u in S_nodes if attr in G.nodes[u]])
        PR=_dist([G.nodes[u][attr] for u in R_nodes if attr in G.nodes[u]])
        dom=set(PS)|set(PR)
        for y in dom:
            ps,pr=PS.get(y,_EPS),PR.get(y,_EPS)
            qsum+=ps*math.log(ps/pr)
    return (nS/nR)*qsum

# ============= LSD ALGORITHM (node version) =============
def rank_nodes_by_distance(G, source):
    lengths = nx.single_source_shortest_path_length(G, source, cutoff=MAX_RADIUS)
    return [n for n,_ in sorted(lengths.items(), key=lambda kv: kv[1])]

def find_best_q_node(G, proto, binary_attrs, nominal_attrs, all_nodes):
    ranking = rank_nodes_by_distance(G, proto)
    if len(ranking)<2: return None

    baseline = all_nodes if COMPARE_R_TO_GLOBAL else ranking
    best_rho, best_q_rg = 0, -float("inf")
    for rho in range(2, len(ranking)+1):
        R = ranking[:rho]
        q_rg = wkl_quality_nodes(R, baseline, binary_attrs, nominal_attrs)
        if q_rg>best_q_rg:
            best_q_rg, best_rho = q_rg, rho

    R_best = ranking[:best_rho]
    best_sigma, best_q_sr = 0, -float("inf")
    for sigma in range(1, best_rho):
        S = R_best[:sigma]
        q_sr = wkl_quality_nodes(S, R_best, binary_attrs, nominal_attrs)
        if q_sr>best_q_sr:
            best_q_sr, best_sigma = q_sr, sigma

    return {"prototype":proto,"q":best_q_sr,"q_rg":best_q_rg,
            "rho":best_rho,"sigma":best_sigma,"n_nodes":len(ranking)}

# ============= MODES (same idea) =============
MODES = {
    "A_binary": {"binary":["ExplicitLanguage"],"nominal":[]},
    "B_multi_binary": {"binary":["ExplicitLanguage","Partner","HighActivity"],"nominal":[]},
    "C_mixed": {"binary":["ExplicitLanguage"],"nominal":["ViewsBand","AgeBand"]},
    "D_nominal": {"binary":[],"nominal":["ViewsBand","AgeBand"]}
}

# ============= PROTOTYPE SELECTION ============
deg = {u:G.degree(u) for u in G.nodes()}
top_nodes = [u for u,_ in sorted(deg.items(), key=lambda kv: kv[1], reverse=True)]
prototypes = top_nodes[:min(N_PROTOTYPES, len(top_nodes))]
print(f"Selected {len(prototypes)} prototypes (highest-degree nodes)")

# ============= MAIN LOOP =====================
all_nodes = list(G.nodes())
records=[]
for proto in prototypes:
    for mode,cfg in MODES.items():
        res = find_best_q_node(G, proto, cfg["binary"], cfg["nominal"], all_nodes)
        if res is None: continue
        records.append({
            "mode":mode,"prototype":proto,"q":res["q"],"q_rg":res["q_rg"],
            "rho":res["rho"],"sigma":res["sigma"],"n_nodes":res["n_nodes"],
            "n_binary":len(cfg["binary"]),"n_nominal":len(cfg["nominal"])
        })

results = pd.DataFrame.from_records(records)
if results.empty:
    raise RuntimeError("No LSD results produced.")

summary = results.groupby("mode").agg(
    prototypes=("prototype","nunique"),
    mean_q=("q","mean"),std_q=("q","std"),median_q=("q","median"),
    mean_q_rg=("q_rg","mean")
).reset_index().sort_values("mean_q",ascending=False)

print("\n=== MODE SUMMARY ===")
print(summary.to_string(index=False))

print("\n=== BEST PROTOTYPES (top 20) ===")
print(results.sort_values("q",ascending=False).head(20).to_string(index=False))


[ID match] using targets.new_id
Graph: |V|=1912 |E|=31299
Selected 30 prototypes (highest-degree nodes)

=== MODE SUMMARY ===
          mode  prototypes   mean_q    std_q  median_q  mean_q_rg
       C_mixed          30 0.012090 0.006576  0.010821   0.034380
B_multi_binary          30 0.011956 0.006328  0.009575   0.027171
     D_nominal          30 0.009646 0.005207  0.008428   0.029798
      A_binary          30 0.004074 0.001825  0.004067   0.006144

=== BEST PROTOTYPES (top 20) ===
          mode prototype        q     q_rg  rho  sigma  n_nodes  n_binary  n_nominal
       C_mixed       682 0.034948 0.025102  612    229     1912         1          2
B_multi_binary      1311 0.031860 0.018163  488    293     1912         3          0
       C_mixed       777 0.025115 0.055881  261     12     1912         1          2
     D_nominal       777 0.024104 0.046904  227     11     1912         0          2
     D_nominal       682 0.023090 0.024399  612    227     1912         0          2
