## Install Dependencies

In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

In [None]:
!pip3 install nnunetv2 acvl-utils cupy-cuda12x

## Define Functions

In [None]:
from abc import abstractmethod, ABC
from typing import Optional, Dict, Union, Tuple, List

import numpy as np
import torch
from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, bounding_box_to_slice
from nnunetv2.configuration import ANISO_THRESHOLD
from scipy.ndimage import binary_fill_holes


class ImageNormalization(ABC):
    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = None

    def __init__(self, use_mask_for_norm: Optional[bool] = None, intensityproperties: Optional[Dict] = None,
                 target_dtype: torch.dtype = torch.float32):
        assert use_mask_for_norm is None or isinstance(use_mask_for_norm, bool)
        self.use_mask_for_norm = use_mask_for_norm
        assert isinstance(intensityproperties, dict) or intensityproperties is None
        self.intensityproperties = intensityproperties
        self.target_dtype = target_dtype

    @abstractmethod
    def run(self, image: torch.Tensor, seg: Optional[torch.Tensor] = None) -> torch.Tensor:
        raise NotImplementedError


class CTNormalization(ImageNormalization):
    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False

    def run(self, image: torch.Tensor, seg: Optional[torch.Tensor] = None) -> torch.Tensor:
        assert self.intensityproperties is not None, "CTNormalization requires intensity properties"
        mean_intensity = self.intensityproperties['mean']
        std_intensity = self.intensityproperties['std']
        lower_bound = self.intensityproperties['percentile_00_5']
        upper_bound = self.intensityproperties['percentile_99_5']

        image = image.to(dtype=self.target_dtype)
        image = torch.clamp(image, lower_bound, upper_bound)
        image = (image - mean_intensity) / max(std_intensity, 1e-8)
        return image


def create_nonzero_mask(data):
    """

    :param data:
    :return: the mask is True where the data is nonzero
    """
    assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
    nonzero_mask = data[0] != 0
    for c in range(1, data.shape[0]):
        nonzero_mask |= data[c] != 0
    return binary_fill_holes(nonzero_mask)


def crop_to_nonzero(data, seg=None, nonzero_label=-1):
    """

    :param data:
    :param seg:
    :param nonzero_label: this will be written into the segmentation map
    :return:
    """
    nonzero_mask = create_nonzero_mask(data)
    bbox = get_bbox_from_mask(nonzero_mask)
    slicer = bounding_box_to_slice(bbox)
    nonzero_mask = nonzero_mask[slicer][None]

    slicer = (slice(None),) + slicer
    data = data[slicer]
    if seg is not None:
        seg = seg[slicer]
        seg[(seg == 0) & (~nonzero_mask)] = nonzero_label
    else:
        seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label))
    return data, seg, bbox


def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
                      old_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                      new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray:
    assert len(old_spacing) == len(old_shape)
    assert len(old_shape) == len(new_spacing)
    new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)])
    return new_shape


def fast_resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],
                                       new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
                                       current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                                       new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                                       is_seg: bool = False,
                                       order: int = 3, order_z: int = 0,
                                       force_separate_z: Union[bool, None] = False,
                                       separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
    use_gpu = True
    device = torch.device("cuda" if use_gpu else "cpu")
    order_to_mode_map = {
        0: "nearest",
        1: "trilinear" if new_shape[0] > 1 else "bilinear",
        2: "trilinear" if new_shape[0] > 1 else "bilinear",
        3: "trilinear" if new_shape[0] > 1 else "bicubic",
        4: "trilinear" if new_shape[0] > 1 else "bicubic",
        5: "trilinear" if new_shape[0] > 1 else "bicubic",
    }
    resize_fn = torch.nn.functional.interpolate
    kwargs = {
        'mode': order_to_mode_map[order],
        'align_corners': False
    }
    shape = np.array(data[0].shape)
    new_shape = np.array(new_shape)
    if np.any(shape != new_shape):
        if not isinstance(data, torch.Tensor):
            #torch_data = torch.from_numpy(data).float()
            torch_data = torch.as_tensor(data.get())
        else:
            torch_data = data.float()
        if new_shape[0] == 1:
            torch_data = torch_data.transpose(1, 0)
            new_shape = new_shape[1:]
        else:
            torch_data = torch_data.unsqueeze(0)

        torch_data = resize_fn(torch_data.to(device), tuple(new_shape), **kwargs)

        if new_shape[0] == 1:
            torch_data = torch_data.transpose(1, 0)
        else:
            torch_data = torch_data.squeeze(0)

        # if use_gpu:
        #     torch_data = torch_data.cpu()
        reshaped_final_data = torch_data
        # if isinstance(data, np.ndarray):
        #     reshaped_final_data = torch_data.numpy().astype(dtype_data)
        # else:
        #     reshaped_final_data = torch_data.to(dtype_data)

        #print(f"Reshaped data from {shape} to {new_shape}")
        #print(f"reshaped_final_data shape: {reshaped_final_data.shape}")
        assert reshaped_final_data.ndim == 4, f"reshaped_final_data.shape = {reshaped_final_data.shape}"
        return reshaped_final_data
    else:
        print("no resampling necessary")
        return data

## Upload Data

In [None]:
from google.colab import drive
from os import mkdir
from os.path import exists
from shutil import copy

drive.mount('/content/drive')
if not exists("inference_test_cases"):
    mkdir("inference_test_cases")
if not exists("inference_test_cases.zip"):
    copy("/content/drive/MyDrive/inference_test_cases.zip", "inference_test_cases.zip")
!unzip inference_test_cases.zip -d inference_test_cases

## Run Inference

In [None]:
from os import listdir
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
from time import time


def main() -> dict[str, list[float]]:
    CT_configuration = {
        "transpose_forward": [
            0,
            1,
            2
        ],
        "spacing": [
            2.5,
            0.7958984971046448,
            0.7958984971046448
        ],
        'intensity_prop': {
            "max": 3071.0,
            "mean": 97.29716491699219,
            "median": 118.0,
            "min": -1024.0,
            "percentile_00_5": -958.0,
            "percentile_99_5": 270.0,
            "std": 137.8484649658203
        }
    }

    normalization = CTNormalization(intensityproperties=CT_configuration['intensity_prop'])
    cases = listdir("inference_test_cases")
    results = {"cropping": [], "normalization": [], "resampling": []}
    for case in cases:
        image, properties = SimpleITKIO().read_images([case[:case.find(".")]])
        image = torch.from_numpy(image).to(dtype=torch.float32, memory_format=torch.contiguous_format).to('cuda')
        data = image.clone()
        data = data.permute([0, *[i + 1 for i in CT_configuration['transpose_forward']]])
        original_spacing = [properties['spacing'][i] for i in CT_configuration['transpose_forward']]
        t0 = time()
        data, seg, bbox = crop_to_nonzero(data)
        torch.cuda.synchronize()
        results["cropping"].append(time() - t0)
        target_spacing = CT_configuration['spacing']
        if len(target_spacing) < len(data.shape[1:]):
            target_spacing = [original_spacing[0]] + target_spacing
        new_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing)
        t0 = time()
        data = normalization.run(data)
        torch.cuda.synchronize()
        results["normalization"].append(time() - t0)
        fast_resample_data_or_seg_to_shape(data, new_shape, original_spacing, target_spacing)
        torch.cuda.synchronize()
        results["resampling"].append(time() - t0)
    return results


print(main())