In [1]:
import math
import json
from collections import Counter, deque
import numpy as np
import pandas as pd
import networkx as nx

# =========================
# Config
# =========================
_EPS = 1e-12

PROTOTYPE_SELECTION = {
    "mode": "improved",        # "original" (all nodes) or "improved" (degree + max-min distance)
    "max_prototypes": 50,      # l < n
    "quality_threshold": -1.0, # disable early-stop while debugging (set >0 later)
    "patience_k": 8,
    "distance_cutoff": 4
}

COMPARE_R_TO_GLOBAL = True     # q(R, G) vs global edges (paper's original ablation)
N_RHO_LIMIT = None              # Optional cap for ρ per prototype (e.g., 300)

# =========================
# Load Twitch PT data
# =========================
edges = pd.read_csv("edges.csv")
with open("features.json", "r") as f:
    node_features = json.load(f)
targets = pd.read_csv("target.csv")

# normalize id types
edges["from"] = edges["from"].astype(str)
edges["to"]   = edges["to"].astype(str)
if "id" in targets.columns:     targets["id"]     = targets["id"].astype(str)
if "new_id" in targets.columns: targets["new_id"] = targets["new_id"].astype(str)

print(f"Loaded: {len(edges)} edges, {len(node_features)} nodes with features, {len(targets)} target rows")

# =========================
# Detect correct target ID column
# =========================
edge_nodes = set(edges["from"]).union(set(edges["to"]))
cand_cols = [c for c in ["id", "new_id"] if c in targets.columns]

def _overlap(col):
    return len(edge_nodes.intersection(set(targets[col].astype(str))))

if not cand_cols:
    raise RuntimeError("target.csv must contain 'id' or 'new_id' column.")

best_col    = max(cand_cols, key=_overlap)
best_overlap = _overlap(best_col)
alt_col      = None
alt_overlap  = 0
if len(cand_cols) == 2:
    alt_col = (set(cand_cols) - {best_col}).pop()
    alt_overlap = _overlap(alt_col)

print(f"[ID match] Using targets column '{best_col}' (overlap={best_overlap}); "
      f"{alt_col+'='+str(alt_overlap) if alt_col else ''}")

# Build a fast lookup by chosen ID column
targets["_key"] = targets[best_col].astype(str)

# Keep essential fields; fill missing numerics with NaN, binaries with 0
for col in ["mature", "partner", "days", "views"]:
    if col not in targets.columns:
        raise RuntimeError(f"target.csv missing required column '{col}'")
targets["mature"]  = targets["mature"].fillna(0).astype(int)
targets["partner"] = targets["partner"].fillna(0).astype(int)
targets["days"]    = pd.to_numeric(targets["days"], errors="coerce")
targets["views"]   = pd.to_numeric(targets["views"], errors="coerce")

t_by_id = {row["_key"]: row for _, row in targets.iterrows()}

# =========================
# Build edge feature matrix (union of endpoint feature sets)
# =========================
def edge_feature_union(i, j):
    fi = node_features.get(i, [])
    fj = node_features.get(j, [])
    return list(set(fi) | set(fj))

edge_feature_dict = { (r["from"], r["to"]): edge_feature_union(r["from"], r["to"])
                      for _, r in edges.iterrows() }

all_codes = sorted({c for feats in edge_feature_dict.values() for c in feats})
edge_features = pd.DataFrame(
    0,
    index=pd.MultiIndex.from_tuples(edge_feature_dict.keys(), names=["i", "j"]),
    columns=all_codes
)
for (i, j), feats in edge_feature_dict.items():
    edge_features.loc[(i, j), feats] = 1

print(f"Edge feature matrix: {edge_features.shape[0]} edges × {edge_features.shape[1]} features")

# =========================
# Global stats for target binning
# =========================
# Only use rows with non-null numerics for quantiles/means
views_series = targets["views"].dropna()
days_series  = targets["days"].dropna()

