In [1]:
import math
import json
from collections import Counter
import numpy as np
import pandas as pd
import networkx as nx
from tqdm import tqdm

# ============= CONFIG =============
_EPS = 1e-12
COMPARE_R_TO_GLOBAL = True
MAX_RADIUS = 5        # max Dijkstra radius
N_PROTOTYPES = 50     # limit prototypes if desired

# ============= LOAD DATA ==========
edges = pd.read_csv("edges.csv")
with open("features.json","r") as f: node_features = json.load(f)
targets = pd.read_csv("target.csv")

edges["from"] = edges["from"].astype(str)
edges["to"]   = edges["to"].astype(str)
if "id" in targets.columns:     targets["id"] = targets["id"].astype(str)
if "new_id" in targets.columns: targets["new_id"] = targets["new_id"].astype(str)

# pick correct ID column
edge_nodes = set(edges["from"]).union(set(edges["to"]))
best_col = max([c for c in ["id","new_id"] if c in targets.columns],
               key=lambda c: len(edge_nodes & set(targets[c].astype(str))))
targets["_key"] = targets[best_col].astype(str)
print(f"[ID match] using targets.{best_col}")

# keep necessary numeric columns
for col in ["mature","partner","days","views"]:
    if col not in targets.columns:
        raise RuntimeError(f"Missing column {col}")
targets["mature"]  = targets["mature"].fillna(0).astype(int)
targets["partner"] = targets["partner"].fillna(0).astype(int)
targets["days"]    = pd.to_numeric(targets["days"], errors="coerce")
targets["views"]   = pd.to_numeric(targets["views"], errors="coerce")

views_q1, views_q2 = targets["views"].quantile([0.33,0.66])
days_q1, days_q2   = targets["days"].quantile([0.33,0.66])
mean_days          = float(targets["days"].mean())

def _views_band(v):
    if pd.isna(v): return "unknown"
    return "low" if v<views_q1 else "medium" if v<views_q2 else "high"
def _age_band(d):
    if pd.isna(d): return "unknown"
    return "young" if d<days_q1 else "mid" if d<days_q2 else "old"

# ============= BUILD NODE GRAPH ==========
G = nx.Graph()
for _,r in edges.iterrows():
    G.add_edge(str(r["from"]), str(r["to"]))

# attach node attributes
for _,row in targets.iterrows():
    nid = row["_key"]
    if nid not in G: continue
    G.nodes[nid]["ExplicitLanguage"] = int(row["mature"])
    G.nodes[nid]["Partner"] = int(row["partner"])
    G.nodes[nid]["HighActivity"] = int(row["days"]>mean_days if not pd.isna(row["days"]) else 0)
    G.nodes[nid]["ViewsBand"] = _views_band(row["views"])
    G.nodes[nid]["AgeBand"]   = _age_band(row["days"])

print(f"Graph: |V|={G.number_of_nodes()} |E|={G.number_of_edges()}")

# ============= KL UTILITIES =============
def _dist(vals):
    n=len(vals)
    if n==0: return {}
    c=Counter(vals)
    return {k:v/n for k,v in c.items()}

def wkl_quality_nodes(S_nodes, R_nodes, binary_attrs, nominal_attrs):
    nS, nR = len(S_nodes), len(R_nodes)
    if nS==0 or nR==0: return 0.0
    qsum=0.0
    for attr in binary_attrs:
        PS=_dist([G.nodes[u][attr] for u in S_nodes if attr in G.nodes[u]])
        PR=_dist([G.nodes[u][attr] for u in R_nodes if attr in G.nodes[u]])
        for y in (0,1):
            ps,pr=PS.get(y,_EPS),PR.get(y,_EPS)
            qsum+=ps*math.log(ps/pr)
    for attr in nominal_attrs:
        PS=_dist([G.nodes[u][attr] for u in S_nodes if attr in G.nodes[u]])
        PR=_dist([G.nodes[u][attr] for u in R_nodes if attr in G.nodes[u]])
        dom=set(PS)|set(PR)
        for y in dom:
            ps,pr=PS.get(y,_EPS),PR.get(y,_EPS)
            qsum+=ps*math.log(ps/pr)
    return (nS/nR)*qsum

