In [1]:
# !pip3 install albumentations==1.2.1
# !pip3 install -e /data/pathology/projects/pathology-lung-TIL/nnUNet_v2/
# restart kernel

In [2]:
import numpy as np
from wholeslidedata.samplers.callbacks import BatchCallback
from time import time
import os

from nnunetv2.utilities.file_path_utilities import load_json
from nnunetv2.training.nnUNetTrainer.variants.pathology.nnUNetTrainer_custom_dataloader_test import nnUNetTrainer_custom_dataloader_test


from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
    ContrastAugmentationTransform, GammaTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform

from nnunetv2.training.data_augmentation.custom_transforms.pathology_transforms import HedTransform, HsvTransform, Clip01

nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up properly.
INSTEAD USING HARDCODED: /data/pathology/projects/pathology-lung-TIL/nnUNet_v2/data/nnUNet_raw
nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing or training. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up.
INSTEAD USING HARDCODED: /data/pathology/projects/pathology-lung-TIL/nnUNet_v2/data/nnUNet_preprocessed
nnUNet_results is not defined and nnU-Net cannot be used for training or inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information on how to set this up.
INSTEAD USING HARDCODED: /data/pathology/projects/pathology-lung-TIL/nn

In [3]:
current_os = "w" if os.name == "nt" else "l"
other_os = "l" if current_os == "w" else "w"

def convert_path(path, to=current_os):
    if to in ["w", "win", "windows"]:
        path = path.replace("/mnt/pa_cpg", "Y:")
        path = path.replace("/data/pathology", "Z:")
        path = path.replace("/mnt/pa_cpgarchive1", "W:")
        path = path.replace("/mnt/pa_cpgarchive2", "X:")
        path = path.replace("/", "\\")
    if to in ["u", "unix", "l", "linux"]:
        path = path.replace("Y:", "/mnt/pa_cpg")
        path = path.replace("Z:", "/data/pathology")
        path = path.replace("W:", "/mnt/pa_cpgarchive1")
        path = path.replace("X:", "/mnt/pa_cpgarchive2")
        path = path.replace("\\", "/")
    return path

In [4]:
class nnUnetBatchCallback(BatchCallback):
    
    # patch_size_spatial (width/height)
    def __init__(self, patch_size_spatial):
        tr_transforms = []
        rotation_for_DA= {'x': (-np.pi, np.pi), 'y': (0, 0), 'z': (0, 0)}

        tr_transforms.append(SpatialTransform(
            patch_size_spatial, 
            patch_center_dist_from_border=None,
            do_elastic_deform=False,
            alpha=(0, 0),
            sigma=(0, 0),
            do_rotation=True,
            angle_x=rotation_for_DA['x'],
            angle_y=rotation_for_DA['y'],
            angle_z=rotation_for_DA['z'],
            p_rot_per_axis=1,  # todo experiment with this
            do_scale=True,
            scale=(0.7, 1.4),
            border_mode_data="constant",
            border_cval_data=0,
            order_data=3,
            border_mode_seg="constant",
            border_cval_seg=-1,
            order_seg=1,
            random_crop=False,  # random cropping is part of our dataloaders
            p_el_per_sample=0,
            p_scale_per_sample=0.2,
            p_rot_per_sample=0.2,
            independent_scale_for_each_axis=False  # todo experiment with this
        ))

        ####
        if True: #do_hed:
            tr_transforms.append(HedTransform(factor=0.05))
        # if True: #do_hsv:
        #     tr_transforms.append(HsvTransform(h_lim=0.10, s_lim=0.10, v_lim=0.10))
        ####
        
        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))

        tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
                                                   p_per_channel=0.5))

        tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))

        tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))

        tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
                                                            p_per_channel=0.5,
                                                            order_downsample=0, order_upsample=3, p_per_sample=0.25,
                                                            ignore_axes=None))
 
        tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=0.1))

        tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=0.3))

        tr_transforms.append(Clip01())
        
        tr_transforms.append(MirrorTransform((0,1)))

        tr_transforms.append(RenameTransform('seg', 'target', True))
        
        
        self._transforms = Compose(tr_transforms)
    
    def __call__(self, x_batch, y_batch):
        # format to nnUNet
        x_batch = np.stack([x/255  for x in x_batch]).transpose((0, 3, 1, 2)).astype('float32')
        y_batch = np.expand_dims(np.stack(y_batch).astype('int8'), 1)
        
        # transform
        start_time = time()
        batch = self._transforms(**{'data': x_batch, 'seg': y_batch})
        line_time = time() - start_time
        print("Time taken for AUG (callback, multi thread):\t\t\t\t\t\t\t", line_time)
        
        # format back to wsd
        x_batch, y_batch = batch['data'], batch['target']
        x_batch = np.multiply(x_batch, 255).astype(np.uint8)
        return x_batch.transpose((0, 2, 3, 1)), y_batch.squeeze()

