# Synergy Threshold Sensitivity & Adaptive Thresholding for SG-FIGS

This notebook demonstrates the **SG-FIGS** (Synergy-Guided Fast Interpretable Greedy-tree Sums) method, which uses feature synergy information to guide oblique splits in interpretable tree models.

**What this experiment does:**
1. Loads pre-computed experiment results (balanced accuracy, AUC, interpretability scores) across multiple datasets
2. Compares **SG-FIGS** against baselines: Standard FIGS (axis-aligned), RO-FIGS (random oblique), and GBDT
3. Evaluates synergy threshold percentiles (50th, 75th, 90th) and their impact on model performance
4. Analyzes correlations between synergy distribution properties and optimal threshold selection
5. Computes aggregate statistics and determines the best universal threshold

In [None]:
%%capture
%pip install -q --force-reinstall numpy==1.26.4 scipy==1.15.3 scikit-learn==1.7.2 networkx==3.4.2 matplotlib==3.10.7

In [None]:
import warnings
warnings.filterwarnings("ignore")

import json
import time
import numpy as np
import networkx as nx
from collections import Counter
from scipy import stats
from sklearn.preprocessing import StandardScaler, KBinsDiscretizer
from sklearn.linear_model import RidgeClassifier
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, mutual_info_score
from sklearn.ensemble import GradientBoostingClassifier
import matplotlib.pyplot as plt

In [None]:
GITHUB_DATA_URL = "https://raw.githubusercontent.com/AMGrobelnik/ai-invention-fb8249-synergy-guided-oblique-splits-using-part/main/experiment_iter3_sg_figs_thresh/demo/mini_demo_data.json"
import json, os

def load_data():
    try:
        import urllib.request
        with urllib.request.urlopen(GITHUB_DATA_URL) as response:
            return json.loads(response.read().decode())
    except Exception: pass
    if os.path.exists("mini_demo_data.json"):
        with open("mini_demo_data.json") as f: return json.load(f)
    raise FileNotFoundError("Could not load mini_demo_data.json")

In [None]:
data = load_data()
print(f"Loaded {len(data['datasets'])} datasets")
for ds in data["datasets"]:
    print(f"  {ds['dataset']}: {len(ds['examples'])} examples")

## Configuration

Tunable parameters for the experiment. Threshold percentiles control how aggressively the synergy graph is pruned, max splits control tree complexity, and N_FOLDS sets cross-validation folds.

In [None]:
# --- Tunable parameters ---
# These filter the pre-computed results loaded from mini_demo_data.json
THRESHOLD_PERCENTILES = [50, 75, 90]
MAX_SPLITS_VALUES = [5, 10, 15]  # Original values
N_FOLDS = 5  # Original value
RANDOM_SEED = 42
ADAPTIVE_CANDIDATES = [50, 60, 70, 80, 90]  # Original values

## Parse Datasets from Loaded Data

Extract feature matrices X, labels y, and fold assignments from the JSON experiment data.

In [None]:
def parse_datasets(data):
    """Parse datasets from mini_demo_data.json format into numpy arrays."""
    datasets = {}
    for ds_entry in data["datasets"]:
        name = ds_entry["dataset"]
        examples = ds_entry["examples"]
        # Each example has input (JSON with features) and output (JSON with metrics)
        # We parse input to get experiment config, output to get results
        parsed_examples = []
        for ex in examples:
            inp = json.loads(ex["input"])
            out = json.loads(ex["output"])
            parsed_examples.append({
                "input": inp,
                "output": out,
                "metadata_fold": ex["metadata_fold"],
                "metadata_threshold_percentile": ex["metadata_threshold_percentile"],
                "metadata_max_splits": ex["metadata_max_splits"],
            })
        datasets[name] = {
            "examples": parsed_examples,
            "n_features": parsed_examples[0]["input"]["n_features"],
        }
    return datasets

all_datasets = parse_datasets(data)
print(f"Parsed {len(all_datasets)} datasets:")
for name, info in all_datasets.items():
    print(f"  {name}: {len(info['examples'])} examples, {info['n_features']} features")

## Per-Dataset Analysis

Aggregate results by threshold percentile for each dataset. Compute mean balanced accuracy and AUC across folds/splits for each threshold setting, identify optimal thresholds, and measure improvement over the default 75th percentile.

In [None]:
start_time = time.time()

per_dataset_analysis = {}
dataset_names = sorted(all_datasets.keys())

