$$
\newcommand{\mat}[1]{\boldsymbol {#1}}
\newcommand{\mattr}[1]{\boldsymbol {#1}^\top}
\newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}}
\renewcommand{\vec}[1]{\boldsymbol {#1}}
\newcommand{\vectr}[1]{\boldsymbol {#1}^\top}
\newcommand{\rvar}[1]{\mathrm {#1}}
\newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}}
\newcommand{\diag}{\mathop{\mathrm {diag}}}
\newcommand{\set}[1]{\mathbb {#1}}
\newcommand{\cset}[1]{\mathcal {#1}}
\newcommand{\norm}[1]{\left\lVert#1\right\rVert}
\newcommand{\abs}[1]{\left\lvert#1\right\rvert}
\newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}}
\newcommand{\bb}[1]{\boldsymbol{#1}}
\newcommand{\Tr}[0]{^\top}
\newcommand{\grad}[0]{\nabla}
\newcommand{\E}[2][]{\mathbb{E}_{#1}\left[#2\right]}
\newcommand{\Var}[1]{\mathrm{Var}\left[#1\right]}
\newcommand{\ip}[3]{\left<#1,#2\right>_{#3}}
\newcommand{\given}[0]{\middle\vert}
\newcommand{\DKL}[2]{\cset{D}_{\text{KL}}\left(#1\,\Vert\, #2\right)}
\DeclareMathOperator*{\argmax}{arg\,max}
\DeclareMathOperator*{\argmin}{arg\,min}
\DeclareMathOperator*{\trace}{trace}
\newcommand{\1}[1]{\mathbb{I}\left\{#1\right\}}
\newcommand{\setof}[1]{\left\{#1\right\}}
\newcommand{\DO}[1]{\mathrm{do}\left(#1\right)}
\newcommand{\indep}{\perp \!\!\! \perp}
$$


# <center>Causal Inference 097400, Winter 2019-20<br><br>Final Project</center>

#### <center>Aviv Rosenberg<br>`avivr@cs`</center>

##### <center>April, 2020<br></center>


In [1]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd

PHYSIONET_DB = 'data/physionet/crisdb'
MHRV_DATA_FILE = 'data/crisdb-full-60min.xlsx'
OUT_DIR = 'out/'

os.makedirs(OUT_DIR, exist_ok=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import matplotlib.pyplot as plt
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Set1.colors)
plt.rcParams['font.size'] = 12

## Part 1: Creating the datasets

In [3]:
from ciproj import data

Load metadata from the PhysioNet CASTRR database files.

In [10]:
df_meta = data.load_castrr_metadata(PHYSIONET_DB)
df_meta.head()

Unnamed: 0_level_0,AGE,SEX
rec,Unnamed: 1_level_1,Unnamed: 2_level_1
e001a,60,Male
e001b,60,Male
e002a,65,Male
e002b,65,Male
e003a,55,Male


Load the HRV features calculated on this database with `mhrv`.
Then, join the HRV features with the metadata.

In [12]:
MHRV_GROUP_NAMES = ['E_CONTROL', 'E_TREATED', 'F_CONTROL', 'F_TREATED', 'M_CONTROL', 'M_TREATED']
dfs = data.load_mhrv_xls(MHRV_DATA_FILE, sheet_names=MHRV_GROUP_NAMES, df_meta=df_meta)

dfs['E_CONTROL']

Loaded E_CONTROL: 5937 samples, 49 features
Loaded E_TREATED: 6018 samples, 49 features
Loaded F_CONTROL: 4617 samples, 49 features
Loaded F_TREATED: 4760 samples, 49 features
Loaded M_CONTROL: 6137 samples, 49 features
Loaded M_TREATED: 6337 samples, 49 features


Unnamed: 0_level_0,Unnamed: 1_level_0,RR,NN,AVNN,SDNN,RMSSD,pNN50,SEM,BETA_AR,HF_NORM_AR,HF_PEAK_AR,...,MSE17,MSE18,MSE19,MSE20,PIP,IALS,PSS,PAS,AGE,SEX
rec,win,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
e001a,1,4327,4264,701.071106,41.581642,13.757771,0.234577,0.636785,-1.117969,7.328302,0.327401,...,1.348242,1.395579,1.477925,1.502828,30.816135,0.308468,19.723265,6.191370,60,Male
e001a,2,4477,4404,750.743286,49.516403,15.641404,0.408812,0.746149,-1.128235,5.956380,0.254147,...,1.307543,1.354694,1.323236,1.371318,34.196186,0.342267,25.272480,8.969119,60,Male
e001a,3,4722,4650,722.733521,59.652050,14.098043,0.150570,0.874780,-1.095494,6.322984,0.252652,...,1.222841,1.239768,1.133098,1.205776,30.537634,0.305657,20.021505,7.118279,60,Male
e001a,4,4769,4696,703.978455,53.346508,13.653588,0.255591,0.778470,-1.121374,4.156405,0.255642,...,1.251534,1.125110,1.228755,1.128171,31.856899,0.318850,21.869677,6.984668,60,Male
e001a,5,4986,4922,687.711121,47.553204,12.772495,0.386100,0.677812,-1.179842,4.143440,0.240692,...,1.444636,1.419865,1.398180,1.431729,28.728159,0.287543,18.468102,5.099553,60,Male
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
e286a,16,4840,4819,739.530396,29.074734,10.354865,0.020755,0.418830,-0.414241,5.571465,0.346089,...,1.133445,1.180290,1.206903,1.167605,38.659473,0.386883,29.196928,15.646399,65,Female
e286a,17,5289,5213,666.909973,41.037853,10.034287,0.095932,0.568383,-0.682378,7.548707,0.325159,...,0.673941,0.681831,0.705784,0.639730,40.053711,0.400806,29.100327,15.288701,65,Female
e286a,18,5374,5355,667.397583,23.006018,11.370922,0.261487,0.314385,-1.036899,12.584630,0.319926,...,0.906764,0.942591,0.936589,1.025636,42.689075,0.427157,33.053223,18.169935,65,Female
e286a,19,4924,4856,719.261597,22.993504,10.629254,0.082389,0.329963,-0.436453,9.975235,0.269097,...,1.371525,1.503135,1.390089,1.548041,38.632618,0.386612,28.644976,14.229818,65,Female


Now we'll add the outcome columns. We'll use the non-linear HRV features as the outcomes, which measure the type of dynamics found in the heart beat intervals.

Note that for the treated group, the outcomes must come from the post-treatment data.
However, we'll take the pre-treatment HRV features for the treated group.

In [18]:
data.consolidate_psd(dfs['E_CONTROL'], '_AR')

Unnamed: 0_level_0,Unnamed: 1_level_0,RR,NN,AVNN,SDNN,RMSSD,pNN50,SEM,BETA,HF_NORM,HF_PEAK,...,MSE17,MSE18,MSE19,MSE20,PIP,IALS,PSS,PAS,AGE,SEX
rec,win,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
e001a,1,4327,4264,701.071106,41.581642,13.757771,0.234577,0.636785,-1.117969,7.328302,0.327401,...,1.348242,1.395579,1.477925,1.502828,30.816135,0.308468,19.723265,6.191370,60,Male
e001a,2,4477,4404,750.743286,49.516403,15.641404,0.408812,0.746149,-1.128235,5.956380,0.254147,...,1.307543,1.354694,1.323236,1.371318,34.196186,0.342267,25.272480,8.969119,60,Male
e001a,3,4722,4650,722.733521,59.652050,14.098043,0.150570,0.874780,-1.095494,6.322984,0.252652,...,1.222841,1.239768,1.133098,1.205776,30.537634,0.305657,20.021505,7.118279,60,Male
e001a,4,4769,4696,703.978455,53.346508,13.653588,0.255591,0.778470,-1.121374,4.156405,0.255642,...,1.251534,1.125110,1.228755,1.128171,31.856899,0.318850,21.869677,6.984668,60,Male
e001a,5,4986,4922,687.711121,47.553204,12.772495,0.386100,0.677812,-1.179842,4.143440,0.240692,...,1.444636,1.419865,1.398180,1.431729,28.728159,0.287543,18.468102,5.099553,60,Male
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
e286a,16,4840,4819,739.530396,29.074734,10.354865,0.020755,0.418830,-0.414241,5.571465,0.346089,...,1.133445,1.180290,1.206903,1.167605,38.659473,0.386883,29.196928,15.646399,65,Female
e286a,17,5289,5213,666.909973,41.037853,10.034287,0.095932,0.568383,-0.682378,7.548707,0.325159,...,0.673941,0.681831,0.705784,0.639730,40.053711,0.400806,29.100327,15.288701,65,Female
e286a,18,5374,5355,667.397583,23.006018,11.370922,0.261487,0.314385,-1.036899,12.584630,0.319926,...,0.906764,0.942591,0.936589,1.025636,42.689075,0.427157,33.053223,18.169935,65,Female
e286a,19,4924,4856,719.261597,22.993504,10.629254,0.082389,0.329963,-0.436453,9.975235,0.269097,...,1.371525,1.503135,1.390089,1.548041,38.632618,0.386612,28.644976,14.229818,65,Female


In [15]:
data.create_outcome_columns(dfs['E_CONTROL'], mse=True, dfa=True)[0]

Unnamed: 0_level_0,Unnamed: 1_level_0,RR,NN,AVNN,SDNN,RMSSD,pNN50,SEM,BETA_AR,HF_NORM_AR,HF_PEAK_AR,...,SD2,SampEn,PIP,IALS,PSS,PAS,AGE,SEX,Y_MSE,Y_DFA
rec,win,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
e001a,1,4327,4264,701.071106,41.581642,13.757771,0.234577,0.636785,-1.117969,7.328302,0.327401,...,57.969677,0.756291,30.816135,0.308468,19.723265,6.191370,60,Male,1.370547,1.222382
e001a,2,4477,4404,750.743286,49.516403,15.641404,0.408812,0.746149,-1.128235,5.956380,0.254147,...,69.152100,0.869921,34.196186,0.342267,25.272480,8.969119,60,Male,1.259971,1.220502
e001a,3,4722,4650,722.733521,59.652050,14.098043,0.150570,0.874780,-1.095494,6.322984,0.252652,...,83.759972,0.796012,30.537634,0.305657,20.021505,7.118279,60,Male,1.108175,1.258079
e001a,4,4769,4696,703.978455,53.346508,13.653588,0.255591,0.778470,-1.121374,4.156405,0.255642,...,74.813255,0.715674,31.856899,0.318850,21.869677,6.984668,60,Male,1.118184,1.305807
e001a,5,4986,4922,687.711121,47.553204,12.772495,0.386100,0.677812,-1.179842,4.143440,0.240692,...,66.641884,0.647037,28.728159,0.287543,18.468102,5.099553,60,Male,1.233004,1.272703
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
e286a,16,4840,4819,739.530396,29.074734,10.354865,0.020755,0.418830,-0.414241,5.571465,0.346089,...,40.462791,1.266139,38.659473,0.386883,29.196928,15.646399,65,Female,1.212141,1.173110
e286a,17,5289,5213,666.909973,41.037853,10.034287,0.095932,0.568383,-0.682378,7.548707,0.325159,...,57.595081,0.305259,40.053711,0.400806,29.100327,15.288701,65,Female,0.533411,1.183375
e286a,18,5374,5355,667.397583,23.006018,11.370922,0.261487,0.314385,-1.036899,12.584630,0.319926,...,31.520037,1.185581,42.689075,0.427157,33.053223,18.169935,65,Female,0.876491,1.039011
e286a,19,4924,4856,719.261597,22.993504,10.629254,0.082389,0.329963,-0.436453,9.975235,0.269097,...,31.637148,1.220223,38.632618,0.386612,28.644976,14.229818,65,Female,1.295721,1.098521


In [21]:

data.mark_dataset(
    *data.create_outcome_columns(dfs['E_CONTROL'], mse=True, dfa=True, prefix=''),
    ignore=['SampEn']
)

(           X_RR  X_NN      X_AVNN     X_SDNN    X_RMSSD   X_pNN50     X_SEM  \
 rec   win                                                                     
 e001a 1    4327  4264  701.071106  41.581642  13.757771  0.234577  0.636785   
       2    4477  4404  750.743286  49.516403  15.641404  0.408812  0.746149   
       3    4722  4650  722.733521  59.652050  14.098043  0.150570  0.874780   
       4    4769  4696  703.978455  53.346508  13.653588  0.255591  0.778470   
       5    4986  4922  687.711121  47.553204  12.772495  0.386100  0.677812   
 ...         ...   ...         ...        ...        ...       ...       ...   
 e286a 16   4840  4819  739.530396  29.074734  10.354865  0.020755  0.418830   
       17   5289  5213  666.909973  41.037853  10.034287  0.095932  0.568383   
       18   5374  5355  667.397583  23.006018  11.370922  0.261487  0.314385   
       19   4924  4856  719.261597  22.993504  10.629254  0.082389  0.329963   
       20   5130  4992  681.668823  20.5

In [30]:
data.create_ci_dataset(dfs['E_CONTROL'], dfs['E_TREATED'], ignore_features=['SampEn'], random_seed=42)

Unnamed: 0_level_0,Unnamed: 1_level_0,X_RR,X_NN,X_AVNN,X_SDNN,X_RMSSD,X_pNN50,X_SEM,X_BETA,X_HF_NORM,X_HF_PEAK,...,X_SD2,X_PIP,X_IALS,X_PSS,X_PAS,X_AGE,X_SEX,Y_MSE,Y_DFA,T
rec,win,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
e001a,1,4327.0,4264.0,701.071106,41.581642,13.757771,0.234577,0.636785,-1.117969,7.328302,0.327401,...,57.969677,30.816135,0.308468,19.723265,6.191370,60,Male,1.370547,1.222382,0
e001a,2,4477.0,4404.0,750.743286,49.516403,15.641404,0.408812,0.746149,-1.128235,5.956380,0.254147,...,69.152100,34.196186,0.342267,25.272480,8.969119,60,Male,1.259971,1.220502,0
e001a,3,4722.0,4650.0,722.733521,59.652050,14.098043,0.150570,0.874780,-1.095494,6.322984,0.252652,...,83.759972,30.537634,0.305657,20.021505,7.118279,60,Male,1.108175,1.258079,0
e001a,4,4769.0,4696.0,703.978455,53.346508,13.653588,0.255591,0.778470,-1.121374,4.156405,0.255642,...,74.813255,31.856899,0.318850,21.869677,6.984668,60,Male,1.118184,1.305807,0
e001a,5,4986.0,4922.0,687.711121,47.553204,12.772495,0.386100,0.677812,-1.179842,4.143440,0.240692,...,66.641884,28.728159,0.287543,18.468102,5.099553,60,Male,1.233004,1.272703,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
e282b,21,,,,,,,,,,,...,,,,,,,,1.282108,1.305852,1
e282b,22,,,,,,,,,,,...,,,,,,,,0.900446,1.293175,1
e282b,23,,,,,,,,,,,...,,,,,,,,0.747876,1.189825,1
e282b,24,,,,,,,,,,,...,,,,,,,,1.093302,1.009494,1


In [24]:
_.index

MultiIndex([('e001a',  '1'),
            ('e001a',  '2'),
            ('e001a',  '3'),
            ('e001a',  '4'),
            ('e001a',  '5'),
            ('e001a',  '6'),
            ('e001a',  '7'),
            ('e001a',  '8'),
            ('e001a',  '9'),
            ('e001a', '10'),
            ...
            ('e281a', '15'),
            ('e281a', '16'),
            ('e281a', '17'),
            ('e281a', '18'),
            ('e281a', '19'),
            ('e281a', '20'),
            ('e281a', '21'),
            ('e281a', '22'),
            ('e281a', '23'),
            ('e281a', '24')],
           names=['rec', 'win'], length=3001)

## Part 2: Exploring the data

In [None]:
df_full = data.create_ci_dataset(df_control=dfs['NSR'], df_treated=dfs['AGING'])

We wish to assess the causal effect of age (the treatment variable) on several different possible outcome variables which are known to correspond to  heart-rate dynamics:

- The normalized power in the very-low frequency band.
- The DFA slopes, $\alpha_1$ and $\alpha_2$:
- The multiscale entropy.

For simplicity we make the treatment binary:
- $T=0$: Our control group consists of **healthy** individuals aged 22-45.
- $T=1$: Our treated group consists of **healthy** individuals aged 60 and over.

Note that we only use data from health individuals: no known underlying pardiopathologies, no arrhythmias in the processed ECG recordings. This is to ensure that if we measure the effect of age and not pathologies on the changes in HRV.

Let's plot the distribution of the data in the outcome variables of interest, conditioned on the treatment (young/old).

In [None]:
import ciproj.plot

fig, ax = plt.subplots(1, 1, figsize=(15,10))
group_by = dict(by='T')
group_legend_names={0:'control', 1:'treated'}
violin_args = dict(showextrema=False, showmeans=True, widths=0.5)
ciproj.plot.df_group_violins(ax, df_full, ['alpha1', 'alpha2'], group_by, violin_args, group_legend_names)

- The multiscale entropy:

In [None]:
mse_vars = [f'MSE{i}' for i in range(1, 21)]
fig, ax = plt.subplots(1, 1, figsize=(20,10))
ciproj.plot.df_group_violins(ax, df_full, mse_vars, group_by, violin_args, group_legend_names)

- The the power-spectral density within three different frequency bands:

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20,10))
ciproj.plot.df_group_violins(ax, df_full, ['VLF_NORM', 'LF_NORM', 'HF_NORM'], group_by, violin_args, group_legend_names)

We'll create three dataseats.
- The `vlf` dataset: Outcome is normalized VLF power. Regular VLF power will be removed from the covariates.
- The `dfa` dataset: Outcomes are `alpha1` and `alpha2` which are the DFA slopes.
- The `mse` dataset: Outcomes are `MSE1-20`, the multiscale entropy variables.

In [None]:
datasets = {}
datasets['vlf'] = ciproj.data.mark_dataset(df_full, outcomes=['VLF_NORM'], treatment='T', ignore=['VLF_POWER'])
datasets['dfa'] = ciproj.data.mark_dataset(df_full, outcomes=['alpha1', 'alpha2'], treatment='T')
datasets['mse'] = ciproj.data.mark_dataset(df_full, outcomes=mse_vars, treatment='T')

In [None]:
for name, df in datasets.items():
    print(f'*** {name} dataset: ', end='')
    X, y, y = ciproj.data.split_dataset(df, scale_covariates=True)
    print(f'X{X.shape}, y{y.shape}, t{t.shape}')
    print(f'covariates: {[x for x in df.columns if x.startswith("X_")]}')
    print(f'outcomes: {[x for x in df.columns if x.startswith("Y_")]}')
    print()

## Part 2: Propensiy estimation and common support

The *propensity score* is defined as $e(\vec{x}):=\Pr{\left(\rvar{T}=1\given \rvec{X}=\vec{x}\right)}$.
In other words, it is the probability that a treatment $\rvar{T}=1$ will be assigned to a
unit with covariates $\rvec{X}=\vec{x}$.
In this case, since our treatment variable is age, it is not really "assigned" based on covariates.
However, the propensity estimation can shed light on whether covariates have predictive power regarding age, and help us achieve balanced covariate marginal distributions for matching.

We would like to obtain an estimator for the propensity score, $\hat e(\vec{x})$, from the data.
One important aspect of this estimator is that we would like it to be *calibrated*.
Briefly, this means that for if for example we look at all units $\vec{x}$ such that $e(\vec{x})=0.8$, we expect 80\% of them to actually belong to the treatment group ($\rvar{T}=1$).

In [None]:
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.naive_bayes import GaussianNB
from scipy import stats

from ci.cv import CVConfig, LogSpaceSampler

# Define the methods and CV search space for propensity estimation
rcv_config = {
    'logistic': CVConfig(
        model=LogisticRegression(dual=False, solver='liblinear'),
        params=dict(C=LogSpaceSampler(-5, 2), penalty=['l1','l2']),
    ),
    'gbm': CVConfig(
        model=GradientBoostingClassifier(),
        params=dict(
            n_estimators=stats.randint(10, 100 + 1),
            max_depth=stats.randint(1, 3 + 1),
            learning_rate=LogSpaceSampler(-2, 0.5),
        )
    ),
}

In [None]:
from ci.propensity import fit_propensity_cv

# Create a dataframe to store propensities from each dataset and each method
midx = pd.MultiIndex.from_product([datasets.keys(), rcv_config.keys()], names=['dataset', 'method'])
df_propensity = pd.DataFrame(columns=midx)

for method, cv_cfg in rcv_config.items():
    fig, ax = plt.subplots(nrows=1, ncols=len(dataframes), figsize=(16, 6))
    
    for i, (dataset_name, df) in enumerate(datasets.items()):
        # Get covariates X and treatment assignment t as numpy arrays
        X, _, t = ciproj.data.split_dataset(df, scale_covariates=True)
        
        # Train a model to estimate propensity using current method,
        # calibrate with two approaches and generate a calibration plot
        model, best_params = fit_propensity_cv(
            cv_cfg, X, t, plot_args=dict(name=method, ax=ax[i]),
            test_size=0.3, n_iter=42, random_state=42+i, cv_splits=4
        )
        
        # Use the best calibrated model (lowest Brier score) to estimate the propensities
        prop = model.predict_proba(X)[:, 1]
        df_propensity[(dataset_name, method)] = prop
        
        ax[i].set_title(f'{dataset_name}, method={method}')
        print(f'{dataset_name} {method} \tbest_params={best_params}')

The plots above show the calibration curves of the logistic regression (top row) and the GBM models (middle row).
The ideal calibration is shown as a dotted line.
For each classifier, two calibration methods, Platt and Isotonic Regression, are shown.
In the legend, both the ROC-AUC score (`a`) and the Brier score (`b`) are reported as measures of the model's
classification performance (higher is better) and de-calibration level (lower is better), respectively.

In [None]:
df_propensity.head()

To visualize the propensity estimation results of each method, we can plot the distribution of propensity scores between treated and control groups.

In [None]:
for i, (dataset_name, df) in enumerate(datasets.items()):
    fig, ax = plt.subplots(nrows=1, ncols=len(rcv_config.keys()), figsize=(15, 5))
    # Create a temporaty dataframe
    df_tmp = df.copy()
    
    for j, method in enumerate(rcv_config.keys()):
        # Add propensity from current method to the temporary dataframe
        df_tmp['propensity'] = df_propensity[(dataset_name, method)].values
        # Plot propensity scores conditioned on treatment
        groups = df_tmp.groupby('T')
        groups['propensity'].plot(kind='hist', sharex=True, alpha=0.7, bins=50, ax=ax[j])
        ax[j].set_title(f'{dataset_name}, method={method}')
        ax[j].set_xlabel('propensity'); ax[j].grid(True); ax[j].legend([f'T=0','T=1'])

In order to do causal inference we need the common support ("overlap") assumption to hold:
$$
\forall t,~\vec{x}:\ \Pr\left(\rvar{T}=t\given \rvec{X}=\vec{x}\right)>0.
$$
This means that the the probability of any treatment is possible for any possible covariates of a unit.
In practice we have a limited dataset, and in our specific data this assumption does not hold when
we don't have samples from both groups (treatment and control) within the entire range of propensity scores.
Therefore, to maintain the overlap assumption in our dataset, we'll remove samples for which the propensity
score is outside the range of propensity scores of the other group.
In addition, we'll remove samples with extremely low propensity and high scores as these will cause numerical errors.

The common support calculation is implemented in the `common_support` function within the `propensity.py` module.

In [None]:
from ci.propensity import common_support

# Select 'logistic' as the propensity method for all datasets
prop_methods = ('logistic', 'logistic', 'logistic')
cs_threshold = 5e-2 # 5%/95%

# We'll save the dataframes after removing samples outside common support
datasets_cs = {}

fig, ax = plt.subplots(nrows=1, ncols=len(dataframes), figsize=(20, 6))
for i, (dataset_name, df) in enumerate(datasets.items()):
    # Copy the dataset and add a propesity column based on the selected method
    df = df.copy()
    df['propensity'] = df_propensity[(dataset_name, prop_methods[i])].values
    
    # Get common support indices and remove outlying samples
    idx_common = common_support(df['T'].to_numpy(), df['propensity'].to_numpy(),
                                min_thresh=cs_threshold, max_thresh=1-cs_threshold)
    df = df[idx_common]
    
    # Save the df with removed samples
    datasets_cs[dataset_name] = df
    
    groups = df.groupby('T')
    groups['propensity'].plot(kind='hist', sharex=True, alpha=0.7, bins=50, ax=ax[i], density=True)
    ax[i].set_xlabel('propensity'); ax[i].grid(True); ax[i].legend([f'T=0','T=1'])
    ax[i].set_title(f'{dataset_name} Common Support ({len(df)}/{len(datasets[dataset_name])} samples)')
    

In [None]:
datasets_cs['mse']