# SG-FIGS: Fold-Aware Synergy-Guided Oblique Splits with Threshold Ablation

This notebook demonstrates the SG-FIGS experiment, which compares 6 methods across tabular classification benchmarks:

1. **FIGS** — Standard axis-aligned splits
2. **RO-FIGS** — Random oblique splits (no synergy guidance)
3. **SG-FIGS-10** — Synergy threshold p90 (aggressive)
4. **SG-FIGS-25** — Synergy threshold p75 (moderate)
5. **SG-FIGS-50** — Synergy threshold p50 (permissive)
6. **GradientBoosting** — Non-interpretable baseline

The experiment implements fold-aware PID synergy graphs, oblique splits via Ridge projection, non-circular interpretability scoring, and statistical tests (Wilcoxon, Friedman/Nemenyi).

- **Part 1 (Quick Demo)**: Runs on 3 datasets with 2-fold CV for fast iteration
- **Part 2 (Full Run)**: Runs on all 12 datasets with 5-fold CV (original parameters)

In [None]:
# ── NUMPY MONKEY-PATCH (must be before dit import) ──────────────────────────
import numpy as np
if not hasattr(np, 'alltrue'):
    np.alltrue = np.all
if not hasattr(np, 'cumproduct'):
    np.cumproduct = np.cumprod
if not hasattr(np, 'product'):
    np.product = np.prod

# ── IMPORTS ──────────────────────────────────────────────────────────────────
import json
import sys
import time
import warnings
from pathlib import Path

import dit
from dit.pid import PID_WB
from imodels import FIGSClassifier
from imodels.tree.figs import Node
from sklearn.linear_model import Ridge
from sklearn.tree import DecisionTreeRegressor
from sklearn.preprocessing import (
    KBinsDiscretizer, StandardScaler, OrdinalEncoder, LabelEncoder
)
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.datasets import load_breast_cancer, load_wine, fetch_openml
from sklearn.ensemble import GradientBoostingClassifier
from scipy.stats import wilcoxon, friedmanchisquare, rankdata, ttest_rel
from scipy.stats import studentized_range

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.dpi'] = 100

warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

In [None]:
# ── DATA LOADING HELPERS ─────────────────────────────────────────────────────
GITHUB_FULL_DATA_URL = "https://raw.githubusercontent.com/AMGrobelnik/ai-invention-ac2586-synergy-guided-oblique-splits-using-part/main/sg_figs_eval/demo/full_demo_data.json"
GITHUB_MINI_DATA_URL = "https://raw.githubusercontent.com/AMGrobelnik/ai-invention-ac2586-synergy-guided-oblique-splits-using-part/main/sg_figs_eval/demo/mini_demo_data.json"
import json, os

def _load_json(url, local_path):
    try:
        import urllib.request
        with urllib.request.urlopen(url) as response:
            return json.loads(response.read().decode())
    except Exception: pass
    if os.path.exists(local_path):
        with open(local_path) as f: return json.load(f)
    raise FileNotFoundError(f"Could not load {local_path}")

def load_mini():
    return _load_json(GITHUB_MINI_DATA_URL, "mini_demo_data.json")

def load_full():
    return _load_json(GITHUB_FULL_DATA_URL, "full_demo_data.json")

---
## Part 1 — Quick Demo (Mini Data)

Runs on 3 datasets (banknote, breast_cancer, wine) with **2-fold CV** and a single `max_rules=5` setting for fast execution.

In [None]:
data = load_mini()
print(f"Loaded {len(data['datasets'])} datasets")
for ds in data['datasets']:
    print(f"  {ds['dataset']}: {len(ds['examples'])} examples")

### Constants and Dataset Configuration

Defines the 12 benchmark datasets and experiment parameters. For Part 1, we reduce to 3 datasets and 2-fold CV.

In [None]:
# ── CONSTANTS ────────────────────────────────────────────────────────────────
DATASETS = {
    'breast_cancer': {'source': 'sklearn', 'domain': 'medical'},
    'wine': {'source': 'sklearn', 'domain': 'chemistry'},
    'diabetes': {'source': 'openml', 'openml_id': 37, 'domain': 'medical'},
    'heart_statlog': {'source': 'openml', 'openml_id': 53, 'domain': 'medical'},
    'ionosphere': {'source': 'openml', 'openml_id': 59, 'domain': 'physics'},
    'sonar': {'source': 'openml', 'openml_id': 40, 'domain': 'signal'},
    'vehicle': {'source': 'openml', 'openml_id': 54, 'domain': 'vision'},
    'segment': {'source': 'openml', 'openml_id': 36, 'domain': 'vision'},
    'glass': {'source': 'openml', 'openml_id': 41, 'domain': 'forensics'},
    'banknote': {'source': 'openml', 'openml_id': 1462, 'domain': 'image'},
    'credit_g': {'source': 'openml', 'openml_id': 31, 'domain': 'finance'},
    'australian': {'source': 'openml', 'openml_id': 40981, 'domain': 'finance'},
}

THRESHOLDS = {'SG-FIGS-10': 90, 'SG-FIGS-25': 75, 'SG-FIGS-50': 50}

# ── QUICK DEMO PARAMS (reduced for Part 1) ──────────────────────────────────
DATASET_ORDER = ['banknote', 'wine', 'breast_cancer']  # Original: all 12
MAX_RULES_VALUES = [5]  # Original: [5, 10, 15]
N_FOLDS = 2  # Original: 5
PER_DATASET_TIMEOUT = 120  # Original: 600

