In [4]:
import pandas as pd
import os
import numpy as np
import eval_utils

def add_ranges(df, one_hot=False, threshold1 = 0.075, threshold2 = 0.2):
    
    range1 = (df.pred_probs < threshold1).astype(int)
    range2 = ((df.pred_probs >= threshold1) & (df.pred_probs < threshold2)).astype(int)
    range3 = ((df.pred_probs >= threshold2)).astype(int)

    if one_hot:
        df = df.assign(treat0=range1, treat1=range2, treat2=range3)
    else:
        rang = 1*range2 + 2*range3
        df = df.assign(treat=rang)
        
    return df

### argparse ####
EXPERIMENT_NAME = 'apr14_erm'
args = {'experiment_name': EXPERIMENT_NAME,
        'cohort_path': '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts/cohort/all_cohorts.csv',
        'base_path': '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts'
       }
eval_fold = 'eval'
################

aggregate_path = os.path.join(args['base_path'], 'experiments', 
                              EXPERIMENT_NAME, 'performance',
                              'all')
new_experiment_name = '_'.join((EXPERIMENT_NAME, 'recalib'))
new_aggregate_path = os.path.join(args['base_path'], 'experiments', 
                              new_experiment_name, 'performance',
                              'all')
os.makedirs(new_aggregate_path, exist_ok = True)

preds = pd.read_csv(os.path.join(aggregate_path, 'predictions.csv'))

lin_calibs=[]
test_calibs=[]
for group in [1,2,3,4]:
    for fold_id in range(1,11):    
        max_pred_prob = preds.query("(group==@group)").pred_probs.values.max()
        group_df = preds.query("(group==@group) & (fold_id==@fold_id)")
        group_test = group_df.query("phase=='test'").reset_index(drop=True)
        group_eval = group_df[group_df.phase==eval_fold].reset_index(drop=True)
        

        model = eval_utils.get_calib_model(group_eval, transform='log')

        lin_calib = (eval_utils.get_calib_probs(model, 
                                     np.linspace(1e-15, max_pred_prob, 30),
                                     'log')
                     .assign(group=group)
                    )
        
        test_calib = (eval_utils.get_calib_probs(model, 
                                      group_test.pred_probs.values, 
                                      'log')
                      .merge(group_test)
                      .drop(['pred_probs', 'model_input'], axis=1)
                      .rename(columns={'calibration_density': 'pred_probs'})
                      .assign(calibrated=True)
                     )

        lin_calibs.append(lin_calib)
        test_calibs.append(test_calib)
lin_calibs = pd.concat(lin_calibs)
test_calibs = pd.concat(test_calibs)

test_preds = preds.query("phase=='test'").assign(calibrated=False, lambda_reg=0)
test_calibs = (add_ranges(test_calibs).assign(lambda_reg = 0))

#test_with_recalib = test_preds.append(test_calibs)
test_calibs.to_csv(os.path.join(new_aggregate_path, 'predictions.csv'), index=False)