# Pairwise PID Synergy Graph Construction Across 12 Benchmark Datasets

This notebook demonstrates the computation of **Partial Information Decomposition (PID) synergy scores** (Williams-Beer I_min) for all feature pairs across tabular classification datasets, followed by **synergy graph construction** and analysis.

**What this artifact does:**
- Computes pairwise PID synergy scores for all feature pairs using a manual MI-based PID implementation
- Constructs synergy graphs at 3 threshold levels (top-10%, top-25%, above-median) with NetworkX
- Analyzes graph structure: cliques, connected components, clustering coefficient
- Compares PID synergy against a baseline interaction information measure
- Evaluates discretization stability (5-bin vs 10-bin) using Spearman/Jaccard metrics
- Checks domain-meaningful interactions for diabetes, breast cancer, and heart datasets

---

**Part 1** — Quick Demo using a curated subset (5 datasets, runs in seconds)

**Part 2** — Full Run using all 12 datasets with original parameters

In [None]:
import json
import os
import time
import warnings
from collections import Counter
from itertools import combinations

import numpy as np
from scipy.stats import pearsonr, spearmanr
import networkx as nx
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore", category=FutureWarning)

## Data Loading

Load pre-computed PID synergy results from JSON. The data was produced by the full pipeline which computed pairwise PID decomposition (synergy, redundancy, unique information) for all feature pairs across 12 benchmark datasets.

In [None]:
GITHUB_FULL_DATA_URL = "https://raw.githubusercontent.com/AMGrobelnik/ai-invention-ac2586-synergy-guided-oblique-splits-using-part/main/pid_synergy_gra/demo/full_demo_data.json"
GITHUB_MINI_DATA_URL = "https://raw.githubusercontent.com/AMGrobelnik/ai-invention-ac2586-synergy-guided-oblique-splits-using-part/main/pid_synergy_gra/demo/mini_demo_data.json"

def _load_json(url, local_path):
    try:
        import urllib.request
        with urllib.request.urlopen(url) as response:
            return json.loads(response.read().decode())
    except Exception: pass
    if os.path.exists(local_path):
        with open(local_path) as f: return json.load(f)
    raise FileNotFoundError(f"Could not load {local_path}")

def load_mini():
    return _load_json(GITHUB_MINI_DATA_URL, "mini_demo_data.json")

def load_full():
    return _load_json(GITHUB_FULL_DATA_URL, "full_demo_data.json")

---
# Part 1 — Quick Demo (Mini Data)

Uses a curated subset of 5 datasets (banknote-authentication, diabetes, glass, wine, heart-statlog) for fast execution.

In [None]:
data = load_mini()
print(f"Loaded {len(data['datasets'])} datasets")
for ds in data["datasets"]:
    print(f"  {ds['dataset']}: {len(ds['examples'])} feature pairs")

## Parse PID Results

Extract PID decomposition results from the loaded JSON into per-dataset lists of dictionaries, matching the format used by the original pipeline's graph construction and analysis functions.

In [None]:
def parse_pid_results(data):
    """Parse loaded JSON data into per-dataset PID result dicts.
    
    Returns dict mapping dataset_name -> {
        'pid_results': list of PID result dicts,
        'feature_names': list of unique feature names,
        'n_features': int, 'n_samples': int, 'n_classes': int,
        'stability_spearman': float, 'stability_pearson': float,
    }
    """
    all_datasets = {}
    for ds_block in data["datasets"]:
        ds_name = ds_block["dataset"]
        examples = ds_block["examples"]
        if not examples:
            continue

        pid_results = []
        feature_set = set()
        for ex in examples:
            inp = json.loads(ex["input"])
            feature_set.add(inp["feature_i"])
            feature_set.add(inp["feature_j"])
            pid_results.append({
                "feature_i": inp["feature_i"],
                "feature_j": inp["feature_j"],
                "synergy": ex["metadata_synergy"],
                "unique_i": ex["metadata_unique_i"],
                "unique_j": ex["metadata_unique_j"],
                "redundancy": ex["metadata_redundancy"],
                "joint_mi": ex["metadata_joint_mi"],
                "mi_i": ex["metadata_mi_i"],
                "mi_j": ex["metadata_mi_j"],
                "baseline_interaction": ex["metadata_baseline_interaction"],
                "n_bins": ex["metadata_bin_level"],
            })

        first = examples[0]
        all_datasets[ds_name] = {
            "pid_results": pid_results,
            "feature_names": sorted(feature_set),
            "n_features": first["metadata_n_features"],
            "n_samples": first["metadata_n_samples"],
            "n_classes": first["metadata_n_classes"],
            "stability_spearman": first.get("metadata_stability_spearman", None),
            "stability_pearson": first.get("metadata_stability_pearson", None),
        }

    return all_datasets

