In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold

In [2]:
##
## Config

k_fold_splits = 6
rng_seed = 42

# data paths
frame = '/zfs/wficai/pda/model_data/20221028_frame.csv'
video = '/zfs/wficai/pda/model_data/20221028_video.csv'
study = '/zfs/wficai/pda/model_data/20221028_study.csv'
patient_study = '/zfs/wficai/pda/model_data/20221028_patient_study.csv'
patient = '/zfs/wficai/pda/model_data/20221028_patient.csv'

# output frames file
output_csv = '/zfs/wficai/pda/model_data/pda_train_val_test.csv'

# load data
df_frame = pd.read_csv(frame)
df_video = pd.read_csv(video)
df_study = pd.read_csv(study)
df_patient_study = pd.read_csv(patient_study)
df_patient = pd.read_csv(patient)

In [3]:
# split data
df_patient['cv_split'] = None
cv = StratifiedKFold(n_splits=k_fold_splits, shuffle=True, random_state=rng_seed)
for ix, (train_ix, test_ix) in enumerate(cv.split(X = df_patient, y = df_patient.patient_type)):
    df_patient.iloc[test_ix, -1] = ix

In [4]:
split_map = {i: 'TRAIN' for i in range(k_fold_splits)}
split_map[k_fold_splits-1] = 'TEST'
split_map[k_fold_splits-2] = 'VAL'
df_patient['Split'] = df_patient.cv_split.map(split_map)

In [5]:
df_patient.cv_split.value_counts()

2    11
4    11
0    11
3    11
5    11
1    11
Name: cv_split, dtype: int64

In [6]:
df_patient.Split.value_counts()

TRAIN    44
VAL      11
TEST     11
Name: Split, dtype: int64

### Design subsamples for testing sample efficiency

In [7]:
# use logarithmically spaced subsample fractions
start = 5
stop = 44
num_steps = 6
subset_nums = np.linspace(start, stop, num=num_steps).astype(int)
subset_nums

array([ 5, 12, 20, 28, 36, 44])

In [8]:
# extract training patients
df_pat_train, df_pat_val, df_pat_test = [df_patient.query('Split == @split') for split in ('TRAIN', 'VAL', 'TEST')]

#shuffle the train set
df_pat_train = df_pat_train.sample(frac=1, replace=False)

# form the new sample columns
for sn in subset_nums:
    column = f'patient_sample_{sn}'
    df_pat_train[column] = False
    df_pat_train.iloc[:sn, -1] = True
    
# recombine into single patient table
df_patient = pd.concat([df_pat_train, df_pat_val, df_pat_test])

# fill missings with True for val/test splits
df_patient = df_patient.fillna(True)

In [14]:
df_patient.head(20)

Unnamed: 0,patient_id,patient_type,num_studies,cv_split,Split,patient_sample_5,patient_sample_12,patient_sample_20,patient_sample_28,patient_sample_36,patient_sample_44
53,d89a49b21f3a94ee,pda,1,1,TRAIN,True,True,True,True,True,True
57,ded498e8adf7852f,pda,1,2,TRAIN,True,True,True,True,True,True
54,d9d3ed4a1ab8b062,nopda,1,0,TRAIN,True,True,True,True,True,True
55,db6c7756ec256c5d,nopda,3,3,TRAIN,True,True,True,True,True,True
64,f87e8f22175b90d2,nopda,1,1,TRAIN,True,True,True,True,True,True
36,7b0bda0dcd4f6c57,nopda,1,2,TRAIN,False,True,True,True,True,True
58,e54c49ddb1688524,pda,1,3,TRAIN,False,True,True,True,True,True
50,bc469a5717698f88,pda,1,3,TRAIN,False,True,True,True,True,True
19,4725d50f8dacb4da,nopda,3,3,TRAIN,False,True,True,True,True,True
38,7ea606265b1074dc,nopda,1,3,TRAIN,False,True,True,True,True,True


In [10]:
# merge on video data
df_vid = df_patient.drop('patient_type', axis=1).\
    merge(df_patient_study, on='patient_id').\
    merge(df_study, on=['patient_type', 'study']).\
    merge(df_video, on=['patient_type', 'study'])

df_vid.shape

(5086, 21)

