In [1]:
!pip install monai



In [2]:
import os
import nibabel as nib
import numpy as np
from glob import glob
import psutil

"""
    Get data inputs, assumes CT images and label masks have corresponding names and indices.
    Analyze the nifti datasets for MONAI parameter adjustments
    :param str in_dir: file path of data.
"""
def prepare_and_configure(in_dir):
    image_dict = {}
    label_dict = {}

    # find all .nii files under in_dir
    nii_files = glob(os.path.join(in_dir, "**", "*.nii"), recursive=True)

    for filepath in nii_files:
        filename = os.path.basename(filepath)
        if filename.startswith("volume-"):
            idx = int(filename.split("-")[1].split(".")[0])
            image_dict[idx] = filepath
        elif filename.startswith("segmentation-"):
            idx = int(filename.split("-")[1].split(".")[0])
            label_dict[idx] = filepath

    # match image and label by idx
    
    # matched_keys = sorted(set(image_dict.keys()) & set(label_dict.keys()))
    # test model syntax first
    matched_keys = sorted(set(image_dict.keys()) & set(label_dict.keys()))[:20]

    all_files = [{"image": image_dict[k], "label": label_dict[k]} for k in matched_keys]

    # split 80% train / 20% validation
    split_idx = int(0.8 * len(all_files))
    train_files = all_files[:split_idx]
    validation_files = all_files[split_idx:]
    
    # analyze voxel sizes and shapes
    voxel_sizes = []
    shapes = []
    for k in matched_keys:
        img = nib.load(image_dict[k])
        data = img.get_fdata()
        voxel_sizes.append(img.header.get_zooms())
        shapes.append(data.shape)

    # pixdim based on variables in https://github.com/Project-MONAI/tutorials/blob/main/3d_label/spleen_label_3d.ipynb
    mean_spacing = np.mean(voxel_sizes, axis=0)
    mean_shape = np.mean(shapes, axis=0)

    if isinstance(mean_spacing, np.ndarray):  # Expected case
        pixdim = tuple(round(s, 2) for s in mean_spacing)
    else:  # Fallback for scalar
        pixdim = (round(mean_spacing, 2),)        

    # default for soft tissue
    a_min, a_max = -200, 250

    # detect GPU & RAM memory
    try:
        import GPUtil
        gpus = GPUtil.getGPUs()
        mem_free_gpu = max([gpu.memoryFree for gpu in gpus])  # in MB
    except Exception:
        mem_free_gpu = 0  # fallback to CPU

    mem_free_ram = psutil.virtual_memory().available // (1024 * 1024)

    # adjust preprocessing resolution based on memory
    # values are randomized based on https://docs.monai.io/en/stable/transforms.html
    if mem_free_gpu >= 20000:
        spatial_size = [256, 256, 256]
        batch_size = 2
    elif mem_free_gpu >= 10000:
        spatial_size = [192, 192, 128]
        batch_size = 1
    elif mem_free_gpu >= 4000:
        spatial_size = [128, 128, 64]
        batch_size = 1
    else:
        spatial_size = [96, 96, 64]
        batch_size = 1

    return {
        "train_files": train_files,
        "validation_files": validation_files,
        "pixdim": pixdim,
        "a_min": a_min,
        "a_max": a_max,
        "spatial_size": spatial_size,
        "batch_size": batch_size,
        "mem_free_gpu": mem_free_gpu,
        "mem_free_ram": mem_free_ram,
    }

In [3]:
import re
from glob import glob
from monai.transforms import (
    Compose,
    EnsureChannelFirstD,
    LoadImaged,
    Resized,
    ToTensord,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd,
    RandCropByPosNegLabeld,
)
from monai.data import DataLoader, Dataset, CacheDataset
from monai.utils import set_determinism

"""
    Use MONAI transforms to prepares data for segmentation.
    Voxel: 3D grid representation of data.
    
    :param tuple pixdim: standard voxel spacing (in millimeters) for resampling the images in the x, y, and z dimensions.
    :param int a_min: intensity voxel min for CT scans (less are clipped before scaling).
    :param int a_max: intensity voxel max for CT scans (more are clipped before scaling).
    :param int array spatial_size: output size (in voxel) to which each image and label volume will be resized. AKA input size for the neural network.
    :param int batch_size: adjyst batch size, default is 1.
    :return PyTorch DataLoader objects: used to train neural network.
"""
def preprocess(pixdim, a_min, a_max, spatial_size, batch_size, cache, train_files, validation_files):

    # reproduce training results
    set_determinism(seed=0)

    # and apply transformations to them
    # parameters from https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/spleen_segmentation_3d.ipynb
    train_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstD(keys=["image", "label"]),
        ScaleIntensityRanged(keys=["image"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=pixdim, mode=("bilinear", "nearest")),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=spatial_size,  # use your configured size here
            pos=1, neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        ToTensord(keys=["image", "label"]),
    ])

    # transforms for validation data
    validation_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstD(keys=["image", "label"]),
        ScaleIntensityRanged(keys=["image"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=pixdim, mode=("bilinear", "nearest")),
    ])

    if cache is not None and cache >= 16000:
        train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0)
        validation_ds = CacheDataset(data=validation_files, transform=validation_transforms, cache_rate=1.0)

        # train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
        # validation_ds = CacheDataset(data=validation_files, transform=validation_transforms, cache_rate=1.0, num_workers=4)
    else:
        train_ds = Dataset(data=train_files, transform=train_transforms)
        validation_ds = Dataset(data=validation_files, transform=validation_transforms)

        # train_ds = Dataset(data=train_files, transform=train_transforms, num_workers=4)
        # validation_ds = Dataset(data=validation_files, transform=validation_transforms, num_workers=4)

    train_loader = DataLoader(train_ds, batch_size=batch_size)
    validation_loader = DataLoader(validation_ds, batch_size=batch_size)

    # use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    # train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
    # validation_loader = DataLoader(validation_ds, batch_size=batch_size, num_workers=4)

    return train_loader, validation_loader

