In [13]:
import numpy as np
import pandas as pd
import math
from collections import defaultdict

from tqdm.auto import tqdm
from joblib import Parallel, delayed

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score

from src.utils import load_data
from src.WDI_1NN import WDI_KNN
from src.EAC_1NN import EAC_KNN
from src.MBW_LR import MBW_LR
from src.ACM_SVM import ACM_SVM
from src.CASIM import CASIM

In [14]:
# Adjust the path as needed
X, y, _, _ = load_data("./data/synthetic/")

print("Data shape:", X.shape)   # (n_samples, n_alarms, n_steps)
print("Labels shape:", y.shape)
n_samples, n_alarms, n_steps = X.shape

Data shape: (1875, 10, 60)
Labels shape: (1875,)


In [15]:
def stratified_train_calibration_split(X, y, calib_frac=0.3, random_state=42):
    """
    Split (X, y) into train and calibration sets in a stratified way, per class.
    """
    rng = np.random.RandomState(random_state)
    X_train_list, y_train_list = [], []
    X_cali_list, y_cali_list = [], []
    
    for cl in np.unique(y):
        idxs = np.where(y == cl)[0]
        rng.shuffle(idxs)
        n_calib = max(1, int(len(idxs) * calib_frac))
        calib_idxs = idxs[:n_calib]
        train_idxs = idxs[n_calib:]
        
        X_cali_list.append(X[calib_idxs])
        y_cali_list.append(y[calib_idxs])
        X_train_list.append(X[train_idxs])
        y_train_list.append(y[train_idxs])
    
    X_train = np.concatenate(X_train_list, axis=0)
    y_train = np.concatenate(y_train_list, axis=0)
    X_cali = np.concatenate(X_cali_list, axis=0)
    y_cali = np.concatenate(y_cali_list, axis=0)
    
    return X_train, y_train, X_cali, y_cali


def jaccard_distance(a, b):
    """
    Multivariate Jaccard distance between two alarm floods (binary or non-negative).
    a, b shape: (n_alarms, n_steps_partial)
    """
    a_flat = a.reshape(-1) > 0
    b_flat = b.reshape(-1) > 0
    intersection = np.logical_and(a_flat, b_flat).sum()
    union = np.logical_or(a_flat, b_flat).sum()
    if union == 0:
        return 0.0
    return 1.0 - intersection / union

def hamming_distance(a, b):
    """
    Hamming distance between two alarm floods (binary or non-negative).
    Counts the fraction of positions where the two sequences differ.
    a, b shape: (n_alarms, n_steps_partial)
    
    Returns:
        float in [0, 1]: proportion of differing positions
    """
    a_binary = (a.reshape(-1) > 0).astype(int)
    b_binary = (b.reshape(-1) > 0).astype(int)
    total_positions = len(a_binary)
    if total_positions == 0:
        return 0.0
    differing = np.sum(a_binary != b_binary)
    return differing / total_positions

In [16]:
class ConformalAlarmClassifier:
    """
    Wrapper around an AFC classifier with class-wise conformal thresholds per step
    and a delay-timer–based prediction set construction.
    """

    def __init__(
        self,
        base_clf_cls,
        clf_params,
        step_list,
        clf_stepwise=False,
        input_type="PROBA",   # "PROBA" or "DIST"
        alpha=0.05,
        delay_timer=5,
        random_state=42,
    ):
        self.base_clf_cls = base_clf_cls
        self.clf_params = clf_params
        self.step_list = sorted(step_list)
        self.clf_stepwise = clf_stepwise
        self.input_type = input_type
        self.alpha = alpha
        self.delay_timer = delay_timer
        self.random_state = random_state

        self.clf_dict_ = {}
        self.thresholds_ = {}      # step -> array(len(classes))
        self.class_labels_ = None

    @property
    def name(self):
        return f"Conformal_{self.base_clf_cls.__name__}"

    # ---------- CP threshold helpers (unchanged) ----------

    def _cp_threshold_proba(self, scores):
        scores = np.sort(scores)  # ascending
        n = len(scores)
        if n == 0:
            return 0.0
        k = math.floor(self.alpha * (n + 1))
        if k < 1:
            return scores[0] - 1e-12
        else:
            return scores[k - 1]

    def _cp_threshold_dist(self, scores):
        scores = np.sort(scores)  # ascending (small = good)
        n = len(scores)
        if n == 0:
            return float("inf")
        k = math.ceil((1.0 - self.alpha) * (n + 1))
        if k > n:
            k = n
        return scores[k - 1]

    def _compute_thresholds_for_step(self, y_cali, y_proba):
        thresholds = []
        for class_idx, class_label in enumerate(self.class_labels_):
            mask = (y_cali == class_label)
            scores = y_proba[mask, class_idx]
            if self.input_type == "PROBA":
                tau = self._cp_threshold_proba(scores)
            else:
                tau = self._cp_threshold_dist(scores)
            thresholds.append(tau)
        return np.array(thresholds)

    # ---------- training & prediction ----------

    def fit(self, X_train, y_train, X_cali, y_cali):
        np.random.seed(self.random_state)

        # Train classifier(s)
        if self.clf_stepwise:
            for idx, step in enumerate(self.step_list):
                X_step = X_train[:, :, :step]
                clf = self.base_clf_cls(self.clf_params)
                clf.fit(X_step, y_train)
                self.clf_dict_[step] = clf
                if idx == 0:
                    self.class_labels_ = clf.classes_
        else:
            clf = self.base_clf_cls(self.clf_params)
            clf.fit(X_train, y_train)
            self.clf_dict_[0] = clf
            self.class_labels_ = clf.classes_

        # Compute thresholds for each step from calibration set
        for step in self.step_list:
            X_cali_step = X_cali[:, :, :step]
            y_proba = self.predict_proba(X_cali_step, step)
            self.thresholds_[step] = self._compute_thresholds_for_step(y_cali, y_proba)

    def predict_proba(self, X, step):
        if self.clf_stepwise:
            clf = self.clf_dict_[step]
        else:
            clf = self.clf_dict_[0]
        return clf.predict_proba(X)

    # ---------- raw CP sets (no delay) ----------

    def _raw_conformal_sets_for_step(self, X_step, step):
        """
        Compute raw CP prediction sets for a single step (no delay timer).

        Returns:
            raw_sets: array(shape=(n_samples,), dtype=object of sets)
            y_pred_proba: (n_samples, n_classes)
        """
        y_pred_proba = self.predict_proba(X_step, step)
        n_samples, n_classes = y_pred_proba.shape
        thresholds = self.thresholds_[step]

        raw_sets = []

        for i in range(n_samples):
            inst_set = set()
            for class_idx, class_label in enumerate(self.class_labels_):
                score = y_pred_proba[i, class_idx]
                tau = thresholds[class_idx]

                if self.input_type == "PROBA":
                    if score >= tau:
                        inst_set.add(class_label)
                else:  # "DIST"
                    if score <= tau:
                        inst_set.add(class_label)
            raw_sets.append(inst_set)

        raw_sets = np.array(raw_sets, dtype=object)
        return raw_sets, y_pred_proba

    # ---------- final CP sets with set-level delay timer ----------

    def predict_conformal_full_sequence(self, X):
        """
        Compute step-wise conformal prediction sets for a full sequence of alarm floods,
        with a set-level delay timer to suppress chattering of bifurcation points.

        X shape: (n_samples, n_alarms, n_steps)

        Returns:
            conformal_sets: dict[step] -> array(shape=(n_samples,), dtype=object of sets)
                            (these are the delayed, "accepted" sets)
            conformal_vectors: dict[step] -> (n_samples, n_classes) boolean vectors
            proba_per_step: dict[step] -> (n_samples, n_classes) scores (raw)
        """
        n_samples = X.shape[0]

        # For each sample we track:
        # - ref_set: last accepted prediction set
        # - timer_running: whether "change timer" is active
        # - timer_start_step: step at which change was first observed
        # - ref_set_at_timer_start: accepted set when timer started
        ref_set = [None] * n_samples
        timer_running = [False] * n_samples
        timer_start_step = [None] * n_samples
        ref_set_at_timer_start = [None] * n_samples

        conformal_sets = {}
        conformal_vectors = {}
        proba_per_step = {}

        for step in self.step_list:
            X_step = X[:, :, :step]
            raw_sets, y_pred_proba = self._raw_conformal_sets_for_step(X_step, step)

            step_sets_final = []

            for i in range(n_samples):
                if ref_set[i] is None:
                    # First step: accept raw set directly
                    ref_set[i] = raw_sets[i]
                    timer_running[i] = False
                    timer_start_step[i] = step
                    ref_set_at_timer_start[i] = ref_set[i]
                    step_sets_final.append(ref_set[i])
                    continue

                if not timer_running[i]:
                    # No timer currently running
                    if raw_sets[i] != ref_set[i]:
                        # A change is observed -> start timer
                        timer_running[i] = True
                        timer_start_step[i] = step
                        ref_set_at_timer_start[i] = ref_set[i]
                        # Do NOT accept change yet
                    # In any case, output current accepted set
                    step_sets_final.append(ref_set[i])
                else:
                    # A timer is running
                    if step - timer_start_step[i] >= self.delay_timer:
                        # Timer elapsed -> check net change vs set at timer start
                        if raw_sets[i] != ref_set_at_timer_start[i]:
                            # Accept new set as bifurcation
                            ref_set[i] = raw_sets[i]
                        # Reset timer in any case
                        timer_running[i] = False
                        timer_start_step[i] = step
                        ref_set_at_timer_start[i] = ref_set[i]
                    # Output accepted set (possibly just updated)
                    step_sets_final.append(ref_set[i])

            step_sets_final = np.array(step_sets_final, dtype=object)
            conformal_sets[step] = step_sets_final

            # boolean vectors for accepted sets
            vectors = np.array(
                [[cls in s for cls in self.class_labels_] for s in step_sets_final],
                dtype=bool,
            )
            conformal_vectors[step] = vectors
            proba_per_step[step] = y_pred_proba

        return conformal_sets, conformal_vectors, proba_per_step

