In [1]:
import matplotlib.pyplot as plt
plt.show()

In [2]:
import matplotlib.pyplot as plt
import os, sys
from typing import Iterable, Dict, List, Callable, Tuple, Union, List

import numpy as np
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt

sys.path.append('../')
from dataset import ACDCDataset, MNMDataset
from model.unet import UNet2D
from model.ae import AE
from model.dae import resDAE, AugResDAE
from model.wrapper import Frankenstein, ModelAdapter
from losses import DiceScoreMMS
from utils import  epoch_average, UMapGenerator



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet



In [3]:
acdc_train = ACDCDataset(data='train')

loading dataset
loading all case properties


In [None]:
data = acdc_train[0*10]
img = data['input']
mask = data['target']y



In [None]:
plt.imshow(img.squeeze())

In [None]:
plt.imshow(img.squeeze() + mask.squeeze())

In [87]:
### - datasets
#acdc_train = ACDCDataset(data='val')



debug = False
loader = {}
vendor = 'A'

mnm_a = MNMDataset(vendor=vendor, debug=debug)
mnm_a_loader = DataLoader(mnm_a, batch_size=1, shuffle=False, drop_last=False)


loading dataset
loading all case properties


In [4]:
### - init unets
# U-Nets
ROOT = '../../'
middle = 'unet8_'
pre = 'acdc'
unet_names = [f'{pre}_{middle}{i}' for i in range(10)]
unets = []
for name in unet_names:
    model_path = f'{ROOT}pre-trained-tmp/trained_UNets/{name}_best.pt'
    state_dict = torch.load(model_path)['model_state_dict']
    n_chans_out = 4
    unet = UNet2D(n_chans_in=1, 
                  n_chans_out=n_chans_out, 
                  n_filters_init=8, 
                  dropout=False)
    unet.load_state_dict(state_dict)
    unets.append(unet)

In [58]:
# init models
post = 'localAug_multiImgSingleView_res_balanced_same'
post = 'localAug_multiImgSingleView_recon_balanced_same'
disabled_ids = ['shortcut0', 'shortcut1', 'shortcut2']
models = []
for i, unet in enumerate(unets):
    DAEs = nn.ModuleDict({'up3': AugResDAE(in_channels = 64, 
                                        in_dim      = 32,
                                        latent_dim  = 256,
                                        depth       = 3,
                                        block_size  = 4,
                                        residual    = False),
                         })


    for layer_id in disabled_ids:
        DAEs[layer_id] = nn.Identity()
    
    model = Frankenstein(seg_model=unet,
                         transformations=DAEs,
                         disabled_ids=disabled_ids,
                         copy=True)
    model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_AugResDAE{i}_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/{pre}_resDAE{i}_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_epinet_CE-only_prior-1_best.pt'localAug_multiImgSingleView_res
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_resDAE0_venus_best.pt'
    state_dict = torch.load(model_path)['model_state_dict']
    model.load_state_dict(state_dict)
    # Remove trainiung hooks, add evaluation hooks
    model.remove_all_hooks()        
    model.hook_inference_transformations(model.transformations,
                               n_samples=1)
    # Put model in evaluation state
    model.eval()
    model.freeze_seg_model()
    models.append(model)

In [97]:
### Init two models for UNet 0, Reconstruction and Residual

# init models
posts = ['localAug_multiImgSingleView_res_balanced_same', 'localAug_multiImgSingleView_recon_balanced_same']
residuals = [True, False]
disabled_ids = ['shortcut0', 'shortcut1', 'shortcut2']
models = []
for i, (post, residual) in enumerate(zip(posts, residuals)):
    DAEs = nn.ModuleDict({'up3': AugResDAE(in_channels = 64, 
                                        in_dim      = 32,
                                        latent_dim  = 256,
                                        depth       = 3,
                                        block_size  = 4,
                                        residual    = residual),
                         })


    for layer_id in disabled_ids:
        DAEs[layer_id] = nn.Identity()
    
    model = Frankenstein(seg_model=unet,
                         transformations=DAEs,
                         disabled_ids=disabled_ids,
                         copy=True)
    model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_AugResDAE1_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/{pre}_resDAE{i}_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_epinet_CE-only_prior-1_best.pt'localAug_multiImgSingleView_res
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_resDAE0_venus_best.pt'
    state_dict = torch.load(model_path)['model_state_dict']
    model.load_state_dict(state_dict)
    # Remove trainiung hooks, add evaluation hooks
    model.remove_all_hooks()        
    model.hook_inference_transformations(model.transformations,
                               n_samples=1)
    # Put model in evaluation state
    model.eval()
    model.freeze_seg_model()
    models.append(model)

In [6]:
from torchvision.utils import save_image
import fiftyone as fo
import numpy as np
import os
import shutil
from torchmetrics import Dice

In [None]:
# try:
#     dataset.delete()
# except:
#     print("dataset name already available or dataset didnt exist")


# # build first dataset
# # init dataset
# dataset = fo.Dataset(name="segmentation_dataset_test")
# # add group and all Groups we need
# dataset.add_group_field("group", default="image")

# # set mask targets
# ## ground truth targets
# dataset.mask_targets = {
#     "ground_truth": {0: "background",
#                      1: "LV",
#                      2: "MYO",
#                      3: "RV"}
# }
# # error map labels
# for i in range(n_unets):
#     dataset.mask_targets[f'errormap_it:{i}'] = {1: 'error'}
    
# # make temporary dir for data handling
# os.makedirs('tmp', exist_ok=True)
# path = 'tmp/'
# # init sample list. We save each sample here and add it to the
# # dataset in the end
# samples = []
# n_unets = 1
# # init dice score class
# dcs = Dice(num_classes=4, ignore_index=0)
# # init umap generator
# umap_generator_entropy = UMapGenerator(method='entropy', net_out='mms')
# umap_generator_AE = UMapGenerator(method='ae', net_out='mms')

# # itertatively make samples
# for i in range(10):
#     # get data
#     data = mnm_a[i*10]
#     img = data['input']
#     mask = data['target']
#     mask[mask < 0] = 0
    
#     # save image to disk
#     img_path  = path + f'test_{i}.png'
#     img_norm  = img - img.min()
#     img_norm /= img_norm.max()

#     save_image(img_norm, img_path)
#     # make sample
#     group = fo.Group()
#     sample_image = fo.Sample(
#         filepath=img_path,
#         group=group.element('image'),
#         ground_truth=fo.Segmentation(
#             mask=mask.squeeze().numpy()
#         )
#     )
    
