In [125]:
import matplotlib.pyplot as plt
import os, sys
from typing import Iterable, Dict, List, Callable, Tuple, Union, List

import numpy as np
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt

sys.path.append('../')
from dataset import ACDCDataset, MNMDataset
from model.unet import UNet2D
from model.ae import AE
from model.dae import resDAE, AugResDAE
from model.wrapper import Frankenstein, FrankensteinV2
from losses import DiceScoreMMS
from utils import  epoch_average, UMapGenerator

In [128]:
### - datasets
debug = False
loader = {}
vendor = 'A'

mnm_a = MNMDataset(vendor=vendor, debug=debug)
mnm_a_loader = DataLoader(mnm_a, batch_size=1, shuffle=False, drop_last=False)


In [11]:
### - init unets
# U-Nets
ROOT = '../../'
middle = 'unet8_'
pre = 'acdc'
unet_names = [f'{pre}_{middle}{i}' for i in range(10)]
unets = []
for name in unet_names:
    model_path = f'{ROOT}pre-trained-tmp/trained_UNets/{name}_best.pt'
    state_dict = torch.load(model_path)['model_state_dict']
    n_chans_out = 4
    unet = UNet2D(n_chans_in=1, 
                  n_chans_out=n_chans_out, 
                  n_filters_init=8, 
                  dropout=False)
    unet.load_state_dict(state_dict)
    unets.append(unet)

In [82]:
# init models
post = 'localAug_multiImgSingleView_res_balanced_same'
disabled_ids = ['shortcut0', 'shortcut1', 'shortcut2']
models = []
for i, unet in enumerate(unets):
    DAEs = nn.ModuleDict({'up3': AugResDAE(in_channels = 64, 
                                        in_dim      = 32,
                                        latent_dim  = 256,
                                        depth       = 3,
                                        block_size  = 4)})


    for layer_id in disabled_ids:
        DAEs[layer_id] = nn.Identity()
    
    model = Frankenstein(seg_model=unet,
                         transformations=DAEs,
                         disabled_ids=disabled_ids,
                         copy=True)
    model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_AugResDAE{i}_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/{pre}_resDAE{i}_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_epinet_CE-only_prior-1_best.pt'localAug_multiImgSingleView_res
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_resDAE0_venus_best.pt'
    state_dict = torch.load(model_path)['model_state_dict']
    model.load_state_dict(state_dict)
    # Remove trainiung hooks, add evaluation hooks
    model.remove_all_hooks()        
    model.hook_transformations(model.transformations,
                               n_samples=1)
    # Put model in evaluation state
    model.eval()
    model.freeze_seg_model()
    models.append(model)

In [51]:
from torchvision.utils import save_image
import fiftyone as fo
import numpy as np
import os
import shutil
from torchmetrics import Dice

In [58]:
dcs = Dice(num_classes=4)

In [60]:
dcs(pred.squeeze(), mask.squeeze().int())

tensor(0.9543)

In [117]:
# try:
#     dataset.delete()
# except:
#     print("dataset name already available or dataset didnt exist")


# # build first dataset
# # init dataset
# dataset = fo.Dataset(name="segmentation_dataset_test")
# # add group and all Groups we need
# dataset.add_group_field("group", default="image")

# # set mask targets
# ## ground truth targets
# dataset.mask_targets = {
#     "ground_truth": {0: "background",
#                      1: "LV",
#                      2: "MYO",
#                      3: "RV"}
# }
# # error map labels
# for i in range(n_unets):
#     dataset.mask_targets[f'errormap_it:{i}'] = {1: 'error'}
    
# # make temporary dir for data handling
# os.makedirs('tmp', exist_ok=True)
# path = 'tmp/'
# # init sample list. We save each sample here and add it to the
# # dataset in the end
# samples = []
# n_unets = 1
# # init dice score class
# dcs = Dice(num_classes=4, ignore_index=0)
# # init umap generator
# umap_generator_entropy = UMapGenerator(method='entropy', net_out='mms')
# umap_generator_AE = UMapGenerator(method='ae', net_out='mms')

# # itertatively make samples
# for i in range(10):
#     # get data
#     data = mnm_a[i*10]
#     img = data['input']
#     mask = data['target']
#     mask[mask < 0] = 0
    
#     # save image to disk
#     img_path  = path + f'test_{i}.png'
#     img_norm  = img - img.min()
#     img_norm /= img_norm.max()

