# Identifying Heterogeneous Treatment Effects using Causal Forest

This notebook demonstrates how to identify heterogeneous treatment effects using Causal Forest in Python. The original implementation was in R, and here we convert it to Python using libraries such as `pandas`, `scikit-learn`, `causalml`, and `matplotlib`.


In [None]:
# !pip install pandas scikit-learn causalml matplotlib openpyxl

In [None]:
%matplotlib inline
import numpy as np
import pandas as pd
import datetime as dt
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from sklearn import preprocessing
from sklearn.preprocessing import StandardScaler
from causalml.inference.tree import CausalForest
from scipy import stats

import warnings 
warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', 400)

In [None]:
# PATH
DATA_PATH = "../data/"
RESULT_PATH = "../result/"
FILE_NAME = "FILE.csv"

In [None]:
# データの読み込み
df = pd.read_csv(DATA_PATH + FILE_NAME)
df.head()

In [None]:
COVARIATES_LIST = ["COVARIATES_LIST"]
OUTCOME = "OUTCOME"
TREATMENT = "TREATMENT"

In [None]:
# パラメータの選択
X = df[COVARIATES_LIST]
Y = df[OUTCOME].values
W = df[TREATMENT].values

In [None]:
# パラメータの選択
num_rankings = 5
num_folds = 10  # CV folds
kf = KFold(n_splits=num_folds, shuffle=True, random_state=1)

In [None]:
n = 1000  # シミュレーションの回数
results = []

for i in range(n):
    np.random.seed(i)
    cf = CausalForest(n_estimators=5000, random_state=i)
    cf.fit(X.values, W, Y)

    tau_hat = cf.predict(X.values).flatten()
    e_hat = cf.propensity
    m_hat = cf.marginal_outcome

    # Quintiles
    rankings = np.full(X.shape[0], np.nan)
    for train_index, test_index in kf.split(X):
        tau_hat_quantiles = np.quantile(tau_hat[test_index], np.linspace(0, 1, num_rankings + 1))
        rankings[test_index] = np.digitize(tau_hat[test_index], tau_hat_quantiles) - 1

    mu_hat_0 = m_hat - e_hat * tau_hat
    mu_hat_1 = m_hat + (1 - e_hat) * tau_hat
    aipw_scores = tau_hat + W / e_hat * (Y - mu_hat_1) - (1 - W) / (1 - e_hat) * (Y - mu_hat_0)

    # OLS regression
    ranking_dummies = pd.get_dummies(rankings)
    X_ols = ranking_dummies.values
    ols_coef = np.linalg.lstsq(X_ols, aipw_scores, rcond=None)[0]

    result = {
        "tau_hat_mean": np.mean(tau_hat),
        "tau_hat_diff": np.mean(tau_hat[W == 1]) - np.mean(tau_hat[W == 0]),
        "ols_coef": ols_coef
    }
    results.append(result)

In [None]:

# 最適なシード値の選択
results_df = pd.DataFrame(results)
best_seed = results_df["tau_hat_diff"].abs().idxmin()

In [None]:
# 最適なシード値で再度実行
np.random.seed(best_seed)
cf = CausalForest(n_estimators=5000, random_state=best_seed)
cf.fit(X.values, W, Y)

tau_hat = cf.predict(X.values).flatten()
e_hat = cf.propensity
m_hat = cf.marginal_outcome

rankings = np.full(X.shape[0], np.nan)
for train_index, test_index in kf.split(X):
    tau_hat_quantiles = np.quantile(tau_hat[test_index], np.linspace(0, 1, num_rankings + 1))
    rankings[test_index] = np.digitize(tau_hat[test_index], tau_hat_quantiles) - 1

mu_hat_0 = m_hat - e_hat * tau_hat
mu_hat_1 = m_hat + (1 - e_hat) * tau_hat
aipw_scores = tau_hat + W / e_hat * (Y - mu_hat_1) - (1 - W) / (1 - e_hat) * (Y - mu_hat_0)

In [None]:
# AIPW scoresのプロット
plt.figure(figsize=(10, 6))
plt.scatter(range(len(tau_hat)), tau_hat, alpha=0.5)
plt.axhline(y=0, color='grey', linestyle='--')
plt.xlabel("Individual ranking of treatment effects")
plt.ylabel("Estimated individual Polymyxin-B Hemoperfusion effects on 28-day survival")
plt.title("Estimated Treatment Effects")
plt.tight_layout()
plt.savefig(RESULT_PATH + FILE_NAME + "_AIPW.png", dpi=300)
plt.show()

# Quintileごとの平均治療効果のプロット
forest_ate = []
for q in range(num_rankings):
    mask = (rankings == q)
    ate = np.mean(aipw_scores[mask])
    stderr = np.std(aipw_scores[mask]) / np.sqrt(mask.sum())
    forest_ate.append((f"Quintile{q + 1}", ate, stderr))
forest_ate_df = pd.DataFrame(forest_ate, columns=["Quintile", "Estimate", "StdErr"])

In [None]:
plt.figure(figsize=(10, 6))
plt.errorbar(forest_ate_df["Quintile"], forest_ate_df["Estimate"], yerr=1.96 * forest_ate_df["StdErr"], fmt='o', color='black', capsize=5)
plt.axhline(y=0, color='grey', linestyle='--')
plt.xlabel("Quintile")
plt.ylabel("The treatment effect on Survival")
plt.title("Average CATE within each ranking (as defined by predicted CATE)")
plt.tight_layout()
plt.savefig(RESULT_PATH + FILE_NAME + "_ATE.png", dpi=300)
plt.show()

# 各層のATTとATEの計算
att_results = []
ate_results = []
for q in range(num_rankings):
    mask = (rankings == q)
    att = cf.estimate_ate(X.values[mask], W[mask], Y[mask], target="treated")
    ate = cf.estimate_ate(X.values[mask], W[mask], Y[mask])
    att_results.append(att)
    ate_results.append(ate)

print("ATT Results:", att_results)
print("ATE Results:", ate_results)