### Phase 1 — Dataset Loading

Loads datasets from sklearn or OpenML, handles categorical encoding, and standardizes features.

In [None]:
def load_dataset(name: str, info: dict) -> tuple:
    """Load and preprocess dataset. Returns X (float64), y (int), feature_names."""
    if info['source'] == 'sklearn':
        loader = load_breast_cancer if name == 'breast_cancer' else load_wine
        data = loader()
        X, y, feat_names = data.data, data.target, list(data.feature_names)
    else:
        d = fetch_openml(data_id=info['openml_id'], as_frame=True, parser='auto')
        df = d.frame
        target_col = d.target.name
        feat_names = [c for c in df.columns if c != target_col]
        X_df = df[feat_names]
        y_raw = df[target_col]

        le = LabelEncoder()
        y = le.fit_transform(y_raw)

        cat_cols = X_df.select_dtypes(include=['category', 'object']).columns.tolist()

        X = np.zeros((len(X_df), len(feat_names)), dtype=float)
        for i, col in enumerate(feat_names):
            if col in cat_cols:
                oe = OrdinalEncoder(
                    handle_unknown='use_encoded_value',
                    unknown_value=-1,
                )
                X[:, i] = oe.fit_transform(X_df[[col]]).ravel()
            else:
                vals = X_df[col].values
                # Handle potential NaN
                vals = np.where(np.isnan(vals.astype(float)), 0.0, vals.astype(float))
                X[:, i] = vals

    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    y = y.astype(int)
    return X, y, feat_names

### Phase 2 — Synergy Graph Construction (Fold-Aware)

Computes pairwise PID synergy between features using only training data, then builds an adjacency graph by thresholding.

In [None]:
def _compute_single_synergy(
    x1: np.ndarray,
    x2: np.ndarray,
    y: np.ndarray,
) -> float:
    """Compute PID synergy between two discretized features and target."""
    probs: dict = {}
    for idx in range(len(y)):
        key = (int(x1[idx]), int(x2[idx]), int(y[idx]))
        probs[key] = probs.get(key, 0) + 1
    total = sum(probs.values())
    if total == 0:
        return 0.0
    outcomes = list(probs.keys())
    pmf = [probs[o] / total for o in outcomes]

    d = dit.Distribution(outcomes, pmf)
    pid = PID_WB(d, [[0], [1]], [2])
    syn = pid.get_pi(pid._lattice.top)
    return float(syn)


def compute_pairwise_synergy_matrix(
    X_train: np.ndarray,
    y_train: np.ndarray,
    n_bins: int = 5,
    max_features: int = 30,
) -> tuple:
    """Compute pairwise PID synergy for all feature pairs using ONLY training data."""
    t0 = time.time()
    p = X_train.shape[1]

    # Pre-filter features for high-dimensional datasets
    feature_mask = np.arange(p)
    if p > max_features:
        from sklearn.feature_selection import mutual_info_classif
        mi_scores = mutual_info_classif(
            X_train, y_train, random_state=42, n_neighbors=5,
        )
        top_idx = np.argsort(mi_scores)[-max_features:]
        feature_mask = np.sort(top_idx)

    # Discretize on training data only
    kbd = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='quantile')
    try:
        X_disc = kbd.fit_transform(X_train).astype(int)
    except ValueError:
        kbd = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='uniform')
        X_disc = kbd.fit_transform(X_train).astype(int)
    y_disc = y_train.astype(int)

    synergy_matrix = np.zeros((p, p))
    for ii in range(len(feature_mask)):
        i = feature_mask[ii]
        for jj in range(ii + 1, len(feature_mask)):
            j = feature_mask[jj]
            try:
                syn = _compute_single_synergy(X_disc[:, i], X_disc[:, j], y_disc)
            except Exception:
                syn = 0.0
            synergy_matrix[i, j] = syn
            synergy_matrix[j, i] = syn

    elapsed = time.time() - t0
    return synergy_matrix, elapsed


def compute_fast_synergy_proxy(
    X_train: np.ndarray,
    y_train: np.ndarray,
    n_bins: int = 5,
) -> np.ndarray:
    """Fast MI-based synergy proxy for validation set.

    Synergy ≈ I(X_i, X_j; Y) - I(X_i; Y) - I(X_j; Y)
    """
    from sklearn.metrics import mutual_info_score
    p = X_train.shape[1]

    kbd = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='quantile')
    try:
        X_disc = kbd.fit_transform(X_train).astype(int)
    except ValueError:
        kbd = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='uniform')
        X_disc = kbd.fit_transform(X_train).astype(int)
    y_disc = y_train.astype(int)

    mi_individual = np.zeros(p)
    for i in range(p):
        mi_individual[i] = mutual_info_score(X_disc[:, i], y_disc)

    synergy_proxy = np.zeros((p, p))
    for i in range(p):
        for j in range(i + 1, p):
            joint = X_disc[:, i] * (n_bins + 1) + X_disc[:, j]
            mi_joint = mutual_info_score(joint.astype(int), y_disc)
            interaction = mi_joint - mi_individual[i] - mi_individual[j]
            synergy_proxy[i, j] = max(interaction, 0.0)
            synergy_proxy[j, i] = synergy_proxy[i, j]

    return synergy_proxy


