# Information-Theoretic Split Quality Audit for SG-FIGS

This notebook evaluates whether **SG-FIGS oblique splits** capture more target information than random feature subsets of the same size, directly testing the PID synergy premise.

**Six metrics are computed per oblique split:**
1. **Synergy Concentration Ratio (SCR)** — mean split synergy / mean all synergy (>1 = concentrates high-synergy pairs)
2. **Above-Median Fraction** — fraction of split feature pairs with above-median synergy (random baseline = 0.5)
3. **Impurity Reduction per Feature** — oblique impurity per feature vs best axis-aligned split
4. **Redundancy Load** — mean redundancy among split features vs expected random
5. **Joint MI Coverage** — total joint MI of split pairs vs sum of individual MI
6. **Synergy-Weighted Efficiency** — (impurity × mean_synergy) / best_axis_impurity

**Part 1** runs a quick demo on a curated subset (3 datasets). **Part 2** runs on all 9 dataset×threshold combinations (144 examples).

In [None]:
import json
import math
from itertools import combinations
from typing import Any

import matplotlib.pyplot as plt
import numpy as np

In [None]:
GITHUB_FULL_DATA_URL = "https://raw.githubusercontent.com/AMGrobelnik/ai-invention-ac2586-synergy-guided-oblique-splits-using-part/main/split_quality/demo/full_demo_data.json"
GITHUB_MINI_DATA_URL = "https://raw.githubusercontent.com/AMGrobelnik/ai-invention-ac2586-synergy-guided-oblique-splits-using-part/main/split_quality/demo/mini_demo_data.json"

import json, os

def _load_json(url, local_path):
    try:
        import urllib.request
        with urllib.request.urlopen(url) as response:
            return json.loads(response.read().decode())
    except Exception: pass
    if os.path.exists(local_path):
        with open(local_path) as f: return json.load(f)
    raise FileNotFoundError(f"Could not load {local_path}")

def load_mini():
    return _load_json(GITHUB_MINI_DATA_URL, "mini_demo_data.json")

def load_full():
    return _load_json(GITHUB_FULL_DATA_URL, "full_demo_data.json")

## Part 1 — Quick Demo (Mini Data)

Load a curated subset of 3 dataset×threshold combinations (diabetes SG-FIGS-10, heart_statlog SG-FIGS-25, breast_cancer SG-FIGS-50) to demonstrate the audit pipeline quickly.

In [None]:
data = load_mini()
print(f"Loaded {len(data['datasets'])} datasets, "
      f"{sum(len(d['examples']) for d in data['datasets'])} total examples")
print(f"Datasets: {[d['dataset'] for d in data['datasets']]}")

### Metric Extraction and Aggregation

Extract per-split quality metrics and per-fold performance context from each dataset block. The data contains two types of examples:
- **Split quality** examples (with `eval_scr`, `eval_above_median_fraction`, etc.)
- **Performance context** examples (with `eval_accuracy`, `eval_auc`)

In [None]:
# ---------------------------------------------------------------------------
# Metric aggregation helpers (from eval.py)
# ---------------------------------------------------------------------------
def safe_mean(vals: list[float]) -> float:
    return sum(vals) / len(vals) if vals else 0.0

def safe_std(vals: list[float]) -> float:
    if len(vals) < 2:
        return 0.0
    m = safe_mean(vals)
    return math.sqrt(sum((v - m) ** 2 for v in vals) / (len(vals) - 1))


