In [1]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, RandomCrop, RandomRotation, Normalize

from repath.utils.paths import project_root
import repath.data.datasets.camelyon16 as camelyon16
from repath.preprocess.tissue_detection import TissueDetectorOTSU
from repath.preprocess.patching import GridPatchFinder, SlidesIndex, SlidesIndexResults
from repath.preprocess.sampling import split_camelyon16, balanced_sample
from torchvision.models import GoogLeNet

In [2]:
experiment_name = "wang"
experiment_root = project_root() / "experiments" / experiment_name
tissue_detector = TissueDetectorOTSU()

In [4]:
class PatchClassifier(pl.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.model = GoogLeNet(num_classes=2)

    def training_step(self, batch, batch_idx):
        x, y = batch
        output, aux2, aux1 = self.model(x)
        pred = torch.log_softmax(output, dim=1)

        criterion = nn.CrossEntropyLoss()
        loss1 = criterion(output, y)
        loss2 = criterion(aux1, y)
        loss3 = criterion(aux2, y)
        loss = loss1 + 0.3 * loss2 + 0.3 * loss3
        self.log("train_loss", loss)
        
        correct=pred.argmax(dim=1).eq(y).sum().item()
        total=len(y)   
        accu = correct / total
        self.log("train_accuracy", accu)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        pred = torch.log_softmax(output, dim=1)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, y)
        self.log("val_loss", loss)
        
        correct=pred.argmax(dim=1).eq(y).sum().item()
        total=len(y)   
        accu = correct / total
        self.log("val_accuracy", accu)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), 
                                    lr=0.01, 
                                    momentum=0.9, 
                                    weight_decay=0.0005)
        scheduler = {
            'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, step_size=50000, gamma=0.5),
            'interval': 'step' 
        }
        return [optimizer], [scheduler]

In [12]:
import numpy as np
import torch

from repath.postprocess.slide_dataset import SlideDataset

def evaluate_loop_dp(model, device, loader, num_classes):

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

    model.eval()
    model.to(device)

    num_samples = len(loader) * loader.batch_size

    prob_out = np.zeros((num_samples, num_classes))

    with torch.no_grad():
        for idx, batch in enumerate(loader):
            data, target = batch
            data = data.to(device)
            output = model(data)
            sm = torch.nn.Softmax(1)
            output_sm = sm(output)
            pred_prob = output_sm.cpu().numpy()  # rows: batch_size, cols: num_classes

            start = idx * loader.batch_size
            end = start + pred_prob.shape[0]
            prob_out[start:end, :] = pred_prob

            if idx % 100 == 0:
                print('Batch {} of {}'.format(idx, len(loader)))

    return prob_out


def evaluate_loop_threaded(model, device, loader, num_classes):

    model.eval()
    model.to(device)

    num_samples = len(loader) * loader.batch_size

    prob_out = np.zeros((num_samples, num_classes))

    with torch.no_grad():
        for idx, batch in enumerate(loader):
            data, target = batch
            data = data.to(device)
            output = model(data)
            sm = torch.nn.Softmax(1)
            output_sm = sm(output)
            pred_prob = output_sm.cpu().numpy()  # rows: batch_size, cols: num_classes

            start = idx * loader.batch_size
            end = start + pred_prob.shape[0]
            prob_out[start:end, :] = pred_prob

            if idx % 100 == 0:
                print('Batch {} of {}'.format(idx, len(loader)))

    return prob_out



