In [None]:
import os

"""
    Get data inputs, assumes CT volumes and segmentation masks have corresponding names and indices.
    :param str in_dir: file path of data.
"""
def prepare(in_dir="/kaggle/input"):

    volume_dict = {}
    segmentation_dict = {}

    for dirname, _, filenames in os.walk(in_dir):
        for filename in filenames:
            if filename.endswith(".nii"):
                filepath = os.path.join(dirname, filename)
                if filename.startswith("volume-"):
                    idx = int(filename.split("-")[1].split(".")[0])
                    volume_dict[idx] = filepath
                elif filename.startswith("segmentation-"):
                    idx = int(filename.split("-")[1].split(".")[0])
                    segmentation_dict[idx] = filepath

    # match volume and segmentation by idx
    matched_keys = sorted(set(volume_dict.keys()) & set(segmentation_dict.keys()))
    all_files = [{"vol": volume_dict[k], "seg": segmentation_dict[k]} for k in matched_keys]

    # split into 80 train and 20 test
    split_idx = int(0.8 * len(all_files))
    train_files = all_files[:split_idx]
    test_files = all_files[split_idx:]

    return train_files, test_files

In [None]:
'''
Written by abushaah with sources:
https://github.com/Project-MONAI/MONAI
https://github.com/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch/blob/main/preporcess.py
'''

import re
from glob import glob
from monai.transforms import (
    Compose,
    EnsureChannelFirstD,
    LoadImaged,
    Resized,
    ToTensord,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd,
)
from monai.data import DataLoader, Dataset, CacheDataset
from monai.utils import set_determinism

"""
    Use MONAI transforms to prepares data for segmentation.
    Voxel: 3D grid representation of data.
    
    :param tuple pixdim: standard voxel spacing (in millimeters) for resampling the images in the x, y, and z dimensions.
    :param int a_min: intensity voxel min for CT scans (less are clipped before scaling).
    :param int a_max: intensity voxel max for CT scans (more are clipped before scaling).
    :param int array spatial_size: output size (in voxel) to which each image and label volume will be resized. AKA input size for the neural network.
    :param int batch_size: adjyst batch size, default is 1.
    :return PyTorch DataLoader objects: used to train neural network.
"""
def preprocess(pixdim=(1.5, 1.5, 1.0), a_min=-200, a_max=200, spatial_size=[128, 128, 64], batch_size=1):
    train_files, test_files = prepare()

    # reproduce training results
    set_determinism(seed=0)

    # and apply transformations to them
    train_transforms = Compose([
        LoadImaged(keys=["vol", "seg"]),
        EnsureChannelFirstD(keys=["vol", "seg"]),
        Spacingd(keys=["vol", "seg"], pixdim=pixdim, mode=("bilinear", "nearest")),
        Orientationd(keys=["vol", "seg"], axcodes="RAS"),
        ScaleIntensityRanged(keys=["vol"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["vol", "seg"], source_key="vol"),
        Resized(keys=["vol", "seg"], spatial_size=spatial_size),
        ToTensord(keys=["vol", "seg"]),
    ])

    # transforms for test data
    test_transforms = Compose([
        LoadImaged(keys=["vol", "seg"]),
        EnsureChannelFirstD(keys=["vol", "seg"]),
        Spacingd(keys=["vol", "seg"], pixdim=pixdim, mode=("bilinear", "nearest")),
        Orientationd(keys=["vol", "seg"], axcodes="RAS"),
        ScaleIntensityRanged(keys=["vol"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["vol", "seg"], source_key="vol"),
        Resized(keys=["vol", "seg"], spatial_size=spatial_size),
        ToTensord(keys=["vol", "seg"]),
    ])
    
    train_ds = Dataset(data=train_files, transform=train_transforms)
    test_ds = Dataset(data=test_files, transform=test_transforms)

    train_loader = DataLoader(train_ds, batch_size=batch_size)
    test_loader = DataLoader(test_ds, batch_size=batch_size)

    return train_loader, test_loader