2025-07-28 23:41:17.984505: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753746078.007105     289 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753746078.013974     289 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
from monai.transforms import (
    AsDiscrete,
    Compose,
)
from torch.cuda.amp import GradScaler, autocast

def train (model, train_loader, validation_loader, loss_function, optimizer, dice_metric):
    max_epochs = 600
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = []
    metric_values = []
    post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
    post_label = Compose([AsDiscrete(to_onehot=2)])
    scaler = GradScaler()
    
    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        print(f"Number of batches in train_loader: {len(train_loader)}")
        for batch_data in train_loader:
            step += 1
            inputs, labels = (
                batch_data["image"].to(device), # volume
                batch_data["label"].to(device), # segment
            )

            # labels.shape: torch.Size([4, 1, 96, 96, 64]), dtype: torch.float32
            # unique label values: tensor([0., 1., 2.], device='cuda:0')
            labels = labels.squeeze(1) # remove channel dimension [B, D, H, W]
            labels = labels.long()
            labels = torch.clamp(labels, 0, 1)     # ensure binary labels
            
            optimizer.zero_grad()
            with autocast():
                outputs = model(inputs)

                print(f"outputs.shape: {outputs.shape}, dtype: {outputs.dtype}")
                print(f"labels.shape: {labels.shape}, dtype: {labels.dtype}")
                print(f"unique label values: {torch.unique(labels)}")
                
                loss = loss_function(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
                
            # loss.backward()
            # optimizer.step()
            print(f"loss.item: {loss.item()}")
            epoch_loss += loss.item()
            print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                for validation_data in validation_loader:
                    val_inputs, val_labels = (
                        validation_data["image"].to(device),
                        validation_data["label"].to(device),
                    )
                    # roi_size = (160, 160, 160)
                    roi_size = (64, 64, 64)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)

                    # free memory
                    del val_inputs, val_labels, val_outputs
                    torch.cuda.empty_cache()
    
                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                # reset the status for next validation round
                dice_metric.reset()
    
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                    print("saved new best metric model")
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} "
                    f"at epoch: {best_metric_epoch}"
                )
                
    print(f"train completed, best_metric: {best_metric:.4f} " f"at epoch: {best_metric_epoch}")
    plot(epoch_loss_values, val_interval, metric_values)

In [5]:
import matplotlib.pyplot as plt

def plot(epoch_loss_values, val_interval, metric_values):
    plt.figure("train", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Epoch Average Loss")
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel("epoch")
    plt.plot(x, y)
    plt.subplot(1, 2, 2)
    plt.title("Val Mean Dice")
    x = [val_interval * (i + 1) for i in range(len(metric_values))]
    y = metric_values
    plt.xlabel("epoch")
    plt.plot(x, y)
    plt.show()

In [6]:
import torch
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks.layers import Norm

# if __name__ == '__main__':
    
# 1. user input (for now, it is kaggle data set)
print("Preparing dataset directory and machine specifications...")
params = prepare_and_configure(in_dir="/kaggle/input")
print("Complete")

# show output
# print("prepare_and_configure:")
# for k, v in params.items():
#     print(f"{k}: {v}")

# 2. preprocess
print("Preprocess dataset with MONAI transforms...")
train_loader, validation_loader = preprocess(pixdim=params['pixdim'], a_min=params['a_min'], a_max=params['a_max'], spatial_size=params['spatial_size'], batch_size=params['batch_size'], cache=params['mem_free_ram'], train_files=params['train_files'], validation_files=params['validation_files'])
# print(train_loader)
# print(validation_loader)
print("Complete")

# 3. build U-net
print("Building U-Net Model...")
device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2, 2),
    num_res_units=1,
    norm=Norm.BATCH,
).to(device)
print("Complete")

print("Evaluating model parameters...")
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")
print("Complete")

# 4. train
print("Training model...")
train(model, train_loader, validation_loader, loss_function, optimizer, dice_metric)
print("Complete")

# 5. test/validate

Preparing dataset directory and machine specifications...
Complete
Preprocess dataset with MONAI transforms...


Loading dataset: 100%|██████████| 16/16 [04:58<00:00, 18.63s/it]
Loading dataset: 100%|██████████| 4/4 [01:31<00:00, 22.76s/it]


Complete
Building U-Net Model...
Complete
Evaluating model parameters...
Complete
Training model...
----------
epoch 1/600
Number of batches in train_loader: 16


  scaler = GradScaler()
  with autocast():


outputs.shape: torch.Size([4, 2, 96, 96, 64]), dtype: torch.float16
labels.shape: torch.Size([4, 96, 96, 64]), dtype: torch.int64
unique label values: tensor([0, 1], device='cuda:0')


AssertionError: labels should have a channel with length equal to one.