## Install Dependencies

In [1]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

Looking in indexes: https://download.pytorch.org/whl/cu126
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading https://download.pytorch.org/whl/cu126/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading https://download.pytorch.org/whl/cu126/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.met

In [2]:
!pip3 install nnunetv2 acvl-utils cupy-cuda12x

Collecting nnunetv2
  Downloading nnunetv2-2.6.0.tar.gz (206 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/206.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m206.3/206.3 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting acvl-utils
  Downloading acvl_utils-0.2.5.tar.gz (29 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dynamic-network-architectures<0.4,>=0.3.1 (from nnunetv2)
  Downloading dynamic_network_architectures-0.3.1.tar.gz (20 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dicom2nifti (from nnunetv2)
  Downloading dicom2nifti-2.5.1-py3-none-any.whl.metadata (1.3 kB)
Collecting batchgenerators>=0.25.1 (from nnunetv2)
  Downloading batchgenerators-0.25.1.tar.gz (76 kB)
[2K

## Define Functions

In [1]:
from abc import abstractmethod, ABC
from typing import Optional, Dict, Union, Tuple, List

import numpy as np
import cupy as cp
import torch
from nnunetv2.configuration import ANISO_THRESHOLD
from cupyx.scipy import ndimage


class ImageNormalization(ABC):
    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = None

    def __init__(self, use_mask_for_norm: Optional[bool] = None, intensityproperties: Optional[Dict] = None,
                 target_dtype: torch.dtype = torch.float32):
        assert use_mask_for_norm is None or isinstance(use_mask_for_norm, bool)
        self.use_mask_for_norm = use_mask_for_norm
        assert isinstance(intensityproperties, dict) or intensityproperties is None
        self.intensityproperties = intensityproperties
        self.target_dtype = target_dtype

    @abstractmethod
    def run(self, image: torch.Tensor, seg: Optional[torch.Tensor] = None) -> torch.Tensor:
        raise NotImplementedError


class CTNormalization(ImageNormalization):
    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False

    def run(self, image: torch.Tensor, seg: Optional[torch.Tensor] = None) -> torch.Tensor:
        assert self.intensityproperties is not None, "CTNormalization requires intensity properties"
        mean_intensity = self.intensityproperties['mean']
        std_intensity = self.intensityproperties['std']
        lower_bound = self.intensityproperties['percentile_00_5']
        upper_bound = self.intensityproperties['percentile_99_5']

        image = image.to(dtype=self.target_dtype)
        image = torch.clamp(image, lower_bound, upper_bound)
        image = (image - mean_intensity) / max(std_intensity, 1e-8)
        return image


def create_nonzero_mask(data):
    """

    :param data:
    :return: the mask is True where the data is nonzero
    """
    assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
    nonzero_mask = data[0] != 0
    for c in range(1, data.shape[0]):
        nonzero_mask |= data[c] != 0
    filled_mask = ndimage.binary_fill_holes(nonzero_mask)
    return filled_mask


def get_bbox_from_mask(mask: cp.ndarray) -> List[List[int]]:
    """
    ALL bounding boxes in acvl_utils and nnU-Netv2 are half open interval [start, end)!
    - Alignment with Python Slicing
    - Ease of Subdivision
    - Consistency in Multi-Dimensional Arrays
    - Precedent in Computer Graphics

    This implementation uses CuPy for GPU acceleration. The mask should be a CuPy array.

    Args:
        mask (cp.ndarray): 3D mask array on GPU

    Returns:
        List[List[int]]: Bounding box coordinates as [[minz, maxz], [minx, maxx], [miny, maxy]]
    """
    Z, X, Y = mask.shape
    minzidx, maxzidx, minxidx, maxxidx, minyidx, maxyidx = 0, Z, 0, X, 0, Y

    # Create range arrays on GPU
    zidx = cp.arange(Z)
    xidx = cp.arange(X)
    yidx = cp.arange(Y)

    # Z dimension
    for z in zidx.get():  # .get() to iterate over CPU array
        if cp.any(mask[z]).get():  # .get() to get boolean result to CPU
            minzidx = z
            break
    for z in zidx[::-1].get():
        if cp.any(mask[z]).get():
            maxzidx = z + 1
            break

    # X dimension
    for x in xidx.get():
        if cp.any(mask[:, x]).get():
            minxidx = x
            break
    for x in xidx[::-1].get():
        if cp.any(mask[:, x]).get():
            maxxidx = x + 1
            break

    # Y dimension
    for y in yidx.get():
        if cp.any(mask[:, :, y]).get():
            minyidx = y
            break
    for y in yidx[::-1].get():
        if cp.any(mask[:, :, y]).get():
            maxyidx = y + 1
            break

    return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]

def bounding_box_to_slice(bounding_box: List[List[int]]):
    """
    ALL bounding boxes in acvl_utils and nnU-Netv2 are half open interval [start, end)!
    - Alignment with Python Slicing
    - Ease of Subdivision
    - Consistency in Multi-Dimensional Arrays
    - Precedent in Computer Graphics
    https://chatgpt.com/share/679203ec-3fbc-8013-a003-13a7adfb1e73
    """
    return tuple([slice(*i) for i in bounding_box])


def crop_to_nonzero(data, seg=None, nonzero_label=-1):
    """

    :param data:
    :param seg:
    :param nonzero_label: this will be written into the segmentation map
    :return:
    """
    nonzero_mask = create_nonzero_mask(data)
    bbox = get_bbox_from_mask(nonzero_mask)
    slicer = bounding_box_to_slice(bbox)
    nonzero_mask = nonzero_mask[slicer][None]

    slicer = (slice(None),) + slicer
    data = data[slicer]
    if seg is not None:
        seg = seg[slicer]
        seg[(seg == 0) & (~nonzero_mask)] = nonzero_label
    else:
        seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label))
    return data, seg, bbox


def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
                      old_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                      new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray:
    assert len(old_spacing) == len(old_shape)
    assert len(old_shape) == len(new_spacing)
    new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)])
    return new_shape


