In [1]:
import pandas as pd
import os
import numpy as np
import logging
import sys
import torch
import copy
import yaml
import random

from prediction_utils.pytorch_utils.metrics import (
    StandardEvaluator,
    FairOVAEvaluator,
    CalibrationEvaluator
)

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
sns.set_style("ticks")

grp_label_dict = {1: 'Black women', 2: 'White women', 3: 'Black men', 4: 'White men'} 

EXPERIMENT_NAME = 'apr14_thr'
#EXPERIMENT_NAME = 'apr14_mmd'
#EXPERIMENT_NAME = 'apr14_erm_recalib'
#EXPERIMENT_NAME = 'apr14_erm'
#EXPERIMENT_NAME = 'original_pce'

include_recalibrated = False
args = {'experiment_name': EXPERIMENT_NAME,
        'cohort_path': '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts/cohort/all_cohorts.csv',
        'base_path': '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts',
        'n_bootstrap': 100,
        'eval_fold': 'test'
       }

aggregate_path = os.path.join(args['base_path'], 'experiments', 
                              EXPERIMENT_NAME, 'performance',
                              'all')

preds_path = os.path.join(aggregate_path, 'predictions.csv')
    
preds = pd.read_csv(preds_path)
if 'fold_id' not in preds.columns:
    preds = preds.assign(fold_id=0)
if 'model_id' not in preds.columns:
    preds = preds.assign(model_id=0)

def get_calib_probs(model, x, transform=None):
    
    if transform=='log':
        model_input = np.log(x)
    else:
        model_input = x
        
    calibration_density = model.predict_proba(model_input.reshape(-1, 1))[:, -1]
                    
    df = pd.DataFrame({'pred_probs': x,
                       'model_input': model_input,
                       'calibration_density': calibration_density})  
    return df
    
def get_calib_model(labels, pred_probs, weights, transform=None):
    
    evaluator = CalibrationEvaluator()
    _, model = evaluator.get_calibration_density_df(labels, 
                                                    pred_probs,
                                                    weights,
                                                    transform = transform)

    return model

df_to_calibrate = preds[preds.phase==args['eval_fold']].reset_index(drop=True)
lin_calibs=[]
thr_calibs=[]
for iter_idx in range(args['n_bootstrap']):
    for group in [1,2,3,4]:
        for model_id in preds.model_id.unique():
            
            max_pred_prob = df_to_calibrate.query("(group==@group) & (model_id==@model_id)").pred_probs.values.max()
            group_df = df_to_calibrate.query("(group==@group) & (model_id==@model_id)")
            
            for fold_id in group_df.fold_id.unique(): 

                    
                df = (group_df
                                .query("(fold_id==@fold_id)")
                                .sample(frac=1, replace=True)
                               )

                loop_kwargs = {'group': group,
                                'fold_id': fold_id,
                              'phase': args['eval_fold'],
                              'model_type': preds.model_type.unique()[0],
                              'model_id' : model_id}

                model = get_calib_model(df.labels, df.pred_probs, df.weights, transform='log')
                    
                lin_calib = (get_calib_probs(model, np.append([1e-15], np.linspace(0.025, int(max_pred_prob/0.025)*0.025, int((max_pred_prob)/0.025))), 'log')
                                 .assign(**loop_kwargs))
                lin_calibs.append(lin_calib)
                    
                thr_calib = (get_calib_probs(model, [0.075, 0.2], 'log')
                                 .assign(**loop_kwargs))
                thr_calibs.append(thr_calib)
    print(iter_idx)

lin_calibs = pd.concat(lin_calibs)
lin_calibs.to_csv(os.path.join(aggregate_path, 'calibration_sensitivity_test_raw.csv'), index=False)

thr_calibs = pd.concat(thr_calibs)
thr_calibs.to_csv(os.path.join(aggregate_path, 'calibration_sensitivity_thresholds_raw.csv'), index=False)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [2]:
thr_calibs

Unnamed: 0,pred_probs,model_input,calibration_density,group,fold_id,phase,model_type,model_id
0,0.075,-2.590267,0.083191,1,1,test,eqodds_mmd,0.001
1,0.200,-1.609438,0.233207,1,1,test,eqodds_mmd,0.001
0,0.075,-2.590267,0.056410,1,2,test,eqodds_mmd,0.001
1,0.200,-1.609438,0.202184,1,2,test,eqodds_mmd,0.001
0,0.075,-2.590267,0.069819,1,3,test,eqodds_mmd,0.001
...,...,...,...,...,...,...,...,...
1,0.200,-1.609438,0.355876,4,8,test,eqodds_mmd,1.000
0,0.075,-2.590267,0.061109,4,9,test,eqodds_mmd,1.000
1,0.200,-1.609438,0.324344,4,9,test,eqodds_mmd,1.000
0,0.075,-2.590267,0.088150,4,10,test,eqodds_mmd,1.000


In [3]:
lin_calibs

Unnamed: 0,pred_probs,model_input,calibration_density,group,fold_id,phase,model_type,model_id
0,1.000000e-15,-34.538776,7.051558e-19,1,1,test,eqodds_mmd,0.001
1,2.500000e-02,-3.688879,2.287709e-02,1,1,test,eqodds_mmd,0.001
2,5.000000e-02,-2.995732,5.216602e-02,1,1,test,eqodds_mmd,0.001
3,7.500000e-02,-2.590267,8.319070e-02,1,1,test,eqodds_mmd,0.001
4,1.000000e-01,-2.302585,1.145566e-01,1,1,test,eqodds_mmd,0.001
...,...,...,...,...,...,...,...,...
21,5.250000e-01,-0.644357,7.074774e-01,4,10,test,eqodds_mmd,1.000
22,5.500000e-01,-0.597837,7.231485e-01,4,10,test,eqodds_mmd,1.000
23,5.750000e-01,-0.553385,7.376289e-01,4,10,test,eqodds_mmd,1.000
24,6.000000e-01,-0.510826,7.510272e-01,4,10,test,eqodds_mmd,1.000


In [4]:
preds

Unnamed: 0,phase,outputs,pred_probs,labels,person_id,weights,group,treat,ldlc,relative_risk,fold_id,config_id,model_id,model_type
0,val,-1.344877,0.018974,0,1,1.069048,2,0,141.679,1.000000,1,0,0.001,eqodds_mmd
1,val,-1.220947,0.009054,0,12,1.069048,2,0,140.901,1.000000,1,0,0.001,eqodds_mmd
2,val,-1.244707,0.011578,0,13,1.069048,2,0,133.879,1.000000,1,0,0.001,eqodds_mmd
3,val,-1.243629,0.027535,0,20,1.069048,2,0,108.800,1.000000,1,0,0.001,eqodds_mmd
4,val,-1.881558,0.048957,0,44,1.069048,2,0,133.294,1.000000,1,0,0.001,eqodds_mmd
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
654735,eval,-1.703006,0.082949,1,23450,1.132407,3,1,167.000,0.689061,10,9,1.000,eqodds_mmd
654736,eval,-1.850750,0.238665,1,23472,1.122530,3,2,144.000,0.585539,10,9,1.000,eqodds_mmd
654737,eval,-0.604527,0.104552,1,23476,1.126947,3,1,79.000,0.838469,10,9,1.000,eqodds_mmd
654738,eval,-2.496016,0.171404,1,23487,1.034370,3,1,111.000,0.780719,10,9,1.000,eqodds_mmd


In [7]:
os.path.join(aggregate_path, 'calibration_sensitivity_thresholds_raw.csv')

'/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts/experiments/apr14_mmd/performance/all/calibration_sensitivity_thresholds_raw.csv'