def build_synergy_graph(
    synergy_matrix: np.ndarray,
    threshold_percentile: float,
) -> tuple:
    """Build adjacency from synergy matrix by keeping edges above the given percentile."""
    p = synergy_matrix.shape[0]
    upper_tri = synergy_matrix[np.triu_indices(p, k=1)]

    if len(upper_tri) == 0 or np.all(upper_tri == 0):
        return {i: set() for i in range(p)}, 0, 0.0

    cutoff = float(np.percentile(upper_tri, threshold_percentile))
    MIN_SYNERGY = 1e-6

    adj: dict = {i: set() for i in range(p)}
    n_edges = 0
    for i in range(p):
        for j in range(i + 1, p):
            if synergy_matrix[i, j] >= cutoff and synergy_matrix[i, j] > MIN_SYNERGY:
                adj[i].add(j)
                adj[j].add(i)
                n_edges += 1

    return adj, n_edges, cutoff

### Phase 3 — SG-FIGS and RO-FIGS Implementation

Custom FIGS classifiers with synergy-guided oblique splits (SG-FIGS) and random oblique splits (RO-FIGS). Both extend `FIGSClassifier` from the `imodels` library, adding oblique split capability via Ridge regression projections.

In [None]:
class SynergyGuidedFIGS(FIGSClassifier):
    """FIGS with synergy-guided oblique splits."""

    def __init__(
        self,
        synergy_adj: dict,
        synergy_matrix: np.ndarray,
        max_rules: int = 12,
        ridge_alpha: float = 1.0,
        max_features_per_split: int = 5,
        random_state: int = None,
    ):
        super().__init__(max_rules=max_rules, random_state=random_state)
        self.synergy_adj = synergy_adj
        self.synergy_matrix = synergy_matrix
        self.ridge_alpha = ridge_alpha
        self.max_features_per_split = max_features_per_split
        self.oblique_splits_info: list = []

    def _construct_node_with_stump(
        self, X, y, idxs, tree_num,
        sample_weight=None, compare_nodes_with_sample_weight=True,
        max_features=None, depth=None,
    ):
        """Override: try synergy-guided oblique splits alongside axis-aligned."""
        node_axis = super()._construct_node_with_stump(
            X, y, idxs, tree_num,
            sample_weight=sample_weight,
            compare_nodes_with_sample_weight=compare_nodes_with_sample_weight,
            max_features=max_features,
            depth=depth,
        )

        if not hasattr(node_axis, 'left_temp') or node_axis.left_temp is None:
            node_axis.is_oblique = False
            return node_axis

        best_node = node_axis
        best_impurity_reduction = node_axis.impurity_reduction or 0.0

        best_feature = node_axis.feature
        if best_feature is not None and best_feature in self.synergy_adj:
            neighbors = self.synergy_adj[best_feature]
            if len(neighbors) > 0:
                feat_subset = sorted([best_feature] + list(neighbors))

                if len(feat_subset) > self.max_features_per_split:
                    syn_scores = [
                        (f, self.synergy_matrix[best_feature, f])
                        for f in neighbors
                    ]
                    syn_scores.sort(key=lambda x: x[1], reverse=True)
                    feat_subset = [best_feature] + [
                        f for f, _ in syn_scores[:self.max_features_per_split - 1]
                    ]

                if len(feat_subset) >= 2:
                    oblique_node = self._try_oblique_split(
                        X, y, idxs, tree_num, feat_subset,
                        sample_weight=sample_weight,
                        depth=depth,
                    )
                    if oblique_node is not None:
                        obl_imp = oblique_node.impurity_reduction or 0.0
                        if obl_imp > best_impurity_reduction:
                            best_node = oblique_node
                            best_impurity_reduction = obl_imp

        if not hasattr(best_node, 'is_oblique'):
            best_node.is_oblique = False
        return best_node

    def _try_oblique_split(
        self,
        X: np.ndarray,
        y: np.ndarray,
        idxs: np.ndarray,
        tree_num: int,
        feat_subset: list,
        sample_weight: np.ndarray = None,
        depth: int = None,
    ):
        """Try an oblique split using Ridge on the feature subset."""
        try:
            X_sub = X[idxs][:, feat_subset]
            y_sub = y[idxs]
            if y_sub.ndim > 1:
                y_fit = y_sub[:, 0]
            else:
                y_fit = y_sub.copy().astype(float)

            if len(np.unique(y_fit)) < 2:
                return None

            ridge = Ridge(alpha=self.ridge_alpha)
            ridge.fit(X_sub, y_fit)
            weights = ridge.coef_
            intercept = float(ridge.intercept_)
            projection = X_sub @ weights + intercept

            dt = DecisionTreeRegressor(max_depth=1)
            sweight = None
            if sample_weight is not None:
                sweight = sample_weight[idxs]
            dt.fit(projection.reshape(-1, 1), y_sub, sample_weight=sweight)

            if dt.tree_.feature[0] == -2 or len(dt.tree_.feature) < 3:
                return None

            threshold = float(dt.tree_.threshold[0])
            impurity = dt.tree_.impurity
            n_samples = dt.tree_.n_node_samples

            if sample_weight is not None:
                proj_full = X[idxs][:, feat_subset] @ weights + intercept
                idxs_left_mask = proj_full <= threshold
                n_left = sample_weight[idxs][idxs_left_mask].sum()
                n_right = sample_weight[idxs][~idxs_left_mask].sum()
                n_total = n_left + n_right
            else:
                n_left = n_samples[1]
                n_right = n_samples[2]
                n_total = n_samples[0]

            if n_total == 0:
                return None

            imp_red = (
                impurity[0]
                - impurity[1] * n_left / n_total
                - impurity[2] * n_right / n_total
            ) * n_total

            if imp_red <= 0:
                return None

            proj_full = X[idxs][:, feat_subset] @ weights + intercept
            idxs_left = idxs.copy()
            idxs_right = idxs.copy()
            idxs_left[idxs] = proj_full <= threshold
            idxs_right[idxs] = proj_full > threshold

            node_oblique = Node(
                idxs=idxs,
                value=dt.tree_.value[0],
                tree_num=tree_num,
                feature=feat_subset[0],
                threshold=threshold,
                impurity=float(impurity[0]),
                impurity_reduction=float(imp_red),
                depth=depth,
            )
            node_oblique.is_oblique = True
            node_oblique.oblique_features = feat_subset
            node_oblique.oblique_weights = weights.tolist()
            node_oblique.oblique_bias = intercept

            node_oblique.setattrs(
                left_temp=Node(
                    idxs=idxs_left,
                    value=dt.tree_.value[1],
                    tree_num=tree_num,
                    depth=(depth or 0) + 1,
                ),
                right_temp=Node(
                    idxs=idxs_right,
                    value=dt.tree_.value[2],
                    tree_num=tree_num,
                    depth=(depth or 0) + 1,
                ),
            )

            self.oblique_splits_info.append({
                'features': feat_subset,
                'weights': weights.tolist(),
                'threshold': float(threshold),
                'impurity_reduction': float(imp_red),
            })

            return node_oblique

        except Exception:
            return None

    def _predict_tree(self, root: Node, X: np.ndarray) -> np.ndarray:
        """Override to handle oblique splits during prediction."""
        def _predict_single(node, x):
            if node.left is None and node.right is None:
                return node.value

            if getattr(node, 'is_oblique', False):
                proj = sum(
                    w * x[f]
                    for w, f in zip(node.oblique_weights, node.oblique_features)
                )
                proj += node.oblique_bias
                go_left = proj <= node.threshold
            else:
                go_left = x[node.feature] <= node.threshold

            if go_left:
                return _predict_single(node.left, x) if node.left else node.value
            else:
                return _predict_single(node.right, x) if node.right else node.value

        preds = np.zeros((X.shape[0], self.n_outputs))
        for i in range(X.shape[0]):
            preds[i] = _predict_single(root, X[i])
        return preds