#     sample_activation = fo.Sample(
#         filepath=img_path,
#         group=group.element('test'),
#         ground_truth=fo.Segmentation(
#             mask=mask.squeeze().numpy()
#         )
#     )
    
#     # predictions, error maps, umaps and foreground dice
#     DSC = torch.zeros((n_unets))
#     for i in range(n_unets):
#         # U-Net - Predictions and errormaps
#         unet_output = unets[i](img.unsqueeze(0))
#         pred = torch.argmax(unet_output, dim=1, keepdims=True)
#         err_map = (pred != mask)
#         sample_image[f'pred_unet_it:{i}'] = fo.Segmentation(mask=pred.squeeze().numpy())
#         sample_image[f'error_unet_it:{i}']  = fo.Segmentation(mask=err_map.squeeze().numpy())
#         # umaps
#         ## entropy
#         umap_entropy = umap_generator_entropy(unet_output)
#         sample_image[f'umap_entropy_it:{i}'] = fo.Heatmap(map=umap_entropy.squeeze().numpy())
#         ## segmentation distortion - predictions with AE and UQ map
#         model_output  = models[i](img.unsqueeze(0))
#         model_pred    = torch.argmax(model_output[1], dim=0, keepdims=False)
#         model_err_map = (model_pred != mask.squeeze())
#         umap_ae = umap_generator_AE(model_output)
#         sample_image[f'pred_ae_it:{i}'] = fo.Segmentation(mask=model_pred.squeeze().numpy())
#         sample_image[f'error_ae_it:{i}']  = fo.Segmentation(mask=model_err_map.squeeze().numpy())
#         sample_image[f'umap_ae_it:{i}'] = fo.Heatmap(map=umap_ae.squeeze().numpy())
        
#         # dice score
#         DSC[i] = dcs(pred.squeeze(), mask.squeeze().int())        
    
#     # get dice related tags per sample
#     DSC = DSC.mean()
#     thresholds = [0.3, 0.5, 0.8]
#     for threshold in thresholds:
#         if DSC < threshold:
#             sample_image.tags.append(f'DSC_below_{int((threshold * 100))}')
        
#     # add sample to sample list
#     samples += [sample_image, sample_activation]
    
# # add samples to dataset
# dataset.add_samples(samples)


# # validate dataset
# print(dataset)

# # export dataset
# export_dir = os.path.expanduser('~') + '/fiftyone/test-dataset'
# dataset.export(export_dir=export_dir, dataset_type=fo.types.FiftyOneDataset)

# # clean up. remove tmp
# shutil.rmtree('tmp')
# #dataset.delete()

In [None]:
try:
    dataset.delete()
except:
    print("dataset name already available or dataset didnt exist")


# build first dataset
n_unets = 2
export_dir = os.path.expanduser('~') + '/fiftyone/test-dataset'
# init dataset
dataset = fo.Dataset(name="test")
# add group and all Groups we need
dataset.add_group_field("group", default="it_0")

# set mask targets
## ground truth targets
dataset.mask_targets = {
    "ground_truth": {0: "background",
                     1: "LV",
                     2: "MYO",
                     3: "RV"}
}
# error map labels
for i in range(n_unets):
    dataset.mask_targets[f'errormap_it:{i}'] = {1: 'error'}
    
# make temporary dir for data handling
os.makedirs('tmp', exist_ok=True)
path = 'tmp/'
# init sample list. We save each sample here and add it to the
# dataset in the end
samples = []
# init dice score class
dcs = Dice(num_classes=4, ignore_index=0)
# init umap generator
umap_generator_entropy = UMapGenerator(method='entropy', net_out='mms')
umap_generator_AE = UMapGenerator(method='ae', net_out='mms')

# itertatively make samples
for i in range(20, 30):
    # get data
    data = acdc_train[i*10]
    img = data['input']
    mask = data['target']
    mask[mask < 0] = 0
    
    # save image to disk
    img_path  = path + f'img_{i}.png'
    #print(img_path)
    img_norm  = img - img.min()
    img_norm /= img_norm.max()

    save_image(img_norm, img_path)
    torch.save(img, path + f'img_{i}.pt')
    
#     sample_activation = fo.Sample(
#         filepath=img_path,
#         group=group.element('test'),
#         ground_truth=fo.Segmentation(
#             mask=mask.squeeze().numpy()
#         )
#     )
    # init group for slice
    group = fo.Group()
    # predictions, error maps, umaps and foreground dice
    DSC = torch.zeros((n_unets))
    for j in range(n_unets):
        # make sample
        #print(img_path)
        sample_image = fo.Sample(
            filepath=img_path,
            group=group.element(f'it_{j}'),
            ground_truth=fo.Segmentation(
                mask=mask.squeeze().numpy()
            )
        )
        # save additional info to link id to image and all its masks
        sample_info = {
            'vendor': 'val',
            'slice': i,
            'unet': j
        }
        sample_image['sample_info'] = sample_info
#         if j == 0:
#             sample_image['input'] = img
        
        # U-Net - Predictions and errormaps
        unet_output = unets[j](img.unsqueeze(0))
        pred = torch.argmax(unet_output, dim=1, keepdims=True)
        err_map = (pred != mask)
        sample_image[f'pred_unet'] = fo.Segmentation(mask=pred.squeeze().numpy())
        sample_image[f'error_unet'] = fo.Segmentation(mask=err_map.squeeze().numpy())
        # umaps
        ## entropy
        umap_entropy = umap_generator_entropy(unet_output)
        sample_image[f'umap_entropy'] = fo.Heatmap(map=umap_entropy.squeeze().numpy())
        ## segmentation distortion - predictions with AE and UQ map
        model_output  = models[j](img.unsqueeze(0))
        model_pred    = torch.argmax(model_output[1], dim=0, keepdims=False)
        model_err_map = (model_pred != mask.squeeze())
        umap_ae = umap_generator_AE(model_output)
        sample_image[f'pred_ae']  = fo.Segmentation(mask=model_pred.squeeze().numpy())
        sample_image[f'error_ae'] = fo.Segmentation(mask=model_err_map.squeeze().numpy())
        sample_image[f'umap_ae']  = fo.Heatmap(map=umap_ae.squeeze().numpy())
        
        # dice scores
        sample_image['unet_dsc'] = dcs(pred.squeeze(), mask.squeeze().int()).numpy()
        sample_image['model_dsc'] = dcs(model_pred.squeeze(), mask.squeeze().int()).numpy()
        DSC[j] = dcs(pred.squeeze(), mask.squeeze().int())
        
        # add sample to sample list
        samples += [sample_image]
        
      
    
    # get dice related tags per sample
    DSC = DSC.mean()
    thresholds = [0.3, 0.5, 0.8]
    for threshold in thresholds:
        if DSC < threshold:
            sample_image.tags.append(f'DSC_below_{int((threshold * 100))}')
        