# ============= LSD ALGORITHM (node version) =============
def rank_nodes_by_distance(G, source):
    lengths = nx.single_source_shortest_path_length(G, source, cutoff=MAX_RADIUS)
    return [n for n,_ in sorted(lengths.items(), key=lambda kv: kv[1])]

def find_best_q_node(G, proto, binary_attrs, nominal_attrs, all_nodes):
    ranking = rank_nodes_by_distance(G, proto)
    if len(ranking)<2: return None

    baseline = all_nodes if COMPARE_R_TO_GLOBAL else ranking
    best_rho, best_q_rg = 0, -float("inf")
    for rho in range(2, len(ranking)+1):
        R = ranking[:rho]
        q_rg = wkl_quality_nodes(R, baseline, binary_attrs, nominal_attrs)
        if q_rg>best_q_rg:
            best_q_rg, best_rho = q_rg, rho

    R_best = ranking[:best_rho]
    best_sigma, best_q_sr = 0, -float("inf")
    for sigma in range(1, best_rho):
        S = R_best[:sigma]
        q_sr = wkl_quality_nodes(S, R_best, binary_attrs, nominal_attrs)
        if q_sr>best_q_sr:
            best_q_sr, best_sigma = q_sr, sigma

    return {"prototype":proto,"q":best_q_sr,"q_rg":best_q_rg,
            "rho":best_rho,"sigma":best_sigma,"n_nodes":len(ranking)}

# ============= MODES (same idea) =============
MODES = {
    "A_binary": {"binary":["ExplicitLanguage"],"nominal":[]},
    "B_multi_binary": {"binary":["ExplicitLanguage","Partner","HighActivity"],"nominal":[]},
    "C_mixed": {"binary":["ExplicitLanguage"],"nominal":["ViewsBand","AgeBand"]},
    "D_nominal": {"binary":[],"nominal":["ViewsBand","AgeBand"]}
}

# ============= PROTOTYPE SELECTION ============
deg = {u:G.degree(u) for u in G.nodes()}
top_nodes = [u for u,_ in sorted(deg.items(), key=lambda kv: kv[1], reverse=True)]
prototypes = top_nodes[:min(N_PROTOTYPES, len(top_nodes))]
print(f"Selected {len(prototypes)} prototypes (highest-degree nodes)")

# ============= MAIN LOOP =====================
all_nodes = list(G.nodes())
records=[]
for proto in tqdm(prototypes):
    for mode,cfg in MODES.items():
        res = find_best_q_node(G, proto, cfg["binary"], cfg["nominal"], all_nodes)
        if res is None: continue
        records.append({
            "mode":mode,"prototype":proto,"q":res["q"],"q_rg":res["q_rg"],
            "rho":res["rho"],"sigma":res["sigma"],"n_nodes":res["n_nodes"],
            "n_binary":len(cfg["binary"]),"n_nominal":len(cfg["nominal"])
        })

results = pd.DataFrame.from_records(records)
if results.empty:
    raise RuntimeError("No LSD results produced.")

summary = results.groupby("mode").agg(
    prototypes=("prototype","nunique"),
    mean_q=("q","mean"),std_q=("q","std"),median_q=("q","median"),
    mean_q_rg=("q_rg","mean")
).reset_index().sort_values("mean_q",ascending=False)

print("\n=== MODE SUMMARY ===")
print(summary.to_string(index=False))

print("\n=== BEST PROTOTYPES (top 20) ===")
print(results.sort_values("q",ascending=False).head(20).to_string(index=False))


[ID match] using targets.new_id
Graph: |V|=1912 |E|=31299
Selected 50 prototypes (highest-degree nodes)


100%|██████████| 50/50 [11:14<00:00, 13.49s/it]


=== MODE SUMMARY ===
          mode  prototypes   mean_q    std_q  median_q  mean_q_rg
       C_mixed          50 0.014845 0.009102  0.012544   0.038471
B_multi_binary          50 0.013500 0.007026  0.013240   0.029919
     D_nominal          50 0.011933 0.007432  0.010211   0.033056
      A_binary          50 0.004380 0.002501  0.003973   0.007251

