In [1]:
%load_ext autoreload
%autoreload 2

In [30]:
from CauseML.parameters import build_parameters_from_specification, build_parameters_from_metric_levels
from CauseML.constants import Constants
from CauseML.data_generation import DataGeneratingProcessWrapper
import CauseML.data_sources as data_sources
from CauseML.utilities import extract_treat_and_control_data
from CauseML.data_metrics import calculate_data_metrics

In [29]:
from collections import defaultdict
from IPython.display import clear_output
import matplotlib.pyplot as plt

## 1. Demonstration of Basic Dataset Generation Workflow

The code below demonstrates how the various pieces of CauseML benchmarking work together to generate a synthetic dataset.

1. Covariate data is fetched from the random covariate data source (imported above).


2. Parameters are generated by setting some metrics to given levels (with the rest left at default).


3. A data generating process is sampled according to the parameters and this DGP is used to generate data.


4. The observed and oracle data is extracted from the DGP wrapper. The observed data corresponds to the data a model would run on: the observed covariates, treatment assignment and outcome. The oracle data contains the potential outcomes and the transformed covariates which make up the treatment and outcome functions. Throughout this notebook, these transformed covariates are referenced as the 'True Covariates'. This is because a simple linear model over these values would accurately model the treatment and response surfaces.

In [13]:
covar_data = data_sources.load_random_normal_covariates(n_covars = 12)

In [14]:
covar_data = data_sources.load_cpp()
covar_data.head()

Unnamed: 0,x_1,x_2,x_3,x_4,x_5,x_6,x_7,x_8,x_9,x_10,...,x_49,x_50,x_51,x_52,x_53,x_54,x_55,x_56,x_57,x_58
0,-0.117647,2.0,-0.966667,-0.6,0.052632,-0.87156,-1.0,-1.0,-0.6,-1.0,...,-1.0,-1.0,0.0,-1.0,-1.0,0.0,-1.0,-1.0,0.294118,0.241379
1,-0.235294,2.0,-1.0,-1.0,0.263158,-0.018349,-1.0,-1.0,-1.0,-1.0,...,-1.0,-1.0,0.0,-1.0,-1.0,0.0,-1.0,-1.0,0.411765,0.448276
2,-0.235294,2.0,-1.0,-1.0,0.052632,-0.715596,-1.0,-1.0,-1.0,-1.0,...,-0.75,-1.0,0.0,-1.0,-1.0,0.0,-1.0,-1.0,0.294118,0.310345
3,0.352941,2.0,-1.0,-1.0,0.315789,-0.055046,-1.0,-1.0,-1.0,-1.0,...,-1.0,-1.0,0.0,-1.0,-1.0,0.0,-1.0,-1.0,0.529412,0.310345
4,-0.411765,2.0,-0.333333,-0.2,0.210526,-0.46789,-1.0,-1.0,-1.0,-1.0,...,-0.5,-1.0,0.0,-1.0,-1.0,0.0,-1.0,-1.0,0.529412,0.517241


In [9]:
dgp_params = build_parameters_from_metric_levels({
    Constants.MetricNames.TREATMENT_NONLINEARITY: Constants.MetricLevels.HIGH,
    Constants.MetricNames.TE_HETEROGENEITY: Constants.MetricLevels.LOW
})

In [48]:
%%time

dgp_wrapper = DataGeneratingProcessWrapper(
    parameters=dgp_params, source_covariate_data=covar_data)

dgp_wrapper.sample_dgp()

_ = dgp_wrapper.generate_data()

obs = dgp_wrapper.get_observed_data()
oracle = dgp_wrapper.get_oracle_data()

Sampling observed individuals...
Sampling potential confounding variables...
Sampling covariate transforms...
Sampling treatment function...
Sampling outcome function...
Done
CPU times: user 16.9 s, sys: 50 ms, total: 16.9 s
Wall time: 16.8 s


In [41]:
obs.head()

