In [1]:
!pip install -e ../CauseML

Obtaining file:///home/jovyan/work/CauseML
Installing collected packages: cause-ml
  Found existing installation: cause-ml 0.0.11
    Uninstalling cause-ml-0.0.11:
      Successfully uninstalled cause-ml-0.0.11
  Running setup.py develop for cause-ml
Successfully installed cause-ml


In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
from cause_ml.parameters import build_parameters_from_axis_levels
from cause_ml.constants import Constants
from cause_ml.data_generation import DataGeneratingProcessSampler
import cause_ml.data_sources as data_sources
from cause_ml.benchmarking import run_benchmark

In [17]:
import pandas as pd

## Model Demo

In [8]:
covar_data_source = data_sources.load_random_normal_covariates(n_covars = 10, n_observations=1000)
dgp_params = build_parameters_from_axis_levels({
    Constants.AxisNames.OUTCOME_NONLINEARITY: Constants.AxisLevels.LOW,
    Constants.AxisNames.TREATMENT_NONLINEARITY: Constants.AxisLevels.LOW,
})

dgp_sampler = DataGeneratingProcessSampler(
    parameters=dgp_params, data_source=covar_data_source)

dgp = dgp_sampler.sample_dgp()
dataset = dgp.generate_data()

In [9]:
dataset.ATE

-1.013

In [10]:
model = LinearRegressionCausalModel(dataset)
model.fit()
model.estimate(estimand=Constants.Model.ATE_ESTIMAND)

-1.0170460400173909

## Benchmarking Demo

In [16]:
%%time

HIGH, MEDIUM, LOW = Constants.AxisLevels.HIGH, Constants.AxisLevels.MEDIUM, Constants.AxisLevels.LOW
param_grid = dgp_params = {
    Constants.AxisNames.TREATMENT_NONLINEARITY: [HIGH, MEDIUM, LOW],
    Constants.AxisNames.OUTCOME_NONLINEARITY: [HIGH, MEDIUM, LOW]
}

covar_data_source = data_sources.load_random_normal_covariates(
    n_covars=10,
    n_observations=500)

result = run_benchmark(
    model_class=LinearRegressionCausalModel,
    estimand=Constants.Model.ATE_ESTIMAND,
    data_source=covar_data_source,
    param_grid=param_grid,
    num_dgp_samples=1,
    num_data_samples_per_dgp=1,
    enable_ray_multiprocessing=True)

CPU times: user 900 ms, sys: 520 ms, total: 1.42 s
Wall time: 21.1 s


In [18]:
pd.DataFrame(result)

Unnamed: 0,param_outcome_nonlinearity,param_treatment_nonlinearity,absolute mean bias,root mean squared error
0,HIGH,HIGH,0.13526,0.13526
1,HIGH,MEDIUM,0.168341,0.168341
2,HIGH,LOW,0.25274,0.25274
3,MEDIUM,HIGH,0.016795,0.016795
4,MEDIUM,MEDIUM,0.14666,0.14666
5,MEDIUM,LOW,0.000401,0.000401
6,LOW,HIGH,0.069457,0.069457
7,LOW,MEDIUM,0.043171,0.043171
8,LOW,LOW,0.031515,0.031515
