In [91]:
import numpy as np
import pandas as pd
import os

from prediction_utils.pytorch_utils.metrics import StandardEvaluator
from prediction_utils.pytorch_utils.metrics import *
from collections import ChainMap

class CalibrationEvaluatorNew(CalibrationEvaluator):

    def observation_rate_at_point(
        self,
        point,
        labels,
        pred_probs,
        sample_weight=None,
        model_type="logistic",
        transform=None,
    ):

        df, model = self.get_calibration_density_df(
            labels,
            pred_probs,
            sample_weight=sample_weight,
            model_type=model_type,
            transform=transform,
        )
        
        valid_transforms = ["log", "c_log_log"]
        
        if transform is None:
            point = np.array(point).reshape(-1, 1)
        elif transform in valid_transforms:
            if transform == "log":
                point = np.array(np.log(point)).reshape(-1, 1)
            elif transform == "c_log_log":
                point = np.array(self.c_log_log(point)).reshape(-1, 1)
        else:
            raise ValueError("Invalid transform provided")
        
        calibration_density = model.predict_proba(point)
        if len(calibration_density.shape) > 1:
            calibration_density = calibration_density[:, -1]
            
        return calibration_density[0]

class StandardEvaluatorNew(StandardEvaluator):
    
    def get_threshold_metrics(
        self,
        threshold_metrics=None,
        thresholds=[0.01, 0.05, 0.1, 0.2, 0.5],
        weighted=False,
    ):
        """
        Returns a set of metric functions that are defined with respect to a set of thresholds
        """
        if thresholds is None:
            return {}

        if threshold_metrics is None:
            threshold_metrics = [
                "recall",
                "precision",
                "specificity",
            ]  # acts as default value

        result = {}

        if "recall" in threshold_metrics:
            result["recall"] = {
                "recall_{}".format(threshold): generate_recall_at_threshold(
                    threshold, weighted=weighted
                )
                for threshold in thresholds
            }
        if "precision" in threshold_metrics:
            result["precision"] = {
                "precision_{}".format(threshold): generate_precision_at_threshold(
                    threshold, weighted=weighted
                )
                for threshold in thresholds
            }
        if "specificity" in threshold_metrics:
            result["specificity"] = {
                "specificity_{}".format(threshold): generate_specificity_at_threshold(
                    threshold, weighted=weighted
                )
                for threshold in thresholds
            }
        if "observation_rate" in threshold_metrics:
            result["observation_rate"] = {
                "observation_rate_{}".format(threshold): generate_observation_rate_at_threshold(
                    threshold, weighted=weighted
                )
                for threshold in thresholds
            }  
        if "nri_e" in threshold_metrics:
            result["nri_e"] = {
                "nri_e_{}".format(threshold): generate_nri_e_at_threshold(
                    threshold, weighted=weighted
                )
                for threshold in thresholds
            }
        if "nri_ne" in threshold_metrics:
            result["nri_ne"] = {
                "nri_ne_{}".format(threshold): generate_nri_ne_at_threshold(
                    threshold, weighted=weighted
                )
                for threshold in thresholds
            }
        if len(result) > 0:
            return dict(ChainMap(*result.values()))
        else:
            return result
        
def generate_observation_rate_at_threshold(threshold, weighted=False):
    """
    Returns a lambda function that computes the specificity at a provided threshold.
    If weights = True, the lambda function takes a third argument for the sample weights
    """
    return (
            lambda labels, pred_probs, sample_weight: (
             observation_rate_at_point(threshold, labels, pred_probs,
                                                      sample_weight,
                                                     model_type="logistic", 
                                                     transform='log')))

def generate_nri_e_at_threshold(threshold, weighted=False):
    """
    Returns a lambda function that computes the specificity at a provided threshold.
    If weights = True, the lambda function takes a third argument for the sample weights
    """
    return (
            lambda labels, pred_probs, sample_weight: (
             observation_rate_at_point(threshold, labels, pred_probs,
                                                      sample_weight,
                                                     model_type="logistic", 
                                                     transform='log')))

def recall_at_threshold(labels, pred_probs, sample_weight=None, threshold=0.5):
    """
    Computes recall at a threshold
    """
    return threshold_metric_fn(
        labels=labels,
        pred_probs=pred_probs,
        sample_weight=None,
        threshold=threshold,
        metric_generator_fn=generate_recall_at_threshold,
    )


def generate_recall_at_threshold(threshold, weighted=False):
    """
    Returns a lambda function that computes the recall at a provided threshold.
    If weights = True, the lambda function takes a third argument for the sample weights
    """
    if not weighted:
        return lambda labels, pred_probs: recall_score(
            labels, 1.0 * (pred_probs >= threshold)
        )
    else:
        return lambda labels, pred_probs, sample_weight: recall_score(
            labels, 1.0 * (pred_probs >= threshold), sample_weight=sample_weight
        )

