# 04b Validate Ensembles with Indicator B at q* = 0.9

This notebook implements **Stage 2B** of the ESTIMATOR_RANKING_PLAN:

1. For each (Q, k) configuration, compute indicator B at fixed q* = 0.9
2. Compute B̄_0.9 (proportion of folds where top subgroup outperforms complement)
3. Compute Δτ̄_0.9 (mean treatment effect difference)
4. Identify optimal (Q*, k*) configuration

**Fixed validation quantile:** q* = 0.9 (top 10% of predicted effects)

**Selection criteria:**
- Primary: B̄_0.9 = 1.0 (or ≥ 0.95 if relaxed)
- Secondary: Largest Δτ̄_0.9 among passing configurations

In [16]:
from pathlib import Path

# Ensure output directories exist
OUTPUT_PATH = Path("output/analysis")
INTERMEDIATE_PATH = Path("output/intermediate/grid_search")
for p in [OUTPUT_PATH, INTERMEDIATE_PATH]:
    p.mkdir(parents=True, exist_ok=True)

In [17]:
# Standard imports
import numpy as np
import pandas as pd
import joblib
import warnings
from functools import lru_cache
warnings.filterwarnings("ignore")

# Custom methods
from methods.causal_functions import get_subgroup_CATE, get_subgroup_t_statistic, get_subgroup_t_statistic

## 1. Configuration

In [18]:
# Analysis configuration
OUTCOME_NAME = "fausebal"
DIR_NEG = False  # We want POSITIVE treatment effects

# Fixed validation quantile
Q_STAR = 0.9  # Top 10% (q* = 0.9 means top 10% when dir_neg=False)

# Grid search parameters (same as 04a)
Q_SETS = {
    "Q1": np.array([0.1]),
    "Q2": np.array([0.1, 0.2]),
    "Q3": np.array([0.1, 0.2, 0.3]),
    "Q4": np.array([0.1, 0.2, 0.3, 0.4]),
    "Q5": np.array([0.1, 0.2, 0.3, 0.4, 0.5]),
}

K_VALUES = list(range(1, 11))  # k ∈ {1, 2, ..., 10}

print(f"Outcome: {OUTCOME_NAME}")
print(f"Fixed validation quantile: q* = {Q_STAR} (top {int((1-Q_STAR)*100)}%)")
print(f"Direction: {'Negative' if DIR_NEG else 'Positive'} effects preferred")

Outcome: fausebal
Fixed validation quantile: q* = 0.9 (top 9%)
Direction: Positive effects preferred


## 2. Load Data from 04a

In [19]:
# Load fitted libraries
fitted_libraries = joblib.load(f"output/analysis/{OUTCOME_NAME}/{OUTCOME_NAME}_fitted_libraries.pkl")
print(f"Loaded fitted libraries")

# Load ensemble table from 04a
ensemble_table = joblib.load(INTERMEDIATE_PATH / f"{OUTCOME_NAME}_ensemble_table.pkl")
print(f"Loaded ensemble table")

# Get perturbation names
perturbations = ["none", "cv_0", "cv_1"]
n_folds = 4

Loaded fitted libraries
Loaded ensemble table


## 3. Compute Indicator B at q* = 0.9

For each ensemble we average its members' CATE predictions within every (perturbation, fold), build the top-q* subgroup from that ensemble prediction, and evaluate on the validation samples only:
- **B_0.9 = 1** if τ̂_top ≥ τ̂_complement (positive effects) or τ̂_top ≤ τ̂_complement (negative effects)
- **Δτ_0.9 = τ̂_top − τ̂_complement**

These per-fold metrics feed into the aggregation step below.

In [20]:
def build_subgroup_masks(tau_values, train_mask, val_mask, q_star, dir_neg):
    """Build boolean masks for the desired quantile subgroup based on ensemble predictions."""
    tau_train = tau_values[train_mask]
    if tau_train.size == 0:
        return None, None
    if dir_neg:
        q_bot, q_top = 0.0, max(0.0, 1.0 - q_star)
    else:
        q_bot, q_top = q_star, 1.0
    quantile_top = np.quantile(tau_train, min(q_top, 1.0))
    if q_bot <= 0:
        base_indicator = tau_values <= quantile_top
    else:
        quantile_bot = np.quantile(tau_train, q_bot)
        base_indicator = (tau_values > quantile_bot) & (tau_values <= quantile_top)
    subgroup_mask = base_indicator & val_mask
    complement_mask = val_mask & (~subgroup_mask)
    return subgroup_mask, complement_mask


def has_treated_and_control(t_vector, mask):
    """Ensure both treatment arms are represented within the mask."""
    subset = t_vector[mask]
    if subset.size == 0:
        return False
    return np.any(subset == 1) and np.any(subset == 0)