all_datasets = parse_pid_results(data)
print(f"Parsed {len(all_datasets)} datasets:")
for name, info in all_datasets.items():
    print(f"  {name}: {info['n_features']} features, "
          f"{len(info['pid_results'])} pairs, "
          f"{info['n_samples']} samples, "
          f"{info['n_classes']} classes")

## Synergy Graph Construction & Analysis

Build synergy graphs by thresholding pairwise synergy scores at 3 levels (top-10%, top-25%, above-median). Analyze graph structure including edge count, density, connected components, largest clique, and clustering coefficient.

In [None]:
# ── Constants (from original) ────────────────────────────────────────────────
THRESHOLDS = {"top_10pct": 0.90, "top_25pct": 0.75, "above_median": 0.50}

# ── Domain knowledge pairs for checking (from original) ─────────────────────
DOMAIN_PAIRS = {
    "diabetes": [("plas", "insu"), ("plas", "mass"), ("mass", "age")],
    "breast_cancer_wisconsin": [
        ("mean radius", "mean perimeter"),
        ("mean area", "mean concavity"),
    ],
    "heart-statlog": [
        ("chest", "maximum_heart_rate_achieved"),
        ("age", "oldpeak"),
    ],
}


def build_synergy_graph(
    pid_results: list,
    feature_names: list,
    threshold_quantile: float,
) -> tuple:
    """Build a synergy graph with edges above the given quantile threshold."""
    synergy_values = [r["synergy"] for r in pid_results]
    if not synergy_values:
        return nx.Graph(), 0.0

    threshold_value = float(np.quantile(synergy_values, threshold_quantile))

    G = nx.Graph()
    G.add_nodes_from(feature_names)
    for r in pid_results:
        if r["synergy"] >= threshold_value:
            G.add_edge(r["feature_i"], r["feature_j"], weight=r["synergy"])

    return G, threshold_value


def analyze_graph(G: nx.Graph) -> dict:
    """Compute graph statistics."""
    n_nodes = G.number_of_nodes()
    n_edges = G.number_of_edges()
    nodes_with_edges = sum(1 for n in G.nodes() if G.degree(n) > 0)

    if n_nodes < 2:
        return {
            "n_edges": n_edges,
            "n_nodes_with_edges": nodes_with_edges,
            "density": 0.0,
            "n_connected_components": 0,
            "largest_component_size": 0,
            "largest_clique_size": 0,
            "mean_degree": 0.0,
            "max_degree": 0,
            "clustering_coefficient": 0.0,
        }

    max_edges = n_nodes * (n_nodes - 1) / 2
    density = n_edges / max_edges if max_edges > 0 else 0.0

    components = list(nx.connected_components(G))
    component_sizes = [len(c) for c in components]

    cliques = list(nx.find_cliques(G)) if n_edges > 0 else []
    clique_sizes = [len(c) for c in cliques] if cliques else [0]

    degrees = [d for _, d in G.degree()]
    mean_degree = float(np.mean(degrees)) if degrees else 0.0
    max_degree = max(degrees) if degrees else 0

    cc = nx.average_clustering(G) if n_edges > 0 else 0.0

    return {
        "n_edges": n_edges,
        "n_nodes_with_edges": nodes_with_edges,
        "density": round(density, 4),
        "n_connected_components": len(components),
        "largest_component_size": max(component_sizes) if component_sizes else 0,
        "largest_clique_size": max(clique_sizes),
        "mean_degree": round(mean_degree, 2),
        "max_degree": max_degree,
        "clustering_coefficient": round(cc, 4),
    }

