In [1]:
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import logging
import sys
import torch
import configargparse as argparse

from prediction_utils.util import yaml_write
from prediction_utils.pytorch_utils.models import TorchModel
from prediction_utils.pytorch_utils.lagrangian import MultiLagrangianThresholdRateModel
from prediction_utils.pytorch_utils.robustness import GroupDROModel
from prediction_utils.pytorch_utils.group_fairness import EqualThresholdRateModel
from prediction_utils.pytorch_utils.layers import LinearLayer
from prediction_utils.pytorch_utils.metrics import StandardEvaluator, FairOVAEvaluator, CalibrationEvaluator

import git
repo = git.Repo('.', search_parent_directories=True)
os.chdir(repo.working_dir) 

import train_utils
import yaml



EXPERIMENT_NAME = 'eq_oddsconstr'
config_id = '00'
fold_id = '0'
BASE_PATH = '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts'
args = {'experiment_name': EXPERIMENT_NAME,
        'cohort_path': '/labs/shahlab/projects/agataf/data/pooled_cohorts/cohort_extraction/all_cohorts.csv',
        'base_path': BASE_PATH,
        'config_id': config_id,
        'fold_id': fold_id,
        'print_debug': True,
        'save_outputs': True,
        'run_evaluation_group_standard': True,
        'run_evaluation_group_fair_ova': True,
        'save_model_weights': True,
        'run_evaluation': True,
        #'split_gender': True,
        'data_query': ''
       }


BASE_CONFIG_PATH = os.path.join(args['base_path'], 'experiments', 'basic_config.yaml')


RESULT_PATH = os.path.join(args['base_path'], 'experiments', args['experiment_name'], 'performance',
                           '.'.join((args['config_id'], 'yaml')), args['fold_id'])
LOGGING_PATH = os.path.join(RESULT_PATH, 'training_log.log')

args.update({'result_path': RESULT_PATH})

##### INITIAL SETUP #####

os.makedirs(RESULT_PATH, exist_ok=True)

#model_params = yaml.load(open(CONFIG_PATH), Loader=yaml.FullLoader)
config_dict = yaml.load(open(BASE_CONFIG_PATH), Loader=yaml.FullLoader)


config_dict.update({'logging_path': LOGGING_PATH})
update_dict = {
    "threshold_mode": "conditional",
    "thresholds": [0.075, 0.2],
    "surrogate_scale": 1.0,
    'logging_metrics': ['auc', 'auprc', 'brier', 'loss_bce'],
    'data_query': '',
    'group_objective_type': 'multiThreshold',
    'lambda_group_regularization': 0.01,
    'evaluate_by_group': False,
    'sparse': False,
    'output_dim': 2,
    "num_groups": 4,
    "sparse": False,
    'num_hidden': 0,
    'lr': 1e-4,
    'weighted_loss': True
}


config_dict.update(update_dict)
# following https://github.com/som-shahlab/group_robustness_fairness/blob/main/group_robustness_fairness/scripts/tune_baseline_model_starr.sh

In [2]:
# config_dict

In [2]:
## remove ##
config_dict['num_epochs'] = 5

logger = train_utils.logger_setup(config_dict, args)

##### DATASET #####
data_df = pd.read_csv(args['cohort_path'])

if (len(args['data_query']) > 0):
    data_df = (data_df
               .query(args['data_query'])
               .reset_index(drop=True)
              )
    
data_args = train_utils.get_dict_subset(config_dict, ['feature_columns', 'val_fold_id', 'test_fold_id', 'batch_size'])
data = train_utils.Dataset(data_df, deg=2, **data_args)

# add input dim to dictionary
config_dict.update({'input_dim': data.features_dict_uncensored_scaled['train'].shape[1]})

# log
logger.info("Result path: {}".format(args['result_path']))

model, logger = train_utils.model_setup(config_dict, logger, args)

result_df = model.train(loaders=data.loaders_dict)['performance']

result_df.to_parquet(os.path.join(RESULT_PATH, "result_df_training.parquet"), index=False, engine="pyarrow")

if args['save_model_weights']:
    torch.save(model.model.state_dict(), os.path.join(RESULT_PATH, "state_dict.pt"))
    
if args['run_evaluation']:
    logger.info("Evaluating model")

    predict_dict = model.predict(data.loaders_dict_predict, 
                                 phases=['val', 'test'])
    
    # general evaluation
    output_df_eval, result_df_eval = (
        predict_dict["outputs"],
        predict_dict["performance"]
    )

    logger.info(result_df_eval)
    
    output_df_eval = (train_utils.add_ranges(output_df_eval)
                      .rename(columns={'row_id': 'person_id'})
                      .merge(data_df.filter(['person_id', 'ldlc']), how='inner', on='person_id')
                      .assign(relative_risk = lambda x: train_utils.treat_relative_risk(x),
                              new_risk = lambda x: x.pred_probs*x.relative_risk
                             )
                  )
    
    # Dump evaluation result to disk
    result_df_eval.to_parquet(
        os.path.join(args['result_path'], "result_df_training_eval.parquet"),
        index=False,
        engine="pyarrow",
    )

    if args.get('save_outputs'):
        output_df_eval.to_parquet(
            os.path.join(args['result_path'], "output_df.parquet"),
            index=False,
            engine="pyarrow",
        )
    
    logger = train_utils.evaluation(output_df_eval, args, config_dict, logger)