def canonical_key(ensemble):
    """Return a sorted tuple key for caching ensemble metrics."""
    return tuple(sorted(ensemble))


@lru_cache(maxsize=None)
def compute_ensemble_fold_metrics(ensemble_key):
    """Compute raw B values, delta values, and t-statistics for an ensemble across all folds."""
    if len(ensemble_key) == 0:
        return np.array([]), np.array([]), np.array([]), np.array([])
    B_values = []
    delta_values = []
    t_stat_values = []
    for pert_name, library in fitted_libraries.items():
        missing = [est for est in ensemble_key if est not in library]
        if missing:
            continue
        ref_estimator = library[ensemble_key[0]]
        y = ref_estimator.y
        t = ref_estimator.t
        n_splits = ref_estimator.n_splits
        for fold in range(n_splits):
            val_mask = ref_estimator.results[fold].val_indicator
            train_mask = ref_estimator.results[fold].train_indicator
            tau_stack = np.stack([library[est].results[fold].tau for est in ensemble_key])
            ensemble_tau = np.mean(tau_stack, axis=0)
            subgroup_mask, complement_mask = build_subgroup_masks(
                ensemble_tau, train_mask, val_mask, Q_STAR, DIR_NEG
            )
            if subgroup_mask is None or complement_mask is None:
                continue
            if subgroup_mask.sum() == 0 or complement_mask.sum() == 0:
                continue
            if not has_treated_and_control(t, subgroup_mask):
                continue
            if not has_treated_and_control(t, complement_mask):
                continue
            tau_top = get_subgroup_CATE(y, t, subgroup_mask)
            tau_complement = get_subgroup_CATE(y, t, complement_mask)
            if not np.isfinite(tau_top) or not np.isfinite(tau_complement):
                continue
            if DIR_NEG:
                B = 1 if tau_top <= tau_complement else 0
            else:
                B = 1 if tau_top >= tau_complement else 0
            B_values.append(B)
            delta_values.append(tau_top - tau_complement)
            # Compute t-statistic for the subgroup (CATE_sg - ATE) / SE
            t_stat_fold = get_subgroup_t_statistic(y, t, subgroup_mask, val_mask)
            t_stat_values.append(t_stat_fold if np.isfinite(t_stat_fold) else np.nan)
    return np.array(B_values), np.array(delta_values), np.array(t_stat_values)


def summarize_ensemble_metrics(ensemble):
    """Aggregate B, delta, and t-statistic for a given ensemble list."""
    key = canonical_key(tuple(ensemble))
    B_vals, delta_vals, t_stat_vals = compute_ensemble_fold_metrics(key)
    finite_t = t_stat_vals[np.isfinite(t_stat_vals)]
    return {
        "B_bar": float(np.mean(B_vals)),
        "delta_bar": float(np.mean(delta_vals)),
        "delta_std": float(np.std(delta_vals, ddof=0)),
        "t_bar": float(np.mean(finite_t)) if finite_t.size > 0 else np.nan,
        "n_folds_used": int(delta_vals.size),
    }


In [21]:
# Compute metrics for unique ensembles for reference
unique_records = []
seen_keys = set()

for Q_name, config in ensemble_table.items():
    for k, ensemble in config.items():
        if not ensemble:
            continue
        key = canonical_key(tuple(ensemble))
        if key in seen_keys:
            continue
        seen_keys.add(key)
        metrics = summarize_ensemble_metrics(ensemble)
        unique_records.append({
            "ensemble": ", ".join(ensemble),
            "size": len(ensemble),
            "B_bar_0.9": metrics["B_bar"],
            "delta_bar_0.9": metrics["delta_bar"],
            "delta_std_0.9": metrics["delta_std"],
            "t_bar_0.9": metrics["t_bar"],
            "n_folds_used": metrics["n_folds_used"],
        })

ensemble_metrics_df = pd.DataFrame(unique_records)
if not ensemble_metrics_df.empty:
    ensemble_metrics_df = ensemble_metrics_df.sort_values(
        ["B_bar_0.9", "delta_bar_0.9"], ascending=[False, False]
    )
    csv_path = INTERMEDIATE_PATH / f"{OUTCOME_NAME}_ensemble_metrics.csv"
    ensemble_metrics_df.to_csv(csv_path, index=False)
    print(f"Saved ensemble metrics to {csv_path}")
    ensemble_metrics_df.head()
else:
    print("No non-empty ensembles available for metric computation.")

Saved ensemble metrics to output/intermediate/grid_search/fausebal_ensemble_metrics.csv


## 4. Validate Each (Q, k) Configuration

For each ensemble, compute aggregate B̄_0.9 and Δτ̄_0.9.

In [22]:
# Compute validation metrics for each (Q, k) configuration
validation_results = []