#     # add sample to sample list
#     samples += [sample_image, sample_activation]
    
# add samples to dataset
dataset.add_samples(samples)
print()

# # validate dataset
# print(dataset)

# # export dataset
# export_dir = os.path.expanduser('~') + '/fiftyone/test-dataset'
# dataset.export(export_dir=export_dir, dataset_type=fo.types.FiftyOneDataset)

# # clean up. remove tmp
# shutil.rmtree('tmp')
# #dataset.delete()

In [None]:
shutil.rmtree('tmp')

In [None]:
plt.imshow(img.squeeze())

In [None]:
# dataset.delete()

In [None]:
#dataset_dir = "/home/lennartz/fiftyone/test-dataset"
#dataset_type = fo.types.FiftyOneDataset
#dataset_show = fo.Dataset.from_dir(dataset_dir, dataset_type=dataset_type)

session = fo.launch_app(dataset, port=8100)

In [None]:
#fo.close_app()
#session.refresh()

In [None]:
#dataset['651e795cf009e13d026dd8ec']

# Visualizing the effect of our models on activations from certain images, specified by ID

In [7]:
### load model
def load_model(post, residual, i):
#     post = 'localAug_multiImgSingleView_res_balanced_same'
    disabled_ids = ['shortcut0', 'shortcut1', 'shortcut2']
#     models = []
    DAEs = nn.ModuleDict({'up3': AugResDAE(in_channels = 64, 
                                           in_dim      = 32,
                                           latent_dim  = 256,
                                           depth       = 3,
                                           block_size  = 4,
                                           residual    = residual)})

    for layer_id in disabled_ids:
        DAEs[layer_id] = nn.Identity()

    model = ModelAdapter(seg_model=unet,
                         transformations=DAEs,
                         disabled_ids=disabled_ids,
                         copy=True)
    model_path = f'../../pre-trained-tmp/trained_AEs/acdc_AugResDAE{i}_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/{pre}_resDAE{i}_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_epinet_CE-only_prior-1_best.pt'localAug_multiImgSingleView_res
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_resDAE0_venus_best.pt'
    state_dict = torch.load(model_path)['model_state_dict']
    model.load_state_dict(state_dict)
    # Remove trainiung hooks, add evaluation hooks
    model.remove_all_hooks()        
    model.hook_inspect_transformation(model.transformations)
    # Put model in evaluation state
    model.eval()
    model.freeze_seg_model()
    
    return model


### test
model = load_model('localAug_multiImgSingleView_res_balanced_same', True, 0)

In [8]:
### handle data from dataset before forward pass (i.e. extract correct slice and augment)
from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
import batchgenerators
from batchgenerators.transforms.local_transforms import *
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
### get augmentor from training
def get_local_augmentor_from_nnUnetplan():
    nnUnet_prefix = '../../../nnUNet/'
    pkl_file          = nnUnet_prefix + 'data/nnUNet_preprocessed/Task500_ACDC/nnUNetPlansv2.1_plans_2D.pkl'
    fold              = 0
    output_folder     = nnUnet_prefix + 'results/nnUnet/nnUNet/2d/Task027_ACDC/nnUNetTrainerV2__nnUNetPlansv2.1/'
    dataset_directory = nnUnet_prefix + 'data/nnUNet_preprocessed/Task500_ACDC'

    trainer = nnUNetTrainerV2(pkl_file, 0, output_folder, dataset_directory)
    trainer.initialize()

    train_loader = trainer.tr_gen

    original_transforms = (
        batchgenerators.transforms.resample_transforms.SimulateLowResolutionTransform,
        batchgenerators.transforms.noise_transforms.GaussianNoiseTransform,
        batchgenerators.transforms.utility_transforms.RemoveLabelTransform,
        batchgenerators.transforms.utility_transforms.RenameTransform,
        batchgenerators.transforms.utility_transforms.NumpyToTensor
    )

    scale = 200.
    local_transforms = [
        BrightnessGradientAdditiveTransform(scale=scale, max_strength=4, p_per_sample=0.2, p_per_channel=1),
        LocalGammaTransform(scale=scale, gamma=(2, 5), p_per_sample=0.2, p_per_channel=1),
        LocalSmoothingTransform(scale=scale, smoothing_strength=(0.5, 1), p_per_sample=0.2, p_per_channel=1),
        LocalContrastTransform(scale=scale, new_contrast=(1, 3), p_per_sample=0.2, p_per_channel=1),
    ]

    transforms = local_transforms + [t for t in train_loader.transform.transforms if isinstance(t, original_transforms)]
    augmentor = batchgenerators.transforms.abstract_transforms.Compose(transforms)
    
    return augmentor

### built latent vis dataloader from SingleImageMultiView

class SingleImageMultiViewDataLoader(batchgenerators.dataloading.data_loader.SlimDataLoaderBase):
    def __init__(self, data: dict, batch_size: int):
        super(SingleImageMultiViewDataLoader, self).__init__(data, batch_size)
        # data is now stored in self._data.
    
    def generate_train_batch(self):
        #data = self._data[randrange(len(self._data))]
        img = self._data.numpy().astype(np.float32)
        #tar = np.zeros_like(img, dtype=np.float32)
        
        img_batched = np.tile(img, (self.batch_size, 1, 1, 1))
        tar_batched = np.zeros_like(img_batched, dtype=np.float32)
        # now construct the dictionary and return it. np.float32 cast because most networks take float
        out = {'data': img_batched, 
               'seg':  tar_batched}
        
        # if the original data is also needed, activate this flag to store it where augmentations
        # cant find it.

        out['data_orig']   = self._data.unsqueeze(0)
        #out['target_orig'] = data['target'].unsqueeze(0)
        
        return out

def get_dataloader_for_dataset_and_id(dataset, sample_id, batch_size=1):
    #sample_id = '65313dd6718acf6628df3bc4'
    sample = dataset[sample_id]
    # sample information
    sample_info = sample['sample_info']
    # data sample
    img_path = sample['filepath'].replace('.png', '.pt')
    x_in = torch.load(img_path)

    dataloader = SingleImageMultiViewDataLoader(x_in, batch_size)
    
    return dataloader

### test
# augmentor  = get_local_augmentor_from_nnUnetplan()
# dataloader = get_dataloader_for_dataset_and_id(dataset_show, '65313dd6718acf6628df3bc4', batch_size=1)
# generator  = SingleThreadedAugmenter(dataloader, augmentor)
# batch      = next(generator)
# print(batch['data'].shape)

