# Create Patches

In [4]:
function_type = 'tile_annotations'  # ['tile', 'tile_annotations', 'artefact_annotations']

if function_type == 'tile':
    # Regular patching
    source_dir = os.path.abspath("/kaggle/input/fungal-10x/")
    patch_dir = os.path.abspath('/kaggle/working/patches/')
elif function_type == 'tile_annotations':
    # Patching for annotations
    source_dir = os.path.abspath("/kaggle/input/fungal-10x-annot/")
    patch_dir = os.path.abspath('/kaggle/working/patch_annot/')
    
patch_size = 256
thresholds = {
    'annotations': 254,  # For the annotated images
    'artefacts': 50,  # For artefacts in all images
    'patch_positive': 10000  # For positive label
}

In [10]:
import os
import argparse
import yaml
import cv2
import pickle
import numpy as np
from PIL import Image
from itertools import product
import matplotlib.pyplot as plt

In [6]:
def save_pkl(filename, save_object):
    writer = open(filename,'wb')
    pickle.dump(save_object, writer)
    writer.close()

def load_pkl(filename):
    loader = open(filename,'rb')
    file = pickle.load(loader)
    loader.close()
    return file

In [7]:
def tile(filename, dir_in, dir_out, d):
    if not os.path.isdir(dir_out):
        os.mkdir(dir_out)

    name, ext = os.path.splitext(filename)
    img = Image.open(os.path.join(dir_in, filename))
    w, h = img.size

    grid = product(range(0, h-h%d, d), range(0, w-w%d, d))
    for i, j in grid:
        box = (j, i, j+d, i+d)
        i /= 256
        j /= 256
        out = os.path.join(dir_out, f'{name}_{int(i)}_{int(j)}{ext}')
        img.crop(box).save(out)

        
def tile_annotations(filename, dir_in, dir_out, d):
    if not os.path.isdir(dir_out):
        os.mkdir(dir_out)

    patch_scores = []
    name, ext = os.path.splitext(filename)
    img_cv = cv2.imread(os.path.join(dir_in, filename))
    img_cv_gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)  # Convert to grayscale
    # Thresholding options: ['THRESH_BINARY', 'THRESH_BINARY_INV', 'THRESH_TOZERO ', 'THRESH_TOZERO_INV', 'THRESH_OTSU']
    ret, img_cv_binarized = cv2.threshold(img_cv_gray, thresholds['annotations'], 255, cv2.THRESH_TOZERO)  # Apply thresholding
    img_pil_binarized = cv2.cvtColor(img_cv_binarized, cv2.COLOR_BGR2RGB)  # Convert to RGB, for PIL Image
    img_pil_binarized = Image.fromarray(img_pil_binarized)  # Convert to PIL Image
    w, h = img_pil_binarized.size

    grid = product(range(0, h-h%d, d), range(0, w-w%d, d))
    for i, j in grid:
        box = (j, i, j+d, i+d)
        i /= 256
        j /= 256
        out = os.path.join(dir_out, f'{name}_{int(i)}_{int(j)}{ext}')

        img_patch = img_pil_binarized.crop(box)

        img_patch_np = np.asarray(img_patch)  # Convert to Numpy array
        patch_non_zero = np.count_nonzero(img_patch_np)
        patch_scores.append(patch_non_zero)

        img_patch.save(out)  # Save patch image

    print("P", patch_scores)

    bin_scores = []
    for score in patch_scores:
        bin_score = (score > thresholds['patch_positive']) if 1 else 0
        bin_scores.append(bin_score)

    save_path = os.path.join(dir_out, name+".pkl")
    save_object = {
        "patch_scores": patch_scores,
        "bin_scores": bin_scores
    }
    save_pkl(save_path, save_object)

In [72]:
if not os.path.isdir(patch_dir):
    os.mkdir(patch_dir)

for filename in os.listdir(source_dir):
    name, ext = os.path.splitext(filename)
    output_patches_dir = os.path.join(patch_dir, name)

    if function_type == 'tile':
        print("Patching", filename)
        tile(filename, source_dir, output_patches_dir, patch_size)
    elif function_type == 'tile_annotations':
        print("Binarizing and Patching Annotated", filename)
        tile_annotations(filename, source_dir, output_patches_dir, patch_size)
    elif function_type == 'artefact_annotations':
        print("Binarizing and Patching Artefacts", filename)
        artefact_annotations(filename, source_dir, patch_dir, patch_size)
    else:
        print("Unknown function_type")

Binarizing and Patching Annotated F052a17.tif
P [1758, 0, 0, 0, 11256, 15765, 9333, 13179, 5532, 0, 636, 1410, 5640, 13674, 1188, 0, 0, 0, 5136, 606, 915, 0, 0, 0]
Binarizing and Patching Annotated F021a01.tif
P [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 261, 0, 0, 0, 0, 8172, 8367, 0, 0, 0, 0, 0, 0, 0]
Binarizing and Patching Annotated F033a12.tif
P [34302, 24297, 36282, 21402, 690, 0, 20457, 14700, 23901, 27651, 14691, 1845, 14886, 7566, 2319, 21015, 8235, 6045, 20211, 14979, 9825, 1494, 5793, 1683]
Binarizing and Patching Annotated F017a09.tif
P [2847, 3048, 0, 7785, 111, 0, 0, 0, 0, 6843, 8436, 729, 0, 0, 2583, 8607, 8862, 7611, 5643, 9033, 4644, 1890, 2610, 3459]
Binarizing and Patching Annotated F052a06.tif
P [5025, 1881, 4518, 7605, 1320, 975, 8499, 17136, 14025, 15303, 9891, 0, 5487, 16491, 13971, 12336, 9390, 1515, 16680, 20250, 7707, 1338, 14547, 4347]
Binarizing and Patching Annotated F053a01.tif
P [0, 0, 0, 0, 0, 150, 399, 0, 0, 3, 0, 5676, 396, 78, 0, 1914, 2184, 7788, 39, 234, 2862, 

# Feature Extraction

In [11]:
seed = 1
patch_dir = os.path.abspath('/kaggle/working/patches/')
patch_annot_dir = os.path.abspath('/kaggle/working/patches_annot/')
feat_dir = os.path.abspath('/kaggle/working/features/')

In [12]:
import os
import yaml
import argparse

# import h5py
import cv2
import numpy as np
from PIL import Image

import tensorflow as tf
from tensorflow import keras

from keras.models import Sequential
from keras.layers import Dense
from keras.utils.np_utils import to_categorical

from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input

from tensorflow.keras.utils import img_to_array
from tensorflow.keras.preprocessing.image import load_img
from keras.callbacks import ModelCheckpoint

from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score, confusion_matrix

In [57]:
# !mkdir data
# F = os.path.abspath('/kaggle/input/fungal-10x/')
# !cp -r $F data
# !ls data

# !mkdir data/train
# !mkdir data/train/fungal
# !mv data/F* data/train/fungal/
# !mkdir data/train/nonfungal
# !mv data/N* data/train/nonfungal/

# !ls data/train/fungal | wc -l
# !ls data/train/nonfungal | wc -l

In [63]:
epochs = 20

train_ds, validation_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.abspath("/kaggle/working/data/train/fungal"),
    labels="inferred",
    batch_size=32,
    image_size=(256, 256),
    seed=seed
)

Found 0 files belonging to 0 classes.


ValueError: No images found.

In [29]:
# Create feat_dir if not exists.
if not os.path.isdir(feat_dir):
    os.mkdir(feat_dir)

data_augmentation = tf.keras.Sequential([
  keras.layers.RandomFlip("horizontal_and_vertical"),
  keras.layers.RandomRotation(0.2),
])
    
# Loading ResNet50 wit imagenet weights, include_top means that we loading model without last fully connected layers
base_model  = ResNet50(weights = 'imagenet', include_top = False)
base_model.trainable = False  # Freeze base_model

# Create new model on top
inputs = keras.Input(shape=(256, 256, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.25)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1, activation="relu")(x)
model = keras.Model(inputs, outputs)

model.summary()

Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_15 (InputLayer)        [(None, 256, 256, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 256, 256, 3)       0         
_________________________________________________________________
resnet50 (Functional)        (None, None, None, 2048)  23587712  
_________________________________________________________________
global_average_pooling2d_3 ( (None, 2048)              0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 2048)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 2049      
Total params: 23,589,761
Trainable params: 2,049
Non-trainable params: 23,587,712
___________________________________________

In [30]:
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)

NameError: name 'train_ds' is not defined

In [None]:
# patch_folders = [os.path.join(patch_dir, folder) for folder in sorted(os.listdir(patch_dir))]
# patches_per_image = len(os.listdir(patch_folders[0]))
# print(patches_per_image)

# Create dataset from the image patches
for folder in sorted(os.listdir(patch_dir)):
    filename = str(folder).split("/")[-1]
    filePath = os.path.join(feat_dir, filename)
    # Run only if file doesn't already exist
    if os.path.exists(filePath):
        print("Skipping File:", filename)
        continue
    print("Running on File:", filename)

    features = []
    patch_folder = os.path.join(patch_dir, folder)
    for patch_file in sorted(os.listdir(patch_folder)):
        img_path = os.path.join(patch_folder, patch_file)

        # Get coord in [x, y] format
        coord = img_path.split("/")
        coord = coord[-1]
        coord = coord.split(".")[-2]
        coord = coord.split("_")
        coord = [int(coord[-2])/256, int(coord[-1])/256]

        # Read image
        orig = cv2.imread(img_path)

        # Convert image to RGB from BGR (another way is to use "image = image[:, :, ::-1]" code)
        orig = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB)

        # Resize image to 224x224 size
        image = cv2.resize(orig, (224, 224)).reshape(-1, 224, 224, 3)

        # We need to preprocess imageto fulfill ResNet50 requirements
        image = preprocess_input(image)

        # Extracting our features
        feature = model.predict(image)

        # Group the features
        features.append(feature)
    np.save(filePath, features)

