In [1]:
# %% [markdown]
# # Complete Causal Inference Pipeline for Diagnostic Escalation
# 
# This notebook walks through the entire pipeline, from loading the raw ARS data to estimating the causal effect of resistance on diagnostic escalation, comparing standard and joint selection models, applying Bayesian shrinkage, and exploring heterogeneity with causal forests.
# 
# **Prerequisites**: 
# - Python environment with `escalation-causal` installed.
# - ARS dataset in Parquet format (adjust the path below).
# - Filter configuration JSON file (adjust path).
# 
# **Notation**:
# - \(T_{code}\): tested indicator (1 if tested, 0 otherwise)
# - \(R_{code}\): resistance indicator (1 if tested and resistant, 0 otherwise)
# - \(A\): trigger antibiotic (treatment)
# - \(D\): target antibiotic (outcome)
# - \(C\): context (lab, pathogen group, year)
# - \(Y^*_D\): escalation score for target \(D\)
# - \(\psi\): risk difference \(\mathbb{E}[Y^*|A=1] - \mathbb{E}[Y^*|A=0]\)

# %% [markdown]
# ## 1. Setup and Imports

# %%
import sys
import logging
from pathlib import Path

import pandas as pd
import numpy as np
from sklearn.model_selection import GroupShuffleSplit

# Our package modules
from src.controllers.DataLoader import DataLoader
from src.controllers.filters.FilteringStrategy import FilterConfig
from src.controllers.escalation_causal.config.settings import (
    RunConfig, SplitConfig, CovariateConfig, PolicyConfig, NuisanceConfig, TMLEConfig
)
from src.controllers.escalation_causal.pipeline import CausalPipeline
from src.controllers.escalation_causal.screening.phase1_screener import Phase1Screener, Phase1Config
from src.controllers.escalation_causal.utils.io import save_results
from src.controllers.escalation_causal.nuisance.joint_selection import JointSelectionModel
from src.controllers.escalation_causal.multiple_comparison.bayesian_shrinkage import BayesianShrinkage
from src.controllers.escalation_causal.heterogeneity.causal_forest import CausalForestWrapper

# Plotting
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

In [None]:
# %% [markdown]
# ## 2. Load and Filter Real Data
# 
# We use the `DataLoader` to read the ARS Parquet file and apply the inclusion/exclusion criteria defined in a JSON configuration file.
# 
# **Why this step?**  
# Raw data often contain screening cultures, repeated isolates, or other records that are not suitable for causal analysis. Filtering ensures we work with a clean, clinically relevant cohort.
# 
# **What if you skip filtering?**  
# You might include screening cultures or repeated isolates, violating the independence assumption and biasing your results.

# %%
data_path = "./datasets/structured/dataset_parquet"          # <-- UPDATE THIS PATH
filter_config_path = "./src/controllers/filters/config_all.json"  # <-- UPDATE THIS PATH

loader = DataLoader(data_path, strict=False, normalize_on_load=True)
filter_config = FilterConfig.from_json(filter_config_path)

df, meta = loader.get_cohort(
    filter_config=filter_config,
    apply_exclusions=True,
    verbose=True,
)

print(f"\n✅ Final cohort: {len(df)} isolates from {meta.n_labs} laboratories")
print(f"   Year range: {meta.yearmonth_min} – {meta.yearmonth_max}")

# %% [markdown]
# ## 3. Build Antibiotic Flags
# 
# The `DataLoader` generates binary flags for testing (`_T`) and resistance (`_R`) for all antibiotics. This gives us a clean matrix for analysis.

# %%
all_codes = sorted(loader.code_to_base.keys())
print(f"Total antibiotic codes: {len(all_codes)}")

flags = loader.get_abx_flags(
    df,
    codes=all_codes,
    recode_mode="R_vs_nonR",
    drop_I=True,
)

print(f"Flags shape: {flags.shape}")
flags.head()



APPLYING FILTERS
[Pathogen:equals] 328,914 → 26,765 (8.1% retained, 302,149 dropped)
[CSQMG:equals] 26,765 → 21,937 (82.0% retained, 4,828 dropped)
[ARS_WardType:in] 21,937 → 21,462 (97.8% retained, 475 dropped)
[CareType:in] 21,462 → 21,462 (100.0% retained, 0 dropped)
[Year:range] 21,462 → 21,462 (100.0% retained, 0 dropped)

 FILTERING SUMMARY