In [43]:
### get activations
def get_activations(
    dataset, 
    sample_id, 
    model='localAug_multiImgSingleView_res_balanced_same', 
    residual=True, 
    iteration=0
): 

    model = load_model(model, residual, iteration)

    augmentor  = get_local_augmentor_from_nnUnetplan()
    dataloader = get_dataloader_for_dataset_and_id(dataset, sample_id, batch_size=1)
    generator  = SingleThreadedAugmenter(dataloader, augmentor)
    batch      = next(generator)

    tmp = model(batch['data'])
    activations_augmented = model.inspect_data['up3']
    original_act_augmented = activations_augmented['input']
    denoised_act_augmented = activations_augmented['denoised']

    tmp = model(batch['data_orig'])
    activations = model.inspect_data['up3']
    original_act = activations['input']
    denoised_act = activations['denoised']

    # original_act_augmented, original_act, dist(original_act, denoised_act), dist(original_act, denoised_act_augmented)
    recon_residual_original = (original_act - denoised_act) ** 2
    recon_residual_original_augmented = (original_act - denoised_act_augmented) ** 2
    change_act = (original_act - original_act_augmented) ** 2
    residual_original   = activations['residuals']
    residual_original_augmented = activations_augmented['residuals']
    
    
    data = torch.cat([original_act, 
                      original_act_augmented, 
                      residual_original, 
                      residual_original_augmented, 
                      change_act,
                      recon_residual_original, 
                      recon_residual_original_augmented], dim=0)

    scaling_factor = data.max() - data.min()
    data -= torch.flatten(data, start_dim=1, end_dim=3).min(1).values.view(-1, 1, 1, 1)
    data /= scaling_factor
    
    data[2:4] *= 10
    data[4:] *= 100
    data = data.clamp(0,1)
    
    return data.transpose(0,1)

def temp_save_activation_data(data, path, img_names):
    for i, channel in enumerate(data):
        for name, img in zip(img_names, channel):
            save_image(img, path + f'{name}_{i}_test.png')
    save_image(torch.ones(data.shape[-2:]), path + f'background.png')
## testing
#data = get_activations(dataset, '65315041bc3ffd5e948dbabb')

In [None]:
### Make dataset for fiftyone
try:
    act_dataset.delete()
except:
    print("dataset name already available or dataset didnt exist")
# init dataset
act_dataset = fo.Dataset(name="activation_dataset_test")
# add group and all Groups we need
act_dataset.add_group_field("group", default="activation")
# make temporary dir for data handling
os.makedirs('tmp', exist_ok=True)
path = 'tmp/'
# init samples
samples = []

data = get_activations(dataset, '65326e3811ebb0636b055ed3')
img_names = [
    'activation',
    'augmented_activation',
    'residual_activation',
    'residual_augmented_activation',
    'activation_difference',
    'reconstruction_residual_activation',
    'reconstruction_residual_augmented_activation'
]
path = 'tmp/'
temp_save_activation_data(path, img_names)

for i, channel in enumerate(data):
    # make group for channel
    group = fo.Group()
    
    for name, img in zip(img_names, channel):
        
        img_path = path + f'{name}_{i}_test.png'
        
        sample_image = fo.Sample(
            filepath=img_path,
            group=group.element(f'{name}'),
        )
        
        samples += [sample_image]
        
act_dataset.add_samples(samples)
print()

In [None]:
fo.list_datasets()

In [None]:
session = fo.launch_app(act_dataset, port=8100)

In [None]:
try:
    dataset.delete()
except:
    print("dataset name already available or dataset didnt exist")
    
    
### Make dataset for fiftyone
try:
    act_dataset.delete()
except:
    print("dataset name already available or dataset didnt exist")

In [None]:
for tmp in [original_act, original_act_augmented, residual_original, residual_original_augmented]:
    print(tmp.min(), tmp.max())

# Visualizing the effect of local augmentations

In [10]:
## Initialize trainer to get data loaders with data augmentations from training
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from augment import MultiImageSingleViewDataLoader

nnUnet_prefix = '../../../nnUNet/'
pkl_file          = nnUnet_prefix + 'data/nnUNet_preprocessed/Task500_ACDC/nnUNetPlansv2.1_plans_2D.pkl'
fold              = 0
output_folder     = nnUnet_prefix + 'results/nnUnet/nnUNet/2d/Task027_ACDC/nnUNetTrainerV2__nnUNetPlansv2.1/'
dataset_directory = nnUnet_prefix + 'data/nnUNet_preprocessed/Task500_ACDC'

trainer = nnUNetTrainerV2(pkl_file, 0, output_folder, dataset_directory)
trainer.initialize()

train_loader = trainer.tr_gen


from batchgenerators.transforms.local_transforms import *
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter

original_transforms = (
    batchgenerators.transforms.resample_transforms.SimulateLowResolutionTransform,
    batchgenerators.transforms.noise_transforms.GaussianNoiseTransform,
    batchgenerators.transforms.utility_transforms.RemoveLabelTransform,
    batchgenerators.transforms.utility_transforms.RenameTransform,
    batchgenerators.transforms.utility_transforms.NumpyToTensor
)

scale = 200.
local_transforms = [
    BrightnessGradientAdditiveTransform(scale=scale, max_strength=4, p_per_sample=0.2, p_per_channel=1),
    LocalGammaTransform(scale=scale, gamma=(2, 5), p_per_sample=0.2, p_per_channel=1),
    LocalSmoothingTransform(scale=scale, smoothing_strength=(0.5, 1), p_per_sample=0.2, p_per_channel=1),
    LocalContrastTransform(scale=scale, new_contrast=(1, 3), p_per_sample=0.2, p_per_channel=1),
]

train_transforms = local_transforms + [t for t in train_loader.transform.transforms if isinstance(t, original_transforms)]

train_augmentor = batchgenerators.transforms.abstract_transforms.Compose(train_transforms)
### - Load dataset and init batch generator
train_data = ACDCDataset(data='train', debug=False)
train_gen = MultiImageSingleViewDataLoader(train_data, batch_size=1, return_orig=True)
train_gen = MultiThreadedAugmenter(train_gen, train_augmentor, 1, 1, seeds=None)

loading dataset
loading all case properties
2023-10-24 15:29:55.485117: Using splits from existing split file: ../../../nnUNet/data/nnUNet_preprocessed/Task500_ACDC/splits_final.pkl
2023-10-24 15:29:55.486391: The split file contains 5 splits.
2023-10-24 15:29:55.486510: Desired fold for training: 0
2023-10-24 15:29:55.487473: This split has 160 training and 40 validation cases.
unpacking dataset
done
loading dataset
loading all case properties