print("Graph construction functions defined.")

### Build Graphs for All Datasets

For each dataset, construct synergy graphs at all 3 threshold levels and compute graph statistics.

In [None]:
t_start = time.time()
graph_results = {}

for ds_name, ds_info in all_datasets.items():
    pid_results = ds_info["pid_results"]
    feature_names = ds_info["feature_names"]

    graph_stats = {}
    for thr_name, thr_q in THRESHOLDS.items():
        G, thr_val = build_synergy_graph(
            pid_results=pid_results,
            feature_names=feature_names,
            threshold_quantile=thr_q,
        )
        stats = analyze_graph(G)
        stats["threshold_value"] = round(thr_val, 6)
        graph_stats[thr_name] = stats

    graph_results[ds_name] = {
        "graph_stats": graph_stats,
        "pid_results": pid_results,
        "ds_info": ds_info,
    }

elapsed = time.time() - t_start
print(f"Built synergy graphs for {len(graph_results)} datasets in {elapsed:.2f}s\n")

# Print summary table
print(f"{'Dataset':<28} {'Thr':<12} {'Edges':>5} {'Density':>7} {'MaxClq':>6} {'ClustC':>6}")
print("-" * 70)
for ds_name, gr in graph_results.items():
    for thr_name, stats in gr["graph_stats"].items():
        print(f"{ds_name:<28} {thr_name:<12} {stats['n_edges']:>5} "
              f"{stats['density']:>7.4f} {stats['largest_clique_size']:>6} "
              f"{stats['clustering_coefficient']:>6.4f}")

## Domain-Meaningful Interaction Checks

For datasets with known meaningful feature interactions (diabetes, breast cancer, heart), check how these domain-known pairs rank in the synergy ordering. This validates that PID synergy captures scientifically relevant interactions.

In [None]:
def check_domain_pairs(
    pid_results: list, dataset_name: str
) -> list:
    """Check how known-meaningful pairs rank in synergy ordering."""
    if dataset_name not in DOMAIN_PAIRS:
        return []

    # Sort by synergy descending
    sorted_results = sorted(pid_results, key=lambda r: r["synergy"], reverse=True)
    pair_to_rank = {}
    for rank, r in enumerate(sorted_results, 1):
        key = (r["feature_i"], r["feature_j"])
        pair_to_rank[key] = rank
        # Also store reverse
        pair_to_rank[(r["feature_j"], r["feature_i"])] = rank

    checks = []
    total_pairs = len(sorted_results)
    for fi, fj in DOMAIN_PAIRS[dataset_name]:
        rank = pair_to_rank.get((fi, fj), None)
        if rank is None:
            rank = pair_to_rank.get((fj, fi), None)
        checks.append({
            "feature_i": fi,
            "feature_j": fj,
            "synergy_rank": rank,
            "total_pairs": total_pairs,
            "in_top_10pct": rank is not None and rank <= max(1, int(total_pairs * 0.1)),
            "in_top_25pct": rank is not None and rank <= max(1, int(total_pairs * 0.25)),
        })

    return checks


# Run domain checks
for ds_name, gr in graph_results.items():
    domain_checks = check_domain_pairs(
        pid_results=gr["pid_results"],
        dataset_name=ds_name,
    )
    if domain_checks:
        print(f"\n{ds_name} — Domain-Meaningful Pair Checks:")
        for dc in domain_checks:
            rank_str = f"{dc['synergy_rank']}/{dc['total_pairs']}" if dc['synergy_rank'] else "not found"
            print(f"  {dc['feature_i']} <-> {dc['feature_j']}: "
                  f"rank {rank_str} "
                  f"(top-10%: {dc['in_top_10pct']}, top-25%: {dc['in_top_25pct']})")

## Visualization