# If everything is NaN (bad CSV), avoid crash
if views_series.empty: views_series = pd.Series([0.0])
if days_series.empty:  days_series  = pd.Series([0.0])

views_q1, views_q2 = views_series.quantile([0.33, 0.66])
days_q1, days_q2   = days_series.quantile([0.33, 0.66])
mean_days          = float(days_series.mean())

def _bands_from_val(val, q1, q2, low="low", mid="medium", hi="high"):
    if pd.isna(val):
        return "unknown"
    if val < q1:
        return low
    elif val < q2:
        return mid
    else:
        return hi

def _age_band(days_val):
    if pd.isna(days_val):
        return "unknown"
    if days_val < days_q1:
        return "young"
    elif days_val < days_q2:
        return "mid"
    else:
        return "old"

def _get_row(node_id: str):
    return t_by_id.get(node_id, None)

# =========================
# Derive *edge* targets from both endpoints
# =========================
def derive_edge_targets(i: str, j: str, row_vec: pd.Series):
    total_value = float(row_vec.sum())

    si = _get_row(i)
    sj = _get_row(j)

    # source
    mature_i   = int(si["mature"])  if si is not None else 0
    partner_i  = int(si["partner"]) if si is not None else 0
    days_i     = float(si["days"])  if si is not None else np.nan
    views_i    = float(si["views"]) if si is not None else np.nan
    highact_i  = int((not pd.isna(days_i)) and (days_i > mean_days))
    views_i_b  = _bands_from_val(views_i, views_q1, views_q2)
    age_i_b    = _age_band(days_i)

    # dest
    mature_j   = int(sj["mature"])  if sj is not None else 0
    partner_j  = int(sj["partner"]) if sj is not None else 0
    days_j     = float(sj["days"])  if sj is not None else np.nan
    views_j    = float(sj["views"]) if sj is not None else np.nan
    highact_j  = int((not pd.isna(days_j)) and (days_j > mean_days))
    views_j_b  = _bands_from_val(views_j, views_q1, views_q2)
    age_j_b    = _age_band(days_j)

    return {
        # binary (both endpoints)
        "ExplicitLanguage_src": mature_i,
        "ExplicitLanguage_dst": mature_j,
        "Partner_src": partner_i,
        "Partner_dst": partner_j,
        "HighActivity_src": highact_i,
        "HighActivity_dst": highact_j,

        # nominal (both endpoints)
        "ViewsBand_src": views_i_b,
        "ViewsBand_dst": views_j_b,
        "AgeBand_src":   age_i_b,
        "AgeBand_dst":   age_j_b,

        # helper
        "TotalValue": total_value
    }

# =========================
# Build DiGraph with edge data
# =========================
G = nx.DiGraph()
found_src = found_dst = 0

for (i, j), row in edge_features.iterrows():
    tdict = derive_edge_targets(i, j, row)
    G.add_edge(i, j, features=row.to_dict(), **tdict)
    if i in t_by_id: found_src += 1
    if j in t_by_id: found_dst += 1

print(f"Graph: |V|={G.number_of_nodes()}, |E|={G.number_of_edges()}")
print(f"[coverage] edges with source in targets: ~{found_src/len(edge_features):.1%}, "
      f"dest in targets: ~{found_dst/len(edge_features):.1%}")

# =========================
# Quality (Weighted KL) on edges
# =========================
def _dist(vals):
    n = len(vals)
    if n == 0:
        return {}
    c = Counter(vals)
    return {k: v / n for k, v in c.items()}