In [17]:
class BifurcationDetector:
    """
    Detect bifurcation points where the size of the CP prediction set decreases.
    """

    def __init__(self, step_list):
        self.step_list = sorted(step_list)

    def detect(self, conformal_sets):
        """
        conformal_sets: dict[step] -> array(shape=(n_samples,), dtype=object of sets)

        Returns:
            events_per_sample: list of length n_samples
                each entry: list of events, where each event is a dict:
                {
                    "bif_step": int,
                    "start_step": int,          # start segment for CF (last bif or 0)
                    "prev_step": int,           # previous step in step_list
                    "prev_set": set,
                    "curr_set": set,
                    "dropped_classes": set
                }
        """
        steps = self.step_list
        n_samples = len(next(iter(conformal_sets.values())))
        events_per_sample = [[] for _ in range(n_samples)]

        for i in range(n_samples):
            last_set = None
            last_bif_step = None
            prev_step = None

            for step in steps:
                curr_set = conformal_sets[step][i]

                if last_set is not None:
                    if (len(curr_set) < len(last_set)) and len(curr_set) > 0:
                        dropped = last_set - curr_set
                        start_step = last_bif_step if last_bif_step is not None else steps[0]

                        events_per_sample[i].append(
                            {
                                "bif_step": step,
                                "start_step": start_step,
                                "prev_step": prev_step,
                                "prev_set": last_set,
                                "curr_set": curr_set,
                                "dropped_classes": dropped,
                            }
                        )
                        last_bif_step = step

                last_set = curr_set
                prev_step = step

        return events_per_sample

