# Segmentation or registration CNN training and evaluation

1. Train either a segmentation or a registration (DDF prediction) model
1. Evaluate it or test on new images
1. Plot the results

### Load required libraries

In [1]:
%load_ext autoreload
%autoreload 2

#Common imports + pytorch + torchinfo
import argparse, shlex, glob, os, sys, numpy as np
from collections import defaultdict
from functools import partial
import SimpleITK as sitk
import torch, pytorch_lightning as pl
from torchinfo import summary
torch.autograd.set_detect_anomaly(True)

#Our library
from lib.models import SegmentationModel, CustomMaskedSegmentationModel, DDFModel
from lib.data import MedicalImageDataset, SubsetDataModule, PathDataModule

#plot_lib
from pathlib import Path
sys.path.append(os.path.join(Path.home(), 'plot_lib'))
from plot_lib import plot, plot_multi_mask, plot_alpha, plot_composite

### Global configuration

In [2]:
#Global problem to solve
PROBLEM= 'DDF'
problem_dict= {
     'SEG_MR': 'Standard MR segmentation problem', 
     'DDF': 'Dense deformation field estimation from MR / US pairs with corresponding prostate masks'}
print(f"{PROBLEM=}: {problem_dict[PROBLEM]}")
test_id= -1 #Position of the patient for which a sample will be plotted
    
#Train / test / predict?
TRAIN= False
TEST= True
BLIND_PREDICT= True

PROBLEM='DDF': Dense deformation field estimation from MR / US pairs with corresponding prostate masks


### Data loading configuration

In [3]:
#Problem-specific configuration
if PROBLEM.startswith('SEG'):
    
    #Build DataModule
    modality= PROBLEM.split('_')[-1]
    data_path= 'D:/oscar/Prostate Images/Promise12/Train'
    #input_size, spacing= [160, 160, 32], [0.75,0.75,3]
    input_size, spacing= [112, 112, 32], [1,1,3]
    input_channels, output_channels= 1, 2 #Consider BG class too
    num_workers, batch_size= 4, 4
    model_class= SegmentationModel
    dataModule= PathDataModule(data_path, {'*_segmentation.mhd':'MR_mask', '*.mhd':'MR'}, dataset='Promise12',
                               num_workers=num_workers, batch_size=batch_size,
                               inputs=[modality], outputs=[f'{modality}_mask'],
                               subset_percentages= {'test':0.15, 'val':0.10, 'train':0.75}, shuffle=True,
                               size=input_size, spacing=spacing, cache=True,
                               sorting_fn= lambda id: int(id.split('Case')[1]),
                               id_making_fn=lambda path: path.split('_')[0])
    
    #Plot a sample
    dataModule.setup()
    img, msk, meta= dataModule.train_dataset[test_id]
    print(meta)
    plot(img[0], masks=[msk[1]], title=dataModule.train_ids[test_id])
    
elif PROBLEM.startswith('DDF'):

    #Build DataModule
    data_path= r'./registration_data/preprocessed'
    data_path_DDF= r'./registration_data/DDFs'
    fiducials_path= r'./registration_data/fiducials'
    #input_size, spacing= [160]*3, [0.5]*3
    input_size, spacing= [120]*3, [2/3]*3
    input_channels, output_channels= 4, 3
    model_class= DDFModel
    num_workers, batch_size= 3, 3 #Set to 1,1 if using MI loss
    dataModule= PathDataModule(data_path_DDF, data_path, {'*_DDF_FEM.nrrd':'DDF'}, 
                               {'*_MR_img.nrrd':'MR', '*_US_img.nrrd':'US', 
                                '*_MR_msk.nrrd':'MR_mask', '*_US_msk.nrrd':'US_mask'},
                               sorting_fn= lambda id: int(id[2:]),
                               id_making_fn=lambda path: path.split('_')[0],
                               num_workers=num_workers, batch_size=batch_size, 
                               inputs=['MR', 'US', 'MR_mask', 'US_mask'], problem='masked_regression',
                               outputs=['DDF', 'US_mask'], size=input_size, spacing=spacing, process_ddf=True,
                               cache=True, output_channels=None, mask_sigma=[7]*3, normalize_inputs=[0,1],
                               subset_percentages= {'test':0.15, 'val':0.10, 'train':0.75}, shuffle=False,
                               input_interpolator=([sitk.sitkBSpline]*2) + ([sitk.sitkLabelGaussian]*2))

    #Plot a sample
    dataModule.setup()
    img, ddf, meta= dataModule.train_dataset[test_id]
    print(meta)
    plot(img[0], masks=[img[2] > 0.5], title=dataModule.train_ids[test_id])
    plot(img[1], masks=[img[3] > 0.5])
    
    masks_plot= [img[2] > 0.5, img[3] > 0.5]
    plot(ddf[:3].transpose((1,2,3,0)), masks=masks_plot, is_color=True)
    plot(ddf[-1], masks=[img[3] > 0.5])

    from lib.models.layers import SpatialTransformer
    transformer = SpatialTransformer()
    actual_ddf= torch.as_tensor(ddf[:-1][None]).double()
    mr_tf_mask= transformer(torch.as_tensor(img[2:3][None]).double(), actual_ddf) 
    mr_tf= transformer(torch.as_tensor(img[0:1][None]).double(), actual_ddf)
    plot(mr_tf[0,0].numpy(), masks=[mr_tf_mask[0,0].numpy() > 0.5, img[3]> 0.5], title='Actual transformed MR')

