Skip to content
/ ROOT Public

ROOT is a python package that helps analyze and identify underrepresented populations in an RCT with respect to a target population

Notifications You must be signed in to change notification settings

CI-NYC/ROOT

Repository files navigation

ROOT: Rashomon Set of Optimal Trees for Characterizing Underrepresented Populations

Randomized controlled trials (RCTs) serve as the cornerstone for understanding causal effects, yet extending inferences to target populations presents challenges due to effect heterogeneity and underrepresentation. Our paper addresses the critical issue of identifying and characterizing underrepresented subgroups in RCTs, proposing a novel framework for refining target populations to improve generalizability. We introduce an optimization-based approach, Rashomon Set of Optimal Trees (ROOT), to characterize underrepresented groups. ROOT optimizes the target subpopulation distribution by minimizing the variance of the target average treatment effect estimate, ensuring more precise treatment effect estimations. Notably, ROOT generates interpretable characteristics of the underrepresented population, aiding researchers in effective communication.

Importing Necessary Libraries / Packages

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import sklearn.tree as tree

warnings.filterwarnings("ignore")

Generating Synthetic Data for Example

import dgp

df, Y = dgp.get_data(n=1000, seed=0)
outcome = "Yobs"
treatment = "T"
sample = "S"

Running ROOT (forest)

Our approach takes as the input a combine dataframe of experimental and target sample units, name of the outcome, treatment and sample indicator, and hyperparameters. Ideally, larger the number of trees, the better the performance. However, there is time and performance tradeoff as the number of trees increases.

import ROOT

D_rash, D_forest, w_forest, rashomon_set, f, testing_data = ROOT.forest_opt(
    data=df,
    outcome=outcome,
    treatment=treatment,
    sample=sample,
    leaf_proba=1,
    num_trees=3000,
    vote_threshold=2 / 5,
)
Fold 0
Fold 1
Fold 2
Fold 3
Fold 4
ATE Est: 1.2211
leaf    0.500000
X0      0.156484
X1      0.105682
X2      0.018238
X3      0.044944
X4      0.019970
X5      0.058912
X6      0.030142
X7      0.054766
X8      0.001075
X9      0.009787
dtype: float64

Plotting the Characteristic Tree

The units in the Orange Leaves are underrepresented and the units in the Blue leaves are well-represented in the RCT.

tree.plot_tree(f, filled=True)
[Text(0.375, 0.875, 'x[0] <= 0.497\ngini = 0.337\nsamples = 457\nvalue = [98, 359]'),
 Text(0.25, 0.625, 'gini = 0.0\nsamples = 300\nvalue = [0, 300]'),
 Text(0.5, 0.625, 'x[5] <= 0.51\ngini = 0.469\nsamples = 157\nvalue = [98, 59]'),
 Text(0.25, 0.375, 'x[3] <= 0.209\ngini = 0.153\nsamples = 72\nvalue = [66, 6]'),
 Text(0.125, 0.125, 'gini = 0.494\nsamples = 9\nvalue = [4, 5]'),
 Text(0.375, 0.125, 'gini = 0.031\nsamples = 63\nvalue = [62, 1]'),
 Text(0.75, 0.375, 'x[1] <= 0.506\ngini = 0.469\nsamples = 85\nvalue = [32, 53]'),
 Text(0.625, 0.125, 'gini = 0.301\nsamples = 65\nvalue = [12, 53]'),
 Text(0.875, 0.125, 'gini = 0.0\nsamples = 20\nvalue = [20, 0]')]

png

Getting the refined Experimental Data for Further Analysis

df_exp = df.loc[df["S"] == 1]
df_exp["w"] = f.predict(df_exp[[col for col in df_exp.columns if "X" in col]])

df_exp_refined = df_exp.loc[(df_exp["w"] == 1)]
df_exp_refined.describe().round(2)
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
X0 X1 X2 X3 X4 X5 X6 X7 X8 X9 Yobs S T w
count 377.00 377.00 377.00 377.00 377.00 377.00 377.00 377.00 377.00 377.00 377.00 377.0 377.00 377.0
mean 0.34 0.45 0.50 0.47 0.52 0.53 0.48 0.48 0.52 0.49 14.11 1.0 0.51 1.0
std 0.25 0.27 0.28 0.29 0.29 0.29 0.29 0.29 0.28 0.28 4.96 0.0 0.50 0.0
min 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 2.37 1.0 0.00 1.0
25% 0.14 0.21 0.25 0.20 0.26 0.29 0.23 0.22 0.27 0.25 10.40 1.0 0.00 1.0
50% 0.30 0.43 0.50 0.46 0.56 0.56 0.48 0.50 0.53 0.50 13.98 1.0 1.00 1.0
75% 0.47 0.66 0.73 0.71 0.75 0.78 0.71 0.72 0.75 0.72 17.51 1.0 1.00 1.0
max 0.98 1.00 1.00 1.00 1.00 1.00 1.00 0.99 1.00 1.00 27.01 1.0 1.00 1.0

About

ROOT is a python package that helps analyze and identify underrepresented populations in an RCT with respect to a target population

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published