In [18]:
class CounterfactualGenerator:
    """
    Generate counterfactual explanations at bifurcation points.

    For each bifurcation and dropped class:
    - Find k nearest neighbors from calibration set belonging to the target class.
    - Depending on `strategy`, construct one or more candidate CFs by swapping alarm
      variables from the neighbor into the original flood in [0, bif_step).
    - For each candidate, compute a cost:
          cost = shortfall + lambda_distance * dist_cf_orig
      where shortfall measures violation of the conformal condition for the target class.
    - Choose the CF with minimal cost among all candidates that include the target class
      in the CP prediction set at bif_step.

      * random / all: use only the single closest neighbor
      * greedy: compute a CF candidate per neighbor (up to k_neighbors), then select
        the candidate with the smallest cost.
    """

    def __init__(
        self,
        conformal_classifier,   # ConformalAlarmClassifier
        X_cali,
        y_cali,
        step_list,
        k_neighbors=3,
        max_iter=200,
        strategy="random",      # "random", "all", "greedy"
        lambda_distance=1.0,
        random_state=42,
    ):
        self.cc = conformal_classifier
        self.X_cali = X_cali
        self.y_cali = y_cali
        self.step_list = sorted(step_list)
        self.k_neighbors = k_neighbors
        self.max_iter = max_iter
        self.strategy = strategy
        self.lambda_distance = lambda_distance
        self.rng = np.random.RandomState(random_state)

        # index calibration samples by class label
        self.calib_indices_by_class = {
            cl: np.where(y_cali == cl)[0] for cl in self.cc.class_labels_
        }

    # ------------------------------------------------------------------
    # Helper: evaluate a single candidate CF
    # ------------------------------------------------------------------
    def _evaluate_candidate(
        self,
        x_candidate,
        x_orig,
        neighbor,
        step_bif,
        start_step,
        target_idx,
    ):
        """
        Evaluate a candidate counterfactual x_candidate.

        Returns:
            dict with:
                - scores: posterior scores at step_bif
                - cp_set: CP prediction set at step_bif
                - shortfall: normalized violation amount for target class in [0, 1]
                - dist_cf_orig: Hamming distance(CF, orig) up to step_bif
                - dist_cf_nn:   Hamming distance(CF, neighbor) up to step_bif
                - cost: shortfall + lambda_distance * dist_cf_orig
        """
        # Scores for truncated sequence up to bifurcation step
        scores = self.cc.predict_proba(
            x_candidate[np.newaxis, :, :step_bif], step_bif
        )[0]
        thresholds = self.cc.thresholds_[step_bif]

        # Build CP prediction set
        cp_set = set()
        if self.cc.input_type == "PROBA":
            for idx, cls in enumerate(self.cc.class_labels_):
                if scores[idx] >= thresholds[idx]:
                    cp_set.add(cls)

            # Normalized shortfall in [0, 1]:
            # - if p >= tau -> 0
            # - if p = 0    -> 1
            # - if 0 < p < tau -> (tau - p) / tau
            prob_target = scores[target_idx]
            tau = thresholds[target_idx]
            if prob_target >= tau or tau <= 0.0:
                shortfall = 0.0
            else:
                shortfall = (tau - prob_target) / tau

        else:  # "DIST"
            for idx, cls in enumerate(self.cc.class_labels_):
                if scores[idx] <= thresholds[idx]:
                    cp_set.add(cls)

            # Normalized shortfall in [0, 1]:
            # - CP condition: dist <= tau
            # - if dist <= tau -> 0
            # - if dist = 1    -> 1
            # - if tau < dist < 1 -> (dist - tau) / (1 - tau)
            dist_target = scores[target_idx]
            tau = thresholds[target_idx]
            if dist_target <= tau or tau >= 1.0:
                shortfall = 0.0
            else:
                denom = max(1.0 - tau, 1e-12)
                shortfall = (dist_target - tau) / denom

        # Distances (always up to step_bif)
        dist_cf_orig = hamming_distance(
            x_candidate[:, :step_bif], x_orig[:, :step_bif]
        )
        dist_cf_nn = hamming_distance(
            x_candidate[:, :step_bif], neighbor[:, :step_bif]
        )

        cost = shortfall + self.lambda_distance * dist_cf_orig

        return {
            "scores": scores,
            "cp_set": cp_set,
            "shortfall": shortfall,
            "dist_cf_orig": dist_cf_orig,
            "dist_cf_nn": dist_cf_nn,
            "cost": cost,
        }

    # ------------------------------------------------------------------
    # Strategies for a single neighbor
    # ------------------------------------------------------------------
    def _cf_with_neighbor_random(
        self,
        x_sample,
        neighbor,
        step_bif,
        start_step,
        target_idx,
    ):
        """
        Strategy 1: random swaps of alarm variables.
        Swaps are applied once on the time range [0, step_bif).
        """
        n_alarms = x_sample.shape[0]

        # Single random draw of variables to swap
        n_vars_to_swap = self.rng.randint(1, n_alarms + 1)
        var_indices = self.rng.choice(n_alarms, size=n_vars_to_swap, replace=False)

        x_candidate = x_sample.copy()
        x_candidate[var_indices, :step_bif] = neighbor[var_indices, :step_bif]

        eval_res = self._evaluate_candidate(
            x_candidate, x_sample, neighbor, step_bif, start_step, target_idx
        )

        target_class = self.cc.class_labels_[target_idx]
        if target_class in eval_res["cp_set"]:
            return {
                "var_indices": var_indices,
                "attempts": 1,
                **eval_res,
            }
        else:
            return None

    def _cf_with_neighbor_all(
        self,
        x_sample,
        neighbor,
        step_bif,
        start_step,
        target_idx,
    ):
        """
        Strategy 2: swap all alarm variables in [0, step_bif).
        """
        n_alarms = x_sample.shape[0]
        var_indices = np.arange(n_alarms)

        x_candidate = x_sample.copy()
        x_candidate[:, :step_bif] = neighbor[:, :step_bif]

        eval_res = self._evaluate_candidate(
            x_candidate, x_sample, neighbor, step_bif, start_step, target_idx
        )

        target_class = self.cc.class_labels_[target_idx]
        if target_class in eval_res["cp_set"]:
            return {
                "var_indices": var_indices,
                "attempts": 1,
                **eval_res,
            }
        else:
            return None

    def _cf_with_neighbor_greedy(
        self,
        x_sample,
        neighbor,
        step_bif,
        start_step,
        target_idx,
    ):
        """
        Strategy 3: greedy search over variables that differ between neighbor and original
        in [0, step_bif).

        At each iteration:
        - for each remaining candidate variable, build CF with current set S plus that variable
        - evaluate cost, choose variable that yields minimum cost
        - stop if the resulting CP set includes the target class
        """
        # variables that actually differ in [0, step_bif)
        diff_mask = np.any(
            x_sample[:, :step_bif] != neighbor[:, :step_bif],
            axis=1,
        )
        candidate_vars = np.where(diff_mask)[0]

        if len(candidate_vars) == 0:
            return None

        S = set()
        best_global = None
        current_attempt = 0

        max_iters = min(len(candidate_vars), self.max_iter)

        for _ in range(max_iters):
            best_local = None

            for var_j in candidate_vars:
                if var_j in S:
                    continue

                var_indices = np.array(sorted(list(S | {var_j})), dtype=int)
                x_candidate = x_sample.copy()
                x_candidate[var_indices, :step_bif] = neighbor[
                    var_indices, :step_bif
                ]

                eval_res = self._evaluate_candidate(
                    x_candidate, x_sample, neighbor, step_bif, start_step, target_idx
                )
                current_attempt += 1

                if (best_local is None) or (eval_res["cost"] < best_local["cost"]):
                    best_local = {
                        "var_indices": var_indices,
                        "attempts": current_attempt,
                        **eval_res,
                    }

            if best_local is None:
                break

            # update S with the best variable set from this iteration
            S = set(best_local["var_indices"])
            target_class = self.cc.class_labels_[target_idx]

            # if target is now in CP set, we are done
            if target_class in best_local["cp_set"]:
                best_global = best_local
                break

        return best_global

    def _find_cf_with_neighbor(
        self,
        x_sample,
        neighbor,
        step_bif,
        start_step,
        target_idx,
    ):
        """
        Dispatch to the chosen strategy for a single neighbor.
        """
        if self.strategy == "random":
            return self._cf_with_neighbor_random(
                x_sample, neighbor, step_bif, start_step, target_idx
            )
        elif self.strategy == "all":
            return self._cf_with_neighbor_all(
                x_sample, neighbor, step_bif, start_step, target_idx
            )
        elif self.strategy == "greedy":
            return self._cf_with_neighbor_greedy(
                x_sample, neighbor, step_bif, start_step, target_idx
            )
        else:
            # fallback: random
            return self._cf_with_neighbor_random(
                x_sample, neighbor, step_bif, start_step, target_idx
            )

    # ------------------------------------------------------------------
    # Main CF search for one target class
    # ------------------------------------------------------------------
    def _search_cf_for_target(self, x_sample, step_bif, start_step, target_class):
        """
        Search for a counterfactual for a single target class at a given bifurcation.

        Returns:
            dict with CF info if found, otherwise failure dict.
        """
        target_idx = np.where(self.cc.class_labels_ == target_class)[0][0]
        idxs_class = self.calib_indices_by_class.get(
            target_class, np.array([], dtype=int)
        )

        if len(idxs_class) == 0:
            return {
                "found": False,
                "neighbor_global_idx": None,
                "neighbor_rank": None,
                "var_indices": None,
                "attempts": 0,
                "cf_set": None,
                "cost": None,
                "dist_cf_orig": None,
                "dist_cf_nn": None,
                "dist_nn_orig_segment": None,
            }

        # partial sequences for distance-based k-NN selection (0:step_bif)
        x_partial = x_sample[:, :step_bif]
        X_cali_class_partial = self.X_cali[idxs_class][:, :, :step_bif]

        # --- NEW: filter neighbors so that target class is in the CP set for the neighbor itself ---
        # Get scores for all candidate neighbors at step_bif
        scores_neighbors = self.cc.predict_proba(X_cali_class_partial, step_bif)
        thresholds = self.cc.thresholds_[step_bif]
        tau_target = thresholds[target_idx]

        if self.cc.input_type == "PROBA":
            cp_mask = scores_neighbors[:, target_idx] >= tau_target
        else:  # "DIST"
            cp_mask = scores_neighbors[:, target_idx] <= tau_target

        # Keep only neighbors where target is already in the CP set
        if not np.any(cp_mask):
            # No suitable neighbor; cannot form a CF consistent with conformal thresholds
            return {
                "found": False,
                "neighbor_global_idx": None,
                "neighbor_rank": None,
                "var_indices": None,
                "attempts": 0,
                "cf_set": None,
                "cost": None,
                "dist_cf_orig": None,
                "dist_cf_nn": None,
                "dist_nn_orig_segment": None,
            }

        idxs_class = idxs_class[cp_mask]
        X_cali_class_partial = X_cali_class_partial[cp_mask]

        # compute distances and get neighbor order among *CP-feasible* neighbors
        dists = np.array(
            [
                hamming_distance(x_partial, X_cali_class_partial[i])
                for i in range(len(idxs_class))
            ]
        )
        nn_order = np.argsort(dists)
        if len(nn_order) == 0:
            return {
                "found": False,
                "neighbor_global_idx": None,
                "neighbor_rank": None,
                "var_indices": None,
                "attempts": 0,
                "cf_set": None,
                "cost": None,
                "dist_cf_orig": None,
                "dist_cf_nn": None,
                "dist_nn_orig_segment": None,
            }

        best_global = None

        # --- Strategy-specific neighbor handling ---
        if self.strategy in ("random", "all"):
            # Only use the single closest CP-feasible neighbor
            nn_rank = nn_order[0]
            global_idx = idxs_class[nn_rank]
            neighbor = self.X_cali[global_idx]

            # distance between NN and original up to last bifurcation step
            if start_step > 0:
                dist_nn_orig_segment = hamming_distance(
                    x_sample[:, :start_step], neighbor[:, :start_step]
                )
            else:
                dist_nn_orig_segment = 0.0

            cf_candidate = self._find_cf_with_neighbor(
                x_sample, neighbor, step_bif, start_step, target_idx
            )
            if cf_candidate is not None:
                target_cls = self.cc.class_labels_[target_idx]
                if target_cls in cf_candidate["cp_set"]:
                    cf_candidate["neighbor_global_idx"] = int(global_idx)
                    cf_candidate["neighbor_rank"] = int(nn_rank)
                    cf_candidate["dist_nn_orig_segment"] = dist_nn_orig_segment
                    best_global = cf_candidate

        else:  # "greedy"
            # Consider up to k_neighbors CP-feasible neighbors, then pick CF with lowest cost
            nn_order_k = nn_order[: min(self.k_neighbors, len(nn_order))]
            for nn_rank in nn_order_k:
                global_idx = idxs_class[nn_rank]
                neighbor = self.X_cali[global_idx]

                if start_step > 0:
                    dist_nn_orig_segment = hamming_distance(
                        x_sample[:, :start_step], neighbor[:, :start_step]
                    )
                else:
                    dist_nn_orig_segment = 0.0

                cf_candidate = self._find_cf_with_neighbor(
                    x_sample, neighbor, step_bif, start_step, target_idx
                )
                if cf_candidate is None:
                    continue

                target_cls = self.cc.class_labels_[target_idx]
                if target_cls not in cf_candidate["cp_set"]:
                    continue

                cf_candidate["neighbor_global_idx"] = int(global_idx)
                cf_candidate["neighbor_rank"] = int(nn_rank)
                cf_candidate["dist_nn_orig_segment"] = dist_nn_orig_segment

                if (best_global is None) or (cf_candidate["cost"] < best_global["cost"]):
                    best_global = cf_candidate

        if best_global is None:
            return {
                "found": False,
                "neighbor_global_idx": None,
                "neighbor_rank": None,
                "var_indices": None,
                "attempts": 0,
                "cf_set": None,
                "cost": None,
                "dist_cf_orig": None,
                "dist_cf_nn": None,
                "dist_nn_orig_segment": None,
            }

        # success
        return {
            "found": True,
            "neighbor_global_idx": best_global["neighbor_global_idx"],
            "neighbor_rank": best_global["neighbor_rank"],
            "var_indices": best_global["var_indices"],
            "attempts": best_global["attempts"],
            "cf_set": best_global["cp_set"],
            "cost": best_global["cost"],
            "dist_cf_orig": best_global["dist_cf_orig"],
            "dist_cf_nn": best_global["dist_cf_nn"],
            "dist_nn_orig_segment": best_global["dist_nn_orig_segment"],
        }

    # ------------------------------------------------------------------
    # Public API: generate CFs for one sample across all its bifurcations
    # ------------------------------------------------------------------
    def generate_for_sample(self, x_sample, events_for_sample, true_label):
        """
        Generate CFs for all bifurcation events of a single alarm flood.

        x_sample: (n_alarms, n_steps)
        events_for_sample: list of bifurcation event dicts from BifurcationDetector
        true_label: ground truth class label for this sample
        """
        results = []

        for event in events_for_sample:
            step_bif = event["bif_step"]
            start_step = event["start_step"]
            dropped_classes = event["dropped_classes"]
            orig_set = set(int(c) for c in event["curr_set"])
            last_bif_step = start_step

            for target_class in dropped_classes:
                cf_res = self._search_cf_for_target(
                    x_sample, step_bif, start_step, target_class
                )

                # ensure cf_set also only contains plain Python ints (or None)
                if cf_res["cf_set"] is not None:
                    cf_set = set(int(c) for c in cf_res["cf_set"])
                else:
                    cf_set = None
                
                # ensure var_indices is a list of ints (or None)
                if cf_res["var_indices"] is not None:
                    var_indices = [int(idx) for idx in cf_res["var_indices"]]
                else:
                    var_indices = None

                results.append(
                    {
                        "true_class": int(true_label),
                        "bif_step": int(step_bif),
                        "last_bif_step": int(last_bif_step),
                        "start_step": int(start_step),
                        "target_class": int(target_class),
                        "orig_set": orig_set,
                        "cf_set": cf_set,
                        "found": cf_res["found"],
                        "neighbor_global_idx": cf_res["neighbor_global_idx"],
                        "neighbor_rank": cf_res["neighbor_rank"],
                        "var_indices": var_indices,
                        "attempts": cf_res["attempts"],
                        "cost": cf_res["cost"],
                        "dist_nn_orig_segment": cf_res["dist_nn_orig_segment"],
                        "dist_cf_orig": cf_res["dist_cf_orig"],
                        "dist_cf_nn": cf_res["dist_cf_nn"],
                        "strategy": self.strategy,
                    }
                )
        return results