#     save_image(img_norm, img_path)
#     # make sample
#     group = fo.Group()
#     sample_image = fo.Sample(
#         filepath=img_path,
#         group=group.element('image'),
#         ground_truth=fo.Segmentation(
#             mask=mask.squeeze().numpy()
#         )
#     )
    
#     sample_activation = fo.Sample(
#         filepath=img_path,
#         group=group.element('test'),
#         ground_truth=fo.Segmentation(
#             mask=mask.squeeze().numpy()
#         )
#     )
    
#     # predictions, error maps, umaps and foreground dice
#     DSC = torch.zeros((n_unets))
#     for i in range(n_unets):
#         # U-Net - Predictions and errormaps
#         unet_output = unets[i](img.unsqueeze(0))
#         pred = torch.argmax(unet_output, dim=1, keepdims=True)
#         err_map = (pred != mask)
#         sample_image[f'pred_unet_it:{i}'] = fo.Segmentation(mask=pred.squeeze().numpy())
#         sample_image[f'error_unet_it:{i}']  = fo.Segmentation(mask=err_map.squeeze().numpy())
#         # umaps
#         ## entropy
#         umap_entropy = umap_generator_entropy(unet_output)
#         sample_image[f'umap_entropy_it:{i}'] = fo.Heatmap(map=umap_entropy.squeeze().numpy())
#         ## segmentation distortion - predictions with AE and UQ map
#         model_output  = models[i](img.unsqueeze(0))
#         model_pred    = torch.argmax(model_output[1], dim=0, keepdims=False)
#         model_err_map = (model_pred != mask.squeeze())
#         umap_ae = umap_generator_AE(model_output)
#         sample_image[f'pred_ae_it:{i}'] = fo.Segmentation(mask=model_pred.squeeze().numpy())
#         sample_image[f'error_ae_it:{i}']  = fo.Segmentation(mask=model_err_map.squeeze().numpy())
#         sample_image[f'umap_ae_it:{i}'] = fo.Heatmap(map=umap_ae.squeeze().numpy())
        
#         # dice score
#         DSC[i] = dcs(pred.squeeze(), mask.squeeze().int())        
    
#     # get dice related tags per sample
#     DSC = DSC.mean()
#     thresholds = [0.3, 0.5, 0.8]
#     for threshold in thresholds:
#         if DSC < threshold:
#             sample_image.tags.append(f'DSC_below_{int((threshold * 100))}')
        
#     # add sample to sample list
#     samples += [sample_image, sample_activation]
    
# # add samples to dataset
# dataset.add_samples(samples)


# # validate dataset
# print(dataset)

# # export dataset
# export_dir = os.path.expanduser('~') + '/fiftyone/test-dataset'
# dataset.export(export_dir=export_dir, dataset_type=fo.types.FiftyOneDataset)

# # clean up. remove tmp
# shutil.rmtree('tmp')
# #dataset.delete()



 100% |███████████████████| 20/20 [570.3ms elapsed, 0s remaining, 35.1 samples/s]      