for Q_name in Q_SETS.keys():
    for k in K_VALUES:
        ensemble = ensemble_table[Q_name][k]

        if not ensemble:
            validation_results.append({
                "Q": Q_name,
                "k": k,
                "n_estimators": 0,
                "estimators": "",
                "B_bar_0.9": np.nan,
                "delta_bar_0.9": np.nan,
                "delta_std_0.9": np.nan,
                "t_bar_0.9": np.nan,
                "n_folds_used": 0,
                "passes_strict": False,
                "passes_relaxed": False,
            })
            continue

        metrics = summarize_ensemble_metrics(ensemble)
        B_bar = metrics["B_bar"]
        delta_bar = metrics["delta_bar"]
        delta_std = metrics["delta_std"]
        t_bar = metrics["t_bar"]
        n_folds_used = metrics["n_folds_used"]

        passes_strict = np.isclose(B_bar, 1.0) if not np.isnan(B_bar) else False
        passes_relaxed = (B_bar >= 0.95) if not np.isnan(B_bar) else False

        validation_results.append({
            "Q": Q_name,
            "k": k,
            "n_estimators": len(ensemble),
            "estimators": ", ".join(ensemble),
            "B_bar_0.9": B_bar,
            "delta_bar_0.9": delta_bar,
            "delta_std_0.9": delta_std,
            "t_bar_0.9": t_bar,
            "n_folds_used": n_folds_used,
            "passes_strict": passes_strict,
            "passes_relaxed": passes_relaxed,
        })

validation_df = pd.DataFrame(validation_results)

# Save validation results
validation_df.to_csv(INTERMEDIATE_PATH / f"{OUTCOME_NAME}_validation_results.csv", index=False)
print(f"Saved validation results to {INTERMEDIATE_PATH / f'{OUTCOME_NAME}_validation_results.csv'}")

Saved validation results to output/intermediate/grid_search/fausebal_validation_results.csv


In [23]:
# Display validation results
print("Validation Results for (Q, k) Configurations:")
print("="*80)

valid_df = validation_df[validation_df["n_estimators"] > 0].copy()
valid_df = valid_df.sort_values(["B_bar_0.9", "delta_bar_0.9"], ascending=[False, False])
columns = [
    "Q",
    "k",
    "n_estimators",
    "B_bar_0.9",
    "delta_bar_0.9",
    "delta_std_0.9",
    "t_bar_0.9",
    "n_folds_used",
    "passes_strict",
    "passes_relaxed",
]

if valid_df.empty:
    print("No non-empty ensembles available.")
else:
    print(valid_df[columns].to_string(index=False))

Validation Results for (Q, k) Configurations:
 Q  k  n_estimators  B_bar_0.9  delta_bar_0.9  delta_std_0.9  t_bar_0.9  n_folds_used  passes_strict  passes_relaxed
Q3 10             5   0.916667       0.014296       0.024718   0.381860            12          False           False
Q4 10             5   0.916667       0.014296       0.024718   0.381860            12          False           False
Q4  9             4   0.833333       0.029442       0.037014   0.785744            12          False           False
Q5  8             4   0.833333       0.029442       0.037014   0.785744            12          False           False
Q5  9             5   0.833333       0.028368       0.026570   0.758069            12          False           False
Q5 10             5   0.833333       0.028368       0.026570   0.758069            12          False           False
Q3  7             2   0.750000       0.031610       0.044946   0.845787            12          False           False
Q4  6             

## 5. Select Optimal (Q*, k*) Configuration

In [24]:
# Selection procedure:
# 1. Primary: B̄_0.9 = 1.0 (strict criterion)
# 2. Secondary: Largest Δτ̄_0.9 among passing configurations

optimal = None
criterion_used = "none"

strict_passing = validation_df[validation_df["passes_strict"] == True]

if not strict_passing.empty:
    print("Configurations passing STRICT criterion (B̄_0.9 = 1.0):")
    print(
        strict_passing[
            ["Q", "k", "n_estimators", "B_bar_0.9", "delta_bar_0.9", "estimators"]
        ].to_string(index=False)
    )

    # Select the one with largest delta
    optimal = strict_passing.loc[strict_passing["delta_bar_0.9"].idxmax()]
    criterion_used = "strict"
else:
    print("No configurations pass STRICT criterion. Trying RELAXED criterion (B̄_0.9 >= 0.95)...")
    relaxed_passing = validation_df[validation_df["passes_relaxed"] == True]

    if not relaxed_passing.empty:
        print("\nConfigurations passing RELAXED criterion:")
        print(
            relaxed_passing[
                ["Q", "k", "n_estimators", "B_bar_0.9", "delta_bar_0.9", "estimators"]
            ].to_string(index=False)
        )

        # Select the one with largest delta
        optimal = relaxed_passing.loc[relaxed_passing["delta_bar_0.9"].idxmax()]
        criterion_used = "relaxed"
    else:
        print("\nNo configurations pass RELAXED criterion either.")
        print("Selecting configuration with highest B̄_0.9...")

        valid_configs = validation_df[validation_df["n_estimators"] > 0]
        if valid_configs.empty:
            print("No valid configurations with non-empty ensembles were found.")
            optimal = None
            criterion_used = "none"
        else:
            optimal = valid_configs.loc[valid_configs["B_bar_0.9"].idxmax()]
            criterion_used = "best_available"

