In [45]:
# import prediction_utils.pytorch_utils.metrics as metrics
# from prediction_utils.pytorch_utils.metrics import (
#     StandardEvaluator,
#     CalibrationEvaluator
# )

from prediction_utils.pytorch_utils.metrics import *
from collections import ChainMap


class CalibrationEvaluatorNew(CalibrationEvaluator):

    def observation_rate_at_point(
        self,
        point,
        labels,
        pred_probs,
        sample_weight=None,
        model_type="logistic",
        transform=None,
    ):

        df, model = self.get_calibration_density_df(
            labels,
            pred_probs,
            sample_weight=sample_weight,
            model_type=model_type,
            transform=transform,
        )
        
        valid_transforms = ["log", "c_log_log"]
        
        if transform is None:
            point = np.array(point).reshape(-1, 1)
        elif transform in valid_transforms:
            if transform == "log":
                point = np.array(np.log(point)).reshape(-1, 1)
            elif transform == "c_log_log":
                point = np.array(self.c_log_log(point)).reshape(-1, 1)
        else:
            raise ValueError("Invalid transform provided")
        
        calibration_density = model.predict_proba(point)
        if len(calibration_density.shape) > 1:
            calibration_density = calibration_density[:, -1]
            
        return calibration_density[0]

In [42]:
import prediction_utils

In [44]:
prediction_utils.pytorch_utils.metrics

<module 'prediction_utils.pytorch_utils.metrics' from '/labs/shahlab/projects/agataf/prediction_utils/prediction_utils/pytorch_utils/metrics.py'>

<module 'prediction_utils.pytorch_utils.metrics' from '/labs/shahlab/projects/agataf/prediction_utils/prediction_utils/pytorch_utils/metrics.py'>

In [46]:
class StandardEvaluatorNew(StandardEvaluator):
    def get_threshold_metrics(
        self,
        threshold_metrics=None,
        thresholds=[0.01, 0.05, 0.1, 0.2, 0.5],
        weighted=False,
    ):
        """
        Returns a set of metric functions that are defined with respect to a set of thresholds
        """
        if thresholds is None:
            return {}

        if threshold_metrics is None:
            threshold_metrics = [
                "recall",
                "precision",
                "specificity",
            ]  # acts as default value

        result = {}

        if "recall" in threshold_metrics:
            result["recall"] = {
                "recall_{}".format(threshold): generate_recall_at_threshold(
                    threshold, weighted=weighted
                )
                for threshold in thresholds
            }
        if "precision" in threshold_metrics:
            result["precision"] = {
                "precision_{}".format(threshold): generate_precision_at_threshold(
                    threshold, weighted=weighted
                )
                for threshold in thresholds
            }
        if "specificity" in threshold_metrics:
            result["specificity"] = {
                "specificity_{}".format(threshold): generate_specificity_at_threshold(
                    threshold, weighted=weighted
                )
                for threshold in thresholds
            }
        if "observation_rate" in threshold_metrics:
            result["observation_rate"] = {
                "observation_rate_{}".format(threshold): generate_observation_rate_at_threshold(
                    threshold, weighted=weighted
                )
                for threshold in thresholds
            }  
            
        if len(result) > 0:
            return dict(ChainMap(*result.values()))
        else:
            return result


In [47]:
def generate_observation_rate_at_threshold(threshold, weighted=False):
    """
    Returns a lambda function that computes the specificity at a provided threshold.
    If weights = True, the lambda function takes a third argument for the sample weights
    """
    return (
            lambda labels, pred_probs, sample_weight: (
             observation_rate_at_point(threshold, labels, pred_probs,
                                                      sample_weight,
                                                     model_type="logistic", 
                                                     transform='log')))

# def generate_recall_at_threshold(threshold, weighted=False):
#     """
#     Returns a lambda function that computes the recall at a provided threshold.
#     If weights = True, the lambda function takes a third argument for the sample weights
#     """
#     if not weighted:
#         return lambda labels, pred_probs: recall_score(
#             labels, 1.0 * (pred_probs >= threshold)
#         )
#     else:
#         return lambda labels, pred_probs, sample_weight: recall_score(
#             labels, 1.0 * (pred_probs >= threshold), sample_weight=sample_weight
#         )


In [None]:
observation_rate_at_point

In [38]:
def observation_rate_at_point(*args, **kwargs):
    evaluator = CalibrationEvaluatorNew()
    return evaluator.observation_rate_at_point(*args, **kwargs)

In [39]:
import numpy as np
import pandas as pd
import os

grp_label_dict = {1: "Black women", 2: "White women", 3: "Black men", 4: "White men"}

args = {
    "experiment_name": "apr14_erm",
    "cohort_path": "/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts/cohort/all_cohorts.csv",
    "base_path": "/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts",
    "eval_fold": "test",
}
aggregate_path = os.path.join(
    args["base_path"], "experiments", args["experiment_name"], "performance", "all"
)

preds_path = os.path.join(aggregate_path, "predictions.csv")
preds = pd.read_csv(preds_path)
eval_df = preds.query('phase == "test"')

In [40]:
evaluator = CalibrationEvaluatorNew()

df = eval_df
point = 0.075
weight_var="weights"
label_var="labels"
pred_prob_var="pred_probs"
group_var_name="group"
result_name="performance"
group_overall_name="overall"

calibration_df= evaluator.observation_rate_at_point(point, 
                                                      labels=df[label_var],
                                                      pred_probs=df[pred_prob_var],
                                                      sample_weight=df[weight_var],
                                                     model_type="logistic", 
                                                     transform='log')
calibration_df


0.0634914643066256

In [50]:
standard_evaluator = StandardEvaluatorNew(thresholds = [0.075, 0.2],
                                              metrics = ['auc', 'auprc',#'observation_rate',
                                                        #'ace_rmse_logistic_log'
                                                        ],
                                         threshold_metrics=['observation_rate', 'specificity', 'recall'])
eval_dict = {'label_var': 'labels',
             'pred_prob_var': 'pred_probs',
             'weight_var': 'weights',
             'group_var_name': 'group'}
                                      
eval_overall = standard_evaluator.get_result_df(eval_df,
                                                    #strata_vars=['model_id', 'fold_id', 'experiment'],
                                                    **eval_dict)


  if group_overall_name in (df[group_var_name].unique()):


In [51]:
eval_overall.metric.unique()

array(['auc', 'ace_rmse_logistic_log', 'observation_rate_0.075',
       'observation_rate_0.2', 'specificity_0.075', 'specificity_0.2',
       'recall_0.075', 'recall_0.2'], dtype=object)