$$
\newcommand{\mat}[1]{\boldsymbol {#1}}
\newcommand{\mattr}[1]{\boldsymbol {#1}^\top}
\newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}}
\renewcommand{\vec}[1]{\boldsymbol {#1}}
\newcommand{\vectr}[1]{\boldsymbol {#1}^\top}
\newcommand{\rvar}[1]{\mathrm {#1}}
\newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}}
\newcommand{\diag}{\mathop{\mathrm {diag}}}
\newcommand{\set}[1]{\mathbb {#1}}
\newcommand{\cset}[1]{\mathcal {#1}}
\newcommand{\norm}[1]{\left\lVert#1\right\rVert}
\newcommand{\abs}[1]{\left\lvert#1\right\rvert}
\newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}}
\newcommand{\bb}[1]{\boldsymbol{#1}}
\newcommand{\Tr}[0]{^\top}
\newcommand{\grad}[0]{\nabla}
\newcommand{\E}[2][]{\mathbb{E}_{#1}\left[#2\right]}
\newcommand{\Var}[1]{\mathrm{Var}\left[#1\right]}
\newcommand{\ip}[3]{\left<#1,#2\right>_{#3}}
\newcommand{\given}[0]{\middle\vert}
\newcommand{\DKL}[2]{\cset{D}_{\text{KL}}\left(#1\,\Vert\, #2\right)}
\DeclareMathOperator*{\argmax}{arg\,max}
\DeclareMathOperator*{\argmin}{arg\,min}
\DeclareMathOperator*{\trace}{trace}
\newcommand{\1}[1]{\mathbb{I}\left\{#1\right\}}
\newcommand{\setof}[1]{\left\{#1\right\}}
\newcommand{\DO}[1]{\mathrm{do}\left(#1\right)}
\newcommand{\indep}{\perp \!\!\! \perp}
$$


# <center>Causal Inference 097400, Winter 2019-20<br><br>Final Project</center>

#### <center>Aviv Rosenberg<br>`avivr@cs`</center>

##### <center>April, 2020<br></center>


In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd

PHYSIONET_DB = 'data/physionet/crisdb'
MHRV_DATA_FILE = 'data/crisdb-full-60min.xlsx'
OUT_DIR = 'out/'

os.makedirs(OUT_DIR, exist_ok=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import matplotlib.pyplot as plt
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Set1.colors)
plt.rcParams['font.size'] = 12

## Part 1: Creating the datasets

In [3]:
from proj import data
from proj import ci

Load metadata from the PhysioNet CASTRR database files.

In [4]:
df_meta = data.castrr_load_metadata(PHYSIONET_DB)
df_meta.head()

Unnamed: 0_level_0,AGE,SEX
rec,Unnamed: 1_level_1,Unnamed: 2_level_1
e001a,60,Male
e001b,60,Male
e002a,65,Male
e002b,65,Male
e003a,55,Male


Load the HRV features calculated on this database with `mhrv`.
Then, join the HRV features with the metadata.

In [6]:
MHRV_GROUP_NAMES = ['E_CONTROL', 'E_TREATED', 'F_CONTROL', 'F_TREATED', 'M_CONTROL', 'M_TREATED']
dfs = data.load_mhrv_xls(MHRV_DATA_FILE, sheet_names=MHRV_GROUP_NAMES, df_meta=df_meta)

for name, df in dfs.items():
    assert not np.any(pd.isna(df)), name

Loaded E_CONTROL: 5937 samples, 49 features
Loaded E_TREATED: 6018 samples, 49 features
Loaded F_CONTROL: 4617 samples, 49 features
Loaded F_TREATED: 4760 samples, 49 features
Loaded M_CONTROL: 6137 samples, 49 features
Loaded M_TREATED: 6337 samples, 49 features


In [7]:
dfs['E_CONTROL']

Unnamed: 0_level_0,Unnamed: 1_level_0,RR,NN,AVNN,SDNN,RMSSD,pNN50,SEM,BETA_AR,HF_NORM_AR,HF_PEAK_AR,...,MSE17,MSE18,MSE19,MSE20,PIP,IALS,PSS,PAS,AGE,SEX
rec,win,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
e001a,1,4327,4264,701.071106,41.581642,13.757771,0.234577,0.636785,-1.117969,7.328302,0.327401,...,1.348242,1.395579,1.477925,1.502828,30.816135,0.308468,19.723265,6.191370,60,Male
e001a,2,4477,4404,750.743286,49.516403,15.641404,0.408812,0.746149,-1.128235,5.956380,0.254147,...,1.307543,1.354694,1.323236,1.371318,34.196186,0.342267,25.272480,8.969119,60,Male
e001a,3,4722,4650,722.733521,59.652050,14.098043,0.150570,0.874780,-1.095494,6.322984,0.252652,...,1.222841,1.239768,1.133098,1.205776,30.537634,0.305657,20.021505,7.118279,60,Male
e001a,4,4769,4696,703.978455,53.346508,13.653588,0.255591,0.778470,-1.121374,4.156405,0.255642,...,1.251534,1.125110,1.228755,1.128171,31.856899,0.318850,21.869677,6.984668,60,Male
e001a,5,4986,4922,687.711121,47.553204,12.772495,0.386100,0.677812,-1.179842,4.143440,0.240692,...,1.444636,1.419865,1.398180,1.431729,28.728159,0.287543,18.468102,5.099553,60,Male
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
e286a,16,4840,4819,739.530396,29.074734,10.354865,0.020755,0.418830,-0.414241,5.571465,0.346089,...,1.133445,1.180290,1.206903,1.167605,38.659473,0.386883,29.196928,15.646399,65,Female
e286a,17,5289,5213,666.909973,41.037853,10.034287,0.095932,0.568383,-0.682378,7.548707,0.325159,...,0.673941,0.681831,0.705784,0.639730,40.053711,0.400806,29.100327,15.288701,65,Female
e286a,18,5374,5355,667.397583,23.006018,11.370922,0.261487,0.314385,-1.036899,12.584630,0.319926,...,0.906764,0.942591,0.936589,1.025636,42.689075,0.427157,33.053223,18.169935,65,Female
e286a,19,4924,4856,719.261597,22.993504,10.629254,0.082389,0.329963,-0.436453,9.975235,0.269097,...,1.371525,1.503135,1.390089,1.548041,38.632618,0.386612,28.644976,14.229818,65,Female


Now we'll add the outcome columns. We'll use the non-linear HRV features as the outcomes, which measure the type of dynamics found in the heart beat intervals.

Note that for the treated group, the outcomes must come from the post-treatment data.
However, we'll take the pre-treatment HRV features for the treated group.

In [8]:
for drug in ['E', 'F', 'M']:
    df_ci = data.castrr_ci_dataset(
        dfs[f'{drug}_CONTROL'], dfs[f'{drug}_TREATED'],
        ignore_features=['SampEn'],
        include_counterfactuals=True,
        random_seed=42
    )
    df_ci.to_csv(f'{OUT_DIR}/df_ci_{drug}.csv')
    assert not np.any(pd.isna(df_ci))
    display(df_ci)



Unnamed: 0_level_0,Unnamed: 1_level_0,X_RR,X_NN,X_AVNN,X_SDNN,X_RMSSD,X_pNN50,X_SEM,X_BETA,X_HF_NORM,X_HF_PEAK,...,X_IALS,X_PSS,X_PAS,X_AGE,X_SEX,Y_MSE,Y_DFA,Y_MSE_CF,Y_DFA_CF,T
rec,win,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
e003,1,2302,1922,785.615540,108.269882,62.563004,19.260801,2.469625,-1.643589,14.567132,0.327401,...,0.484123,44.849117,22.320499,55,Male,0.810852,0.912992,0.919241,1.290615,0
e003,2,3576,2998,839.575317,85.939987,48.302204,14.848182,1.569566,-0.898494,19.157187,0.393928,...,0.492492,47.965309,16.210808,55,Male,1.101020,1.001630,1.206214,1.218996,0
e003,3,4227,4211,847.497620,77.423080,23.275806,3.752969,1.193103,-1.202459,3.817187,0.288532,...,0.423040,37.805748,12.134885,55,Male,1.333161,1.254841,0.966393,1.351943,0
e003,4,4135,4135,870.423950,68.379890,25.949564,4.741171,1.063385,-1.312657,6.859756,0.287784,...,0.537494,51.463120,26.457073,55,Male,1.104663,1.120857,1.440737,1.198067,0
e003,5,4162,4160,864.023071,65.426506,25.984091,4.664583,1.014395,-1.078064,7.304376,0.288532,...,0.543159,54.591347,23.221153,55,Male,1.201327,1.165819,1.403497,1.338633,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
e286,16,4840,4819,739.530396,29.074734,10.354865,0.020755,0.418830,-0.414241,5.571465,0.346089,...,0.386883,29.196928,15.646399,65,Female,0.583669,1.258133,1.212141,1.173110,1
e286,17,5289,5213,666.909973,41.037853,10.034287,0.095932,0.568383,-0.682378,7.548707,0.325159,...,0.400806,29.100327,15.288701,65,Female,1.057332,1.180567,0.533411,1.183375,1
e286,18,5374,5355,667.397583,23.006018,11.370922,0.261487,0.314385,-1.036899,12.584630,0.319926,...,0.427157,33.053223,18.169935,65,Female,1.464731,1.267163,0.876491,1.039011,1
e286,19,4924,4856,719.261597,22.993504,10.629254,0.082389,0.329963,-0.436453,9.975235,0.269097,...,0.386612,28.644976,14.229818,65,Female,1.020859,1.324425,1.295721,1.098521,1




Unnamed: 0_level_0,Unnamed: 1_level_0,X_RR,X_NN,X_AVNN,X_SDNN,X_RMSSD,X_pNN50,X_SEM,X_BETA,X_HF_NORM,X_HF_PEAK,...,X_IALS,X_PSS,X_PAS,X_AGE,X_SEX,Y_MSE,Y_DFA,Y_MSE_CF,Y_DFA_CF,T
rec,win,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
f003,1,2136,1799,623.271606,41.754089,14.734976,0.166852,0.984427,-1.296954,7.668265,0.373746,...,0.429366,35.519733,16.453585,55,Male,1.092734,1.166016,0.961286,1.190175,0
f003,2,2654,1947,620.216980,34.527798,15.583796,0.411100,0.782503,-1.303576,8.523433,0.338614,...,0.458890,38.674885,18.798151,55,Male,1.360651,1.192065,1.146658,1.178145,0
f003,3,3134,2595,678.037720,93.592285,14.180335,0.346955,1.837263,-1.256573,5.994501,0.373746,...,0.401311,32.793835,13.256262,55,Male,0.583588,1.164349,1.060824,1.180857,0
f003,4,3200,2842,864.655701,45.613350,19.269770,1.337557,0.855618,-0.993488,6.390196,0.366271,...,0.442802,39.232933,13.194933,55,Male,1.367110,1.195974,1.192654,1.203273,0
f003,5,2825,2442,844.357849,45.739792,19.037077,1.556739,0.925596,-0.943524,7.088824,0.371503,...,0.424826,36.773136,10.851761,55,Male,1.449908,1.168133,1.069999,1.174503,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
f229,17,3936,3926,906.958435,136.147064,21.316225,2.878981,2.172867,-1.266117,3.690003,0.306472,...,0.449936,41.645439,11.156393,55,Female,0.820468,1.150365,0.544653,1.284006,1
f229,18,5012,4980,703.677185,56.472477,11.324690,0.080337,0.800243,-1.642962,1.985371,0.296007,...,0.398875,30.682732,11.044177,55,Female,1.129551,0.914447,0.873783,1.322412,1
f229,19,4719,4668,747.099609,52.903912,12.687689,0.107135,0.774323,-1.333049,2.373204,0.296007,...,0.403685,32.090832,11.396744,55,Female,0.411574,1.113703,1.038265,1.337645,1
f229,20,4533,4493,778.791321,52.577774,13.328798,0.200356,0.784393,-1.505871,2.810117,0.378978,...,0.368210,29.979969,7.188961,55,Female,0.777500,1.200346,1.070522,1.331406,1




Unnamed: 0_level_0,Unnamed: 1_level_0,X_RR,X_NN,X_AVNN,X_SDNN,X_RMSSD,X_pNN50,X_SEM,X_BETA,X_HF_NORM,X_HF_PEAK,...,X_IALS,X_PSS,X_PAS,X_AGE,X_SEX,Y_MSE,Y_DFA,Y_MSE_CF,Y_DFA_CF,T
rec,win,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
m004,1,3888,3764,830.254944,54.718555,37.990566,6.404465,0.891887,-1.613317,24.273483,0.230227,...,0.500664,47.210415,20.536663,75,Male,0.868981,0.819913,0.691113,1.176074,0
m004,2,4067,3920,854.955872,49.046280,37.899147,5.792294,0.783363,-1.423211,28.345371,0.372251,...,0.522582,48.010204,24.974489,75,Male,0.773713,0.782073,0.671109,1.086064,0
m004,3,3950,3775,871.750000,41.066044,40.667088,8.002120,0.668381,-0.965243,38.653778,0.322916,...,0.519873,48.238411,23.735100,75,Male,1.028134,0.769491,1.263955,1.093659,0
m004,4,3949,3873,891.648926,46.359188,34.758556,6.224174,0.744924,-1.512840,29.633215,0.347584,...,0.522986,50.064548,20.423445,75,Male,0.773600,0.827209,1.199861,1.039614,0
m004,5,3889,3832,911.322388,47.382942,35.399673,6.395197,0.765437,-1.286781,27.544939,0.219015,...,0.513443,46.816284,21.894571,75,Male,0.971787,0.870471,1.250625,1.115467,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
m294,19,3586,3497,977.574524,37.354572,16.642721,0.943936,0.631678,-0.577826,12.399299,0.370008,...,0.496568,42.179012,26.794395,70,Male,0.880734,0.821657,0.898034,1.126386,1
m294,20,3828,3724,914.404480,29.345848,19.031404,0.859522,0.480886,-0.753966,12.013743,0.364028,...,0.576148,51.906551,36.922665,70,Male,0.675288,0.839720,1.381127,1.129456,1
m294,21,3711,3601,940.956055,28.774065,16.857832,0.194444,0.479501,-0.599465,17.896460,0.361786,...,0.569444,48.708691,37.683975,70,Male,0.673560,1.093003,0.892264,1.023486,1
m294,22,4107,3991,845.335571,79.906998,22.974422,4.837093,1.264864,-1.143045,12.646611,0.231722,...,0.565664,49.787022,34.227013,70,Male,0.647124,0.921717,0.393721,1.070972,1


In [9]:
X, y, t = data.split_dataset(df_ci, scale_covariates=True)
(X.shape, y.shape, t.shape)

((5621, 26), (5621, 4), (5621,))

In [10]:
X[0],y[0],t[0]

(array([-0.674552  , -0.5250128 ,  0.40915948,  0.08429667,  1.3321184 ,
         0.2775112 ,  0.0981509 , -1.3258218 ,  0.91730934, -0.92903256,
         0.5894641 , -1.1520776 , -0.45506716, -0.5237314 , -0.98965263,
        -0.24726298, -0.62948173, -0.3337854 ,  1.3320574 ,  0.01366417,
         0.92591834,  0.9259942 ,  0.8893802 ,  0.7158837 ,  1.6381848 ,
         0.45383295], dtype=float32),
 array([0.86898124, 0.8199128 , 0.6911126 , 1.1760738 ], dtype=float32),
 0)

## Part 2: Exploring the data

In [None]:
df_full = data.create_ci_dataset(df_control=dfs['NSR'], df_treated=dfs['AGING'])

We wish to assess the causal effect of age (the treatment variable) on several different possible outcome variables which are known to correspond to  heart-rate dynamics:

- The normalized power in the very-low frequency band.
- The DFA slopes, $\alpha_1$ and $\alpha_2$:
- The multiscale entropy.

For simplicity we make the treatment binary:
- $T=0$: Our control group consists of **healthy** individuals aged 22-45.
- $T=1$: Our treated group consists of **healthy** individuals aged 60 and over.

Note that we only use data from health individuals: no known underlying pardiopathologies, no arrhythmias in the processed ECG recordings. This is to ensure that if we measure the effect of age and not pathologies on the changes in HRV.

Let's plot the distribution of the data in the outcome variables of interest, conditioned on the treatment (young/old).

In [None]:
import ciproj.plot

fig, ax = plt.subplots(1, 1, figsize=(15,10))
group_by = dict(by='T')
group_legend_names={0:'control', 1:'treated'}
violin_args = dict(showextrema=False, showmeans=True, widths=0.5)
ciproj.plot.df_group_violins(ax, df_full, ['alpha1', 'alpha2'], group_by, violin_args, group_legend_names)

- The multiscale entropy:

In [None]:
mse_vars = [f'MSE{i}' for i in range(1, 21)]
fig, ax = plt.subplots(1, 1, figsize=(20,10))
ciproj.plot.df_group_violins(ax, df_full, mse_vars, group_by, violin_args, group_legend_names)

- The the power-spectral density within three different frequency bands:

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20,10))
ciproj.plot.df_group_violins(ax, df_full, ['VLF_NORM', 'LF_NORM', 'HF_NORM'], group_by, violin_args, group_legend_names)

We'll create three dataseats.
- The `vlf` dataset: Outcome is normalized VLF power. Regular VLF power will be removed from the covariates.
- The `dfa` dataset: Outcomes are `alpha1` and `alpha2` which are the DFA slopes.
- The `mse` dataset: Outcomes are `MSE1-20`, the multiscale entropy variables.

In [None]:
datasets = {}
datasets['vlf'] = ciproj.data.mark_dataset(df_full, outcomes=['VLF_NORM'], treatment='T', ignore=['VLF_POWER'])
datasets['dfa'] = ciproj.data.mark_dataset(df_full, outcomes=['alpha1', 'alpha2'], treatment='T')
datasets['mse'] = ciproj.data.mark_dataset(df_full, outcomes=mse_vars, treatment='T')

In [None]:
for name, df in datasets.items():
    print(f'*** {name} dataset: ', end='')
    X, y, y = ciproj.data.split_dataset(df, scale_covariates=True)
    print(f'X{X.shape}, y{y.shape}, t{t.shape}')
    print(f'covariates: {[x for x in df.columns if x.startswith("X_")]}')
    print(f'outcomes: {[x for x in df.columns if x.startswith("Y_")]}')
    print()

## Part 2: Propensiy estimation and common support

The *propensity score* is defined as $e(\vec{x}):=\Pr{\left(\rvar{T}=1\given \rvec{X}=\vec{x}\right)}$.
In other words, it is the probability that a treatment $\rvar{T}=1$ will be assigned to a
unit with covariates $\rvec{X}=\vec{x}$.
In this case, since our treatment variable is age, it is not really "assigned" based on covariates.
However, the propensity estimation can shed light on whether covariates have predictive power regarding age, and help us achieve balanced covariate marginal distributions for matching.

We would like to obtain an estimator for the propensity score, $\hat e(\vec{x})$, from the data.
One important aspect of this estimator is that we would like it to be *calibrated*.
Briefly, this means that for if for example we look at all units $\vec{x}$ such that $e(\vec{x})=0.8$, we expect 80\% of them to actually belong to the treatment group ($\rvar{T}=1$).

In [None]:
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.naive_bayes import GaussianNB
from scipy import stats

from ci.cv import CVConfig, LogSpaceSampler

# Define the methods and CV search space for propensity estimation
rcv_config = {
    'logistic': CVConfig(
        model=LogisticRegression(dual=False, solver='liblinear'),
        params=dict(C=LogSpaceSampler(-5, 2), penalty=['l1','l2']),
    ),
    'gbm': CVConfig(
        model=GradientBoostingClassifier(),
        params=dict(
            n_estimators=stats.randint(10, 100 + 1),
            max_depth=stats.randint(1, 3 + 1),
            learning_rate=LogSpaceSampler(-2, 0.5),
        )
    ),
}

In [None]:
from ci.propensity import fit_propensity_cv

# Create a dataframe to store propensities from each dataset and each method
midx = pd.MultiIndex.from_product([datasets.keys(), rcv_config.keys()], names=['dataset', 'method'])
df_propensity = pd.DataFrame(columns=midx)

for method, cv_cfg in rcv_config.items():
    fig, ax = plt.subplots(nrows=1, ncols=len(dataframes), figsize=(16, 6))
    
    for i, (dataset_name, df) in enumerate(datasets.items()):
        # Get covariates X and treatment assignment t as numpy arrays
        X, _, t = ciproj.data.split_dataset(df, scale_covariates=True)
        
        # Train a model to estimate propensity using current method,
        # calibrate with two approaches and generate a calibration plot
        model, best_params = fit_propensity_cv(
            cv_cfg, X, t, plot_args=dict(name=method, ax=ax[i]),
            test_size=0.3, n_iter=42, random_state=42+i, cv_splits=4
        )
        
        # Use the best calibrated model (lowest Brier score) to estimate the propensities
        prop = model.predict_proba(X)[:, 1]
        df_propensity[(dataset_name, method)] = prop
        
        ax[i].set_title(f'{dataset_name}, method={method}')
        print(f'{dataset_name} {method} \tbest_params={best_params}')

The plots above show the calibration curves of the logistic regression (top row) and the GBM models (middle row).
The ideal calibration is shown as a dotted line.
For each classifier, two calibration methods, Platt and Isotonic Regression, are shown.
In the legend, both the ROC-AUC score (`a`) and the Brier score (`b`) are reported as measures of the model's
classification performance (higher is better) and de-calibration level (lower is better), respectively.

In [None]:
df_propensity.head()

To visualize the propensity estimation results of each method, we can plot the distribution of propensity scores between treated and control groups.

In [None]:
for i, (dataset_name, df) in enumerate(datasets.items()):
    fig, ax = plt.subplots(nrows=1, ncols=len(rcv_config.keys()), figsize=(15, 5))
    # Create a temporaty dataframe
    df_tmp = df.copy()
    
    for j, method in enumerate(rcv_config.keys()):
        # Add propensity from current method to the temporary dataframe
        df_tmp['propensity'] = df_propensity[(dataset_name, method)].values
        # Plot propensity scores conditioned on treatment
        groups = df_tmp.groupby('T')
        groups['propensity'].plot(kind='hist', sharex=True, alpha=0.7, bins=50, ax=ax[j])
        ax[j].set_title(f'{dataset_name}, method={method}')
        ax[j].set_xlabel('propensity'); ax[j].grid(True); ax[j].legend([f'T=0','T=1'])

In order to do causal inference we need the common support ("overlap") assumption to hold:
$$
\forall t,~\vec{x}:\ \Pr\left(\rvar{T}=t\given \rvec{X}=\vec{x}\right)>0.
$$
This means that the the probability of any treatment is possible for any possible covariates of a unit.
In practice we have a limited dataset, and in our specific data this assumption does not hold when
we don't have samples from both groups (treatment and control) within the entire range of propensity scores.
Therefore, to maintain the overlap assumption in our dataset, we'll remove samples for which the propensity
score is outside the range of propensity scores of the other group.
In addition, we'll remove samples with extremely low propensity and high scores as these will cause numerical errors.

The common support calculation is implemented in the `common_support` function within the `propensity.py` module.

In [None]:
from ci.propensity import common_support

# Select 'logistic' as the propensity method for all datasets
prop_methods = ('logistic', 'logistic', 'logistic')
cs_threshold = 5e-2 # 5%/95%

# We'll save the dataframes after removing samples outside common support
datasets_cs = {}

fig, ax = plt.subplots(nrows=1, ncols=len(dataframes), figsize=(20, 6))
for i, (dataset_name, df) in enumerate(datasets.items()):
    # Copy the dataset and add a propesity column based on the selected method
    df = df.copy()
    df['propensity'] = df_propensity[(dataset_name, prop_methods[i])].values
    
    # Get common support indices and remove outlying samples
    idx_common = common_support(df['T'].to_numpy(), df['propensity'].to_numpy(),
                                min_thresh=cs_threshold, max_thresh=1-cs_threshold)
    df = df[idx_common]
    
    # Save the df with removed samples
    datasets_cs[dataset_name] = df
    
    groups = df.groupby('T')
    groups['propensity'].plot(kind='hist', sharex=True, alpha=0.7, bins=50, ax=ax[i], density=True)
    ax[i].set_xlabel('propensity'); ax[i].grid(True); ax[i].legend([f'T=0','T=1'])
    ax[i].set_title(f'{dataset_name} Common Support ({len(df)}/{len(datasets[dataset_name])} samples)')
    

In [None]:
datasets_cs['mse']