def fast_resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],
                                       new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
                                       current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                                       new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                                       is_seg: bool = False,
                                       order: int = 3, order_z: int = 0,
                                       force_separate_z: Union[bool, None] = False,
                                       separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
    use_gpu = True
    device = torch.device("cuda" if use_gpu else "cpu")
    order_to_mode_map = {
        0: "nearest",
        1: "trilinear" if new_shape[0] > 1 else "bilinear",
        2: "trilinear" if new_shape[0] > 1 else "bilinear",
        3: "trilinear" if new_shape[0] > 1 else "bicubic",
        4: "trilinear" if new_shape[0] > 1 else "bicubic",
        5: "trilinear" if new_shape[0] > 1 else "bicubic",
    }
    resize_fn = torch.nn.functional.interpolate
    kwargs = {
        'mode': order_to_mode_map[order],
        'align_corners': False
    }
    shape = np.array(data[0].shape)
    new_shape = np.array(new_shape)
    if np.any(shape != new_shape):
        if not isinstance(data, torch.Tensor):
            #torch_data = torch.from_numpy(data).float()
            torch_data = torch.as_tensor(data.get())
        else:
            torch_data = data.float()
        if new_shape[0] == 1:
            torch_data = torch_data.transpose(1, 0)
            new_shape = new_shape[1:]
        else:
            torch_data = torch_data.unsqueeze(0)

        torch_data = resize_fn(torch_data.to(device), tuple(new_shape), **kwargs)

        if new_shape[0] == 1:
            torch_data = torch_data.transpose(1, 0)
        else:
            torch_data = torch_data.squeeze(0)

        # if use_gpu:
        #     torch_data = torch_data.cpu()
        reshaped_final_data = torch_data
        # if isinstance(data, np.ndarray):
        #     reshaped_final_data = torch_data.numpy().astype(dtype_data)
        # else:
        #     reshaped_final_data = torch_data.to(dtype_data)

        #print(f"Reshaped data from {shape} to {new_shape}")
        #print(f"reshaped_final_data shape: {reshaped_final_data.shape}")
        assert reshaped_final_data.ndim == 4, f"reshaped_final_data.shape = {reshaped_final_data.shape}"
        return reshaped_final_data
    else:
        print("no resampling necessary")
        return data

## Upload Data