# Dataset creator

In [None]:
filename = os.path.join('/kaggle/working/', 'fungal_vs_nonfungal.csv')
patch_dir = os.path.abspath('/kaggle/working/patches/')
feat_dir = os.path.abspath('/kaggle/working/features/')
annotated_dir = os.path.abspath('/kaggle/input/fungal-10x-annot/')

In [None]:
with open(filename, 'w') as file:
    file.write('case_id,slide_id,label' + '\n')

    patch_folders = [os.path.join(patch_dir, folder) for folder in sorted(os.listdir(patch_dir))]

    for i, name in enumerate(patch_folders):
        name = name.split("/")[-1]
        if name != feat_dir:
            if name[0] == "F":
                f_nf = "fungal"
            elif name[0] == "N":
                f_nf = "nonfungal"
                annotated = True
            else:
                f_nf = "unclassified"

            line = 'case_' + str(i) + ',' + name + ',' + f_nf
            file.write('{}\n'.format(line))


# Create Splits

In [31]:
label_frac = 1.0
seed = 1
k = 5
val_frac = 0.15
test_frac = 0.15
annot_frac = 0.4
annot_positive_frac = 1

dataset_csv_file = os.path.join('/kaggle/working/', 'fungal_vs_nonfungal.csv')

In [32]:
import os
import yaml
import random
import argparse
import numpy as np

In [33]:
from __future__ import print_function, division
import os
import torch
import numpy as np
import pandas as pd
import math
import re
import pdb
import pickle
import random
from scipy import stats

from torch.utils.data import Dataset
import h5py

In [35]:
def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5,
    seed = 7, label_frac = 1.0, custom_test_ids = None):
    indices = np.arange(samples).astype(int)

    if custom_test_ids is not None:
        indices = np.setdiff1d(indices, custom_test_ids)

    np.random.seed(seed)
    for i in range(n_splits):
        all_val_ids = []
        all_test_ids = []
        sampled_train_ids = []

        if custom_test_ids is not None: # pre-built test split, do not need to sample
            all_test_ids.extend(custom_test_ids)

        for c in range(len(val_num)):
            possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class
            val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids

            remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation
            all_val_ids.extend(val_ids)

            if custom_test_ids is None: # sample test split

                test_ids = np.random.choice(remaining_ids, test_num[c], replace = False)
                remaining_ids = np.setdiff1d(remaining_ids, test_ids)
                all_test_ids.extend(test_ids)

            if label_frac == 1:
                sampled_train_ids.extend(remaining_ids)

            else:
                sample_num  = math.ceil(len(remaining_ids) * label_frac)
                slice_ids = np.arange(sample_num)
                sampled_train_ids.extend(remaining_ids[slice_ids])

        yield sampled_train_ids, all_val_ids, all_test_ids


def nth(iterator, n, default=None):
    if n is None:
        return collections.deque(iterator, maxlen=0)
    else:
        return next(islice(iterator,n, None), default)

In [88]:
def save_splits(split_datasets, column_keys, filename, annot_frac=None, annot_positive_frac=None, boolean_style=False, annot_create=True):
    print(split_datasets)
    splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))]

    if annot_create:
        # Add annot column # Only for 2 classes
        train_set = split_datasets[0]
        train_set_list = []
        annot_set = []
        positive_list = []
        negative_list = []

        for ids in train_set.slide_cls_ids[0]:
            negative_list.append(str(train_set.slide_data['slide_id'][ids]))

        for ids in train_set.slide_cls_ids[1]:
            positive_list.append(str(train_set.slide_data['slide_id'][ids]))

        train_set_list.extend(negative_list)
        train_set_list.extend(positive_list)

        train_set_annot = np.round(len(train_set_list) * annot_frac)
        neg_annot_num = np.round(train_set_annot * (1-annot_positive_frac)).astype(int)
        pos_annot_num = np.round(train_set_annot * annot_positive_frac).astype(int)

        neg_annot_set = random.sample(negative_list, neg_annot_num)
        pos_annot_set = random.sample(positive_list, pos_annot_num)

        annot_set.extend(neg_annot_set)
        annot_set.extend(pos_annot_set)

    #     print("annot_set", annot_set)

        true_annot_set = [False]*len(train_set_list)
        for idx in range(len(true_annot_set)):
            if train_set_list[idx] in annot_set:
                true_annot_set[idx] = True
    #     print("true_annot_set", true_annot_set)
        true_annot_set = pd.DataFrame(true_annot_set)
    #     print("splits", splits)
    #     print("true_annot_set", true_annot_set)
        splits.insert(1, true_annot_set)

    if not boolean_style:
        df = pd.concat(splits, ignore_index=True, axis=1)
        df.columns = column_keys
    else:
        df = pd.concat(splits, ignore_index = True, axis=0)
        index = df.values.tolist()
        one_hot = np.eye(len(split_datasets)).astype(bool)
        bool_array = np.repeat(one_hot, [len(dset) for dset in split_datasets], axis=0)
        df = pd.DataFrame(bool_array, index=index, columns = ['train', 'annot', 'val', 'test'])

    print(split_datasets[0].slide_data)
    df.to_csv(filename)
    print()

