# Pairwise PID Synergy Matrices on Benchmark Datasets

This notebook computes **Partial Information Decomposition (PID)** synergy matrices on benchmark classification datasets. It decomposes the mutual information between feature pairs and a target variable into four atoms: **synergy**, **redundancy**, **unique information from feature 0**, and **unique information from feature 1**.

**Key analyses:**
1. **XOR Validation** — verifies PID computation on the canonical XOR distribution (expected synergy ≈ 1.0)
2. **Pairwise Synergy Matrix** — computes PID for all feature pairs using BROJA or MMI estimators
3. **Synergy vs MI Comparison** — measures Jaccard overlap between top synergy pairs and top MI features
4. **Stability Analysis** — assesses reproducibility via subsampled Spearman correlations
5. **Synergy Graph** — constructs a graph with edges between high-synergy feature pairs

**Method:** PID_BROJA for small datasets (≤100 pairs), PID_MMI for larger ones, with Co-Information baseline.

In [None]:
import subprocess, sys
def _pip(*a): subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', *a])

# Packages not pre-installed on Colab
_pip('boltons')
_pip('debtcollector')
_pip('lattices')
_pip('PLTable')
_pip('--no-deps', 'dit==1.5')

# Core packages (pre-installed on Colab, only needed locally)
if 'google.colab' not in sys.modules:
    _pip('pandas==2.2.2', 'scikit-learn==1.6.1', 'matplotlib==3.10.0', 'networkx==3.6.1', 'scipy==1.16.3')

In [None]:
# NumPy 2.0 compat shim (dit uses removed np.alltrue/np.sometrue)
import numpy as np
if not hasattr(np, "alltrue"): np.alltrue = np.all
if not hasattr(np, "sometrue"): np.sometrue = np.any

import json
import time
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from itertools import combinations
from typing import Any

import dit
from dit.pid import PID_BROJA, PID_MMI
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.feature_selection import mutual_info_classif
from scipy.stats import spearmanr
import networkx as nx

## Data Loading

Load the mini demo dataset (iris, 90 balanced samples across 3 classes) from GitHub, with local fallback.

In [None]:
GITHUB_DATA_URL = "https://raw.githubusercontent.com/AMGrobelnik/test-colab-install/master/pid_mini_demo_data.json"
import json, os

def load_data():
    try:
        import urllib.request
        with urllib.request.urlopen(GITHUB_DATA_URL) as response:
            return json.loads(response.read().decode())
    except Exception: pass
    if os.path.exists("mini_demo_data.json"):
        with open("mini_demo_data.json") as f: return json.load(f)
    raise FileNotFoundError("Could not load mini_demo_data.json")

In [None]:
data = load_data()
print(f"Loaded {len(data['datasets'])} dataset(s)")
for ds in data['datasets']:
    print(f"  {ds['dataset']}: {len(ds['examples'])} examples")

## Configuration

All tunable parameters. Start with minimum values for quick testing.

In [None]:
# ── Tunable parameters ──────────────────────────────────────────────
N_BINS = 5                    # Discretization bins
BROJA_PAIR_LIMIT = 100        # Max pairs for BROJA (else fallback to MMI)
STABILITY_SUBSAMPLES = 5      # Number of subsamples for stability analysis
STABILITY_FRACTION = 0.8      # Subsample fraction
SYNERGY_THRESHOLD_PCTL = 75   # Percentile for synergy graph threshold
MAX_FEATURES_LARGE = 30       # Subsample features for datasets with >30 features

## Dataset Parsing

Parse the loaded JSON into numpy arrays (X features, y labels).

