Randomized controlled trials (RCTs) serve as the cornerstone for understanding causal effects, yet extending inferences to target populations presents challenges due to effect heterogeneity and underrepresentation. Our paper addresses the critical issue of identifying and characterizing underrepresented subgroups in RCTs, proposing a novel framework for refining target populations to improve generalizability. We introduce an optimization-based approach, Rashomon Set of Optimal Trees (ROOT), to characterize underrepresented groups. ROOT optimizes the target subpopulation distribution by minimizing the variance of the target average treatment effect estimate, ensuring more precise treatment effect estimations. Notably, ROOT generates interpretable characteristics of the underrepresented population, aiding researchers in effective communication.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import sklearn.tree as tree
warnings.filterwarnings("ignore")
import dgp
df, Y = dgp.get_data(n=1000, seed=0)
outcome = "Yobs"
treatment = "T"
sample = "S"
Our approach takes as the input a combine dataframe of experimental and target sample units, name of the outcome, treatment and sample indicator, and hyperparameters. Ideally, larger the number of trees, the better the performance. However, there is time and performance tradeoff as the number of trees increases.
import ROOT
D_rash, D_forest, w_forest, rashomon_set, f, testing_data = ROOT.forest_opt(
data=df,
outcome=outcome,
treatment=treatment,
sample=sample,
leaf_proba=1,
num_trees=3000,
vote_threshold=2 / 5,
)
Fold 0
Fold 1
Fold 2
Fold 3
Fold 4
ATE Est: 1.2211
leaf 0.500000
X0 0.156484
X1 0.105682
X2 0.018238
X3 0.044944
X4 0.019970
X5 0.058912
X6 0.030142
X7 0.054766
X8 0.001075
X9 0.009787
dtype: float64
The units in the Orange Leaves are underrepresented and the units in the Blue leaves are well-represented in the RCT.
tree.plot_tree(f, filled=True)
[Text(0.375, 0.875, 'x[0] <= 0.497\ngini = 0.337\nsamples = 457\nvalue = [98, 359]'),
Text(0.25, 0.625, 'gini = 0.0\nsamples = 300\nvalue = [0, 300]'),
Text(0.5, 0.625, 'x[5] <= 0.51\ngini = 0.469\nsamples = 157\nvalue = [98, 59]'),
Text(0.25, 0.375, 'x[3] <= 0.209\ngini = 0.153\nsamples = 72\nvalue = [66, 6]'),
Text(0.125, 0.125, 'gini = 0.494\nsamples = 9\nvalue = [4, 5]'),
Text(0.375, 0.125, 'gini = 0.031\nsamples = 63\nvalue = [62, 1]'),
Text(0.75, 0.375, 'x[1] <= 0.506\ngini = 0.469\nsamples = 85\nvalue = [32, 53]'),
Text(0.625, 0.125, 'gini = 0.301\nsamples = 65\nvalue = [12, 53]'),
Text(0.875, 0.125, 'gini = 0.0\nsamples = 20\nvalue = [20, 0]')]
df_exp = df.loc[df["S"] == 1]
df_exp["w"] = f.predict(df_exp[[col for col in df_exp.columns if "X" in col]])
df_exp_refined = df_exp.loc[(df_exp["w"] == 1)]
df_exp_refined.describe().round(2)
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
X0 | X1 | X2 | X3 | X4 | X5 | X6 | X7 | X8 | X9 | Yobs | S | T | w | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 377.00 | 377.00 | 377.00 | 377.00 | 377.00 | 377.00 | 377.00 | 377.00 | 377.00 | 377.00 | 377.00 | 377.0 | 377.00 | 377.0 |
mean | 0.34 | 0.45 | 0.50 | 0.47 | 0.52 | 0.53 | 0.48 | 0.48 | 0.52 | 0.49 | 14.11 | 1.0 | 0.51 | 1.0 |
std | 0.25 | 0.27 | 0.28 | 0.29 | 0.29 | 0.29 | 0.29 | 0.29 | 0.28 | 0.28 | 4.96 | 0.0 | 0.50 | 0.0 |
min | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 2.37 | 1.0 | 0.00 | 1.0 |
25% | 0.14 | 0.21 | 0.25 | 0.20 | 0.26 | 0.29 | 0.23 | 0.22 | 0.27 | 0.25 | 10.40 | 1.0 | 0.00 | 1.0 |
50% | 0.30 | 0.43 | 0.50 | 0.46 | 0.56 | 0.56 | 0.48 | 0.50 | 0.53 | 0.50 | 13.98 | 1.0 | 1.00 | 1.0 |
75% | 0.47 | 0.66 | 0.73 | 0.71 | 0.75 | 0.78 | 0.71 | 0.72 | 0.75 | 0.72 | 17.51 | 1.0 | 1.00 | 1.0 |
max | 0.98 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 0.99 | 1.00 | 1.00 | 27.01 | 1.0 | 1.00 | 1.0 |