In [None]:
import pandas as pd
import os
import numpy as np
import logging
import sys
import torch
import copy
import yaml
import random

from prediction_utils.pytorch_utils.metrics import (
    StandardEvaluator,
    FairOVAEvaluator,
    CalibrationEvaluator
)

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
sns.set_style("ticks")


args = {'cohort_path': '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts/cohort/all_cohorts.csv',
        'base_path': '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts',
        'grp_label_dict': {1: 'Black women', 2: 'White women', 3: 'Black men', 4: 'White men', 'overall': 'Overall'},
        'plot_path': os.path.join('/labs/shahlab/projects/agataf/fairness_utility/eval_manuscript/plots', 'paper_plots'),

       }

output_path = os.path.join(args['plot_path'], 'risk_category_counts.png')
os.makedirs(args['plot_path'], exist_ok=True)


In [None]:
# def get_frac_treated(df, add_strata=None):
#     base_grouping = ['model_type', 'model_id','labels'] 
#     if add_strata:
#         base_grouping.append(add_strata)
#     num_treated_by_stratum = (df
#          .query('(phase=="test")')
#          .filter(['person_id', 'treat'] + base_grouping)
#          .groupby(['treat'] + base_grouping)
#          .count()
#         )

#     num_per_stratum = (df
#           .query('(phase=="test")')
#                .filter(['person_id'] + base_grouping)
#                .groupby(base_grouping)
#                .count()
#          )

#     frac_treated = num_treated_by_stratum.div(num_per_stratum)


#     return frac_treated.reset_index()

# frac_treated_all = []
# eqodds_threshold = 0.1
# for experiment in ['original_pce', 'revised_pce', 'apr14_erm', 'apr14_erm_recalib', 'scratch_thr']:

#     aggregate_path = os.path.join(args['base_path'], 'experiments', 
#                                       experiment, 'performance',
#                                       'all')
#     preds_path = os.path.join(aggregate_path, 'predictions.csv')
#     preds = pd.read_csv(preds_path)
    
#     if experiment in ['apr14_mmd', 'apr14_thr']:
#         preds = preds.query('model_id >= @eqodds_threshold')
        
#     if 'model_id' not in preds.columns:
#         preds = preds.assign(model_id=0)
        
#     frac_treated = get_frac_treated(preds, add_strata='group') 
#     frac_treated_overall = get_frac_treated(preds, add_strata=None).assign(group='overall')
#     frac_treated = frac_treated.append(frac_treated_overall)
#     frac_treated = (frac_treated
#                     .assign(group = lambda x: x.group.map(args['grp_label_dict']),
#                             labels = lambda x: x.labels.map({0: ' Individuals with \nno ASCVD event',
#                                                              1: 'Individuals with \n  ASCVD event'}))
#                     .rename(columns={'person_id': 'fraction'})
#                    )
#     frac_treated_all.append(frac_treated)

# frac_treated_all = (pd
#                     .concat(frac_treated_all)
#                     .assign(model_type = lambda x: x.model_type.map({'original_pce': 'PCE',
#                                                                      'revised_pce': 'rPCE',
#                                                                      'erm': 'BL',
#                                                                      'recalib_erm': 'rBL',
#                                                                      'eqodds_thr': 'EO',
#                                                                     }
#                                                                    )
#                            )
#                    )


# model_type = np.where((frac_treated_all.model_type=='EO') & (frac_treated_all.model_id==0.1), 'EO1', frac_treated_all.model_type)
# model_type = np.where((frac_treated_all.model_type=='EO') & (frac_treated_all.model_id==0.21544346900318825), 'EO2', model_type)
# model_type = np.where((frac_treated_all.model_type=='EO') & (frac_treated_all.model_id==0.4641588833612778), 'EO', model_type)
# model_type = np.where((frac_treated_all.model_type=='EO') & (frac_treated_all.model_id==1.0), 'EO4', model_type)