def wkl_quality_edges(S_edges, R_edges, binary_attrs, nominal_attrs):
    nS, nR = len(S_edges), len(R_edges)
    if nS == 0 or nR == 0:
        return 0.0
    qsum = 0.0

    # Binary: ensure both 0 and 1 are considered
    for attr in binary_attrs:
        P_S = _dist([d[attr] for _, _, d in S_edges])
        P_R = _dist([d[attr] for _, _, d in R_edges])
        for y in (0, 1):
            ps = P_S.get(y, _EPS)
            pr = P_R.get(y, _EPS)
            qsum += ps * math.log(ps / pr)

    # Nominal: domain from both
    for attr in nominal_attrs:
        P_S = _dist([d[attr] for _, _, d in S_edges])
        P_R = _dist([d[attr] for _, _, d in R_edges])
        dom = set(P_S) | set(P_R)
        for y in dom:
            ps = P_S.get(y, _EPS)
            pr = P_R.get(y, _EPS)
            qsum += ps * math.log(ps / pr)

    return (nS / nR) * qsum

def _rank_out_edges_for_proto(G, proto):
    edges_list = [(u, v, d) for u, v, d in G.out_edges(proto, data=True)]
    # sort by TotalValue descending
    edges_list.sort(key=lambda x: x[2]["TotalValue"], reverse=True)
    return edges_list

# =========================
# Modes (≥2 nominal vars included)
# =========================
MODES = {
    "A_binary": {
        "binary": ["ExplicitLanguage_src", "ExplicitLanguage_dst"],
        "nominal": []
    },
    "B_multi_binary": {
        "binary": [
            "ExplicitLanguage_src", "ExplicitLanguage_dst",
            "Partner_src", "Partner_dst",
            "HighActivity_src", "HighActivity_dst"
        ],
        "nominal": []
    },
    "C_mixed": {
        "binary": ["ExplicitLanguage_src", "ExplicitLanguage_dst"],
        "nominal": ["ViewsBand_src", "ViewsBand_dst", "AgeBand_src", "AgeBand_dst"]
    },
    "D_nominal": {
        "binary": [],
        "nominal": ["ViewsBand_src", "ViewsBand_dst", "AgeBand_src", "AgeBand_dst"]
    },
}

# =========================
# Find best q for a prototype
#   - choose ρ by maximizing q(R, GLOBAL or out-edges)
#   - choose σ by maximizing q(S, R)
# =========================
def find_best_q_for_prototype(G, proto, binary_attrs, nominal_attrs, global_edges_cache):
    ranked = _rank_out_edges_for_proto(G, proto)
    if len(ranked) < 2:
        return None

    baseline = global_edges_cache if COMPARE_R_TO_GLOBAL else ranked

    best_rho, best_q_rg = 0, -float("inf")
    max_rho = len(ranked) if N_RHO_LIMIT is None else min(N_RHO_LIMIT, len(ranked))
    for rho in range(2, max_rho + 1):
        R = ranked[:rho]
        q_rg = wkl_quality_edges(R, baseline, binary_attrs, nominal_attrs)
        if q_rg > best_q_rg:
            best_q_rg = q_rg
            best_rho = rho

    R_best = ranked[:best_rho]
    best_sigma, best_q_sr = 0, -float("inf")
    for sigma in range(1, best_rho):
        S = R_best[:sigma]
        q_sr = wkl_quality_edges(S, R_best, binary_attrs, nominal_attrs)
        if q_sr > best_q_sr:
            best_q_sr = q_sr
            best_sigma = sigma

    return {
        "prototype": proto,
        "rho": best_rho,
        "sigma": best_sigma,
        "q": best_q_sr,
        "n_out": len(ranked),
        "q_rg": best_q_rg
    }

# =========================
# Prototype selection
# =========================
def select_prototypes(G):
    nodes = list(G.nodes())
    if PROTOTYPE_SELECTION["mode"] == "original":
        return nodes

    l = PROTOTYPE_SELECTION["max_prototypes"]
    # Start with highest-degree node
    degs = {u: G.degree(u) for u in nodes}
    first = max(degs, key=degs.get)
    selected = [first]
    remaining = set(nodes) - {first}

    # Use undirected view; bounded all-pairs shortest paths
    UG = G.to_undirected(as_view=True)
    spl = dict(nx.all_pairs_shortest_path_length(UG, cutoff=PROTOTYPE_SELECTION["distance_cutoff"]))

    while len(selected) < l and remaining:
        best_node, best_min = None, -1
        for cand in list(remaining):
            cand_map = spl.get(cand, {})
            dists = [cand_map.get(s, np.inf) for s in selected]
            m = min(dists) if dists else 0
            if (m > best_min) or (m == best_min and degs.get(cand, 0) > degs.get(best_node, 0)):
                best_node, best_min = cand, m
        if best_node is None:
            break
        selected.append(best_node)
        remaining.remove(best_node)

    return selected