Unnamed: 0,T,Y,x_1,x_2,x_3,x_4,x_5,x_6,x_7,x_8,...,x_50,x_51,x_52,x_53,x_54,x_55,x_56,x_57,x_58,NOISE(Y)
0,1,-0.48207,-0.117647,2.0,-0.966667,-0.6,0.052632,-0.87156,-1.0,-1.0,...,-1.0,0.0,-1.0,-1.0,0.0,-1.0,-1.0,0.294118,0.241379,-0.076
1,0,-0.153521,-0.235294,2.0,-1.0,-1.0,0.263158,-0.018349,-1.0,-1.0,...,-1.0,0.0,-1.0,-1.0,0.0,-1.0,-1.0,0.411765,0.448276,0.089
2,1,0.978425,-0.235294,2.0,-1.0,-1.0,0.052632,-0.715596,-1.0,-1.0,...,-1.0,0.0,-1.0,-1.0,0.0,-1.0,-1.0,0.294118,0.310345,0.475
3,1,0.665105,0.352941,2.0,-1.0,-1.0,0.315789,-0.055046,-1.0,-1.0,...,-1.0,0.0,-1.0,-1.0,0.0,-1.0,-1.0,0.529412,0.310345,-0.013
4,1,-1.191833,-0.411765,2.0,-0.333333,-0.2,0.210526,-0.46789,-1.0,-1.0,...,-1.0,0.0,-1.0,-1.0,0.0,-1.0,-1.0,0.529412,0.517241,-0.045


In [42]:
oracle.head()

Unnamed: 0,logit(P(T|X)),P(T|X),Y0,Y1,TE,TRANSFORMED_X0,TRANSFORMED_X1,TRANSFORMED_X2,TRANSFORMED_X3,TRANSFORMED_X4,...,TRANSFORMED_X301,TRANSFORMED_X302,TRANSFORMED_X303,TRANSFORMED_X304,TRANSFORMED_X305,TRANSFORMED_X306,TRANSFORMED_X307,TRANSFORMED_X308,TRANSFORMED_X309,TRANSFORMED_X310
0,-0.465611358020533,0.385655500897233,-0.64307,-0.48207,0.161,0.000637,-0.003241,-0.061,0.0,0.0,...,0.070495,0.0,0.039608,0.0,-0.00525,-0.0,0.038679,0.004143,-0.056118,0.045732
1,0.227907376875991,0.55673149533612,-0.153521,0.007479,0.161,-0.000141,-0.003602,-0.061,0.0,0.0,...,0.026954,-0.0,-0.003772,0.0,0.001969,-0.0,0.014789,-0.006905,-0.187059,0.057879
2,0.771324073781636,0.683807247860239,0.817425,0.978425,0.161,-0.00592,0.002701,-0.061,0.0,0.0,...,0.057018,0.0,0.081101,0.0,0.001356,-0.0,0.031284,-0.012429,-0.0,0.076299
3,0.209724982296737,0.5522399066514,0.504105,0.665105,0.161,0.001634,0.000324,-0.061,0.0,0.0,...,0.060128,0.0,0.050924,0.0,-0.007656,-0.0,0.032991,0.009667,0.364765,0.049622
4,-0.33460631090466,0.417120260677784,-1.352833,-1.191833,0.161,-0.027055,-0.002701,-0.061,0.0,0.193889,...,0.089156,0.0,0.033949,0.0,0.001181,-0.0,0.048917,-0.012429,0.168353,0.057879


## 2. Metrics and Measures

The code below displays the various metrics which specify the distributional challenges in the data and the measures used to quantify them.

Note that most of the measures below are what is referred to as Oracle measures. They are not values we could calculate in real observational data as they rely on access to data about the counterfactual outcome and assignment mechanism.

#### 1. Outcome Nonlinearity

This is the degree of non-linearity in the outcome mechanism. It is measured as the linear fit $R^2$ between the observed covariates and the observed outcome. The higher the $R^2$, the more linear the outcome surface must be. This is not an oracle measure.

#### 2. Treatment Nonlinearity

