In [12]:
import numpy as np
import pandas as pd
import os
import git
import argparse

def patient_split(df, test_frac=0.1, eval_frac=0.1, nfold=10, seed=657):

    assert (test_frac > 0.0) & (test_frac < 1.0)
    #assert (eval_frac > 0.0) & (eval_frac < 1.0)

    # Shuffle the patients
    patient_df = df.sample(frac=1, random_state=seed)

    # Record the number of samples in each split
    num_test = int(np.floor(test_frac * patient_df.shape[0]))
    num_eval = int(np.floor(eval_frac * patient_df.shape[0]))
    num_train = patient_df.shape[0] - (num_test + num_eval)

    # Get the number of patients in each fold
    test_patient_df = patient_df.iloc[0:num_test].assign(fold_id="test")
    eval_patient_df = patient_df.iloc[num_test:(num_test+num_eval)].assign(fold_id="eval")
    train_patient_df = patient_df.iloc[(num_test + num_eval):]

    train_patient_df = train_patient_df.assign(
        fold_id=lambda x: np.tile(
            np.arange(1, nfold + 1), int(np.ceil(num_train / nfold))
        )[: x.shape[0]]
    )
    train_patient_df["fold_id"] = train_patient_df["fold_id"].astype(str)
    patient_df = pd.concat([train_patient_df, test_patient_df, eval_patient_df], ignore_index=True)

    df = df.merge(patient_df)
    return df

# parser = argparse.ArgumentParser()
# parser.add_argument("--cohort_path", type=str, help="path where input cohorts are stored", required=False,
#                    default='/labs/shahlab/projects/agataf/data/pooled_cohorts/cohort_extraction')
# parser.add_argument("--output_path", type=str, help="path where aggregated data should be stored", required=False,
#                    default='/labs/shahlab/projects/agataf/data/pooled_cohorts/cohort_extraction/all_cohorts.csv')
# parser.add_argument("--test_frac", type=float, help="fraction of data that should go into test", required=False,
#                    default=0.1)

# args = parser.parse_args()

cohort_path = '/labs/shahlab/projects/agataf/data/pooled_cohorts/cohort_extraction'
output_path = '/labs/shahlab/projects/agataf/data/pooled_cohorts/cohort_extraction/all_cohorts_witheval.csv'
test_frac=0.2
eval_frac=0
cohort_frames = []
for el in ['mesa', 'fhs_os', 'chs', 'aric', 'jhs', 'cardia']:
    cohort_frames.append(pd.read_csv(os.path.join(cohort_path, '.'.join((el, 'csv')))))
    
df = (pd
      .concat(cohort_frames)
      .assign(event_time      = lambda x: x.event_time_10yr,
              event_indicator = lambda x: x.ascvd_10yr)
     )

# Stratified splitting

result = {}
for (grp_name, grp_df) in df.groupby(['censored_10yr', 'ascvd_10yr', 'race_black', 'gender_male', 'study']):
    result[grp_name] = patient_split(grp_df, test_frac=test_frac, eval_frac=eval_frac)

data_df = (pd
           .concat(result, ignore_index=True)
           .reset_index(drop=True)
           .reset_index(drop=False)
           .rename(columns = {'index': 'person_id'})
          )

#data_df.to_csv(os.path.join(args.output_path), index = False)


In [None]:
data_df.to_csv(os.path.join(args.output_path), index = False)

In [15]:
data_df.query("fold_id=='test'")

Unnamed: 0,person_id,cohort_idx,cohort_pid,age,race_black,gender_male,grp,hdlc,ldlc,trigly,...,unrxsbp,rxsbp,study,ascvd_10yr,censored_10yr,event_time_10yr,bmi,event_time,event_indicator,fold_id
3,3,6,C000011,51.0,0.0,0.0,2,62.000,170.600,122.0,...,131.0,0.0,ARIC,False,False,10.000000,24.532728,10.000000,False,test
8,8,16,C000024,48.0,0.0,0.0,2,70.299,75.101,93.0,...,105.0,0.0,ARIC,False,False,10.000000,23.411924,10.000000,False,test
11,11,26,C000037,50.0,0.0,0.0,2,49.113,155.087,299.0,...,99.0,0.0,ARIC,False,False,10.000000,20.341527,10.000000,False,test
21,21,46,C000060,50.0,0.0,0.0,2,43.335,170.265,197.0,...,112.0,0.0,ARIC,False,False,10.000000,31.603212,10.000000,False,test
27,27,62,C000078,51.0,0.0,0.0,2,59.706,125.494,84.0,...,129.0,0.0,ARIC,False,False,10.000000,27.052350,10.000000,False,test
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25603,25603,3561,7014325,79.0,1.0,1.0,3,31.000,135.000,115.0,...,0.0,178.0,MESA,False,True,1.303217,30.166230,1.303217,False,test
25605,25605,3655,7015690,78.0,1.0,1.0,3,53.000,151.000,95.0,...,135.5,0.0,MESA,False,True,0.484600,32.980066,0.484600,False,test
25606,25606,3667,7015860,74.0,1.0,1.0,3,41.000,97.000,62.0,...,0.0,141.0,MESA,False,True,8.889802,29.879926,8.889802,False,test
25613,25613,4198,8010927,53.0,1.0,1.0,3,61.000,133.000,50.0,...,151.0,0.0,MESA,False,True,9.111567,25.453563,9.111567,False,test


In [14]:
data_df_eval.query("(fold_id=='test') | fold_id=='eval'")

Unnamed: 0,person_id,cohort_idx,cohort_pid,age,race_black,gender_male,grp,hdlc,ldlc,trigly,...,unrxsbp,rxsbp,study,ascvd_10yr,censored_10yr,event_time_10yr,bmi,event_time,event_indicator,fold_id
3,3,6,C000011,51.0,0.0,0.0,2,62.000,170.600,122.0,...,131.0,0.0,ARIC,False,False,10.000000,24.532728,10.000000,False,test
8,8,16,C000024,48.0,0.0,0.0,2,70.299,75.101,93.0,...,105.0,0.0,ARIC,False,False,10.000000,23.411924,10.000000,False,eval
11,11,26,C000037,50.0,0.0,0.0,2,49.113,155.087,299.0,...,99.0,0.0,ARIC,False,False,10.000000,20.341527,10.000000,False,eval
21,21,46,C000060,50.0,0.0,0.0,2,43.335,170.265,197.0,...,112.0,0.0,ARIC,False,False,10.000000,31.603212,10.000000,False,test
27,27,62,C000078,51.0,0.0,0.0,2,59.706,125.494,84.0,...,129.0,0.0,ARIC,False,False,10.000000,27.052350,10.000000,False,eval
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25603,25603,3561,7014325,79.0,1.0,1.0,3,31.000,135.000,115.0,...,0.0,178.0,MESA,False,True,1.303217,30.166230,1.303217,False,eval
25605,25605,3655,7015690,78.0,1.0,1.0,3,53.000,151.000,95.0,...,135.5,0.0,MESA,False,True,0.484600,32.980066,0.484600,False,eval
25606,25606,3667,7015860,74.0,1.0,1.0,3,41.000,97.000,62.0,...,0.0,141.0,MESA,False,True,8.889802,29.879926,8.889802,False,eval
25613,25613,4198,8010927,53.0,1.0,1.0,3,61.000,133.000,50.0,...,151.0,0.0,MESA,False,True,9.111567,25.453563,9.111567,False,test


In [6]:
data_df_eval = data_df.copy()