In [None]:
!pip install monai

In [None]:
import os
import nibabel as nib
import numpy as np
from glob import glob
import psutil

"""
    Get data inputs, assumes CT volumes and segmentation masks have corresponding names and indices.
    Analyze the nifti datasets for MONAI parameter adjustments
    :param str in_dir: file path of data.
"""
def prepare_and_configure(in_dir):
    volume_dict = {}
    segmentation_dict = {}

    # find all .nii files under in_dir
    nii_files = glob(os.path.join(in_dir, "**", "*.nii"), recursive=True)

    for filepath in nii_files:
        filename = os.path.basename(filepath)
        if filename.startswith("volume-"):
            idx = int(filename.split("-")[1].split(".")[0])
            volume_dict[idx] = filepath
        elif filename.startswith("segmentation-"):
            idx = int(filename.split("-")[1].split(".")[0])
            segmentation_dict[idx] = filepath

    # match volume and segmentation by idx
    matched_keys = sorted(set(volume_dict.keys()) & set(segmentation_dict.keys()))
    all_files = [{"vol": volume_dict[k], "seg": segmentation_dict[k]} for k in matched_keys]

    # split 80% train / 20% validation
    split_idx = int(0.8 * len(all_files))
    train_files = all_files[:split_idx]
    validation_files = all_files[split_idx:]
    
    # analyze voxel sizes and shapes
    voxel_sizes = []
    shapes = []
    for k in matched_keys:
        img = nib.load(volume_dict[k])
        data = img.get_fdata()
        voxel_sizes.append(img.header.get_zooms())
        shapes.append(data.shape)

    # pixdim based on variables in https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/spleen_segmentation_3d.ipynb
    mean_spacing = np.mean(voxel_sizes, axis=0)
    mean_shape = np.mean(shapes, axis=0)
    pixdim = tuple(round(s, 2) for s in mean_spacing)

    # default for soft tissue
    a_min, a_max = -200, 250

    # detect GPU & RAM memory
    try:
        import GPUtil
        gpus = GPUtil.getGPUs()
        mem_free_gpu = max([gpu.memoryFree for gpu in gpus])  # in MB
    except Exception:
        mem_free_gpu = 0  # fallback to CPU

    mem_free_ram = psutil.virtual_memory().available // (1024 * 1024)

    # adjust preprocessing resolution based on memory
    # values are randomized based on https://docs.monai.io/en/stable/transforms.html
    if mem_free_gpu >= 20000:
        spatial_size = [256, 256, 256]
        batch_size = 2
    elif mem_free_gpu >= 10000:
        spatial_size = [192, 192, 128]
        batch_size = 1
    elif mem_free_gpu >= 4000:
        spatial_size = [128, 128, 64]
        batch_size = 1
    else:
        spatial_size = [96, 96, 64]
        batch_size = 1

    return {
        "train_files": train_files,
        "validation_files": validation_files,
        "pixdim": pixdim,
        "a_min": a_min,
        "a_max": a_max,
        "spatial_size": spatial_size,
        "batch_size": batch_size,
        "mem_free_gpu": mem_free_gpu,
        "mem_free_ram": mem_free_ram,
    }

In [None]:
import re
from glob import glob
from monai.transforms import (
    Compose,
    EnsureChannelFirstD,
    LoadImaged,
    Resized,
    ToTensord,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd,
)
from monai.data import DataLoader, Dataset, CacheDataset
from monai.utils import set_determinism

"""
    Use MONAI transforms to prepares data for segmentation.
    Voxel: 3D grid representation of data.
    
    :param tuple pixdim: standard voxel spacing (in millimeters) for resampling the images in the x, y, and z dimensions.
    :param int a_min: intensity voxel min for CT scans (less are clipped before scaling).
    :param int a_max: intensity voxel max for CT scans (more are clipped before scaling).
    :param int array spatial_size: output size (in voxel) to which each image and label volume will be resized. AKA input size for the neural network.
    :param int batch_size: adjyst batch size, default is 1.
    :return PyTorch DataLoader objects: used to train neural network.
"""
def preprocess(pixdim, a_min, a_max, spatial_size, batch_size, cache, train_files, validation_files):

    # reproduce training results
    set_determinism(seed=0)

    # and apply transformations to them
    # parameters from https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/spleen_segmentation_3d.ipynb
    train_transforms = Compose([
        LoadImaged(keys=["vol", "seg"]),
        EnsureChannelFirstD(keys=["vol", "seg"]),
        ScaleIntensityRanged(keys=["vol"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["vol", "seg"], source_key="vol"),
        Orientationd(keys=["vol", "seg"], axcodes="RAS"),
        Spacingd(keys=["vol", "seg"], pixdim=pixdim, mode=("bilinear", "nearest")),
        RandCropByPosNegLabeld(
            keys=["vol", "seg"],
            label_key="seg",
            spatial_size=spatial_size,  # use your configured size here
            pos=1, neg=1,
            num_samples=4,
            image_key="vol",
            image_threshold=0,
        ),
        ToTensord(keys=["vol", "seg"]),
    ])

    # transforms for validation data
    validation_transforms = Compose([
        LoadImaged(keys=["vol", "seg"]),
        EnsureChannelFirstD(keys=["vol", "seg"]),
        ScaleIntensityRanged(keys=["vol"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["vol", "seg"], source_key="vol"),
        Orientationd(keys=["vol", "seg"], axcodes="RAS"),
        Spacingd(keys=["vol", "seg"], pixdim=pixdim, mode=("bilinear", "nearest")),
    ])

    if cache >= 16000:
        train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0)
        val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)

        # train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
        # val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
    else:
        train_ds = Dataset(data=train_files, transform=train_transforms)
        validation_ds = Dataset(data=validation_files, transform=validation_transforms)

        # train_ds = Dataset(data=train_files, transform=train_transforms, num_workers=4)
        # validation_ds = Dataset(data=validation_files, transform=validation_transforms, num_workers=4)

    train_loader = DataLoader(train_ds, batch_size=batch_size)
    validation_loader = DataLoader(validation_ds, batch_size=batch_size)

    # use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    # train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
    # validation_loader = DataLoader(validation_ds, batch_size=batch_size, num_workers=4)

    return train_loader, validation_loader

In [None]:
# usage flow

# 1. user input (for now, it is kaggle data set)
params = prepare_and_configure(in_dir="/kaggle/input")

# testing
print("prepare_and_configure:")
for k, v in params.items():
    print(f"{k}: {v}")

# preprocess & show reasoning
train_loader, validation_loader = preprocess(pixdim=params['pixdim'], a_min=params['a_min'], a_max=params['a_max'], spatial_size=params['spatial_size'], batch_size=params['batch_size'], cache=params['mem_free_ram'], train_files=params['train_files'], validation_files=params['validation_files'])
print(train_loader)
print(validation_loader)