In [None]:
def parse_datasets(raw_data: dict) -> dict[str, dict[str, Any]]:
    """Parse JSON datasets into numpy arrays {name: {X, y, feature_names, ...}}."""
    datasets = {}
    for ds_entry in raw_data["datasets"]:
        ds_name = ds_entry["dataset"]
        examples = ds_entry["examples"]
        if len(examples) == 0:
            continue

        first_input = json.loads(examples[0]["input"])
        feature_names = list(first_input.keys())
        n_features = len(feature_names)

        X_rows, y_vals = [], []
        for ex in examples:
            inp = json.loads(ex["input"])
            X_rows.append([float(inp[fn]) for fn in feature_names])
            y_vals.append(str(ex["output"]))

        X = np.array(X_rows, dtype=np.float64)
        unique_classes = sorted(set(y_vals))
        class_map = {c: i for i, c in enumerate(unique_classes)}
        y = np.array([class_map[v] for v in y_vals], dtype=np.int32)

        n_classes = int(examples[0].get("metadata_n_classes", len(unique_classes)))
        datasets[ds_name] = {
            "X": X, "y": y, "feature_names": feature_names,
            "n_classes": n_classes, "n_samples": X.shape[0],
            "n_features": n_features, "class_labels": unique_classes,
        }
        n_pairs = n_features * (n_features - 1) // 2
        print(f"  {ds_name:30s} | {X.shape[0]:5d} samples | "
              f"{n_features:3d} features | {n_classes} classes | {n_pairs} pairs")
    return datasets

datasets = parse_datasets(data)

## XOR Validation

Validate the PID computation on the canonical XOR distribution. XOR has **pure synergy** (both inputs needed to predict output), so synergy should be ≈ 1.0 and redundancy ≈ 0.0.

In [None]:
def xor_validation() -> dict:
    """Validate PID computation on XOR distribution."""
    # XOR: Y = X1 XOR X2
    xor_dist = dit.Distribution(
        ["000", "001", "010", "011", "100", "101", "110", "111"],
        [1 / 4, 0, 0, 1 / 4, 0, 1 / 4, 1 / 4, 0],
    )
    r_broja = PID_BROJA(xor_dist)
    broja_syn = r_broja.get_pi(((0, 1),))
    broja_red = r_broja.get_pi(((0,), (1,)))
    broja_u0 = r_broja.get_pi(((0,),))
    broja_u1 = r_broja.get_pi(((1,),))

    r_mmi = PID_MMI(xor_dist)
    mmi_syn = r_mmi.get_pi(((0, 1),))
    mmi_red = r_mmi.get_pi(((0,), (1,)))

    # AND: Y = X1 AND X2 (redundancy test)
    and_dist = dit.Distribution(
        ["000", "010", "100", "111"],
        [1 / 4, 1 / 4, 1 / 4, 1 / 4],
    )
    r_and = PID_BROJA(and_dist)
    and_syn = r_and.get_pi(((0, 1),))
    and_red = r_and.get_pi(((0,), (1,)))

    # Conservation check
    total_pi = broja_syn + broja_red + broja_u0 + broja_u1

    results = {
        "xor_broja_synergy": round(float(broja_syn), 6),
        "xor_broja_redundancy": round(float(broja_red), 6),
        "xor_broja_unique_0": round(float(broja_u0), 6),
        "xor_broja_unique_1": round(float(broja_u1), 6),
        "xor_mmi_synergy": round(float(mmi_syn), 6),
        "xor_mmi_redundancy": round(float(mmi_red), 6),
        "xor_conservation_total": round(float(total_pi), 6),
        "and_broja_synergy": round(float(and_syn), 6),
        "and_broja_redundancy": round(float(and_red), 6),
        "xor_synergy_pass": bool(abs(broja_syn - 1.0) < 0.01),
        "xor_redundancy_pass": bool(abs(broja_red) < 0.01),
        "conservation_pass": bool(abs(total_pi - 1.0) < 0.05),
    }

    print(f"XOR synergy (BROJA):  {broja_syn:.4f} (expect ~1.0) {'PASS' if results['xor_synergy_pass'] else 'FAIL'}")
    print(f"XOR redundancy:       {broja_red:.4f} (expect ~0.0) {'PASS' if results['xor_redundancy_pass'] else 'FAIL'}")
    print(f"Conservation:         {total_pi:.4f} (expect ~1.0) {'PASS' if results['conservation_pass'] else 'FAIL'}")
    print(f"AND synergy:          {and_syn:.4f}, AND redundancy: {and_red:.4f}")
    return results