Initial rows:     328,914
Final rows:       21,462
Total removed:    307,452
Overall retained: 6.5%


APPLYING FILTERS (All Isolates for Klebsiella pneumoniae Analysis)
[exclusions] dropped=0 detail={'IsSpecificlyExcluded_Screening': 0, 'IsSpecificlyExcluded_Pathogen': 0, 'IsSpecificlyExcluded_PathogenevidenceNegative': 0}
[Pathogen:equals] 328,914 → 26,765 (8.1% retained, 302,149 dropped)
[CSQMG:equals] 26,765 → 21,937 (82.0% retained, 4,828 dropped)
[ARS_WardType:in] 21,937 → 21,462 (97.8% retained, 475 dropped)
[CareType:in] 21,462 → 21,462 (100.0% retained, 0 dropped)
[Year:range] 21,462 → 21,462 (100.0% retained, 0 dropped)


✅ Final coh

Unnamed: 0,AMC_T,AMC_R,AMK_T,AMK_R,AMP_T,AMP_R,AMS_T,AMS_R,AMX_T,AMX_R,...,TGC_T,TGC_R,TOB_T,TOB_R,TPL_T,TPL_R,TRP_T,TRP_R,VAN_T,VAN_R
0,1,0,1,0,1,1,0,0,0,0,...,0,0,1,0,0,0,1,0,0,0
1,1,0,1,0,1,1,0,0,0,0,...,0,0,1,0,0,0,1,0,0,0
2,1,0,1,0,1,1,0,0,0,0,...,0,0,1,0,0,0,1,0,0,0
3,1,0,1,0,1,1,0,0,0,0,...,0,0,1,0,0,0,1,0,0,0
4,1,0,1,0,1,1,0,0,0,0,...,0,0,1,0,0,0,1,0,0,0


In [4]:
# %% [markdown]
# ## 4. Split Data into Discovery and Estimation Sets
# 
# We split by laboratory (`Anonymized_Lab`) so that all isolates from a given lab stay together. This prevents information leakage between the two sets.
# 
# **Why split by lab?**  
# Testing protocols vary between laboratories. By splitting at the lab level, the routine policy learned on discovery labs is evaluated on entirely different labs, mimicking external validation.
# 
# **What if you use a random split?**  
# Then isolates from the same lab could appear in both discovery and estimation, causing the screening to be overly optimistic (winner’s curse) and the policy to be evaluated in‑sample.

# %%
group_col = "Anonymized_Lab"
groups = df[group_col].astype(str).fillna("NA").to_numpy()

gss = GroupShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
discovery_idx, estimation_idx = next(gss.split(df.index, groups=groups))

discovery_df = df.iloc[discovery_idx].copy()
discovery_flags = flags.iloc[discovery_idx].copy()
estimation_df = df.iloc[estimation_idx].copy()
estimation_flags = flags.iloc[estimation_idx].copy()

print(f"Discovery set: {len(discovery_df)} isolates from {discovery_df[group_col].nunique()} labs")
print(f"Estimation set: {len(estimation_df)} isolates from {estimation_df[group_col].nunique()} labs")

Discovery set: 7706 isolates from 38 labs
Estimation set: 13756 isolates from 39 labs


In [5]:
# %% [markdown]
# ## 5. Phase 1 Screening on the Discovery Set
# 
# We use the `Phase1Screener` to identify the top 100 trigger–target pairs (by absolute log‑odds ratio) that pass crude association tests with FDR correction.
# 
# **Why screening?**  
# With many antibiotics, the number of possible pairs is huge (~3,600). Estimating all would be computationally prohibitive and would incur severe multiple testing penalties. Screening selects the most promising pairs for the full causal analysis.
# 
# **What if you skip screening and run all pairs?**  
# The pipeline would take much longer, and you would face a massive multiple‑comparison problem. Moreover, many pairs would have insufficient sample sizes, leading to many failed estimates.

# %%
phase1_cfg = Phase1Config(
    min_group=50,
    min_trigger_tested=100,
    crude_screening_threshold=0.05,
    fdr_alpha=0.05,
    exclude_targets_equal_trigger=True
)