# frac_treated_all = (frac_treated_all
#            .assign(model_type=model_type)
#            .query("model_type != ['EO1', 'EO2', 'EO4']")
#            #.assign(model_type={'EO3': 'EO'})
#            .assign(model_type = lambda x: pd.Categorical(x.model_type, 
#                                                          categories=['PCE', 'rPCE', 'BL', 'rBL', 'EO'],
#                                                          ordered=True),
#                   group = lambda x: pd.Categorical(x.group, 
#                                                    categories=args['grp_label_dict'].values(),
#                                                    ordered=True)
#                   )
#           )

In [None]:
# (frac_treated_all
#  .assign(labels = lambda x: x.labels.map({' Individuals with \nno ASCVD event':0,
#                                           'Individuals with \n  ASCVD event': 1}))
#  .drop(columns=['model_id'])
#  .pivot(index=['treat', 'labels','group'], columns=['model_type'], values='fraction')
#  .fillna(0)
#  .round(3)
#  .head(50)
# )

In [None]:
sns.set(font_scale=4)
sns.set_style("ticks")

g = sns.displot(
    frac_treated_all,
    x='group',
    col="model_type", 
    row='labels', 
    hue="treat", 
    weights='fraction', 
    shrink=0.8, 
    multiple='stack',
    height=10, 
    bins=len(frac_treated_all.model_type.unique()), 
    hue_order=['0','1','2',], 
    legend=False, #stat='frequency', 
    facet_kws = {'margin_titles': True}
)

plt.legend(title='Risk category',
           loc='lower center', 
           labels=['High risk', 'Intermediate risk', 'Low risk'],
          # bbox_to_anchor=(1.07, 1)
           bbox_to_anchor=(-2, -1.2)
          )

(g.set_xlabels('')
 .set_ylabels('Fraction of individuals')
 .set_titles(row_template="{row_name}", col_template="{col_name}")
 .set_xticklabels(rotation=90)
)

sns.set(font_scale=1)

# if output_path is not None:
#     g.savefig(output_path)
    

In [None]:
sns.set(font_scale=4)
sns.set_style("ticks")


g = sns.displot(
    frac_treated_all,
    x='group',
    col="model_type", 
    #row='labels', 
    hue="treat", 
    weights='fraction', 
    shrink=0.8, 
    multiple='stack',
    height=10, 
    bins=len(frac_treated_all.model_type.unique()), 
    hue_order=['0','1','2',], 
    legend=False, #stat='frequency', 
    facet_kws = {'margin_titles': True}
)

plt.legend(title='Risk category',
           loc='lower center', 
           labels=['High risk', 'Intermediate risk', 'Low risk'],
          # bbox_to_anchor=(1.07, 1)
           bbox_to_anchor=(-2.6, -1.2)
          )

(g.set_xlabels('')
 .set_ylabels('Fraction of individuals')
 .set_titles(row_template="{row_name}", col_template="{col_name}")
 .set_xticklabels(rotation=90)
)

sns.set(font_scale=1)

    

In [None]:
(frac_treated_all
 #.query("(phase=='test') ")
 .assign(labels = lambda x: x.labels.map({' Individuals with \nno ASCVD event':0,
                                          'Individuals with \n  ASCVD event': 1}))
 .drop(columns=['model_id', ;])
 .pivot(index=['model_type', 'treat',], columns=['group'], values='fraction')
 .fillna(0)
 .round(3)
 .head(50)
)

In [None]:
    add_strata='group'
    base_grouping = ['model_type', 'model_id'] 
    if add_strata:
        base_grouping.append(add_strata)
    num_treated_by_stratum = (df
         .query('(phase=="test")')
         .filter(['person_id', 'treat', 'labels'] + base_grouping)
         .groupby(['treat'] + base_grouping)
         .agg({'labels': np.mean})
        )
#     #num_treated_by_stratum
#     num_per_stratum = (df
#           .query('(phase=="test")')
#                .filter(['person_id'] + base_grouping)
#                .groupby(base_grouping)
#                .sum()
#          )