for ds_name in dataset_names:
    ds = all_datasets[ds_name]
    examples = ds["examples"]

    # Group by threshold percentile
    threshold_results = {}
    for pct in THRESHOLD_PERCENTILES:
        pct_examples = [ex for ex in examples if ex["metadata_threshold_percentile"] == pct]
        if not pct_examples:
            threshold_results[str(pct)] = {"mean_balanced_acc": 0.5, "std_balanced_acc": 0.0, "mean_auc": 0.5}
            continue

        sg_accs = [ex["output"]["sg_figs_balanced_acc"] for ex in pct_examples]
        sg_aucs = [ex["output"]["sg_figs_auc"] for ex in pct_examples]
        figs_accs = [ex["output"]["figs_balanced_acc"] for ex in pct_examples]
        rofigs_accs = [ex["output"]["rofigs_balanced_acc"] for ex in pct_examples]
        gbdt_accs = [ex["output"]["gbdt_balanced_acc"] for ex in pct_examples]

        threshold_results[str(pct)] = {
            "mean_balanced_acc": round(float(np.mean(sg_accs)), 6),
            "std_balanced_acc": round(float(np.std(sg_accs)), 6),
            "mean_auc": round(float(np.mean(sg_aucs)), 6),
            "mean_figs_acc": round(float(np.mean(figs_accs)), 6),
            "mean_rofigs_acc": round(float(np.mean(rofigs_accs)), 6),
            "mean_gbdt_acc": round(float(np.mean(gbdt_accs)), 6),
        }

    # Find optimal threshold
    best_pct = max(THRESHOLD_PERCENTILES, key=lambda p: threshold_results[str(p)]["mean_balanced_acc"])
    best_acc = threshold_results[str(best_pct)]["mean_balanced_acc"]
    fixed_75 = threshold_results.get("75", {}).get("mean_balanced_acc", 0.5)

    per_dataset_analysis[ds_name] = {
        "threshold_results": threshold_results,
        "optimal_threshold": best_pct,
        "optimal_acc": best_acc,
        "fixed_75_acc": fixed_75,
        "improvement": round(best_acc - fixed_75, 6),
        "n_features": ds["n_features"],
    }

print(f"Analysis completed in {time.time() - start_time:.2f}s")
print(f"\nPer-dataset optimal thresholds:")
for ds_name in dataset_names:
    a = per_dataset_analysis[ds_name]
    print(f"  {ds_name:45s} | optimal: {a['optimal_threshold']:3d}th pct | "
          f"acc: {a['optimal_acc']:.4f} | improvement: {a['improvement']:+.4f}")

## Aggregate Statistics & Correlation Analysis

Compute overall statistics across all datasets: best universal threshold, SG-FIGS vs baselines mean differences, and Spearman correlations between dataset properties and optimal thresholds.

In [None]:
# --- Aggregate statistics ---
improvements = [per_dataset_analysis[d]["improvement"] for d in dataset_names]
n_improved = sum(1 for imp in improvements if imp > 0.001)

# Best universal threshold (highest mean accuracy across datasets)
universal_accs = {}
for pct in THRESHOLD_PERCENTILES:
    accs = [per_dataset_analysis[d]["threshold_results"][str(pct)]["mean_balanced_acc"]
            for d in dataset_names]
    universal_accs[pct] = float(np.mean(accs))
best_universal = max(universal_accs, key=universal_accs.get)

# SG-FIGS vs baselines (using 75th percentile as reference)
sg_vs_figs = []
sg_vs_rofigs = []
sg_vs_gbdt = []
for ds_name in dataset_names:
    tr = per_dataset_analysis[ds_name]["threshold_results"].get("75", {})
    if "mean_figs_acc" in tr:
        sg_vs_figs.append(tr["mean_balanced_acc"] - tr["mean_figs_acc"])
        sg_vs_rofigs.append(tr["mean_balanced_acc"] - tr["mean_rofigs_acc"])
        sg_vs_gbdt.append(tr["mean_balanced_acc"] - tr["mean_gbdt_acc"])

aggregate = {
    "mean_improvement_from_tuning": round(float(np.mean(improvements)), 6),
    "n_datasets_improved": n_improved,
    "best_universal_threshold": int(best_universal),
    "sg_figs_vs_figs_mean_diff": round(float(np.mean(sg_vs_figs)), 6) if sg_vs_figs else 0.0,
    "sg_figs_vs_rofigs_mean_diff": round(float(np.mean(sg_vs_rofigs)), 6) if sg_vs_rofigs else 0.0,
    "sg_figs_vs_gbdt_mean_diff": round(float(np.mean(sg_vs_gbdt)), 6) if sg_vs_gbdt else 0.0,
}

