# Applying causal trees to data

Causal trees are a powerful tool for identifying treatment effects in randomized trials. In this notebook, we'll demonstrate their use with a real-world dataset from the [medicaldata R package](https://cran.r-project.org/web/packages/medicaldata/index.html).


In [None]:
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri
import numpy as np

# Install package for downloading dataset
cran_mirror = "https://cran.r-project.org"
ro.r(f'install.packages("medicaldata", repos="{cran_mirror}", quiet=TRUE)')

# Load the R package and dataset. Retrieve the dataset from the R environment and convert from R to pandas dataframe
ro.r('library(medicaldata)')
ro.r('data("indo_rct")')
pandas2ri.activate()
df = ro.r('indo_rct')

# Convert binary columns whose values are 1 and 2 to 0 and 1.
one_two_array = np.array([1,2])
for var in df.columns:
    if np.array_equal(sorted(df[var].unique()), one_two_array):
        df[var] -= 1

# Dropping 'bleed' variable as it is mostly missing (unclear if 0)
df.drop(columns= 'bleed', inplace=True)

print(f'Dataframe shape: {df.shape}')
display(df)

## Dataset Overview: RCT of Indomethacin for Prevention of Post-ERCP Pancreatitis

### TLDR

The data comes from a RCT trial. The question of this trial is does indomethacin (the treatment) prevent PEP (the outcome variable) after a medical procedure known as ERCP.

### Context
This dataset originates from a multicenter, randomized, placebo-controlled, prospective 2-arm trial designed to investigate whether rectal indomethacin (100 mg) can prevent post-ERCP pancreatitis (PEP). ERCP, or endoscopic retrograde cholangio-pancreatogram, is a medical procedure used to diagnose and treat conditions involving the pancreas and bile ducts. However, it carries a significant risk of complications, with post-ERCP pancreatitis being one of the most common and severe outcomes. 

The trial, published in the *New England Journal of Medicine* in 2012 by Elmunzer, Higgins, et al., enrolled 602 participants across multiple centers. It aimed to assess the effectiveness of indomethacin in reducing the occurrence of post-ERCP pancreatitis compared to a placebo.

### Key Variables
- **Identifier variables**: `id` (subject id) and `site` (where the experiment took place)
- **Outcome Variable (Y):** The occurrence of post-ERCP pancreatitis (`outcome`, factor: 0_no, 1_yes).
- **Treatment Variable (W):** The treatment arm (`rx`, factor: 0_placebo, 1_indomethacin).

### Study Design
This dataset is derived from a **randomized controlled trial (RCT)**. Participants were randomly assigned to one of two groups: a treatment group receiving rectal indomethacin (100 mg) or a placebo group. The primary outcome was the rate of post-ERCP pancreatitis, with secondary analyses exploring risk factors and safety concerns, such as gastrointestinal bleeding.


## Data Frame Summary