screener = Phase1Screener(phase1_cfg)

phase1_df = screener.run(
    df=discovery_df,
    flags=discovery_flags,
    all_codes=all_codes,
    top_n=100
)

print(f"Phase 1 screening retained {len(phase1_df)} pairs.")
if not phase1_df.empty:
    display(phase1_df.head())
else:
    print("⚠️ No pairs passed screening. Check group sizes and thresholds.")

# Convert to list of tuples for the pipeline
pairs = list(zip(phase1_df["trigger"], phase1_df["target"]))

Phase 1 screening retained 100 pairs.


Unnamed: 0,trigger,target,or_unadjusted,or_ci_low,or_ci_high,p_value,delta,p_D_tested_given_A_R,p_D_tested_given_A_S,n_trigger_tested,n_A_R,n_A_S,q_value,significant,abs_log_or
0,NFT,DOX,452.873255,27.960618,7335.109102,1.3038410000000001e-40,0.117647,0.117647,0.0,2213,527,1686,8.139690000000001e-39,True,6.115612
1,CRO,COL,352.548217,21.052441,5903.840125,5.3803630000000006e-17,0.038168,0.038168,0.0,4697,393,4304,8.708217e-16,True,5.865187
2,CEP,COL,222.510393,13.066392,3789.177354,6.270156e-12,0.048458,0.048458,0.0,2321,227,2094,5.591955e-11,True,5.404974
3,CPO,COL,198.499538,11.735542,3357.498651,4.318855e-12,0.0217,0.0217,0.0,4852,553,4299,3.931957e-11,True,5.290787
4,TET,CRO,193.617886,11.642249,3219.986645,1.101263e-16,0.306818,0.306818,0.0,304,88,216,1.718756e-15,True,5.265887


In [6]:
# %% [markdown]
# ## 6. Configure the Causal Pipeline
# 
# All parameters are centralised in a `RunConfig` object. You can adjust these settings as needed. Here we set the configuration for the **standard model** (no joint selection).

# %%
config = RunConfig(
    split=SplitConfig(
        test_size=0.3,
        split_group_col="Anonymized_Lab",
        random_state=42
    ),
    covariates=CovariateConfig(
        covariate_cols=["Anonymized_Lab", "ARS_WardType", "AgeGroup", "Year"],
        min_count=200,
        max_levels=25,
        drop_first=True
    ),
    policy=PolicyConfig(
        context_cols=["Anonymized_Lab", "PathogengroupL1", "Year"],
        method="empirical",
        min_context_n=100,
        model_type="xgb",
        calibrate=True,
        calibration_method="isotonic",
        calibration_cv=5
    ),
    nuisance=NuisanceConfig(
        testing_model="xgb",
        propensity_model="xgb",
        outcome_model="xgb",
        calibrate_testing=True,
        calibrate_propensity=False,
        calibrate_outcome=False,
        testing_cv_folds=5,
        use_joint_selection=False,   # we will create two copies with different values
        random_state=42
    ),
    tmle=TMLEConfig(
        n_folds=5,
        min_prob=0.01,
        weight_cap_percentile=99.0,
        min_tested=200,
        min_group=50,
        stabilize_weights=True,
        n_bootstrap=None,
        alpha=0.05
    )
)

In [7]:
# %% [markdown]
# ## 7. Run Pipeline Without Joint Selection (Standard Model)
# 
# This is our baseline – it uses standard inverse probability weighting for testing selection.

# %%
config_no_joint = config.model_copy(deep=True)
config_no_joint.nuisance.use_joint_selection = False

pipeline_no_joint = CausalPipeline(config_no_joint, n_jobs=4)   # use 4 cores
results_no_joint = pipeline_no_joint.run(
    df=estimation_df,
    flags=estimation_flags,
    all_codes=all_codes,
    pairs=pairs
)