def extract_and_aggregate(data: dict) -> dict:
    """Extract per-split metrics and aggregate across all datasets.
    
    Separates split-quality examples (containing eval_scr) from
    performance-context examples (containing eval_accuracy).
    Returns a dict with split_records, perf_records, and metrics_agg.
    """
    all_scr_values: list[float] = []
    all_above_median_values: list[float] = []
    all_impurity_ratios: list[float] = []
    all_redundancy_ratios: list[float] = []
    all_joint_mi_ratios: list[float] = []
    all_synergy_efficiency: list[float] = []

    split_records: list[dict] = []
    perf_records: list[dict] = []

    for ds_block in data["datasets"]:
        ds_name = ds_block["dataset"]

        for ex in ds_block["examples"]:
            inp = json.loads(ex["input"])

            # Performance context examples
            if inp.get("type") == "performance_context":
                perf_records.append({
                    "dataset": ds_name,
                    "method": inp.get("method", ex.get("metadata_method", "")),
                    "fold": inp.get("fold", 0),
                    "accuracy": ex.get("eval_accuracy", 0.0),
                    "auc": ex.get("eval_auc", 0.0),
                })
                continue

            # Split quality examples
            scr = ex.get("eval_scr", 0.0)
            above_med = ex.get("eval_above_median_fraction", 0.0)
            imp_ratio = ex.get("eval_impurity_per_feature_ratio", 0.0)
            red_ratio = ex.get("eval_redundancy_ratio", 0.0)
            jmi_ratio = ex.get("eval_joint_mi_ratio", 0.0)
            syn_eff = ex.get("eval_synergy_weighted_efficiency", 0.0)

            if scr != 0.0:
                all_scr_values.append(scr)
            if above_med != 0.0:
                all_above_median_values.append(above_med)
            if imp_ratio != 0.0:
                all_impurity_ratios.append(imp_ratio)
            if red_ratio != 0.0:
                all_redundancy_ratios.append(red_ratio)
            if jmi_ratio != 0.0:
                all_joint_mi_ratios.append(jmi_ratio)
            if syn_eff != 0.0:
                all_synergy_efficiency.append(syn_eff)

            split_records.append({
                "dataset": ds_name,
                "method": ex.get("metadata_method", ""),
                "features": inp.get("features", []),
                "n_features": inp.get("n_features", 0),
                "rule_str": ex.get("metadata_rule_str", ""),
                "scr": scr,
                "above_median_fraction": above_med,
                "impurity_per_feature_ratio": imp_ratio,
                "redundancy_ratio": red_ratio,
                "joint_mi_ratio": jmi_ratio,
                "synergy_weighted_efficiency": syn_eff,
                "impurity_reduction": ex.get("metadata_impurity_reduction", 0.0),
                "best_axis_impurity": ex.get("metadata_best_axis_impurity", 0.0),
            })

    metrics_agg: dict[str, float] = {
        "mean_scr": round(safe_mean(all_scr_values), 6),
        "std_scr": round(safe_std(all_scr_values), 6),
        "n_oblique_splits_evaluated": len(all_scr_values),
        "mean_above_median_fraction": round(safe_mean(all_above_median_values), 6),
        "std_above_median_fraction": round(safe_std(all_above_median_values), 6),
        "expected_random_above_median": 0.5,
        "mean_impurity_per_feature_ratio": round(safe_mean(all_impurity_ratios), 6),
        "std_impurity_per_feature_ratio": round(safe_std(all_impurity_ratios), 6),
        "mean_redundancy_ratio": round(safe_mean(all_redundancy_ratios), 6),
        "std_redundancy_ratio": round(safe_std(all_redundancy_ratios), 6),
        "mean_joint_mi_ratio": round(safe_mean(all_joint_mi_ratios), 6),
        "std_joint_mi_ratio": round(safe_std(all_joint_mi_ratios), 6),
        "mean_synergy_weighted_efficiency": round(safe_mean(all_synergy_efficiency), 6),
        "std_synergy_weighted_efficiency": round(safe_std(all_synergy_efficiency), 6),
        "n_datasets_with_oblique_splits": len(
            set(r["dataset"].rsplit("_SG-FIGS-", 1)[0] for r in split_records)
        ),
        "total_examples": sum(len(d["examples"]) for d in data["datasets"]),
    }

    return {
        "split_records": split_records,
        "perf_records": perf_records,
        "metrics_agg": metrics_agg,
    }

In [None]:
# Run extraction on mini data
results = extract_and_aggregate(data)
split_records = results["split_records"]
perf_records = results["perf_records"]
metrics_agg = results["metrics_agg"]

print(f"Split quality records: {len(split_records)}")
print(f"Performance records:   {len(perf_records)}")
print(f"\n--- Per-Split Details ---")
for rec in split_records:
    print(f"  {rec['dataset']:30s}  SCR={rec['scr']:.3f}  "
          f"AboveMed={rec['above_median_fraction']:.3f}  "
          f"ImpRatio={rec['impurity_per_feature_ratio']:.3f}  "
          f"RedRatio={rec['redundancy_ratio']:.3f}")

### Aggregate Results and Interpretation

Print the aggregate metrics and interpret whether SG-FIGS oblique splits effectively concentrate synergistic feature pairs (from eval.py lines 531–611).

In [None]:
# ---------------------------------------------------------------------------
# Aggregate results and interpretation (from eval.py main())
# ---------------------------------------------------------------------------
print("=" * 60)
print("AGGREGATE RESULTS")
print("=" * 60)
for k, v in metrics_agg.items():
    print(f"  {k}: {v}")