class RandomObliqueFIGS(FIGSClassifier):
    """Baseline: oblique splits with RANDOM feature subsets (no synergy guidance)."""

    def __init__(
        self,
        beam_size: int = 3,
        max_rules: int = 12,
        ridge_alpha: float = 1.0,
        random_state: int = None,
    ):
        super().__init__(max_rules=max_rules, random_state=random_state)
        self.beam_size = beam_size
        self.ridge_alpha = ridge_alpha
        self.oblique_splits_info: list = []
        self._rng = np.random.RandomState(random_state)

    def _construct_node_with_stump(
        self, X, y, idxs, tree_num,
        sample_weight=None, compare_nodes_with_sample_weight=True,
        max_features=None, depth=None,
    ):
        node_axis = super()._construct_node_with_stump(
            X, y, idxs, tree_num,
            sample_weight=sample_weight,
            compare_nodes_with_sample_weight=compare_nodes_with_sample_weight,
            max_features=max_features,
            depth=depth,
        )

        if not hasattr(node_axis, 'left_temp') or node_axis.left_temp is None:
            node_axis.is_oblique = False
            return node_axis

        best_node = node_axis
        best_imp_red = node_axis.impurity_reduction or 0.0

        p = X.shape[1]
        for _ in range(3):
            feat_size = min(self.beam_size, p)
            if feat_size < 2:
                continue
            feat_subset = sorted(
                self._rng.choice(p, size=feat_size, replace=False).tolist()
            )

            oblique_node = self._try_oblique_split(
                X, y, idxs, tree_num, feat_subset,
                sample_weight=sample_weight,
                depth=depth,
            )
            if oblique_node is not None:
                obl_imp = oblique_node.impurity_reduction or 0.0
                if obl_imp > best_imp_red:
                    best_node = oblique_node
                    best_imp_red = obl_imp

        if not hasattr(best_node, 'is_oblique'):
            best_node.is_oblique = False
        return best_node

    def _try_oblique_split(
        self,
        X: np.ndarray,
        y: np.ndarray,
        idxs: np.ndarray,
        tree_num: int,
        feat_subset: list,
        sample_weight: np.ndarray = None,
        depth: int = None,
    ):
        """Try an oblique split with random feature subset."""
        try:
            X_sub = X[idxs][:, feat_subset]
            y_sub = y[idxs]
            if y_sub.ndim > 1:
                y_fit = y_sub[:, 0]
            else:
                y_fit = y_sub.copy().astype(float)

            if len(np.unique(y_fit)) < 2:
                return None

            ridge = Ridge(alpha=self.ridge_alpha)
            ridge.fit(X_sub, y_fit)
            weights = ridge.coef_
            intercept = float(ridge.intercept_)
            projection = X_sub @ weights + intercept

            dt = DecisionTreeRegressor(max_depth=1)
            sweight = None
            if sample_weight is not None:
                sweight = sample_weight[idxs]
            dt.fit(projection.reshape(-1, 1), y_sub, sample_weight=sweight)

            if dt.tree_.feature[0] == -2 or len(dt.tree_.feature) < 3:
                return None

            threshold = float(dt.tree_.threshold[0])
            impurity = dt.tree_.impurity
            n_samples = dt.tree_.n_node_samples

            if sample_weight is not None:
                proj_full = X[idxs][:, feat_subset] @ weights + intercept
                idxs_left_mask = proj_full <= threshold
                n_left = sample_weight[idxs][idxs_left_mask].sum()
                n_right = sample_weight[idxs][~idxs_left_mask].sum()
                n_total = n_left + n_right
            else:
                n_left = n_samples[1]
                n_right = n_samples[2]
                n_total = n_samples[0]

            if n_total == 0:
                return None

            imp_red = (
                impurity[0]
                - impurity[1] * n_left / n_total
                - impurity[2] * n_right / n_total
            ) * n_total

            if imp_red <= 0:
                return None

            proj_full = X[idxs][:, feat_subset] @ weights + intercept
            idxs_left = idxs.copy()
            idxs_right = idxs.copy()
            idxs_left[idxs] = proj_full <= threshold
            idxs_right[idxs] = proj_full > threshold

            node_oblique = Node(
                idxs=idxs,
                value=dt.tree_.value[0],
                tree_num=tree_num,
                feature=feat_subset[0],
                threshold=threshold,
                impurity=float(impurity[0]),
                impurity_reduction=float(imp_red),
                depth=depth,
            )
            node_oblique.is_oblique = True
            node_oblique.oblique_features = feat_subset
            node_oblique.oblique_weights = weights.tolist()
            node_oblique.oblique_bias = intercept

            node_oblique.setattrs(
                left_temp=Node(
                    idxs=idxs_left,
                    value=dt.tree_.value[1],
                    tree_num=tree_num,
                    depth=(depth or 0) + 1,
                ),
                right_temp=Node(
                    idxs=idxs_right,
                    value=dt.tree_.value[2],
                    tree_num=tree_num,
                    depth=(depth or 0) + 1,
                ),
            )

            self.oblique_splits_info.append({
                'features': feat_subset,
                'weights': weights.tolist(),
                'threshold': float(threshold),
                'impurity_reduction': float(imp_red),
            })

            return node_oblique

        except Exception:
            return None

    def _predict_tree(self, root: Node, X: np.ndarray) -> np.ndarray:
        """Override to handle oblique splits during prediction."""
        def _predict_single(node, x):
            if node.left is None and node.right is None:
                return node.value

            if getattr(node, 'is_oblique', False):
                proj = sum(
                    w * x[f]
                    for w, f in zip(node.oblique_weights, node.oblique_features)
                )
                proj += node.oblique_bias
                go_left = proj <= node.threshold
            else:
                go_left = x[node.feature] <= node.threshold

            if go_left:
                return _predict_single(node.left, x) if node.left else node.value
            else:
                return _predict_single(node.right, x) if node.right else node.value

        preds = np.zeros((X.shape[0], self.n_outputs))
        for i in range(X.shape[0]):
            preds[i] = _predict_single(root, X[i])
        return preds

