In [None]:
import pandas as pd
import os
import numpy as np
import logging
import sys
import torch
import copy
import yaml
import random

from prediction_utils.pytorch_utils.metrics import (
    StandardEvaluator,
    FairOVAEvaluator,
    CalibrationEvaluator
)

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
sns.set_style("ticks")

grp_label_dict = {1: 'Black women', 2: 'White women', 3: 'Black men', 4: 'White men'} 

EXPERIMENT_NAME = 'apr14_thr'
#EXPERIMENT_NAME = 'apr14_mmd'
#EXPERIMENT_NAME = 'apr14_erm_recalib'
#EXPERIMENT_NAME = 'apr14_erm'
#EXPERIMENT_NAME = 'original_pce'

include_recalibrated = False
args = {'experiment_name': EXPERIMENT_NAME,
        'cohort_path': '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts/cohort/all_cohorts.csv',
        'base_path': '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts',
        'n_bootstrap': 100,
        'eval_fold': 'test'
       }

aggregate_path = os.path.join(args['base_path'], 'experiments', 
                              EXPERIMENT_NAME, 'performance',
                              'all')

preds_path = os.path.join(aggregate_path, 'predictions.csv')
    
preds = pd.read_csv(preds_path)
if 'fold_id' not in preds.columns:
    preds = preds.assign(fold_id=0)
if 'model_id' not in preds.columns:
    preds = preds.assign(model_id=0)

def get_calib_probs(model, x, transform=None):
    
    if transform=='log':
        model_input = np.log(x)
    else:
        model_input = x
        
    calibration_density = model.predict_proba(model_input.reshape(-1, 1))[:, -1]
                    
    df = pd.DataFrame({'pred_probs': x,
                       'model_input': model_input,
                       'calibration_density': calibration_density})  
    return df
    
def get_calib_model(labels, pred_probs, weights, transform=None):
    
    evaluator = CalibrationEvaluator()
    _, model = evaluator.get_calibration_density_df(labels, 
                                                    pred_probs,
                                                    weights,
                                                    transform = transform)

    return model

df_to_calibrate = preds[preds.phase==args['eval_fold']].reset_index(drop=True)
lin_calibs=[]
thr_calibs=[]
for iter_idx in range(args['n_bootstrap']):
    for group in [1,2,3,4]:
        for model_id in preds.model_id.unique():
            
            max_pred_prob = df_to_calibrate.query("(group==@group) & (model_id==@model_id)").pred_probs.values.max()
            group_df = df_to_calibrate.query("(group==@group) & (model_id==@model_id)")
            
            for fold_id in group_df.fold_id.unique(): 

                    
                df = (group_df
                                .query("(fold_id==@fold_id)")
                                .sample(frac=1, replace=True)
                               )

                loop_kwargs = {'group': group,
                                'fold_id': fold_id,
                              'phase': args['eval_fold'],
                              'model_type': preds.model_type.unique()[0],
                              'model_id' : model_id}

                model = get_calib_model(df.labels, df.pred_probs, df.weights, transform='log')
                    
                lin_calib = (get_calib_probs(model, np.append([1e-15], np.linspace(0.025, int(max_pred_prob/0.025)*0.025, int((max_pred_prob)/0.025))), 'log')
                                 .assign(**loop_kwargs))
                lin_calibs.append(lin_calib)
                    
                thr_calib = (get_calib_probs(model, [0.075, 0.2], 'log')
                                 .assign(**loop_kwargs))
                thr_calibs.append(thr_calib)
    print(iter_idx)

lin_calibs = pd.concat(lin_calibs)
lin_calibs.to_csv(os.path.join(aggregate_path, 'calibration_sensitivity_test_raw.csv'), index=False)

thr_calibs = pd.concat(thr_calibs)
thr_calibs.to_csv(os.path.join(aggregate_path, 'calibration_sensitivity_thresholds_raw.csv'), index=False)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
