In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
#!pip install POT

In [244]:
# from CausalBenchmark.data_generation_framework import *
from CausalBenchmark.parameters import build_parameters, DEFAULT_PARAMETER_PATH
from CausalBenchmark.data_generation import DataGeneratingProcessWrapper
from CausalBenchmark.utilities import generate_random_covariates, normalize_covariate_data
from CausalBenchmark.data_metrics import calculate_data_metrics

In [304]:
from collections import defaultdict
from IPython.display import clear_output

In [300]:
covar_data = generate_random_covariates(n_covars = 12)
covar_data = normalize_covariate_data(covar_data)

In [301]:
dgp_params = build_parameters(DEFAULT_PARAMETER_PATH)
dgp_params.POTENTIAL_CONFOUNDER_SELECTION_PROBABILITY= 0.99
dgp_params.ACTUAL_CONFOUNDER_ALIGNMENT=0.99
dgp_params.TREATMENT_EFFECT_HETEROGENEITY = 0
dgp_params.OUTCOME_NOISE_TAIL_THICKNESS = 250

In [313]:
dgp_wrapper = DataGeneratingProcessWrapper(
    parameters=dgp_params, source_covariate_data=covar_data)

dgp_wrapper.sample_dgp()

observed_covariate_data, observed_outcome_data, oracle_covariate_data, oracle_outcome_data = dgp_wrapper.generate_data()

In [314]:
observed_outcome_data.head()
oracle_outcome_data.join(oracle_covariate_data).head()
observed_outcome_data.join(observed_covariate_data).head()
oracle_outcome_data.head()

Unnamed: 0,logit(P(T|X)),P(T|X),Y0,Y1,TE
0,-0.274829,0.431722,1.024203,2.535203,1.511
1,1.381897,0.799296,0.684659,2.195659,1.511
2,1.262487,0.779454,0.734363,2.245363,1.511
3,1.082848,0.747032,-1.178141,0.332859,1.511
4,1.443417,0.808983,1.223133,2.734133,1.511


## Metrics

In [253]:
metrics = calculate_data_metrics(observed_covariate_data, observed_outcome_data, oracle_covariate_data, oracle_outcome_data)
print(json.dumps(metrics, indent=4))

{
    "outcome non-linearity": {
        "Lin r2(X_obs, Y)": 0.8283097174903884,
        "Lin r2(X_true, Y)": 0.8110386742121254,
        "Lin r2(X_obs, Y1)": 0.975642647751361,
        "Lin r2(X_obs, Y0)": 0.975642647751361,
        "Lin r2(X_true, Y1)": 0.9291062024145316,
        "Lin r2(X_true, Y0)": 0.9291062024145315,
        "Lin r2(X_obs, TE)": 0.0,
        "Lin r2(X_true, TE)": 0.0
    },
    "treatment non-linearity": {
        "Log r2(X_obs, T)": 0.7,
        "Lin r2(X_obs, Treat Logit)": 0.974004210878853
    },
    "Percent": {
        "Percent(T, 1)": 49.2
    },
    "Overlap": {
        "NN c-factual dist(X_obs, T)": 2.606809057275519,
        "NN c-factual dist(X_true, T)": 3.1066770669054
    },
    "Balance": {
        "mean dist(X_true, T)": 0.7144340108534263,
        "Wass dist(X_true, T)": 0.025022391435119095,
        "Wass dist(X_obs, T)": 0.08932139846986369
    },
    "Alignment": {
        "Lin r2(Y, Treat Logit)": 0.807104214648415,
        "Lin r2(Y0, Treat

In [315]:
dgp_wrapper.outcome_function

NOISE(Y) + 1.511*T - 1.53702176552532*X0**2 - 0.509066846728832*X0 - 0.502519363169619*X1 - 2.0739154173808*X3 - 0.0311005469062631*X4**2 - 0.646564001472311*X4 - 1.43717264124731*X5**2 + 1.22274255468308*X6**2 + 2.15903270365058*X6 + 1.25384310158934*X7**2 + 0.554907925328674

In [316]:
dgp_wrapper.treatment_assignment_logit_function

Max(-2.19722457733622, Min(2.19722457733622, -0.47289501270356*X0**3 - 0.860266459279881*X1**3 - 0.149247007555379*X1*X5 - 0.5148183294326*X1 - 0.588603366875708*X11**2 - 0.727788778416118*X4 - 1.75742543728132*X5 + 1.28453042457776*X7**2 + 0.954335678106122))

In [308]:
def run(dgp_params):
    results = defaultdict(list)
    for i in range(10):
        clear_output()
        print(i+1)
        results
        covar_data = generate_random_covariates(n_covars = 10)
        covar_data = normalize_covariate_data(covar_data)

        dgp_wrapper = DataGeneratingProcessWrapper(
            parameters=dgp_params, source_covariate_data=covar_data)
        dgp_wrapper.sample_dgp()

        observed_covariate_data, observed_outcome_data, oracle_covariate_data, oracle_outcome_data = dgp_wrapper.generate_data()
        metrics = calculate_data_metrics(observed_covariate_data, observed_outcome_data, oracle_covariate_data, oracle_outcome_data)
        alignment_obs = metrics["Alignment"]["Lin r2(Y, Treat Logit)"]
        alignment_oracle = metrics["Alignment"]["Lin r2(Y0, Treat Logit)"]
        results["Alignment Obs"].append(alignment_obs)
        results["Alignment Oracle"].append(alignment_oracle)
        
    return results

In [325]:
dgp_params = build_parameters(DEFAULT_PARAMETER_PATH)
dgp_params.POTENTIAL_CONFOUNDER_SELECTION_PROBABILITY= 0.5
dgp_params.ACTUAL_CONFOUNDER_ALIGNMENT=0.1
dgp_params.TREATMENT_EFFECT_HETEROGENEITY = 1
dgp_params.OUTCOME_NOISE_TAIL_THICKNESS = 250

res = run(dgp_params)

10


In [326]:
print(np.mean(res["Alignment Obs"]), np.max(res["Alignment Obs"]))
print(np.mean(res["Alignment Oracle"]), np.max(res["Alignment Oracle"]))

0.14497290062566132 0.6508155806052487
0.16608106441328108 0.6519570211223438