xor_results = xor_validation()

## Helper Functions

Core computation functions: discretization, trivariate distribution building, PID synergy, and Co-Information baseline.

In [None]:
def discretize(X: np.ndarray, n_bins: int = N_BINS) -> np.ndarray:
    """Quantile-based discretization of continuous features."""
    X_disc = np.zeros_like(X, dtype=np.int32)
    for col in range(X.shape[1]):
        col_data = X[:, col]
        n_unique = len(np.unique(col_data))
        if n_unique <= 1:
            X_disc[:, col] = 0
        else:
            actual_bins = min(n_bins, n_unique)
            disc = KBinsDiscretizer(
                n_bins=actual_bins, encode="ordinal",
                strategy="quantile", subsample=None,
            )
            X_disc[:, col] = disc.fit_transform(col_data.reshape(-1, 1)).ravel().astype(np.int32)
    return X_disc


def build_trivariate_dist(xi: np.ndarray, xj: np.ndarray, y: np.ndarray):
    """Build a dit.Distribution from three discrete arrays."""
    counts = Counter(zip(xi.tolist(), xj.tolist(), y.tolist()))
    total = sum(counts.values())
    max_val = max(max(xi), max(xj), max(y))
    if max_val >= 10:
        outcomes = [(a, b, c) for (a, b, c) in counts.keys()]
        probs = [count / total for count in counts.values()]
        return dit.Distribution(outcomes, probs)
    else:
        outcomes = [f"{a}{b}{c}" for (a, b, c) in counts.keys()]
        probs = [count / total for count in counts.values()]
        return dit.Distribution(outcomes, probs)


def compute_full_pid(xi: np.ndarray, xj: np.ndarray, y: np.ndarray, use_broja: bool = True) -> dict:
    """Compute full PID decomposition and return all atoms."""
    PID_Cls = PID_BROJA if use_broja else PID_MMI
    d = build_trivariate_dist(xi, xj, y)
    result = PID_Cls(d)
    return {
        "synergy": float(result.get_pi(((0, 1),))),
        "unique_0": float(result.get_pi(((0,),))),
        "unique_1": float(result.get_pi(((1,),))),
        "redundancy": float(result.get_pi(((0,), (1,)))),
    }


def compute_co_information(xi: np.ndarray, xj: np.ndarray, y: np.ndarray) -> float:
    """Compute co-information (interaction information) as synergy proxy.
    Negative CoI indicates synergy. We return -CoI so positive = synergy."""
    mi_i = mutual_info_classif(xi.reshape(-1, 1), y, discrete_features=True, random_state=42)[0]
    mi_j = mutual_info_classif(xj.reshape(-1, 1), y, discrete_features=True, random_state=42)[0]
    X_pair = np.column_stack([xi, xj])
    mi_pair = mutual_info_classif(X_pair, y, discrete_features=True, random_state=42)[0]
    return -(mi_i + mi_j - mi_pair)

## Compute Pairwise Synergy Matrix

For each feature pair, compute the full PID decomposition (synergy, redundancy, unique info) and the Co-Information baseline.