### Helper Functions

Utilities for computing AUC, counting splits, interpretability scoring, and extracting split descriptions.

In [None]:
def compute_auc(y_true: np.ndarray, y_proba: np.ndarray) -> float:
    """Compute AUC. Handles binary and multi-class via OVR."""
    try:
        n_classes = y_proba.shape[1]
        if n_classes == 2:
            return float(roc_auc_score(y_true, y_proba[:, 1]))
        else:
            return float(roc_auc_score(
                y_true, y_proba, multi_class='ovr', average='weighted'
            ))
    except Exception:
        return float('nan')


def count_splits_in_tree(node) -> tuple:
    """Count total splits and oblique splits in a tree."""
    if node is None or (node.left is None and node.right is None):
        return 0, 0
    is_obl = 1 if getattr(node, 'is_oblique', False) else 0
    left_total, left_obl = count_splits_in_tree(node.left)
    right_total, right_obl = count_splits_in_tree(node.right)
    return 1 + left_total + right_total, is_obl + left_obl + right_obl


def count_model_splits(model) -> tuple:
    """Count total and oblique splits across all trees in a FIGS model."""
    total = 0
    oblique = 0
    for tree in model.trees_:
        t, o = count_splits_in_tree(tree)
        total += t
        oblique += o
    return total, oblique


def compute_mean_features_per_oblique(model) -> float:
    """Mean number of features used per oblique split."""
    def _collect(node, counts):
        if node is None:
            return
        if getattr(node, 'is_oblique', False):
            counts.append(len(node.oblique_features))
        _collect(node.left, counts)
        _collect(node.right, counts)

    all_counts: list = []
    for tree in model.trees_:
        _collect(tree, all_counts)
    if not all_counts:
        return 0.0
    return float(np.mean(all_counts))