Plot key results: (1) synergy distribution per dataset, (2) graph density across thresholds, (3) synergy vs baseline interaction correlation, and (4) stability metrics summary.

In [None]:
def visualize_results(graph_results, all_datasets):
    """Reusable visualization function for PID synergy graph results."""
    ds_names = list(graph_results.keys())
    n_datasets = len(ds_names)

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle("PID Synergy Graph Analysis", fontsize=14, fontweight="bold")

    # ── Plot 1: Synergy distribution per dataset ────────────────────────────
    ax = axes[0, 0]
    synergy_data = []
    labels = []
    for ds_name in ds_names:
        values = [r["synergy"] for r in graph_results[ds_name]["pid_results"]]
        synergy_data.append(values)
        labels.append(ds_name.replace("_", "\n")[:15])
    bp = ax.boxplot(synergy_data, labels=labels, patch_artist=True)
    colors = plt.cm.Set3(np.linspace(0, 1, n_datasets))
    for patch, color in zip(bp["boxes"], colors):
        patch.set_facecolor(color)
    ax.set_title("Synergy Distribution per Dataset")
    ax.set_ylabel("Synergy (bits)")
    ax.tick_params(axis="x", rotation=45, labelsize=8)

    # ── Plot 2: Graph density across thresholds ─────────────────────────────
    ax = axes[0, 1]
    x_pos = np.arange(n_datasets)
    width = 0.25
    thr_names = list(THRESHOLDS.keys())
    for i, thr_name in enumerate(thr_names):
        densities = [
            graph_results[ds]["graph_stats"][thr_name]["density"]
            for ds in ds_names
        ]
        ax.bar(x_pos + i * width, densities, width, label=thr_name)
    ax.set_title("Graph Density by Threshold")
    ax.set_ylabel("Density")
    ax.set_xticks(x_pos + width)
    ax.set_xticklabels([n.replace("_", "\n")[:15] for n in ds_names],
                       rotation=45, fontsize=8)
    ax.legend(fontsize=8)

    # ── Plot 3: Synergy vs baseline interaction ─────────────────────────────
    ax = axes[1, 0]
    for ds_name in ds_names:
        pid_res = graph_results[ds_name]["pid_results"]
        syn = [r["synergy"] for r in pid_res]
        bl = [r["baseline_interaction"] for r in pid_res]
        ax.scatter(syn, bl, alpha=0.5, s=15,
                   label=ds_name[:12])
    ax.set_xlabel("PID Synergy (bits)")
    ax.set_ylabel("Baseline Interaction (bits)")
    ax.set_title("PID Synergy vs Baseline Interaction")
    ax.legend(fontsize=7, loc="best")
    # Add diagonal reference
    lims = [min(ax.get_xlim()[0], ax.get_ylim()[0]),
            max(ax.get_xlim()[1], ax.get_ylim()[1])]
    ax.plot(lims, lims, "k--", alpha=0.3, linewidth=1)

    # ── Plot 4: Stability metrics ───────────────────────────────────────────
    ax = axes[1, 1]
    stab_names = []
    spearman_vals = []
    pearson_vals = []
    for ds_name in ds_names:
        info = all_datasets[ds_name]
        if info.get("stability_spearman") is not None:
            stab_names.append(ds_name.replace("_", "\n")[:15])
            spearman_vals.append(info["stability_spearman"])
            pearson_vals.append(info["stability_pearson"])
    if stab_names:
        x_pos = np.arange(len(stab_names))
        ax.bar(x_pos - 0.15, spearman_vals, 0.3, label="Spearman rho")
        ax.bar(x_pos + 0.15, pearson_vals, 0.3, label="Pearson r")
        ax.set_xticks(x_pos)
        ax.set_xticklabels(stab_names, rotation=45, fontsize=8)
        ax.set_ylabel("Correlation")
        ax.set_title("Discretization Stability (5-bin vs 10-bin)")
        ax.legend(fontsize=8)
        ax.set_ylim(0, 1.1)
    else:
        ax.text(0.5, 0.5, "No stability data",
                ha="center", va="center", transform=ax.transAxes)
        ax.set_title("Discretization Stability")

    plt.tight_layout()
    plt.show()

    # ── Print summary table ─────────────────────────────────────────────────
    print(f"\n{'Dataset':<25} {'Feat':>4} {'Pairs':>5} "
          f"{'MeanSyn':>8} {'MaxSyn':>7} "
          f"{'StabSp':>6} {'StabPe':>6}")
    print("-" * 65)
    for ds_name in ds_names:
        info = all_datasets[ds_name]
        pid_res = graph_results[ds_name]["pid_results"]
        syn_vals = [r["synergy"] for r in pid_res]
        sp = info.get("stability_spearman", None)
        pe = info.get("stability_pearson", None)
        sp_str = f"{sp:.4f}" if sp is not None else "N/A"
        pe_str = f"{pe:.4f}" if pe is not None else "N/A"
        print(f"{ds_name:<25} {info['n_features']:>4} {len(pid_res):>5} "
              f"{np.mean(syn_vals):>8.4f} {np.max(syn_vals):>7.4f} "
              f"{sp_str:>6} {pe_str:>6}")