No configurations pass STRICT criterion. Trying RELAXED criterion (B̄_0.9 >= 0.95)...

No configurations pass RELAXED criterion either.
Selecting configuration with highest B̄_0.9...


In [25]:
# Display optimal configuration
print("\n" + "="*80)
print("OPTIMAL CONFIGURATION (Q*, k*)")
print("="*80)

if optimal is None:
    print("\nNo optimal configuration could be selected.")
else:
    print(f"\nCriterion used: {criterion_used}")
    print(f"\nQ* = {optimal['Q']}")
    print(f"k* = {optimal['k']}")
    print(f"Number of estimators: {optimal['n_estimators']}")
    print(f"Estimators: {optimal['estimators']}")
    print(f"\nValidation metrics at q* = {Q_STAR}:")
    print(f"  B̄_0.9 = {optimal['B_bar_0.9']:.4f}")
    print(f"  Δτ̄_0.9 = {optimal['delta_bar_0.9']:.4f}")
    print(f"  t̄_0.9 = {optimal['t_bar_0.9']:.4f}")



OPTIMAL CONFIGURATION (Q*, k*)

Criterion used: best_available

Q* = Q3
k* = 10
Number of estimators: 5
Estimators: x_xgb, causal_tree_1, x_rf, t_rf, x_logistic

Validation metrics at q* = 0.9:
  B̄_0.9 = 0.9167
  Δτ̄_0.9 = 0.0143
  t̄_0.9 = 0.3819


In [26]:
# Save optimal configuration
optimal_config = None
if optimal is None:
    print("\nSkipping save because no optimal configuration was selected.")
else:
    optimal_config = {
        "Q_star": optimal["Q"],
        "k_star": optimal["k"],
        "q_star": Q_STAR,
        "estimators": optimal["estimators"].split(", ") if optimal["estimators"] else [],
        "B_bar_0.9": optimal["B_bar_0.9"],
        "delta_bar_0.9": optimal["delta_bar_0.9"],
        "t_bar_0.9": optimal["t_bar_0.9"],
        "criterion_used": criterion_used,
    }

    joblib.dump(optimal_config, INTERMEDIATE_PATH / f"{OUTCOME_NAME}_optimal_config.pkl")
    print(f"\nSaved optimal configuration to {INTERMEDIATE_PATH / f'{OUTCOME_NAME}_optimal_config.pkl'}")


Saved optimal configuration to output/intermediate/grid_search/fausebal_optimal_config.pkl


## 7. Summary

**Outputs saved:**
1. `{OUTCOME_NAME}_estimator_B_metrics.csv` - B̄_0.9 and Δτ̄_0.9 for each estimator
2. `{OUTCOME_NAME}_validation_results.csv` - Validation metrics for each (Q, k) configuration
3. `{OUTCOME_NAME}_optimal_config.pkl` - Optimal (Q*, k*) configuration

**Next steps:**
- Notebook 05b: Fit final ensemble on full training/validation sample
- Validate on holdout test set

In [27]:
# Final summary
print("="*80)
print("FINAL SUMMARY")
print("="*80)
print(f"\nFixed validation quantile: q* = {Q_STAR}")

if optimal_config is None:
    print("\nNo optimal configuration was saved.")
else:
    print("\nOptimal configuration:")
    print(f"  Q* = {optimal_config['Q_star']}")
    print(f"  k* = {optimal_config['k_star']}")
    print(f"  Estimators: {optimal_config['estimators']}")
    print("\nValidation metrics:")
    print(f"  B̄_0.9 = {optimal_config['B_bar_0.9']:.4f}")
    print(f"  Δτ̄_0.9 = {optimal_config['delta_bar_0.9']:.4f}")
    print(f"  t̄_0.9 = {optimal_config['t_bar_0.9']:.4f}")
    print(f"\nCriterion used: {optimal_config['criterion_used']}")

FINAL SUMMARY

Fixed validation quantile: q* = 0.9

Optimal configuration:
  Q* = Q3
  k* = 10
  Estimators: ['x_xgb', 'causal_tree_1', 'x_rf', 't_rf', 'x_logistic']

Validation metrics:
  B̄_0.9 = 0.9167
  Δτ̄_0.9 = 0.0143
  t̄_0.9 = 0.3819

Criterion used: best_available