In [93]:
fo.delete_non_persistent_datasets()

In [71]:
# what we need
# model (Unet, single is alright for this)
# train data loader from DAEs. Half original, half augmented
# visualize segmentation, ground by original and augmented.
# Maybe include augmentation pipeline here to be able to change it if necessary (e.g. increase magnitude)

# first fiftyone stuff

try:
    local_aug_dataset.delete()
except:
    print("local_aug_dataset name already available or dataset didnt exist")


# build dataset
samples = []
n_unets = 1
export_dir = os.path.expanduser('~') + '/fiftyone/test-dataset'
# init dataset
local_aug_dataset = fo.Dataset(name="local_aug_0")
# add group and all Groups we need
local_aug_dataset.add_group_field("group", default="original")

# set mask targets
## ground truth targets
local_aug_dataset.mask_targets = {
    "ground_truth": {0: "background",
                     1: "LV",
                     2: "MYO",
                     3: "RV"}
}


# make temporary dir for data handling
os.makedirs('tmp', exist_ok=True)
path = 'tmp/'


# itertatively make samples
for i in range(0, 50):
    # get data
    data = next(train_gen)
    img_original = data['data_orig']
    img_augmented = data['data']
    mask = data['target_orig']
    mask[mask < 0] = 0
    
    # save images to disk
    img_path  = path + f'img_original_{i}.png'
    img_norm  = img_original - img_original.min()
    img_norm /= img_norm.max()
    save_image(img_norm, img_path)
    torch.save(img_original, path + f'img_original_{i}.pt')
    
    img_augmented_path  = path + f'img_augmented_{i}.png'
    img_augmented_norm  = img_augmented - img_augmented.min()
    img_augmented_norm /= img_augmented_norm.max()
    save_image(img_augmented_norm, img_augmented_path)
    torch.save(img_augmented, path + f'img_augmented_{i}.pt')
    
    # init group for slice
    group = fo.Group()
    
    sample_original = fo.Sample(
        filepath=img_path,
        group=group.element(f'original'),
        ground_truth=fo.Segmentation(
            mask=mask.squeeze().numpy()
        )
    )
    
    sample_augmented = fo.Sample(
        filepath=img_augmented_path,
        group=group.element(f'augmented'),
        ground_truth=fo.Segmentation(
            mask=mask.squeeze().numpy()
        )
    )
    
    # UNet predictions
    input_batch = torch.cat([data['data_orig'], data['data']], dim=0)
    unet_output = unets[0](input_batch)
    pred = torch.argmax(unet_output, dim=1, keepdims=True)
    err_map = (pred != mask)
    sample_original[f'pred_unet'] = fo.Segmentation(mask=pred[0].squeeze().numpy())
    sample_augmented[f'pred_unet'] = fo.Segmentation(mask=pred[1].squeeze().numpy())
    
    # UNet predictions with feature resampling
    res_model_output_original       = models[0](data['data_orig'])
    res_model_pred_original         = torch.argmax(res_model_output_original[1], dim=0, keepdims=False)
    sample_original[f'res_pred_ae'] = fo.Segmentation(mask=res_model_pred_original.squeeze().numpy())
    
    res_model_output_augmented       = models[0](data['data'])
    res_model_pred_augmented         = torch.argmax(res_model_output_augmented[1], dim=0, keepdims=False)
    sample_augmented[f'res_pred_ae'] = fo.Segmentation(mask=res_model_pred_augmented.squeeze().numpy())
    
    rec_model_output_original       = models[1](data['data_orig'])
    rec_model_pred_original         = torch.argmax(rec_model_output_original[1], dim=0, keepdims=False)
    sample_original[f'rec_pred_ae'] = fo.Segmentation(mask=rec_model_pred_original.squeeze().numpy())
    
    rec_model_output_augmented       = models[1](data['data'])
    rec_model_pred_augmented         = torch.argmax(rec_model_output_augmented[1], dim=0, keepdims=False)
    sample_augmented[f'rec_pred_ae'] = fo.Segmentation(mask=rec_model_pred_augmented.squeeze().numpy())
    # add samples to sample list
    samples += [sample_original, sample_augmented]
    
    
# add samples to dataset
local_aug_dataset.add_samples(samples)
print()



 100% |█████████████████| 100/100 [1.6s elapsed, 0s remaining, 63.6 samples/s]         


INFO:eta.core.utils: 100% |█████████████████| 100/100 [1.6s elapsed, 0s remaining, 63.6 samples/s]         





In [84]:
session = fo.launch_app(local_aug_dataset, port=8100)

# visualize activations for a pair of augmented / non augmented images

In [51]:
def get_data_from_dataset_and_ids(dataset, sample_ids):
    #sample_id = '65313dd6718acf6628df3bc4'
    
    samples_out = []
    
    for sample_id in sample_ids:
        sample = dataset[sample_id]
        img_path = sample['filepath'].replace('.png', '.pt')
        samples_out.append(torch.load(img_path))
        
    return samples_out


# def get_data_from_dataset_and_ids(dataset, sample_id, batch_size=1):
#     sample = dataset[sample_id]
#     img_path = sample['filepath'].replace('.png', '.pt')
#     return torch.load(img_path)
        

### get activations
def get_activations_without_additional_augmentations(
    dataset, 
    sample_ids, 
    model='localAug_multiImgSingleView_res_balanced_same', 
    residual=True, 
    iteration=0
): 

    model = load_model(model, residual, iteration)
    samples = get_data_from_dataset_and_ids(dataset, sample_ids)
    print(samples[0].shape)

    tmp = model(samples[1])
    activations_augmented = model.inspect_data['up3']
    original_act_augmented = activations_augmented['input']
    denoised_act_augmented = activations_augmented['denoised']

    tmp = model(samples[0])
    activations = model.inspect_data['up3']
    original_act = activations['input']
    denoised_act = activations['denoised']

    # original_act_augmented, original_act, dist(original_act, denoised_act), dist(original_act, denoised_act_augmented)
    recon_residual_original = (original_act - denoised_act)
    recon_residual_original_augmented = (original_act - denoised_act_augmented)
    change_act = (original_act - original_act_augmented)
    residual_original = activations['residuals']
    residual_original_augmented = activations_augmented['residuals']
    
    
    data = torch.cat([original_act, 
                      original_act_augmented,
                      change_act,
                      residual_original, 
                      residual_original_augmented, 
                      #recon_residual_original, 
                      recon_residual_original_augmented], dim=0)
    
    for t in data:
        print(t.shape, t.min(), t.max())
    
    scaling_factor = data.max() - data.min()
    #data -= torch.flatten(data, start_dim=1, end_dim=3).min(1).values.view(-1, 1, 1, 1)
    #data /= scaling_factor
    
    #data[2:4] *= 10
    #data[4:] *= 100
    data = data.clamp(-10,10)
    
    return data.transpose(0,1)


