In [None]:
import math
import json
from collections import Counter
import numpy as np
import pandas as pd
import networkx as nx
from tqdm import tqdm

# ============= CONFIG =============
_EPS = 1e-12
COMPARE_R_TO_GLOBAL = True
MAX_RADIUS = 5        # max Dijkstra radius
N_PROTOTYPES = 50     # limit prototypes if desired

# ============= LOAD DATA ==========
edges = pd.read_csv("edges.csv")
with open("features.json","r") as f: node_features = json.load(f)
targets = pd.read_csv("target.csv")

edges["from"] = edges["from"].astype(str)
edges["to"]   = edges["to"].astype(str)
if "id" in targets.columns:     targets["id"] = targets["id"].astype(str)
if "new_id" in targets.columns: targets["new_id"] = targets["new_id"].astype(str)

# pick correct ID column
edge_nodes = set(edges["from"]).union(set(edges["to"]))
best_col = max([c for c in ["id","new_id"] if c in targets.columns],
               key=lambda c: len(edge_nodes & set(targets[c].astype(str))))
targets["_key"] = targets[best_col].astype(str)
print(f"[ID match] using targets.{best_col}")

# keep necessary numeric columns
for col in ["mature","partner","days","views"]:
    if col not in targets.columns:
        raise RuntimeError(f"Missing column {col}")
targets["mature"]  = targets["mature"].fillna(0).astype(int)
targets["partner"] = targets["partner"].fillna(0).astype(int)
targets["days"]    = pd.to_numeric(targets["days"], errors="coerce")
targets["views"]   = pd.to_numeric(targets["views"], errors="coerce")

views_q1, views_q2 = targets["views"].quantile([0.33,0.66])
days_q1, days_q2   = targets["days"].quantile([0.33,0.66])
mean_days          = float(targets["days"].mean())

def _views_band(v):
    if pd.isna(v): return "unknown"
    return "low" if v<views_q1 else "medium" if v<views_q2 else "high"
def _age_band(d):
    if pd.isna(d): return "unknown"
    return "young" if d<days_q1 else "mid" if d<days_q2 else "old"

# ============= BUILD NODE GRAPH ==========
G = nx.Graph()
for _,r in edges.iterrows():
    G.add_edge(str(r["from"]), str(r["to"]))

# attach node attributes
for _,row in targets.iterrows():
    nid = row["_key"]
    if nid not in G: continue
    G.nodes[nid]["ExplicitLanguage"] = int(row["mature"])
    G.nodes[nid]["Partner"] = int(row["partner"])
    G.nodes[nid]["HighActivity"] = int(row["days"]>mean_days if not pd.isna(row["days"]) else 0)
    G.nodes[nid]["ViewsBand"] = _views_band(row["views"])
    G.nodes[nid]["AgeBand"]   = _age_band(row["days"])

print(f"Graph: |V|={G.number_of_nodes()} |E|={G.number_of_edges()}")

# ============= KL UTILITIES =============
def _dist(vals):
    n=len(vals)
    if n==0: return {}
    c=Counter(vals)
    return {k:v/n for k,v in c.items()}

def wkl_quality_nodes(S_nodes, R_nodes, binary_attrs, nominal_attrs):
    nS, nR = len(S_nodes), len(R_nodes)
    if nS==0 or nR==0: return 0.0
    qsum=0.0
    for attr in binary_attrs:
        PS=_dist([G.nodes[u][attr] for u in S_nodes if attr in G.nodes[u]])
        PR=_dist([G.nodes[u][attr] for u in R_nodes if attr in G.nodes[u]])
        for y in (0,1):
            ps,pr=PS.get(y,_EPS),PR.get(y,_EPS)
            qsum+=ps*math.log(ps/pr)
    for attr in nominal_attrs:
        PS=_dist([G.nodes[u][attr] for u in S_nodes if attr in G.nodes[u]])
        PR=_dist([G.nodes[u][attr] for u in R_nodes if attr in G.nodes[u]])
        dom=set(PS)|set(PR)
        for y in dom:
            ps,pr=PS.get(y,_EPS),PR.get(y,_EPS)
            qsum+=ps*math.log(ps/pr)
    return (nS/nR)*qsum

# ============= LSD ALGORITHM (node version) =============
def rank_nodes_by_distance(G, source):
    lengths = nx.single_source_shortest_path_length(G, source, cutoff=MAX_RADIUS)
    return [n for n,_ in sorted(lengths.items(), key=lambda kv: kv[1])]