# =========================
# Main loop
# =========================
global_edges = [(u, v, d) for u, v, d in G.edges(data=True)]
prototypes = select_prototypes(G)
print(f"Prototypes selected ({PROTOTYPE_SELECTION['mode']}): {len(prototypes)}")

records = []
weak_streak = deque(maxlen=PROTOTYPE_SELECTION["patience_k"])

for proto in prototypes:
    any_strong = False
    for mode_name, cfg in MODES.items():
        res = find_best_q_for_prototype(G, proto, cfg["binary"], cfg["nominal"], global_edges)
        if res is None:
            continue
        qv = res["q"]
        records.append({
            "mode": mode_name,
            "prototype": res["prototype"],
            "q": qv,
            "q_rg": res["q_rg"],
            "rho": res["rho"],
            "sigma": res["sigma"],
            "n_out": res["n_out"],
            "n_binary": len(cfg["binary"]),
            "n_nominal": len(cfg["nominal"])
        })
        if qv >= PROTOTYPE_SELECTION["quality_threshold"]:
            any_strong = True
    weak_streak.append(not any_strong)

    # early-stop disabled while debugging unless you set threshold > 0
    if (PROTOTYPE_SELECTION["quality_threshold"] > 0 and
        len(weak_streak) == PROTOTYPE_SELECTION["patience_k"] and
        all(weak_streak)):
        print(f"Early stop: last {PROTOTYPE_SELECTION['patience_k']} prototypes yielded weak q "
              f"(< {PROTOTYPE_SELECTION['quality_threshold']}).")
        break

# =========================
# Reporting
# =========================
results = pd.DataFrame.from_records(records)
if results.empty:
    raise RuntimeError("No results produced. Most common cause: ID mismatch — check '[ID match]' and '[coverage]' lines above.")

summary = results.groupby("mode").agg(
    prototypes=("prototype", "nunique"),
    mean_q=("q", "mean"),
    std_q=("q", "std"),
    median_q=("q", "median"),
    mean_q_rg=("q_rg", "mean")
).reset_index().sort_values("mean_q", ascending=False)

print("\n=== MODE SUMMARY ===")
print(summary.to_string(index=False))

print("\n=== BEST PROTOTYPES (top 20 by q) ===")
print(results.sort_values("q", ascending=False).head(20).to_string(index=False))


Loaded: 31299 edges, 1912 nodes with features, 1912 target rows
[ID match] Using targets column 'new_id' (overlap=1912); id=0
Edge feature matrix: 31299 edges × 1449 features
Graph: |V|=1912, |E|=31299
[coverage] edges with source in targets: ~100.0%, dest in targets: ~100.0%
Prototypes selected (improved): 50

=== MODE SUMMARY ===
          mode  prototypes   mean_q    std_q  median_q  mean_q_rg
       C_mixed          10 0.548340 0.284942  0.618491   0.005572
B_multi_binary          10 0.387158 0.200158  0.346574   0.004862
     D_nominal          10 0.355563 0.168741  0.346574   0.004302
      A_binary          10 0.195904 0.166322  0.259930   0.001273

=== BEST PROTOTYPES (top 20 by q) ===
          mode prototype        q     q_rg  rho  sigma  n_out  n_binary  n_nominal
       C_mixed       314 1.039721 0.000291    2      1      2         2          4
       C_mixed       364 0.693147 0.000300    2      1      2         2          4
B_multi_binary       943 0.693147 0.000148    2 