print()
print("=" * 60)
print("INTERPRETATION SUMMARY")
print("=" * 60)

scr_mean = metrics_agg["mean_scr"]
print(
    f"SCR = {scr_mean:.3f} "
    f"({'> 1: splits concentrate high-synergy pairs' if scr_mean > 1 else '<= 1: splits do NOT preferentially select high-synergy pairs'})"
)

above_med = metrics_agg["mean_above_median_fraction"]
print(
    f"Above-median fraction = {above_med:.3f} (random baseline = 0.5, "
    f"{'better' if above_med > 0.5 else 'worse'} than random)"
)

imp_ratio = metrics_agg["mean_impurity_per_feature_ratio"]
print(
    f"Impurity/feature ratio = {imp_ratio:.3f} "
    f"({'oblique splits MORE efficient per feature' if imp_ratio > 1 else 'oblique splits LESS efficient per feature than best axis-aligned'})"
)

red_ratio = metrics_agg["mean_redundancy_ratio"]
print(
    f"Redundancy ratio = {red_ratio:.3f} "
    f"({'higher than random: features overlap' if red_ratio > 1 else 'lower than random: features complementary'})"
)

### Visualization

Visualize the six split-quality metrics across all oblique splits and show per-dataset performance context.

In [None]:
def visualize_results(split_records, perf_records, metrics_agg, title_prefix=""):
    """Reusable visualization for split quality audit results.
    
    Produces:
    1. Bar chart of the 6 aggregate metrics with reference lines
    2. Per-split radar-style comparison if enough splits exist
    3. Performance context summary table
    """
    # --- Figure 1: Aggregate metric bar chart ---
    metric_names = ["SCR", "Above-Med\nFraction", "Impurity/\nFeature Ratio",
                     "Redundancy\nRatio", "Joint MI\nRatio", "Synergy-Wtd\nEfficiency"]
    metric_keys = ["mean_scr", "mean_above_median_fraction",
                    "mean_impurity_per_feature_ratio", "mean_redundancy_ratio",
                    "mean_joint_mi_ratio", "mean_synergy_weighted_efficiency"]
    std_keys = ["std_scr", "std_above_median_fraction",
                "std_impurity_per_feature_ratio", "std_redundancy_ratio",
                "std_joint_mi_ratio", "std_synergy_weighted_efficiency"]
    # Reference lines: SCR>1 good, AboveMed>0.5 baseline, ImpRatio>1 means better, etc.
    references = [1.0, 0.5, 1.0, 1.0, 1.0, None]

    means = [metrics_agg[k] for k in metric_keys]
    stds = [metrics_agg[k] for k in std_keys]

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Left: Bar chart of aggregate metrics
    ax = axes[0]
    x = np.arange(len(metric_names))
    bars = ax.bar(x, means, yerr=stds, capsize=4, color="#4C72B0", alpha=0.8, edgecolor="black")
    for i, ref in enumerate(references):
        if ref is not None:
            ax.plot([i - 0.4, i + 0.4], [ref, ref], "r--", linewidth=1.5, alpha=0.7)
    ax.set_xticks(x)
    ax.set_xticklabels(metric_names, fontsize=9)
    ax.set_ylabel("Value")
    ax.set_title(f"{title_prefix}Aggregate Split Quality Metrics\n"
                 f"(n={metrics_agg['n_oblique_splits_evaluated']} oblique splits, "
                 f"red dashed = reference)")
    ax.grid(axis="y", alpha=0.3)

    # Right: Per-split SCR values colored by dataset
    ax2 = axes[1]
    if split_records:
        datasets = sorted(set(r["dataset"] for r in split_records))
        colors = plt.cm.Set2(np.linspace(0, 1, max(len(datasets), 3)))
        ds_color = {ds: colors[i] for i, ds in enumerate(datasets)}

        labels = []
        scr_vals = []
        bar_colors = []
        for rec in split_records:
            short_name = rec["dataset"].replace("_SG-FIGS-", "\nSG-FIGS-")
            labels.append(short_name)
            scr_vals.append(rec["scr"])
            bar_colors.append(ds_color[rec["dataset"]])

        x2 = np.arange(len(labels))
        ax2.bar(x2, scr_vals, color=bar_colors, alpha=0.8, edgecolor="black")
        ax2.axhline(y=1.0, color="red", linestyle="--", linewidth=1.5, alpha=0.7, label="SCR=1 (random)")
        ax2.set_xticks(x2)
        ax2.set_xticklabels(labels, fontsize=8, rotation=0, ha="center")
        ax2.set_ylabel("Synergy Concentration Ratio")
        ax2.set_title(f"{title_prefix}Per-Split SCR Values")
        ax2.legend(fontsize=8)
        ax2.grid(axis="y", alpha=0.3)
    else:
        ax2.text(0.5, 0.5, "No split records", ha="center", va="center")

    plt.tight_layout()
    plt.show()

    # --- Performance context summary ---
    if perf_records:
        print(f"\n{'='*60}")
        print(f"PERFORMANCE CONTEXT SUMMARY ({len(perf_records)} fold records)")
        print(f"{'='*60}")
        # Group by dataset
        from collections import defaultdict
        by_ds = defaultdict(list)
        for rec in perf_records:
            by_ds[rec["dataset"]].append(rec)
        for ds_name in sorted(by_ds.keys()):
            recs = by_ds[ds_name]
            acc_vals = [r["accuracy"] for r in recs if r["accuracy"] > 0]
            auc_vals = [r["auc"] for r in recs if r["auc"] > 0]
            print(f"  {ds_name}:")
            if acc_vals:
                print(f"    Accuracy: mean={safe_mean(acc_vals):.4f}, "
                      f"std={safe_std(acc_vals):.4f}, n={len(acc_vals)}")
            if auc_vals:
                print(f"    AUC:      mean={safe_mean(auc_vals):.4f}, "
                      f"std={safe_std(auc_vals):.4f}, n={len(auc_vals)}")


