In [13]:
import os
import glob
import monai
from PIL import Image
import torch
import nibabel as nib
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from monai.apps import CrossValidation
from abc import ABC, abstractmethod
import sys


In [14]:
data_path = "./preprocessed/"
if not os.path.exists(data_path):
    print("Please update your data path to an existing folder.")
elif not set(["training", "testing"]).issubset(set(os.listdir(data_path))):
    print("Please update your data path to the correct folder (should contain training and testing folders).")
else:
    print("Congrats! You selected the correct folder :)")

Congrats! You selected the correct folder :)


In [15]:
class CVDataset(ABC, monai.data.CacheDataset):
    """
    Base class to generate cross validation datasets.

    """

    def __init__(
        self,
        data,
        transform,
        cache_num=sys.maxsize,
        cache_rate=1.0,
        num_workers=4,
    ) -> None:
        data = self._split_datalist(datalist=data)
        monai.data.CacheDataset.__init__(
            self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
        )

    @abstractmethod
    def _split_datalist(self, datalist):
        raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

In [16]:
# print(img.header.structarr['pixdim'])

def build_dict_ACDC(data_path, modes='training', heart_mode='Off'):
    """
    This function returns a list of dictionaries, each dictionary containing the keys 'img' and 'mask' 
    that returns the path to the corresponding image.
    
    Args:
        data_path (str): path to the root folder of the data set.
        modes (str): subset used. Must correspond to 'training', 'val' or 'testing'.
        
    Returns:
        (List[Dict[str, str]]) list of the dictionaries containing the paths of X-ray images and masks.
    """
    # test if mode is correct
    if modes not in ["training", "val", "testing", "all"]:
        raise ValueError(f"Please choose a mode in ['training', 'val', 'testing', 'all']. Current mode is {mode}.")
    
    # define empty dictionary
    dicts = []
    dicts2 = []
    paths_mri = []
    iBegin = 1
    iEnd = 101
    
    if(modes=="all"):
        modes= ['training', 'testing']
    else:
        modes = [modes]
    
    for mode in modes:

        if (mode=='testing'):
            iBegin = 101
            iEnd = 151

        for i in tqdm(range(iBegin,iEnd)):
            
                # list all .png files in directory, including the path
                paths_mri.append(glob.glob(os.path.join(data_path, mode, 'patient{:03}'.format(i), '*[!gt].png')))
                
                # print(os.path.join(data_path, mode, 'patient{:03}'.format(i), '*[!gt].png'))
                # make a corresponding list for all the mask files
                for mri_path in paths_mri[0]:
                    if mode == 'testing':
                        suffix = 'val'
                    else:
                        suffix = mode

                    mask_path = os.path.join(mri_path[0:-4]+'_gt'+ '.png')
                    if os.path.exists(mask_path):
                        if (heart_mode=='Off'):
                            dicts.append({'img': mri_path, 'mask': mask_path})
                        else:

                            if 'ED' in mri_path:
                                dicts.append({'img': mri_path, 'mask': mask_path})

                            else:
                                dicts2.append({'img': mri_path, 'mask': mask_path})

                paths_mri.clear()        
    if (heart_mode=='Off'):   
        return dicts
    else:
        return dicts, dicts2

                    
All_data = build_dict_ACDC(data_path, modes="training")
# ES, ED = build_dict_ACDC(data_path, heart_mode="On")
# print(len(All_data))

100%|██████████| 100/100 [00:06<00:00, 14.71it/s]

1902





In [18]:
class LoadMriData(monai.transforms.Transform):
    """
    This custom Monai transform loads the data from the rib segmentation dataset.
    Defining a custom transform is simple; just overwrite the __init__ function and __call__ function.
    """
    def __init__(self, keys=None):
        pass

    def __call__(self, sample):
        image = Image.open(sample['img']).convert('L') # import as grayscale image
        # image = nib.load(sample['img']).get_fdata()
        image = np.array(image, dtype=np.uint8)
        mask = Image.open(sample['mask']).convert('L') # import as grayscale image
        # mask = nib.load(sample['mask']).get_fdata()
        mask = np.array(mask, dtype=np.uint8)
        # slice = sample['slice']
        # mask has value 255 on rib pixels. Convert to binary array
        mask[np.logical_and(np.logical_and(mask!=85,mask!=170),mask!=255)] = 0
        mask[np.where(mask==255)] = 1
        mask[np.where(mask==85)] = 2
        mask[np.where(mask==170)] = 3
        # mask[np.where(mask>0 & mask <255)] = 0.5
        return {'img': image, 'mask': mask, 'img_meta_dict': {'affine': np.eye(2)}, 
                'mask_meta_dict': {'affine': np.eye(2)}}