| Variable       | Description                                                                                                   | Type      | Levels                                                                                                    |
|----------------|---------------------------------------------------------------------------------------------------------------|-----------|-----------------------------------------------------------------------------------------------------------|
| id             | Subject ID, first integer indicates center (range: 1001–4003)                                                | Integer   | -                                                                                                         |
| site           | Study site (center)                                                                                          | Factor    | 1 = University of Michigan, 2 = Indiana University, 3 = University of Kentucky, 4 = Case Western         |
| age            | Age in years (range: 19–90)                                                                                  | Numeric   | -                                                                                                         |
| risk           | Risk score (range: 1–5.5)                                                                                    | Numeric   | -                                                                                                         |
| gender         | Gender                                                                                                       | Factor    | 0_female, 1_male                                                                                         |
| sod            | Sphincter of Oddi dysfunction present (a risk factor for post-ERCP pancreatitis)                             | Factor    | 0_no, 1_yes                                                                                              |
| pep            | Previous post-ERCP pancreatitis (PEP), a risk factor for future PEP                                          | Factor    | 0_no, 1_yes                                                                                              |
| recpanc        | Recurrent pancreatitis, a risk factor for future PEP                                                         | Factor    | 0_no, 1_yes                                                                                              |
| psphinc        | Pancreatic sphincterotomy performed (a risk factor for PEP)                                                  | Factor    | 0_no, 1_yes                                                                                              |
| precut         | Sphincter pre-cut needed to enter the papilla (a risk factor for PEP)                                        | Factor    | 0_no, 1_yes                                                                                              |
| difcan         | Difficulty cannulating the papilla (a risk factor for PEP)                                                   | Factor    | 0_no, 1_yes                                                                                              |
| pneudil        | Pneumatic dilation of the papilla performed (a risk factor for PEP)                                          | Factor    | 0_no, 1_yes                                                                                              |
| amp            | Ampullectomy performed for dysplasia or cancer (a risk factor for PEP)                                       | Factor    | 0_no, 1_yes                                                                                              |
| paninj         | Contrast injected into the pancreas during the procedure (a risk factor for PEP)                             | Factor    | 0_no, 1_yes                                                                                              |
| acinar         | Pancreatic acinarization observed on imaging (a risk factor for PEP)                                         | Factor    | 0_no, 1_yes                                                                                              |
| brush          | Brushings taken from the pancreatic duct (possible risk factor for PEP)                                      | Factor    | 0_no, 1_yes                                                                                              |
| asa81          | Aspirin used at 81 mg/day (may increase bleeding risk)                                                       | Factor    | 0_no, 1_yes                                                                                              |
| asa325         | Aspirin used at 325 mg/day (may increase bleeding risk)                                                      | Factor    | 0_no, 1_yes                                                                                              |
| asa            | Aspirin used at any dose (may increase bleeding risk)                                                        | Factor    | 0_no, 1_yes                                                                                              |
| prophystent    | Pancreatic duct stent placed as a protective measure against PEP                                             | Factor    | 0_no, 1_yes                                                                                              |
| therastent     | Pancreatic duct stent placed to treat narrowing of the duct                                                  | Factor    | 0_no, 1_yes                                                                                              |
| pdstent        | Pancreatic duct stent placed for any reason (potential protective effect against PEP)                        | Factor    | 0_no, 1_yes                                                                                              |
| sodsom         | Sphincter of Oddi manometry performed for SOD (a risk factor for PEP)                                        | Factor    | 0_no, 1_yes                                                                                              |
| bsphinc        | Biliary sphincterotomy performed (a risk factor for PEP)                                                     | Factor    | 0_no, 1_yes                                                                                              |
| bstent         | Biliary stent placed to relieve obstruction                                                                  | Factor    | 0_no, 1_yes                                                                                              |
| chole          | Choledocholithiasis (gallstones blocking the duct) present                                                  | Factor    | 0_no, 1_yes                                                                                              |
| pbmal          | Malignancy of the biliary duct or pancreas found                                                             | Factor    | 0_no, 1_yes                                                                                              |
| train          | Trainee participated in the ERCP (a potential risk factor for PEP)                                           | Factor    | 0_no, 1_yes                                                                                              |
| outcome        | Outcome of post-ERCP pancreatitis                                                                            | Factor    | 0_no, 1_yes                                                                                              |
| status         | Outpatient status                                                                                           | Factor    | 0_inpatient, 1_outpatient                                                                                 |
| type           | Sphincter of Oddi dysfunction type/level (higher levels indicate greater association with PEP)               | Factor    | 0_no SOC, 1_type 1, 2_type 2, 3_type 3                                                                    |
| rx             | Treatment arm                                                                                               | Factor    | 0_placebo, 1_indomethacin                                                                                 |
| bleed          | Gastrointestinal bleed occurred (a potential complication of indomethacin therapy)                          | Factor    | 0_no, 1_yes                                                                                              |


Here are a few variables from this study:

In [None]:
import matplotlib.pyplot as plt
df[['age','gender','rx','outcome']].hist(figsize=(6,6))
plt.suptitle('Histograms of age, gender, rx (treatment) and outcome:')
plt.tight_layout()
plt.show()