print(f"Number of successful estimates (no joint): {len(results_no_joint[results_no_joint['status']=='ok'])}")
results_no_joint.head()

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   5 tasks      | elapsed:    1.8s
[Parallel(n_jobs=4)]: Done  10 tasks      | elapsed:    2.0s
[Parallel(n_jobs=4)]: Done  17 tasks      | elapsed:    8.6s
[Parallel(n_jobs=4)]: Done  24 tasks      | elapsed:   15.8s
[Parallel(n_jobs=4)]: Done  33 tasks      | elapsed:   22.9s
[Parallel(n_jobs=4)]: Done  42 tasks      | elapsed:   23.5s
[Parallel(n_jobs=4)]: Done  53 tasks      | elapsed:   37.4s
[Parallel(n_jobs=4)]: Done  64 tasks      | elapsed:   44.8s
[Parallel(n_jobs=4)]: Done  77 tasks      | elapsed:   53.9s
[Parallel(n_jobs=4)]: Done  90 tasks      | elapsed:  1.3min
[Parallel(n_jobs=4)]: Done 100 out of 100 | elapsed:  1.5min finished


Number of successful estimates (no joint): 45


Unnamed: 0,trigger,target,rd,ci_low,ci_high,p_value,se,ess,n_used,n_trigger_tested,n_A1,n_A0,baseline_mu0,escalation_score_mean,diagnostics,model_spec,status,skip_reason
0,NFT,DOX,0.0,0.0,0.0,1.0,0.0,450.228404,718,718,447,271,0.0,0.0,{'testing_model': {'p_test_quantiles': {'q00':...,"{'propensity_model': 'xgb', 'outcome_model': '...",ok,
1,CRO,COL,,,,,,,0,365,39,326,,,{},{},failed,"Group sizes too small: A=1 39, A=0 326 (min re..."
2,CEP,COL,0.0,0.0,0.0,1.0,0.0,487.921173,582,582,60,522,0.0,0.0,{'testing_model': {'p_test_quantiles': {'q00':...,"{'propensity_model': 'xgb', 'outcome_model': '...",ok,
3,CPO,COL,0.0,0.0,0.0,1.0,0.0,984.205472,1031,1031,107,924,0.0,0.0,{'testing_model': {'p_test_quantiles': {'q00':...,"{'propensity_model': 'xgb', 'outcome_model': '...",ok,
4,TET,CRO,,,,,,,0,83,0,0,,,{},{},failed,Too few tested isolates for TET: n=83 < 200


In [8]:
# %% [markdown]
# ## 8. Run Pipeline With Joint Selection
# 
# Now we enable the bivariate probit joint selection model to account for unmeasured common causes of testing and resistance.

# %%
config_joint = config.model_copy(deep=True)
config_joint.nuisance.use_joint_selection = True

pipeline_joint = CausalPipeline(config_joint, n_jobs=4)
results_joint = pipeline_joint.run(
    df=estimation_df,
    flags=estimation_flags,
    all_codes=all_codes,
    pairs=pairs
)

print(f"Number of successful estimates (joint): {len(results_joint[results_joint['status']=='ok'])}")
results_joint.head()

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Batch computation too fast (0.09719705581665039s.) Setting batch_size=2.
Joint model fitting failed: Singular matrix. Falling back to separate probits.
Joint model fitting failed: Singular matrix. Falling back to separate probits.
Joint model fitting failed: Singular matrix. Falling back to separate probits.
[Parallel(n_jobs=4)]: Done   5 tasks      | elapsed:    0.3s
Joint model fitting failed: Singular matrix. Falling back to separate probits.
[Parallel(n_jobs=4)]: Done  13 tasks      | elapsed:    2.3s
Joint model fitting failed: Singular matrix. Falling back to separate probits.
Joint model fitting failed: Singular matrix. Falling back to separate probits.
Joint model fitting failed: Singular matrix. Falling back to separate probits.
Joint model fitting failed: Singular matrix. Falling back to separate probits.
Joint model fitting failed: Singular matrix. Falling back to separate probit

KeyboardInterrupt: 

In [None]:
# %% [markdown]
# ## 9. Compare the Two Sets of Results
# 
# We merge the two result DataFrames on trigger–target pairs and compute the difference.

# %%
def compare_results(res_no_joint, res_joint):
    # Keep only successful estimates
    noj_ok = res_no_joint[res_no_joint["status"] == "ok"][["trigger", "target", "rd", "ci_low", "ci_high"]].copy()
    joint_ok = res_joint[res_joint["status"] == "ok"][["trigger", "target", "rd", "ci_low", "ci_high"]].copy()
    
    noj_ok.rename(columns={"rd": "rd_noj", "ci_low": "ci_low_noj", "ci_high": "ci_high_noj"}, inplace=True)
    joint_ok.rename(columns={"rd": "rd_joint", "ci_low": "ci_low_joint", "ci_high": "ci_high_joint"}, inplace=True)
    
    merged = pd.merge(noj_ok, joint_ok, on=["trigger", "target"], how="inner")
    merged["diff"] = merged["rd_joint"] - merged["rd_noj"]
    merged["pair"] = merged["trigger"] + " → " + merged["target"]
    return merged

comparison_df = compare_results(results_no_joint, results_joint)
print(f"Merged {len(comparison_df)} pairs for comparison.")
comparison_df.head()

In [None]:
# %% [markdown]
# ## 10. Visualise the Comparison
# 
# We create a two‑panel plot: point estimates with confidence intervals for both methods (top), and a bar chart of the differences (bottom).

# %%
def plot_comparison(merged):
    merged_sorted = merged.sort_values("diff", ascending=False).reset_index(drop=True)
    y_pos = list(range(len(merged_sorted)))
    
    fig = make_subplots(
        rows=2, cols=1,
        subplot_titles=("Risk Difference Estimates", "Difference (Joint – Standard)"),
        shared_xaxes=True,
        vertical_spacing=0.15
    )
    
    # Standard estimates
    fig.add_trace(go.Scatter(
        x=merged_sorted["rd_noj"],
        y=y_pos,
        mode="markers",
        marker=dict(color="blue", size=8),
        name="Standard",
        error_x=dict(
            type="data",
            symmetric=False,
            array=merged_sorted["rd_noj"] - merged_sorted["ci_low_noj"],
            arrayminus=merged_sorted["ci_high_noj"] - merged_sorted["rd_noj"],
            color="blue",
            thickness=1
        ),
        showlegend=True
    ), row=1, col=1)
    
    # Joint estimates
    fig.add_trace(go.Scatter(
        x=merged_sorted["rd_joint"],
        y=y_pos,
        mode="markers",
        marker=dict(color="red", size=8),
        name="Joint",
        error_x=dict(
            type="data",
            symmetric=False,
            array=merged_sorted["rd_joint"] - merged_sorted["ci_low_joint"],
            arrayminus=merged_sorted["ci_high_joint"] - merged_sorted["rd_joint"],
            color="red",
            thickness=1
        ),
        showlegend=True
    ), row=1, col=1)
    
    # Difference bar chart
    colors = ["red" if d > 0 else "blue" for d in merged_sorted["diff"]]
    fig.add_trace(go.Bar(
        x=merged_sorted["diff"],
        y=y_pos,
        orientation="h",
        marker_color=colors,
        name="Difference",
        showlegend=False
    ), row=2, col=1)
    
    # Vertical lines at zero
    fig.add_vline(x=0, line_dash="dash", line_color="gray", row=1, col=1)
    fig.add_vline(x=0, line_dash="dash", line_color="gray", row=2, col=1)
    
    fig.update_yaxes(tickvals=y_pos, ticktext=merged_sorted["pair"], row=1, col=1)
    fig.update_yaxes(tickvals=y_pos, ticktext=merged_sorted["pair"], row=2, col=1)
    
    fig.update_layout(
        height=800,
        width=1000,
        title_text="Comparison: Standard vs. Joint Selection Model",
        showlegend=True
    )
    
    fig.show()

plot_comparison(comparison_df)

In [None]:
# %% [markdown]
# ## 11. Save Intermediate Results (Optional)
# 
# Save the comparison data and the individual results for later use.

# %%
output_dir = Path("./comparison_output")
output_dir.mkdir(exist_ok=True)

comparison_df.to_csv(output_dir / "comparison_results.csv", index=False)
results_no_joint.to_csv(output_dir / "results_no_joint.csv", index=False)
results_joint.to_csv(output_dir / "results_joint.csv", index=False)

print(f"Results saved to {output_dir}")

In [None]:
# %% [markdown]
# ## 12. Bayesian Shrinkage for Multiple Comparisons
# 
# We apply a Bayesian hierarchical model to shrink the estimates toward a common mean, borrowing strength across pairs. This helps guard against false positives due to multiple testing.
# 
# **Important**: We must exclude pairs with zero standard error (where the target was never tested) because they cause degeneracy in the likelihood.

# %%
# Load the standard model results (we can also use joint, but let's use standard for consistency)
results_std = pd.read_csv(output_dir / "results_no_joint.csv")

# Keep only successful estimates with positive standard errors
ok = results_std[(results_std['status'] == 'ok') & (results_std['se'] > 0)].copy()
ok['pair'] = ok['trigger'] + ' → ' + ok['target']
ok = ok.dropna(subset=['rd', 'se'])

if len(ok) > 1:
    # Use higher target_accept to reduce divergences; may need to tune further
    bs = BayesianShrinkage(random_seed=42, draws=2000, tune=1000, target_accept=0.99)
    bs.fit(ok, estimate_col='rd', se_col='se')
    shrunk = bs.summary()
    shrunk.to_csv(output_dir / "bayesian_shrinkage_summary.csv", index=False)
    print(f"Shrinkage complete for {len(ok)} pairs.")
    display(shrunk.head())
else:
    print("Not enough valid pairs for Bayesian shrinkage.")

In [None]:
# %% [markdown]
# **Interpretation**: 
# - `theta_shrunken` is the posterior mean after shrinkage.
# - `prob_positive` is the posterior probability that the true effect is positive.
# - If `prob_positive` > 0.95, we have strong evidence of a positive effect.
# 
# *Note*: MCMC warnings (divergences, rhat > 1.01) may appear; they indicate that the posterior may not be fully reliable. You can try increasing `target_accept` further or using a more informative prior for `tau` (modify the `BayesianShrinkage` class).

# %% [markdown]
# ## 13. Heterogeneity Analysis with Causal Forest (for a Pair of Interest)
# 
# Suppose we want to explore whether the effect varies by clinical context for a specific pair – e.g., the one with the strongest signal. From our results, we might choose **NFT → ERT** (p ≈ 3×10⁻⁵) or **CPO → CIP** (high posterior probability, though tiny effect). We'll use **NFT → ERT** for illustration.
# 
# **Challenge**: The pipeline does not automatically store per‑observation data. We need to reconstruct them for the pair of interest. The following steps do this by:
# - Extracting the test set from the pipeline (or re‑running the pipeline for a single pair with a modified version that returns the data).
# - Re‑fitting the testing model to obtain weights.
# - Fitting a causal forest.

# %%
# Choose a pair
trigger = 'NFT'
target = 'ERT'

# We need the test set from the pipeline run. If you saved it, load it; otherwise, we can re‑extract from estimation set.
# For simplicity, we assume the pipeline object `pipeline_no_joint` still exists and has attributes `df_test`, `flags_test`, `esc_scores`.
# If not, you may need to re‑run the pipeline for this single pair with a modified version that stores these.

# Check if the pipeline has the necessary attributes (you may need to modify pipeline.py to store them)
if hasattr(pipeline_no_joint, 'df_test') and hasattr(pipeline_no_joint, 'flags_test') and hasattr(pipeline_no_joint, 'esc_scores'):
    df_test = pipeline_no_joint.df_test
    flags_test = pipeline_no_joint.flags_test
    esc_scores = pipeline_no_joint.esc_scores
else:
    # Fallback: re‑run the pipeline for just this pair with a modified config that ensures we get the test set
    print("Pipeline does not store test data; re‑running for the single pair with a custom function.")
    # (This would require a custom function; for brevity, we assume the data are available.)

# Get tested mask for trigger
T_col = f"{trigger}_T"
tested_mask = flags_test[T_col].astype(int).to_numpy() == 1

# Build covariates (using same encoding as pipeline)
X = pipeline_no_joint._encode_covariates(df_test, config.covariates.covariate_cols)
X_tested = X[tested_mask]

# Treatment and outcome
A = flags_test[f"{trigger}_R"].astype(int).to_numpy()[tested_mask]
Y = esc_scores[target][tested_mask]

# Re‑compute testing model weights for this pair (using standard model, since joint may have issues)
from src.controllers.escalation_causal.nuisance.testing_model import TestingModel
test_model = TestingModel(
    model_type=config.nuisance.testing_model,
    calibrate=config.nuisance.calibrate_testing,
    n_folds_cv=config.nuisance.testing_cv_folds,
    min_prob=config.tmle.min_prob,
    weight_cap_percentile=config.tmle.weight_cap_percentile,
)
T_full = flags_test[T_col].astype(int).to_numpy()
test_model.fit(X, T_full)
p_test = test_model.get_oof_predictions() if test_model._is_cross_fitted else test_model.predict_proba(X)
w_full, _ = test_model.compute_weights(p_test, tested_mask)
w = w_full[tested_mask]

# Fit causal forest
# We need feature names (you can extract them from the encoder; here we use generic names)
feature_names = [f"cov_{i}" for i in range(X.shape[1])]

cf = CausalForestWrapper(
    n_estimators=400,
    max_depth=20,
    min_samples_leaf=10,
    random_state=42
)
cf.fit(X_tested, A, Y, sample_weight=w, feature_names=feature_names)

# Variable importance
imp = cf.feature_importances()
imp_df = pd.DataFrame(list(imp.items()), columns=['feature', 'importance'])
imp_df.sort_values('importance', ascending=False).to_csv(output_dir / f"cf_importance_{trigger}_{target}.csv", index=False)
print("Variable importance saved.")

# Subgroup summaries
# By ward type
ward_labels = df_test.loc[tested_mask, "ARS_WardType"].values
ward_summary = cf.get_cate_summary(X_tested, group_labels=ward_labels)
ward_summary.to_csv(output_dir / f"cate_by_ward_{trigger}_{target}.csv", index=False)
print("CATE by ward type:")
print(ward_summary)

# By age group
age_labels = df_test.loc[tested_mask, "AgeGroup"].values
age_summary = cf.get_cate_summary(X_tested, group_labels=age_labels)
age_summary.to_csv(output_dir / f"cate_by_age_{trigger}_{target}.csv", index=False)
print("CATE by age group:")
print(age_summary)

# Plot CATE distribution
plt.figure()
cf.plot_cate_distribution(X_tested)
plt.savefig(output_dir / f"cate_distribution_{trigger}_{target}.png", dpi=300, bbox_inches='tight')
plt.show()

# Plot variable importance (top 15)
cf.plot_variable_importance(top_k=15)
plt.savefig(output_dir / f"cf_importance_plot_{trigger}_{target}.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# %% [markdown]
# **Interpretation**:
# - Variable importance shows which covariates (ward type, age, year) most influence the treatment effect.
# - Subgroup summaries give average CATE within each category. If the confidence intervals are wide, the heterogeneity may not be statistically significant.

# %% [markdown]
# ## 14. Summary and Next Steps
# 
# You have now:
# 1. Loaded and filtered the ARS data.
# 2. Split into discovery and estimation sets.
# 3. Performed Phase 1 screening to select 100 trigger–target pairs.
# 4. Run the causal pipeline with and without the joint selection model.
# 5. Compared the estimates and generated a comparison plot.
# 6. Applied Bayesian shrinkage to handle multiple comparisons.
# 7. Explored heterogeneity with a causal forest for a specific pair.
# 
# **Possible next steps**:
# - Re‑run with a lower `min_group` threshold to include more pairs.
# - Improve the joint model convergence by tweaking its parameters.
# - Generate publication‑ready figures using the `visualization` module (forest plots, network graphs, etc.).
# - Write up your findings in a paper.

# %% [markdown]
# ## 15. Troubleshooting
# 
# | Problem | Likely cause | Solution |
# |---------|--------------|----------|
# | Phase 1 screening empty | `min_group` or `min_trigger_tested` too high | Lower thresholds, or check data. |
# | Pipeline fails for many pairs | `min_tested` or `min_group` too high; some triggers rarely tested | Reduce thresholds, or accept that some pairs cannot be estimated. |
# | Joint model fails to converge | Sample size too small; perfect prediction | Fallback to separate models automatically; consider using a two‑stage Heckman instead. |
# | Bayesian shrinkage gives warnings (divergences, rhat) | Weak prior, small number of pairs | Increase `target_accept`, use more informative prior for tau, or increase draws. |
# | Causal forest requires test set data | Pipeline not configured to store it | Modify `pipeline.py` to save `df_test`, `flags_test`, and `esc_scores` after run. |

# %% [markdown]
# ---
# **End of notebook**