In [5]:
fold = 0
trainer_results_folder = convert_path('/data/pathology/projects/pathology-lung-TIL/nnUNet_v2/data/nnUNet_results/Dataset004_TIGER_split/nnUNetTrainer_custom_dataloader_test__nnUNetWholeSlideDataPlans__wsd_balanced_iterator_nnunet_aug__2d')
checkpoint_path = os.path.join(trainer_results_folder, f'fold_{fold}', 'checkpoint_best.pth')
plans_dict = load_json(os.path.join(trainer_results_folder, 'plans.json'))
dataset_dict = load_json(os.path.join(trainer_results_folder, 'dataset.json'))

In [6]:
dataset_dict

{'channel_names': {'0': 'rgb_to_0_1', '1': 'rgb_to_0_1', '2': 'rgb_to_0_1'},
 'labels': {'background': 0,
  'invasive tumor': 1,
  'tumor-associated stroma': 2,
  'in-situ tumor': 3,
  'healthy glands': 4,
  'necrosis not in-situ': 5,
  'inflamed stroma': 6,
  'rest': 7},
 'Unet_max_patch_size': [2048, 2048]}

In [7]:
dataset_dict = data = {
    'channel_names': {
        '0': 'rgb_to_0_1',
        '1': 'rgb_to_0_1',
        '2': 'rgb_to_0_1'
    },
    'labels': {
        'background':0,
        'roi': 0,
        'invasive tumor': 1,
        'tumor-associated stroma': 2,
        'in-situ tumor': 3,
        'healthy glands': 4,
        'necrosis not in-situ': 5,
        'inflamed stroma': 6,
        'rest': 7
    },
    'Unet_max_patch_size': [2048, 2048]
}

In [8]:
trainer = nnUNetTrainer_custom_dataloader_test(plans_dict, '2d', fold, dataset_dict)

Using device: cuda:0

#######################################################################
Please cite the following paper when using nnU-Net:
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
#######################################################################



In [9]:
dataset_dict

{'channel_names': {'0': 'rgb_to_0_1', '1': 'rgb_to_0_1', '2': 'rgb_to_0_1'},
 'labels': {'background': 0,
  'roi': 0,
  'invasive tumor': 1,
  'tumor-associated stroma': 2,
  'in-situ tumor': 3,
  'healthy glands': 4,
  'necrosis not in-situ': 5,
  'inflamed stroma': 6,
  'rest': 7},
 'Unet_max_patch_size': [2048, 2048]}

In [11]:
tr_it, val_it = trainer.get_dataloaders(subset=True, cpus= 16)





USING DATA SUBSET




Found splits.json
[Getting WSD dataloaders]
[ITERATOR TEMPLATE] Using iterator template: /data/pathology/projects/pathology-lung-TIL/nnUNet_v2/nnunetv2/training/nnUNetTrainer/variants/pathology/wsd_roi_iterator_nnunet_aug_template.json
Still timing everything, only copying SOME training and val files
Taking random seed 384 for iterators
cpus used for iterators =  16
[Creating batch iterators]
	[Creating TRAIN batch iterator]
	[Creating VAL batch iterator]