def compute_interpretability_score(
    model,
    synergy_matrix_validate: np.ndarray,
) -> float:
    """Non-circular interpretability score.

    Fraction of oblique splits whose feature pairs ALL rank in top-25%
    of synergy scores computed on the VALIDATE subset.
    """
    def _collect_oblique_features(node, oblique_feats):
        if node is None:
            return
        if getattr(node, 'is_oblique', False):
            oblique_feats.append(node.oblique_features)
        _collect_oblique_features(node.left, oblique_feats)
        _collect_oblique_features(node.right, oblique_feats)

    oblique_feats: list = []
    for tree in model.trees_:
        _collect_oblique_features(tree, oblique_feats)

    if not oblique_feats:
        return float('nan')

    p = synergy_matrix_validate.shape[0]
    upper_tri = synergy_matrix_validate[np.triu_indices(p, k=1)]
    if len(upper_tri) == 0 or np.all(upper_tri == 0):
        return 0.0
    top25_cutoff = float(np.percentile(upper_tri, 75))

    n_aligned = 0
    for feat_list in oblique_feats:
        all_high = True
        for i in range(len(feat_list)):
            for j in range(i + 1, len(feat_list)):
                if synergy_matrix_validate[feat_list[i], feat_list[j]] < top25_cutoff:
                    all_high = False
                    break
            if not all_high:
                break
        if all_high:
            n_aligned += 1

    return float(n_aligned / len(oblique_feats))

### Phase 4 — Main Experiment Loop

Runs all 6 methods on each dataset with stratified K-fold CV. Computes fold-aware synergy, trains models, and collects per-fold accuracy/AUC metrics.

