In [8]:
import pandas as pd
import os
import numpy as np
import logging
import sys
import torch
import copy
import yaml
import random

from prediction_utils.pytorch_utils.metrics import (
    StandardEvaluator,
    FairOVAEvaluator,
    CalibrationEvaluator
)

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
sns.set_style("ticks")

grp_label_dict = {1: 'Black women', 2: 'White women', 3: 'Black men', 4: 'White men'} 

args = {'cohort_path': '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts/cohort/all_cohorts.csv',
        'base_path': '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts',
       }

calib_method = 'loess'

def get_real_prevalence_tables(df):
    means = (df
     .filter(['model_type', 'model_id', 'model_type', 'group', 'calibration_density', 'pred_probs'])
     .groupby(['pred_probs', 'model_id','group'])
     .mean()
     .rename_axis(columns = {'calibration_density': 'prevalence_at_threshold'})
    )

    stds = (df
     .filter(['model_type', 'model_id', 'model_type', 'group', 'calibration_density', 'pred_probs'])
     .groupby(['pred_probs', 'model_id','group'])
     .std()
     .rename_axis(columns = {'calibration_density': 'prevalence_at_threshold'})
    )

    lower = (means-2*stds).applymap(lambda x: format(x, '.4f'))
    higher = (means+2*stds).applymap(lambda x: format(x, '.4f'))
    means1 = means.applymap(lambda x: format(x, '.4f'))
    table = means1 + ' (' + lower + ', ' + higher + ')'
    assert len(df.model_type.unique()) == 1, "multiple model types passed"
    
    table = table.assign(model_type = df.model_type.unique()[0])

    return table

b=[]
eqodds_threshold = 0.01
for experiment in ['original_pce', 'apr14_erm', 'apr14_erm_recalib', 'apr14_mmd']: # 'apr14_mmd', 'apr14_thr',

    aggregate_path = os.path.join(args['base_path'], 'experiments', 
                                      experiment, 'performance',
                                      'all', 'calibration', calib_method)
    thr_calibs = pd.read_csv(os.path.join(aggregate_path, 'calibration_sensitivity_thresholds_raw.csv'))
    
    if experiment in ['apr14_mmd', 'apr14_thr']:
        thr_calibs = thr_calibs.query('model_id == @eqodds_threshold')
    
    a = thr_calibs.assign(group = lambda x: x.group.map(grp_label_dict))
    if 'model_id' not in a.columns:
        a = a.assign(model_id=0)

    table = get_real_prevalence_tables(a)
    b.append(table)

(pd
 .concat(b)
 .assign(model_type = lambda x: x.model_type.map({'original_pce': 'PCE',
                                                                     'erm': 'ERM',
                                                                     'recalib_erm': 'rERM',
                                                                     'eqodds_mmd': 'EqOdd'}
                                                                   ))
 .reset_index()
 .pivot(index=['pred_probs', 'group'], columns="model_type", values="calibration_density")
 .reindex(columns=['PCE', 'ERM', 'rERM', 'EqOdd'])
 .reset_index()
 .rename(columns = {'pred_probs' : 'decision threshold'})
 .set_index(['decision threshold', 'group'])
)

Unnamed: 0_level_0,model_type,PCE,ERM,rERM,EqOdd
decision threshold,group,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0.075,Black men,"0.0309 (0.0309, 0.0309)","0.1086 (0.1013, 0.1159)","0.1294 (0.1216, 0.1373)","0.1116 (0.1041, 0.1191)"
0.075,Black women,"0.0343 (0.0343, 0.0343)","0.0434 (0.0366, 0.0503)","0.0567 (0.0491, 0.0644)","0.0429 (0.0355, 0.0503)"
0.075,White men,"0.0300 (0.0300, 0.0300)","0.0608 (0.0572, 0.0643)","0.0710 (0.0683, 0.0736)","0.0619 (0.0591, 0.0647)"
0.075,White women,"0.0659 (0.0659, 0.0659)","0.0650 (0.0598, 0.0703)","0.0796 (0.0739, 0.0854)","0.0648 (0.0594, 0.0703)"
0.2,Black men,"0.1480 (0.1480, 0.1480)","0.1726 (0.1621, 0.1831)","0.2215 (0.2146, 0.2283)","0.1727 (0.1624, 0.1830)"
0.2,Black women,"0.1317 (0.1317, 0.1317)","0.1826 (0.1750, 0.1903)","0.2154 (0.2063, 0.2245)","0.1805 (0.1732, 0.1877)"
0.2,White men,"0.1340 (0.1340, 0.1340)","0.1826 (0.1729, 0.1923)","0.1554 (0.1493, 0.1616)","0.1833 (0.1740, 0.1926)"
0.2,White women,"0.1768 (0.1768, 0.1768)","0.2796 (0.2750, 0.2843)","0.2768 (0.2698, 0.2839)","0.2758 (0.2693, 0.2823)"