In [19]:
def compute_counterfactual_metrics(cf_results_fold, conformal_sets, y_test, ground_truth_bifurcations, step_list):
    """
    Compute metrics for counterfactual explanations.
    
    Parameters:
    -----------
    cf_results_fold : list
        List of counterfactual results for each sample in the fold
    conformal_sets : dict
        Dictionary mapping steps to conformal prediction sets
    y_test : array-like
        True labels for test samples
    ground_truth_bifurcations : dict
        Ground truth bifurcation information
    step_list : list
        List of prediction steps
        
    Returns:
    --------
    cf_metrics : dict
        Dictionary containing computed metrics
    """
    print("Computing counterfactual metrics ...")
    
    cf_metrics = {
        "num_vars_swapped": [],
        "changed_prediction_set": [],
        "overlap_with_ground_truth": [],
        "attempts": [],
        "cost": [],
        "dist_nn_orig_segment": [],
        "dist_cf_orig": [],
        "dist_cf_nn": [],
    }

    for sample_idx, sample_cf_results in enumerate(cf_results_fold):
        for cf_result in sample_cf_results:
            # 1. Check if prediction set changed
            cf_metrics["changed_prediction_set"].append(1 if cf_result.get("found", False) else 0)

            if not cf_result.get("found", False):
                continue
            
            bif_step = cf_result["bif_step"]
            target_class = cf_result["target_class"]
            var_indices = cf_result["var_indices"]
            
            # 2. Number of variables swapped
            num_swapped = len(var_indices) if var_indices is not None else 0
            cf_metrics["num_vars_swapped"].append(num_swapped)
            
            # 3. Overlap with ground truth alarm variables
            true_label = y_test[sample_idx]
            relevant_bif_steps = []
            
            for true_bif_step, bif_info in ground_truth_bifurcations.items():
                if true_bif_step <= bif_step:
                    for group in bif_info["split_groups"]:
                        if target_class in group["classes"]:
                            relevant_bif_steps.append((true_bif_step, group["alarm_vars"]))
            
            if relevant_bif_steps:
                # Use the closest (latest) relevant bifurcation point
                relevant_bif_steps.sort(key=lambda x: x[0], reverse=True)
                _, ground_truth_vars = relevant_bif_steps[0]
                
                ground_truth_vars_set = set(ground_truth_vars)
                overlap = len(set(var_indices) & ground_truth_vars_set) if var_indices is not None else 0
                overlap_pct = overlap / num_swapped if num_swapped > 0 else 0
                cf_metrics["overlap_with_ground_truth"].append(overlap_pct)
            else:
                # No relevant ground truth bifurcation found
                cf_metrics["overlap_with_ground_truth"].append(np.nan)
            
            # 4. Attempts, cost, distances (if present)
            attempts = cf_result.get("attempts", None)
            if attempts is not None:
                cf_metrics["attempts"].append(float(attempts))
            
            cost = cf_result.get("cost", None)
            if cost is not None:
                cf_metrics["cost"].append(float(cost))
            
            dist_nn_orig_segment = cf_result.get("dist_nn_orig_segment", None)
            if dist_nn_orig_segment is not None:
                cf_metrics["dist_nn_orig_segment"].append(float(dist_nn_orig_segment))
            
            dist_cf_orig = cf_result.get("dist_cf_orig", None)
            if dist_cf_orig is not None:
                cf_metrics["dist_cf_orig"].append(float(dist_cf_orig))
            
            dist_cf_nn = cf_result.get("dist_cf_nn", None)
            if dist_cf_nn is not None:
                cf_metrics["dist_cf_nn"].append(float(dist_cf_nn))
    
    return cf_metrics