In [None]:
def run_experiment(dataset_order, max_rules_values, n_folds, per_dataset_timeout):
    """Run the SG-FIGS experiment across datasets and folds."""
    global_start = time.time()
    print("=" * 60)
    print("SG-FIGS Experiment — Starting")
    print(f"  Datasets: {dataset_order}")
    print(f"  Max rules: {max_rules_values}, Folds: {n_folds}")
    print("=" * 60)

    results: list = []
    synergy_stability: dict = {}
    method_names = [
        'FIGS', 'RO-FIGS',
        'SG-FIGS-10', 'SG-FIGS-25', 'SG-FIGS-50',
        'GradientBoosting',
    ]

    for ds_idx, ds_name in enumerate(dataset_order):
        ds_start = time.time()
        print(f"\n[{ds_idx+1}/{len(dataset_order)}] Processing: {ds_name}")

        try:
            X, y, feat_names = load_dataset(ds_name, DATASETS[ds_name])
        except Exception as e:
            print(f"  Failed to load dataset {ds_name}: {e}")
            continue

        print(f"  Shape: {X.shape}, classes: {len(np.unique(y))}")

        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
        fold_synergy_graphs: list = []

        for fold_idx, (train_idx, test_idx) in enumerate(skf.split(X, y)):
            fold_start = time.time()
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test = y[train_idx], y[test_idx]

            # ── FOLD-AWARE SYNERGY COMPUTATION ──
            n_train = len(train_idx)
            rng = np.random.RandomState(42 + fold_idx)
            perm = rng.permutation(n_train)
            split_pt = int(n_train * 0.8)
            syn_build_idx = perm[:split_pt]
            syn_validate_idx = perm[split_pt:]

            X_syn_build = X_train[syn_build_idx]
            y_syn_build = y_train[syn_build_idx]
            X_syn_validate = X_train[syn_validate_idx]
            y_syn_validate = y_train[syn_validate_idx]

            synergy_matrix, syn_time = compute_pairwise_synergy_matrix(
                X_syn_build, y_syn_build,
            )
            fold_synergy_graphs.append(synergy_matrix)

            synergy_matrix_validate = compute_fast_synergy_proxy(
                X_syn_validate, y_syn_validate,
            )

            for max_rules in max_rules_values:
                # ── METHOD 1: Standard FIGS ──
                try:
                    figs = FIGSClassifier(max_rules=max_rules, random_state=42)
                    figs.fit(X_train, y_train)
                    y_pred_figs = figs.predict(X_test)
                    y_proba_figs = figs.predict_proba(X_test)
                    acc_figs = float(accuracy_score(y_test, y_pred_figs))
                    auc_figs = compute_auc(y_test, y_proba_figs)
                    n_splits_figs, _ = count_model_splits(figs)
                    results.append({
                        'method': 'FIGS', 'dataset': ds_name, 'fold': fold_idx,
                        'max_rules': max_rules, 'accuracy': acc_figs, 'auc': auc_figs,
                        'n_splits': n_splits_figs, 'n_oblique': 0,
                        'oblique_fraction': 0.0,
                        'mean_features_per_oblique': 0.0,
                        'interpretability_score': float('nan'),
                        'synergy_time_s': 0.0,
                    })
                except Exception as e:
                    print(f"  FIGS failed: {e}")

                # ── METHOD 2: RO-FIGS ──
                try:
                    rofigs = RandomObliqueFIGS(
                        beam_size=3, max_rules=max_rules, random_state=42,
                    )
                    rofigs.fit(X_train, y_train)
                    y_pred_ro = rofigs.predict(X_test)
                    y_proba_ro = rofigs.predict_proba(X_test)
                    acc_ro = float(accuracy_score(y_test, y_pred_ro))
                    auc_ro = compute_auc(y_test, y_proba_ro)
                    n_splits_ro, n_obl_ro = count_model_splits(rofigs)
                    results.append({
                        'method': 'RO-FIGS', 'dataset': ds_name, 'fold': fold_idx,
                        'max_rules': max_rules, 'accuracy': acc_ro, 'auc': auc_ro,
                        'n_splits': n_splits_ro, 'n_oblique': n_obl_ro,
                        'oblique_fraction': n_obl_ro / max(n_splits_ro, 1),
                        'mean_features_per_oblique': compute_mean_features_per_oblique(rofigs),
                        'interpretability_score': float('nan'),
                        'synergy_time_s': 0.0,
                    })
                except Exception as e:
                    print(f"  RO-FIGS failed: {e}")

                # ── METHODS 3-5: SG-FIGS variants ──
                for sg_name, percentile in THRESHOLDS.items():
                    try:
                        adj, n_edges, cutoff = build_synergy_graph(
                            synergy_matrix, percentile,
                        )
                        sgfigs = SynergyGuidedFIGS(
                            synergy_adj=adj,
                            synergy_matrix=synergy_matrix,
                            max_rules=max_rules,
                            random_state=42,
                        )
                        sgfigs.fit(X_train, y_train)
                        y_pred_sg = sgfigs.predict(X_test)
                        y_proba_sg = sgfigs.predict_proba(X_test)
                        acc_sg = float(accuracy_score(y_test, y_pred_sg))
                        auc_sg = compute_auc(y_test, y_proba_sg)
                        n_splits_sg, n_obl_sg = count_model_splits(sgfigs)
                        interp_score = compute_interpretability_score(
                            sgfigs, synergy_matrix_validate,
                        )
                        results.append({
                            'method': sg_name, 'dataset': ds_name, 'fold': fold_idx,
                            'max_rules': max_rules, 'accuracy': acc_sg, 'auc': auc_sg,
                            'n_splits': n_splits_sg, 'n_oblique': n_obl_sg,
                            'oblique_fraction': n_obl_sg / max(n_splits_sg, 1),
                            'mean_features_per_oblique': compute_mean_features_per_oblique(sgfigs),
                            'interpretability_score': interp_score,
                            'synergy_time_s': syn_time,
                        })
                    except Exception as e:
                        print(f"  {sg_name} failed: {e}")

                # ── METHOD 6: GradientBoosting baseline ──
                try:
                    gbc = GradientBoostingClassifier(
                        n_estimators=100, max_depth=3, random_state=42,
                    )
                    gbc.fit(X_train, y_train)
                    y_pred_gb = gbc.predict(X_test)
                    y_proba_gb = gbc.predict_proba(X_test)
                    acc_gb = float(accuracy_score(y_test, y_pred_gb))
                    auc_gb = compute_auc(y_test, y_proba_gb)
                    results.append({
                        'method': 'GradientBoosting', 'dataset': ds_name,
                        'fold': fold_idx, 'max_rules': max_rules,
                        'accuracy': acc_gb, 'auc': auc_gb,
                        'n_splits': -1, 'n_oblique': 0,
                        'oblique_fraction': 0.0,
                        'mean_features_per_oblique': 0.0,
                        'interpretability_score': float('nan'),
                        'synergy_time_s': 0.0,
                    })
                except Exception as e:
                    print(f"  GBC failed: {e}")

            fold_elapsed = time.time() - fold_start
            print(f"  Fold {fold_idx} done in {fold_elapsed:.1f}s")

        # ── SYNERGY GRAPH STABILITY (Jaccard) ──
        jaccard_pairs: list = []
        for fi in range(len(fold_synergy_graphs)):
            for fj in range(fi + 1, len(fold_synergy_graphs)):
                adj_i, _, _ = build_synergy_graph(fold_synergy_graphs[fi], 75)
                adj_j, _, _ = build_synergy_graph(fold_synergy_graphs[fj], 75)
                edges_i = {(a, b) for a in adj_i for b in adj_i[a] if a < b}
                edges_j = {(a, b) for a in adj_j for b in adj_j[a] if a < b}
                union = edges_i | edges_j
                if len(union) > 0:
                    jacc = len(edges_i & edges_j) / len(union)
                else:
                    jacc = 1.0
                jaccard_pairs.append(jacc)
        synergy_stability[ds_name] = {
            'mean_jaccard': float(np.mean(jaccard_pairs)) if jaccard_pairs else 0.0,
            'std_jaccard': float(np.std(jaccard_pairs)) if jaccard_pairs else 0.0,
        }

        ds_elapsed = time.time() - ds_start
        print(f"  Dataset {ds_name} done in {ds_elapsed:.1f}s | Results: {len(results)}")

    total_runtime = time.time() - global_start
    print(f"\n{'='*60}")
    print(f"EXPERIMENT COMPLETE — {total_runtime:.1f}s, {len(results)} results")
    print(f"{'='*60}")

    return results, synergy_stability, method_names

In [None]:
results, synergy_stability, method_names = run_experiment(
    dataset_order=DATASET_ORDER,
    max_rules_values=MAX_RULES_VALUES,
    n_folds=N_FOLDS,
    per_dataset_timeout=PER_DATASET_TIMEOUT,
)

### Results Visualization

Summary table of mean accuracy per method, bar chart comparison across methods, and synergy graph stability (Jaccard similarity).

