In [1]:
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import logging
import sys
import torch
import configargparse as argparse

from prediction_utils.util import yaml_write
from prediction_utils.pytorch_utils.models import TorchModel
from prediction_utils.pytorch_utils.lagrangian import MultiLagrangianThresholdRateModel
from prediction_utils.pytorch_utils.robustness import GroupDROModel

from prediction_utils.pytorch_utils.metrics import StandardEvaluator, FairOVAEvaluator, CalibrationEvaluator

import git
repo = git.Repo('.', search_parent_directories=True)
os.chdir(repo.working_dir) 

import train_utils
import yaml

In [11]:
EXPERIMENT_NAME = 'erm'
config_id = '00'
fold_id = '0'
BASE_PATH = '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts'
args = {'experiment_name': 'erm',
        'cohort_path': '/labs/shahlab/projects/agataf/data/pooled_cohorts/cohort_extraction/all_cohorts.csv',
        'base_path': '/labs/shahlab/projects/agataf/data/cohorts/pooled_cohorts',
        'config_id': '00',
        'fold_id': '0',
        'print_debug': True,
        'save_outputs': True,
        'run_evaluation_group_standard': True,
        'run_evaluation_group_fair_ova': True,
        'save_model_weights': True,
        'run_evaluation': True,
        #'split_gender': True,
        #'data_query': 'gender_male==0'
       }


BASE_CONFIG_PATH = os.path.join(args['base_path'], 'experiments', args['experiment_name'], 'basic_config.yaml')
CONFIG_PATH = os.path.join(args['base_path'], 'experiments', args['experiment_name'], 'config',
                           '.'.join((args['config_id'], 'yaml')))

RESULT_PATH = os.path.join(args['base_path'], 'experiments', args['experiment_name'], 'performance',
                           '.'.join((args['config_id'], 'yaml')), args['fold_id'])
LOGGING_PATH = os.path.join(RESULT_PATH, 'training_log.log')

args.update({'result_path': RESULT_PATH})
# following https://github.com/som-shahlab/group_robustness_fairness/blob/main/group_robustness_fairness/scripts/tune_baseline_model_starr.sh

In [31]:
##### INITIAL SETUP #####

os.makedirs(RESULT_PATH, exist_ok=True)

model_params = yaml.load(open(CONFIG_PATH), Loader=yaml.FullLoader)
config_dict = yaml.load(open(BASE_CONFIG_PATH), Loader=yaml.FullLoader)

config_dict.update(model_params)
config_dict.update({'logging_path': LOGGING_PATH})

## remove ##
config_dict['num_epochs'] = 5

logger = train_utils.logger_setup(config_dict, args)

##### DATASET #####
data_df = pd.read_csv(args['cohort_path'])

if (len(args['data_query']) > 0):
    data_df = (data_df
               .query(args['data_query'])
               .reset_index(drop=True)
              )
    
data_args = train_utils.get_dict_subset(config_dict, ['feature_columns', 'val_fold_id', 'test_fold_id', 'batch_size'])
data = train_utils.Dataset(data_df, deg=2, **data_args)

# add input dim to dictionary
config_dict.update({'input_dim': data.features_dict_uncensored_scaled['train'].shape[1]})

# log
logger.info("Result path: {}".format(args['result_path']))

model, logger = train_utils.model_setup(config_dict, logger, args)

result_df = model.train(loaders=data.loaders_dict)['performance']

result_df.to_parquet(os.path.join(RESULT_PATH, "result_df_training.parquet"), index=False, engine="pyarrow")

if args['save_model_weights']:
    torch.save(model.model.state_dict(), os.path.join(RESULT_PATH, "state_dict.pt"))
    
if args['run_evaluation']:
    logger.info("Evaluating model")

    predict_dict = model.predict(data.loaders_dict_predict, 
                                 phases=['val', 'test'])
    
    logger = train_utils.evaluation(predict_dict, args, config_dict, logger)