In [1]:
import pandas as pd
import numpy as np
import pickle
import copy
import math
import os

In [2]:
def store_data(data, file_path):
    with open(file_path, 'wb') as file:
        pickle.dump(data, file)

def get_frac_split(meta_df, matching_field, ind_column, num_folds=5):
    # Copy dataframe.
    df = meta_df.copy(deep=True)

    # Get unique classes.
    unique_classes = np.unique(meta_df[ind_column])
    # randomize rows
    df = df.sample(frac=1).reset_index(drop=True)

    folds          = dict()
    for i in range(num_folds):
        folds[i] = dict()
        folds[i]['train'] = list()
        folds[i]['test']  = list()

    for class_ in unique_classes:
        # Get slides for class.
        slides      = np.unique(df[df[ind_column]==class_][matching_field].values)

        # Test size.
        num_samples = len(slides)
        test_size   = math.floor(num_samples*(1/5))

        # Iterate through chunks and add samples to fold.
        for i in range(num_folds):
            test_sample  = slides[i*test_size:(i+1)*test_size]
            train_sample = list(set(slides).difference(set(test_sample)))
            folds[i]['train'].extend(train_sample)
            folds[i]['test'].extend(test_sample)

    return folds

def get_folds(meta_df, matching_field, ind_column, num_folds=5, valid_set=False):

    # Get initial split for train/test.
    folds = get_frac_split(meta_df, matching_field, ind_column, num_folds=num_folds)

    for i in range(num_folds):
        whole_train_samples = folds[i]['train']
        subset_df = meta_df[meta_df[matching_field].isin(whole_train_samples)]
        train_val_folds = get_frac_split(subset_df, matching_field, ind_column, num_folds=num_folds)
        del folds[i]['train']
        folds[i]['train'] = train_val_folds[0]['train']
        folds[i]['valid'] = train_val_folds[0]['test']

    return folds

# Verify: This should all be empty.
def sanity_check_overlap(folds, num_folds):
    # For each fold, no overlap between cells.
    for i in range(num_folds):
        result = set(folds[i]['train']).intersection(set(folds[i]['valid']))
        if len(result) > 0:
            print(result)

        result = set(folds[i]['train']).intersection(set(folds[i]['test']))
        if len(result) > 0:
            print(result)

        result = set(folds[i]['valid']).intersection(set(folds[i]['test']))
        if len(result) > 0:
            print(result)

        # No overlap between test sets of all folds.
        for i in range(num_folds):
            for j in range(num_folds):
                if i==j: continue
                result = set(folds[i]['test']).intersection(set(folds[j]['test']))
                if len(result) > 0:
                    print('Fold %s-%s' % (i,j), result)

# Fit for legacy code.
def fit_format(folds):
    slides_folds = dict()
    for i, fold in enumerate(folds):
        slides_folds[i] = dict()
        slides_folds[i]['train'] = [(slide, None, None) for slide in folds[i]['train']]
        slides_folds[i]['valid'] = [(slide, None, None) for slide in folds[i]['valid']]
        slides_folds[i]['test']  = [(slide, None, None) for slide in folds[i]['test']]

    return slides_folds


In [3]:

meta_csv    = './tcga_panCancer.csv'
pickle_path = './tcga_panCancer.pkl'



# Read meta data file, rename column.
meta_df  = pd.read_csv(meta_csv)
cancer_types = meta_df['type'].values
del meta_df['type']
meta_df['cancer_types'] = cancer_types

# Create mapping for cancer types and integers.
mapping_cancers = dict(zip(np.unique(cancer_types), range(len(np.unique(cancer_types)))))

# Map new columns for integer indicator.
meta_df['cancer_types_ind'] = meta_df['cancer_types'].astype(str).map(mapping_cancers)

FileNotFoundError: [Errno 2] File b'./tcga_panCancer.csv' does not exist: b'./tcga_panCancer.csv'

In [7]:

folds       = get_folds(meta_df, matching_field='slides', ind_column='cancer_types_ind', num_folds=5, valid_set=True)
final_folds = fit_format(folds)

# If no output, all good.
sanity_check_overlap(folds, num_folds=5)

store_data(final_folds, pickle_path)