In [None]:
def compute_synergy_matrix(
    X_disc: np.ndarray, y: np.ndarray, feature_names: list[str],
    dataset_name: str, feature_indices: list[int] | None = None,
) -> dict[str, Any]:
    """Compute pairwise PID synergy matrix for a dataset."""
    if feature_indices is None:
        feature_indices = list(range(X_disc.shape[1]))

    n_feat = len(feature_indices)
    n_pairs = n_feat * (n_feat - 1) // 2
    synergy_matrix = np.zeros((n_feat, n_feat), dtype=np.float64)
    coi_matrix = np.zeros((n_feat, n_feat), dtype=np.float64)
    pid_details = {}

    use_broja = n_pairs <= BROJA_PAIR_LIMIT
    pid_method = "BROJA" if use_broja else "MMI"
    pairs = list(combinations(range(n_feat), 2))

    print(f"  {dataset_name}: {n_pairs} pairs, method={pid_method}")

    t_start = time.time()
    completed, errors = 0, 0

    for li, lj in pairs:
        gi, gj = feature_indices[li], feature_indices[lj]
        xi, xj = X_disc[:, gi], X_disc[:, gj]

        try:
            pid = compute_full_pid(xi, xj, y, use_broja=use_broja)
            synergy_matrix[li, lj] = pid["synergy"]
            synergy_matrix[lj, li] = pid["synergy"]
            pair_key = f"{feature_names[gi]}__x__{feature_names[gj]}"
            pid_details[pair_key] = pid
            completed += 1
        except Exception:
            errors += 1

        try:
            coi_val = compute_co_information(xi, xj, y)
            coi_matrix[li, lj] = coi_val
            coi_matrix[lj, li] = coi_val
        except Exception:
            pass

    total_time = time.time() - t_start
    print(f"  {dataset_name}: {completed}/{n_pairs} pairs in {total_time:.1f}s ({errors} errors)")

    return {
        "synergy_matrix": synergy_matrix.tolist(),
        "coi_matrix": coi_matrix.tolist(),
        "pid_method": pid_method,
        "n_pairs": n_pairs, "completed_pairs": completed, "errors": errors,
        "total_time_s": round(total_time, 2),
        "pid_details": pid_details,
        "feature_indices_used": feature_indices,
    }

## Synergy vs MI Comparison

Compare top synergy pairs against pairs formed from top MI features. Low Jaccard overlap confirms synergy captures different information than MI.

In [None]:
def compare_synergy_vs_mi(
    synergy_matrix: np.ndarray, mi_values: np.ndarray,
    feature_names: list[str], feature_indices: list[int], top_k: int = 10,
) -> dict:
    """Compare top synergy pairs vs pairs formed from top MI features."""
    n_feat = len(feature_indices)
    pairs = list(combinations(range(n_feat), 2))

    synergy_vals = np.array([synergy_matrix[i, j] for (i, j) in pairs])

    k = min(top_k, len(pairs))
    top_synergy_idx = np.argsort(synergy_vals)[-k:]
    top_synergy_pairs = set()
    for idx in top_synergy_idx:
        i, j = pairs[idx]
        top_synergy_pairs.add((feature_indices[i], feature_indices[j]))

    mi_sub = mi_values[feature_indices]
    top_mi_feat = np.argsort(mi_sub)[-k:]
    top_mi_pairs = set()
    for i, j in combinations(sorted(top_mi_feat), 2):
        top_mi_pairs.add((feature_indices[i], feature_indices[j]))

    intersection = top_synergy_pairs & top_mi_pairs
    union = top_synergy_pairs | top_mi_pairs
    jaccard = len(intersection) / max(len(union), 1)

    max_synergy_per_feat = np.max(synergy_matrix, axis=1)
    mi_sub_arr = mi_values[feature_indices]
    rho, pval = spearmanr(max_synergy_per_feat, mi_sub_arr) if len(mi_sub_arr) > 2 else (0.0, 1.0)

    top_synergy_named = []
    for idx in reversed(np.argsort(synergy_vals)[-min(5, len(synergy_vals)):]):
        i, j = pairs[idx]
        gi, gj = feature_indices[i], feature_indices[j]
        top_synergy_named.append({
            "feature_i": feature_names[gi], "feature_j": feature_names[gj],
            "synergy": round(float(synergy_vals[idx]), 6),
        })

    return {
        "jaccard_overlap": round(jaccard, 4),
        "spearman_rho": round(float(rho), 4) if not np.isnan(rho) else 0.0,
        "spearman_pval": round(float(pval), 6) if not np.isnan(pval) else 1.0,
        "top_synergy_pairs": top_synergy_named,
        "n_overlap": len(intersection),
    }

