In [95]:
from __future__ import print_function, division
import os
import torch
import numpy as np
import pandas as pd
import math
import re
import pdb
import pickle
from scipy import stats

from torch.utils.data import Dataset
import h5py

from utils.utils import generate_split, nth

from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset, save_splits

In [96]:
# csv_path = 'dataset_csv/fungal_vs_nonfungal.csv'
csv_path = 'dataset_csv/tumor_vs_normal_dummy_clean.csv'
shuffle = False
print_info = True
label_dict = {'normal_tissue':0, 'tumor_tissue':1}
# label_dict = {'nonfungal':0, 'fungal':1}
patient_strat=False
ignore=[]

seed = 7
filter_dict = {}
label_col = 'label'
patient_voting = 'max'

In [97]:
def patient_data_prep(patient_voting='max'):
    patients = np.unique(np.array(slide_data['case_id'])) # get unique patients
    patient_labels = []

    for p in patients:
        locations = slide_data[slide_data['case_id'] == p].index.tolist()
        assert len(locations) > 0
        label = slide_data['label'][locations].values
        if patient_voting == 'max':
            label = label.max() # get patient label (MIL convention)
        elif patient_voting == 'maj':
            label = stats.mode(label)[0]
        else:
            raise NotImplementedError
        patient_labels.append(label)

    patient_data = {'case_id':patients, 'label':np.array(patient_labels)}
    return patient_data


In [98]:
def cls_ids_prep(patient_data):
    # store ids corresponding each class at the patient or case level
    patient_cls_ids = [[] for i in range(num_classes)]
    for i in range(num_classes):
        patient_cls_ids[i] = np.where(patient_data['label'] == i)[0]

    # store ids corresponding each class at the slide level
    slide_cls_ids = [[] for i in range(num_classes)]
    for i in range(num_classes):
        slide_cls_ids[i] = np.where(slide_data['label'] == i)[0]

In [99]:
def filter_df(df, filter_dict={}):
    if len(filter_dict) > 0:
        filter_mask = np.full(len(df), True, bool)
        # assert 'label' not in filter_dict.keys()
        for key, val in filter_dict.items():
            mask = df[key].isin(val)
            filter_mask = np.logical_and(filter_mask, mask)
        df = df[filter_mask]
    return df

In [100]:
def df_prep(data, label_dict, ignore, label_col):
    if label_col != 'label':
        data['label'] = data[label_col].copy()

    mask = data['label'].isin(ignore)
    data = data[~mask]
    data.reset_index(drop=True, inplace=True)
    for i in data.index:
        key = data.loc[i, 'label']
        data.at[i, 'label'] = label_dict[key]

    return data

In [101]:
class Generic_MIL_Dataset(Generic_WSI_Classification_Dataset):
	def __init__(self,
		data_dir,
		**kwargs):

		super(Generic_MIL_Dataset, self).__init__(**kwargs)
		self.data_dir = data_dir
		self.use_h5 = False

	def load_from_h5(self, toggle):
		self.use_h5 = toggle

	def __getitem__(self, idx):
		slide_id = self.slide_data['slide_id'][idx]
		label = self.slide_data['label'][idx]
		if type(self.data_dir) == dict:
			source = self.slide_data['source'][idx]
			data_dir = self.data_dir[source]
		else:
			data_dir = self.data_dir

		if not self.use_h5:
			if self.data_dir:
				full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id))
				features = torch.load(full_path)
				return features, label

			else:
				return slide_id, label

		else:
			full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id))
			with h5py.File(full_path,'r') as hdf5_file:
				features = hdf5_file['features'][:]
				coords = hdf5_file['coords'][:]

			features = torch.from_numpy(features)
			return features, label, coords

In [102]:
class Generic_Split(Generic_MIL_Dataset):
    def __init__(self, slide_data, data_dir=None, num_classes=2):
        self.use_h5 = False
        self.slide_data = slide_data
        self.data_dir = data_dir
        self.num_classes = num_classes
        self.slide_cls_ids = [[] for i in range(self.num_classes)]
        for i in range(self.num_classes):
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]

    def __len__(self):
        return len(self.slide_data)


In [103]:
def get_split_from_df(all_splits, split_key='train'):
    split = all_splits[split_key]
    split = split.dropna().reset_index(drop=True)

    if len(split) > 0:
        mask = slide_data['slide_id'].isin(split.tolist())
        df_slice = slide_data[mask].reset_index(drop=True)
        split = Generic_Split(df_slice, data_dir=data_dir, num_classes=num_classes)
    else:
        split = None

    return split

In [104]:
def return_splits(from_id=True, csv_path=None):


    if from_id:
        if len(train_ids) > 0:
            train_data = slide_data.loc[train_ids].reset_index(drop=True)
            train_split = Generic_Split(train_data, data_dir=data_dir, num_classes=num_classes)

        else:
            train_split = None

        if len(val_ids) > 0:
            val_data = slide_data.loc[val_ids].reset_index(drop=True)
            val_split = Generic_Split(val_data, data_dir=data_dir, num_classes=num_classes)

        else:
            val_split = None

        if len(self.test_ids) > 0:
            test_data = slide_data.loc[test_ids].reset_index(drop=True)
            test_split = Generic_Split(test_data, data_dir=data_dir, num_classes=num_classes)

        else:
            test_split = None


    else:
        assert csv_path
        all_splits = pd.read_csv(csv_path, dtype=slide_data['slide_id'].dtype)  # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01.
        train_split = get_split_from_df(all_splits, 'train')
        val_split = get_split_from_df(all_splits, 'val')
        test_split = get_split_from_df(all_splits, 'test')

    return train_split, val_split, test_split

In [105]:
num_classes = len(set(label_dict.values()))
train_ids, val_ids, test_ids = (None, None, None)
data_dir = None
if not label_col:
	label_col = 'label'
label_col = label_col

slide_data = pd.read_csv(csv_path)
slide_data = filter_df(slide_data, filter_dict)
slide_data = df_prep(slide_data, label_dict, ignore, label_col)
print(slide_data)

###shuffle data
if shuffle:
	np.random.seed(seed)
	np.random.shuffle(slide_data)

slide_data = slide_data

patient_data = patient_data_prep(patient_voting)
cls_ids_prep(patient_data)

# if print_info:
# 	print("label column: {}".format(label_col))
# 	print("label dictionary: {}".format(label_dict))
# 	print("number of classes: {}".format(num_classes))
# 	print("slide-level counts: ", '\n', slide_data['label'].value_counts(sort = False))
# 	for i in range(num_classes):
# 		print('Patient-LVL; Number of samples registered in class %d: %d' % (i, patient_cls_ids[i].shape[0]))
# 		print('Slide-LVL; Number of samples registered in class %d: %d' % (i, slide_cls_ids[i].shape[0]))


         case_id   slide_id label
0      patient_0    slide_0     1
1      patient_0    slide_1     0
2      patient_1    slide_2     1
3      patient_2    slide_3     0
4      patient_2    slide_4     0
..           ...        ...   ...
495  patient_445  slide_495     1
496  patient_446  slide_496     0
497  patient_447  slide_497     0
498  patient_448  slide_498     0
499  patient_449  slide_499     1

[500 rows x 3 columns]


In [106]:
split_datasets = slide_data
splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))]

print(splits)

KeyError: 0