[Returning batch iterators]


Process CommanderForkProcess-22:
Process ProducerForkProcess-31:
Process ProducerForkProcess-30:
Process ProducerForkProcess-25:
Process ProducerForkProcess-37:
Process ProducerForkProcess-8:
Process ProducerForkProcess-23:
Process ProducerForkProcess-24:
Process ProducerForkProcess-13:
Process ProducerForkProcess-6:
Process ProducerForkProcess-4:
Process ProducerForkProcess-35:
Process ProducerForkProcess-26:
Process ProducerForkProcess-27:
Process ProducerForkProcess-38:
Process ProducerForkProcess-29:
Process ProducerForkProcess-36:
Process ProducerForkProcess-33:
Process ProducerForkProcess-17:
Process ProducerForkProcess-28:
Process ProducerForkProcess-12:
Process ProducerForkProcess-32:
Process ProducerForkProcess-34:
Traceback (most recent call last):
  File "/usr/local/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent ca

In [None]:
total_start = time()
start = time()
for i in range(50):
    batch = next(tr_it)
    print(time() - start)
    start = time()
print(time() - total_start)

In [9]:
trainer.iterator_template

'wsd_balanced_iterator_nnunet_aug'

In [10]:
patch_size_spatial = [512, 512]

In [11]:
tr_transforms = []
rotation_for_DA= {'x': (-np.pi, np.pi), 'y': (0, 0), 'z': (0, 0)}

tr_transforms.append(SpatialTransform(
    patch_size_spatial, 
    patch_center_dist_from_border=None,
    do_elastic_deform=False,
    alpha=(0, 0),
    sigma=(0, 0),
    do_rotation=True,
    angle_x=rotation_for_DA['x'],
    angle_y=rotation_for_DA['y'],
    angle_z=rotation_for_DA['z'],
    p_rot_per_axis=1,  # todo experiment with this
    do_scale=True,
    scale=(0.7, 1.4),
    border_mode_data="constant",
    border_cval_data=0,
    order_data=3,
    border_mode_seg="constant",
    border_cval_seg=-1,
    order_seg=1,
    random_crop=False,  # random cropping is part of our dataloaders
    p_el_per_sample=0,
    p_scale_per_sample=0.2,
    p_rot_per_sample=0.2,
    independent_scale_for_each_axis=False  # todo experiment with this
))

####
if True: #do_hed:
    tr_transforms.append(HedTransform(factor=0.05))
# if True: #do_hsv:
#     tr_transforms.append(HsvTransform(h_lim=0.10, s_lim=0.10, v_lim=0.10))
####

tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))

tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
                                           p_per_channel=0.5))

tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))

tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))

tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
                                                    p_per_channel=0.5,
                                                    order_downsample=0, order_upsample=3, p_per_sample=0.25,
                                                    ignore_axes=None))

tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=0.1))

tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=0.3))

tr_transforms.append(Clip01())

tr_transforms.append(MirrorTransform((0,1)))

tr_transforms.append(RenameTransform('seg', 'target', True))


transforms = Compose(tr_transforms)

In [12]:
from collections import OrderedDict

In [13]:
batch = OrderedDict(next(tr_it))

In [14]:
for i in range(10):
    start = time()
    transforms(**{'data': batch['data'].cpu().numpy(), 'seg': batch['target'][0].cpu().numpy()})
    end = time()
    print(end-start)


KeyboardInterrupt



In [None]:
batch.keys()

In [None]:
tr_it.stop()
val_it.stop()

In [None]:
cb = nnUnetBatchCallback([512, 512])

In [None]:
batch['data'].cpu().numpy().shape

In [None]:
batch['target'][0].cpu().numpy().shape

In [None]:
x_batch = batch['data'].cpu().numpy()
y_batch = batch['target'][0].cpu().numpy()

In [None]:
x_batch.shape, y_batch.shape

In [None]:
for i in range(50):
    cb(x_batch.transpose((1, 2, 3, 0)), y_batch.squeeze());

In [None]:
# x_batch