class Generic_WSI_Classification_Dataset(Dataset):
    def __init__(self,
        csv_path = 'dataset_csv/ccrcc_clean.csv',
        shuffle = False,
        seed = 7,
        print_info = True,
        label_dict = {},
        filter_dict = {},
        ignore=[],
        patient_strat=False,
        label_col = None,
        patient_voting = 'max',
        results_dir = None
        ):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            shuffle (boolean): Whether to shuffle
            seed (int): random seed for shuffling the data
            print_info (boolean): Whether to print a summary of the dataset
            label_dict (dict): Dictionary with key, value pairs for converting str labels to int
            ignore (list): List containing class labels to ignore
        """
        self.label_dict = label_dict
        self.num_classes = len(set(self.label_dict.values()))
        self.seed = seed
        self.print_info = print_info
        self.patient_strat = patient_strat
        self.train_ids, self.val_ids, self.test_ids  = (None, None, None)
        self.data_dir = None
        self.annot_dir = None
        if not label_col:
            label_col = 'label'
        self.label_col = label_col

        slide_data = pd.read_csv(csv_path)
        slide_data = self.filter_df(slide_data, filter_dict)
        slide_data = self.df_prep(slide_data, self.label_dict, ignore, self.label_col)
        print(slide_data)

        ###shuffle data
        if shuffle:
            np.random.seed(seed)
            np.random.shuffle(slide_data)

        self.slide_data = slide_data
        if results_dir:
            slide_data.to_csv(os.path.join(results_dir, 'dataset_csv.csv'))

        self.patient_data_prep(patient_voting)
        self.cls_ids_prep()

        if print_info:
            self.summarize()

    def cls_ids_prep(self):
        # store ids corresponding each class at the patient or case level
        self.patient_cls_ids = [[] for i in range(self.num_classes)]
        for i in range(self.num_classes):
            self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0]

        # store ids corresponding each class at the slide level
        self.slide_cls_ids = [[] for i in range(self.num_classes)]
        for i in range(self.num_classes):
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]

    def patient_data_prep(self, patient_voting='max'):
        patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients
        patient_labels = []

        for p in patients:
            locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist()
            assert len(locations) > 0
            label = self.slide_data['label'][locations].values
            if patient_voting == 'max':
                label = label.max() # get patient label (MIL convention)
            elif patient_voting == 'maj':
                label = stats.mode(label)[0]
            else:
                raise NotImplementedError
            patient_labels.append(label)

        self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)}

    @staticmethod
    def df_prep(data, label_dict, ignore, label_col):
        if label_col != 'label':
            data['label'] = data[label_col].copy()

        mask = data['label'].isin(ignore)
        data = data[~mask]
        data.reset_index(drop=True, inplace=True)
        for i in data.index:
            key = data.loc[i, 'label']
            data.at[i, 'label'] = label_dict[key]

        return data

    def filter_df(self, df, filter_dict={}):
        if len(filter_dict) > 0:
            filter_mask = np.full(len(df), True, bool)
            # assert 'label' not in filter_dict.keys()
            for key, val in filter_dict.items():
                mask = df[key].isin(val)
                filter_mask = np.logical_and(filter_mask, mask)
            df = df[filter_mask]
        return df

    def __len__(self):
        if self.patient_strat:
            return len(self.patient_data['case_id'])

        else:
            return len(self.slide_data)

    def summarize(self):
        print("label column: {}".format(self.label_col))
        print("label dictionary: {}".format(self.label_dict))
        print("number of classes: {}".format(self.num_classes))
        print("slide-level counts: ", '\n', self.slide_data['label'].value_counts(sort = False))
        for i in range(self.num_classes):
            print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0]))
            print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0]))

    def create_splits(self, k = 3, val_num = (25, 25), test_num = (40, 40), label_frac = 1.0, custom_test_ids = None):
        settings = {
                    'n_splits' : k,
                    'val_num' : val_num,
                    'test_num': test_num,
                    'label_frac': label_frac,
                    'seed': self.seed,
                    'custom_test_ids': custom_test_ids
                    }

        if self.patient_strat:
            settings.update({'cls_ids' : self.patient_cls_ids, 'samples': len(self.patient_data['case_id'])})
        else:
            settings.update({'cls_ids' : self.slide_cls_ids, 'samples': len(self.slide_data)})

        self.split_gen = generate_split(**settings)

    def set_splits(self,start_from=None):
        if start_from:
            ids = nth(self.split_gen, start_from)

        else:
            ids = next(self.split_gen)

        if self.patient_strat:
            slide_ids = [[] for i in range(len(ids))]

            for split in range(len(ids)):
                for idx in ids[split]:
                    case_id = self.patient_data['case_id'][idx]
                    slide_indices = self.slide_data[self.slide_data['case_id'] == case_id].index.tolist()
                    slide_ids[split].extend(slide_indices)

            self.train_ids, self.val_ids, self.test_ids = slide_ids[0], slide_ids[1], slide_ids[2]

        else:
            self.train_ids, self.val_ids, self.test_ids = ids

    def get_split_from_df(self, all_splits, split_key='train'):
        split = all_splits[split_key]
        split = split.dropna().reset_index(drop=True)

        if len(split) > 0:
            mask = self.slide_data['slide_id'].isin(split.tolist())
            df_slice = self.slide_data[mask].reset_index(drop=True)
            split = Generic_Split(df_slice, data_dir=self.data_dir, annot_dir=self.annot_dir, num_classes=self.num_classes)
        else:
            split = None

        return split

    def get_merged_split_from_df(self, all_splits, split_keys=['train']):
        merged_split = []
        for split_key in split_keys:
            split = all_splits[split_key]
            split = split.dropna().reset_index(drop=True).tolist()
            merged_split.extend(split)

        if len(split) > 0:
            mask = self.slide_data['slide_id'].isin(merged_split)
            df_slice = self.slide_data[mask].reset_index(drop=True)
            split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes)
        else:
            split = None

        return split

    def get_overlap_split_from_df(self, all_splits, split_keys=['train', 'annot']):
        train_split = all_splits['train']
        annot_split = all_splits['annot']

        if len(train_split) > 0:
            mask = self.slide_data['slide_id'].isin(train_split)
            df_slice = self.slide_data[mask].reset_index(drop=True)

            mask = train_split.isin(df_slice['slide_id'].tolist())
            df_slice['annot'] = annot_split[mask]

            split = Generic_Split(df_slice, data_dir=self.data_dir, annot_dir=self.annot_dir, num_classes=self.num_classes)
        else:
            split = None

        return split


    def return_splits(self, from_id=True, csv_path=None):


        if from_id:
            if len(self.train_ids) > 0:
                train_data = self.slide_data.loc[self.train_ids].reset_index(drop=True)
                train_split = Generic_Split(train_data, annot_dir=self.annot_dir, data_dir=self.data_dir, num_classes=self.num_classes)

            else:
                train_split = None

            if len(self.val_ids) > 0:
                val_data = self.slide_data.loc[self.val_ids].reset_index(drop=True)
                val_split = Generic_Split(val_data, data_dir=self.data_dir, num_classes=self.num_classes)

            else:
                val_split = None

            if len(self.test_ids) > 0:
                test_data = self.slide_data.loc[self.test_ids].reset_index(drop=True)
                test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes)

            else:
                test_split = None


        else:
            assert csv_path
            all_splits = pd.read_csv(csv_path, dtype=self.slide_data['slide_id'].dtype)  # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01.
            train_split = self.get_overlap_split_from_df(all_splits, ['train', 'annot'])
            val_split = self.get_overlap_split_from_df(all_splits, ['val', 'annot'])
            test_split = self.get_overlap_split_from_df(all_splits, ['test', 'annot'])
            # train_split = self.get_split_from_df(all_splits, 'train')
            # val_split = self.get_split_from_df(all_splits, 'val')
            # test_split = self.get_split_from_df(all_splits, 'test')

        return train_split, val_split, test_split

    def get_list(self, ids):
        return self.slide_data['slide_id'][ids]

    def getlabel(self, ids):
        return self.slide_data['label'][ids]

    def __getitem__(self, idx):
        return None

    def test_split_gen(self, return_descriptor=False):

        if return_descriptor:
            index = [list(self.label_dict.keys())[list(self.label_dict.values()).index(i)] for i in range(self.num_classes)]
            columns = ['train', 'val', 'test']
            df = pd.DataFrame(np.full((len(index), len(columns)), 0, dtype=np.int32), index= index,
                            columns= columns)

        count = len(self.train_ids)
        print('\nnumber of training samples: {}'.format(count))
        labels = self.getlabel(self.train_ids)
        unique, counts = np.unique(labels, return_counts=True)
        for u in range(len(unique)):
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
            if return_descriptor:
                df.loc[index[u], 'train'] = counts[u]

        count = len(self.val_ids)
        print('\nnumber of val samples: {}'.format(count))
        labels = self.getlabel(self.val_ids)
        unique, counts = np.unique(labels, return_counts=True)
        for u in range(len(unique)):
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
            if return_descriptor:
                df.loc[index[u], 'val'] = counts[u]

        count = len(self.test_ids)
        print('\nnumber of test samples: {}'.format(count))
        labels = self.getlabel(self.test_ids)
        unique, counts = np.unique(labels, return_counts=True)
        for u in range(len(unique)):
            print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
            if return_descriptor:
                df.loc[index[u], 'test'] = counts[u]

        assert len(np.intersect1d(self.train_ids, self.test_ids)) == 0
        assert len(np.intersect1d(self.train_ids, self.val_ids)) == 0
        assert len(np.intersect1d(self.val_ids, self.test_ids)) == 0

        if return_descriptor:
            return df

    def save_split(self, filename):
        train_split = self.get_list(self.train_ids)
        val_split = self.get_list(self.val_ids)
        test_split = self.get_list(self.test_ids)
        df_tr = pd.DataFrame({'train': train_split})
        df_v = pd.DataFrame({'val': val_split})
        df_t = pd.DataFrame({'test': test_split})
        df = pd.concat([df_tr, df_v, df_t], axis=1)
        df.to_csv(filename, index = False)


class Generic_MIL_Dataset(Generic_WSI_Classification_Dataset):
    def __init__(self,
        data_dir,
        annot_dir=None,
        patch_annot_dir=None,
        **kwargs):

        super(Generic_MIL_Dataset, self).__init__(**kwargs)
        self.data_dir = data_dir
        self.annot_dir = annot_dir
        self.patch_annot_dir = patch_annot_dir
        self.use_h5 = False

    def load_from_h5(self, toggle):
        self.use_h5 = toggle

    def __getitem__(self, idx):
        slide_id = self.slide_data['slide_id'][idx]
        label = self.slide_data['label'][idx]
        if self.slide_data['annot'][idx]:
            bool_annot = bool(self.slide_data['annot'][idx])
            if label == 1:
                patch_annot_path = os.path.join(self.patch_annot_dir, slide_id, slide_id+'.pkl')
                patch_annot = load_pkl(patch_annot_path)
                patch_annot = patch_annot['bin_scores']
            elif label == 0:
                patch_annot = [False]*24

        if type(self.data_dir) == dict:
            source = self.slide_data['source'][idx]
            data_dir = self.data_dir[source]
        else:
            data_dir = self.data_dir

        if not self.use_h5:
            if self.data_dir:
                full_path = os.path.join(data_dir, '{}.pt'.format(slide_id))
                features = torch.load(full_path)
                return features, label, bool_annot, patch_annot
                # return features, label

            else:
                # if bool_annot:
                #     return slide_id, label, bool_annot, patch_annot
                # else:
                #     return slide_id, label, None, None
                return slide_id, label

        else:
            full_path = os.path.join(data_dir,'h5_files','{}.h5'.format(slide_id))
            with h5py.File(full_path,'r') as hdf5_file:
                features = hdf5_file['features'][:]
                coords = hdf5_file['coords'][:]

            features = torch.from_numpy(features)
            return features, label, coords


class Generic_Split(Generic_MIL_Dataset):
    def __init__(self, slide_data, annot_dir=None, patch_annot_dir=None, data_dir=None, num_classes=2):
        self.use_h5 = False
        self.slide_data = slide_data
        self.data_dir = data_dir
        self.annot_dir = annot_dir
        self.patch_annot_dir = patch_annot_dir,
        self.num_classes = num_classes
        self.slide_cls_ids = [[] for i in range(self.num_classes)]
        for i in range(self.num_classes):
            self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]

    def __len__(self):
        return len(self.slide_data)

In [89]:
random.seed(seed)

# task_1_fungal_vs_nonfungal
n_classes=2
dataset = Generic_WSI_Classification_Dataset(csv_path = dataset_csv_file,
                        shuffle = False,
                        seed = seed,
                        print_info = True,
                        label_dict = {'nonfungal':0, 'fungal':1},
                        patient_strat=True,
                        ignore=[])

num_slides_cls = np.array([len(cls_ids) for cls_ids in dataset.patient_cls_ids])
val_num = np.round(num_slides_cls * val_frac).astype(int)
test_num = np.round(num_slides_cls * test_frac).astype(int)

if label_frac > 0:
    label_fracs = [label_frac]
else:
    label_fracs = [0.1, 0.25, 0.5, 0.75, 1.0]

for lf in label_fracs:
    split_dir = 'splits/fungal_vs_nonfungal' + '_{}'.format(int(lf * 100))
    os.makedirs(split_dir, exist_ok=True)
    dataset.create_splits(k = k, val_num = val_num, test_num = test_num, label_frac=lf)
    for i in range(k):
        dataset.set_splits()
        descriptor_df = dataset.test_split_gen(return_descriptor=True)
        splits = dataset.return_splits(from_id=True)

        save_splits(splits, ['train', 'annot', 'val', 'test'], os.path.join(split_dir, 'splits_{}.csv'.format(i)), annot_frac=annot_frac, annot_positive_frac=annot_positive_frac)
        # save_splits(splits, ['train', 'annot', 'val', 'test'], os.path.join(split_dir, 'splits_{}_bool.csv'.format(i)), boolean_style=True)
        # descriptor_df.to_csv(os.path.join(split_dir, 'splits_{}_descriptor.csv'.format(i)))

      case_id  slide_id label
0      case_0   F005a02     1
1      case_1   F006a01     1
2      case_2   F006a02     1
3      case_3   F006a03     1
4      case_4   F006a04     1
..        ...       ...   ...
418  case_418  N012a017     0
419  case_419  N012a018     0
420  case_420  N012a019     0
421  case_421  N012a020     0
422  case_422  N017a001     0

[423 rows x 3 columns]
label column: label
label dictionary: {'nonfungal': 0, 'fungal': 1}
number of classes: 2
slide-level counts:  
 1    208
0    215
Name: label, dtype: int64
Patient-LVL; Number of samples registered in class 0: 215
Slide-LVL; Number of samples registered in class 0: 215
Patient-LVL; Number of samples registered in class 1: 208
Slide-LVL; Number of samples registered in class 1: 208

number of training samples: 297
number of samples in cls 0: 151
number of samples in cls 1: 146

number of val samples: 63
number of samples in cls 0: 32
number of samples in cls 1: 31

number of test samples: 63
number of samples 

In [90]:
# View the splits
# !cat splits/fungal_vs_nonfungal_100/splits_0.csv

# Training the model

In [91]:
max_epochs = 200
lr = 1e-4
label_frac = 1.0
reg = 1e-5
seed = 1
k = 5
k_start = -1
k_end = -1
dataset_csv = os.path.join('/kaggle/working/fungal_vs_nonfungal.csv')
results_parent_dir = os.path.join('/kaggle/working/results/')
split_dir = os.path.join('/kaggle/working/splits/fungal_vs_nonfungal_100/')
log_data = False
testing = False
early_stopping = False
opt = 'adam'
drop_out = False
bag_loss = 'ce'  # ['ce', 'svm']
model_type = 'clam_sb'  # ['clam_sb', 'clam_mb', 'mil']
weighted_sample = False
model_size = 'small'
task = 'task_fungal_vs_nonfungal'

### CLAM specific options
no_inst_cluster = False
inst_loss = None  # ['svm', 'ce', None]
subtyping = False
bag_weight = 0.5
B = 12

exp_code = "exp_01"  # Experiment name
dropout = True  # Whether to use dropout
patch_dir = os.path.abspath('/kaggle/working/patches/')
dest_dir = os.path.abspath('/kaggle/working/splits/')
annot_dir = os.path.abspath('/kaggle/input/fungal-10x-annot/')
patch_annot_dir = os.path.abspath('/kaggle/working/patch_annot/')
feat_dir = os.path.abspath('/kaggle/working/features/')

### Alpha weight
alpha_weight = False
T1 = 50
T2 = 150
af = 1.0

In [92]:
# !pip install --upgrade setuptools

# !git clone https://github.com/oval-group/smooth-topk.git
# !cd smooth-topk && python setup.py install

In [93]:
import pdb
import os
import yaml
import argparse
import math

# pytorch imports
import torch
from torch.utils.data import DataLoader, sampler
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
import numpy as np

In [94]:
if not os.path.isdir(results_parent_dir):
    os.mkdir(results_parent_dir)
results_dir = os.path.join(results_parent_dir, str(exp_code) + '_s{}'.format(seed))
if not os.path.isdir(results_dir):
    os.mkdir(results_dir)

print(results_dir)

/kaggle/working/results/exp_01_s1


from modules.file_utils import save_pkl, load_pkl

from modules.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset

In [95]:
import pickle
import torch
import numpy as np
import torch.nn as nn
import pdb

import torch
import numpy as np
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler
import torch.optim as optim
import pdb
import torch.nn.functional as F
import math
from itertools import islice
import collections
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SubsetSequentialSampler(Sampler):
    """Samples elements sequentially from a given list of indices, without replacement.

    Arguments:
        indices (sequence): a sequence of indices
    """
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)

def collate_MIL(batch):
    img = torch.cat([item[0] for item in batch], dim = 0)
    label = torch.LongTensor([item[1] for item in batch])
    return [img, label]

def collate_MIL_annot(batch):
    img = torch.cat([item[0] for item in batch], dim = 0)
    label = torch.LongTensor([item[1] for item in batch])
    bool_annot = torch.LongTensor([item[2] for item in batch])
    patch_annot = torch.LongTensor([item[3] for item in batch])
    return [img, label, bool_annot, patch_annot]

def collate_features(batch):
    img = torch.cat([item[0] for item in batch], dim = 0)
    coords = np.vstack([item[1] for item in batch])
    return [img, coords]


def get_simple_loader(dataset, batch_size=1, num_workers=1):
    kwargs = {'num_workers': 4, 'pin_memory': False, 'num_workers': num_workers} if device.type == "cuda" else {}
    loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs)
    return loader

def get_split_loader(split_dataset, training = False, testing = False, weighted = False):
    """
        return either the validation loader or training loader
    """
    kwargs = {'num_workers': 4} if device.type == "cuda" else {}
    if not testing:
        if training:
            if weighted:
                weights = make_weights_for_balanced_classes_split(split_dataset)
                loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate_MIL, **kwargs)
            else:

                loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate_MIL_annot, **kwargs)
        else:
            loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate_MIL, **kwargs)

    else:
        ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False)
        loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate_MIL, **kwargs )

    return loader

def get_optim(model, settings):
    if settings['opt'] == "adam":
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=settings['lr'], weight_decay=settings['reg'])
    elif settings['opt'] == 'sgd':
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=settings['lr'], momentum=0.9, weight_decay=setting['reg'])
    else:
        raise NotImplementedError
    return optimizer

def print_network(net):
    num_params = 0
    num_params_train = 0
    print(net)

    for param in net.parameters():
        n = param.numel()
        num_params += n
        if param.requires_grad:
            num_params_train += n

    print('Total number of parameters: %d' % num_params)
    print('Total number of trainable parameters: %d' % num_params_train)


def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5,
    seed = 7, label_frac = 1.0, custom_test_ids = None):
    indices = np.arange(samples).astype(int)

    if custom_test_ids is not None:
        indices = np.setdiff1d(indices, custom_test_ids)

    np.random.seed(seed)
    for i in range(n_splits):
        all_val_ids = []
        all_test_ids = []
        sampled_train_ids = []

        if custom_test_ids is not None: # pre-built test split, do not need to sample
            all_test_ids.extend(custom_test_ids)

        for c in range(len(val_num)):
            possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class
            val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids

            remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation
            all_val_ids.extend(val_ids)

            if custom_test_ids is None: # sample test split

                test_ids = np.random.choice(remaining_ids, test_num[c], replace = False)
                remaining_ids = np.setdiff1d(remaining_ids, test_ids)
                all_test_ids.extend(test_ids)

            if label_frac == 1:
                sampled_train_ids.extend(remaining_ids)

            else:
                sample_num  = math.ceil(len(remaining_ids) * label_frac)
                slice_ids = np.arange(sample_num)
                sampled_train_ids.extend(remaining_ids[slice_ids])

        yield sampled_train_ids, all_val_ids, all_test_ids


def nth(iterator, n, default=None):
    if n is None:
        return collections.deque(iterator, maxlen=0)
    else:
        return next(islice(iterator,n, None), default)

def calculate_error(Y_hat, Y):
    error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()

    return error

def make_weights_for_balanced_classes_split(dataset):
    N = float(len(dataset))
    weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))]
    weight = [0] * int(N)
    for idx in range(len(dataset)):
        y = dataset.getlabel(idx)
        weight[idx] = weight_per_class[y]

    return torch.DoubleTensor(weight)

def initialize_weights(module):
    for m in module.modules():
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            m.bias.data.zero_()

        elif isinstance(m, nn.BatchNorm1d):
            nn.init.constant_(m.weight, 1)

            nn.init.constant_(m.bias, 0)

#from modules.model_mil import MIL_fc, MIL_fc_mc

from modules.utils import initialize_weights

In [96]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

"""
Attention Network without Gating (2 fc layers)
args:
    L: input feature dimension
    D: hidden layer dimension
    dropout: whether to use dropout (p = 0.25)
    n_classes: number of classes