=== BEST PROTOTYPES (top 20) ===
          mode prototype        q     q_rg  rho  sigma  n_nodes  n_binary  n_nominal
       C_mixed      1147 0.045297 0.044005  708    209     1912         1          2
       C_mixed      1036 0.039119 0.043429  384    170     1912         1          2
     D_nominal      1036 0.037775 0.037995  382    170     1912         0          2
       C_mixed       682 0.034948 0.025102  612    229     1912         1          2
       C_mixed      1179 0.034277 0.044647  193     13     1912         1          2
     D_nominal      1147 0.032746 0.039777  857    209     1912         0          2
B_multi_binary     




In [None]:
# ================================================================
# Find the top 100 subgroups based on KL-divergence quality measure
# ================================================================

import math
import json
from collections import Counter, defaultdict
import multiprocessing as mp
from functools import partial
import numpy as np
import pandas as pd
import networkx as nx
from tqdm import tqdm

# Parameters
MAX_RADIUS = 6
TOP_K = 100
_EPS = 1e-12

# Helper functions
def normalized_counts(values):
    c = Counter(values)
    total = sum(c.values())
    if total == 0:
        return {}
    return {k: v / total for k, v in c.items()}

def kl_divergence(p, q, eps=_EPS):
    s = 0.0
    for k, pv in p.items():
        qv = q.get(k, eps)
        pv = max(pv, eps)
        s += pv * math.log(pv / qv)
    return s

def make_global_distributions(G, attrs):
    global_dist = {}
    for a in attrs:
        vals = [str(G.nodes[n].get(a, None)) for n in G.nodes if G.nodes[n].get(a, None) is not None]
        global_dist[a] = normalized_counts(vals)
    return global_dist

# Core function to evaluate node quality
def compute_quality_for_node(G, node, attrs, global_dist, max_radius=MAX_RADIUS):
    best_q, best_r = -float('inf'), 0
    best_nodes = [node]

    try:
        dist = nx.single_source_shortest_path_length(G, node, cutoff=max_radius)
    except Exception:
        return node, 0, 0.0, 0, [node]

    nodes_by_d = defaultdict(list)
    for n, d in dist.items():
        nodes_by_d[d].append(n)

    accumulated = set()
    for r in range(1, max_radius + 1):
        accumulated.update(nodes_by_d.get(r, []))
        if not accumulated:
            continue
        q_sum = 0.0
        for a in attrs:
            values = [str(G.nodes[n].get(a, '')) for n in accumulated if G.nodes[n].get(a, None) is not None]
            if not values:
                continue
            p = normalized_counts(values)
            q = global_dist.get(a, {})
            q_sum += kl_divergence(p, q)
        if q_sum > best_q:
            best_q, best_r, best_nodes = q_sum, r, list(accumulated)

    return node, best_r, best_q, len(best_nodes), best_nodes

# Run search over all nodes (added min_size filter)
def find_top_k_subgroups(G, attrs, max_radius=MAX_RADIUS, top_k=TOP_K, min_size=5, processes=4):
    """Return top_k subgroups whose size >= min_size. """
    global_dist = make_global_distributions(G, attrs)
    nodes = list(G.nodes)

    worker = partial(compute_quality_for_node, G, attrs=attrs, global_dist=global_dist, max_radius=max_radius)

    results = []
    for n in tqdm(nodes, total=len(nodes), desc='Scanning nodes'):
        results.append(worker(n))

    df = pd.DataFrame(results, columns=['node', 'best_radius', 'best_q', 'size', 'members'])
    # filter by minimum subgroup size
    df = df[df['size'] >= min_size]
    df = df.sort_values(by='best_q', ascending=False).reset_index(drop=True)
    return df.head(top_k)

In [3]:
# === Run top-k subgroup search for a single feature ===
# Set FEATURE to any attribute name present in the graph nodes, e.g. 'ViewsBand'
attrs_list = ['ExplicitLanguage', 'Partner', 'ViewsBand', 'AgeBand', 'HighActivity']