# Call visualization for Part 1
visualize_results(split_records, perf_records, metrics_agg, title_prefix="[Mini] ")

## Part 2 — Full Run (Original Parameters)

Load all 9 dataset×threshold combinations (3 datasets × 3 thresholds = 9 oblique splits + 135 performance fold records = 144 total examples) and re-run the same extraction and visualization pipeline.

In [None]:
data = load_full()
print(f"Loaded {len(data['datasets'])} datasets, "
      f"{sum(len(d['examples']) for d in data['datasets'])} total examples")
print(f"Datasets: {[d['dataset'] for d in data['datasets']]}")

In [None]:
# Run extraction on full data — all 9 dataset×threshold combinations
results_full = extract_and_aggregate(data)
split_records_full = results_full["split_records"]
perf_records_full = results_full["perf_records"]
metrics_agg_full = results_full["metrics_agg"]

print(f"Split quality records: {len(split_records_full)}")
print(f"Performance records:   {len(perf_records_full)}")
print(f"\n--- Per-Split Details ---")
for rec in split_records_full:
    print(f"  {rec['dataset']:30s}  SCR={rec['scr']:.3f}  "
          f"AboveMed={rec['above_median_fraction']:.3f}  "
          f"ImpRatio={rec['impurity_per_feature_ratio']:.3f}  "
          f"RedRatio={rec['redundancy_ratio']:.3f}")

In [None]:
# Full aggregate results and interpretation
print("=" * 60)
print("AGGREGATE RESULTS (Full)")
print("=" * 60)
for k, v in metrics_agg_full.items():
    print(f"  {k}: {v}")

print()
print("=" * 60)
print("INTERPRETATION SUMMARY (Full)")
print("=" * 60)

scr_mean = metrics_agg_full["mean_scr"]
print(
    f"SCR = {scr_mean:.3f} "
    f"({'> 1: splits concentrate high-synergy pairs' if scr_mean > 1 else '<= 1: splits do NOT preferentially select high-synergy pairs'})"
)

above_med = metrics_agg_full["mean_above_median_fraction"]
print(
    f"Above-median fraction = {above_med:.3f} (random baseline = 0.5, "
    f"{'better' if above_med > 0.5 else 'worse'} than random)"
)

imp_ratio = metrics_agg_full["mean_impurity_per_feature_ratio"]
print(
    f"Impurity/feature ratio = {imp_ratio:.3f} "
    f"({'oblique splits MORE efficient per feature' if imp_ratio > 1 else 'oblique splits LESS efficient per feature than best axis-aligned'})"
)

red_ratio = metrics_agg_full["mean_redundancy_ratio"]
print(
    f"Redundancy ratio = {red_ratio:.3f} "
    f"({'higher than random: features overlap' if red_ratio > 1 else 'lower than random: features complementary'})"
)

In [None]:
# Call the same reusable visualization function for full results
visualize_results(split_records_full, perf_records_full, metrics_agg_full, title_prefix="[Full] ")