In [82]:
### Make dataset for fiftyone
try:
    act_dataset2.delete()
except:
    print("dataset name already available or dataset didnt exist")
# init dataset
act_dataset2 = fo.Dataset(name="activation_dataset_test")
# add group and all Groups we need
act_dataset2.add_group_field("group", default="activation")
# make temporary dir for data handling
os.makedirs('tmp', exist_ok=True)
path = 'tmp/'
# init samples
samples = []

ids = [
    '6538e8c4324eda9bed423a06',
    '6538e8c4324eda9bed423a07'
]

data_res = get_activations_without_additional_augmentations(
    local_aug_dataset, ids, 
    model='localAug_multiImgSingleView_res_balanced_same', 
    residual=True, iteration=0)
data_rec = get_activations_without_additional_augmentations(
    local_aug_dataset, ids, 
    model='localAug_multiImgSingleView_recon_balanced_same', 
    residual=False, iteration=0)
data = torch.stack([data_res, data_rec], dim=2)

img_names = [
    'activation',
    'augmented_activation',
    'activation_difference',
    'residual_activation',
    'residual_augmented_activation',
    #'reconstruction_residual_activation',
    'reconstruction_residual_augmented_activation'
]
path = 'tmp/'
temp_save_activation_data(data, path, img_names)
path_background = path + 'background.png'
_max = 0
_min = 0
for i, channel in enumerate(data):
    # make group for channel
    group = fo.Group()
    
    for name, img in zip(img_names, channel):
        
        img_path = path + f'{name}_{i}_test.png'
        
        sample_image = fo.Sample(
            filepath=path_background,
            group=group.element(f'{name}'),
        )
        
        sample_image['rec_heatmap_pos'] = fo.Heatmap(
            map=img[0].numpy(),
            #range=[-5, 5]
        )
        
        sample_image['rec_heatmap_neg'] = fo.Heatmap(
            map=-img[0].numpy(),
            #range=[-5, 5]
        )
        
        sample_image['res_heatmap_pos'] = fo.Heatmap(
            map=img[1].numpy(),
            #range=[-5, 5]
        )
        
        sample_image['res_heatmap_neg'] = fo.Heatmap(
            map=-img[1].numpy(),
            #range=[-5, 5]
        )
        
        samples += [sample_image]
        _max = max(_max, img.max())
        _min = min(_min, img.min())

        
act_dataset2.add_samples(samples)
print()

torch.Size([1, 1, 256, 256])
torch.Size([64, 32, 32]) tensor(-6.6202) tensor(10.5962)
torch.Size([64, 32, 32]) tensor(-6.2948) tensor(9.2680)
torch.Size([64, 32, 32]) tensor(-5.4037) tensor(6.6280)
torch.Size([64, 32, 32]) tensor(-0.2675) tensor(0.3243)
torch.Size([64, 32, 32]) tensor(-4.7259) tensor(6.6208)
torch.Size([64, 32, 32]) tensor(-3.6312) tensor(4.2636)
torch.Size([1, 1, 256, 256])
torch.Size([64, 32, 32]) tensor(-6.6202) tensor(10.5962)
torch.Size([64, 32, 32]) tensor(-6.2948) tensor(9.2680)
torch.Size([64, 32, 32]) tensor(-5.4037) tensor(6.6280)
torch.Size([64, 32, 32]) tensor(-1.0151) tensor(0.7926)
torch.Size([64, 32, 32]) tensor(-4.7242) tensor(5.6957)
torch.Size([64, 32, 32]) tensor(-3.5219) tensor(3.7527)
 100% |█████████████████| 384/384 [1.2s elapsed, 0s remaining, 322.2 samples/s]         


INFO:eta.core.utils: 100% |█████████████████| 384/384 [1.2s elapsed, 0s remaining, 322.2 samples/s]         





In [85]:
session = fo.launch_app(act_dataset2, port=8099)

In [80]:
fo.close_app()


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds



In [107]:
# what we need
# model (Unet, single is alright for this)
# train data loader from DAEs. Half original, half augmented
# visualize segmentation, ground by original and augmented.
# Maybe include augmentation pipeline here to be able to change it if necessary (e.g. increase magnitude)

# first fiftyone stuff

try:
    test_dataset.delete()
except:
    print("test_dataset name already available or dataset didnt exist")


# build dataset
samples = []
n_unets = 1
export_dir = os.path.expanduser('~') + '/fiftyone/test-dataset'
# init dataset
test_dataset = fo.Dataset(name="test_set_A")
# add group and all Groups we need
test_dataset.add_group_field("group", default="original")

# set mask targets
## ground truth targets
test_dataset.mask_targets = {
    "ground_truth": {0: "background",
                     1: "LV",
                     2: "MYO",
                     3: "RV"}
}


# make temporary dir for data handling
os.makedirs('tmp', exist_ok=True)
path = 'tmp/'

# init 
umap_generator_AE = UMapGenerator(method='ae', net_out='mms')

# itertatively make samples
for i in range(0, 50):
    # get data
    data = mnm_a[i * 10]
    img_original = data['input'].unsqueeze(0)
#     img_augmented = data['data']
    mask = data['target'].unsqueeze(0)
    mask[mask < 0] = 0
    
    # save images to disk
    img_path  = path + f'img_testA_{i}.png'
    img_norm  = img_original - img_original.min()
    img_norm /= img_norm.max()
    save_image(img_norm, img_path)
    torch.save(img_original, path + f'img_testA_{i}.pt')
    
#     img_augmented_path  = path + f'img_augmented_{i}.png'
#     img_augmented_norm  = img_augmented - img_augmented.min()
#     img_augmented_norm /= img_augmented_norm.max()
#     save_image(img_augmented_norm, img_augmented_path)
#     torch.save(img_augmented, path + f'img_augmented_{i}.pt')
    
    # init group for slice
    group = fo.Group()
    
    sample_original = fo.Sample(
        filepath=img_path,
        group=group.element(f'original'),
        ground_truth=fo.Segmentation(
            mask=mask.squeeze().numpy()
        )
    )