def inference_on_slide(slideps: 'SlidePatchSet', model: torch.nn.Module, num_classes: int,
                       batch_size: int, num_workers: int, transform) -> np.array:

    """ runs inference for every patch on a slide using data parallel

    Outputs probabilities for each class

    Args:
        slideps: A SlidePatchSet object containing all non background patches for the slide
        model: a patch classifier model
        num_classes: the number of output classes predicted by the model
        batch_size: the batch size for inference
        num_workers: the num_workers for inference
        ntransforms: the number of predictions per patch. Each patch can be predicted multiple times eg rotations
            or flips, the mean across thes transforms is found for each patch

    Returns:
        An ndarray the same length as the slide dataset with a column for each class containing a float that
        represents the probability of the patch being that class.
    """    


    # Check if GPU is available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    slide_dataset = SlideDataset(slideps, transform)

    test_loader = torch.utils.data.DataLoader(slide_dataset, shuffle=False,
                                              batch_size=batch_size,  num_workers=num_workers)

    probabilities = evaluate_loop_dp(model, device, test_loader, num_classes)

    ### HACK - ntransforms only needed for google paper need to sort out using transform compose or list of transform compose
    ntransforms = 1

    npreds = int(len(slide_dataset) * ntransforms)

    probabilities = probabilities[0:npreds, :]

    if ntransforms > 1:
        prob_rows = probabilities.shape[0]
        prob_rows = int(prob_rows / ntransforms)
        probabilities_reshape = np.empty((prob_rows, num_classes))
        for cl in num_classes:
            class_probs = probabilities[:, cl]
            class_probs = np.reshape(class_probs, (ntransforms, prob_rows)).T
            class_probs = np.mean(class_probs, axis=1)
            probabilities_reshape[:, cl] = class_probs
        probabilities = probabilities_reshape

    return probabilities


def inference_on_slide_threaded(slideps: 'SlidePatchSet', model: torch.nn.Module, num_classes: int,
                       batch_size: int, num_workers: int, transform, device) -> np.array:

    """ runs inference for every patch on a slide using data parallel

    Outputs probabilities for each class

    Args:
        slideps: A SlidePatchSet object containing all non background patches for the slide
        model: a patch classifier model
        num_classes: the number of output classes predicted by the model
        batch_size: the batch size for inference
        num_workers: the num_workers for inference
        ntransforms: the number of predictions per patch. Each patch can be predicted multiple times eg rotations
            or flips, the mean across thes transforms is found for each patch

    Returns:
        An ndarray the same length as the slide dataset with a column for each class containing a float that
        represents the probability of the patch being that class.
        
    """

    # Check if GPU is available
    device = torch.device(f"cuda:{device}" if torch.cuda.is_available() else "cpu")

    slide_dataset = SlideDataset(slideps, transform)

    test_loader = torch.utils.data.DataLoader(slide_dataset, shuffle=False,
                                              batch_size=batch_size,  num_workers=num_workers)

    probabilities = evaluate_loop_threaded(model, device, test_loader, num_classes)

    ### HACK - ntransforms only needed for google paper need to sort out using transform compose or list of transform compose
    ntransforms = 1

    npreds = int(len(slide_dataset) * ntransforms)

    probabilities = probabilities[0:npreds, :]

    if ntransforms > 1:
        prob_rows = probabilities.shape[0]
        prob_rows = int(prob_rows / ntransforms)
        probabilities_reshape = np.empty((prob_rows, num_classes))
        for cl in num_classes:
            class_probs = probabilities[:, cl]
            class_probs = np.reshape(class_probs, (ntransforms, prob_rows)).T
            class_probs = np.mean(class_probs, axis=1)
            probabilities_reshape[:, cl] = class_probs
        probabilities = probabilities_reshape

    return probabilities



### Below is initial pseudo code on ddp needs developing to speed up inference
#def process_predict(rank, loader_for_subset, model, batch_size, num_classes):
#    model = model.to_gpu(rank)
#    output = Tensor.empty(len(loader_for_subset) * batch_size, num_classes))
#    for batch_idx, batch in enumerate(loader_for_subset):
#        batch = batch.to_gpu(rank)
#        y = model(batch)
#        y = nn.softmax(y)
#        output[batch_idx * batch_size, :] = y
#    return output
#
#
#torch.multiprocessing.spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn')#
#
#
#def distibuted_predict(model, patchset):
#    dataset = SlideDataset(patchset)
#    sampler = SequentialSampler(dataset)


In [27]:
from collections import namedtuple
from itertools import chain
from pathlib import Path
import threading
from typing import List, Sequence

import cv2
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.utils import shuffle
import torch
from torch import nn
from torchvision import transforms