In [22]:
train_dict_list = build_dict_ACDC(data_path)

composed_transform = monai.transforms.Compose([LoadMriData(),
                                                monai.transforms.AddChanneld(keys=['img', 'mask']),
                                                monai.transforms.ScaleIntensityd(keys=['img'],minv=0, maxv=1),
                                                monai.transforms.RandRotated(keys=['img', 'mask'], range_x=np.pi, prob=6/7, mode=['bilinear', 'nearest'],padding_mode=['zeros','zeros']),
                                                #monai.transforms.RandFlipd(keys=['img', 'mask'], prob=1/2, spatial_axis=1),  
                                                ])

val_transform = monai.transforms.Compose([LoadMriData(),
                                                # monai.transforms.AddChanneld(keys=['img', 'mask']),
                                                # monai.transforms.ScaleIntensityd(keys=['img'],minv=0, maxv=1),
                                                # monai.transforms.RandRotated(keys=['img', 'mask'], range_x=np.pi, prob=6/7, mode=['bilinear', 'nearest'],padding_mode=['zeros','zeros']),
                                                #monai.transforms.RandFlipd(keys=['img', 'mask'], prob=1/2, spatial_axis=1),  
                                                ])
# train_dataset = monai.data.CacheDataset(train_dict_list, transform=composed_transform)


num = 5
folds = list(range(num))

cvdataset = CrossValidation(
    dataset_cls=CVDataset,
    data=train_dict_list,
    nfolds=5,
    seed=12345,
    transform=composed_transform,
)

train_dss = [cvdataset.get_dataset(folds=folds[0:i] + folds[(i + 1) :]) for i in folds]
val_dss = [cvdataset.get_dataset(folds=i, transform=val_transform) for i in range(num)]

train_loaders = [monai.data.DataLoader(train_dss[i], batch_size=2, shuffle=True, num_workers=4) for i in folds]
val_loaders = [monai.data.DataLoader(val_dss[i], batch_size=1, num_workers=4) for i in folds]

print("Finished")

100%|██████████| 100/100 [00:00<00:00, 1950.39it/s]
Loading dataset: 100%|██████████| 1521/1521 [00:06<00:00, 219.12it/s]
Loading dataset: 100%|██████████| 1521/1521 [00:07<00:00, 205.15it/s]
Loading dataset: 100%|██████████| 1522/1522 [00:07<00:00, 207.48it/s]
Loading dataset: 100%|██████████| 1522/1522 [00:07<00:00, 207.18it/s]
Loading dataset: 100%|██████████| 1522/1522 [00:07<00:00, 205.93it/s]
Loading dataset: 100%|██████████| 381/381 [00:00<00:00, 561.95it/s]
Loading dataset: 100%|██████████| 381/381 [00:00<00:00, 581.20it/s]
Loading dataset: 100%|██████████| 380/380 [00:00<00:00, 891.87it/s]
Loading dataset: 100%|██████████| 380/380 [00:00<00:00, 687.46it/s]
Loading dataset: 100%|██████████| 380/380 [00:00<00:00, 687.92it/s]


Finished


In [25]:
print(train_loaders)

[<monai.data.dataloader.DataLoader object at 0x7f8009686a00>, <monai.data.dataloader.DataLoader object at 0x7f80038b1fd0>, <monai.data.dataloader.DataLoader object at 0x7f8000d007f0>, <monai.data.dataloader.DataLoader object at 0x7f8000cf5a90>, <monai.data.dataloader.DataLoader object at 0x7f8000cf5310>]


In [26]:
image = nib.load(r'/home/jovyan/Desktop/Deep-Learning-ACDC-challenge/database/training/patient100/patient100_frame01_gt.nii.gz' )

In [31]:
sx, sy, sz  = image.header.get_zooms()
volume_vox = sx*sy*sz
print(str(volume_vox) + ' mm^3')

17.313034 mm^3
