In [1]:
import pandas as pd
import os
import numpy as np
import logging
import sys
import torch
import copy
import configargparse as argparse

from prediction_utils.util import yaml_write
from prediction_utils.pytorch_utils.models import TorchModel
from prediction_utils.pytorch_utils.lagrangian import MultiLagrangianThresholdRateModel
from prediction_utils.pytorch_utils.robustness import GroupDROModel

from prediction_utils.pytorch_utils.metrics import StandardEvaluator, FairOVAEvaluator, CalibrationEvaluator

import git
repo = git.Repo('.', search_parent_directories=True)
os.chdir(repo.working_dir) 

import train_utils
import yaml

# parser = argparse.ArgumentParser()
# parser.add_argument('--experiment_name', type=str)
# parser.add_argument('--cohort_path', type=str) 
# parser.add_argument('--result_path', type=str)
# parser.add_argument('--logging_path', type=str)
# # parser.add_argument('--base_path', type=str)
# parser.add_argument('--config_id', type=str)
# parser.add_argument('--fold_id', type=str)
# parser.add_argument('--print_debug', type=bool)
# parser.add_argument('--save_outputs', type=bool)
# parser.add_argument('--run_evaluation', type=bool)
# parser.add_argument('--run_evaluation_group_standard', type=bool)
# parser.add_argument('--run_evaluation_group_fair_ova', type=bool)
# parser.add_argument('--save_model_weights', type=bool)
# parser.add_argument('--data_query', type=str)
# parser.add_argument('--base_config_path', type=str)
# parser.add_argument('--config_path', type=str)
# parser.add_argument('--num_epochs', type=int)

# parser.set_defaults(
#     save_outputs=False,
#     run_evaluation=True,
#     run_evaluation_group_standard=True,
#     run_evaluation_group_fair_ova=True,
#     print_debug=True,
#     save_model_weights=False,
#     data_query = '',
#     num_epochs = 0
# )

# args = parser.parse_args()
# args = copy.deepcopy(args.__dict__)


def run_model(args, config_dict):
    
    ##### INITIAL SETUP #####
    os.makedirs(args['result_path'], exist_ok=True)

    logger = train_utils.logger_setup(config_dict, args)

    ##### DATASET #####
    data_df = pd.read_csv(args['cohort_path'])

    if (len(args['data_query']) > 0):
        data_df = (data_df
                   .query(args['data_query'])
                   .reset_index(drop=True)
                  )

    data_args = train_utils.get_dict_subset(config_dict, ['feature_columns', 'val_fold_id', 'test_fold_id', 'batch_size'])
    data = train_utils.Dataset(data_df, deg=2, **data_args)

    # add input dim to dictionary
    config_dict.update({'input_dim': data.features_dict_uncensored_scaled['train'].shape[1]})

    # log
    logger.info("Result path: {}".format(args['result_path']))

    model, logger = train_utils.model_setup(config_dict, logger, args)

    result_df = model.train(loaders=data.loaders_dict)['performance']

    result_df.to_parquet(os.path.join(args['result_path'], "result_df_training.parquet"), index=False, engine="pyarrow")

    if args['save_model_weights']:
        torch.save(model.model.state_dict(), os.path.join(args['result_path'], "state_dict.pt"))

    if args['run_evaluation']:
        logger.info("Evaluating model")

        predict_dict = model.predict(data.loaders_dict_predict, 
                                     phases=['val', 'test'])

        # general evaluation
        output_df_eval, result_df_eval = (
            predict_dict["outputs"],
            predict_dict["performance"]
        )

        logger.info(result_df_eval)

        output_df_eval = (train_utils.add_ranges(output_df_eval)
                          .rename(columns={'row_id': 'person_id'})
                          .merge(data_df.filter(['person_id', 'ldlc']), how='inner', on='person_id')
                          .assign(relative_risk = lambda x: train_utils.treat_relative_risk(x),
                                  new_risk = lambda x: x.pred_probs*x.relative_risk
                                 )
                      )

        # Dump evaluation result to disk
        result_df_eval.to_parquet(
            os.path.join(args['result_path'], "result_df_training_eval.parquet"),
            index=False,
            engine="pyarrow",
        )

        if args.get('save_outputs'):
            output_df_eval.to_parquet(
                os.path.join(args['result_path'], "output_df.parquet"),
                index=False,
                engine="pyarrow",
            )

        logger = train_utils.evaluation(output_df_eval, args, config_dict, logger)
        
EXPERIMENT_NAME = 'eq_oddsconstr'

BASE_PATH = '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts'
args = {'experiment_name': EXPERIMENT_NAME,
        'cohort_path': '/labs/shahlab/projects/agataf/data/pooled_cohorts/cohort_extraction/all_cohorts.csv',
        'base_path': BASE_PATH,
        'print_debug': True,
        'save_outputs': True,
        'run_evaluation_group_standard': True,
        'run_evaluation_group_fair_ova': True,
        'save_model_weights': True,
        'run_evaluation': True,
        'split_gender': False,
        'data_query': ''
       }


BASE_CONFIG_PATH = os.path.join(BASE_PATH, 'experiments', 'basic_config.yaml')
config_dict = yaml.load(open(BASE_CONFIG_PATH), Loader=yaml.FullLoader)

update_dict = {
    "threshold_mode": "conditional",
    "thresholds": [0.075, 0.2],
    "surrogate_scale": 1.0,
    'logging_metrics': ['auc', 'auprc', 'brier', 'loss_bce'],
    'data_query': '',
    'group_objective_type': 'multiThreshold',
    'evaluate_by_group': True,
    'sparse': False,
    'output_dim': 2,
    "num_groups": 4,
    'num_hidden': 0,
    'weighted_loss': True,
    'num_epochs': 10
}

config_dict.update(update_dict)

configs = zip(['00', '01', '02', '03', '04', '05', '06', '07', '08', '09'], np.geomspace(1e-3,1e-1,num=10))

for config_id, lambda_reg in configs:

    for fold_id in range(1,11):

        RESULT_PATH = os.path.join(args['base_path'], 'experiments', args['experiment_name'], 'performance',
                                   '.'.join((str(config_id), 'yaml')), str(fold_id))
        LOGGING_PATH = os.path.join(RESULT_PATH, 'training_log.log')

        config_dict.update({'val_fold_id': str(fold_id), 
                            'num_epochs': 100, 
                            'lambda_group_regularization': lambda_reg,
                            'logging_path': LOGGING_PATH})


        args.update({'result_path': RESULT_PATH})

        run_model(args, config_dict)