In [20]:
# Parameters
random_state = 42
n_splits = 5

alpha = 0.05
delay_timer = 3
step_list = list(range(10, 61, 1))  # adapt to your time resolution

calib_frac = 0.8   # fraction of train fold used as calibration set

k_neighbors = 5
max_cf_iter = 10
lambda_distance = 1.0

In [21]:
# Classifiers and hyperparams
clfs = [WDI_KNN, EAC_KNN, MBW_LR, ACM_SVM, CASIM]
params = {
    "WDI_KNN": {"template_threshold": 0.5, "n_neighbors": 1},
    "EAC_KNN": {"attenuation_coefficient_per_min": 0.0667, "n_neighbors": 1},
    "MBW_LR": {
        "penalty": None,
        "fit_intercept": False,
        "solver": "lbfgs",
        "multi_class": "ovr",
        "decision_bounds": True,
        "confidence_interval": 1.96,
    },
    "ACM_SVM": {},
    "CASIM": {
        "num_features": 672,
        "n_estimators": 1,
        "n_jobs_multirocket": 1,
        "random_state": random_state,
        "alphas": np.logspace(-3, 3, 10),
    },
}
params_cone = {
    "WDI_KNN": {"clf_stepwise": True,  "input_type": "PROBA"},
    "MBW_LR": {"clf_stepwise": True, "input_type": "PROBA"},
    "EAC_KNN": {"clf_stepwise": True,  "input_type": "PROBA"},
    "ACM_SVM": {"clf_stepwise": False, "input_type": "PROBA"},
    "CASIM":   {"clf_stepwise": True,  "input_type": "PROBA"},
}

In [22]:
# Ground truth bifurcation counts per class
true_bif_counts = {0: 3, 1: 3, 2: 2, 3: 2, 4: 2}