## Stability Analysis & Synergy Graph

Assess reproducibility by computing synergy on random subsamples, then build a synergy graph with edges between high-synergy pairs.

In [None]:
def stability_analysis(
    X_disc: np.ndarray, y: np.ndarray, feature_names: list[str],
    dataset_name: str, feature_indices: list[int],
    n_subsamples: int = STABILITY_SUBSAMPLES,
    subsample_frac: float = STABILITY_FRACTION,
) -> dict:
    """Compute synergy matrices on random subsamples and assess stability."""
    n_feat = len(feature_indices)
    n_pairs = n_feat * (n_feat - 1) // 2
    n_samples = X_disc.shape[0]
    subsample_size = int(n_samples * subsample_frac)

    rng = np.random.RandomState(42)
    upper_triangles = []

    for s in range(n_subsamples):
        indices = rng.choice(n_samples, size=subsample_size, replace=False)
        X_sub, y_sub = X_disc[indices], y[indices]

        syn_mat = np.zeros((n_feat, n_feat))
        pairs = list(combinations(range(n_feat), 2))
        for (li, lj) in pairs:
            gi, gj = feature_indices[li], feature_indices[lj]
            try:
                d = build_trivariate_dist(X_sub[:, gi], X_sub[:, gj], y_sub)
                result = PID_MMI(d)
                syn = float(result.get_pi(((0, 1),)))
                syn_mat[li, lj] = syn
                syn_mat[lj, li] = syn
            except Exception:
                pass
        upper_triangles.append([syn_mat[li, lj] for (li, lj) in pairs])

    correlations = []
    for i in range(n_subsamples):
        for j in range(i + 1, n_subsamples):
            if len(upper_triangles[i]) > 2:
                rho, _ = spearmanr(upper_triangles[i], upper_triangles[j])
                if not np.isnan(rho):
                    correlations.append(rho)

    mean_corr = float(np.mean(correlations)) if correlations else 0.0
    std_corr = float(np.std(correlations)) if correlations else 0.0
    print(f"  Stability {dataset_name}: mean_rho={mean_corr:.4f} +/- {std_corr:.4f}")

    return {
        "n_subsamples": n_subsamples, "subsample_fraction": subsample_frac,
        "mean_spearman": round(mean_corr, 4), "std_spearman": round(std_corr, 4),
        "all_correlations": [round(c, 4) for c in correlations],
    }


def build_synergy_graph(
    synergy_matrix: np.ndarray, feature_names: list[str],
    feature_indices: list[int], threshold_percentile: int = SYNERGY_THRESHOLD_PCTL,
) -> dict:
    """Construct synergy graph from thresholded synergy matrix."""
    n_feat = len(feature_indices)
    pairs = list(combinations(range(n_feat), 2))
    syn_vals = [synergy_matrix[i, j] for (i, j) in pairs]

    if not syn_vals or max(syn_vals) == 0:
        return {"threshold": 0.0, "n_edges": 0, "n_nodes": n_feat,
                "n_components": n_feat, "largest_clique_size": 0, "top_5_edges": []}

    threshold = float(np.percentile(syn_vals, threshold_percentile))

    G = nx.Graph()
    for idx in range(n_feat):
        G.add_node(feature_names[feature_indices[idx]])
    for (li, lj) in pairs:
        if synergy_matrix[li, lj] >= threshold:
            gi, gj = feature_indices[li], feature_indices[lj]
            G.add_edge(feature_names[gi], feature_names[gj],
                       weight=synergy_matrix[li, lj])

    n_components = nx.number_connected_components(G)
    try:
        cliques = list(nx.find_cliques(G))
        largest_clique_size = max(len(c) for c in cliques) if cliques else 0
    except Exception:
        largest_clique_size = 0

    edges_sorted = sorted(G.edges(data=True), key=lambda e: e[2].get("weight", 0), reverse=True)
    top_5 = [{"feature_i": e[0], "feature_j": e[1], "synergy": round(e[2]["weight"], 6)}
             for e in edges_sorted[:5]]

    return {"threshold": round(threshold, 6), "n_edges": G.number_of_edges(),
            "n_nodes": G.number_of_nodes(), "n_components": n_components,
            "largest_clique_size": largest_clique_size, "top_5_edges": top_5}