def find_best_q_node(G, proto, binary_attrs, nominal_attrs, all_nodes):
    ranking = rank_nodes_by_distance(G, proto)
    if len(ranking)<2: return None

    baseline = all_nodes if COMPARE_R_TO_GLOBAL else ranking
    best_rho, best_q_rg = 0, -float("inf")
    for rho in range(2, len(ranking)+1):
        R = ranking[:rho]
        q_rg = wkl_quality_nodes(R, baseline, binary_attrs, nominal_attrs)
        if q_rg>best_q_rg:
            best_q_rg, best_rho = q_rg, rho

    R_best = ranking[:best_rho]
    best_sigma, best_q_sr = 0, -float("inf")
    for sigma in range(1, best_rho):
        S = R_best[:sigma]
        q_sr = wkl_quality_nodes(S, R_best, binary_attrs, nominal_attrs)
        if q_sr>best_q_sr:
            best_q_sr, best_sigma = q_sr, sigma

    return {"prototype":proto,"q":best_q_sr,"q_rg":best_q_rg,
            "rho":best_rho,"sigma":best_sigma,"n_nodes":len(ranking)}

# ============= MODES (same idea) =============
MODES = {
    "A_binary": {"binary":["ExplicitLanguage"],"nominal":[]},
    "B_multi_binary": {"binary":["ExplicitLanguage","Partner","HighActivity"],"nominal":[]},
    "C_mixed": {"binary":["ExplicitLanguage"],"nominal":["ViewsBand","AgeBand"]},
    "D_nominal": {"binary":[],"nominal":["ViewsBand","AgeBand"]}
}

# ============= PROTOTYPE SELECTION ============
deg = {u:G.degree(u) for u in G.nodes()}
top_nodes = [u for u,_ in sorted(deg.items(), key=lambda kv: kv[1], reverse=True)]
prototypes = top_nodes[:min(N_PROTOTYPES, len(top_nodes))]
print(f"Selected {len(prototypes)} prototypes (highest-degree nodes)")

# ============= MAIN LOOP =====================
all_nodes = list(G.nodes())
records=[]
for proto in tqdm(prototypes):
    for mode,cfg in MODES.items():
        res = find_best_q_node(G, proto, cfg["binary"], cfg["nominal"], all_nodes)
        if res is None: continue
        records.append({
            "mode":mode,"prototype":proto,"q":res["q"],"q_rg":res["q_rg"],
            "rho":res["rho"],"sigma":res["sigma"],"n_nodes":res["n_nodes"],
            "n_binary":len(cfg["binary"]),"n_nominal":len(cfg["nominal"])
        })

results = pd.DataFrame.from_records(records)
if results.empty:
    raise RuntimeError("No LSD results produced.")

summary = results.groupby("mode").agg(
    prototypes=("prototype","nunique"),
    mean_q=("q","mean"),std_q=("q","std"),median_q=("q","median"),
    mean_q_rg=("q_rg","mean")
).reset_index().sort_values("mean_q",ascending=False)

print("\n=== MODE SUMMARY ===")
print(summary.to_string(index=False))

print("\n=== BEST PROTOTYPES (top 20) ===")
print(results.sort_values("q",ascending=False).head(20).to_string(index=False))


[ID match] using targets.new_id
Graph: |V|=1912 |E|=31299
Selected 1 prototypes (highest-degree nodes)
Graph: |V|=1912 |E|=31299
Selected 1 prototypes (highest-degree nodes)


100%|██████████| 1/1 [00:12<00:00, 12.97s/it]


=== MODE SUMMARY ===
          mode  prototypes   mean_q  std_q  median_q  mean_q_rg
       C_mixed           1 0.007804    NaN  0.007804   0.015138
B_multi_binary           1 0.005154    NaN  0.005154   0.009807
     D_nominal           1 0.004810    NaN  0.004810   0.012874
      A_binary           1 0.003164    NaN  0.003164   0.002468

=== BEST PROTOTYPES (top 20) ===
          mode prototype        q     q_rg  rho  sigma  n_nodes  n_binary  n_nominal
       C_mixed       127 0.007804 0.015138 1019    769     1912         1          2
B_multi_binary       127 0.005154 0.009807 1083      2     1912         3          0
     D_nominal       127 0.004810 0.012874 1092    770     1912         0          2
      A_binary       127 0.003164 0.002468 1018    779     1912         1          0





In [None]:
# ================================================================
# Find the top 100 subgroups based on KL-divergence quality measure
# ================================================================