"""
class Attn_Net(nn.Module):

    def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
        super(Attn_Net, self).__init__()
        self.module = [
            nn.Linear(L, D),
            nn.Tanh()]

        if dropout:
            self.module.append(nn.Dropout(0.25))

        self.module.append(nn.Linear(D, n_classes))

        self.module = nn.Sequential(*self.module)

    def forward(self, x):
        return self.module(x), x # N x n_classes

"""
Attention Network with Sigmoid Gating (3 fc layers)
args:
    L: input feature dimension
    D: hidden layer dimension
    dropout: whether to use dropout (p = 0.25)
    n_classes: number of classes
"""
class Attn_Net_Gated(nn.Module):
    def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
        super(Attn_Net_Gated, self).__init__()
        self.attention_a = [
            nn.Linear(L, D),
            nn.Tanh()]

        self.attention_b = [nn.Linear(L, D),
                            nn.Sigmoid()]
        if dropout:
            self.attention_a.append(nn.Dropout(0.25))
            self.attention_b.append(nn.Dropout(0.25))

        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)

        self.attention_c = nn.Linear(D, n_classes)

    def forward(self, x):
        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b)
        A = self.attention_c(A)  # N x n_classes
        return A, x

"""
args:
    gate: whether to use gated attention network
    size_arg: config for network size
    dropout: whether to use dropout
    k_sample: number of positive/neg patches to sample for instance-level training
    dropout: whether to use dropout (p = 0.25)
    n_classes: number of classes
    instance_loss_fn: loss function to supervise instance-level training
    subtyping: whether it's a subtyping problem
