In [5]:
import os
import sys

import torch
import torch.nn.functional as F
sys.path.append(os.path.join('..', '..', 'tta_uia_segmentation', 'src'))

from dataset.dataset_in_memory import get_datasets
from tta_uia_segmentation.src.utils.io import load_config
from tta_uia_segmentation.src.utils.loss import dice_score
from tta_uia_segmentation.src.dataset.utils import onehot_to_class, class_to_onehot


def compare_downsampled(ds, ds_downsampled):
    dices_fg = []
    for vol_i in range(len(ds)):
        _, label_gt, *_ = ds[vol_i]
        
        _, label_downsampled, *_ = ds_downsampled[vol_i]
        
        #label_downsampled = onehot_to_class(label_downsampled.permute(1, 0, 2, 3)).permute(1, 0, 2, 3)
        label_downsampled = F.interpolate(label_downsampled.unsqueeze(0).float(),
                                          size=label_gt.shape[-3:], mode='nearest').round().int()
        
        _, dice_fg = dice_score(label_gt.unsqueeze(0), label_downsampled.unsqueeze(0),
                                reduction='mean', soft=False)
        dice_fg = dice_fg.item()
        dices_fg.append(dice_fg) 
        print(f'Volume {vol_i}, Dice score: {dice_fg}')
        print('\n\n')
    
    print(f'Mean Dice score: {sum(dices_fg) / len(dices_fg)}')

# WMH dataset (with synthseg labels)

In [6]:
# Load Normalizer and Segmentation model
dataset_name        = 'umc_w_synthseg_labels'
split               = 'train'
image_size          = (48, 256, 256)
device              = 'cuda' if torch.cuda.is_available() else 'cpu' 

model_params        = load_config('/scratch_net/biwidl319/jbermeo/MastersThesisUIASegmentation/config/models.yaml')
dataset_params      = load_config('/scratch_net/biwidl319/jbermeo/MastersThesisUIASegmentation/config/datasets.yaml')
dataset_params      = dataset_params[dataset_name]


# Load dataset with original preprocessed images (val)
(ds, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = image_size,
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
    )

(ds_half, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = image_size,
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        rescale_factor  = [1, 0.5, 0.5],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
    )

(ds_quarter, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = image_size,
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        rescale_factor  = [1, 0.25, 0.25],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
    )
    




In [7]:
compare_downsampled(ds, ds_half)

Volume 0, Dice score: 0.8978961706161499



Volume 1, Dice score: 0.884945273399353



Volume 2, Dice score: 0.8913753628730774



Volume 3, Dice score: 0.8939829468727112



Volume 4, Dice score: 0.9011459946632385



Volume 5, Dice score: 0.9005327820777893



Volume 6, Dice score: 0.8904916048049927



Volume 7, Dice score: 0.9025124907493591



Volume 8, Dice score: 0.8830072283744812



Volume 9, Dice score: 0.9017219543457031



Mean Dice score: 0.8947611808776855


In [8]:
compare_downsampled(ds, ds_quarter)

Volume 0, Dice score: 0.7404664754867554



Volume 1, Dice score: 0.7156409025192261



Volume 2, Dice score: 0.732082724571228



Volume 3, Dice score: 0.7218263149261475



Volume 4, Dice score: 0.7403273582458496



Volume 5, Dice score: 0.7417212724685669



Volume 6, Dice score: 0.7193854451179504



Volume 7, Dice score: 0.744893491268158



Volume 8, Dice score: 0.7107737064361572



Volume 9, Dice score: 0.7440729141235352



Mean Dice score: 0.7311190605163574


# WMH dataset (only lesion)

In [9]:
# Load Normalizer and Segmentation model
dataset_name        = 'umc'
split               = 'train'
image_size          = (48, 256, 256)
device              = 'cuda' if torch.cuda.is_available() else 'cpu' 

model_params        = load_config('/scratch_net/biwidl319/jbermeo/MastersThesisUIASegmentation/config/models.yaml')
dataset_params      = load_config('/scratch_net/biwidl319/jbermeo/MastersThesisUIASegmentation/config/datasets.yaml')
dataset_params      = dataset_params[dataset_name]


# Load dataset with original preprocessed images (val)
(ds, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = image_size,
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
    )

(ds_half, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = image_size,
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        rescale_factor  = [1, 0.5, 0.5],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
    )

(ds_quarter, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = image_size,
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        rescale_factor  = [1, 0.25, 0.25],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
    )
    




In [10]:
compare_downsampled(ds, ds_half)

Volume 0, Dice score: 0.8132330775260925



Volume 1, Dice score: 0.6258992552757263



Volume 2, Dice score: 0.6776785850524902



Volume 3, Dice score: 0.8073963522911072



Volume 4, Dice score: 0.8283761143684387



Volume 5, Dice score: 0.8060324192047119



Volume 6, Dice score: 0.64371258020401



Volume 7, Dice score: 0.8435158729553223



Volume 8, Dice score: 0.5855379104614258



Volume 9, Dice score: 0.8724594116210938



Mean Dice score: 0.7503841578960418


In [11]:
compare_downsampled(ds, ds_quarter)

Volume 0, Dice score: 0.5445148348808289



Volume 1, Dice score: 0.25360578298568726



Volume 2, Dice score: 0.37170475721359253



Volume 3, Dice score: 0.5474470853805542



Volume 4, Dice score: 0.5835620164871216



Volume 5, Dice score: 0.5357488989830017



Volume 6, Dice score: 0.27941176295280457



Volume 7, Dice score: 0.6267862915992737



Volume 8, Dice score: 0.20840950310230255



Volume 9, Dice score: 0.6717285513877869



Mean Dice score: 0.4622919484972954