In [59]:
y = df['outcome']
w = df['rx']
X = df[[col for col in df.columns if col not in ['outcome', 'rx', 'id', 'site', 'risk']]] # omit outcome, treatment and identifier variables from X. 
# I also omit 'risk' because it is unclear how it is generated, and it could be based on a combination of variables within the dataset.

## Visualizing a tree

We use sklearn's `plot_tree` function to visualize the causal tree. This visualization helps us understand the decision rules and identify the variables most important for determining heterogeneous treatment effects.

In the tree diagram, the 'value' in each leaf refers to the Conditional Average Treatment Effect (CATE) for that subgroup.


In [None]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree
from econml.grf._base_grftree import GRFTree
from econml.grf import CausalForest
DecisionTreeClassifier.register(GRFTree)

# Instantiate and fit the model
causal_tree = CausalForest(n_estimators=1, subforest_size=1, inference=False, honest=True, random_state=123, criterion='mse')
causal_tree.fit(y=y, X=X, T=w)

# Plot the tree
fig = plt.figure(figsize=(12, 8))  # Adjust the size (width, height)
plot_tree(causal_tree[0], feature_names=X.columns.tolist(), filled=True, proportion=True)
plt.title('Causal tree estimation of the CATE')
plt.show()


- In this tree, the first splitting variable is PEP (whether someone has had previous post-ERCP pancreatitis), which accounts for the largest reduction in mean squared error (MSE).  
- This makes sense to me that history of PEP is an important factor in separating patients' heterogeneous treatment effects. Patients with recurring PEP may respond differently to treatment compared to first-time cases.

# Honest vs Adaptive