Name:        segmentation_dataset_test
Media type:  group
Group slice: image
Num groups:  10
Persistent:  False
Tags:        []
Sample fields:
    id:                fiftyone.core.fields.ObjectIdField
    filepath:          fiftyone.core.fields.StringField
    tags:              fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)
    metadata:          fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.Metadata)
    group:             fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.groups.Group)
    ground_truth:      fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Segmentation)
    pred_unet_it:0:    fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Segmentation)
    error_unet_it:0:   fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Segmentation)
    umap_entropy_it:0: fiftyone.core.fields.EmbeddedDocumentField(fifty

In [154]:
try:
    dataset.delete()
except:
    print("dataset name already available or dataset didnt exist")


# build first dataset
# init dataset
dataset = fo.Dataset(name="segmentation_dataset_debug")
# add group and all Groups we need
dataset.add_group_field("group", default="it_0")

# set mask targets
## ground truth targets
dataset.mask_targets = {
    "ground_truth": {0: "background",
                     1: "LV",
                     2: "MYO",
                     3: "RV"}
}
# error map labels
for i in range(n_unets):
    dataset.mask_targets[f'errormap_it:{i}'] = {1: 'error'}
    
# make temporary dir for data handling
os.makedirs('tmp', exist_ok=True)
path = 'tmp/'
# init sample list. We save each sample here and add it to the
# dataset in the end
samples = []
n_unets = 2
# init dice score class
dcs = Dice(num_classes=4, ignore_index=0)
# init umap generator
umap_generator_entropy = UMapGenerator(method='entropy', net_out='mms')
umap_generator_AE = UMapGenerator(method='ae', net_out='mms')

# itertatively make samples
for i in range(10):
    # get data
    data = mnm_a[i*10]
    img = data['input']
    mask = data['target']
    mask[mask < 0] = 0
    
    # save image to disk
    img_path  = path + f'img_{i}.png'
    img_norm  = img - img.min()
    img_norm /= img_norm.max()

    save_image(img_norm, img_path)
    torch.save(img, path + f'img_{i}.pt')
    
#     sample_activation = fo.Sample(
#         filepath=img_path,
#         group=group.element('test'),
#         ground_truth=fo.Segmentation(
#             mask=mask.squeeze().numpy()
#         )
#     )
    # init group for slice
    group = fo.Group()
    # predictions, error maps, umaps and foreground dice
    DSC = torch.zeros((n_unets))
    for j in range(n_unets):
        # make sample
        sample_image = fo.Sample(
            filepath=img_path,
            group=group.element(f'it_{j}'),
            ground_truth=fo.Segmentation(
                mask=mask.squeeze().numpy()
            )
        )
        # save additional info to link id to image and all its masks
        sample_info = {
            'vendor': vendor,
            'slice': i,
            'unet': j
        }
        sample_image['sample_info'] = sample_info
#         if j == 0:
#             sample_image['input'] = img
        
        # U-Net - Predictions and errormaps
        unet_output = unets[j](img.unsqueeze(0))
        pred = torch.argmax(unet_output, dim=1, keepdims=True)
        err_map = (pred != mask)
        sample_image[f'pred_unet'] = fo.Segmentation(mask=pred.squeeze().numpy())
        sample_image[f'error_unet']  = fo.Segmentation(mask=err_map.squeeze().numpy())
        # umaps
        ## entropy
        umap_entropy = umap_generator_entropy(unet_output)
        sample_image[f'umap_entropy'] = fo.Heatmap(map=umap_entropy.squeeze().numpy())
        ## segmentation distortion - predictions with AE and UQ map
        model_output  = models[j](img.unsqueeze(0))
        model_pred    = torch.argmax(model_output[1], dim=0, keepdims=False)
        model_err_map = (model_pred != mask.squeeze())
        umap_ae = umap_generator_AE(model_output)
        sample_image[f'pred_ae']  = fo.Segmentation(mask=model_pred.squeeze().numpy())
        sample_image[f'error_ae'] = fo.Segmentation(mask=model_err_map.squeeze().numpy())
        sample_image[f'umap_ae']  = fo.Heatmap(map=umap_ae.squeeze().numpy())
        
        # dice scores
        sample_image['unet_dsc'] = dcs(pred.squeeze(), mask.squeeze().int()).numpy()
        sample_image['model_dsc'] = dcs(model_pred.squeeze(), mask.squeeze().int()).numpy()
        DSC[j] = dcs(pred.squeeze(), mask.squeeze().int())
        
        # add sample to sample list
        samples += [sample_image]
        
      
    
    # get dice related tags per sample
    DSC = DSC.mean()
    thresholds = [0.3, 0.5, 0.8]
    for threshold in thresholds:
        if DSC < threshold:
            sample_image.tags.append(f'DSC_below_{int((threshold * 100))}')
        
#     # add sample to sample list
#     samples += [sample_image, sample_activation]
    
# add samples to dataset
dataset.add_samples(samples)


# validate dataset
print(dataset)

# export dataset
export_dir = os.path.expanduser('~') + '/fiftyone/test-dataset'
dataset.export(export_dir=export_dir, dataset_type=fo.types.FiftyOneDataset)

# clean up. remove tmp
shutil.rmtree('tmp')
#dataset.delete()

dataset name already available or dataset didnt exist
 100% |███████████████████| 20/20 [848.7ms elapsed, 0s remaining, 23.6 samples/s]      
Name:        segmentation_dataset_debug
Media type:  group
Group slice: it_0
Num groups:  10
Persistent:  False
Tags:        []
Sample fields:
    id:           fiftyone.core.fields.ObjectIdField
    filepath:     fiftyone.core.fields.StringField
    tags:         fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)
    metadata:     fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.Metadata)
    group:        fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.groups.Group)
    ground_truth: fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Segmentation)
    sample_info:  fiftyone.core.fields.DictField
    pred_unet:    fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Segmentation)
    error_unet:   fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Segmentation)
    umap

