In [2]:
# - standard packages
import os, sys

# - third party packages
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.transforms import (
    Resize,
    CenterCrop,
    Normalize,
    functional,
)
from sklearn.preprocessing import MinMaxScaler
import nibabel as nib
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.paths import preprocessing_output_dir
from nnunet.training.dataloading.dataset_loading import *

# - local source
sys.path.append('..')
from augment import RandAugmentWithLabels, _apply_op


In [98]:
class MNMDataset(Dataset):
    # 1: Siemens
    # 2: Philips
    # 3: GE
    # 4: Canon

    def __init__(self, vendor: str, debug: bool = False, selection: str = 'all', adapt_size: str = "crop"):
        self.vendor = vendor
        self.debug = debug
        self.selection = selection
        if adapt_size == "crop":
            self.crop = CenterCrop([256, 256])
        elif adapt_size == "resize":
            self.resize = Resize((256, 224))
        self._get_dataset_information()
        self._load_selected_cases()

    def _get_dataset_information(self) -> None:
        # we only want to load M&M challenge data with this class. The following code is from
        # the nnUNet repo and simply loads the dataset information. We use this information to
        # filter the data later and load the appropriate files corresponding to e.g. a
        # specific vendor.
        t = "Task679_heart_mnms"  # "old/Task679_mnm" #
        p = join(preprocessing_output_dir, t, "nnUNetData_plans_v2.1_2D_stage0")
        self.dataset_info = load_dataset(p)
        with open(
            join(join(preprocessing_output_dir, t), "nnUNetPlansv2.1_plans_2D.pkl"),
            "rb",
        ) as f:
            plans = pickle.load(f)
        unpack_dataset(p)
        # select keys that are relevant for a specific vendor
        self.vendor_keys = [
            key for key in list(self.dataset_info.keys()) if f"_{self.vendor}" in key
        ]

    def _load_selected_cases(self) -> None:
        self.data = []
        counter = 0
        for key in self.vendor_keys:
            # load data as numpy array from nnUNet preprocessed folder
            data_np = np.load(self.dataset_info[key]["data_file"][:-4] + ".npz")["data"]
            # Transform to torch tensor and append to list.
            data = torch.from_numpy(data_np)
            # Crop data to ACDC data shape, i.e. (256, 224).
            data = self.crop(data)
            # merge background classes -1 and 0
            data[1][data[1] < 0] = 0
            assert (data[1] < 0).sum() == 0
            
            # mask slices with empty targets
            if self.selection == 'non_empty_target':
                mask = data[1].sum((1,2)) > 0
            
            # keep only slices that containt all 4 classes
            if self.selection == 'all_classes':
                mask = torch.zeros((data.shape[1]), dtype=bool)
                for i, slc in enumerate(data[1]):
                    mask[i] = len(torch.unique(slc)) >= 4
            
            # keep only slices that contain 3 out of the 4 classes
            if self.selection == 'single_class_missing':
                mask = torch.zeros((data.shape[1]), dtype=bool)
                for i, slc in enumerate(data[1]):
                    mask[i] = len(torch.unique(slc)) == 3

            # dont mask if mask is none
            if self.selection == 'all':
                pass
            else:
                assert len(mask.shape) == 1
                data = data[:, mask]

            self.data.append(data)

        # cat list to single tensor
        self.data = torch.cat(self.data, dim=1).unsqueeze(2)
        
        # for debugging purposes only take the first 50 cases
        if self.debug:
            self.data = self.data[:, :50]
        # split data into input and target
        self.input  = self.data[0]
        self.target = self.data[1]
        # swap values of 3 and 1 in target so that its
        # similar to the ACDC data convention
        self.target[self.target == 1] = 999
        self.target[self.target == 3] = 1
        self.target[self.target == 999] = 3
#         # merge background classes -1 and 0
#         self.target[self.target < 0] = 0
        

    def __len__(self):
        return self.data.shape[1]

    def __getitem__(self, idx):
        return {
            "input": self.input[idx],
            "target": self.target[idx],
            "voxel_dim": torch.tensor([1.0, 1.0, 1.0]),
        }

In [102]:
mnm = MNMDataset(vendor='A', debug=False, selection='single_class_missing')

loading dataset
loading all case properties


In [103]:
mnm.target.shape

torch.Size([292, 1, 256, 256])