In [2]:
from google.colab import drive
from os import mkdir
from os.path import exists
from shutil import copy

drive.mount('/content/drive')
if not exists("inference_test_cases"):
    mkdir("inference_test_cases")
if not exists("inference_test_cases.zip"):
    copy("/content/drive/MyDrive/inference_test_cases.zip", "inference_test_cases.zip")
!unzip inference_test_cases.zip -d inference_test_cases

Mounted at /content/drive
Archive:  inference_test_cases.zip
 extracting: inference_test_cases/FLARETs_0001_0000.nii.gz  
 extracting: inference_test_cases/FLARETs_0015_0000.nii.gz  
 extracting: inference_test_cases/FLARETs_0018_0000.nii.gz  
 extracting: inference_test_cases/FLARETs_0036_0000.nii.gz  
 extracting: inference_test_cases/FLARETs_0046_0000.nii.gz  


## Run Inference

In [7]:
from os import listdir
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
from time import time


def main() -> dict[str, list[float]]:
    CT_configuration = {
        "transpose_forward": [
            0,
            1,
            2
        ],
        "spacing": [
            2.5,
            0.7958984971046448,
            0.7958984971046448
        ],
        'intensity_prop': {
            "max": 3071.0,
            "mean": 97.29716491699219,
            "median": 118.0,
            "min": -1024.0,
            "percentile_00_5": -958.0,
            "percentile_99_5": 270.0,
            "std": 137.8484649658203
        }
    }

    normalization = CTNormalization(intensityproperties=CT_configuration['intensity_prop'])
    cases = listdir("inference_test_cases")
    cases.sort()
    results = {"cropping": [], "normalization": [], "resampling": []}
    for case in cases:
        image, properties = SimpleITKIO().read_images([f"inference_test_cases/{case[:case.find('.')]}"])
        image = torch.from_numpy(image).to(dtype=torch.float32, memory_format=torch.contiguous_format).to('cuda')
        data = image.clone()
        data = data.permute([0, *[i + 1 for i in CT_configuration['transpose_forward']]])
        original_spacing = [properties['spacing'][i] for i in CT_configuration['transpose_forward']]
        t0 = time()
        data, seg, bbox = crop_to_nonzero(data.cpu())
        torch.cuda.synchronize()
        results["cropping"].append(time() - t0)
        target_spacing = CT_configuration['spacing']
        if len(target_spacing) < len(data.shape[1:]):
            target_spacing = [original_spacing[0]] + target_spacing
        new_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing)
        t0 = time()
        data = normalization.run(data.cuda())
        torch.cuda.synchronize()
        results["normalization"].append(time() - t0)
        fast_resample_data_or_seg_to_shape(data, new_shape, original_spacing, target_spacing)
        torch.cuda.synchronize()
        results["resampling"].append(time() - t0)
    return results


def add_up(table: dict[str, list[float]], entry: str, n: int, num_cases: int) -> list[float]:
    r = []
    for i in range(num_cases):
        s = 0
        for j in range(n):
            s += table[entry][i + j * num_cases]
        r.append(s / n)
    return r


_final = {"cropping": [], "normalization": [], "resampling": []}
_n = 10
_num_cases = 5
for _b in range(_n):
    print(f"Porcessing batch {_b}")
    _r = main()
    _final["cropping"] += _r["cropping"]
    _final["normalization"] += _r["normalization"]
    _final["resampling"] += _r["resampling"]
print(add_up(_final, "cropping", _n, _num_cases))
print(add_up(_final, "normalization", _n, _num_cases))
print(add_up(_final, "resampling", _n, _num_cases))

Porcessing batch 0
Porcessing batch 1
Porcessing batch 2
Porcessing batch 3
Porcessing batch 4
Porcessing batch 5
Porcessing batch 6
Porcessing batch 7
Porcessing batch 8
Porcessing batch 9
[2.4736659288406373, 3.030393648147583, 1.2151961088180543, 0.6216495752334594, 1.7759423732757569]
[0.14339048862457277, 0.15852503776550292, 0.06236412525177002, 0.03585348129272461, 0.0936197280883789]
[0.15676624774932862, 0.17111523151397706, 0.07584750652313232, 0.04186503887176514, 0.10113983154296875]