This is the degree of non-linearity in the treatment mechanism. It is measured as the linear fit $R^2$ between the observed covariates and the logit of the propensity score. The higher the $R^2$, the more linear the treatment surface must be. This is an oracle measure as the propensity score and its logit is unknown in normal circumstances.

#### 3. Percent Treated

This is the percent of the dataset in the treated condition. It is a simple non-oracle measure.

#### 4. Balance

Three measures are used to quantify balance.

The simplest is the distance between the means of the true covariates in the two groups (see above for a definition of the true covariates). If the distributions are identical, the means will be in the same location. The problem with this measure is that it is sensitive to outliers or dimensions with higher variance (which can pull the mean far apart even if most dimensions have very similar values). This is a non-oracle metric.

The second simplest measure is the naive error in the average treatment effect estimate. In balance data, the ATE should be accurate. In imbalanced data it will not be. This is an oracle metric as it relies on access to the true ATE.

Finally, the Wassertein distance is used. This will be discussed more in future work but it is an integral probability measure of distance (the sum of the distance between the distribution CDFs across the domain) which means it is a pretty good measure of distance. 

#### 5. Alignment

The alignment measures the correlation ($R^2$) between the outcome and the the propensity logit. The higher this value, the greater the degree of confounding as the covariates which affect the outcome affect the assignment "more". This is an oracle metric as it relies on the propensity logit.

#### 6. Treatment Effect Heterogeneity

This measures the degree to which the treatment effect interacts with covariates (changes per individual). The standard deviation in the treatment effect is normalized by the standard deviation in the outcome so that the only variation measured is in the treatment effect itself (rather than changes coming from the base response function and not the interacion terms).


In [62]:
metrics_and_measures = {
    Constants.MetricNames.OUTCOME_NONLINEARITY: [
        "Lin r2(X_obs, Y)"
    ],
    Constants.MetricNames.TREATMENT_NONLINEARITY: [
        "Lin r2(X_obs, Treat Logit)"
    ],
    Constants.MetricNames.PERCENT_TREATED: [
        "Percent(T==1)"
    ],
    Constants.MetricNames.BALANCE: [
        "Mean dist X_true: T=1<->T=0",
        "Wass dist X_true: T=1<->T=0",
        "Naive TE"
    ],
    Constants.MetricNames.ALIGNMENT: [
        "Lin r2(Y, Treat Logit)",
        "Lin r2(Y0, Treat Logit)"
    ],
    Constants.MetricNames.TE_HETEROGENEITY: [
        "std(TE)/std(Y)"
    ]
}

levels = [
    Constants.MetricLevels.LOW,
    Constants.MetricLevels.MEDIUM,
    Constants.MetricLevels.HIGH,
]

## 3. Parameter Validation

The code below generates 30 datasets from 10 covariates for each possible setting (high, medium, low) of each metric and plots the distribution of the measures of each metric at each level. Each metric setting corresponds to different underlying parameters, so this serves as a validation that the parameter settings do indeed produce the desired effect on the relevant metric.

Each of the plots show a clear correspondence betweent the metric/parameter setting and the metric measures which track the realized level of the metric in the simulated data. The exact values of the metrics also agrees with the medians and quartiles stipulated in Dorie et al (2019)

There are two metrics worth exploring in a little more detail. Firstly, not that only the middle metric measure in Balance appears to accurately capture the increasing imbalance. Measuring balance is - in general - a hard problem. The Wassertein distance between the true covariates in the treat and control group is a promising measure based on the summed absolute error between the distribution CDFs. I plan to spend some time discussing this distance measure in the final draft.

Also note that the measure of treatment effect heterogeneity has a very wide distribution. This is because the distribution used to generate the treatment effect itself has thick tails. This means that there is large variance in the treatment effect size and hence large variance in the impact of the interactions with covariates. I plan to refine this soon.

In [None]:
results = analyze_metric_measures_across_levels(
    metrics_and_measures, data_sources.load_cpp, n_trials=25)


Running for OUTCOME_NONLINEARITY. Level:  LOW 

In [None]:
plot_metric_measure_analysis(results, max_measure_count=3)