import math
import json
from collections import Counter, defaultdict
import multiprocessing as mp
from functools import partial
import numpy as np
import pandas as pd
import networkx as nx
from tqdm import tqdm

# Parameters
MAX_RADIUS = 6
TOP_K = 100
_EPS = 1e-12

# Helper functions
def normalized_counts(values):
    c = Counter(values)
    total = sum(c.values())
    if total == 0:
        return {}
    return {k: v / total for k, v in c.items()}

def kl_divergence(p, q, eps=_EPS):
    s = 0.0
    for k, pv in p.items():
        qv = q.get(k, eps)
        pv = max(pv, eps)
        s += pv * math.log(pv / qv)
    return s

def make_global_distributions(G, attrs):
    global_dist = {}
    for a in attrs:
        vals = [str(G.nodes[n].get(a, None)) for n in G.nodes if G.nodes[n].get(a, None) is not None]
        global_dist[a] = normalized_counts(vals)
    return global_dist

# Core function to evaluate node quality
def compute_quality_for_node(G, node, attrs, global_dist, max_radius=MAX_RADIUS):
    best_q, best_r = -float('inf'), 0
    best_nodes = [node]

    try:
        dist = nx.single_source_shortest_path_length(G, node, cutoff=max_radius)
    except Exception:
        return node, 0, 0.0, 0, [node]

    nodes_by_d = defaultdict(list)
    for n, d in dist.items():
        nodes_by_d[d].append(n)

    accumulated = set()
    for r in range(1, max_radius + 1):
        accumulated.update(nodes_by_d.get(r, []))
        if not accumulated:
            continue
        q_sum = 0.0
        for a in attrs:
            values = [str(G.nodes[n].get(a, '')) for n in accumulated if G.nodes[n].get(a, None) is not None]
            if not values:
                continue
            p = normalized_counts(values)
            q = global_dist.get(a, {})
            q_sum += kl_divergence(p, q)
        if q_sum > best_q:
            best_q, best_r, best_nodes = q_sum, r, list(accumulated)

    return node, best_r, best_q, len(best_nodes), best_nodes

# Run search over all nodes
def find_top_k_subgroups(G, attrs, max_radius=MAX_RADIUS, top_k=TOP_K, min_size=5, processes=4):
    """Return top_k subgroups whose size >= min_size. """
    global_dist = make_global_distributions(G, attrs)
    nodes = list(G.nodes)

    worker = partial(compute_quality_for_node, G, attrs=attrs, global_dist=global_dist, max_radius=max_radius)

    results = []
    for n in tqdm(nodes, total=len(nodes), desc='Scanning nodes'):
        results.append(worker(n))

    df = pd.DataFrame(results, columns=['node', 'best_radius', 'best_q', 'size', 'members'])
    # filter by minimum subgroup size
    df = df[df['size'] >= min_size]
    df = df.sort_values(by='best_q', ascending=False).reset_index(drop=True)
    return df.head(top_k)

# Usage (run after your graph G is defined)
attrs = ['ExplicitLanguage', 'Partner', 'ViewsBand', 'AgeBand', 'HighActivity']
top_subgroups = find_top_k_subgroups(G, attrs, min_size=5)
top_subgroups.head()

# You can then save results if desired:
top_subgroups.to_csv('top100_subgroups_min5.csv', index=False)

Scanning nodes: 100%|██████████| 1912/1912 [00:27<00:00, 69.54it/s]
Scanning nodes: 100%|██████████| 1912/1912 [00:27<00:00, 69.54it/s]


In [3]:
top_subgroups

Unnamed: 0,node,best_radius,best_q,size,members
0,292,1,5.204105,5,"[1476, 1374, 682, 460, 1196]"
1,1123,1,4.935198,13,"[814, 306, 1259, 290, 467, 1311, 1297, 1376, 1..."
2,1417,1,4.885124,7,"[467, 1660, 1643, 1297, 1320, 103, 1068]"
3,1331,1,4.859868,6,"[290, 1721, 141, 1320, 1660, 305]"
4,401,1,4.758489,6,"[1476, 290, 1721, 103, 1660, 127]"
...,...,...,...,...,...
95,640,1,3.539765,28,"[682, 1376, 1721, 868, 1366, 471, 781, 305, 12..."
96,126,1,3.539110,39,"[1486, 1821, 880, 1763, 1311, 1433, 1825, 369,..."
97,1488,1,3.530827,7,"[1476, 1557, 1758, 1320, 103, 1660, 127]"
98,1757,1,3.530671,12,"[496, 290, 467, 1142, 235, 1311, 1297, 1433, 4..."