"""
class CLAM_SB(nn.Module):
    def __init__(self, gate = True, size_arg = "small", dropout = False, k_sample=8, n_classes=2,
        instance_loss_fn=nn.CrossEntropyLoss(), subtyping=False):
        super(CLAM_SB, self).__init__()
        self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
        size = self.size_dict[size_arg]
        fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
        if dropout:
            fc.append(nn.Dropout(0.25))
        if gate:
            attention_net = Attn_Net_Gated(L = size[1], D = size[2], dropout = dropout, n_classes = 1)
        else:
            attention_net = Attn_Net(L = size[1], D = size[2], dropout = dropout, n_classes = 1)
        fc.append(attention_net)
        self.attention_net = nn.Sequential(*fc)
        self.classifiers = nn.Linear(size[1], n_classes)
        instance_classifiers = [nn.Linear(size[1], 2) for i in range(n_classes)]
        self.instance_classifiers = nn.ModuleList(instance_classifiers)
        self.k_sample = k_sample
        self.instance_loss_fn = instance_loss_fn
        self.n_classes = n_classes
        self.subtyping = subtyping

        initialize_weights(self)

    def relocate(self):
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.attention_net = self.attention_net.to(device)
        self.classifiers = self.classifiers.to(device)
        self.instance_classifiers = self.instance_classifiers.to(device)

    @staticmethod
    def create_positive_targets(length, device):
        return torch.full((length, ), 1, device=device).long()
    @staticmethod
    def create_negative_targets(length, device):
        return torch.full((length, ), 0, device=device).long()

    #instance-level evaluation for in-the-class attention branch
    def inst_eval(self, A, h, classifier, bool_annot, patch_annot, alpha_weight, weight_alpha):
        device=h.device
        if len(A.shape) == 1:
            A = A.view(1, -1)

        # Get instance
        top_p_ids = torch.topk(A.squeeze(), self.k_sample)[1]
        top_n_ids = torch.topk(-A.squeeze(), self.k_sample)[1]
        top_p = torch.index_select(h, dim=0, index=top_p_ids)
        top_n = torch.index_select(h, dim=0, index=top_n_ids)
        all_instances = torch.cat([top_p, top_n], dim=0)

        logits = classifier(all_instances)
        logits = logits.view(2*self.k_sample, 2)  # Shape is [24, 2]; B=12

        all_preds = torch.topk(logits, 1, dim = 1)[1].squeeze(1)

        # Get target labels
        if bool_annot:
            p_targets = torch.index_select(patch_annot.squeeze(), dim=0, index=top_p_ids).long()
            n_targets = torch.index_select(patch_annot.squeeze(), dim=0, index=top_n_ids).long()
        else:
            p_targets = self.create_positive_targets(self.k_sample, device)
            n_targets = self.create_negative_targets(self.k_sample, device)

        all_targets = torch.cat([p_targets, n_targets], dim=0)
#         print("logits", logits.shape)
#         print("all_targets", all_targets.shape)
        instance_loss = self.instance_loss_fn(logits, all_targets)
        if alpha_weight and not bool_annot:
            instance_loss *= weight_alpha
        return instance_loss, all_preds, all_targets

    #instance-level evaluation for out-of-the-class attention branch
    def inst_eval_out(self, A, h, classifier):
        device=h.device
        if len(A.shape) == 1:
            A = A.view(1, -1)
        top_p_ids = torch.topk(A.squeeze(), self.k_sample)[1]
        top_p = torch.index_select(h, dim=0, index=top_p_ids)
        p_targets = self.create_negative_targets(self.k_sample, device)
#         print("top_p", top_p.shape)
        logits = classifier(top_p)
        p_preds = torch.topk(logits, 1, dim = 1)[1].squeeze(1)
#         print("logits", logits.shape)
#         print("p_targets", p_targets.shape)
        instance_loss = self.instance_loss_fn(logits.squeeze(), p_targets)
        return instance_loss, p_preds, p_targets

    def forward(self, h, bool_annot=None, patch_annot=None, alpha_weight=False, weight_alpha=None, label=None, instance_eval=False, return_features=False, attention_only=False):
        device = h.device
#         print("h.shape", h.shape)
        A, h = self.attention_net(h)  # NxK
#         print("A.shape", A.shape)
#         print("h.shape", h.shape)
        A = torch.transpose(A, 1, 0)  # KxN
        if attention_only:
            return A
        A_raw = A
        A = F.softmax(A, dim=1)  # softmax over N

        if instance_eval:
            total_inst_loss = 0.0
            all_preds = []
            all_targets = []
            inst_labels = F.one_hot(label, num_classes=self.n_classes).squeeze() #binarize label
            for i in range(len(self.instance_classifiers)):
                inst_label = inst_labels[i].item()
#                 print("inst_label", inst_label)
                classifier = self.instance_classifiers[i]
                if inst_label == 1: #in-the-class:
                    instance_loss, preds, targets = self.inst_eval(A, h, classifier, bool_annot, patch_annot, alpha_weight, weight_alpha)
#                     print("1, preds", preds.shape)
#                     print("1, targets", targets.shape)
                    all_preds.extend(preds.cpu().numpy())
                    all_targets.extend(targets.cpu().numpy())
                else: #out-of-the-class
                    if self.subtyping:
                        instance_loss, preds, targets = self.inst_eval_out(A, h, classifier)
#                         print("0, preds", preds.shape)
#                         print("0, targets", targets.shape)
#                         print("0, all_preds", all_preds)
#                         print("0, all_targets", all_targets)
                        all_preds.extend(preds.squeeze().cpu().numpy())
                        all_targets.extend(targets.cpu().numpy())
#                         print("0, all_preds", all_preds)
#                         print("0, all_targets", all_targets)
                    else:
                        continue
                total_inst_loss += instance_loss

            if self.subtyping:
                total_inst_loss /= len(self.instance_classifiers)

        M = torch.mm(A.view(1, 24), h.view(24, 512))
        logits = self.classifiers(M)
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
#         print("logits", logits)
#         print("logits.shape", logits.shape)
#         print("Y_hat", Y_hat)
#         print("Y_hat.shape", Y_hat.shape)

#         print("all_targets", all_targets)
#         print("all_targets shape", len(all_targets))
#         print("all_preds", all_preds)
#         print("all_preds shape", len(all_preds))
        Y_prob = F.softmax(logits, dim = 1)
#         print("Y_prob.shape", Y_prob.shape)
#         print("Y_prob", logits)
        if instance_eval:
            results_dict = {'instance_loss': total_inst_loss, 'inst_labels': np.array(all_targets),
            'inst_preds': np.array(all_preds)}
        else:
            results_dict = {}
        if return_features:
            results_dict.update({'features': M})
#         print("Y_hat shape", Y_hat.shape)
        return logits, Y_prob, Y_hat, A_raw, results_dict

class CLAM_MB(CLAM_SB):
    def __init__(self, gate = True, size_arg = "small", dropout = False, k_sample=8, n_classes=2,
        instance_loss_fn=nn.CrossEntropyLoss(), subtyping=False):
        nn.Module.__init__(self)
        self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
        size = self.size_dict[size_arg]
        fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
        if dropout:
            fc.append(nn.Dropout(0.25))
        if gate:
            attention_net = Attn_Net_Gated(L = size[1], D = size[2], dropout = dropout, n_classes = n_classes)
        else:
            attention_net = Attn_Net(L = size[1], D = size[2], dropout = dropout, n_classes = n_classes)
        fc.append(attention_net)
        self.attention_net = nn.Sequential(*fc)
        bag_classifiers = [nn.Linear(size[1], 1) for i in range(n_classes)] #use an indepdent linear layer to predict each class
        self.classifiers = nn.ModuleList(bag_classifiers)
        instance_classifiers = [nn.Linear(size[1], 2) for i in range(n_classes)]
        self.instance_classifiers = nn.ModuleList(instance_classifiers)
        self.k_sample = k_sample
        self.instance_loss_fn = instance_loss_fn
        self.n_classes = n_classes
        self.subtyping = subtyping
        initialize_weights(self)

    def forward(self, h, label=None, instance_eval=False, return_features=False, attention_only=False):
        device = h.device
        A, h = self.attention_net(h)  # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        if attention_only:
            return A
        A_raw = A
        A = F.softmax(A, dim=1)  # softmax over N

        if instance_eval:
            total_inst_loss = 0.0
            all_preds = []
            all_targets = []
            inst_labels = F.one_hot(label, num_classes=self.n_classes).squeeze() #binarize label
            for i in range(len(self.instance_classifiers)):
                inst_label = inst_labels[i].item()
                classifier = self.instance_classifiers[i]
                if inst_label == 1: #in-the-class:
                    instance_loss, preds, targets = self.inst_eval(A.view(self.n_classes, 24)[i], h, classifier)
                    all_preds.extend(preds.cpu().numpy())
                    all_targets.extend(targets.cpu().numpy())
                else: #out-of-the-class
                    if self.subtyping:
                        instance_loss, preds, targets = self.inst_eval_out(A.view(self.n_classes, 24)[i], h, classifier)
                        all_preds.extend(preds.cpu().numpy())
                        all_targets.extend(targets.cpu().numpy())
                    else:
                        continue
                total_inst_loss += instance_loss

            if self.subtyping:
                total_inst_loss /= len(self.instance_classifiers)

        M = torch.mm(A.view(self.n_classes, 24), h.view(24, 512))
        logits = torch.empty(1, self.n_classes).float().to(device)
        for c in range(self.n_classes):
            logits[0, c] = self.classifiers[c](M[c])
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
        Y_prob = F.softmax(logits, dim = 1)
        if instance_eval:
            results_dict = {'instance_loss': total_inst_loss, 'inst_labels': np.array(all_targets),
            'inst_preds': np.array(all_preds)}
        else:
            results_dict = {}
        if return_features:
            results_dict.update({'features': M})
        return logits, Y_prob, Y_hat, A_raw, results_dict

In [97]:
import numpy as np
import torch
import os

from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.metrics import auc as calc_auc

class Accuracy_Logger(object):
    """Accuracy logger"""
    def __init__(self, n_classes):
        super(Accuracy_Logger, self).__init__()
        self.n_classes = n_classes
        self.initialize()

    def initialize(self):
        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]

    def log(self, Y_hat, Y):
        Y_hat = int(Y_hat)
        Y = int(Y)
        self.data[Y]["count"] += 1
        self.data[Y]["correct"] += (Y_hat == Y)

    def log_batch(self, Y_hat, Y):
#         print("Y_hat", Y_hat)
        Y_hat = np.array(Y_hat).astype(int)
        Y = np.array(Y).astype(int)
        for label_class in np.unique(Y):
            cls_mask = Y == label_class
#             print("B-Log Y_hat", Y_hat)
#             print("B-Log Y_hat.shape", Y_hat.shape)
#             print("B-Log Y.shape", Y.shape)
#             print("B-0", sum(cls_mask))
#             print("B-1", sum(Y_hat[cls_mask] == Y[cls_mask]))
            self.data[label_class]["count"] += sum(cls_mask)
            self.data[label_class]["correct"] += sum(Y_hat[cls_mask] == Y[cls_mask])

#         Y_hat = np.array(Y_hat).astype(int)
#         Y_hat = np.reshape(Y_hat, (16, 2))
#         Y_hat = Y_hat[:, 0]
#         Y = np.array(Y).astype(int)
#         for label_class in np.unique(Y):
#             cls_mask = [Y == label_class]
#             self.data[label_class]["count"] += sum(cls_mask)
#             self.data[label_class]["correct"] += sum(tuple([(Y_hat[cls_mask] == Y[cls_mask]) ]) )

#             print("B count:", self.data[label_class]["count"])
#             print("B correct:", self.data[label_class]["correct"])

    def get_summary(self, c):
        count = self.data[c]["count"]
        correct = self.data[c]["correct"]

        if count == 0:
            acc = None
        else:
            acc = float(correct) / count

        return acc, correct, count

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=20, stop_epoch=50, verbose=False):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 20
            stop_epoch (int): Earliest epoch possible for stopping
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
        """
        self.patience = patience
        self.stop_epoch = stop_epoch
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, epoch, val_loss, model, ckpt_name = 'checkpoint.pt'):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, ckpt_name)
        elif score < self.best_score:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience and epoch > self.stop_epoch:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, ckpt_name)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, ckpt_name):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), ckpt_name)
        self.val_loss_min = val_loss