# Ground truth bifurcation points and alarm variables for each class
ground_truth_bifurcations = {
    # Bifurcation at time 10: Classes {0,1,2} split from {3,4}
    10: {
        'split_groups': [
            {'classes': {0, 1, 2}, 'alarm_vars': {0, 1, 2, 3, 4, 6}},
            {'classes': {3, 4}, 'alarm_vars': {0, 1, 2, 3, 4, 6}}
        ],
        'description': 'Classes 0,1,2 split from classes 3,4'
    },
    # Bifurcation at time 20: Classes {3} and {4} split
    20: {
        'split_groups': [
            {'classes': {3}, 'alarm_vars': {3, 4, 6, 7}},
            {'classes': {4}, 'alarm_vars': {3, 4, 6, 7}}
        ],
        'description': 'Class 3 splits from class 4'
    },
    # Bifurcation at time 30: Class {2} splits from {0,1}
    30: {
        'split_groups': [
            {'classes': {2}, 'alarm_vars': {0, 1, 2, 5, 6}},
            {'classes': {0, 1}, 'alarm_vars': {0, 1, 2, 5, 6}}
        ],
        'description': 'Class 2 splits from classes 0,1'
    },
    # Bifurcation at time 50: Class {1} splits from {0}
    50: {
        'split_groups': [
            {'classes': {1}, 'alarm_vars': {2, 3, 4, 7}},
            {'classes': {0}, 'alarm_vars': {2, 3, 4, 7}}
        ],
        'description': 'Class 1 splits from class 0'
    }
}

# Per-class bifurcation timeline (what bifurcations affect each class)
per_class_bifurcations = {
    0: [
        {'time': 10, 'alarm_vars': {0, 1, 2, 3, 4, 6}, 'split_from': {3, 4}},
        {'time': 30, 'alarm_vars': {0, 1, 2, 5, 6}, 'split_from': {2}},
        {'time': 50, 'alarm_vars': {2, 3, 4, 7}, 'split_from': {1}}
    ],
    1: [
        {'time': 10, 'alarm_vars': {0, 1, 2, 3, 4, 6}, 'split_from': {3, 4}},
        {'time': 30, 'alarm_vars': {0, 1, 2, 5, 6}, 'split_from': {2}},
        {'time': 50, 'alarm_vars': {2, 3, 4, 7}, 'split_from': {0}}
    ],
    2: [
        {'time': 10, 'alarm_vars': {0, 1, 2, 3, 4, 6}, 'split_from': {3, 4}},
        {'time': 30, 'alarm_vars': {0, 1, 2, 5, 6}, 'split_from': {0, 1}}
    ],
    3: [
        {'time': 10, 'alarm_vars': {0, 1, 2, 3, 4, 6}, 'split_from': {0, 1, 2}},
        {'time': 20, 'alarm_vars': {3, 4, 6, 7}, 'split_from': {4}}
    ],
    4: [
        {'time': 10, 'alarm_vars': {0, 1, 2, 3, 4, 6}, 'split_from': {0, 1, 2}},
        {'time': 20, 'alarm_vars': {3, 4, 6, 7}, 'split_from': {3}}
    ]
}

print("Ground Truth Bifurcation Timeline:")
print("=" * 60)
for time_step in sorted(ground_truth_bifurcations.keys()):
    bif_info = ground_truth_bifurcations[time_step]
    print(f"\nTime {time_step}: {bif_info['description']}")
    for group in bif_info['split_groups']:
        print(f"  Classes {group['classes']}: alarm vars {sorted(group['alarm_vars'])}")


def compute_expected_avg_set_size(
    y_test,
    step_list,
    ground_truth_bifurcations,
    all_classes=None,
):
    """
    Compute the expected average prediction set size over time, based on
    ground-truth class ambiguities given by ground_truth_bifurcations.

    Parameters
    ----------
    y_test : array-like of shape (n_samples,)
        Ground-truth class labels in the test set.
    step_list : list of int
        Time steps (e.g., your CP evaluation steps) for which to compute
        the expected average prediction set size.
    ground_truth_bifurcations : dict
        As in the example, mapping time -> {
            'split_groups': [
                {'classes': set(...), 'alarm_vars': set(...)}, ...
            ],
            'description': str
        }
    all_classes : iterable, optional
        If None, inferred as the union of all class ids appearing in
        ground_truth_bifurcations.

    Returns
    -------
    expected_avg_set_size : dict
        Mapping time step -> expected average set size at that time.
    """

    # Ensure y_test is a numpy array
    y_test = np.asarray(y_test)

    # Infer the set of all classes if not provided
    if all_classes is None:
        all_classes = set()
        for t, info in ground_truth_bifurcations.items():
            for grp in info["split_groups"]:
                all_classes |= set(grp["classes"])
        all_classes = sorted(all_classes)

    # Helper: given a time t, build the ambiguity groups from all bifurcations up to t
    def build_groups_up_to_time(t):
        # Start with one big group containing all classes
        groups = [set(all_classes)]

        # Process bifurcations in chronological order up to t
        for bif_time in sorted(ground_truth_bifurcations.keys()):
            if bif_time > t:
                break

            split_info = ground_truth_bifurcations[bif_time]
            split_groups = [set(g["classes"]) for g in split_info["split_groups"]]
            union_split = set().union(*split_groups)

            new_groups = []
            for G in groups:
                if union_split & G:
                    # This group is affected by this split; refine it
                    # Intersect with each split subgroup
                    for sg in split_groups:
                        inter = G & sg
                        if inter:
                            new_groups.append(inter)
                    # Any remaining classes in G that are not part of this split
                    leftover = G - union_split
                    if leftover:
                        new_groups.append(leftover)
                else:
                    # Group not affected by this split
                    new_groups.append(G)
            groups = new_groups

        return groups

    expected_avg_set_size = {}

    for t in step_list:
        groups = build_groups_up_to_time(t)

        # Map each class -> size of its ambiguity group at time t
        size_map = {}
        for G in groups:
            size = len(G)
            for c in G:
                size_map[c] = size

        # Compute expected average size over the test set labels
        sizes_for_test = [size_map[int(lbl)] for lbl in y_test]
        expected_avg_set_size[t] = float(np.mean(sizes_for_test))

    return expected_avg_set_size

Ground Truth Bifurcation Timeline:

Time 10: Classes 0,1,2 split from classes 3,4
  Classes {0, 1, 2}: alarm vars [0, 1, 2, 3, 4, 6]
  Classes {3, 4}: alarm vars [0, 1, 2, 3, 4, 6]

Time 20: Class 3 splits from class 4
  Classes {3}: alarm vars [3, 4, 6, 7]
  Classes {4}: alarm vars [3, 4, 6, 7]

Time 30: Class 2 splits from classes 0,1
  Classes {2}: alarm vars [0, 1, 2, 5, 6]
  Classes {0, 1}: alarm vars [0, 1, 2, 5, 6]

Time 50: Class 1 splits from class 0
  Classes {1}: alarm vars [2, 3, 4, 7]
  Classes {0}: alarm vars [2, 3, 4, 7]


In [23]:
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score
from joblib import Parallel, delayed
from tqdm.auto import tqdm

skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

# Storage for results
all_results = {}