else:
    raise ValueError(f'Unknown problem type {PROBLEM}')

#Data loader configuration
#Gradient will be accumulated to an equivalent of 4 images per batch
acb= max(min(4, int(4/batch_size)),1)
print(f'Batch size: {batch_size}; Accumulated batches: {acb}')
    
#Input size
full_input_size= [1, input_channels] + input_size
print(f'Input size: {full_input_size}')

#Some other info
print('Test ids:', dataModule.test_ids, end='\n\n')
for s, s_name in zip([dataModule.train_ids, dataModule.test_ids, 
                      dataModule.val_ids, dataModule.predict_ids], 
                     ['train', 'test', 'val', 'predict']):
    print(' - ', s_name, len(s))

{'dataset': 'IVO', 'pid': 'ID0002'}


interactive(children=(IntSlider(value=60, description='z', max=119, style=SliderStyle(handle_color='lightblue'…

interactive(children=(IntSlider(value=60, description='z', max=119, style=SliderStyle(handle_color='lightblue'…

interactive(children=(IntSlider(value=60, description='z', max=119, style=SliderStyle(handle_color='lightblue'…

interactive(children=(IntSlider(value=60, description='z', max=119, style=SliderStyle(handle_color='lightblue'…

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


interactive(children=(IntSlider(value=60, description='z', max=119, style=SliderStyle(handle_color='lightblue'…

Batch size: 3; Accumulated batches: 1
Input size: [1, 4, 120, 120, 120]
Test ids: []

 -  train 2
 -  test 0
 -  val 0
 -  predict 2


### Model training configuration

In [4]:
print('Warning: ENCODER_CHANNELS, DECODER_CHANNELS, ENCODER_BLOCKS and ANISOTROPIC_STAGES is'
      'unimplemented functionality as of yet. I will update the code when ready.')

#Global standard settings
ENCODER_CHANNELS= '48 128 256'
#ENCODER_CHANNELS= ' '.join(['%.0f'%(np.float_power(2, i)) for i in np.arange(5, 9)])
DECODER_CHANNELS= ' '.join(ENCODER_CHANNELS.split(' ')[::-1])
ENCODER_BLOCKS= '3 5 9' #'3 5 9 3'
LR, ANISOTROPIC_STAGES, LOAD_NAME, EPOCHS, ES_PATIENCE= 1e-4, 2, None, 700, 19

#Custom config depending on problem type
if PROBLEM == 'SEG_MR':
    LR= 3e-4
    CNN_NAME= 'VNET2'
    #LOAD_NAME= 'weights/SEG_MR_VNET2/version_24/epoch=404-step=34019.ckpt'
    
elif PROBLEM == 'DDF':
    LR= 1e-2
    ES_PATIENCE= 38
    CNN_NAME= 'VNET2'
    LOAD_NAME= 'weights/DDF_VNET2/version_33/checkpoints/epoch=459-step=23459.ckpt'
    
else:
    raise ValueError(f'Unknown problem type {PROBLEM}')
    
#Get to args string
print(ENCODER_CHANNELS)
argString= (f" --max_steps 100000 --net {CNN_NAME} --accumulate_grad_batches {acb} --gpus -1 --lr {LR} "
            f" --bnorm --dropout 0.05 --plot --check_val_every_n_epoch 5 --log_every_n_steps 10 "
            f" --max_epochs {EPOCHS} --batch_size {batch_size} --transform"
            f" --input_channels {input_channels} --output_classes {output_channels}"
            f" {' --load_from_checkpoint %s '%LOAD_NAME if LOAD_NAME is not None else ''}"
            f" --anisotropic_stages {ANISOTROPIC_STAGES} --encoder_blocks {ENCODER_BLOCKS} "
            f" --encoder_channels {ENCODER_CHANNELS} --decoder_channels {DECODER_CHANNELS} "
           )

#Build parser
parser = argparse.ArgumentParser()
parser = model_class.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
hparams = parser.parse_args(shlex.split(argString))
print(hparams)

#Load model
if hparams.load_from_checkpoint: 
    model = model_class.load_from_checkpoint(hparams.load_from_checkpoint)
else: model = model_class(hparams)
    
#Logger
from pytorch_lightning.loggers import TensorBoardLogger
model.example_input_array= torch.rand(full_input_size)
logger = TensorBoardLogger('lightning_logs', default_hp_metric=False, name=f'{PROBLEM}_{CNN_NAME}', log_graph=True)

#Trainer
early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss", patience=ES_PATIENCE, 
                                            strict=True, verbose=True, mode="min")
lr_logger= pl.callbacks.lr_monitor.LearningRateMonitor()
swa= pl.callbacks.StochasticWeightAveraging(swa_epoch_start=0.8)
trainer = pl.Trainer.from_argparse_args(hparams, callbacks=[early_stopping, lr_logger, swa], logger=logger)

#See architecture summary and generate graph
print(summary(model, input_size=full_input_size, depth=7))

# Run learning rate finder?
# lr_finder = trainer.tuner.lr_find(model, dataModule)
# fig = lr_finder.plot(suggest=True)

#Train
if TRAIN: trainer.fit(model, dataModule)

48 128 256
Namespace(accelerator=None, accumulate_grad_batches=1, amp_backend='native', amp_level=None, anisotropic_stages=2, auto_lr_find=False, auto_scale_batch_size=False, auto_select_gpus=False, batch_size=3, benchmark=False, bnorm=True, check_val_every_n_epoch=5, checkpoint_callback=None, decoder_channels=[256, 128, 48], default_root_dir=None, detect_anomaly=False, deterministic=False, devices=None, diff_weight=1.0, dropout=0.05, dsc_weight=1.0, enable_checkpointing=True, enable_model_summary=True, enable_progress_bar=True, encoder_blocks=[3, 5, 9], encoder_channels=[48, 128, 256], fast_dev_run=False, feat_weight=0.0, flush_logs_every_n_steps=None, gpus=-1, grad_weight=0.0, gradient_clip_algorithm=None, gradient_clip_val=None, input_channels=4, ipus=None, limit_predict_batches=1.0, limit_test_batches=1.0, limit_train_batches=1.0, limit_val_batches=1.0, load_from_checkpoint='weights/DDF_VNET2/version_33/checkpoints/epoch=459-step=23459.ckpt', log_every_n_steps=10, log_gpu_memory=No

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Layer (type:depth-idx)                             Output Shape              Param #
DDFModel                                           --                        --
├─VNetLight: 1-1                                   [1, 3, 120, 120, 120]     --
│    └─InputTransition: 2-1                        [1, 16, 120, 120, 120]    --
│    │    └─Conv3d: 3-1                            [1, 16, 120, 120, 120]    8,016
│    │    └─InstanceNorm3d: 3-2                    [1, 16, 120, 120, 120]    --
│    │    └─PReLU: 3-3                             [1, 16, 120, 120, 120]    16
│    └─DownTransition: 2-2                         [1, 32, 60, 60, 60]       --
│    │    └─Conv3d: 3-4                            [1, 32, 60, 60, 60]       4,128
│    │    └─InstanceNorm3d: 3-5                    [1, 32, 60, 60, 60]       --
│    │    └─PReLU: 3-6                             [1, 32, 60, 60, 60]       32
│    │    └─PReLU: 3-7                             --                        (recursive)
│    │    └─Sequenti

### Evaluate model or predict on new data + show results

In [5]:
#Manually visualize results & compute metrics
from lib.processing.preprocessing import (
       undo_resample_image, surface_distance_metrics, DSC_sitk, 
       slice_wise_binary_smoothing, surface_distance_metrics, 
       MI_sitk, DSC_sitk, print_metric, sitk_transform_points, load_json, 
       get_gradient_features, transform_sitk, point_average_error, sitk_transform_points, 
       point_max_min_distance, point_hd95, torch2sitk, read_fiducials, 
       point_abd, distance0, distance, pair_test)

#Initialize some data structures
cpu_times, cuda_times, results= [], [], {}

if PROBLEM.startswith('SEG') and (TEST or BLIND_PREDICT):
    for m in ['dsc', 'abd', 'hd95']:
        results[m]= defaultdict(list)
        results[m]['All']= []
    model.eval()
    with torch.no_grad():
        dataset= dataModule.predict_dataset if BLIND_PREDICT else dataModule.test_dataset
        for i, batch in enumerate(dataset):
            x, y, meta= batch #Unpack batch
            x, y= torch.as_tensor(x)[None].to(model.device), torch.as_tensor(y)[None].to(model.device)
            pid= meta['pid']; print('\n', pid)
            
            with torch.autograd.profiler.profile(use_cuda=True) as prof: #Measure forward time!
                yp_onehot, _, _= model.forward(x) #Predict
            cpu_times.append(prof.total_average().self_cpu_time_total)
            cuda_times.append(prof.total_average().self_cuda_time_total)
            
            mask_p= sitk.GetImageFromArray(yp_onehot.cpu().numpy()[0,1]) > 0.5 
            mask_p= undo_resample_image(mask_p, **meta['image_meta'][0], interpolator=sitk.sitkLabelGaussian)
            mask_p= slice_wise_binary_smoothing(mask_p, radius=3)
            img= meta['inputs'][0]
                        
            if BLIND_PREDICT:
                os.makedirs('./predicted', exist_ok = True)
                plot(img, masks=[mask_p], center_crop=[120, 120], title=meta['pid'])
                sitk.WriteImage(mask_p > 0.5, os.path.join('./predicted', f'{pid}.nrrd'), True)
            else:
                if meta['outputs'][0].GetNumberOfComponentsPerPixel() > 1:
                    mask= sitk.VectorIndexSelectionCast(meta['outputs'][0], 0, outputPixelType=sitk.sitkFloat32) > 0.5
                else:
                    mask= sitk.Cast(meta['outputs'][0], sitk.sitkFloat32) > 0.5

                plot(img, masks=[mask, mask_p], center_crop=[120, 120], title=meta['pid'])
                dsc, (hd95, abd)= DSC_sitk(mask, mask_p), surface_distance_metrics(mask, mask_p)
                print(meta['dataset'], meta['pid'], dsc, hd95, abd)
                for name, value in zip(['dsc', 'hd95', 'abd'], [dsc, hd95, abd]):
                    results[name][meta['dataset']].append(value)
                    results[name]['All'].append(value)
                                 
elif PROBLEM.startswith('DDF') and (TEST or BLIND_PREDICT):
    all_metrics= ['mi_all', 'mi_in', 'mig_all', 'mig_in', 'dsc2','hd95', 'abd', 'tre']
    for m in all_metrics:
        results[m + '_base']= defaultdict(list)
        results[m]= defaultdict(list)
        results[m + '_pred']= defaultdict(list)
        results[m + '_interp']= defaultdict(list)
    model.eval()
    with torch.no_grad():
        dataset= dataModule.predict_dataset if BLIND_PREDICT else dataModule.test_dataset
        for i, batch in enumerate(dataset):           
            #Extract batch data
            x, y_true, meta = batch
            pid= meta['pid']; print('\n', pid)
            x, y_true= torch.as_tensor(x)[None].to(model.device), torch.as_tensor(y_true)[None].to(model.device)
            y_softmask= y_true[:,-1:] #Mask is contained in the last channel of y
            y_ddf= y_true[:,:-1] #DDF is contained in the first 3 channels of y
            with torch.autograd.profiler.profile(use_cuda=True) as prof: #Measure forward time!
                yp_ddf= model.forward(x[:,:4]) #Use model
            cpu_times.append(prof.total_average().self_cpu_time_total)
            cuda_times.append(prof.total_average().self_cuda_time_total)
            
            #Load non-FEM interpolated DDF
            ddf_sitk_interp= sitk.ReadImage(os.path.join(data_path_DDF, pid + '_DDF.nrrd'))
            ddf_sitk_inv_interp= sitk.InvertDisplacementField(ddf_sitk_interp, enforceBoundaryCondition=False)
            
            #Everything to sitk
            mr_img, us_img, mr_msk, us_msk= meta['inputs'][:4]
            ddf_sitk_inv, _= meta['outputs']
            ddf_sitk= sitk.InvertDisplacementField(ddf_sitk_inv, enforceBoundaryCondition=False)
            
            ddf_sitk_inv_pred= torch2sitk(yp_ddf, c=[2,1,0], f=lambda a: a.transpose(1,2,3,0) * spacing[2], ref=ddf_sitk)
            ddf_sitk_inv_pred= undo_resample_image(ddf_sitk_inv_pred, **meta['image_meta'][0], interpolator=sitk.sitkBSpline)
            ddf_sitk_pred= sitk.InvertDisplacementField(ddf_sitk_inv_pred, enforceBoundaryCondition=False)
            
            DDF_tfm= sitk.DisplacementFieldTransform(sitk.Cast(ddf_sitk, sitk.sitkVectorFloat64))
            DDF_tfm_pred= sitk.DisplacementFieldTransform(sitk.Cast(ddf_sitk_pred, sitk.sitkVectorFloat64))
            DDF_tfm_interp= sitk.DisplacementFieldTransform(sitk.Cast(ddf_sitk_interp, sitk.sitkVectorFloat64))
            DDF_tfm_inv= sitk.DisplacementFieldTransform(sitk.Cast(ddf_sitk_inv, sitk.sitkVectorFloat64))
            DDF_tfm_inv_pred= sitk.DisplacementFieldTransform(sitk.Cast(ddf_sitk_inv_pred, sitk.sitkVectorFloat64))
            DDF_tfm_inv_interp= sitk.DisplacementFieldTransform(sitk.Cast(ddf_sitk_inv_interp, sitk.sitkVectorFloat64))
            
            mr_img_after= sitk.Resample(mr_img, transform=DDF_tfm_inv, interpolator=sitk.sitkBSpline)
            mr_img_after_pred= sitk.Resample(mr_img, transform=DDF_tfm_inv_pred, interpolator=sitk.sitkBSpline)
            mr_img_after_interp= sitk.Resample(mr_img, transform=DDF_tfm_inv_interp, interpolator=sitk.sitkBSpline)
            mr_msk_after= sitk.Resample(mr_msk, transform=DDF_tfm_inv, interpolator=sitk.sitkLabelGaussian)
            mr_msk_after_pred= sitk.Resample(mr_msk, transform=DDF_tfm_inv_pred, interpolator=sitk.sitkLabelGaussian)
            mr_msk_after_interp= sitk.Resample(mr_msk, transform=DDF_tfm_inv_interp, interpolator=sitk.sitkLabelGaussian)
            
            #Load fiducials (if available) and transform them
            mr_fiducials= read_fiducials(os.path.join(fiducials_path, pid + '_MR.mrk.json'))
            us_fiducials= read_fiducials(os.path.join(fiducials_path, pid + '_US.mrk.json'), check_fiducials=mr_fiducials)
            us_fiducials= list(us_fiducials.values())
            mr_fiducials= list(mr_fiducials.values()) if len(us_fiducials) else []

            mr_fiducials_after= list(sitk_transform_points(DDF_tfm, mr_fiducials)) if len(mr_fiducials) else []
            mr_fiducials_after_pred= list(sitk_transform_points(DDF_tfm_pred, mr_fiducials)) if len(mr_fiducials) else []
            mr_fiducials_after_interp= list(sitk_transform_points(DDF_tfm_interp, mr_fiducials)) if len(mr_fiducials) else []
          
            #Plot
            points=[[*us_img.TransformPhysicalPointToContinuousIndex(p),'o','tab:blue',f'{i+1} (US)']\
                         for i,p in enumerate(us_fiducials)] +\
                    [[*mr_img_after.TransformPhysicalPointToContinuousIndex(p),'o','tab:red',f'{i+1} (MR after)']\
                         for i,p in enumerate(mr_fiducials_after) ]
            points= []
            plot(mr_img_after, allowed_label_overlap=[3,3], text_kwargs=dict(size=14), scale=3, 
                 masks=[ [msk, [1], [color], [name]] for msk, name, color in zip([us_msk, mr_msk, mr_msk_after], 
                        ['US', 'MR before', 'MR after'], ['tab:blue', 'tab:red', 'tab:orange']) ], 
                 points=points, hide_axis=True,)# title='Reference (CPD + FEM)')
            plot(mr_img_after_pred, allowed_label_overlap=[3,3], text_kwargs=dict(size=14), scale=3, 
                 masks=[ [msk, [1], [color], [name]] for msk, name, color in zip([us_msk, mr_msk, mr_msk_after], 
                        ['', '', ''], ['tab:blue', 'tab:red', 'tab:orange']) ], 
                 points=points, hide_axis=True)
            plot_composite(*[us_img, mr_img_after][::1], masks=[us_msk], n_tiles=[4]*2, scale=3, 
                           points=points, hide_axis=True,)# title='Reference (CPD + FEM)')
            plot_composite(*[us_img, mr_img_after_pred][::1], masks=[us_msk], n_tiles=[4]*2, scale=3,
                           points=points, hide_axis=True)
            
            #Compute and print metrics
            print('\nMetrics (>0% is better)')
            mi_all_base, mi_all, mi_all_pred, mi_all_interp= print_metric(us_img, mr_img, mr_img_after, mr_img_after_pred, 
                                                 mr_img_after_interp, metric=MI_sitk, name='MI (all)')
            mi_in_base, mi_in, mi_in_pred, mi_in_interp= print_metric(us_img, mr_img, mr_img_after, mr_img_after_pred, 
                                               mr_img_after_interp, metric=partial(MI_sitk, mask=us_msk), name='MI (in)')
            mig_all_base, mig_all, mig_all_pred, mig_all_interp= print_metric(us_img, mr_img, mr_img_after, mr_img_after_pred,
    mr_img_after_interp, metric=lambda *args: MI_sitk(*(get_gradient_features(a) for a in args)), name='MI grad (all)')
            mig_in_base, mig_in, mig_in_pred, mig_in_interp= print_metric(us_img, mr_img, mr_img_after, mr_img_after_pred,
    mr_img_after_interp, metric=lambda *args: MI_sitk(*(get_gradient_features(a) for a in args), mask=us_msk), name='MI grad (in)')
            dsc2_base, dsc2, dsc2_pred, dsc2_interp= print_metric(us_msk, mr_msk, mr_msk_after, mr_msk_after_pred, 
                                             mr_msk_after_interp, metric=DSC_sitk, name='DSC (masks)')
            
            print('\nMetrics (<0% is better)')
            hd95_base, hd95, hd95_pred, hd95_interp= print_metric(us_msk, mr_msk, mr_msk_after, mr_msk_after_pred, 
                mr_msk_after_interp, metric=lambda *i: surface_distance_metrics(*i)[0], name='HD95 (masks)')
            abd_base, abd, abd_pred, abd_interp= print_metric(us_msk, mr_msk, mr_msk_after, mr_msk_after_pred, 
                mr_msk_after_interp, metric=lambda *i: surface_distance_metrics(*i)[1], name='ABD (masks)')
            tre_base, tre, tre_pred, tre_interp= print_metric(us_fiducials, mr_fiducials, mr_fiducials_after, mr_fiducials_after_pred, 
                                       mr_fiducials_after_interp, metric=lambda a, b: list(distance(a,b)), name='TRE (mm)')
            
            #Save all metrics
            for name, value in zip(
               [m + '_base' for m in all_metrics] + all_metrics + \
               [m + '_pred' for m in all_metrics] + [m + '_interp' for m in all_metrics],
               [
        mi_all_base, mi_in_base, mig_all_base, mig_in_base, dsc2_base, hd95_base, abd_base, tre_base,
        mi_all, mi_in, mig_all, mig_in, dsc2, hd95, abd, tre,
        mi_all_pred, mi_in_pred, mig_all_pred, mig_in_pred, dsc2_pred, hd95_pred, abd_pred, tre_pred,
        mi_all_interp, mi_in_interp, mig_all_interp, mig_in_interp, dsc2_interp, hd95_interp, abd_interp, tre_interp
               ]):
                if not name.startswith('tre'):
                    results[name][meta['dataset']].append(value)
                else:
                    results[name][meta['dataset']]+= value
                if 'All' in results[name].keys():
                    results[name]['All'].append(value)
else:
    raise NotImplementedError('')

#Print collected metrics
print('\nFinal metrics')
print(f' - CPU mean time per image: {np.median(cpu_times)/1000.:.2f}ms;'
      f'\n - CUDA mean time per image: {np.median(cuda_times)/1000.:.2f}ms')
for metric, datasets in results.items():
    #print('%s'%metric)
    for dataset, values in datasets.items():
        if metric.endswith('_base') or 'DDF' not in PROBLEM :
            print('%15s: mean: %.4f +/- %.4f (N=%d) | median : %.4f | min : %.4f | max : %.4f'%
                  (f"{dataset} {metric}", np.nanmean(values), np.nanstd(values), np.sum(~np.isnan(values)), 
                   np.nanmedian(values), np.nanmin(values), np.nanmax(values)) )
        else:
            print('%15s: mean: %.4f +/- %.4f (N=%d) | median : %.4f | min : %.4f | max : %.4f | p-val ttest: %.4f'%
                  (f"{dataset} {metric}", np.nanmean(values), np.nanstd(values), np.sum(~np.isnan(values)),
                    np.nanmedian(values), np.nanmin(values), np.nanmax(values), 
                   pair_test(values, results[metric.replace('_pred', '').replace('_interp', '') + '_base'][dataset]) ))
        if metric.endswith('_interp'): print('')


 ID0001


interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…

interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…

interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…

interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…


Metrics (>0% is better)
 - MI (all): 0: -0.0438, 1: -0.0663(51.22%), 2: -0.0504(15.04%), 3: -0.0996(127.15%)
 - MI (in): 0: -0.0141, 1: -0.0097(-31.25%), 2: -0.0127(-9.90%), 3: -0.0099(-29.74%)
 - MI grad (all): 0: -0.0021, 1: -0.0035(65.85%), 2: -0.0025(18.87%), 3: -0.0058(177.42%)
 - MI grad (in): 0: -0.0016, 1: -0.0012(-22.58%), 2: -0.0017(5.50%), 3: -0.0016(-0.85%)
 - DSC (masks): 0: 0.8564, 1: 0.9856(15.08%), 2: 0.9660(12.79%), 3: 0.9847(14.98%)

Metrics (<0% is better)
 - HD95 (masks): 0: 8.5000, 1: 4.0000(-52.94%), 2: 1.8708(-77.99%), 3: 6.2650(-26.29%)
 - ABD (masks): 0: 2.6749, 1: 0.6465(-75.83%), 2: 0.6169(-76.94%), 3: 0.9197(-65.62%)
 - TRE (mm): 0: 5.8015, 1: 4.2370(-26.97%), 2: 4.2309(-27.07%), 3: 2.8980(-50.05%)

 ID0002


interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…

interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…

interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…

interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…


Metrics (>0% is better)
 - MI (all): 0: -0.1211, 1: -0.1080(-10.88%), 2: -0.1032(-14.79%), 3: -0.1225(1.09%)
 - MI (in): 0: -0.0130, 1: -0.0225(73.43%), 2: -0.0255(96.79%), 3: -0.0196(50.98%)
 - MI grad (all): 0: -0.0013, 1: -0.0063(375.51%), 2: -0.0028(113.72%), 3: -0.0094(606.23%)
 - MI grad (in): 0: -0.0013, 1: -0.0020(55.80%), 2: -0.0020(58.17%), 3: -0.0024(86.23%)
 - DSC (masks): 0: 0.7931, 1: 0.9658(21.78%), 2: 0.9528(20.14%), 3: 0.9747(22.90%)

Metrics (<0% is better)
 - HD95 (masks): 0: 9.5000, 1: 13.0000(36.84%), 2: 4.0000(-57.89%), 3: 12.5000(31.58%)
 - ABD (masks): 0: 3.6195, 1: 2.1262(-41.26%), 2: 0.9347(-74.18%), 3: 2.2444(-37.99%)
 - TRE (mm): 0: 8.0948, 1: 6.1508(-24.02%), 2: 5.2278(-35.42%), 3: 6.5772(-18.75%)

Final metrics
 - CPU mean time per image: 6.27ms;
 - CUDA mean time per image: 233.68ms
IVO mi_all_base: mean: -0.0825 +/- 0.0386 (N=2) | median : -0.0825 | min : -0.1211 | max : -0.0438
     IVO mi_all: mean: -0.0871 +/- 0.0208 (N=2) | median : -0.0871 | min : 