#     sample_augmented = fo.Sample(
#         filepath=img_augmented_path,
#         group=group.element(f'augmented'),
#         ground_truth=fo.Segmentation(
#             mask=mask.squeeze().numpy()
#         )
#     )
    
    # UNet predictions
#     input_batch = torch.cat([data['data_orig'], data['data']], dim=0)
    unet_output = unets[1](img_original)
    pred = torch.argmax(unet_output, dim=1, keepdims=True)
    err_map = (pred != mask)
    sample_original[f'pred_unet'] = fo.Segmentation(mask=pred[0].squeeze().numpy())
    sample_original[f'error_unet'] = fo.Segmentation(mask=err_map.squeeze().numpy())
#     sample_augmented[f'pred_unet'] = fo.Segmentation(mask=pred[1].squeeze().numpy())
    
    # UNet predictions with feature resampling
    res_model_output_original       = models[0](img_original)
    res_model_pred_original         = torch.argmax(res_model_output_original[1], dim=0, keepdims=False)
    sample_original[f'res_pred_ae'] = fo.Segmentation(mask=res_model_pred_original.squeeze().numpy())
    res_model_umap                  = umap_generator_AE(res_model_output_original)
    sample_original[f'res_umap_ae']    = fo.Heatmap(map=res_model_umap.squeeze().numpy())
    
#     res_model_output_augmented       = models[0](data['data'])
#     res_model_pred_augmented         = torch.argmax(res_model_output_augmented[1], dim=0, keepdims=False)
#     sample_augmented[f'res_pred_ae'] = fo.Segmentation(mask=res_model_pred_augmented.squeeze().numpy())
    
    rec_model_output_original       = models[1](img_original)
    rec_model_pred_original         = torch.argmax(rec_model_output_original[1], dim=0, keepdims=False)
    sample_original[f'rec_pred_ae'] = fo.Segmentation(mask=rec_model_pred_original.squeeze().numpy())
    rec_model_umap                  = umap_generator_AE(rec_model_output_original)
    sample_original[f'rec_umap_ae']    = fo.Heatmap(map=rec_model_umap.squeeze().numpy())
#     rec_model_output_augmented       = models[1](data['data'])
#     rec_model_pred_augmented         = torch.argmax(rec_model_output_augmented[1], dim=0, keepdims=False)
#     sample_augmented[f'rec_pred_ae'] = fo.Segmentation(mask=rec_model_pred_augmented.squeeze().numpy())
    # add samples to sample list
#     samples += [sample_original, sample_augmented]
    samples += [sample_original]
    
# add samples to dataset
test_dataset.add_samples(samples)
print()



 100% |███████████████████| 50/50 [1.7s elapsed, 0s remaining, 30.2 samples/s]         


INFO:eta.core.utils: 100% |███████████████████| 50/50 [1.7s elapsed, 0s remaining, 30.2 samples/s]         





In [100]:
fo.delete_non_persistent_datasets()

In [108]:
session = fo.launch_app(test_dataset, port=8098)


Could not connect session, trying again in 10 seconds



In [92]:
data['input'].shape

torch.Size([1, 256, 256])

In [None]:
#check input batch dim
data = next(train_gen)
img = data['data_orig']
img_augmented = data['data']
mask = data['target_orig']
mask[mask < 0] = 0

In [None]:
input_batch = torch.cat([data['data_orig'], data['data']], dim=0)
unet_output = unets[0](input_batch)
pred = torch.argmax(unet_output, dim=1, keepdims=True)
err_map = (pred != mask)
input_batch.shape

In [None]:
err_map.shape

In [None]:
# original_transforms = (
#     batchgenerators.transforms.resample_transforms.SimulateLowResolutionTransform,
#     batchgenerators.transforms.noise_transforms.GaussianNoiseTransform,
#     batchgenerators.transforms.utility_transforms.RemoveLabelTransform,
#     batchgenerators.transforms.utility_transforms.RenameTransform,
#     batchgenerators.transforms.utility_transforms.NumpyToTensor
# )

# scale = 200.
# local_transforms = [
#     BrightnessGradientAdditiveTransform(scale=scale, max_strength=4, p_per_sample=0.2, p_per_channel=1),
#     LocalGammaTransform(scale=scale, gamma=(2, 5), p_per_sample=0.2, p_per_channel=1),
#     LocalSmoothingTransform(scale=scale, smoothing_strength=(0.5, 1), p_per_sample=0.2, p_per_channel=1),
#     LocalContrastTransform(scale=scale, new_contrast=(1, 3), p_per_sample=0.2, p_per_channel=1),
# ]

# transforms = local_transforms + [t for t in train_loader.transform.transforms if isinstance(t, original_transforms)]
# augmentor = batchgenerators.transforms.abstract_transforms.Compose(train_transforms)




# ### - Load dataset and init batch generator
# train_data = ACDCDataset(data='train', debug=False)
# valid_data = ACDCDataset(data='val', debug=False)

# train_gen = MultiImageSingleViewDataLoader(train_data, batch_size=cfg['batch_size'], return_orig=True)
# #train_gen = SingleThreadedAugmenter(train_gen, train_augmentor)
# train_gen = MultiThreadedAugmenter(train_gen, train_augmentor, 4, 2, seeds=None)
# valid_gen = MultiImageSingleViewDataLoader(valid_data, batch_size=cfg['batch_size'], return_orig=True)
# #valid_gen = SingleThreadedAugmenter(valid_gen, valid_augmentor)
# valid_gen = MultiThreadedAugmenter(valid_gen, valid_augmentor, 4, 2, seeds=None)

In [None]:
# import batchgenerators

# # define single image dataloader from batchgenerator example here:
# # https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/example_ipynb.ipynb
# class SingleImageMultiViewDataLoader(batchgenerators.dataloading.data_loader.SlimDataLoaderBase):
#     def __init__(self, data: ACDCDataset, batch_size: int = 1, return_orig: str = False):
#         super(SingleImageMultiViewDataLoader, self).__init__(data, batch_size)
#         # data is now stored in self._data.
#         self.return_orig = return_orig
    
#     def generate_train_batch(self):
#         #data = self._data[randrange(len(self._data))]
#         img = data['input'].numpy().astype(np.float32)
#         tar = data['target'][0].numpy().astype(np.float32)
        
#         #img_batched = np.tile(img, (self.batch_size, 1, 1, 1))
#         #tar_batched = np.tile(tar, (self.batch_size, 1, 1, 1))
#         # now construct the dictionary and return it. np.float32 cast because most networks take float
#         out = {'data': img, 
#                'seg':  img}
        