In [87]:
torch.unique(mnm.target[:100].view(100, -1), dim=1)

tensor([[0., 0., 0.,  ..., 3., 3., 3.],
        [0., 0., 0.,  ..., 3., 3., 3.],
        [0., 0., 0.,  ..., 3., 3., 3.],
        ...,
        [0., 0., 0.,  ..., 2., 2., 2.],
        [0., 0., 0.,  ..., 2., 2., 2.],
        [0., 0., 0.,  ..., 3., 2., 2.]])

In [104]:
lengths = []
for ele in mnm.target:
    lengths.append(len(torch.unique(ele)))
    
print(lengths)

[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]


In [21]:
for sample in mnm.dataset_info:
    for c in [1,2,3]: #in mnm.dataset_info[sample]['properties']['class_locations']:
        if c not in mnm.dataset_info[sample]['properties']['class_locations'].keys():
            print(sample)
        if c in mnm.dataset_info[sample]['properties']['class_locations'].keys():
            print(mnm.dataset_info[sample]['properties']['class_locations'][c].shape)
        #print(c, mnm.dataset_info[sample]['properties']['class_locations'][c].shape)

(10000, 3)
(9479, 3)
(9771, 3)
(10000, 3)
(8021, 3)
(4979, 3)
(7280, 3)
(9957, 3)
(5759, 3)
(1995, 3)
(9182, 3)
(2220, 3)
(7921, 3)
(10000, 3)
(8959, 3)
(10000, 3)
(10000, 3)
(10000, 3)
(10000, 3)
(8524, 3)
(10000, 3)
(8685, 3)
(7125, 3)
(7990, 3)
(4017, 3)
(6275, 3)
(4256, 3)
(10000, 3)
(5401, 3)
(8203, 3)
(3882, 3)
(5351, 3)
(3962, 3)
(9375, 3)
(4776, 3)
(9173, 3)
(10000, 3)
(7576, 3)
(10000, 3)
(4569, 3)
(7071, 3)
(5457, 3)
(1945, 3)
(8530, 3)
(3296, 3)
(10000, 3)
(7264, 3)
(10000, 3)
(10000, 3)
(8739, 3)
(8639, 3)
(7541, 3)
(7379, 3)
(4474, 3)
(7027, 3)
(4765, 3)
(7671, 3)
(3149, 3)
(3828, 3)
(3782, 3)
(1794, 3)
(3628, 3)
(1468, 3)
(5372, 3)
(3541, 3)
(4566, 3)
(9328, 3)
(6249, 3)
(9715, 3)
(4171, 3)
(5059, 3)
(3751, 3)
(1254, 3)
(4205, 3)
(1223, 3)
(4809, 3)
(3394, 3)
(4584, 3)
(5012, 3)
(6133, 3)
(3741, 3)
(9381, 3)
(5781, 3)
(5216, 3)
(10000, 3)
(7655, 3)
(9644, 3)
(2453, 3)
(9385, 3)
(2328, 3)
(10000, 3)
(10000, 3)
(5801, 3)
(10000, 3)
(10000, 3)
(3686, 3)
(6007, 3)
(5398, 3)
(

In [14]:
mnm.dataset_info['A0S9V9_0000_A_1']

OrderedDict([('data_file',
              '/home/lennartz/repos/nnUNet/data/nnUNet_preprocessed/Task679_heart_mnms/nnUNetData_plans_v2.1_2D_stage0/A0S9V9_0000_A_1.npz'),
             ('properties_file',
              '/home/lennartz/repos/nnUNet/data/nnUNet_preprocessed/Task679_heart_mnms/nnUNetData_plans_v2.1_2D_stage0/A0S9V9_0000_A_1.pkl'),
             ('properties',
              OrderedDict([('original_size_of_raw_data',
                            array([ 13, 256, 216])),
                           ('original_spacing',
                            array([9.52000046, 1.328125  , 1.328125  ])),
                           ('list_of_data_files',
                            ['/home/lennartz/repos/nnUNet/data/nnUNet_raw_data_base/nnUNet_raw_data/Task679_heart_mnms/imagesTr/A0S9V9_0000_A_1_0000.nii.gz']),
                           ('seg_file',
                            '/home/lennartz/repos/nnUNet/data/nnUNet_raw_data_base/nnUNet_raw_data/Task679_heart_mnms/labelsTr/A0S9V9_0000_A_1.nii