In [1]:
import os, sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.utils.labels_mapping import RelabelByModality

import pandas as pd
import numpy as np
import torch
import nibabel as nib

from monai.transforms import (
    Compose, 
    LoadImaged, 
    EnsureChannelFirstd,
    AsDiscrete,
    SaveImaged
)

from monai.data import DataLoader, Dataset, decollate_batch
from monai.networks.nets import UNet
from monai.inferers import sliding_window_inference

In [2]:
# Load images and define preprocessing transforms

preprocessing = Compose([
    LoadImaged(keys=['image', 'label']),
    EnsureChannelFirstd(keys=['image', 'label']),
    RelabelByModality(keys=['label'])
])

test_data = pd.read_csv('../data/processed/test_split.csv')
test_data = test_data.drop(columns=['file_name']).rename(columns={'image_path': 'image', 'label_path': 'label'})
test_data['image'] = test_data['image'].apply(lambda x: os.path.join('..', x))
test_data['label'] = test_data['label'].apply(lambda x: os.path.join('..', x))
test_dict = test_data.to_dict('records') # list of dict (one per image) like [{'image':..., 'label':..., 'modality':...}, ...]

test_ds = Dataset(data=test_dict, transform=preprocessing)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4)

In [3]:
# Load model
CHANNELS = (16, 32, 64, 128, 256)
NUM_RES_UNITS = 2

device = torch.device("cuda")

model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=49,
        channels=CHANNELS,
        strides=(2, 2, 2, 2),
        num_res_units=NUM_RES_UNITS,
        bias=False,
        dropout=0.1
    ).to(device)

model.load_state_dict(torch.load('../models/unet3d_PatchSize_increased(96-_128)/best_model_fold_1.pth'))

<All keys matched successfully>

In [4]:
def infer_volume(model, inputs, vol_size=(96, 96, 96), sw_patch_size=4, overlap=0.25):
    """Infer a 3D image volume using sliding window inference."""
    model.eval()
    with torch.no_grad():
        return sliding_window_inference(
            inputs = inputs,                    # e.g. a 3D image volume
            roi_size = vol_size,                # size of the sliding window
            sw_batch_size = sw_patch_size,      # number of sliding windows to process in parallel
            predictor = model,                  # the model to run
            overlap = overlap                   # amount of overlap between sliding windows
        )

In [None]:
post_pred = Compose([AsDiscrete(argmax=True)]) # argmax the prediction

postprocessing = Compose([
    RelabelByModality(keys=['pred', 'label'], reverse=True),
    SaveImaged(
        keys=['pred'],
        output_dir='../data/processed/predictions',
        output_postfix='pred',
        separate_folder=False
    )
])

In [9]:
# Apply model on test set and save predictions
model.eval()

with torch.no_grad():
    for i, test_data in enumerate(test_loader):
        inputs, labels = test_data['image'].to(device), test_data['label'].to(device)

        outputs = infer_volume(model, inputs, vol_size=(128, 128, 128), sw_patch_size=1, overlap=0.1) # infer the image volume
        outputs = [post_pred(i) for i in decollate_batch(outputs)]  # post-process the prediction
        
        test_data['pred'] = outputs[0]

        test_data = postprocessing(test_data)  # post-process and save the prediction

2025-10-20 18:00:02,662 INFO image_writer.py:197 - writing: ../data/preprocessed/predictions/topcow_ct_021_preproc_pred.nii.gz
2025-10-20 18:00:07,144 INFO image_writer.py:197 - writing: ../data/preprocessed/predictions/topcow_ct_005_preproc_pred.nii.gz
2025-10-20 18:00:20,839 INFO image_writer.py:197 - writing: ../data/preprocessed/predictions/topcow_ct_012_preproc_pred.nii.gz
2025-10-20 18:00:28,696 INFO image_writer.py:197 - writing: ../data/preprocessed/predictions/topcow_ct_003_preproc_pred.nii.gz
2025-10-20 18:00:41,906 INFO image_writer.py:197 - writing: ../data/preprocessed/predictions/topcow_ct_022_preproc_pred.nii.gz
2025-10-20 18:00:49,845 INFO image_writer.py:197 - writing: ../data/preprocessed/predictions/topcow_mr_007_preproc_pred.nii.gz
2025-10-20 18:00:57,525 INFO image_writer.py:197 - writing: ../data/preprocessed/predictions/topcow_mr_005_preproc_pred.nii.gz
2025-10-20 18:01:06,980 INFO image_writer.py:197 - writing: ../data/preprocessed/predictions/topcow_mr_017_prep