In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
from sklearn.model_selection import KFold,StratifiedKFold,GroupKFold
import seaborn as sns
import time

In [26]:
def get_subset(patients: list, records: np.array):
    subset = []
    
    for record in records.tolist():
        
        if int(record.split('/')[0]) in patients:
        
            subset.append(record)
    
    return subset

# Get list of records

In [19]:
DATA_PATH = '../data/JET_data/processed_data/'
OUT_PATH = '../data/split_tables/'

patients = [i for i in os.listdir(DATA_PATH) if i.find('.')==-1]

len(patients)

13

# Get all records

In [21]:
records = []
record_patient = []

for patient in tqdm(patients):
    
    temp = [patient+'/'+i for i in os.listdir(DATA_PATH+patient) if i.find('.npy')!=-1]
    
    records += temp
    record_patient += [patient]*len(temp)

records = np.array(records)
record_patient = np.array(record_patient)

100%|██████████| 13/13 [00:00<00:00, 495.03it/s]


In [30]:
print('Total number of records: ',len(records))
print('Total number of hours: ',len(records)//720)
print('Total number of patients: ',np.unique(record_patient).shape[0])

Total number of records:  16892
Total number of hours:  23
Total number of patients:  13


# Current split tables

In [31]:
test_patients = [2573187, 2740581, 2337416, 5006378, 2726762, 2799117, 5001360]

train_patients = [2717090, 5014987, 2718860, 5025542]

val_patients = [2240784, 5025547]

In [32]:
records_test = get_subset(patients=test_patients, records=records)
records_train = get_subset(patients=train_patients, records=records)
records_val = get_subset(patients=val_patients, records=records)

In [40]:
#prepare test split

split = {
        'test': records_test,
    }

with open(f'{OUT_PATH}test_split_table_.json', 'w') as outfile:
        json.dump(split, outfile)
        
#prepare train and validation split

split = {
        'train': records_train,
        'val': records_val,
    }

with open(f'{OUT_PATH}0_fold_split_table_.json', 'w') as outfile:
        json.dump(split, outfile)


# For the time when we get more data

In [33]:

kf = GroupKFold(3)

for fold,(train_index, test_index) in enumerate(kf.split(records,records,record_patient)):
    records_test,patients_test = records[test_index],record_patient[test_index]
    
    split = {
        'test': records_test.tolist(),
    }
    
    with open(f'{OUT_PATH}test_split_table.json', 'w') as outfile:
            json.dump(split, outfile)
            
    if fold == 0:
        break
    
print(f'Number of patients in test set: {np.unique(patients_test).shape[0]}')
print(f'Patients: {np.unique(patients_test)}')

Number of patients in test set: 4
Patients: ['2337416' '5006378' '5014987' '5025547']


In [34]:
records = records[train_index].tolist()
record_patient = record_patient[train_index].tolist()

In [38]:
records = np.array(records)
record_patient = np.array(record_patient)

kf = GroupKFold(2)

for fold,(train_index, test_index) in enumerate(kf.split(records,records,record_patient)):
    records_train, records_test = records[train_index], records[test_index]
    patient_train, patient_test = record_patient[train_index], record_patient[test_index]
    
    split = {
        'train': records_train.tolist(),
        'val': records_test.tolist(),
    }
    
    with open(f'../data/split_tables/{fold}_split_table.json', 'w') as outfile:
            json.dump(split, outfile)
    
    print(f'Fold: {fold}')
    print(f'Number of patients in train set: {np.unique(patient_train).shape[0]}')
    print(f'Patients, train: {np.unique(patient_train)}')
    print(f'Number of patients in validation set: {np.unique(patient_test).shape[0]}')
    print(f'Patients, test: {np.unique(patient_test)}')
    print('\n')


Fold: 0
Number of patients in train set: 5
Patients, train: ['2240784' '2573187' '2717090' '2718860' '2740581']
Number of patients in validation set: 4
Patients, test: ['2726762' '2799117' '5001360' '5025542']


Fold: 1
Number of patients in train set: 4
Patients, train: ['2726762' '2799117' '5001360' '5025542']
Number of patients in validation set: 5
Patients, test: ['2240784' '2573187' '2717090' '2718860' '2740581']