for clf_cls in tqdm(clfs, desc="Classifiers"):
    clf_name = clf_cls.__name__
    all_results[clf_name] = {
        "bifurcations": [],          # list[fold] -> list[sample] -> events
        "conformal_sets": [],        # list[fold] -> dict[step] -> array[sample] of sets
        "y_test": [],                # list[fold] -> array[sample]
        "fold_perf": [],             # list[fold] -> dict of scalar performance metrics
        "counterfactuals_random": [],# list[fold] -> cf_results_fold
        "counterfactuals_all": [],
        "counterfactuals_greedy": [],
    }

    fold_iter = skf.split(X, y)
    for fold_idx, (train_idx, test_idx) in enumerate(fold_iter, start=1):
        print(f"\n=== {clf_name} - Fold {fold_idx} / {n_splits} ===")

        X_train_full, X_test = X[train_idx], X[test_idx]
        y_train_full, y_test = y[train_idx], y[test_idx]

        # Train/Calibration split
        X_train, y_train, X_cali, y_cali = stratified_train_calibration_split(
            X_train_full, y_train_full, calib_frac=calib_frac, random_state=random_state + fold_idx
        )

        print(f"Train size: {len(y_train)}, Calib size: {len(y_cali)}, Test size: {len(y_test)}")

        # Conformal classifier
        cc = ConformalAlarmClassifier(
            base_clf_cls=clf_cls,
            clf_params=params[clf_name],
            step_list=step_list,
            clf_stepwise=params_cone[clf_name]["clf_stepwise"],
            input_type=params_cone[clf_name]["input_type"],
            alpha=alpha,
            delay_timer=delay_timer,
            random_state=random_state,
        )
        cc.fit(X_train, y_train, X_cali, y_cali)

        # Step-wise conformal prediction sets on test data
        print("Computing step-wise conformal prediction sets on test data ...")
        conformal_sets, conformal_vectors, proba_per_step = cc.predict_conformal_full_sequence(X_test)

        # Store conformal sets and y_test for this fold (for later aggregation)
        all_results[clf_name]["conformal_sets"].append(conformal_sets)
        all_results[clf_name]["y_test"].append(y_test)

        # Compute performance metrics for each step

        # For final step prediction (using the last step in step_list)
        final_step = step_list[-1]
        y_pred_final = np.argmax(proba_per_step[final_step], axis=1)
        acc_final = accuracy_score(y_test, y_pred_final)
        f1_final = f1_score(y_test, y_pred_final, average='weighted')

        # Conformal prediction metrics
        avg_set_sizes = {}
        for step in step_list:
            avg_set_sizes[step] = np.mean([len(s) for s in conformal_sets[step]])

        coverages = {}
        for step in step_list:
            correct = sum([y_test[i] in conformal_sets[step][i] for i in range(len(y_test))])
            coverages[step] = correct / len(y_test)

        overall_avg_acc = np.mean(
            [accuracy_score(y_test, np.argmax(proba_per_step[step], axis=1)) for step in step_list]
        )
        overall_avg_set_size = np.mean(list(avg_set_sizes.values()))
        overall_avg_coverage = np.mean(list(coverages.values()))

        # Expected average set sizes from ground-truth bifurcation structure
        expected_sizes = compute_expected_avg_set_size(
            y_test=y_test,
            step_list=step_list,
            ground_truth_bifurcations=ground_truth_bifurcations,
        )
        average_expected_size = np.mean(list(expected_sizes.values()))

        # Store fold-level performance metrics (no printing yet)
        all_results[clf_name]["fold_perf"].append(
            {
                "fold": fold_idx,
                "acc_final": acc_final,
                "f1_final": f1_final,
                "avg_set_size_final": avg_set_sizes[final_step],
                "coverage_final": coverages[final_step],
                "overall_avg_acc": overall_avg_acc,
                "overall_avg_set_size": overall_avg_set_size,
                "overall_avg_coverage": overall_avg_coverage,
                "average_expected_size": average_expected_size,
            }
        )

        # Detect bifurcations (for later aggregate stats)
        print("Detecting bifurcation events ...")
        detector = BifurcationDetector(step_list)
        bif_events = detector.detect(conformal_sets)
        all_results[clf_name]["bifurcations"].append(bif_events)

        # Generate counterfactuals at bifurcations
        print("Generating counterfactual explanations at bifurcation points ...")
        for strategy in ["random", "all", "greedy"]:
            print(f"  Strategy: {strategy}")
            
            cf_gen = CounterfactualGenerator(
                conformal_classifier=cc,
                X_cali=X_cali,
                y_cali=y_cali,
                step_list=step_list,
                k_neighbors=k_neighbors,
                max_iter=max_cf_iter,
                strategy=strategy,
                lambda_distance=lambda_distance,
                random_state=random_state,
            )

            # Parallel over test samples
            cf_results_fold = Parallel(n_jobs=-1)(
                delayed(cf_gen.generate_for_sample)(X_test[i], bif_events[i], y_test[i])
                for i in tqdm(range(len(y_test)), desc=f"{clf_name} - fold {fold_idx} - {strategy}: samples", leave=False)
            )

            cf_key = f"counterfactuals_{strategy}"
            all_results[clf_name][cf_key].append(cf_results_fold)

        print(f"Finished fold {fold_idx} for {clf_name}")

Classifiers:   0%|          | 0/5 [00:00<?, ?it/s]


=== WDI_KNN - Fold 1 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 1 for WDI_KNN

=== WDI_KNN - Fold 2 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 2 for WDI_KNN

=== WDI_KNN - Fold 3 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 3 for WDI_KNN

=== WDI_KNN - Fold 4 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 4 for WDI_KNN

=== WDI_KNN - Fold 5 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy


Classifiers:  20%|██        | 1/5 [01:58<07:52, 118.08s/it]

Finished fold 5 for WDI_KNN

=== EAC_KNN - Fold 1 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 1 for EAC_KNN

=== EAC_KNN - Fold 2 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 2 for EAC_KNN

=== EAC_KNN - Fold 3 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 3 for EAC_KNN

=== EAC_KNN - Fold 4 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 4 for EAC_KNN

=== EAC_KNN - Fold 5 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy


Classifiers:  40%|████      | 2/5 [05:34<08:48, 176.19s/it]

Finished fold 5 for EAC_KNN

=== MBW_LR - Fold 1 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 1 for MBW_LR

=== MBW_LR - Fold 2 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 2 for MBW_LR

=== MBW_LR - Fold 3 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 3 for MBW_LR

=== MBW_LR - Fold 4 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 4 for MBW_LR

=== MBW_LR - Fold 5 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy


Classifiers:  60%|██████    | 3/5 [09:17<06:34, 197.35s/it]

Finished fold 5 for MBW_LR

=== ACM_SVM - Fold 1 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 1 for ACM_SVM

=== ACM_SVM - Fold 2 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 2 for ACM_SVM

=== ACM_SVM - Fold 3 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 3 for ACM_SVM

=== ACM_SVM - Fold 4 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 4 for ACM_SVM