In [92]:
import numpy as np
import pandas as pd
import os

grp_label_dict = {1: "Black women", 2: "White women", 3: "Black men", 4: "White men"}

args = {
    "cohort_path": "/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts/cohort/all_cohorts.csv",
    "base_path": "/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts",
    "eval_fold": "test",
}
# aggregate_path = os.path.join(
#     args["base_path"], "experiments", args["experiment_name"], "performance", "all"
# )

# preds_path = os.path.join(aggregate_path, "predictions.csv")
# preds = pd.read_csv(preds_path)
# eval_df = preds.query('phase == "test"')

In [93]:
preds_all = []
eqodds_threshold = 0.1
for experiment in ['original_pce', 'revised_pce', 'apr14_erm', 'apr14_erm_recalib', 'scratch_thr']:
    aggregate_path = os.path.join(args['base_path'], 'experiments', 
                                  experiment, 'performance',
                                  'all')
    preds_path = os.path.join(aggregate_path, 'predictions.csv')

    preds = pd.read_csv(preds_path)
    if 'model_id' not in preds.columns:
        preds = preds.assign(model_id=0)
    if 'fold_id' not in preds.columns:
        preds = preds.assign(fold_id=0)
    if experiment in ['apr14_mmd', 'apr14_thr', 'scratch_thr']:
        preds = preds.query('model_id >= @eqodds_threshold')
        
    preds_all.append(preds) 
preds_all = pd.concat(preds_all)

In [94]:
preds_all.query("(phase=='test') & (model_type=='eqodds_thr')").model_id.unique()

array([0.1, 0.21544346900318825, 0.4641588833612778, 1.0], dtype=object)

In [95]:
preds_all.query("phase=='test'").groupby(['phase', 'model_type', 'fold_id', 'group']).count()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,pred_probs,labels,weights,person_id,treat,relative_risk,model_id,outputs,ldlc,config_id
phase,model_type,fold_id,group,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
test,eqodds_thr,1,1,1944,1944,1944,1944,1944,1944,1944,1944,1944,1944
test,eqodds_thr,1,2,3232,3232,3232,3232,3232,3232,3232,3232,3232,3232
test,eqodds_thr,1,3,1348,1348,1348,1348,1348,1348,1348,1348,1348,1348
test,eqodds_thr,1,4,2804,2804,2804,2804,2804,2804,2804,2804,2804,2804
test,eqodds_thr,2,1,1944,1944,1944,1944,1944,1944,1944,1944,1944,1944
test,...,...,...,...,...,...,...,...,...,...,...,...,...
test,recalib_erm,10,4,701,701,701,701,701,701,701,701,701,701
test,revised_pce,0,1,539,539,539,539,539,539,539,0,0,0
test,revised_pce,0,2,856,856,856,856,856,856,856,0,0,0
test,revised_pce,0,3,382,382,382,382,382,382,382,0,0,0


In [98]:
evaluator = StandardEvaluatorNew(thresholds = [0.075, 0.2],
                                              metrics = ['auc', 'auprc', 'ace_rmse_logistic_log', 'ace_rmse_bin_log', 'loss_bce'],
                                         threshold_metrics=['observation_rate', 'specificity', 'recall'])

        
result_df_ci = evaluator.bootstrap_evaluate(
    df=preds_all.query("phase=='test'"),
    n_boot=2,
    strata_vars_eval=['phase', 'model_type', 'model_id', 'fold_id', 'group'],
    strata_vars_boot=['phase', 'labels', 'group'],
    strata_var_replicate='fold_id',
    replicate_aggregation_mode=None,
    baseline_experiment_name=0,
    strata_var_group='group',
    weight_var='weights',
    n_jobs=-1,
    compute_overall=True,
)

In [99]:
aggregate_path_all = '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts/experiments/bmj_manuscript/'
os.makedirs(aggregate_path_all, exist_ok=True)
result_df_ci.to_csv(os.path.join(aggregate_path_all, 'bootstrap_standard_eval_raw_scratch.csv'), index=False)

In [107]:
model_type_names = {'original_pce': 'PCE',
                    'revised_pce': 'rPCE',
                    'erm': 'UC',
                    'recalib_erm': 'rUC',
                    'eqodds_thr': 'EO'
                   }

metric_names = {'auc':                    'auc', 
                'auprc':                  'auprc',
                'ace_rmse_logistic_log':  'ace',
                'ace_rmse_bin_log':       'ace_bin',
                'loss_bce':               'loss',
                'recall_0.075':           'sensitivity',
                'recall_0.2':             'sensitivity',
                'specificity_0.075':      'specificity',
                'specificity_0.2':        'specificity',
                'observation_rate_0.075': 'impl_threshold',
                'observation_rate_0.2':   'impl_threshold'
                
               }