from repath.data.datasets import Dataset
from repath.preprocess.patching.patch_finder import PatchFinder
from repath.preprocess.tissue_detection.tissue_detector import TissueDetector
from repath.data.slides import Region
from repath.utils.convert import remove_item_from_dict
from repath.postprocess.prediction import inference_on_slide


class PatchSet(Sequence):
    def __init__(
        self,
        dataset: Dataset,
        patch_size: int,
        level: int,
        patches_df: pd.DataFrame,
    ) -> None:
        self.dataset = dataset
        self.patch_size = patch_size
        self.level = level
        self.patches_df = patches_df

    def __len__(self):
        return len(self.patches_df)

    def __getitem__(self, idx):
        return self.patches_df.iloc[idx,]

    def summary(self) -> pd.DataFrame:
        by_label = self.patches_df.groupby("label").size()
        labels = {v: k for k, v in self.dataset.labels.items()}
        count_df = by_label.to_frame().T.rename(columns = labels)
        columns = list(labels.values())
        summary = pd.DataFrame(columns=columns)
        for l in labels.values():
            if l in count_df:
                summary[l] = count_df[l]
            else:
                summary[l] = 0
        summary = summary.replace(np.nan, 0)  # if there are no patches for some classes
        return summary


class CombinedPatchSet(PatchSet):
    def __init__(self, dataset: Dataset, patch_size: int, level: int, patches_df: pd.DataFrame) -> None:
        super().__init__(dataset, patch_size, level, patches_df)
        # columns of patches_df are x, y, label, slide_idx

    def save_patches(self, output_dir: Path, transforms: List[transforms.Compose] = None) -> None:
        for slide_idx, group in self.patches_df.groupby('slide_idx'):
            slide_path, _, _, _ = self.dataset[slide_idx]
            with self.dataset.slide_cls(slide_path) as slide:
                print(f"Writing patches for {self.dataset.to_rel_path(slide_path)}")
                for row in group.itertuples():
                    # read the patch image from the slide
                    region = Region.patch(row.x, row.y, self.patch_size, self.level)
                    image = slide.read_region(region)

                    # apply any transforms, as indexed in the 'transform' column
                    if transforms:
                        image = transforms[row.transform](image)

                    # get the patch label as a string
                    labels = {v: k for k, v in self.dataset.labels.items()}
                    label = labels[row.label]

                    # ensure the output directory exists
                    output_subdir = output_dir / label
                    output_subdir.mkdir(parents=True, exist_ok=True)

                    # write out the slide
                    rel_slide_path = self.dataset.to_rel_path(slide_path)
                    slide_name_str = str(rel_slide_path)[:-4].replace('/', '-')
                    patch_filename = slide_name_str + f"-{row.x}-{row.y}.png"
                    image_path = output_dir / label / patch_filename
                    image.save(image_path)


class CombinedIndex(object):
    def __init__(self, cps: List[CombinedPatchSet]) -> None:
        self.datasets = [cp.dataset for cp in cps]
        self.patchsizes = [cp.patch_size for cp in cps]
        self.levels = [cp.level for cp in cps]
        patches_dfs = [cp.patches_df for cp in cps]
        patches_df = pd.concat(patches_dfs, axis=0)
        cps_index = [[idx] * len(cp) for idx, cp in enumerate(cps)]
        cps_index = [item for sublist in cps_index for item in sublist]
        patches_df['cps_idx'] = cps_index
        self.patches_df = patches_df

    def __len__(self):
        return len(self.patches_df)

    @classmethod
    def for_slide_indexes(cls, indexes: List['SlidesIndex']) -> 'CombinedIndex':
        cps = [index.as_combined() for index in indexes]
        ci = cls(cps)
        return ci

    def save_patches(self, output_dir: Path, transforms: List[transforms.Compose] = None, affix: str = '') -> None:
        for cps_idx, cps_group in self.patches_df.groupby('cps_idx'):
            for slide_idx, sl_group in cps_group.groupby('slide_idx'):
                slide_path, _, _, _ = self.datasets[cps_idx][slide_idx]
                with self.datasets[cps_idx].slide_cls(slide_path) as slide:
                    print(f"Writing patches for {self.datasets[cps_idx].to_rel_path(slide_path)}")
                    for row in sl_group.itertuples():
                        # read the patch image from the slide
                        region = Region.patch(row.x, row.y, self.patchsizes[cps_idx], self.levels[cps_idx])
                        image = slide.read_region(region)

                        # apply any transforms, as indexed in the 'transform' column
                        if transforms:
                            image = transforms[row.transform](image)

                        # get the patch label as a string
                        labels = {v: k for k, v in self.datasets[cps_idx].labels.items()}
                        label = labels[row.label]

                        # ensure the output directory exists
                        output_subdir = output_dir / label
                        output_subdir.mkdir(parents=True, exist_ok=True)

                        # write out the slide
                        rel_slide_path = self.datasets[cps_idx].to_rel_path(slide_path)
                        slide_name_str = str(rel_slide_path)[:-4].replace('/', '-')
                        patch_filename = slide_name_str + f"-{row.x}-{row.y}{affix}.png"
                        image_path = output_dir / label / patch_filename
                        image.save(image_path)



