In [1]:
import os
import sys

import torch
import torch.nn.functional as F
sys.path.append(os.path.join('..', '..', 'tta_uia_segmentation', 'src'))

from dataset.dataset_in_memory import get_datasets
from tta_uia_segmentation.src.utils.io import load_config
from tta_uia_segmentation.src.utils.loss import dice_score
from tta_uia_segmentation.src.dataset.utils import onehot_to_class, class_to_onehot


def compare_downsampled(ds, ds_downsampled):
    dices_fg = []
    for vol_i in range(len(ds)):
        _, label_gt, *_ = ds[vol_i]
        
        _, label_downsampled, *_ = ds_downsampled[vol_i]
        
        #label_downsampled = onehot_to_class(label_downsampled.permute(1, 0, 2, 3)).permute(1, 0, 2, 3)
        label_downsampled = F.interpolate(label_downsampled.unsqueeze(0).float(),
                                          size=label_gt.shape[-3:], mode='nearest').round().int()
        
        _, dice_fg = dice_score(label_gt.unsqueeze(0), label_downsampled.unsqueeze(0),
                                reduction='mean', soft=False)
        dice_fg = dice_fg.item()
        dices_fg.append(dice_fg) 
        print(f'Volume {vol_i}, Dice score: {dices_fg}')
        print('\n\n')
    
    print(f'Mean Dice score: {sum(dices_fg) / len(dices_fg)}')

# HCP dataset

In [2]:
# Load Normalizer and Segmentation model
dataset_name        = 'hcp_t1'
split               = 'train'
device              = 'cuda' if torch.cuda.is_available() else 'cpu' 

model_params        = load_config('/scratch_net/biwidl319/jbermeo/MastersThesisUIASegmentation/config/models.yaml')
dataset_params      = load_config('/scratch_net/biwidl319/jbermeo/MastersThesisUIASegmentation/config/datasets.yaml')
dataset_params      = dataset_params[dataset_name]


# Load dataset with original preprocessed images (val)
(ds, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = (256, 256, 256),
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
    )

(ds_down, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = (256, 256, 256),
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        rescale_factor  = [1, 0.5, 0.5],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
   )


  return torch._C._cuda_getDeviceCount() > 0




In [3]:
compare_downsampled(ds, ds_down)

torch.Size([15, 256, 128, 128])
torch.Size([1, 15, 256, 256, 256])
Volume 0, Dice score: [0.9180825352668762]



torch.Size([15, 256, 128, 128])
torch.Size([1, 15, 256, 256, 256])
Volume 1, Dice score: [0.9180825352668762, 0.9150625467300415]



torch.Size([15, 256, 128, 128])
torch.Size([1, 15, 256, 256, 256])
Volume 2, Dice score: [0.9180825352668762, 0.9150625467300415, 0.9144259095191956]



torch.Size([15, 256, 128, 128])
torch.Size([1, 15, 256, 256, 256])
Volume 3, Dice score: [0.9180825352668762, 0.9150625467300415, 0.9144259095191956, 0.9169603586196899]



torch.Size([15, 256, 128, 128])
torch.Size([1, 15, 256, 256, 256])
Volume 4, Dice score: [0.9180825352668762, 0.9150625467300415, 0.9144259095191956, 0.9169603586196899, 0.9111089706420898]



torch.Size([15, 256, 128, 128])
torch.Size([1, 15, 256, 256, 256])
Volume 5, Dice score: [0.9180825352668762, 0.9150625467300415, 0.9144259095191956, 0.9169603586196899, 0.9111089706420898, 0.9201210737228394]