visualize_results(graph_results, all_datasets)

---
# Part 2 — Full Run (Original Parameters)

Load the complete dataset with all 12 benchmark datasets (3,597 total feature pairs) and re-run the analysis pipeline with original parameters.

In [None]:
data = load_full()
print(f"Loaded {len(data['datasets'])} datasets")
total_examples = sum(len(ds["examples"]) for ds in data["datasets"])
print(f"Total feature pairs: {total_examples}")
for ds in data["datasets"]:
    print(f"  {ds['dataset']}: {len(ds['examples'])} feature pairs")

### Parse & Build Graphs (Full Data)

In [None]:
t_start = time.time()

# Parse all datasets
all_datasets = parse_pid_results(data)

# Build synergy graphs for all 12 datasets with original thresholds
graph_results = {}
for ds_name, ds_info in all_datasets.items():
    pid_results = ds_info["pid_results"]
    feature_names = ds_info["feature_names"]

    graph_stats = {}
    for thr_name, thr_q in THRESHOLDS.items():
        G, thr_val = build_synergy_graph(
            pid_results=pid_results,
            feature_names=feature_names,
            threshold_quantile=thr_q,
        )
        stats = analyze_graph(G)
        stats["threshold_value"] = round(thr_val, 6)
        graph_stats[thr_name] = stats

    graph_results[ds_name] = {
        "graph_stats": graph_stats,
        "pid_results": pid_results,
        "ds_info": ds_info,
    }

elapsed = time.time() - t_start
print(f"Processed {len(graph_results)} datasets in {elapsed:.2f}s\n")

# Print full summary table
print(f"{'Dataset':<28} {'Thr':<12} {'Edges':>5} {'Density':>7} {'MaxClq':>6} {'ClustC':>6}")
print("-" * 70)
for ds_name, gr in graph_results.items():
    for thr_name, stats in gr["graph_stats"].items():
        print(f"{ds_name:<28} {thr_name:<12} {stats['n_edges']:>5} "
              f"{stats['density']:>7.4f} {stats['largest_clique_size']:>6} "
              f"{stats['clustering_coefficient']:>6.4f}")

### Domain Checks (Full Data)

In [None]:
# Run domain checks on full data
for ds_name, gr in graph_results.items():
    domain_checks = check_domain_pairs(
        pid_results=gr["pid_results"],
        dataset_name=ds_name,
    )
    if domain_checks:
        print(f"\n{ds_name} — Domain-Meaningful Pair Checks:")
        for dc in domain_checks:
            rank_str = f"{dc['synergy_rank']}/{dc['total_pairs']}" if dc['synergy_rank'] else "not found"
            print(f"  {dc['feature_i']} <-> {dc['feature_j']}: "
                  f"rank {rank_str} "
                  f"(top-10%: {dc['in_top_10pct']}, top-25%: {dc['in_top_25pct']})")

### Full Visualization

In [None]:
visualize_results(graph_results, all_datasets)