class SlidePatchSet(PatchSet):
    def __init__(
        self, 
        slide_idx: int,
        dataset: Dataset,
        patch_size: int,
        level: int,
        patches_df: pd.DataFrame
    ) -> None:
        super().__init__(dataset, patch_size, level, patches_df)
        self.slide_idx = slide_idx
        abs_slide_path, self.annotation_path, self.label, tags = dataset[slide_idx]
        self.slide_path = dataset.to_rel_path(abs_slide_path)
        self.tags = [tg.strip() for tg in tags.split(';')]

    @classmethod
    def index_slide(cls, slide_idx: int, dataset: Dataset, tissue_detector: TissueDetector, patch_finder: PatchFinder):
        slide_path, annotation_path, _, _ = dataset[slide_idx]
        with dataset.slide_cls(slide_path) as slide:
            print(f"indexing {slide_path.name}")  # TODO: Add proper logging!
            annotations = dataset.load_annotations(annotation_path)
            labels_shape = slide.dimensions[patch_finder.labels_level].as_shape()
            scale_factor = 2 ** patch_finder.labels_level
            labels_image = annotations.render(labels_shape, scale_factor)
            tissue_mask = tissue_detector(slide.get_thumbnail(patch_finder.labels_level))
            labels_image[~tissue_mask] = 0
            df, level, size = patch_finder(labels_image, slide.dimensions[patch_finder.patch_level])
            patchset = cls(slide_idx, dataset, size, level, df)
            return patchset

    @property
    def abs_slide_path(self):
        return self.dataset.to_abs_path(self.slide_path)

    def open_slide(self):
        return self.dataset.slide_cls(self.abs_slide_path)