=== ACM_SVM - Fold 5 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy


Classifiers:  80%|████████  | 4/5 [09:52<02:13, 133.41s/it]

Finished fold 5 for ACM_SVM

=== CASIM - Fold 1 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 1 for CASIM

=== CASIM - Fold 2 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 2 for CASIM

=== CASIM - Fold 3 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 3 for CASIM

=== CASIM - Fold 4 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy




Finished fold 4 for CASIM

=== CASIM - Fold 5 / 5 ===
Train size: 300, Calib size: 1200, Test size: 375
Computing step-wise conformal prediction sets on test data ...
Detecting bifurcation events ...
Generating counterfactual explanations at bifurcation points ...
  Strategy: random




  Strategy: all




  Strategy: greedy


Classifiers: 100%|██████████| 5/5 [18:22<00:00, 220.60s/it]

Finished fold 5 for CASIM





In [24]:
print("\n" + "="*80)
print("AGGREGATE RESULTS ACROSS ALL FOLDS")
print("="*80)

for clf_name, clf_res in all_results.items():
    print(f"\n{clf_name}:")
    print("-" * 60)

    # ------------------------------------------------------------------
    # 2.1 Aggregate performance metrics across folds
    # ------------------------------------------------------------------
    perf_df = pd.DataFrame(clf_res["fold_perf"])
    if not perf_df.empty:
        print("\nPerformance (over folds):")
        agg_perf = perf_df.drop(columns=["fold"]).agg(["mean", "std"])
        print(agg_perf)

    # ------------------------------------------------------------------
    # 2.2 Bifurcation statistics over folds
    # ------------------------------------------------------------------
    num_folds = len(clf_res["bifurcations"])
    total_bifs = sum(
        np.sum([len(events) for events in fold_bifs])
        for fold_bifs in clf_res["bifurcations"]
    )
    avg_bifs = total_bifs / num_folds if num_folds > 0 else 0

    print("\nBifurcation Detection:")
    print(f"  - Total bifurcations across {num_folds} folds: {total_bifs}")
    print(f"  - Average per fold: {avg_bifs:.2f}")

    # Expected (ground-truth) bifurcations
    total_true_bifs = 0
    for y_test_fold in clf_res["y_test"]:
        total_true_bifs += sum(true_bif_counts[int(lbl)] for lbl in y_test_fold)

    avg_true_bifs = total_true_bifs / num_folds if num_folds > 0 else 0.0

    print(f"  - Expected total bifurcations (ground truth): {total_true_bifs}")
    print(f"  - Expected average per fold: {avg_true_bifs:.2f}")

    # ------------------------------------------------------------------
    # 2.3 Aggregate conformal_sets and y_test across folds
    # ------------------------------------------------------------------
    # Flatten y_test across folds
    y_test_all = np.concatenate(clf_res["y_test"], axis=0)

    # Flatten conformal_sets across folds per step
    conformal_sets_all = {}
    for step in step_list:
        # for each fold, conformal_sets_fold[step] is an array of sets (len = n_samples_fold)
        conformal_sets_all[step] = np.concatenate(
            [conf_fold[step] for conf_fold in clf_res["conformal_sets"]],
            axis=0,
        )

    # ------------------------------------------------------------------
    # 2.4 Counterfactual statistics for each strategy (aggregated over folds)
    # ------------------------------------------------------------------
    for strategy in ["random", "all", "greedy"]:
        cf_key = f"counterfactuals_{strategy}"
        if cf_key not in clf_res:
            continue

        print(f"\nCounterfactual Generation ({strategy}):")

        # Flatten over folds -> samples
        cf_results_all = [
            sample_cf_results
            for fold_cf in clf_res[cf_key]      # list[fold] -> cf_results_fold
            for sample_cf_results in fold_cf    # cf_results_fold: list[samples]
        ]

        n_folds_cf = len(clf_res[cf_key])
        n_samples_cf = len(cf_results_all)
        print(f"  - Folds with CFs: {n_folds_cf}")
        print(f"  - Total samples (with possibly multiple CFs per sample): {n_samples_cf}")

        # Now run compute_counterfactual_metrics ONCE for all folds combined
        cf_metrics = compute_counterfactual_metrics(
            cf_results_all,          # flattened over folds
            conformal_sets_all,      # flattened conformal sets per step
            y_test_all,              # flattened labels
            ground_truth_bifurcations,
            step_list,
        )

        # Optionally: put cf_metrics into a DataFrame for further processing / saving
        rows = []
        max_len = max(len(v) for v in cf_metrics.values() if len(v) > 0) if cf_metrics else 0
        for metric_name, vals in cf_metrics.items():
            for v in vals:
                rows.append({"metric": metric_name, "value": v})
        if rows:
            df_cf_metrics = pd.DataFrame(rows)
            agg_cf = df_cf_metrics.groupby("metric")["value"].agg(["mean", "std", "count"])
            print("\n  Counterfactual metrics (aggregated):")
            print(agg_cf)

        # If you still want to save the detailed CFs per classifier & strategy:
        # (flatten structure and write CSV)
        detailed_rows = []
        for fold_idx, fold_cf in enumerate(clf_res[cf_key], start=1):
            for sample_idx, sample_cf_results in enumerate(fold_cf):
                for cf in sample_cf_results:
                    row = {
                        "classifier": clf_name,
                        "strategy": strategy,
                        "fold": fold_idx,
                        "sample_idx": sample_idx,
                    }
                    for k, v in cf.items():
                        if isinstance(v, (set, list, tuple, np.ndarray)):
                            row[k] = str(list(v))
                        else:
                            row[k] = v
                    detailed_rows.append(row)

        if detailed_rows:
            df_cf = pd.DataFrame(detailed_rows)
            out_name = f"results/synthetic/{clf_name}_counterfactuals_{strategy}.csv"
            df_cf.to_csv(out_name, index=False)
            print(f"  - Saved detailed counterfactuals to: {out_name}")
        else:
            print("  - No counterfactuals to save for this strategy.")


AGGREGATE RESULTS ACROSS ALL FOLDS

WDI_KNN:
------------------------------------------------------------

Performance (over folds):
      acc_final  f1_final  avg_set_size_final  coverage_final  \
mean   0.396267  0.231099            4.088533        0.998933   
std    0.005530  0.003311            0.335912        0.001461   

      overall_avg_acc  overall_avg_set_size  overall_avg_coverage  \
mean         0.302055              4.726703              0.999739   
std          0.035998              0.202304              0.000327   

      average_expected_size  
mean               1.705882  
std                0.000000  

Bifurcation Detection:
  - Total bifurcations across 5 folds: 2940
  - Average per fold: 588.00
  - Expected total bifurcations (ground truth): 4500
  - Expected average per fold: 900.00

Counterfactual Generation (random):
  - Folds with CFs: 5
  - Total samples (with possibly multiple CFs per sample): 1875
Computing counterfactual metrics ...

  Counterfactual metric