In [157]:
dataset_dir = "/home/lennartz/fiftyone/test-dataset"
dataset_type = fo.types.FiftyOneDataset
dataset_show = fo.Dataset.from_dir(dataset_dir, dataset_type=dataset_type)

session = fo.launch_app(dataset_show, port=8100)

Importing samples...
 100% |███████████████████| 20/20 [25.7ms elapsed, 0s remaining, 779.1 samples/s]    
Import complete



Could not connect session, trying again in 10 seconds

Could not connect session, trying again in 10 seconds



Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 seconds


Could not connect session, trying again in 10 s

In [133]:
dataset['651e795cf009e13d026dd8ec']

<Sample: {
    'id': '651e795cf009e13d026dd8ec',
    'media_type': 'image',
    'filepath': '/home/lennartz/repos/Segmentation-Distortion/src/demos/tmp/test_6.png',
    'tags': [],
    'metadata': None,
    'group': <Group: {'id': '651e795af009e13d026dd8a3', 'name': 'it_1'}>,
    'ground_truth': <Segmentation: {
        'id': '651e795af009e13d026dd8ab',
        'tags': [],
        'mask': array([[0., 0., 0., ..., 0., 0., 0.],
               [0., 0., 0., ..., 0., 0., 0.],
               [0., 0., 0., ..., 0., 0., 0.],
               ...,
               [0., 0., 0., ..., 0., 0., 0.],
               [0., 0., 0., ..., 0., 0., 0.],
               [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
        'mask_path': None,
    }>,
    'sample_info': {'vendor': 'A', 'slice': 6, 'unet': 1},
    'pred_unet': <Segmentation: {
        'id': '651e795af009e13d026dd8ac',
        'tags': [],
        'mask': array([[0, 0, 0, ..., 0, 0, 0],
               [0, 0, 0, ..., 0, 0, 0],
               [0, 0, 0, 

In [None]:
original_transforms = (
    batchgenerators.transforms.resample_transforms.SimulateLowResolutionTransform,
    batchgenerators.transforms.noise_transforms.GaussianNoiseTransform,
    batchgenerators.transforms.utility_transforms.RemoveLabelTransform,
    batchgenerators.transforms.utility_transforms.RenameTransform,
    batchgenerators.transforms.utility_transforms.NumpyToTensor
)

scale = 200.
local_transforms = [
    BrightnessGradientAdditiveTransform(scale=scale, max_strength=4, p_per_sample=0.2, p_per_channel=1),
    LocalGammaTransform(scale=scale, gamma=(2, 5), p_per_sample=0.2, p_per_channel=1),
    LocalSmoothingTransform(scale=scale, smoothing_strength=(0.5, 1), p_per_sample=0.2, p_per_channel=1),
    LocalContrastTransform(scale=scale, new_contrast=(1, 3), p_per_sample=0.2, p_per_channel=1),
]

transforms = local_transforms + [t for t in train_loader.transform.transforms if isinstance(t, original_transforms)]
augmentor = batchgenerators.transforms.abstract_transforms.Compose(train_transforms)




# ### - Load dataset and init batch generator
# train_data = ACDCDataset(data='train', debug=False)
# valid_data = ACDCDataset(data='val', debug=False)

# train_gen = MultiImageSingleViewDataLoader(train_data, batch_size=cfg['batch_size'], return_orig=True)
# #train_gen = SingleThreadedAugmenter(train_gen, train_augmentor)
# train_gen = MultiThreadedAugmenter(train_gen, train_augmentor, 4, 2, seeds=None)
# valid_gen = MultiImageSingleViewDataLoader(valid_data, batch_size=cfg['batch_size'], return_orig=True)
# #valid_gen = SingleThreadedAugmenter(valid_gen, valid_augmentor)
# valid_gen = MultiThreadedAugmenter(valid_gen, valid_augmentor, 4, 2, seeds=None)

In [None]:
# define single image dataloader from batchgenerator example here:
# https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/example_ipynb.ipynb
class SingleImageMultiViewDataLoader(batchgenerators.dataloading.data_loader.SlimDataLoaderBase):
    def __init__(self, data: ACDCDataset, batch_size: int = 1, return_orig: str = False):
        super(SingleImageMultiViewDataLoader, self).__init__(data, batch_size)
        # data is now stored in self._data.
        self.return_orig = return_orig
    
    def generate_train_batch(self):
        
        #data = self._data[randrange(len(self._data))]
        img = data['input'].numpy().astype(np.float32)
        tar = data['target'][0].numpy().astype(np.float32)
        
        #img_batched = np.tile(img, (self.batch_size, 1, 1, 1))
        #tar_batched = np.tile(tar, (self.batch_size, 1, 1, 1))
        # now construct the dictionary and return it. np.float32 cast because most networks take float
        out = {'data': img, 
               'seg':  img}
        
        # if the original data is also needed, activate this flag to store it where augmentations
        # cant find it.
        if self.return_orig:
            out['data_orig']   = data['input'].unsqueeze(0)
            out['target_orig'] = data['target'].unsqueeze(0)
        
        return out

def build_dataloader():
    pass

def load_model(post, residual, i):
#     post = 'localAug_multiImgSingleView_res_balanced_same'
    disabled_ids = ['shortcut0', 'shortcut1', 'shortcut2']
#     models = []
    DAEs = nn.ModuleDict({'up3': AugResDAE(in_channels = 64, 
                                           in_dim      = 32,
                                           latent_dim  = 256,
                                           depth       = 3,
                                           block_size  = 4,
                                           residual    = residual)})

    for layer_id in disabled_ids:
        DAEs[layer_id] = nn.Identity()

    model = Frankenstein(seg_model=unet,
                         transformations=DAEs,
                         disabled_ids=disabled_ids,
                         copy=True)
    model_path = f'../../pre-trained-tmp/trained_AEs/acdc_AugResDAE{i}_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/{pre}_resDAE{i}_{post}_best.pt'
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_epinet_CE-only_prior-1_best.pt'localAug_multiImgSingleView_res
    #model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_resDAE0_venus_best.pt'
    state_dict = torch.load(model_path)['model_state_dict']
    model.load_state_dict(state_dict)
    # Remove trainiung hooks, add evaluation hooks
    model.remove_all_hooks()        
    model.hook_inspect_transformation(model.transformations)
    # Put model in evaluation state
    model.eval()
    model.freeze_seg_model()
    
    return model


def get_data_from_id_for_actvis(sample_id, dataset, debug):
    sample = dataset[sample_id]
    # sample information
    sample_info = sample['sample_info']
    # data sample
    img_path = sample['filepath'].replace('.png', '.pt')
    x_in = torch.load(img_path)
    # load models
    model = load_model(
        post = 'localAug_multiImgSingleView_res_balanced_same',
        residual = True,
        i = sample_info['unet']
    )
    # get activations
    
    activation = None
    
    
def build_51dataset_from_data():
    pass
    
def start_51_activation_vis():
    pass

In [None]:
post = 'localAug_multiImgSingleView_res_balanced_same'
disabled_ids = ['shortcut0', 'shortcut1', 'shortcut2']
models = []
DAEs = nn.ModuleDict({'up3': AugResDAE(in_channels = 64, 
                                    in_dim      = 32,
                                    latent_dim  = 256,
                                    depth       = 3,
                                    block_size  = 4)})


for layer_id in disabled_ids:
    DAEs[layer_id] = nn.Identity()

model = Frankenstein(seg_model=unet,
                     transformations=DAEs,
                     disabled_ids=disabled_ids,
                     copy=True)
model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_AugResDAE{i}_{post}_best.pt'
#model_path = f'{ROOT}pre-trained-tmp/trained_AEs/{pre}_resDAE{i}_{post}_best.pt'
#model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_epinet_CE-only_prior-1_best.pt'localAug_multiImgSingleView_res
#model_path = f'{ROOT}pre-trained-tmp/trained_AEs/acdc_resDAE0_venus_best.pt'
state_dict = torch.load(model_path)['model_state_dict']
model.load_state_dict(state_dict)
# Remove trainiung hooks, add evaluation hooks
model.remove_all_hooks()        
model.hook_transformations(model.transformations,
                           n_samples=1)
# Put model in evaluation state
model.eval()
model.freeze_seg_model()

In [None]:
# try:
#     dataset.delete()
# except:
#     print("dataset name already available or dataset didnt exist")


# build first dataset
# init dataset
dataset = fo.Dataset(name="activation_dataset_test")
# add group and all Groups we need
dataset.add_group_field("group", default="it_0")

# set mask targets
## ground truth targets
# dataset.mask_targets = {
#     "ground_truth": {0: "background",
#                      1: "LV",
#                      2: "MYO",
#                      3: "RV"}
# }
# error map labels
# for i in range(n_unets):
#     dataset.mask_targets[f'errormap_it:{i}'] = {1: 'error'}
    
# make temporary dir for data handling
os.makedirs('tmp', exist_ok=True)
path = 'tmp/'
# init sample list. We save each sample here and add it to the
# dataset in the end
samples = []
n_unets = 2
# init dice score class
dcs = Dice(num_classes=4, ignore_index=0)
# init umap generator
umap_generator_entropy = UMapGenerator(method='entropy', net_out='mms')
umap_generator_AE = UMapGenerator(method='ae', net_out='mms')

# itertatively make samples
for i in range(10):
    # get data
    data = mnm_a[i*10]
    img = data['input']
    mask = data['target']
    mask[mask < 0] = 0
    
    # save image to disk
    img_path  = path + f'test_{i}.png'
    img_norm  = img - img.min()
    img_norm /= img_norm.max()

    save_image(img_norm, img_path)

    
#     sample_activation = fo.Sample(
#         filepath=img_path,
#         group=group.element('test'),
#         ground_truth=fo.Segmentation(
#             mask=mask.squeeze().numpy()
#         )
#     )
    # init group for slice
    group = fo.Group()
    # predictions, error maps, umaps and foreground dice
    DSC = torch.zeros((n_unets))
    for i in range(n_unets):
        # make sample
        
        sample_image = fo.Sample(
            filepath=img_path,
            group=group.element(f'it_{i}'),
            ground_truth=fo.Segmentation(
                mask=mask.squeeze().numpy()
            )
        )
        # U-Net - Predictions and errormaps
        unet_output = unets[i](img.unsqueeze(0))
        pred = torch.argmax(unet_output, dim=1, keepdims=True)
        err_map = (pred != mask)
        sample_image[f'pred_unet'] = fo.Segmentation(mask=pred.squeeze().numpy())
        sample_image[f'error_unet']  = fo.Segmentation(mask=err_map.squeeze().numpy())
        # umaps
        ## entropy
        umap_entropy = umap_generator_entropy(unet_output)
        sample_image[f'umap_entropy'] = fo.Heatmap(map=umap_entropy.squeeze().numpy())
        ## segmentation distortion - predictions with AE and UQ map
        model_output  = models[i](img.unsqueeze(0))
        model_pred    = torch.argmax(model_output[1], dim=0, keepdims=False)
        model_err_map = (model_pred != mask.squeeze())
        umap_ae = umap_generator_AE(model_output)
        sample_image[f'pred_ae'] = fo.Segmentation(mask=model_pred.squeeze().numpy())
        sample_image[f'error_ae']  = fo.Segmentation(mask=model_err_map.squeeze().numpy())
        sample_image[f'umap_ae'] = fo.Heatmap(map=umap_ae.squeeze().numpy())
        
        # add sample to sample list
        samples += [sample_image]
        
        # dice score
        DSC[i] = dcs(pred.squeeze(), mask.squeeze().int())        
    
    # get dice related tags per sample
    DSC = DSC.mean()
    thresholds = [0.3, 0.5, 0.8]
    for threshold in thresholds:
        if DSC < threshold:
            sample_image.tags.append(f'DSC_below_{int((threshold * 100))}')
        
#     # add sample to sample list
#     samples += [sample_image, sample_activation]
    
# add samples to dataset
dataset.add_samples(samples)


# validate dataset
print(dataset)

# export dataset
export_dir = os.path.expanduser('~') + '/fiftyone/test-dataset'
dataset.export(export_dir=export_dir, dataset_type=fo.types.FiftyOneDataset)

# clean up. remove tmp
shutil.rmtree('tmp')
#dataset.delete()

In [116]:
print(dataset.group_slices)

['image', 'test']


In [112]:
dataset.delete()

In [108]:
a=[1]
b=[2,3]

a+b

[1, 2, 3]

In [109]:
 a+=b

In [110]:
a

[1, 2, 3]

In [126]:
fo.__version__

'0.22.0'