Skip to content

Latest commit

 

History

History
186 lines (133 loc) · 9.38 KB

File metadata and controls

186 lines (133 loc) · 9.38 KB

Ensemble of Causal Trees

An efficient and useful technique for growing a random forest is by simply averaging the result of each individual tree. Consequently, we can also apply this technique to grow a causal forest by combining many single causal tree. In YLearn, we implement this idea in the class :pyCTCausalForest (refering to Causal Tree Causal Forest).

Since it is an ensemble of a bunch of CausalTree, currently it only supports binary treatment. One may need specify the treat and control groups before applying the CTCausalForest. This will be improved in the future version.

We provide below an example of it.

Example

We first build a dataset and define the names of treatment, outcome, and covariate separately.

import numpy as np
import matplotlib.pyplot as plt

from ylearn.estimator_model import CTCausalForest
from ylearn.exp_dataset.exp_data import sq_data
from ylearn.utils._common import to_df


# build dataset
n = 2000
d = 10     
n_x = 1
y, x, v = sq_data(n, d, n_x)
true_te = lambda X: np.hstack([X[:, [0]]**2 + 1, np.ones((X.shape[0], n_x - 1))])
data = to_df(treatment=x, outcome=y, v=v)
outcome = 'outcome'
treatment = 'treatment'
adjustment = data.columns[2:]

# build test data
v_test = v[:min(100, n)].copy()
v_test[:, 0] = np.linspace(np.percentile(v[:, 0], 1), np.percentile(v[:, 0], 99), min(100, n))
test_data = to_df(v=v_test)

We now train the CTCausalForest and use it in the test data. To have better performance, it is also recommended to set the honest_subsample_num as not None.

ctcf = CTCausalForest(
    n_jobs=-1, 
    honest_subsample_num=0.5,
    min_samples_split=2,
    min_samples_leaf=10, 
    sub_sample_num=0.8, 
    n_estimators=500, 
    random_state=2022, 
    min_impurity_decrease=1e-10, 
    max_depth=100, 
    max_features=0.8,
    verbose=0,
)
ctcf.fit(data=data, outcome=outcome, treatment=treatment, adjustment=adjustment)
ctcf_pred = ctcf.estimate(data=test_data)

Class Structures