class SlidesIndex(Sequence):
    def __init__(self, dataset: Dataset, patches: List[SlidePatchSet]) -> None:
        self.dataset = dataset
        self.patches = patches

    @classmethod
    def index_dataset(cls, dataset: Dataset, tissue_detector: TissueDetector, patch_finder: PatchFinder) -> 'SlidesIndex':
        patchsets = [SlidePatchSet.index_slide(idx, dataset, tissue_detector, patch_finder) for idx in range(len(dataset))]
        return cls(dataset, patchsets)

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        return self.patches[idx]

    @property
    def patch_size(self):
        return self.patches[0].patch_size

    @property
    def level(self):
        return self.patches[0].level

    def summary(self) -> pd.DataFrame:
        summaries = [p.summary() for p in self.patches]
        slide_path = [p.slide_path for p in self.patches]
        slide_label = [p.label for p in self.patches]
        rtn = pd.concat(summaries)
        rtn['slide_path'] = slide_path
        rtn['slide_label'] = slide_label
        rtn = rtn.reset_index()
        rtn = rtn.drop('index', axis=1)
        rtn = rtn[['slide_path', 'slide_label'] + list(self.dataset.labels.keys())]
        return rtn

    def as_combined(self) -> CombinedPatchSet:
        # combine all patchsets into one
        frames = [ps.patches_df for ps in self.patches]
        slide_indexes = [[ps.slide_idx]*len(ps) for ps in self.patches]
        slide_indexes = list(chain(*slide_indexes))
        patches_df =  pd.concat(frames, axis=0)
        patches_df['slide_idx'] = slide_indexes
        return CombinedPatchSet(self.dataset, self.patch_size, self.level, patches_df)

    def save(self, output_dir: Path) -> None:
        columns = ['slide_idx', 'csv_path', 'level', 'patch_size']
        index_df = pd.DataFrame(columns=columns)
        for ps in self.patches:
            # save out the csv file for this slide
            csv_path = ps.slide_path.with_suffix('.csv')
            csv_path = output_dir / csv_path
            csv_path.parents[0].mkdir(parents=True, exist_ok=True)
            print(f"Saving {csv_path}")
            ps.patches_df.to_csv(csv_path, index=False)

            # append information about slide to index
            info = np.array([ps.slide_idx, csv_path, ps.level, ps.patch_size])
            info = np.reshape(info, (1, 4))
            row = pd.DataFrame(info, columns=columns)
            index_df = index_df.append(row, ignore_index=True)

        # tidy up a bit and save the csv
        index_df = index_df.astype({"level": int, "patch_size": int})
        output_dir.mkdir(parents=True, exist_ok=True)
        index_df.to_csv(output_dir / 'index.csv', index=False)

    @classmethod
    def load(cls, dataset: Dataset, input_dir: Path) -> 'SlidesIndex':
        def patchset_from_row(r) -> PatchSet:
            patches_df = pd.read_csv(input_dir / r.csv_path)
            return SlidePatchSet(int(r.slide_idx), dataset, int(r.patch_size), 
                                 int(r.level), patches_df)

        index = pd.read_csv(input_dir / 'index.csv')
        patches = [patchset_from_row(r) for r in index.itertuples()]
        rtn = cls(dataset, patches)
        return rtn


class SlidePatchSetResults(SlidePatchSet):
    def __init__(self, slide_idx: int, dataset: Dataset, patch_size: int, level: int, patches_df: pd.DataFrame) -> None:
        super().__init__(slide_idx, dataset, patch_size, level, patches_df)
        abs_slide_path, self.annotation_path, self.label, self.tags = dataset[slide_idx]
        self.slide_path = dataset.to_rel_path(abs_slide_path)

    @classmethod
    def predict_slide(cls, sps: SlidePatchSet, classifier: nn.Module, batch_size: int, nworkers: int,
                      transform):

        just_patch_classes = remove_item_from_dict(sps.dataset.labels, "background")
        num_classes = len(just_patch_classes)
        probs_out = inference_on_slide(sps, classifier, num_classes, batch_size, nworkers, transform)
        probs_df = pd.DataFrame(probs_out, columns=list(just_patch_classes.keys()))
        probs_df = pd.concat((sps.patches_df, probs_df), axis=1)
        patchsetresults = cls(sps.slide_idx, sps.dataset, sps.patch_size, sps.level, probs_df)
        return patchsetresults

    def to_heatmap(self, class_name: str) -> np.array:
        self.patches_df.columns = [colname.lower() for colname in self.patches_df.columns]
        class_name = class_name.lower()

        self.patches_df['column'] = np.divide(self.patches_df.x, self.patch_size)
        self.patches_df['row'] = np.divide(self.patches_df.y, self.patch_size)

        max_rows = int(np.max(self.patches_df.row)) + 1
        max_cols = int(np.max(self.patches_df.column)) + 1

        # create a blank thumbnail
        thumbnail_out = np.zeros((max_rows, max_cols))

        # for each row in dataframe set the value of the pixel specified by row and column to the probability in clazz
        for rw in range(len(self)):
            df_row = self.patches_df.iloc[rw]
            thumbnail_out[int(df_row.row), int(df_row.column)] = df_row[class_name]

        return thumbnail_out

    def save_csv(self, output_dir):
        # save out the patches csv file for this slide
        csv_path = output_dir / self.slide_path.with_suffix('.csv')
        self.patches_df.to_csv(csv_path, index=False)

    def save_heatmap(self, class_name: str, output_dir: Path):
        # get the heatmap filename for this slide
        img_path = output_dir / self.slide_path.with_suffix('.png')
        # create heatmap and write out
        heatmap = self.to_heatmap(class_name)
        heatmap_out = np.array(np.multiply(heatmap, 255), dtype=np.uint8)
        cv2.imwrite(str(img_path), heatmap_out)

    @classmethod
    def predict_slide_threaded(cls, sps: SlidePatchSet, classifier: nn.Module, batch_size: int, nworkers: int,
                      transform, device: int):

        just_patch_classes = remove_item_from_dict(sps.dataset.labels, "background")
        num_classes = len(just_patch_classes)
        probs_out = inference_on_slide_threaded(sps, classifier, num_classes, batch_size, nworkers, transform, device)
        probs_df = pd.DataFrame(probs_out, columns=list(just_patch_classes.keys()))
        probs_df = pd.concat((sps.patches_df, probs_df), axis=1)
        patchsetresults = cls(sps.slide_idx, sps.dataset, sps.patch_size, sps.level, probs_df)
        return patchsetresults