In [11]:
# number of videos by subset
for sn in subset_nums:
    column = f'patient_sample_{sn}'
    dfq = df_vid.query("mode != '2d' and view != 'nonPDAView'")
    print(dfq.groupby('Split')[column].sum())

Split
TEST     272
TRAIN     46
VAL      118
Name: patient_sample_5, dtype: int64
Split
TEST     272
TRAIN    143
VAL      118
Name: patient_sample_12, dtype: int64
Split
TEST     272
TRAIN    258
VAL      118
Name: patient_sample_20, dtype: int64
Split
TEST     272
TRAIN    465
VAL      118
Name: patient_sample_28, dtype: int64
Split
TEST     272
TRAIN    621
VAL      118
Name: patient_sample_36, dtype: int64
Split
TEST     272
TRAIN    755
VAL      118
Name: patient_sample_44, dtype: int64


In [12]:
# merge on the frame data
df = df_vid.merge(df_frame, on=['patient_type', 'external_id'])

df = df.drop(['num_videos_x', 'num_videos_y', 'num_frames_x', 'num_frames_y', 'diagnosis'], axis=1)
df

Unnamed: 0,patient_id,num_studies,cv_split,Split,patient_sample_5,patient_sample_12,patient_sample_20,patient_sample_28,patient_sample_36,patient_sample_44,patient_type,study,external_id,view,mode,mp4_path,png_path
0,d89a49b21f3a94ee,1,1,TRAIN,True,True,True,True,True,True,pda,study47,study47_dicom17,pdaRelatedView,2d,/zfs/wficai/pda/batch_1/PDA_Batch_1/Superior V...,/zfs/wficai/pda/model_data/20221028/pda_study4...
1,d89a49b21f3a94ee,1,1,TRAIN,True,True,True,True,True,True,pda,study47,study47_dicom17,pdaRelatedView,2d,/zfs/wficai/pda/batch_1/PDA_Batch_1/Superior V...,/zfs/wficai/pda/model_data/20221028/pda_study4...
2,d89a49b21f3a94ee,1,1,TRAIN,True,True,True,True,True,True,pda,study47,study47_dicom17,pdaRelatedView,2d,/zfs/wficai/pda/batch_1/PDA_Batch_1/Superior V...,/zfs/wficai/pda/model_data/20221028/pda_study4...
3,d89a49b21f3a94ee,1,1,TRAIN,True,True,True,True,True,True,pda,study47,study47_dicom17,pdaRelatedView,2d,/zfs/wficai/pda/batch_1/PDA_Batch_1/Superior V...,/zfs/wficai/pda/model_data/20221028/pda_study4...
4,d89a49b21f3a94ee,1,1,TRAIN,True,True,True,True,True,True,pda,study47,study47_dicom17,pdaRelatedView,2d,/zfs/wficai/pda/batch_1/PDA_Batch_1/Superior V...,/zfs/wficai/pda/model_data/20221028/pda_study4...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
471918,efa3614c94506474,1,5,TEST,True,True,True,True,True,True,nopda,study55,study55_dicom122,nonPDAView,color,/zfs/wficai/pda/batch_1/Non-PDA_Batch_1/Superi...,/zfs/wficai/pda/model_data/20221028/nopda_stud...
471919,efa3614c94506474,1,5,TEST,True,True,True,True,True,True,nopda,study55,study55_dicom122,nonPDAView,color,/zfs/wficai/pda/batch_1/Non-PDA_Batch_1/Superi...,/zfs/wficai/pda/model_data/20221028/nopda_stud...
471920,efa3614c94506474,1,5,TEST,True,True,True,True,True,True,nopda,study55,study55_dicom122,nonPDAView,color,/zfs/wficai/pda/batch_1/Non-PDA_Batch_1/Superi...,/zfs/wficai/pda/model_data/20221028/nopda_stud...
471921,efa3614c94506474,1,5,TEST,True,True,True,True,True,True,nopda,study55,study55_dicom122,nonPDAView,color,/zfs/wficai/pda/batch_1/Non-PDA_Batch_1/Superi...,/zfs/wficai/pda/model_data/20221028/nopda_stud...


In [13]:
# save splits
df.to_csv(output_csv, index=None)