def train(datasets, cur, settings):
    """
        train for a single fold
    """
    print("Settings:", settings)
    print('\nTraining Fold {}!'.format(cur))
    exp_dir = os.path.join(settings["results_dir"], str(settings["exp_code"]) + '_s{}'.format(settings["seed"]))
    if not os.path.isdir(exp_dir):
        os.mkdir(exp_dir)
    split_dir = os.path.join(exp_dir, 'splits_{}'.format(cur))
    if not os.path.isdir(split_dir):
        os.mkdir(split_dir)

    if settings['log_data']:
        writer_dir = os.path.join(exp_dir, "logs", str(cur))
        if not os.path.isdir(writer_dir):
            os.mkdir(writer_dir)

        from tensorboardX import SummaryWriter
        writer = SummaryWriter(writer_dir, flush_secs=15)

    else:
        writer = None

    print('\nInit train/val/test splits...', end=' ')
    train_split, val_split, test_split = datasets
    save_splits(datasets, ['train', 'val', 'test'], os.path.join(split_dir, 'splits_{}.csv'.format(cur)), annot_create=False)
    print('Done!')
    print("Training on {} samples".format(len(train_split)))
    print("Validating on {} samples".format(len(val_split)))
    print("Testing on {} samples".format(len(test_split)))

    print('\nInit loss function...', end=' ')
    if settings['bag_loss'] == 'svm':
        from topk.svm import SmoothTop1SVM
        loss_fn = SmoothTop1SVM(n_classes = settings['n_classes'])
        if device.type == 'cuda':
            loss_fn = loss_fn.cuda()
    else:
        loss_fn = nn.CrossEntropyLoss()
    print('Done!')

    print('\nInit Model...', end=' ')
    model_dict = {"dropout": settings['dropout'], 'n_classes': settings['n_classes']}
    if settings['model_type'] == 'clam' and settings['subtyping']:
        model_dict.update({'subtyping': True})

    if settings['model_size'] is not None and settings['model_type'] != 'mil':
        model_dict.update({"size_arg": settings['model_size']})

    if settings['model_type'] in ['clam_sb', 'clam_mb']:
        if settings['subtyping']:
            model_dict.update({'subtyping': True})

        if settings['B'] > 0:
            model_dict.update({'k_sample': settings['B']})

        if settings['inst_loss'] == 'svm':
            from topk.svm import SmoothTop1SVM
            instance_loss_fn = SmoothTop1SVM(n_classes = 2)
            if device.type == 'cuda':
                instance_loss_fn = instance_loss_fn.cuda()
        else:
            instance_loss_fn = nn.CrossEntropyLoss()

        if settings['model_type'] =='clam_sb':
            model = CLAM_SB(**model_dict, instance_loss_fn=instance_loss_fn)
        elif settings['model_type'] == 'clam_mb':
            model = CLAM_MB(**model_dict, instance_loss_fn=instance_loss_fn)
        else:
            raise NotImplementedError

    else: # settings['model_type == 'mil'
        if settings['n_classes'] > 2:
            model = MIL_fc_mc(**model_dict)
        else:
            model = MIL_fc(**model_dict)

    model.relocate()
    print('Done!')
    print_network(model)

    print('\nInit optimizer ...', end=' ')
    optimizer = get_optim(model, settings)
    print('Done!')

    print('\nInit Loaders...', end=' ')
    train_loader = get_split_loader(train_split, training=True, testing = settings['testing'], weighted = settings['weighted_sample'])
    val_loader = get_split_loader(val_split,  testing = settings['testing'])
    test_loader = get_split_loader(test_split, testing = settings['testing'])
    print('Done!')

    print('\nSetup EarlyStopping...', end=' ')
    if settings['early_stopping']:
        early_stopping = EarlyStopping(patience = 20, stop_epoch=50, verbose = True)

    else:
        early_stopping = None
    print('Done!')
    
    for epoch in range(settings['max_epochs']):
        weight_alpha = get_alpha_weight(epoch, settings['T1'], settings['T2'], settings['af'])
        if settings['model_type'] in ['clam_sb', 'clam_mb'] and not settings['no_inst_cluster']:
            train_loop_clam(epoch, model, train_loader, optimizer, settings['n_classes'], settings['bag_weight'], writer, loss_fn, alpha_weight=settings['alpha_weight'], weight_alpha=weight_alpha)
            stop = validate_clam(cur, epoch, model, val_loader, settings['n_classes'],
                early_stopping, writer, loss_fn, settings['results_dir'])

        else:
            train_loop(epoch, model, train_loader, optimizer, settings['n_classes'], writer, loss_fn)
            stop = validate(cur, epoch, model, val_loader, settings['n_classes'],
                early_stopping, writer, loss_fn, settings['results_dir'])

        if stop:
            break

    exp_dir = os.path.join(settings["results_dir"], str(settings["exp_code"]) + '_s{}'.format(settings["seed"]))
    split_dir = os.path.join(exp_dir, 'splits_{}'.format(cur))
    if settings['early_stopping']:
        model.load_state_dict(torch.load(os.path.join(split_dir, "s_{}_checkpoint.pt".format(cur))))
    else:
        torch.save(model.state_dict(), os.path.join(split_dir, "s_{}_checkpoint.pt".format(cur)))

    _, val_error, val_auc, _= summary(model, val_loader, settings['n_classes'])
    print('Val error: {:.4f}, ROC AUC: {:.4f}'.format(val_error, val_auc))

    results_dict, test_error, test_auc, acc_logger = summary(model, test_loader, settings['n_classes'])
    print('Test error: {:.4f}, ROC AUC: {:.4f}'.format(test_error, test_auc))

    for i in range(settings['n_classes']):
        acc, correct, count = acc_logger.get_summary(i)
        print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))

        if writer:
            writer.add_scalar('final/test_class_{}_acc'.format(i), acc, 0)

    if writer:
        writer.add_scalar('final/val_error', val_error, 0)
        writer.add_scalar('final/val_auc', val_auc, 0)
        writer.add_scalar('final/test_error', test_error, 0)
        writer.add_scalar('final/test_auc', test_auc, 0)
        writer.close()
    return results_dict, test_auc, val_auc, 1-test_error, 1-val_error


def train_loop_clam(epoch, model, loader, optimizer, n_classes, bag_weight, writer = None, loss_fn = None, alpha_weight=False, weight_alpha=None):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    acc_logger = Accuracy_Logger(n_classes=n_classes)
    inst_logger = Accuracy_Logger(n_classes=n_classes)

    train_loss = 0.
    train_error = 0.
    train_inst_loss = 0.
    inst_count = 0

    print('\n')
    for batch_idx, (data, label, bool_annot, patch_annot) in enumerate(loader):
        data = data.float()
        model = model.float()
        data, label = data.to(device), label.to(device)
        bool_annot, patch_annot = bool_annot.to(device), patch_annot.to(device)
        # print("data.shape", data.shape)
        logits, Y_prob, Y_hat, _, instance_dict = model(data, bool_annot=bool_annot, patch_annot=patch_annot, label=label, alpha_weight=alpha_weight, weight_alpha=weight_alpha, instance_eval=True)

        acc_logger.log(Y_hat, label)
        loss = loss_fn(logits.view(1, 2), label)
        loss_value = loss.item()

        instance_loss = instance_dict['instance_loss']
        inst_count+=1
        instance_loss_value = instance_loss.item()
        train_inst_loss += instance_loss_value

        total_loss = bag_weight * loss + (1-bag_weight) * instance_loss

        inst_preds = instance_dict['inst_preds']
        inst_labels = instance_dict['inst_labels']