class SlidesIndexResults(SlidesIndex):
    def __init__(self, dataset: Dataset, patches: List[SlidePatchSet],
                 output_dir: Path, results_dir_name: str, heatmap_dir_name: str) -> None:
        super().__init__(dataset, patches)
        self.output_dir = output_dir
        self.results_dir_name = results_dir_name
        self.heatmap_dir_name = heatmap_dir_name

    @classmethod
    def predict_dataset(cls,
                        si: SlidesIndex,
                        classifier: nn.Module,
                        batch_size,
                        num_workers,
                        transform,
                        output_dir: Path,
                        results_dir_name: str,
                        heatmap_dir_name: str) -> 'SlidesIndexResults':

        output_dir.mkdir(parents=True, exist_ok=True)
        results_dir = output_dir / results_dir_name
        results_dir.mkdir(parents=True, exist_ok=True)
        heatmap_dir = output_dir / heatmap_dir_name
        heatmap_dir.mkdir(parents=True, exist_ok=True)

        spsresults = []
        for sps in si:
            spsresult = SlidePatchSetResults.predict_slide(sps, classifier, batch_size, num_workers, transform)
            print(f"Saving {sps.slide_path}")
            results_slide_dir = results_dir / sps.slide_path.parents[0]
            results_slide_dir.mkdir(parents=True, exist_ok=True)
            spsresults.append(spsresult)
            spsresult.save_csv(results_dir)
            heatmap_slide_dir = heatmap_dir/ sps.slide_path.parents[0]
            heatmap_slide_dir.mkdir(parents=True, exist_ok=True)
            ### HACK since this is only binary at the moment it will always be the tumor heatmap we want need to change to work for multiple classes
            spsresult.save_heatmap('tumor', heatmap_dir)

        return cls(si.dataset, spsresults, output_dir, results_dir_name, heatmap_dir_name)

    def save_results_index(self):
        columns = ['slide_idx', 'csv_path', 'png_path', 'level', 'patch_size']
        index_df = pd.DataFrame(columns=columns)
        for ps in self.patches:
            # save out the csv file for this slide
            csv_path = self.output_dir / self.results_dir_name / ps.slide_path.with_suffix('.csv')
            png_path = self.output_dir / self.heatmap_dir_name / ps.slide_path.with_suffix('.png')

            # append information about slide to index
            info = np.array([ps.slide_idx, csv_path, png_path, ps.level, ps.patch_size])
            info = np.reshape(info, (1, 5))
            row = pd.DataFrame(info, columns=columns)
            index_df = index_df.append(row, ignore_index=True)

        # tidy up a bit and save the csv
        index_df = index_df.astype({"level": int, "patch_size": int})
        self.output_dir.mkdir(parents=True, exist_ok=True)
        index_df.to_csv(self.output_dir / 'results_index.csv', index=False)

    @classmethod
    def load_results_index(cls, dataset, input_dir, results_dir_name, heatmap_dir_name):
        def patchset_from_row(r: namedtuple) -> SlidePatchSet:
            patches_df = pd.read_csv(input_dir / r.csv_path)
            return SlidePatchSetResults(int(r.slide_idx), dataset, int(r.patch_size),
                                 int(r.level), patches_df)

        index = pd.read_csv(input_dir / 'results_index.csv')
        patches = [patchset_from_row(r) for r in index.itertuples()]
        rtn = cls(dataset, patches, input_dir, results_dir_name, heatmap_dir_name)
        return rtn


    @classmethod
    def predict_dataset_threaded(cls,
                        si: SlidesIndex,
                        classifier: nn.Module,
                        batch_size,
                        num_workers,
                        transform,
                        output_dir: Path,
                        results_dir_name: str,
                        heatmap_dir_name: str) -> 'SlidesIndexResults':

        output_dir.mkdir(parents=True, exist_ok=True)
        results_dir = output_dir / results_dir_name
        results_dir.mkdir(parents=True, exist_ok=True)
        heatmap_dir = output_dir / heatmap_dir_name
        heatmap_dir.mkdir(parents=True, exist_ok=True)

        ### experiment to distribute slides across multigpus for inference
        # find how many gpus
        ngpus = torch.cuda.device_count()

        # work out how many slides and numbers in each split
        nslides = len(si)
        splits = np.rint(np.linspace(0, nslides, num=(ngpus+1))).astype(int)
        start_indexes = splits[0:ngpus]
        end_indexes = splits[1:]
        print("splits:", splits)

        # shuffle slide index
        si = shuffle(si)
        si_per_gpu = []
        for ii in range(ngpus):
            print(start_indexes[ii], end_indexes[ii])
            si_gpu = si[start_indexes[ii]:end_indexes[ii]]
            si_per_gpu.append(si_gpu)

        def worker(num):
            si_thread = si_per_gpu[num]
            spsresults_thread = []
            for sps in si_thread:
                spsresult = SlidePatchSetResults.predict_slide_threaded(sps, classifier, batch_size, num_workers, transform, num)
                print(f"Saving {num}: {sps.slide_path}")
                results_slide_dir = results_dir / sps.slide_path.parents[0]
                results_slide_dir.mkdir(parents=True, exist_ok=True)
                spsresults_thread.append(spsresult)
                spsresult.save_csv(results_dir)
                heatmap_slide_dir = heatmap_dir/ sps.slide_path.parents[0]
                heatmap_slide_dir.mkdir(parents=True, exist_ok=True)
                ### HACK since this is only binary at the moment it will always be the tumor heatmap we want need to change to work for multiple classes
                spsresult.save_heatmap('tumor', heatmap_dir)         
            spsresults[num] = spsresults_thread
            return 

        spsresults = [0] * ngpus
        threads = []
        for i in range(ngpus):
            t = threading.Thread(target=worker, args=(i,))
            threads.append(t)
            t.start()
            t.join()

        spsresults_flat = [item for sublist in spsresults for item in sublist]


        return cls(si.dataset, spsresults, output_dir, results_dir_name, heatmap_dir_name)
    

In [6]:
cp_path = list((experiment_root / "patch_model").glob("*.ckpt"))[0]
classifier = PatchClassifier.load_from_checkpoint(checkpoint_path=cp_path)

In [7]:
output_dir16 = experiment_root / "train_index" / "pre_hnm_results"

results_dir_name = "results"
heatmap_dir_name = "heatmaps"

train16 = SlidesIndex.load(camelyon16.training(), experiment_root / "train_index")


In [8]:
transform = Compose([
    RandomCrop((240, 240)),
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
train_results16 = SlidesIndexResults.predict_dataset_threaded(train16, classifier, 128, 8, transform, output_dir16, results_dir_name, heatmap_dir_name)


splits: [  0  27  54  81 108 134 161 188 215]
0 27
27 54
54 81
81 108
108 134
134 161
161 188
188 215


In [None]:
train_results16.save_results_index()