## Run Full Pipeline

Process each dataset: discretize → synergy matrix → MI comparison → stability → synergy graph.

In [None]:
t_global_start = time.time()
all_results = []

for ds_name, ds_info in datasets.items():
    print(f"\n{'='*50}")
    print(f"Processing: {ds_name}")
    print(f"{'='*50}")

    X, y = ds_info["X"], ds_info["y"]
    feature_names = ds_info["feature_names"]
    n_features = ds_info["n_features"]

    # Subsample features for large datasets
    feature_indices = list(range(n_features))
    if n_features > MAX_FEATURES_LARGE:
        mi_all = mutual_info_classif(
            discretize(X, n_bins=N_BINS), y, discrete_features=True, random_state=42
        )
        top_feat_idx = np.argsort(mi_all)[-MAX_FEATURES_LARGE:]
        feature_indices = sorted(top_feat_idx.tolist())

    # Discretize
    X_disc = discretize(X, n_bins=N_BINS)

    # Synergy matrix
    synergy_result = compute_synergy_matrix(
        X_disc=X_disc, y=y, feature_names=feature_names,
        dataset_name=ds_name, feature_indices=feature_indices,
    )

    # MI comparison
    mi_values = mutual_info_classif(X_disc, y, discrete_features=True, random_state=42)
    mi_comparison = compare_synergy_vs_mi(
        synergy_matrix=np.array(synergy_result["synergy_matrix"]),
        mi_values=mi_values, feature_names=feature_names,
        feature_indices=feature_indices,
    )

    # Synergy graph
    graph_result = build_synergy_graph(
        synergy_matrix=np.array(synergy_result["synergy_matrix"]),
        feature_names=feature_names, feature_indices=feature_indices,
    )

    # Stability analysis
    stability_result = stability_analysis(
        X_disc=X_disc, y=y, feature_names=feature_names,
        dataset_name=ds_name, feature_indices=feature_indices,
    )

    # MI values for features used
    mi_dict = {feature_names[idx]: round(float(mi_values[idx]), 6)
               for idx in feature_indices}

    all_results.append({
        "dataset": ds_name,
        "n_samples": ds_info["n_samples"], "n_features": n_features,
        "n_features_used": len(feature_indices), "n_classes": ds_info["n_classes"],
        "synergy": synergy_result, "mi_values": mi_dict,
        "mi_comparison": mi_comparison, "synergy_graph": graph_result,
        "stability": stability_result,
    })

total_time = time.time() - t_global_start
print(f"\nDone in {total_time:.1f}s — {len(all_results)} dataset(s) processed")

## Results Visualization

Display key results: PID decomposition heatmap, synergy vs Co-Information scatter, and summary table.