In [51]:
def analyze_metric_measures_across_levels(
    metrics_and_measures, covar_generator,
    n_trials=20):
    '''
    Collect values for the given metrics and measures
    across all possible settings for each metric. Uses
    the gather_metric_measures_for_given_params to generate
    the metric measure values.
    '''

    results = defaultdict(lambda: defaultdict(dict))
    
    # Run for each given metric and set of measures
    for metric, measures in metrics_and_measures.items():
        print(f"\nRunning for {metric}. Level: ", end=" ")

        # Construct observation list of measures.
        observation_list = { metric: measures }

        # Run trials at all levels of metric.
        for level in levels:
            print(level, end=" ")
            dgp_params = build_parameters_from_metric_levels({
                metric: level
            })

            res = gather_metric_measures_for_given_params(
                    dgp_params, observation_list, covar_generator,
                    n_trials=n_trials)
            
            for measure, values in res[metric].items():
                results[metric][measure][level] = values
    
    return results

In [52]:
def gather_metric_measures_for_given_params(
        dgp_params, observation_spec,
        covar_generator,
        n_trials=10, verbose=False):
    
    '''
    Create n_trials datasets by sampling data generating processes according
    to the given dgp_params and data source parameters (n_covars, n_obs).
    
    Collect the metrics given in observation spec.
    '''
    results = defaultdict(lambda: defaultdict(list))
    for i in range(n_trials):
        if verbose:
            clear_output()
            print("Trials run:", i+1)
        results
        
        # Generate data
        covar_data = covar_generator()
        
        # Sample DGP
        dgp_wrapper = DataGeneratingProcessWrapper(
            parameters=dgp_params, source_covariate_data=covar_data)
        dgp_wrapper.sample_dgp()
        observed_covariate_data, observed_outcome_data, oracle_covariate_data, oracle_outcome_data = \
            dgp_wrapper.generate_data()
        
        # Calculate metrics
        metrics = calculate_data_metrics(
            observed_covariate_data, observed_outcome_data,
            oracle_covariate_data, oracle_outcome_data,
            observation_spec=observation_spec)
        
        # Build results
        for metric, measures in observation_spec.items():
            for measure in measures:
                res = metrics[metric][measure]
                results[metric][measure].append(res)
    
    if verbose:
        for metric, measures in results.items():
            for measure, result_data in measures.items():
                print(f"{metric} {measure}:")
                print("min", round(np.min(result_data), 3), end=" ")
                print("mean:", round(np.mean(result_data), 3), end=" ")
                print("max", round(np.max(result_data), 3))
                print("-------------\n\n")
        
    return results

In [None]:
def plot_metric_measure_analysis(results, max_measure_count=1):
    '''
    Plot the results of the analyze_metric_measures_across_levels
    function. Show the median, mean and 1st and 3rd quartiles for each
    metric measure at each level.
    '''
    
    level_colors = ["g", "b", "y"] # low medium high
    mean_color = "r" # mean will be red
    for metric, measures in results.items():
        plt.figure(figsize=(10, 4))
        plt.suptitle(f"{metric}")
        plt.tight_layout()

        for measure_num, (measure_name, measure_values) in enumerate(measures.items()):
            plt.subplot(1, max_measure_count, 1 + measure_num)
            plt.title(f"{measure_name}")
            for level_num, level in enumerate(levels):
                level_values = measure_values[level]

                # Find quartile values in data
                quartiles = np.percentile(
                    level_values,
                    [25, 50, 75],
                    interpolation = 'midpoint')

                # Prepare plotting data
                x = level_num+1
                y_median = quartiles[1]
                y_mean = np.mean(level_values)
                err = np.array([
                    [y_median - quartiles[0]],
                    [quartiles[2] - y_median]
                ])

                # Plot
                color = level_colors[level_num]
                plt.xlim((0, max_measure_count+1))
                plt.scatter(x, y_median, label=level, color=color)
                plt.scatter(x, y_mean, color=mean_color)
                plt.errorbar(x, y_median, err, color=color)

            plt.legend()

        plt.show()