#         print("inst_preds", inst_preds.shape)
#         print("inst_labels", inst_labels.shape)
        inst_logger.log_batch(inst_preds, inst_labels)

        train_loss += loss_value
        if (batch_idx + 1) % 20 == 0:
            print('batch {}, loss: {:.4f}, instance_loss: {:.4f}, weighted_loss: {:.4f}, '.format(batch_idx, loss_value, instance_loss_value, total_loss.item()) +
                'label: {}, bag_size: {}'.format(label.item(), data.size(0)))

        error = calculate_error(Y_hat, label)
        train_error += error

        # backward pass
        total_loss.backward()
        # step
        optimizer.step()
        optimizer.zero_grad()

    # calculate loss and error for epoch
    train_loss /= len(loader)
    train_error /= len(loader)

    if inst_count > 0:
        train_inst_loss /= inst_count
        print('\n')
        for i in range(2):
            acc, correct, count = inst_logger.get_summary(i)
            print('class {} clustering acc {}: correct {}/{}'.format(i, acc, correct, count))

    print('Epoch: {}, train_loss: {:.4f}, train_clustering_loss:  {:.4f}, train_error: {:.4f}'.format(epoch, train_loss, train_inst_loss,  train_error))
    for i in range(n_classes):
        acc, correct, count = acc_logger.get_summary(i)
        print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
        if writer and acc is not None:
            writer.add_scalar('train/class_{}_acc'.format(i), acc, epoch)

    if writer:
        writer.add_scalar('train/loss', train_loss, epoch)
        writer.add_scalar('train/error', train_error, epoch)
        writer.add_scalar('train/clustering_loss', train_inst_loss, epoch)

def train_loop(epoch, model, loader, optimizer, n_classes, writer = None, loss_fn = None):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    acc_logger = Accuracy_Logger(n_classes=n_classes)
    train_loss = 0.
    train_error = 0.

    print('\n')
    for batch_idx, (data, label) in enumerate(loader):
        data, label = data.to(device), label.to(device)

        logits, Y_prob, Y_hat, _, _ = model(data)

        acc_logger.log(Y_hat, label)
        loss = loss_fn(logits.view(1, 2), label)
        loss_value = loss.item()

        train_loss += loss_value
        if (batch_idx + 1) % 20 == 0:
            print('batch {}, loss: {:.4f}, label: {}, bag_size: {}'.format(batch_idx, loss_value, label.item(), data.size(0)))

        error = calculate_error(Y_hat, label)
        train_error += error

        # backward pass
        loss.backward()
        # step
        optimizer.step()
        optimizer.zero_grad()

    # calculate loss and error for epoch
    train_loss /= len(loader)
    train_error /= len(loader)

    print('Epoch: {}, train_loss: {:.4f}, train_error: {:.4f}'.format(epoch, train_loss, train_error))
    for i in range(n_classes):
        acc, correct, count = acc_logger.get_summary(i)
        print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
        if writer:
            writer.add_scalar('train/class_{}_acc'.format(i), acc, epoch)

    if writer:
        writer.add_scalar('train/loss', train_loss, epoch)
        writer.add_scalar('train/error', train_error, epoch)


def validate(cur, epoch, model, loader, n_classes, early_stopping = None, writer = None, loss_fn = None, results_dir=None):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    acc_logger = Accuracy_Logger(n_classes=n_classes)
    # loader.dataset.update_mode(True)
    val_loss = 0.
    val_error = 0.

    prob = np.zeros((len(loader), n_classes))
    labels = np.zeros(len(loader))

    with torch.no_grad():
        for batch_idx, (data, label) in enumerate(loader):
            data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True)

            logits, Y_prob, Y_hat, _, _ = model(data)

            acc_logger.log(Y_hat, label)

            loss = loss_fn(logits.view(1, 2), label)

            prob[batch_idx] = Y_prob.cpu().numpy()
            labels[batch_idx] = label.item()

            val_loss += loss.item()
            error = calculate_error(Y_hat, label)
            val_error += error


    val_error /= len(loader)
    val_loss /= len(loader)

    if n_classes == 2:
        auc = roc_auc_score(labels, prob[:, 1])

    else:
        auc = roc_auc_score(labels, prob, multi_class='ovr')


    if writer:
        writer.add_scalar('val/loss', val_loss, epoch)
        writer.add_scalar('val/auc', auc, epoch)
        writer.add_scalar('val/error', val_error, epoch)

    print('\nVal Set, val_loss: {:.4f}, val_error: {:.4f}, auc: {:.4f}'.format(val_loss, val_error, auc))
    for i in range(n_classes):
        acc, correct, count = acc_logger.get_summary(i)
        print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))

    if early_stopping:
        exp_dir = os.path.join(settings["results_dir"], str(settings["exp_code"]) + '_s{}'.format(settings["seed"]))
        assert exp_dir
        split_dir = os.path.join(exp_dir, 'splits_{}'.format(cur))
        assert split_dir
        early_stopping(epoch, val_loss, model, ckpt_name = os.path.join(split_dir, "s_{}_checkpoint.pt".format(cur)))

        if early_stopping.early_stop:
            print("Early stopping")
            return True

    return False

def validate_clam(cur, epoch, model, loader, n_classes, early_stopping = None, writer = None, loss_fn = None, results_dir = None):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    acc_logger = Accuracy_Logger(n_classes=n_classes)
    inst_logger = Accuracy_Logger(n_classes=n_classes)
    val_loss = 0.
    val_error = 0.

    val_inst_loss = 0.
    val_inst_acc = 0.
    inst_count=0

    prob = np.zeros((len(loader), n_classes))
    labels = np.zeros(len(loader))
    sample_size = model.k_sample
    with torch.no_grad():
        for batch_idx, (data, label) in enumerate(loader):
            data, label = data.to(device), label.to(device)
            logits, Y_prob, Y_hat, _, instance_dict = model(data, label=label, instance_eval=True)
            acc_logger.log(Y_hat, label)

            loss = loss_fn(logits.view(1, 2), label)

            val_loss += loss.item()

            instance_loss = instance_dict['instance_loss']

            inst_count+=1
            instance_loss_value = instance_loss.item()
            val_inst_loss += instance_loss_value

            inst_preds = instance_dict['inst_preds']
            inst_labels = instance_dict['inst_labels']
            inst_logger.log_batch(inst_preds, inst_labels)

            prob[batch_idx] = Y_prob.cpu().numpy()
            labels[batch_idx] = label.item()

            error = calculate_error(Y_hat, label)
            val_error += error

    val_error /= len(loader)
    val_loss /= len(loader)

    if n_classes == 2:
        auc = roc_auc_score(labels, prob[:, 1])
        aucs = []
    else:
        aucs = []
        binary_labels = label_binarize(labels, classes=[i for i in range(n_classes)])
        for class_idx in range(n_classes):
            if class_idx in labels:
                fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], prob[:, class_idx])
                aucs.append(calc_auc(fpr, tpr))
            else:
                aucs.append(float('nan'))

        auc = np.nanmean(np.array(aucs))

    print('\nVal Set, val_loss: {:.4f}, val_error: {:.4f}, auc: {:.4f}'.format(val_loss, val_error, auc))
    if inst_count > 0:
        val_inst_loss /= inst_count
        for i in range(2):
            acc, correct, count = inst_logger.get_summary(i)
            print('class {} clustering acc {}: correct {}/{}'.format(i, acc, correct, count))

    if writer:
        writer.add_scalar('val/loss', val_loss, epoch)
        writer.add_scalar('val/auc', auc, epoch)
        writer.add_scalar('val/error', val_error, epoch)
        writer.add_scalar('val/inst_loss', val_inst_loss, epoch)


    for i in range(n_classes):
        acc, correct, count = acc_logger.get_summary(i)
        print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))

        if writer and acc is not None:
            writer.add_scalar('val/class_{}_acc'.format(i), acc, epoch)


    if early_stopping:
        exp_dir = os.path.join(settings["results_dir"], str(settings["exp_code"]) + '_s{}'.format(settings["seed"]))
        assert exp_dir
        split_dir = os.path.join(exp_dir, 'splits_{}'.format(cur))
        assert split_dir
        early_stopping(epoch, val_loss, model, ckpt_name = os.path.join(split_dir, "s_{}_checkpoint.pt".format(cur)))

        if early_stopping.early_stop:
            print("Early stopping")
            return True

    return False

def summary(model, loader, n_classes):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    acc_logger = Accuracy_Logger(n_classes=n_classes)
    model.eval()
    test_loss = 0.
    test_error = 0.

    all_probs = np.zeros((len(loader), n_classes))
    all_labels = np.zeros(len(loader))

    slide_ids = loader.dataset.slide_data['slide_id']
    patient_results = {}

    for batch_idx, (data, label) in enumerate(loader):
        data, label = data.to(device), label.to(device)
        slide_id = slide_ids.iloc[batch_idx]
        with torch.no_grad():
            logits, Y_prob, Y_hat, _, _ = model(data)

        acc_logger.log(Y_hat, label)
        probs = Y_prob.cpu().numpy()
        all_probs[batch_idx] = probs
        all_labels[batch_idx] = label.item()

        patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}})
        error = calculate_error(Y_hat, label)
        test_error += error

    test_error /= len(loader)

    if n_classes == 2:
        auc = roc_auc_score(all_labels, all_probs[:, 1])
        aucs = []
    else:
        aucs = []
        binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
        for class_idx in range(n_classes):
            if class_idx in all_labels:
                fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx])
                aucs.append(calc_auc(fpr, tpr))
            else:
                aucs.append(float('nan'))

        auc = np.nanmean(np.array(aucs))


    return patient_results, test_error, auc, acc_logger

