# Overview

This notebook plots the main simulation figures for the CPI-DNN paper

- _Supp. Figure 1_: High-level comparison of CPI and Perfmit (Power and Computation time)
- _Supp. 1 Figure 3_: High-level comparison of the different core learners
- _Supp. 2 Figure 3_: Extensive generative model benchmark, inclduding state of the art (Power and Computation time)

In [1]:
import pathlib
import numpy as np
import pandas as pd
import altair as alt
import matplotlib.pyplot as plt
plt.rcParams['text.usetex'] = True

In [2]:
res_path = pathlib.Path('../results/results_csv')
list(res_path.glob('*.csv'))

[PosixPath('../results/results_csv/time_bars_blocks_100_UKBB_single.csv'),
 PosixPath('../results/results_csv/simulation_results_blocks_100_CPI_n_1000_p_50_cpi_depth_10.csv'),
 PosixPath('../results/results_csv/simulation_results_blocks_100_allMethods_pred_final.csv'),
 PosixPath('../results/results_csv/simulation_results_blocks_100_dnn_dnn_py_perm_100--1000.csv'),
 PosixPath('../results/results_csv/time_bars_blocks_100_Mi_dnn_dnn_py_300:100.csv'),
 PosixPath('../results/results_csv/type1error_blocks_100_CPI_LOCO_DNN.csv'),
 PosixPath('../results/results_csv/simulation_results_blocks_100_CPI_n_1000_p_50_cpi_depth_4.csv'),
 PosixPath('../results/results_csv/time_bars_blocks_100_n_10000_p_50_cpi_permfit.csv'),
 PosixPath('../results/results_csv/type1error_blocks_100_dnn_dnn_py_perm_100--1000.csv'),
 PosixPath('../results/results_csv/type1error_blocks_100_UKBB_single.csv'),
 PosixPath('../results/results_csv/AUC_blocks_100_n_10000_p_50_cpi_permfit.csv'),
 PosixPath('../results/results_csv

## Supplementary Figure 1

In [3]:
def ymin(x):    return np.quantile(x, 0.25)
def ymax(x):    return np.quantile(x, 0.75)

In [4]:
df_power_Mi = pd.read_csv(res_path / 'power_blocks_100_Mi_dnn_dnn_py_300:100.csv').iloc[:, 1:]
df_comp_time_Mi = pd.read_csv(res_path / 'time_bars_blocks_100_Mi_dnn_dnn_py_300:100.csv').iloc[:, 1:]
df_comp_time_Mi

Unnamed: 0,n_samples,Method,V1
0,300,Permfit-DNN,462.146225
1,300,CPI-DNN,553.65184


In [5]:
df_power_Mi_agg = df_power_Mi.groupby(['Method', 'correlation'])['V1'].agg([np.mean, ymin, ymax]).reset_index()
df_power_Mi_agg.head(20)

  df_power_Mi_agg = df_power_Mi.groupby(['Method', 'correlation'])['V1'].agg([np.mean, ymin, ymax]).reset_index()


Unnamed: 0,Method,correlation,mean,ymin,ymax
0,CPI-DNN,0.0,0.466,0.4,0.6
1,CPI-DNN,0.2,0.468,0.4,0.6
2,CPI-DNN,0.5,0.532,0.4,0.6
3,CPI-DNN,0.8,0.55,0.4,0.6
4,Permfit-DNN,0.0,0.468,0.4,0.6
5,Permfit-DNN,0.2,0.48,0.4,0.6
6,Permfit-DNN,0.5,0.546,0.4,0.6
7,Permfit-DNN,0.8,0.594,0.6,0.6


In [6]:
## Create left panel ##
marker_size = 200
err_size = 2
height = 65
width= 300
scheme = 'tableau10'

points_b = alt.Chart(
    df_power_Mi_agg
    
).mark_point(
    size=marker_size, opacity=1, fill='white'
).encode(
    y=alt.Y('Method:O', title=None),
    x=alt.X('mean:Q', 
            title='Power',scale=alt.Scale(domain=(0.3, 0.7))),
    color=alt.Color('Method:O', title='Method', scale=alt.Scale(scheme=scheme))
).properties(
    height=height,
    width=width
)

error_b = alt.Chart(
    df_power_Mi_agg
).mark_errorbar(
    size=err_size, opacity=1
).encode(
    y=alt.Y('Method:N',
            title=None,
           ),
    x=alt.X('ymin:Q',  title='Power', scale=alt.Scale(domain=(0.3, 0.7))),
    x2=alt.X2('ymax:Q', title='Power'),
    color=alt.Color('Method:O', title='Method', scale=alt.Scale(scheme=scheme)),
    strokeWidth=alt.value(err_size)
).properties(
    height=height,
    width=width
)


fig_power_Mi = (error_b + points_b).facet(
    row=alt.Row('correlation:O', 
                sort='descending',
                title='Correlation strength')
).properties(title='A')

## create right panel ##

bar_time_Mi = alt.Chart(
    df_comp_time_Mi
).mark_bar().encode(
    y = alt.X('Method:N', title=None),
    x = alt.Y('V1:Q', title="Time (seconds)", scale=alt.Scale(type="log")),
    color = 'Method:N'

).properties(
    height=320,
    width=300,
    title=alt.TitleParams('B', anchor='start')
)

my_font = 'Helvetica'
fig_supp1 = alt.hconcat(
    fig_power_Mi,
    bar_time_Mi
).configure_axis(
    grid=False, 
    titleFont=my_font,
    labelFont=my_font,
    labelFontSize=16,
    titleFontSize=20
).configure_header(
    titleFont=my_font,
    labelFont=my_font,
    titleFontSize=20,
    labelFontSize=16
).configure_view(
    strokeWidth=0
).configure_title(
    font=my_font,
    fontSize=20
).configure_legend(
    titleFontSize=20,
    labelFontSize=20,
    labelLimit=0,
    orient='top'
)

# fig_supp1.save('figure_supp1.svg')
# fig_supp1.save('figure_supp1.png', scale_factor=3)
fig_supp1

  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
  col = df[col_name].apply(to_list_if_array, convert_dtype=False)


## Figure 3 supp - Prediction scores 

In [7]:
df_pred = pd.read_csv(res_path / 'simulation_results_blocks_100_allMethods_pred_final.csv')
df_pred = df_pred.rename(columns={'method': 'Method', 'prob_data': 'Problem Data'})
df_pred = df_pred.groupby(['Method', 'Problem Data', 'iteration'])['score'].agg([np.mean]).dropna().reset_index()
df_pred = df_pred[df_pred['Method'].isin(['Marg', 'Knockoff_lasso', 'MDI', 'BART', 'CPI-DNN'])]

df_pred.loc[df_pred['Problem Data'] == 'classification', 'Problem Data'] = 'Classification'
df_pred.loc[df_pred['Problem Data'] == 'regression', 'Problem Data'] = 'Plain linear'
df_pred.loc[df_pred['Problem Data'] == 'regression_relu', 'Problem Data'] = 'Regression with ReLu'
df_pred.loc[df_pred['Problem Data'] == 'regression_product', 'Problem Data'] = 'Interactions only'
df_pred.loc[df_pred['Problem Data'] == 'regression_combine', 'Problem Data'] = 'Main effects and Interactions'


df_pred.loc[df_pred['Method'] == 'Marg', 'Method'] = 'Marginal'
df_pred.loc[df_pred['Method'] == 'Knockoff_lasso', 'Method'] = 'Lasso'
df_pred.loc[df_pred['Method'] == 'MDI', 'Method'] = 'Random Forest'
df_pred.loc[df_pred['Method'] == 'CPI-DNN', 'Method'] = 'DNN'

df_pred_class = df_pred[df_pred['Problem Data'] == 'Classification']
df_pred_regr = df_pred[~(df_pred['Problem Data'] == 'Classification')]
df_pred_class

  df_pred = pd.read_csv(res_path / 'simulation_results_blocks_100_allMethods_pred_final.csv')
  df_pred = df_pred.groupby(['Method', 'Problem Data', 'iteration'])['score'].agg([np.mean]).dropna().reset_index()


Unnamed: 0,Method,Problem Data,iteration,mean
0,BART,Classification,1,0.972375
1,BART,Classification,2,0.970125
2,BART,Classification,3,0.974638
3,BART,Classification,4,0.926255
4,BART,Classification,5,0.970627
...,...,...,...,...
3595,Marginal,Classification,96,0.595022
3596,Marginal,Classification,97,0.594228
3597,Marginal,Classification,98,0.591206
3598,Marginal,Classification,99,0.614948


In [8]:
def plot_figsupp3(df_data, title_x=None, title_facet=None):
    box_pred = alt.Chart(
        df_data
    ).mark_boxplot(
        size=30, outliers=False, ticks=True, opacity=0.8                                     
    ).encode(
        y=alt.Y('Method:N', title=None),
        x=alt.X('mean:Q', title=title_x, scale=alt.Scale(domain=(0, 1.0))),
        color=alt.Color('Problem Data:N', title='Data Problem')
    ).properties(
        height=200,
        width=250
    ).facet(column=alt.Column('Problem Data:N', title=title_facet))
    return box_pred

prob_data = [('Classification', None, None),
             ('Plain linear', None, None),
             ('Regression with ReLu', None, None),
             ('Interactions only', "Prediction score (R2, AUC)", None),
             ('Main effects and Interactions', "Prediction score (R2, AUC)", None)]
list_figs = []
for el in prob_data:
    df_pred_tmp = df_pred[df_pred['Problem Data'] == el[0]]
    list_figs.append(plot_figsupp3(df_pred_tmp, el[1], el[2]))

fig3_supp = alt.vconcat(
    alt.hconcat(list_figs[0],
                list_figs[1]).properties(title=alt.TitleParams('Scenario', anchor='middle')),
    alt.hconcat(list_figs[2],
                list_figs[3]),
    list_figs[4]
)
my_font = 'Helvetica'
fig3_supp_pred = fig3_supp.configure_axis(
    grid=True, 
    titleFont=my_font,
    titleFontWeight='normal',
    labelFont=my_font,
    labelFontSize=24,
    titleFontSize=28,
    labelLimit=0
).configure_header(
    titleFont=my_font,
    titleFontWeight='normal',
    labelFont=my_font,
    titleFontSize=28,
    labelFontSize=24
).configure_view(
    strokeWidth=0
).configure_title(
    font=my_font,
    fontSize=28
).configure_legend(
    titleFontSize=28,
    labelFontSize=24,
    labelLimit=0,
    titleLimit=0,
    orient='none',
    legendX=325,
    legendY=600
)
# fig3_supp_pred.save('figure3_supp_pred.svg')
# fig3_supp_pred.save('figure3_supp_pred.png', scale_factor=3)
fig3_supp_pred

  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
  col = df[col_name].apply(to_list_if_array, convert_dtype=False)


## Figure 3 supp - Power

In [9]:
def ymin(x):    return np.quantile(x, 0.25)
def ymax(x):    return np.quantile(x, 0.75)
scenarios = ['Classification', 'Plain linear', 'Regression with ReLu', 'Interactions only', 'Main effects and Interactions']
methods = ['Marginal', 'd0CRT', 'Conditional-RF', 'Lazy VI', 'LOCO', 'cpi-knockoff', 'CPI-RF', 'Permfit-DNN', 'CPI-DNN']

In [10]:
df_power_all = pd.read_csv(res_path / 'power_blocks_100_allMethods_pred_imp_final.csv')

In [11]:
df_power_all_agg = df_power_all.groupby(['Method', 'Problem Data', 'prob_type'])['V1'].agg([np.mean, ymin, ymax]).reset_index()
df_power_all_agg = df_power_all_agg.set_index('Problem Data').loc[scenarios].reset_index()
df_power_all_agg = df_power_all_agg.set_index('Method').loc[methods].reset_index()

  df_power_all_agg = df_power_all.groupby(['Method', 'Problem Data', 'prob_type'])['V1'].agg([np.mean, ymin, ymax]).reset_index()


In [12]:
marker_size = 200
err_size = 2
height = 175
width= 125
scheme = 'tableau10'
points_a = alt.Chart(
    df_power_all_agg
).mark_point(
    size=marker_size, opacity=1, fill='white'
).encode(
    x=alt.X('Problem Data:O', title=None, scale=alt.Scale(domain=scenarios[::-1]),
            axis=alt.Axis(labels=False, title=None, tickSize=0)),
    y=alt.Y('mean:Q', 
            title='Power',scale=alt.Scale(domain=(0, 1.0))),
    color=alt.Color('Problem Data:O', title='Scenario',
                    scale=alt.Scale(scheme=scheme, domain=scenarios[::-1]))
).properties(
    height=height,
    width=width
)

error_a = alt.Chart(
    df_power_all_agg
).mark_errorbar(
    opacity=1
).encode(
    x=alt.X('Problem Data:O', title=None, scale=alt.Scale(domain=scenarios[::-1])),
    y=alt.Y('ymin:Q',  title='Power',scale=alt.Scale(domain=(0, 1.0))),
    y2=alt.Y2('ymax:Q', title='Power'),
    color=alt.Color('Problem Data:O', title='Scenario',
                    scale=alt.Scale(scheme=scheme, domain=scenarios[::-1])),
    strokeWidth=alt.value(err_size)
).properties(
    height=height,
    width=width
)

fig_power_all = (error_a + points_a).facet(
    column=alt.Column('Method:O', sort=methods[::-1])
)

my_font = 'Helvetica'
fig3_supp_power = fig_power_all.configure_axis(
    grid=True, 
    titleFont=my_font,
    titleFontWeight='normal',
    labelFont=my_font,
    labelFontSize=24,
    titleFontSize=28,
    labelLimit=0,
    titlePadding=20
).configure_header(
    titleFont=my_font,
    titleFontWeight='normal',
    labelFont=my_font,
    titleFontSize=28,
    labelFontSize=24,
    labelPadding=35
).configure_view(
    strokeWidth=0
).configure_title(
    font=my_font,
    fontSize=28
).configure_legend(
    titleFontSize=28,
    labelFontSize=23,
    labelLimit=0,
#     columns=3,
    orient='top',
    legendX=-300,
    legendY=-90,
    columnPadding=45,
    symbolSize=150
)
# fig3_supp_power.save('figure3_supp_power.svg')
# fig3_supp_power.save('figure3_supp_power.png', scale_factor=3)
fig3_supp_power

  col = df[col_name].apply(to_list_if_array, convert_dtype=False)


## Figure 3 supp - Computation

In [13]:
methods = ['BART', 'MDI', 'SAGE', 'SHAP', 'Knockoff-Deep', 'Knockoff-Bart', 'Knockoff-Lasso', 'Marginal', 'd0CRT', 'Conditional-RF', 'Lazy VI', 'LOCO', 'cpi-knockoff', 'CPI-RF', 'Permfit-DNN', 'CPI-DNN']

In [14]:
df_comp_time_all = pd.read_csv(res_path / 'time_bars_blocks_100_allMethods_pred_imp_final.csv').iloc[:, 1:]
df_comp_time_all = df_comp_time_all.set_index('Method').loc[methods].reset_index()
df_comp_time_all

Unnamed: 0,Method,n_samples,V1
0,BART,1000,264.25862
1,MDI,1000,3823.9078
2,SAGE,1000,36927.84344
3,SHAP,1000,52.26706
4,Knockoff-Deep,1000,1297.68407
5,Knockoff-Bart,1000,295.16324
6,Knockoff-Lasso,1000,48.54227
7,Marginal,1000,2.8457
8,d0CRT,1000,607.61035
9,Conditional-RF,1000,4157.80484


In [15]:
bar_time_all = alt.Chart(
    df_comp_time_all
).mark_bar().encode(
    y = alt.X('Method:N', title='Methods', sort=methods[::-1]),
    x = alt.Y('V1:Q', title="Time (seconds)", scale=alt.Scale(type="log"))
).properties(
    height=400,
    width=400,
)

bar_time_all = bar_time_all.configure_axis(
    grid=True, 
    titleFont=my_font,
    titleFontWeight='normal',
    labelFont=my_font,
    labelFontSize=20,
    titleFontSize=24,
    labelLimit=0,
    titlePadding=30
).configure_header(
    titleFont=my_font,
    titleFontWeight='normal',
    labelFont=my_font,
    titleFontSize=24,
    labelFontSize=20,
    labelPadding=35,
).configure_view(
    strokeWidth=0
).configure_title(
    font=my_font,
    fontSize=24
).configure_legend(
    titleFontSize=20,
    labelFontSize=20,
    labelLimit=0
)
# bar_time_all.save('figure3_supp_time.svg')
# bar_time_all.save('figure3_supp_time.png', scale_factor=3)
bar_time_all

  col = df[col_name].apply(to_list_if_array, convert_dtype=False)


## Figure 3 supp - No Pvals

In [16]:
df_auc_noPval_all = pd.read_csv(res_path / 'AUC_blocks_100_allMethods_pred_imp_final_withoutPval.csv')

In [17]:
methods_noPval = ['Knockoff-Lasso', 'Knockoff-Bart', 'Knockoff-Deep', 'SHAP', 'SAGE', 'MDI', 'BART']

In [18]:
df_auc_noPval_all_agg = df_auc_noPval_all.groupby(['Method', 'Problem Data', 'prob_type'])['V1'].agg([np.mean, ymin, ymax]).reset_index()
df_auc_noPval_all_agg = df_auc_noPval_all_agg.set_index('Problem Data').loc[scenarios].reset_index()
df_auc_noPval_all_agg = df_auc_noPval_all_agg.set_index('Method').loc[methods_noPval].reset_index()
df_auc_noPval_all_agg

  df_auc_noPval_all_agg = df_auc_noPval_all.groupby(['Method', 'Problem Data', 'prob_type'])['V1'].agg([np.mean, ymin, ymax]).reset_index()


Unnamed: 0,Method,Problem Data,prob_type,mean,ymin,ymax
0,Knockoff-Lasso,Classification,classification,0.899833,0.868125,0.94
1,Knockoff-Lasso,Plain linear,regression,0.984225,0.976667,1.0
2,Knockoff-Lasso,Regression with ReLu,regression,0.874775,0.837292,0.920417
3,Knockoff-Lasso,Interactions only,regression,0.503358,0.484583,0.516667
4,Knockoff-Lasso,Main effects and Interactions,regression,0.6002,0.558542,0.634583
5,Knockoff-Bart,Classification,classification,0.756117,0.714583,0.80375
6,Knockoff-Bart,Plain linear,regression,0.822117,0.78625,0.868333
7,Knockoff-Bart,Regression with ReLu,regression,0.779183,0.745,0.822083
8,Knockoff-Bart,Interactions only,regression,0.8031,0.766667,0.8325
9,Knockoff-Bart,Main effects and Interactions,regression,0.839617,0.809583,0.873333


In [19]:
marker_size = 200
err_size = 2
height = 175
width= 175
scheme = 'tableau10'
points_a = alt.Chart(
    df_auc_noPval_all_agg
).mark_point(
    size=marker_size, opacity=1, fill='white'
).encode(
    x=alt.X('Problem Data:O', title=None, scale=alt.Scale(domain=scenarios[::-1]),
            axis=alt.Axis(labels=False, title=None, tickSize=0)),
    y=alt.Y('mean:Q', 
            title='AUC score',scale=alt.Scale(domain=(0.5, 1.0))),
    color=alt.Color('Problem Data:O', title='Scenario',
                    scale=alt.Scale(scheme=scheme, domain=scenarios[::-1]))
).properties(
    height=height,
    width=width
)

error_a = alt.Chart(
    df_auc_noPval_all_agg
).mark_errorbar(
    opacity=1
).encode(
    x=alt.X('Problem Data:O', title=None, scale=alt.Scale(domain=scenarios[::-1])),
    y=alt.Y('ymin:Q',  title='AUC score',scale=alt.Scale(domain=(0.5, 1.0))),
    y2=alt.Y2('ymax:Q', title='AUC score'),
    color=alt.Color('Problem Data:O', title='Scenario',
                    scale=alt.Scale(scheme=scheme, domain=scenarios[::-1])),
    strokeWidth=alt.value(err_size)
).properties(
    height=height,
    width=width
)

rule2 = alt.Chart(pd.DataFrame({
  'V1': [0.5],
  'Problem Data': ['Type-I error nominal rate'],
  'color': ['black']
})).mark_rule(size=1.5, color='black').encode(
  y='V1:Q'
)

fig_auc_noPval_all = (error_a + points_a + rule2).facet(
    column=alt.Column('Method:O', sort=methods[::-1])
)

my_font = 'Helvetica'
fig3_supp_noPval = fig_auc_noPval_all.configure_axis(
    grid=True, 
    titleFont=my_font,
    titleFontWeight='normal',
    labelFont=my_font,
    labelFontSize=24,
    titleFontSize=28,
    labelLimit=0,
    titlePadding=20
).configure_header(
    titleFont=my_font,
    titleFontWeight='normal',
    labelFont=my_font,
    titleFontSize=28,
    labelFontSize=24,
    labelPadding=35
).configure_view(
    strokeWidth=0
).configure_title(
    font=my_font,
    fontSize=28
).configure_legend(
    titleFontSize=28,
    labelFontSize=23,
    labelLimit=0,
#     columns=3,
    orient='top',
#     legendX=-300,
#     legendY=-90,
    columnPadding=45,
    symbolSize=150
)
# fig3_supp_noPval.save('figure3_supp_noPval.svg')
# fig3_supp_noPval.save('figure3_supp_noPval.png', scale_factor=3)
fig3_supp_noPval

  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