Athey and Imbens write about the costs and advantages of using honest estimation (see page six from [their paper](https://arxiv.org/pdf/1504.01132)):

### Costs and advantages of honest vs. adaptive approaches

- **Cost**:  
  Honest estimation reduces the available sample size for training since part of the data is reserved for estimation.

- **Advantage**:  
  Honest estimation minimizes bias in the leaf-specific treatment effect estimates. In adaptive estimation:
  - Spurious extreme values of $Y_i$ are more likely to be grouped together in the same leaf.
  - This can lead to exaggerated sample means within leaves, as these groups are not independently sampled.
---
- Does that mean that the predicted CATEs will have a wider distribution when using an adaptive method as compared to an honest method?
- To test this, I explore how honest and adaptive CATE estimates differ using K-Fold cross-validation to evaluate out-of-sample predictions.
- It appears that the standard deviation is indeed typically larger for the adaptive method in this experiment.

In [None]:
from sklearn.model_selection import KFold
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings("ignore", "use_inf_as_na")
warnings.filterwarnings("ignore", category=FutureWarning, module="seaborn")

# Cross-validation setup
kf = KFold(n_splits=5, shuffle=True, random_state=1)

def train_and_predict(X_train, X_test, y_train, w_train, honest):
    tree = CausalForest(
        n_estimators=1, subforest_size=1, inference=False, honest=honest, random_state=123
    )
    tree.fit(y=y_train, X=X_train, T=w_train)
    return pd.Series(tree.predict(X_test).squeeze(), index=X_test.index)

# Perform cross-validation
results = []
for i, (train_idx, test_idx) in enumerate(kf.split(X)):
    # Split data into training and test sets
    X_train, X_test = X.iloc[train_idx, :], X.iloc[test_idx, :]
    y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
    w_train, w_test = w.iloc[train_idx], w.iloc[test_idx]

        # Get predictions for both honest and adaptive models
    predictions = {
        'honest': train_and_predict(X_train, X_test, y_train, w_train, honest=True),
        'adaptive': train_and_predict(X_train, X_test, y_train, w_train, honest=False)
    }

    # Convert dictionary of predictions into a DataFrame
    predictions_df = pd.DataFrame(predictions).reset_index().assign(fold=i)
    predictions_df = predictions_df.melt(id_vars = ['fold','index'], var_name='method', value_name='CATE')
    results.append(predictions_df)

# Combine results from all folds into a single DataFrame
results = pd.concat(results, ignore_index=True)

# Plot the distribution of predicted CATE values by method
sns.histplot(results, x='CATE', hue='method')
plt.show()

print('The standard deviation of the predicted CATE distribution of the two methods.')
results.groupby('method').CATE.std().round(2)

## Using synthetic data

In order to understand the difference in methods as the sample size changes, let's work with some synthetic data where we know the CATE and can evaluate our models' predictive performance.

Here are some functions to generate some causal data, evaluate and cross validate the predictive performance of our causal tree model

In [62]:
from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error

def generate_data(n_samples, n_features=5, treatment_effect_fn=None, cate_noise_level=0.0, seed=42):
    """
    Generate synthetic data with a noisy true conditional average treatment effect (CATE).
    Optimized for large datasets by minimizing memory overhead.
    """
    np.random.seed(seed)
    X = np.random.normal(0, 1, size=(n_samples, n_features))  # Use NumPy array directly
    
    if treatment_effect_fn is None:
        treatment_effect_fn = lambda x: 2 * x[:, 0] + np.sin(x[:, 1]) + 0.5 * x[:, 2] + 0.5 * x[:, 3]
    
    true_cate = treatment_effect_fn(X)  # Noisy true CATE
    cate_noise = np.random.normal(0, cate_noise_level, n_samples)
    true_cate_noisy = true_cate + cate_noise

    w = np.random.binomial(1, 0.5, size=n_samples)  # Treatment assignment
    baseline = np.dot(X, np.random.uniform(-1, 1, n_features))
    y = baseline + np.random.normal(0, 1, n_samples) + w * true_cate_noisy  # Observed outcomes

    # Return as DataFrame/Series
    return (
        pd.DataFrame(X, columns=[f"x{i+1}" for i in range(n_features)]),
        pd.Series(w, name="w"),
        pd.Series(y, name="y"),
        pd.Series(true_cate_noisy, name="true_cate"),
    )

# Function to train a CausalForest and compute MSE
def evaluate_method(X_train, X_test, y_train, w_train, test_cate, honest, error_fn = mean_squared_error, n_estimators = 1, subforest_size = 1, model=CausalForest):
    tree = model(n_estimators=n_estimators, subforest_size=subforest_size, inference=False, 
                        honest=honest, random_state=123)
    tree.fit(y=y_train, X=X_train, T=w_train)
    cate_pred = tree.predict(X_test).squeeze()
    return error_fn(test_cate, cate_pred)

# Function to perform cross-validation for a dataset
def cross_validate(X, w, y, true_cate, kf, error_fn, n_estimators = 1, subforest_size = 1, model=CausalForest):
    mse_results = {"honest": [], "adaptive": []}
    for train_idx, test_idx in kf.split(X):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
        w_train, w_test = w.iloc[train_idx], w.iloc[test_idx]
        test_cate = true_cate.iloc[test_idx]
        mse_results["honest"].append(evaluate_method(X_train, X_test, y_train, w_train, test_cate, honest=True, error_fn = error_fn, n_estimators=n_estimators, subforest_size=subforest_size))
        mse_results["adaptive"].append(evaluate_method(X_train, X_test, y_train, w_train, test_cate, honest=False, error_fn = error_fn))
    return {key: np.mean(values) for key, values in mse_results.items()}

#### Understanding relationship between X and the CATE in our synthetic data.

We have defined the CATE $ = 2x_1 + \sin(x_2) + 0.5x_2 + 0.5x_3 + \epsilon$, where $\epsilon \sim \mathcal{N}(0, \sigma^2)$ represents Gaussian noise with standard deviation $\sigma$. Here is what that looks like graphically.

In [None]:
# Generate dataset
X, w, y, true_cate = generate_data(n_samples=500, cate_noise_level=1) # here the standard deviation is 1 for the noise parameter.

scatter = plt.scatter(X['x1'], X['x2'], c=true_cate, cmap='viridis', s=50, alpha=0.8)
plt.colorbar(scatter, label="CATE")
plt.xlabel("x1")
plt.ylabel("x2")
plt.title("CATE by x1 and x2")
plt.tight_layout()
plt.show()

In [None]:
# Compare MSE and MAPE across dataset sizes
n_samples_list = np.logspace(1.7, 4, num=10).astype(int)
mse_results = []
mape_results  = []
kf = KFold(n_splits=8, shuffle=True, random_state=42)

for n_samples in n_samples_list:
    X, w, y, true_cate = generate_data(n_samples=n_samples)
    mse_scores = cross_validate(X, w, y, true_cate, kf, error_fn=mean_squared_error)
    mse_results.append({"n_samples": n_samples, "honest": mse_scores["honest"], "adaptive": mse_scores["adaptive"]})
    mape_scores = cross_validate(X, w, y, true_cate, kf, error_fn=mean_absolute_percentage_error)
    mape_results.append({"n_samples": n_samples, "honest": mse_scores["honest"], "adaptive": mape_scores["adaptive"]})

for name, result in (('MSE',mse_results),('MAPE',mape_results)):
    # Prepare data for plotting
    results_df = pd.DataFrame(result).melt(id_vars="n_samples", var_name="method", value_name="mse")

    # Plot results
    sns.barplot(data=results_df, x="n_samples", y="mse", hue="method")
    plt.title(f"{name} comparison across sample sizes")
    plt.xlabel("Number of samples")
    plt.ylabel(f"{name}")
    plt.legend(title="Method")
    plt.show()

The error amount for both methods decreases as the sample size increases.

It seems for smaller samples (n <= 950), there isn't a clear winner between the honest and adaptive method.

However, as the number of samples increases, the honest method typically beats the adaptive model (see n > 950).

What about when we work with causal *forests* and not trees?

In [None]:
# Compare MSE and MAPE across dataset sizes. This time we specify that n_estimators is 100.
n_samples_list = np.logspace(1.7, 4, num=10).astype(int)
mse_results = []
mape_results  = []
kf = KFold(n_splits=8, shuffle=True, random_state=42)

for n_samples in n_samples_list:
    X, w, y, true_cate = generate_data(n_samples=n_samples)
    mse_scores = cross_validate(X, w, y, true_cate, kf, error_fn=mean_squared_error, n_estimators=100, subforest_size=4)
    mse_results.append({"n_samples": n_samples, "honest": mse_scores["honest"], "adaptive": mse_scores["adaptive"]})
    mape_scores = cross_validate(X, w, y, true_cate, kf, error_fn=mean_absolute_percentage_error, n_estimators=100, subforest_size=4)
    mape_results.append({"n_samples": n_samples, "honest": mse_scores["honest"], "adaptive": mape_scores["adaptive"]})

for name, result in (('MSE',mse_results),('MAPE',mape_results)):
    # Prepare data for plotting
    results_df = pd.DataFrame(result).melt(id_vars="n_samples", var_name="method", value_name="mse")

    # Plot results
    sns.barplot(data=results_df, x="n_samples", y="mse", hue="method")
    plt.title(f"{name} comparison across sample sizes")
    plt.xlabel("Number of samples")
    plt.ylabel(f"{name}")
    plt.legend(title="Method")
    plt.show()

### *I want to learn more. What are some good places to start?*

- For comparing model performance, see ["A comparison of methods for model selection when estimating individual treatment effects" by Schuler et al.](https://arxiv.org/pdf/1804.05146) or EconML's [RScorer](https://econml.azurewebsites.net/_autosummary/econml.score.RScorer.html#schuleretal2018) (not to be confused with an R score).
- For Double Machine Learning see ["Double/Debiased Machine Learning for Treatment and Structural Parameters" by Chernozhukov et al.](https://arxiv.org/pdf/1608.00060)
- For Generalized Random Forests see ["Generalized Random Forests" by Athey et al.](https://arxiv.org/pdf/1610.01271)
- For IV in ML see ["Instrumental Variables in Causal Inference and Machine Learning: A Survey" by Wu et al.](https://arxiv.org/pdf/2212.05778)