#     frac_treated = num_treated_by_stratum.div(num_per_stratum)
#     frac_treated.reset_index()

In [None]:
def get_frac_treated(df, add_strata=None):
    base_grouping = ['model_type', 'model_id','labels'] 
    if add_strata:
        base_grouping.append(add_strata)
    num_treated_by_stratum = (df
         .query('(phase=="test")')
         .filter(['person_id', 'treat'] + base_grouping)
         .groupby(['treat'] + base_grouping)
         .count()
        )

    num_per_stratum = (df
          .query('(phase=="test")')
               .filter(['person_id'] + base_grouping)
               .groupby(base_grouping)
               .count()
         )

    frac_treated = num_treated_by_stratum.div(num_per_stratum)
    return frac_treated.reset_index()

def get_outcomes_treated(df, add_strata=None):    
    base_grouping = ['model_type', 'model_id'] 
    if add_strata:
        base_grouping.append(add_strata)
        
    num_treated_by_stratum = (df
         .query('(phase=="test")')
         .filter(['person_id', 'treat', 'labels'] + base_grouping)
         .groupby(['treat'] + base_grouping)
         .agg({'labels': np.mean})
        )

    frac_treated = num_treated_by_stratum.reset_index().rename(columns={'labels': 'outcome_rate'})
    return frac_treated

def get_both(df, fun):
    frac = fun(df, add_strata='group') 
    frac_overall = fun(df, add_strata=None).assign(group='overall')
    frac = frac.append(frac_overall)
    return (frac
            .assign(group = lambda x: x.group.map(args['grp_label_dict']))
            .rename(columns={'person_id': 'fraction'})
           )

frac_oucomes_all = []
frac_treated_all = []
eqodds_threshold = 0.1
for experiment in ['original_pce', 'revised_pce', 'apr14_erm', 'apr14_erm_recalib', 'scratch_thr']:

    aggregate_path = os.path.join(args['base_path'], 'experiments', 
                                      experiment, 'performance',
                                      'all')
    preds_path = os.path.join(aggregate_path, 'predictions.csv')
    preds = pd.read_csv(preds_path)
    
    if experiment in ['apr14_mmd', 'apr14_thr']:
        preds = preds.query('model_id >= @eqodds_threshold')
        
    if 'model_id' not in preds.columns:
        preds = preds.assign(model_id=0)
        
    frac_treated = get_both(preds, get_frac_treated)
    frac_oucomes = get_both(preds, get_outcomes_treated)
    
#     frac_treated = get_frac_treated(preds, add_strata='group') 
#     frac_treated_overall = get_frac_treated(preds, add_strata=None).assign(group='overall')
#     frac_treated = frac_treated.append(frac_treated_overall)
#     frac_treated = (frac_treated
#                     .assign(group = lambda x: x.group.map(args['grp_label_dict']),
#                             labels = lambda x: x.labels.map({0: ' Individuals with \nno ASCVD event',
#                                                              1: 'Individuals with \n  ASCVD event'}))
#                     .rename(columns={'person_id': 'fraction'})
#                    )
    frac_treated_all.append(frac_treated)
    frac_oucomes_all.append(frac_oucomes)

In [None]:
    #get_both(preds, get_outcomes_treated)
    df=preds
    fun=get_outcomes_treated
    frac = fun(df, add_strata='group') 
    frac_overall = fun(df, add_strata=None).assign(group='overall')
    frac = frac.append(frac_overall)
    a = (frac
            .assign(group = lambda x: x.group.map(args['grp_label_dict']))
            .rename(columns={'person_id': 'fraction'})
           )

In [None]:
custom_palette = ["red", "green", "orange", "blue", "black"]
sns.set_palette(custom_palette)
g=sns.relplot(data = frac_outcomes_all,
                x = 'model_type',
                y = 'outcome_rate',
                col='treat',
                kind = 'line',
                palette = custom_palette, 
                hue='group',
                #err_style="bars",
                facet_kws= {'sharey': True, 'margin_titles': True},
                #err_kws = {'capsize': 5},
                #legend=False,
                aspect=1,
              marker='X',
              markersize=14,
                linestyle='',
             #dodge=True
             )