thresholds = {'recall_0.075':           0.075,
              'recall_0.2':             0.2,
              'specificity_0.075':      0.075,
              'specificity_0.2':        0.2,
              'observation_rate_0.075': 0.075,
              'observation_rate_0.2':   0.2
             }

plot_df = (result_df_ci
           .assign(model_type = lambda x: x.model_type.map(model_type_names),
                   thresholds = lambda x: x.metric.map(thresholds),
                   metric = lambda x: x.metric.map(metric_names)
                  )
          )

model_type = np.where((plot_df.model_type=='EO') & (plot_df.model_id==0.1), 'EO1', plot_df.model_type)
model_type = np.where((plot_df.model_type=='EO') & (plot_df.model_id==0.21544346900318825), 'EO2', model_type)
model_type = np.where((plot_df.model_type=='EO') & (plot_df.model_id==0.4641588833612778), 'EO3', model_type)
model_type = np.where((plot_df.model_type=='EO') & (plot_df.model_id==1.0), 'EO4', model_type)


plot_df = (plot_df
           .assign(model_type = pd.Categorical(model_type, 
                                               categories = ['PCE', 'rPCE', 'UC', 'rUC', 'EO1', 'EO2', 'EO3', 'EO4'],
                                               ordered=True)
                  )
           .drop(columns = ['model_id'])
          )




In [105]:
plot_df

Unnamed: 0,phase,model_type,group,metric,CI_lower,CI_med,CI_upper,thresholds
0,test,EO1,1,ace_bin,0.039305,0.047119,0.053654,
1,test,EO1,1,ace,0.015367,0.018348,0.022893,
2,test,EO1,1,auc,0.781378,0.787854,0.800801,
3,test,EO1,1,auprc,0.168727,0.217446,0.265906,
4,test,EO1,1,loss,0.189369,0.192966,0.196441,
...,...,...,...,...,...,...,...,...
435,test,rPCE,overall,impl_threshold,0.205094,0.209393,0.213693,0.200
436,test,rPCE,overall,sensitivity,0.693555,0.724273,0.754992,0.075
437,test,rPCE,overall,sensitivity,0.230612,0.270420,0.310228,0.200
438,test,rPCE,overall,specificity,0.713356,0.725822,0.738289,0.075


In [108]:
aggregate_path_all = '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts/experiments/bmj_manuscript/'
os.makedirs(aggregate_path_all, exist_ok=True)

plot_df.to_csv(os.path.join(aggregate_path_all, 'bootstrap_standard_eval_scratch.csv'), index=False)



In [89]:
plot_df.query("metric=='impl_threshold' & thresholds==0.2")

Unnamed: 0,phase,model_type,group,metric,CI_lower,CI_med,CI_upper,thresholds
3,test,EO1,1,impl_threshold,0.165554,0.205757,0.261187,0.2
11,test,EO1,2,impl_threshold,0.20149,0.243124,0.295233,0.2
19,test,EO1,3,impl_threshold,0.151028,0.187823,0.238186,0.2
27,test,EO1,4,impl_threshold,0.227819,0.265308,0.309103,0.2
35,test,EO1,overall,impl_threshold,0.211735,0.232343,0.255918,0.2
43,test,EO2,1,impl_threshold,0.188085,0.255168,0.340024,0.2
51,test,EO2,2,impl_threshold,0.220225,0.284894,0.362367,0.2
59,test,EO2,3,impl_threshold,0.165365,0.21563,0.287397,0.2
67,test,EO2,4,impl_threshold,0.25995,0.316628,0.384015,0.2
75,test,EO2,overall,impl_threshold,0.235157,0.276338,0.313722,0.2


In [83]:
plot_df.columns

Index(['phase', 'model_type', 'group', 'metric', 'CI_lower', 'CI_med',
       'CI_upper', 'thresholds'],
      dtype='object')

In [None]:
variance_metrics='sensitivity'
(result_df
 .filter(variance_metrics)
)

In [None]:
stds = (plot_df
        .query("metric!=['AUC', 'AUPRC']")
        .drop(columns=['phase', 'model_id', 'CI_lower', 'CI_upper'])
        .groupby(['model_type', 'thresholds', 'metric', 'group']).sum()
        .reset_index()
        .pivot(index=['model_type', 'thresholds', 'metric'], columns='group', values='CI_med')
        .drop(columns=['overall'])
        .apply(np.std, axis=1)
       )

In [None]:
# eval_dict = {'label_var': 'labels',
#              'pred_prob_var': 'pred_probs',
#              'weight_var': 'weights',
#              'group_var_name': 'group'}
                                      
# eval_overall = standard_evaluator.get_result_df(eval_df,
#                                                     #strata_vars=['model_id', 'fold_id', 'experiment'],
#                                                     **eval_dict)