def get_alpha_weight(epoch, T1, T2, af):
    if epoch < T1:
        return 0.0
    elif epoch > T2:
        return af
    else:
         return ((epoch-T1) / (T2-T1))*af

In [98]:
def seed_torch(seed=7):
    import random
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if device.type == 'cuda':
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [99]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_torch(seed)

encoding_size = 1024
settings = {
    'k': k,
    'k_start': k_start,
    'k_end': k_end,
    'task': task,
    'max_epochs': max_epochs,
    'results_dir': results_dir,
    'lr': lr,
    'experiment': exp_code,
    'reg': reg,
    'label_frac': label_frac,
    'bag_loss': bag_loss,
    'seed': seed,
    'model_type': model_type,
    'model_size': model_size,
    "use_drop_out": drop_out,
    'weighted_sample': weighted_sample,
    'opt': opt,
    'patch_dir': patch_dir,
    'feat_dir': feat_dir,
    'label_frac': label_frac,
    'split_dir': split_dir,
    'log_data': log_data,
    'dataset_csv': dataset_csv,
    'testing': testing,
    'early_stopping': early_stopping,
    'dropout': dropout,
    'no_inst_cluster': no_inst_cluster,
    'subtyping': subtyping,
    'exp_code': exp_code,
    'bag_weight': bag_weight,
    'inst_loss': inst_loss,
    'B': B,
    'annot_dir': annot_dir,
    'patch_annot_dir': patch_annot_dir,
    'alpha_weight': alpha_weight,
    'T1': T1,
    'T2': T2,
    'af': af
}

In [100]:
with open(os.path.join(results_dir, 'config.yaml'), 'w') as yaml_file:
    yaml.dump(settings, yaml_file, default_flow_style=False)

In [101]:
if task == 'task_fungal_vs_nonfungal':
    n_classes = 2
    settings.update({'n_classes': n_classes})
    dataset = Generic_MIL_Dataset(csv_path=dataset_csv,
                                  data_dir=feat_dir,
                                  annot_dir=annot_dir,
                                  patch_annot_dir=patch_annot_dir,
                                  results_dir=results_dir,
                                  shuffle=False,
                                  seed=seed,
                                  print_info=True,
                                  label_dict={'nonfungal': 0, 'fungal': 1},
                                  patient_strat=False,
                                  ignore=[])

elif task == 'task_1_tumor_vs_normal':
    n_classes = 2
    settings.update({'n_classes': n_classes})
    dataset = Generic_MIL_Dataset(csv_path='dataset_csv/tumor_vs_normal_dummy_clean.csv',
                                  data_dir=os.path.join(
                                      data_root_dir, 'tumor_vs_normal_resnet_features'),
                                  shuffle=False,
                                  seed=seed,
                                  print_info=True,
                                  label_dict={'normal_tissue': 0,
                                              'tumor_tissue': 1},
                                  patient_strat=False,
                                  ignore=[])

elif task == 'task_2_tumor_subtyping':
    n_classes = 3
    settings.update({'n_classes': n_classes})
    dataset = Generic_MIL_Dataset(csv_path='dataset_csv/tumor_subtyping_dummy_clean.csv',
                                  data_dir=os.path.join(
                                      data_root_dir, 'tumor_subtyping_resnet_features'),
                                  shuffle=False,
                                  seed=seed,
                                  print_info=True,
                                  label_dict={'subtype_1': 0,
                                              'subtype_2': 1, 'subtype_3': 2},
                                  patient_strat=False,
                                  ignore=[])

    if model_type in ['clam_sb', 'clam_mb']:
        assert subtyping

else:
    raise NotImplementedError

with open(results_dir + '/experiment_{}.txt'.format(exp_code), 'w') as f:
    print(settings, file=f)
f.close()

print("################# Settings ###################")
for key, val in settings.items():
    print("{}:  {}".format(key, val))

      case_id  slide_id label
0      case_0   F005a02     1
1      case_1   F006a01     1
2      case_2   F006a02     1
3      case_3   F006a03     1
4      case_4   F006a04     1
..        ...       ...   ...
418  case_418  N012a017     0
419  case_419  N012a018     0
420  case_420  N012a019     0
421  case_421  N012a020     0
422  case_422  N017a001     0

[423 rows x 3 columns]
label column: label
label dictionary: {'nonfungal': 0, 'fungal': 1}
number of classes: 2
slide-level counts:  
 1    208
0    215
Name: label, dtype: int64
Patient-LVL; Number of samples registered in class 0: 215
Slide-LVL; Number of samples registered in class 0: 215
Patient-LVL; Number of samples registered in class 1: 208
Slide-LVL; Number of samples registered in class 1: 208
################# Settings ###################
k:  5
k_start:  -1
k_end:  -1
task:  task_fungal_vs_nonfungal
max_epochs:  200
results_dir:  /kaggle/working/results/exp_01_s1
lr:  0.0001
experiment:  exp_01
reg:  1e-05
label_frac:  1

In [102]:
start = 0 if k_start == -1 else k_start
end = k if k_end == -1 else k_end

all_test_auc = []
all_val_auc = []
all_test_acc = []
all_val_acc = []
folds = np.arange(start, end)
for i in folds:
    seed_torch(seed)
    train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False,
            csv_path='{}/splits_{}.csv'.format(split_dir, i))

    datasets = (train_dataset, val_dataset, test_dataset)

    results, test_auc, val_auc, test_acc, val_acc  = train(datasets, i, settings)
    all_test_auc.append(test_auc)
    all_val_auc.append(val_auc)
    all_test_acc.append(test_acc)
    all_val_acc.append(val_acc)
    #write results to pkl
    filename = os.path.join(results_dir, "splits_{}".format(i), 'split_{}_results.pkl'.format(i))
    save_pkl(filename, results)

final_df = pd.DataFrame({'folds': folds, 'test_auc': all_test_auc,
    'val_auc': all_val_auc, 'test_acc': all_test_acc, 'val_acc' : all_val_acc})

if len(folds) != k:
    save_name = 'summary_partial_{}_{}.csv'.format(start, end)
else:
    save_name = 'summary.csv'
final_df.to_csv(os.path.join(results_dir, save_name))


Settings: {'k': 5, 'k_start': -1, 'k_end': -1, 'task': 'task_fungal_vs_nonfungal', 'max_epochs': 200, 'results_dir': '/kaggle/working/results/exp_01_s1', 'lr': 0.0001, 'experiment': 'exp_01', 'reg': 1e-05, 'label_frac': 1.0, 'bag_loss': 'ce', 'seed': 1, 'model_type': 'clam_sb', 'model_size': 'small', 'use_drop_out': False, 'weighted_sample': False, 'opt': 'adam', 'patch_dir': '/kaggle/working/patches', 'feat_dir': '/kaggle/working/features', 'split_dir': '/kaggle/working/splits/fungal_vs_nonfungal_100/', 'log_data': False, 'dataset_csv': '/kaggle/working/fungal_vs_nonfungal.csv', 'testing': False, 'early_stopping': False, 'dropout': True, 'no_inst_cluster': False, 'subtyping': False, 'exp_code': 'exp_01', 'bag_weight': 0.5, 'inst_loss': None, 'B': 12, 'annot_dir': '/kaggle/input/fungal-10x-annot', 'patch_annot_dir': '/kaggle/working/patch_annot', 'alpha_weight': False, 'T1': 50, 'T2': 150, 'af': 1.0, 'n_classes': 2}

Training Fold 0!

Init train/val/test splits... (<__main__.Generic_Sp

TypeError: expected str, bytes or os.PathLike object, not tuple

In [1]:
!ls features

F005a02.npy  F021a02.npy  F053a06.npy	N004a043.npy  N004a128.npy
F006a01.npy  F021a03.npy  F053a07.npy	N004a044.npy  N004a129.npy
F006a02.npy  F021a04.npy  F053a08.npy	N004a045.npy  N004a130.npy
F006a03.npy  F021a05.npy  F053a09.npy	N004a046.npy  N004a131.npy
F006a04.npy  F030a01.npy  F053a10.npy	N004a047.npy  N004a132.npy
F006a05.npy  F030a02.npy  F053a11.npy	N004a048.npy  N004a133.npy
F006a06.npy  F030a03.npy  F053a12.npy	N004a049.npy  N004a134.npy
F006a07.npy  F030a04.npy  F053a13.npy	N004a050.npy  N004a135.npy
F006a08.npy  F030a05.npy  F053a14.npy	N004a051.npy  N004a136.npy
F006a09.npy  F030a06.npy  F053a15.npy	N004a052.npy  N004a137.npy
F006a10.npy  F030a07.npy  F053a16.npy	N004a053.npy  N005a001.npy
F007a01.npy  F030a08.npy  F053a17.npy	N004a054.npy  N005a002.npy
F007a02.npy  F030a09.npy  F053a18.npy	N004a055.npy  N005a003.npy
F007a03.npy  F030a10.npy  F053a19.npy	N004a056.npy  N005a004.npy
F007a04.npy  F030a11.npy  F056a01.npy	N004a057.npy  N005a005.npy
F007a05.npy  F030a12.npy 