In [None]:
frac_treated_all.head()

In [None]:
custom_palette = ["red", "green", "orange", "blue", "black"]
sns.set_palette(custom_palette)
g=sns.relplot(data = frac_treated_all,
                x = 'model_type',
                y = 'fraction',
                col='treat',
              row='labels',
                kind = 'line',
                palette = custom_palette, 
                hue='group',
                #err_style="bars",
                facet_kws= {'sharey': True, 'margin_titles': True},
                #err_kws = {'capsize': 5},
                #legend=False,
                aspect=1,
              marker='X',
              markersize=14,
                linestyle='',
             #dodge=True
             )

In [None]:
frac_outcomes_all = (pd
                    .concat(frac_oucomes_all)
                    .assign(model_type = lambda x: x.model_type.map({'original_pce': 'PCE',
                                                                     'revised_pce': 'rPCE',
                                                                     'erm': 'BL',
                                                                     'recalib_erm': 'rBL',
                                                                     'eqodds_thr': 'EO',
                                                                    }
                                                                   )
                           )
                   )


model_type = np.where((frac_outcomes_all.model_type=='EO') & (frac_outcomes_all.model_id==0.1), 'EO1', frac_outcomes_all.model_type)
model_type = np.where((frac_outcomes_all.model_type=='EO') & (frac_outcomes_all.model_id==0.21544346900318825), 'EO2', model_type)
model_type = np.where((frac_outcomes_all.model_type=='EO') & (frac_outcomes_all.model_id==0.4641588833612778), 'EO', model_type)
model_type = np.where((frac_outcomes_all.model_type=='EO') & (frac_outcomes_all.model_id==1.0), 'EO4', model_type)

frac_outcomes_all = (frac_outcomes_all
           .assign(model_type=model_type)
           .query("model_type != ['EO1', 'EO2', 'EO4']")
           #.assign(model_type={'EO3': 'EO'})
           .assign(model_type = lambda x: pd.Categorical(x.model_type, 
                                                         categories=['PCE', 'rPCE', 'BL', 'rBL', 'EO'],
                                                         ordered=True),
                  group = lambda x: pd.Categorical(x.group, 
                                                   categories=args['grp_label_dict'].values(),
                                                   ordered=True)
                  )
          )

In [None]:
frac_treated_all = (pd
                    .concat(frac_treated_all)
                    .assign(model_type = lambda x: x.model_type.map({'original_pce': 'PCE',
                                                                     'revised_pce': 'rPCE',
                                                                     'erm': 'BL',
                                                                     'recalib_erm': 'rBL',
                                                                     'eqodds_thr': 'EO',
                                                                    }
                                                                   )
                           )
                   )


model_type = np.where((frac_treated_all.model_type=='EO') & (frac_treated_all.model_id==0.1), 'EO1', frac_treated_all.model_type)
model_type = np.where((frac_treated_all.model_type=='EO') & (frac_treated_all.model_id==0.21544346900318825), 'EO2', model_type)
model_type = np.where((frac_treated_all.model_type=='EO') & (frac_treated_all.model_id==0.4641588833612778), 'EO', model_type)
model_type = np.where((frac_treated_all.model_type=='EO') & (frac_treated_all.model_id==1.0), 'EO4', model_type)

frac_treated_all = (frac_treated_all
           .assign(model_type=model_type)
           .query("model_type != ['EO1', 'EO2', 'EO4']")
           #.assign(model_type={'EO3': 'EO'})
           .assign(model_type = lambda x: pd.Categorical(x.model_type, 
                                                         categories=['PCE', 'rPCE', 'BL', 'rBL', 'EO'],
                                                         ordered=True),
                  group = lambda x: pd.Categorical(x.group, 
                                                   categories=args['grp_label_dict'].values(),
                                                   ordered=True)
                  )
          )