In [None]:
for r in all_results:
    ds_name = r["dataset"]
    syn_mat = np.array(r["synergy"]["synergy_matrix"])
    coi_mat = np.array(r["synergy"]["coi_matrix"])
    feat_indices = r["synergy"]["feature_indices_used"]
    feat_names = [r["mi_values"] and list(r["mi_values"].keys())]
    feat_labels = list(r["mi_values"].keys())

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    fig.suptitle(f"PID Synergy Analysis: {ds_name}", fontsize=14, fontweight="bold")

    # 1. Synergy heatmap
    ax = axes[0]
    im = ax.imshow(syn_mat, cmap="YlOrRd", aspect="auto")
    ax.set_xticks(range(len(feat_labels)))
    ax.set_xticklabels(feat_labels, rotation=45, ha="right", fontsize=8)
    ax.set_yticks(range(len(feat_labels)))
    ax.set_yticklabels(feat_labels, fontsize=8)
    ax.set_title("Synergy Matrix")
    fig.colorbar(im, ax=ax, shrink=0.8)

    # 2. Synergy vs Co-Information scatter
    ax = axes[1]
    pairs = list(combinations(range(syn_mat.shape[0]), 2))
    syn_vals = [syn_mat[i, j] for i, j in pairs]
    coi_vals = [coi_mat[i, j] for i, j in pairs]
    ax.scatter(coi_vals, syn_vals, alpha=0.7, edgecolors="k", linewidth=0.5)
    ax.set_xlabel("Co-Information (\u2212CoI \u2192 synergy)")
    ax.set_ylabel("PID Synergy")
    ax.set_title("Synergy vs Co-Information")
    ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
    ax.axvline(x=0, color="gray", linestyle="--", alpha=0.5)

    # 3. PID decomposition bar chart
    ax = axes[2]
    pid_details = r["synergy"]["pid_details"]
    pair_labels = list(pid_details.keys())[:6]  # Top 6 pairs
    synergies = [pid_details[k]["synergy"] for k in pair_labels]
    redundancies = [pid_details[k]["redundancy"] for k in pair_labels]
    unique_0s = [pid_details[k]["unique_0"] for k in pair_labels]
    unique_1s = [pid_details[k]["unique_1"] for k in pair_labels]
    x = np.arange(len(pair_labels))
    width = 0.2
    ax.bar(x - 1.5*width, synergies, width, label="Synergy", color="#e74c3c")
    ax.bar(x - 0.5*width, redundancies, width, label="Redundancy", color="#3498db")
    ax.bar(x + 0.5*width, unique_0s, width, label="Unique\u2080", color="#2ecc71")
    ax.bar(x + 1.5*width, unique_1s, width, label="Unique\u2081", color="#f39c12")
    short_labels = [lbl.replace("__x__", "\n\u00d7\n") for lbl in pair_labels]
    ax.set_xticks(x)
    ax.set_xticklabels(short_labels, fontsize=6, rotation=45, ha="right")
    ax.set_ylabel("Information (bits)")
    ax.set_title("PID Decomposition per Pair")
    ax.legend(fontsize=7, loc="upper right")

    plt.tight_layout()
    plt.show()

    # Summary table
    print(f"\n{'='*60}")
    print(f"Summary: {ds_name}")
    print(f"{'='*60}")
    print(f"  PID method:           {r['synergy']['pid_method']}")
    print(f"  Pairs computed:       {r['synergy']['completed_pairs']}/{r['synergy']['n_pairs']}")
    print(f"  Errors:               {r['synergy']['errors']}")
    print(f"  Computation time:     {r['synergy']['total_time_s']:.2f}s")
    print(f"  Jaccard overlap:      {r['mi_comparison']['jaccard_overlap']:.4f}")
    print(f"  Spearman rho:         {r['mi_comparison']['spearman_rho']:.4f}")
    print(f"  Graph edges:          {r['synergy_graph']['n_edges']}")
    print(f"  Graph components:     {r['synergy_graph']['n_components']}")
    print(f"  Largest clique:       {r['synergy_graph']['largest_clique_size']}")
    if r["stability"]:
        print(f"  Stability rho:        {r['stability']['mean_spearman']:.4f} \u00b1 {r['stability']['std_spearman']:.4f}")
    print(f"\n  Top synergy pairs:")
    for p in r["mi_comparison"]["top_synergy_pairs"][:5]:
        print(f"    {p['feature_i']:25s} \u00d7 {p['feature_j']:25s} \u2192 {p['synergy']:.6f}")

    # XOR validation summary
    print(f"\n{'='*60}")
    print(f"XOR Validation Results")
    print(f"{'='*60}")
    for k, v in xor_results.items():
        print(f"  {k:30s}: {v}")