### Select patients for CFR model: Split patients in train, val and test sets ###

In [None]:
import os
import numpy as np
import pandas as pd

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 50)
pd.set_option('display.width', 1000)

In [None]:
cfr_data_root = os.path.normpath('/mnt/obi0/andreas/data/cfr')
match_view_filename = '210_getStressTest_match365_files_BWH_200131.parquet'
files_cfr = pd.read_parquet(os.path.join(cfr_data_root, 'metadata_200131', match_view_filename))

print('Total number of patients {}'.format(len(files_cfr.mrn.unique())))
print('Total number of studies  {}'.format(len(files_cfr.study.unique())))
print('Total number of files    {}'.format(len(files_cfr.filename.unique())))

files_cfr.head()

### Split the patients in train, validate and test sets ###
Although each view might have a little different patient pupulation distribution, because not all views are in each study. However, we want the same MRNS in each data set and for all views so that we can directly compare the performance of the algorithm for the same patients. We can expand the data frame above to add the splits.

In [None]:
def patientsplit(patient_list):

    train_test_split = 0.86
    train_eval_split = 0.90

    # Split train/test sets
    patient_list_train = np.random.choice(patient_list,
                                          size = int(np.floor(train_test_split*len(patient_list))),
                                          replace = False)
    patient_list_test = list(set(patient_list).difference(patient_list_train))
    train_test_intersection = set(patient_list_train).intersection(set(patient_list_test)) # This should be empty
    print('Intersection of patient_list_train and patient_list_test:', train_test_intersection)

    # Further separate some patients for evaluation
    patient_list_eval = np.random.choice(patient_list_train,
                                         size = int(np.ceil((1-train_eval_split)*len(patient_list_train))),
                                         replace = False)

    patient_list_train = set(patient_list_train).difference(patient_list_eval)
    train_eval_intersection = set(patient_list_train).intersection(set(patient_list_eval))
    print('Intersection of patient_list_train and patient_list_eval:', train_eval_intersection)

    # Show the numbers
    print('total patients:', len(patient_list))
    print('patients in set:', np.sum([len(patient_list_train),
                                     len(patient_list_eval),
                                     len(patient_list_test)]))
    print('patients in train:', len(patient_list_train))
    print('patients in eval:', len(patient_list_eval))
    print('patients in test:', len(patient_list_test))

    return patient_list_train, patient_list_eval, patient_list_test

In [None]:
# Get a patient list
patient_list = list(files_cfr.sample(frac=1).mrn.unique())
patient_list_train, patient_list_eval, patient_list_test = patientsplit(patient_list)

patient_split = {'train': patient_list_train,
                 'eval': patient_list_eval,
                 'test': patient_list_test}

print('Patient IDs in train:', len(patient_split['train']))
print('Patient IDs in eval:', len(patient_split['eval']))
print('Patient IDs in test:', len(patient_split['test']))

print()

print('contamination train-test:', set(patient_split['train']).intersection(set(patient_split['test'])))
print('contamination train-eval:', set(patient_split['train']).intersection(set(patient_split['eval'])))
print('contamination eval-test:', set(patient_split['eval']).intersection(set(patient_split['test'])))

In [None]:
# Add dset column to files_cfr
split_list = []
for dset in patient_split.keys():
    dset_mrn_list = list(patient_split[dset])
    split_list.append(pd.DataFrame({'mrn': dset_mrn_list,
                                    'dset': [dset]*len(dset_mrn_list)}))
split_df = pd.concat(split_list, ignore_index = True)

files_cfr_dset = files_cfr.merge(right = split_df, on = 'mrn', how = 'left').\
                    sample(frac = 1).\
                    reset_index(drop = True)

In [None]:
files_cfr_dset.head(50)

In [None]:
dset_filename = '210_getStressTest_files_dset_BWH_200131.parquet'
files_cfr_dset.to_parquet(os.path.join(cfr_data_root, 'metadata_200131', dset_filename))