for FEATURE in attrs_list:
    attrs = [FEATURE]
    print(f"Running top-k subgroup search using single feature: {FEATURE}")
    top_subgroups_single = find_top_k_subgroups(G, attrs, max_radius=MAX_RADIUS, top_k=TOP_K)
    display(top_subgroups_single.head())
    top_subgroups_single.to_csv(f'top100_subgroups_{FEATURE}.csv', index=False)

Running top-k subgroup search using single feature: ExplicitLanguage


Scanning nodes: 100%|██████████| 1912/1912 [00:09<00:00, 199.32it/s]


Unnamed: 0,node,best_radius,best_q,size,members
0,1544,1,1.062151,5,"[67, 868, 1494, 315, 11]"
1,1849,1,1.062151,6,"[227, 1837, 1137, 775, 1545, 996]"
2,108,1,0.673274,10,"[92, 809, 67, 1354, 1287, 212, 1311, 1337, 751..."
3,1642,1,0.595293,31,"[110, 67, 536, 1361, 1422, 1036, 1736, 6, 188,..."
4,1152,1,0.5609,14,"[1286, 407, 267, 1106, 1667, 747, 530, 1090, 7..."


Running top-k subgroup search using single feature: Partner


Scanning nodes: 100%|██████████| 1912/1912 [00:09<00:00, 200.80it/s]


Unnamed: 0,node,best_radius,best_q,size,members
0,1542,1,1.924693,5,"[127, 290, 1320, 1739, 45]"
1,1149,1,1.924693,11,"[656, 1366, 1631, 1476, 1280, 428, 1297, 467, ..."
2,126,1,1.924693,39,"[1476, 67, 36, 781, 455, 725, 1533, 122, 369, ..."
3,476,1,1.924693,6,"[488, 1533, 622, 1394, 496, 798]"
4,1309,1,1.924693,5,"[1821, 465, 1699, 471, 467]"


Running top-k subgroup search using single feature: ViewsBand


Scanning nodes: 100%|██████████| 1912/1912 [00:08<00:00, 222.62it/s]


Unnamed: 0,node,best_radius,best_q,size,members
0,1197,1,1.078933,29,"[1366, 1543, 1476, 455, 1137, 1739, 306, 1555,..."
1,149,1,1.078933,11,"[488, 25, 1669, 880, 1908, 928, 637, 429, 1552..."
2,401,1,1.078933,6,"[1476, 127, 290, 1660, 1721, 103]"
3,890,1,1.078933,5,"[127, 290, 1660, 1320, 103]"
4,1910,1,1.078933,15,"[1476, 1821, 127, 290, 905, 428, 471, 1758, 17..."


Running top-k subgroup search using single feature: AgeBand


Scanning nodes: 100%|██████████| 1912/1912 [00:09<00:00, 209.73it/s]


Unnamed: 0,node,best_radius,best_q,size,members
0,292,1,1.078933,5,"[1476, 1196, 460, 682, 1374]"
1,1305,1,1.078933,7,"[738, 287, 122, 88, 1035, 1479, 1421]"
2,100,1,1.078933,10,"[309, 208, 265, 1188, 414, 231, 939, 1421, 885..."
3,1346,1,1.078933,6,"[1466, 1731, 1287, 274, 1721, 305]"
4,1816,1,1.078933,7,"[1666, 1409, 1232, 1287, 1706, 195, 1755]"


Running top-k subgroup search using single feature: HighActivity


Scanning nodes: 100%|██████████| 1912/1912 [00:09<00:00, 197.69it/s]


Unnamed: 0,node,best_radius,best_q,size,members
0,144,1,0.69734,6,"[403, 689, 1297, 1762, 1643, 467]"
1,1905,1,0.69734,5,"[1476, 127, 1297, 851, 1259]"
2,1048,1,0.69734,10,"[1119, 1884, 1118, 394, 1232, 484, 779, 622, 1..."
3,1355,1,0.69734,7,"[471, 982, 1147, 176, 1196, 1311, 1374]"
4,1521,1,0.69734,10,"[1424, 1476, 127, 290, 857, 1660, 1739, 798, 1..."