#         # if the original data is also needed, activate this flag to store it where augmentations
#         # cant find it.
#         if self.return_orig:
#             out['data_orig']   = data['input'].unsqueeze(0)
#             out['target_orig'] = data['target'].unsqueeze(0)
        
#         return out

# def build_dataloader():
#     pass



# def get_data_from_id_for_actvis(sample_id, dataset, debug):
#     sample = dataset[sample_id]
#     # sample information
#     sample_info = sample['sample_info']
#     # data sample
#     img_path = sample['filepath'].replace('.png', '.pt')
#     x_in = torch.load(img_path)
#     # load models
#     model = load_model(
#         post = 'localAug_multiImgSingleView_res_balanced_same',
#         residual = True,
#         i = sample_info['unet']
#     )
#     # get activations
    
#     activation = None
    
    
# def build_51dataset_from_data():
#     pass
    
# def start_51_activation_vis():
#     pass

In [None]:
# model = load_model('localAug_multiImgSingleView_res_balanced_same', True, 0)

In [None]:
# post = 'localAug_multiImgSingleView_res_balanced_same'
# disabled_ids = ['shortcut0', 'shortcut1', 'shortcut2']
# models = []
# DAEs = nn.ModuleDict({'up3': AugResDAE(in_channels = 64, 
#                                     in_dim      = 32,
#                                     latent_dim  = 256,
#                                     depth       = 3,
#                                     block_size  = 4)})


# for layer_id in disabled_ids:
#     DAEs[layer_id] = nn.Identity()

# model = Frankenstein(seg_model=unet,
#                      transformations=DAEs,
#                      disabled_ids=disabled_ids,
#                      copy=True)
# model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_AugResDAE{i}_{post}_best.pt'
# #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/{pre}_resDAE{i}_{post}_best.pt'
# #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_epinet_CE-only_prior-1_best.pt'localAug_multiImgSingleView_res
# #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_resDAE0_venus_best.pt'
# state_dict = torch.load(model_path)['model_state_dict']
# model.load_state_dict(state_dict)
# # Remove trainiung hooks, add evaluation hooks
# model.remove_all_hooks()        
# model.hook_inference_transformations(model.transformations,
#                            n_samples=1)
# # Put model in evaluation state
# model.eval()
# model.freeze_seg_model()

In [None]:
# # try:
# #     dataset.delete()
# # except:
# #     print("dataset name already available or dataset didnt exist")


# # build first dataset
# # init dataset
# dataset = fo.Dataset(name="activation_dataset_test")
# # add group and all Groups we need
# dataset.add_group_field("group", default="it_0")

# # set mask targets
# ## ground truth targets
# # dataset.mask_targets = {
# #     "ground_truth": {0: "background",
# #                      1: "LV",
# #                      2: "MYO",
# #                      3: "RV"}
# # }
# # error map labels
# # for i in range(n_unets):
# #     dataset.mask_targets[f'errormap_it:{i}'] = {1: 'error'}
    
# # make temporary dir for data handling
# os.makedirs('tmp', exist_ok=True)
# path = 'tmp/'
# # init sample list. We save each sample here and add it to the
# # dataset in the end
# samples = []
# n_unets = 2
# # init dice score class
# dcs = Dice(num_classes=4, ignore_index=0)
# # init umap generator
# umap_generator_entropy = UMapGenerator(method='entropy', net_out='mms')
# umap_generator_AE = UMapGenerator(method='ae', net_out='mms')

# # itertatively make samples
# for i in range(10):
    # get data
#     data = mnm_a[i*10]
#     img = data['input']
#     mask = data['target']
#     mask[mask < 0] = 0
    
#     # save image to disk
#     img_path  = path + f'test_{i}.png'
#     img_norm  = img - img.min()
#     img_norm /= img_norm.max()

#     save_image(img_norm, img_path)

    
# #     sample_activation = fo.Sample(
# #         filepath=img_path,
# #         group=group.element('test'),
# #         ground_truth=fo.Segmentation(
# #             mask=mask.squeeze().numpy()
# #         )
# #     )
#     # init group for slice
#     group = fo.Group()
#     # predictions, error maps, umaps and foreground dice
#     DSC = torch.zeros((n_unets))
#     for i in range(n_unets):
#         # make sample
        
#         sample_image = fo.Sample(
#             filepath=img_path,
#             group=group.element(f'it_{i}'),
#             ground_truth=fo.Segmentation(
#                 mask=mask.squeeze().numpy()
#             )
#         )
#         # U-Net - Predictions and errormaps
#         unet_output = unets[i](img.unsqueeze(0))
#         pred = torch.argmax(unet_output, dim=1, keepdims=True)
#         err_map = (pred != mask)
#         sample_image[f'pred_unet'] = fo.Segmentation(mask=pred.squeeze().numpy())
#         sample_image[f'error_unet']  = fo.Segmentation(mask=err_map.squeeze().numpy())
#         # umaps
#         ## entropy
#         umap_entropy = umap_generator_entropy(unet_output)
#         sample_image[f'umap_entropy'] = fo.Heatmap(map=umap_entropy.squeeze().numpy())
#         ## segmentation distortion - predictions with AE and UQ map
#         model_output  = models[i](img.unsqueeze(0))
#         model_pred    = torch.argmax(model_output[1], dim=0, keepdims=False)
#         model_err_map = (model_pred != mask.squeeze())
#         umap_ae = umap_generator_AE(model_output)
#         sample_image[f'pred_ae'] = fo.Segmentation(mask=model_pred.squeeze().numpy())
#         sample_image[f'error_ae']  = fo.Segmentation(mask=model_err_map.squeeze().numpy())
#         sample_image[f'umap_ae'] = fo.Heatmap(map=umap_ae.squeeze().numpy())
        
#         # add sample to sample list
#         samples += [sample_image]
        
#         # dice score
#         DSC[i] = dcs(pred.squeeze(), mask.squeeze().int())        
    
#     # get dice related tags per sample
#     DSC = DSC.mean()
#     thresholds = [0.3, 0.5, 0.8]
#     for threshold in thresholds:
#         if DSC < threshold:
#             sample_image.tags.append(f'DSC_below_{int((threshold * 100))}')
        
# #     # add sample to sample list
# #     samples += [sample_image, sample_activation]
    
# # add samples to dataset
# dataset.add_samples(samples)


# # validate dataset
# print(dataset)

# # export dataset
# export_dir = os.path.expanduser('~') + '/fiftyone/test-dataset'
# dataset.export(export_dir=export_dir, dataset_type=fo.types.FiftyOneDataset)

# # clean up. remove tmp
# shutil.rmtree('tmp')
# #dataset.delete()