torch.Size([15, 256, 1

In [4]:
(ds_down, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = (256, 256, 256),
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        rescale_factor  = [1, 0.25, 0.25],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
    )


compare_downsampled(ds, ds_down)

torch.Size([15, 256, 64, 64])
torch.Size([1, 15, 256, 256, 256])
Volume 0, Dice score: [0.789800226688385]



torch.Size([15, 256, 64, 64])
torch.Size([1, 15, 256, 256, 256])
Volume 1, Dice score: [0.789800226688385, 0.7796684503555298]



torch.Size([15, 256, 64, 64])
torch.Size([1, 15, 256, 256, 256])
Volume 2, Dice score: [0.789800226688385, 0.7796684503555298, 0.781373918056488]



torch.Size([15, 256, 64, 64])
torch.Size([1, 15, 256, 256, 256])
Volume 3, Dice score: [0.789800226688385, 0.7796684503555298, 0.781373918056488, 0.7742565274238586]



torch.Size([15, 256, 64, 64])
torch.Size([1, 15, 256, 256, 256])
Volume 4, Dice score: [0.789800226688385, 0.7796684503555298, 0.781373918056488, 0.7742565274238586, 0.7718745470046997]



torch.Size([15, 256, 64, 64])
torch.Size([1, 15, 256, 256, 256])
Volume 5, Dice score: [0.789800226688385, 0.7796684503555298, 0.781373918056488, 0.7742565274238586, 0.7718745470046997, 0.7832061052322388]



torch.Size([15, 256, 64, 64])
torch.Size([1,

# WMH dataset

In [None]:
# Load Normalizer and Segmentation model
dataset_name        = 'umc'
split               = 'train'
image_size          = (48, 256, 256)
device              = 'cuda' if torch.cuda.is_available() else 'cpu' 

model_params        = load_config('/scratch_net/biwidl319/jbermeo/MastersThesisUIASegmentation/config/models.yaml')
dataset_params      = load_config('/scratch_net/biwidl319/jbermeo/MastersThesisUIASegmentation/config/datasets.yaml')
dataset_params      = dataset_params[dataset_name]


# Load dataset with original preprocessed images (val)
(ds, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = image_size,
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
    )

(ds_half, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = image_size,
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        rescale_factor  = [1, 0.5, 0.5],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
    )

(ds_quarter, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = image_size,
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        rescale_factor  = [1, 0.25, 0.25],
        aug_params      = None,
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = None
    )
    


In [None]:
compare_downsampled(ds, ds_half)

torch.Size([15, 256, 128, 128])
torch.Size([1, 15, 256, 256, 256])
0.9166433215141296
Volume 0, Dice score: [0.9166433215141296]



torch.Size([15, 256, 128, 128])
torch.Size([1, 15, 256, 256, 256])
0.9136461019515991
Volume 1, Dice score: [0.9166433215141296, 0.9136461019515991]



torch.Size([15, 256, 128, 128])
torch.Size([1, 15, 256, 256, 256])
0.9127492904663086
Volume 2, Dice score: [0.9166433215141296, 0.9136461019515991, 0.9127492904663086]



torch.Size([15, 256, 128, 128])
torch.Size([1, 15, 256, 256, 256])
0.9142723083496094
Volume 3, Dice score: [0.9166433215141296, 0.9136461019515991, 0.9127492904663086, 0.9142723083496094]



torch.Size([15, 256, 128, 128])
torch.Size([1, 15, 256, 256, 256])
0.9141612648963928
Volume 4, Dice score: [0.9166433215141296, 0.9136461019515991, 0.9127492904663086, 0.9142723083496094, 0.9141612648963928]



Mean Dice score: 0.9142944574356079


In [None]:
compare_downsampled(ds, ds_quarter)

torch.Size([15, 256, 64, 64])
torch.Size([1, 15, 256, 256, 256])
0.7871463894844055
Volume 0, Dice score: [0.7871463894844055]



torch.Size([15, 256, 64, 64])
torch.Size([1, 15, 256, 256, 256])
0.7765485048294067
Volume 1, Dice score: [0.7871463894844055, 0.7765485048294067]



torch.Size([15, 256, 64, 64])
torch.Size([1, 15, 256, 256, 256])
0.7676004767417908
Volume 2, Dice score: [0.7871463894844055, 0.7765485048294067, 0.7676004767417908]



torch.Size([15, 256, 64, 64])
torch.Size([1, 15, 256, 256, 256])
0.77845299243927
Volume 3, Dice score: [0.7871463894844055, 0.7765485048294067, 0.7676004767417908, 0.77845299243927]



torch.Size([15, 256, 64, 64])
torch.Size([1, 15, 256, 256, 256])
0.7739769816398621
Volume 4, Dice score: [0.7871463894844055, 0.7765485048294067, 0.7676004767417908, 0.77845299243927, 0.7739769816398621]



Mean Dice score: 0.7767450690269471