In [None]:
def visualize_results(results, synergy_stability, method_names, title_prefix=""):
    """Reusable visualization: summary table + bar chart + synergy stability."""
    # ── Summary Table ──
    mr_for_summary = max(set(r['max_rules'] for r in results))
    print(f"\n{'='*60}")
    print(f"{title_prefix}Summary (max_rules={mr_for_summary}, mean accuracy)")
    print(f"{'='*60}")
    summary = {}
    for m in method_names:
        m_results = [r for r in results if r['method'] == m and r['max_rules'] == mr_for_summary]
        if m_results:
            accs = [r['accuracy'] for r in m_results]
            aucs = [r['auc'] for r in m_results if not np.isnan(r['auc'])]
            summary[m] = {
                'mean_acc': np.mean(accs), 'std_acc': np.std(accs),
                'mean_auc': np.mean(aucs) if aucs else float('nan'),
                'n': len(m_results),
            }
            print(f"  {m:20s}: {summary[m]['mean_acc']:.4f} ± {summary[m]['std_acc']:.4f}  (n={summary[m]['n']})")

    # ── Bar Chart: Mean Accuracy by Method ──
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    methods_with_data = [m for m in method_names if m in summary]
    mean_accs = [summary[m]['mean_acc'] for m in methods_with_data]
    std_accs = [summary[m]['std_acc'] for m in methods_with_data]

    colors = ['#2196F3', '#FF9800', '#E91E63', '#9C27B0', '#673AB7', '#4CAF50']
    bars = axes[0].bar(range(len(methods_with_data)), mean_accs,
                       yerr=std_accs, capsize=4, color=colors[:len(methods_with_data)],
                       edgecolor='black', linewidth=0.5)
    axes[0].set_xticks(range(len(methods_with_data)))
    axes[0].set_xticklabels(methods_with_data, rotation=30, ha='right', fontsize=9)
    axes[0].set_ylabel('Mean Accuracy')
    axes[0].set_title(f'{title_prefix}Mean Accuracy by Method (max_rules={mr_for_summary})')
    axes[0].set_ylim(0, 1.05)
    for bar, acc in zip(bars, mean_accs):
        axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,
                     f'{acc:.3f}', ha='center', va='bottom', fontsize=8)

    # ── Per-Dataset Accuracy Heatmap ──
    datasets_in_results = sorted(set(r['dataset'] for r in results))
    acc_matrix = np.full((len(datasets_in_results), len(methods_with_data)), np.nan)
    for di, ds in enumerate(datasets_in_results):
        for mi, m in enumerate(methods_with_data):
            ds_m = [r['accuracy'] for r in results
                    if r['dataset'] == ds and r['method'] == m and r['max_rules'] == mr_for_summary]
            if ds_m:
                acc_matrix[di, mi] = np.mean(ds_m)

    im = axes[1].imshow(acc_matrix, cmap='RdYlGn', aspect='auto', vmin=0.3, vmax=1.0)
    axes[1].set_xticks(range(len(methods_with_data)))
    axes[1].set_xticklabels(methods_with_data, rotation=30, ha='right', fontsize=9)
    axes[1].set_yticks(range(len(datasets_in_results)))
    axes[1].set_yticklabels(datasets_in_results, fontsize=9)
    axes[1].set_title(f'{title_prefix}Accuracy by Dataset × Method')
    plt.colorbar(im, ax=axes[1], label='Accuracy')

    # Annotate cells
    for di in range(len(datasets_in_results)):
        for mi in range(len(methods_with_data)):
            val = acc_matrix[di, mi]
            if not np.isnan(val):
                axes[1].text(mi, di, f'{val:.2f}', ha='center', va='center', fontsize=7,
                            color='white' if val < 0.6 else 'black')

    plt.tight_layout()
    plt.show()

    # ── Synergy Stability ──
    if synergy_stability:
        print(f"\n{'='*60}")
        print(f"{title_prefix}Synergy Graph Stability (Jaccard)")
        print(f"{'='*60}")
        for ds, stab in synergy_stability.items():
            print(f"  {ds:20s}: {stab['mean_jaccard']:.3f} ± {stab['std_jaccard']:.3f}")

    return summary

In [None]:
summary_mini = visualize_results(results, synergy_stability, method_names, title_prefix="[Quick Demo] ")

---
## Full Run — Original Parameters

Runs on all 12 datasets with **5-fold CV** and `max_rules=[5, 10, 15]` — matching the original experiment parameters. This takes approximately 15-20 minutes.

In [None]:
data = load_full()
print(f"Loaded {len(data['datasets'])} datasets (full)")
for ds in data['datasets']:
    print(f"  {ds['dataset']}: {len(ds['examples'])} examples")

In [None]:
# ── FULL RUN PARAMS (original) ────────────────────────────────────────────────
DATASET_ORDER_FULL = [
    'banknote', 'wine', 'glass',
    'diabetes', 'heart_statlog', 'sonar',
    'breast_cancer', 'ionosphere', 'vehicle',
    'segment', 'credit_g', 'australian',
]
MAX_RULES_VALUES_FULL = [5, 10, 15]
N_FOLDS_FULL = 5
PER_DATASET_TIMEOUT_FULL = 600

results_full, synergy_stability_full, method_names_full = run_experiment(
    dataset_order=DATASET_ORDER_FULL,
    max_rules_values=MAX_RULES_VALUES_FULL,
    n_folds=N_FOLDS_FULL,
    per_dataset_timeout=PER_DATASET_TIMEOUT_FULL,
)

In [None]:
summary_full = visualize_results(results_full, synergy_stability_full, method_names_full, title_prefix="[Full Run] ")