In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import glob
import monai
from PIL import Image
import torch
from monai.visualize import blend_images, matshow3d, plot_2d_or_3d_image
from tqdm.notebook import tqdm

In [9]:
patient_folders = ["data/training/" + x + "/" for x in os.listdir("data/training")]
train_patients = patient_folders
test_patients = patient_folders

patient_files = [[x + y[:-7] for y in os.listdir(x) if "frame" in y and "gt" not in y] for x in train_patients]
patient_files_flattened = [element for sublist in patient_files for element in sublist]

test_patient_files = [[x + y[:-7] for y in os.listdir(x) if "frame" in y and "gt" not in y] for x in test_patients]
test_patient_files_flattened = [element for sublist in test_patient_files for element in sublist]


images = [{'img': x} for x in patient_files_flattened]
test_images = [{'img': x} for x in test_patient_files_flattened]

In [10]:
from evaluate import load_nii
class LoadNIFTI(monai.transforms.Transform):
    """
    This custom Monai transform loads the data from the rib segmentation dataset.
    Defining a custom transform is simple; just overwrite the __init__ function and __call__ function.
    """
    def __init__(self, keys=None):
        pass

    def __call__(self, sample):
        img_file = sample['img'] + ".nii.gz"
        img_mask = sample['img'] + "_gt.nii.gz"
        
        image, img_affine, img_header = load_nii(img_file)
        scale_dims = img_header['pixdim']
        
        
        image = np.moveaxis(image, (2), (0))
        
        mask, mask_affine, mask_header = load_nii(img_mask)
        mask = np.moveaxis(mask, (2), (0))
        
        return {'img': image, 'mask': mask, 'name': sample, 'scaling': scale_dims}
    
    
class SplitMask(monai.transforms.Transform):
    """
    This custom Monai transform loads the data from the rib segmentation dataset.
    Defining a custom transform is simple; just overwrite the __init__ function and __call__ function.
    """
    def __init__(self, keys=None):
        pass

    def __call__(self, sample):
        mask = sample['mask'][0]
            
        c2 = np.where(np.logical_and(mask > 0.2, mask < 0.5), 1.0, 0.0)
        c3 = np.where(np.logical_and(mask > 0.5, mask < 0.8), 1.0, 0.0)
        c4 = np.where(mask > 0.8, 1.0, 0.0)
        sample['mask'] = np.array([c2, c3, c4])
        
        return sample
    
class ScaleDims(monai.transforms.Transform):
    def __init__(self, keys=None):
        pass
    
    def __call__(self, sample):
        scaling = sample['scaling']
        
        return monai.transforms.Zoomd(keys=['img', 'mask'], mode=['area', 'nearest'], zoom=(scaling[3] / 10, scaling[1] / 1.5, scaling[2] / 1.5), keep_size=False)(sample)
    
        # img = monai.transforms.Zoomd(keys=['img', 'mask']


In [11]:
# Define transforms for loading the dataset

# add_channels_transform = monai.transforms.AddChanneld(keys=['img', 'mask'])
# flip_transform = monai.transforms.RandFlipd(keys=['img', 'mask'], prob=1, spatial_axis=1)
# rotate_transform = monai.transforms.RandRotated(keys=['img', 'mask'], range_x=np.pi/4, prob=1, mode=['bilinear', 'nearest'])

compose_transform = monai.transforms.Compose(
    [
        LoadNIFTI(),
        monai.transforms.AddChanneld(keys=['img', 'mask']),
        monai.transforms.ScaleIntensityd(keys=['img', 'mask'], minv=0.0, maxv=1.0),
        SplitMask(),
        ScaleDims(),
        monai.transforms.RandRotated(keys=['img', 'mask'], range_x=np.pi/4, prob=1, mode=['bilinear', 'nearest']),
        monai.transforms.RandZoomd(keys=['img', 'mask'], prob=0.5, mode=['area', 'nearest']),
        monai.transforms.RandGridDistortiond(keys=['img', 'mask'], mode=['bilinear', 'nearest']),
        # monai.transforms.SpatialPadd(keys=['img', 'mask'], spatial_size=(16, -1, -1)),
        monai.transforms.RandFlipd(keys=['img', 'mask'], prob=0.5, spatial_axis=1),
        # monai.transforms.SpatialCropd(keys=['img', 'mask'], roi_size=(16, 128, 128), roi_center=(8, 64, 64)),
        monai.transforms.ScaleIntensityd(keys=['mask'], minv=0.0, maxv=1.0)
    ]
)

test_transform = monai.transforms.Compose(
    [
        LoadNIFTI(),
        monai.transforms.AddChanneld(keys=['img', 'mask']),
        monai.transforms.ScaleIntensityd(keys=['img', 'mask'], minv=0.0, maxv=1.0),
        SplitMask(),
        ScaleDims(),
        # monai.transforms.SpatialPadd(keys=['img', 'mask'], spatial_size=(16, -1, -1)),
        # monai.transforms.Resized(keys=['img', 'mask'], spatial_size=(-1, 128, 128)),
        # monai.transforms.SpatialCropd(keys=['img', 'mask'], roi_size=(16, 128, 128), roi_center=(8, 64, 64)),

        monai.transforms.ScaleIntensityd(keys=['mask'], minv=0.0, maxv=1.0)
    ]
)


In [14]:
train_dict_list = [x for x in images]
dataset = monai.data.CacheDataset(train_dict_list, transform=compose_transform)
test_dict_list = [x for x in test_images]
# test_dataset = monai.data.CacheDataset(test_dict_list, transform=test_transform)

data_loader = monai.data.DataLoader(dataset, batch_size=1, shuffle=True)
# test_loader = monai.data.DataLoader(test_dataset, batch_size=1, shuffle=True)


Loading dataset: 100%|██████████| 200/200 [00:01<00:00, 105.03it/s]


In [None]:
for d in 