# Correlation analysis
opt_pcts = [per_dataset_analysis[d]["optimal_threshold"] for d in dataset_names]
n_feats = [per_dataset_analysis[d]["n_features"] for d in dataset_names]

def safe_spearman(a, b):
    try:
        rho, pval = stats.spearmanr(a, b)
        return {"rho": round(float(rho), 4) if not np.isnan(rho) else 0.0,
                "p_value": round(float(pval), 4) if not np.isnan(pval) else 1.0}
    except Exception:
        return {"rho": 0.0, "p_value": 1.0}

correlation_analysis = {
    "optimal_percentile_vs_n_features": safe_spearman(opt_pcts, n_feats),
    "improvement_vs_n_features": safe_spearman(improvements, n_feats),
}

print("=" * 60)
print("AGGREGATE RESULTS")
print("=" * 60)
print(f"  Best universal threshold: {aggregate['best_universal_threshold']}th percentile")
print(f"  Mean improvement from tuning: {aggregate['mean_improvement_from_tuning']:.4f}")
print(f"  Datasets improved: {aggregate['n_datasets_improved']}/{len(dataset_names)}")
print(f"  SG-FIGS vs FIGS: {aggregate['sg_figs_vs_figs_mean_diff']:+.4f}")
print(f"  SG-FIGS vs RO-FIGS: {aggregate['sg_figs_vs_rofigs_mean_diff']:+.4f}")
print(f"  SG-FIGS vs GBDT: {aggregate['sg_figs_vs_gbdt_mean_diff']:+.4f}")
print(f"\nCorrelation analysis:")
for key, val in correlation_analysis.items():
    print(f"  {key}: rho={val['rho']:.4f}, p={val['p_value']:.4f}")
print(f"\nUniversal threshold accuracies:")
for pct, acc in sorted(universal_accs.items()):
    marker = " <-- best" if pct == best_universal else ""
    print(f"  {pct}th percentile: {acc:.4f}{marker}")

## Visualization

**Left**: SG-FIGS balanced accuracy across synergy threshold percentiles for each dataset.
**Right**: Method comparison — SG-FIGS vs FIGS, RO-FIGS, and GBDT baselines at the 75th percentile threshold.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# --- Plot 1: Threshold sensitivity per dataset ---
ax1 = axes[0]
colors = plt.cm.Set2(np.linspace(0, 1, len(dataset_names)))
for idx, ds_name in enumerate(dataset_names):
    accs = [per_dataset_analysis[ds_name]["threshold_results"][str(p)]["mean_balanced_acc"]
            for p in THRESHOLD_PERCENTILES]
    short_name = ds_name[:15] + "..." if len(ds_name) > 15 else ds_name
    ax1.plot(THRESHOLD_PERCENTILES, accs, 'o-', color=colors[idx], label=short_name, linewidth=1.5, markersize=5)
ax1.set_xlabel("Synergy Threshold Percentile")
ax1.set_ylabel("Mean Balanced Accuracy")
ax1.set_title("Threshold Sensitivity by Dataset")
ax1.legend(fontsize=7, loc="best")
ax1.set_xticks(THRESHOLD_PERCENTILES)
ax1.grid(True, alpha=0.3)

# --- Plot 2: Method comparison at 75th percentile ---
ax2 = axes[1]
methods = ["SG-FIGS", "FIGS", "RO-FIGS", "GBDT"]
method_keys = ["mean_balanced_acc", "mean_figs_acc", "mean_rofigs_acc", "mean_gbdt_acc"]
x = np.arange(len(dataset_names))
width = 0.18

for i, (method, key) in enumerate(zip(methods, method_keys)):
    vals = []
    for ds_name in dataset_names:
        tr = per_dataset_analysis[ds_name]["threshold_results"].get("75", {})
        vals.append(tr.get(key, 0.5))
    ax2.bar(x + i * width, vals, width, label=method, alpha=0.85)

short_names = [n[:8] + ".." if len(n) > 10 else n for n in dataset_names]
ax2.set_xticks(x + width * 1.5)
ax2.set_xticklabels(short_names, rotation=45, ha="right", fontsize=7)
ax2.set_ylabel("Balanced Accuracy")
ax2.set_title("Method Comparison (75th Percentile)")
ax2.legend(fontsize=8)
ax2.grid(True, alpha=0.3, axis="y")

plt.tight_layout()
plt.show()

print(f"\nTotal notebook runtime: {time.time() - start_time:.2f}s")