## Install Dependencies

In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

In [None]:
!pip3 install nnunetv2 acvl-utils cupy-cuda12x

## Define Functions

In [None]:
from abc import abstractmethod, ABC
from typing import Optional, Dict, Union, Tuple, List

import numpy as np
import cupy as cp
import torch
from nnunetv2.configuration import ANISO_THRESHOLD
from cupyx.scipy import ndimage
from nnunetv2.utilities.helpers import empty_cache
import gc

CT_configuration = {
    "transpose_forward": [
        0,
        1,
        2
    ],
    "spacing": [
                2.5,
                0.7958984971046448,
                0.7958984971046448
            ],
    'intensity_prop': {
            "max": 3071.0,
            "mean": 97.29716491699219,
            "median": 118.0,
            "min": -1024.0,
            "percentile_00_5": -958.0,
            "percentile_99_5": 270.0,
            "std": 137.8484649658203
        }}

def logit_to_segment(predicted_logits):
    max_logit, max_class = torch.max(predicted_logits, dim=0)
    segmentation = torch.where(max_logit >= 0.5, max_class, torch.tensor(0, device=predicted_logits.device))

    return segmentation

def resize_by_chunk(torch_data, new_shape, chunk_size = 300):
    torch_data = torch_data.detach().cpu()
    torch.cuda.empty_cache()
    step = new_shape[0] // chunk_size + 1
    seg_old_spacing = np.zeros(new_shape)
    z = torch_data.shape[2]
    stride = int(z / step)
    step1 = [i * stride for i in range(step)] + [z]
    z = new_shape[0]
    stride = int(z / step)
    step2 = [i * stride for i in range(step)] + [z]
    for i in range(step):
        size = list(new_shape)
        size[0] = step2[i + 1] - step2[i]
        slicer = torch_data[:,:, step1[i]:step1[i + 1]]#.half()
        slicer = torch.nn.functional.interpolate(slicer.cuda(), mode='trilinear', size=size, align_corners=True)[0]
        seg_old_spacing[step2[i]:step2[i + 1]] = logit_to_segment(slicer).cpu()
        del slicer
        torch.cuda.empty_cache()

    return torch.from_numpy(seg_old_spacing)

def fast_resample_logit_to_shape(torch_data: Union[torch.Tensor, np.ndarray],
                                  new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
                                  current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                                  new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                                  is_seg: bool = False,
                                  order: int = 3, order_z: int = 0,
                                  force_separate_z: Union[bool, None] = False,
                                  separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
    use_gpu = True
    device = torch.device("cuda" if use_gpu else "cpu")
    order_to_mode_map = {
        0: "nearest",
        1: "trilinear" if new_shape[0] > 1 else "bilinear",
        2: "trilinear" if new_shape[0] > 1 else "bilinear",
        3: "trilinear" if new_shape[0] > 1 else "bicubic",
        4: "trilinear" if new_shape[0] > 1 else "bicubic",
        5: "trilinear" if new_shape[0] > 1 else "bicubic",
    }
    resize_fn = torch.nn.functional.interpolate
    kwargs = {
        'mode': order_to_mode_map[order],
        'align_corners': False,
    }
    shape = np.array(torch_data[0].shape)
    new_shape = np.array(new_shape)
    if np.any(shape != new_shape):
        if new_shape[0] == 1:
            torch_data = torch_data.transpose(1, 0)
            new_shape = new_shape[1:]
        else:
            torch_data = torch_data.unsqueeze(0)
        gc.collect()
        empty_cache(device)
        if new_shape[0] < 600:
            torch_data = resize_fn(torch_data.to(device), tuple(new_shape), **kwargs)

            if new_shape[0] == 1:
                torch_data = torch_data.transpose(1, 0)
            else:
                torch_data = torch_data.squeeze(0)
        else:
            torch_data = resize_by_chunk(torch_data.to(device), tuple(new_shape))
        reshaped_final_data = torch_data
        return reshaped_final_data
    else:
        print("no resampling necessary")
        return torch_data

def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray],
                                                                properties_dict: dict,
                                                                ):

    spacing_transposed = [properties_dict['spacing'][i] for i in CT_configuration['transpose_forward']]
    current_spacing = CT_configuration['spacing'] if \
        len(CT_configuration['spacing']) == \
        len(properties_dict['shape_after_cropping_and_before_resampling']) else \
        [spacing_transposed[0], *CT_configuration['spacing']]
    predicted_logits = fast_resample_logit_to_shape(predicted_logits,
                                            properties_dict['shape_after_cropping_and_before_resampling'],
                                            current_spacing,
                                            [properties_dict['spacing'][i] for i in CT_configuration['transpose_forward']])

    segmentation = logit_to_segment(predicted_logits)
    return segmentation


class ImageNormalization(ABC):
    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = None

    def __init__(self, use_mask_for_norm: Optional[bool] = None, intensityproperties: Optional[Dict] = None,
                 target_dtype: torch.dtype = torch.float32):
        assert use_mask_for_norm is None or isinstance(use_mask_for_norm, bool)
        self.use_mask_for_norm = use_mask_for_norm
        assert isinstance(intensityproperties, dict) or intensityproperties is None
        self.intensityproperties = intensityproperties
        self.target_dtype = target_dtype

    @abstractmethod
    def run(self, image: torch.Tensor, seg: Optional[torch.Tensor] = None) -> torch.Tensor:
        raise NotImplementedError


class CTNormalization(ImageNormalization):
    leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False

    def run(self, image: torch.Tensor, seg: Optional[torch.Tensor] = None) -> torch.Tensor:
        assert self.intensityproperties is not None, "CTNormalization requires intensity properties"
        mean_intensity = self.intensityproperties['mean']
        std_intensity = self.intensityproperties['std']
        lower_bound = self.intensityproperties['percentile_00_5']
        upper_bound = self.intensityproperties['percentile_99_5']

        image = image.to(dtype=self.target_dtype)
        image = torch.clamp(image, lower_bound, upper_bound)
        image = (image - mean_intensity) / max(std_intensity, 1e-8)
        return image


def create_nonzero_mask(data):
    """

    :param data:
    :return: the mask is True where the data is nonzero
    """
    assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
    nonzero_mask = data[0] != 0
    for c in range(1, data.shape[0]):
        nonzero_mask |= data[c] != 0
    filled_mask = ndimage.binary_fill_holes(nonzero_mask)
    return filled_mask


def get_bbox_from_mask(mask: cp.ndarray) -> List[List[int]]:
    """
    ALL bounding boxes in acvl_utils and nnU-Netv2 are half open interval [start, end)!
    - Alignment with Python Slicing
    - Ease of Subdivision
    - Consistency in Multi-Dimensional Arrays
    - Precedent in Computer Graphics

    This implementation uses CuPy for GPU acceleration. The mask should be a CuPy array.

    Args:
        mask (cp.ndarray): 3D mask array on GPU

    Returns:
        List[List[int]]: Bounding box coordinates as [[minz, maxz], [minx, maxx], [miny, maxy]]
    """
    Z, X, Y = mask.shape
    minzidx, maxzidx, minxidx, maxxidx, minyidx, maxyidx = 0, Z, 0, X, 0, Y

    # Create range arrays on GPU
    zidx = cp.arange(Z)
    xidx = cp.arange(X)
    yidx = cp.arange(Y)

    # Z dimension
    for z in zidx.get():  # .get() to iterate over CPU array
        if cp.any(mask[z]).get():  # .get() to get boolean result to CPU
            minzidx = z
            break
    for z in zidx[::-1].get():
        if cp.any(mask[z]).get():
            maxzidx = z + 1
            break

    # X dimension
    for x in xidx.get():
        if cp.any(mask[:, x]).get():
            minxidx = x
            break
    for x in xidx[::-1].get():
        if cp.any(mask[:, x]).get():
            maxxidx = x + 1
            break

    # Y dimension
    for y in yidx.get():
        if cp.any(mask[:, :, y]).get():
            minyidx = y
            break
    for y in yidx[::-1].get():
        if cp.any(mask[:, :, y]).get():
            maxyidx = y + 1
            break

    return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]

def bounding_box_to_slice(bounding_box: List[List[int]]):
    """
    ALL bounding boxes in acvl_utils and nnU-Netv2 are half open interval [start, end)!
    - Alignment with Python Slicing
    - Ease of Subdivision
    - Consistency in Multi-Dimensional Arrays
    - Precedent in Computer Graphics
    https://chatgpt.com/share/679203ec-3fbc-8013-a003-13a7adfb1e73
    """
    return tuple([slice(*i) for i in bounding_box])


def crop_to_nonzero(data, seg=None, nonzero_label=-1):
    """

    :param data:
    :param seg:
    :param nonzero_label: this will be written into the segmentation map
    :return:
    """
    nonzero_mask = create_nonzero_mask(data)
    bbox = get_bbox_from_mask(nonzero_mask)
    slicer = bounding_box_to_slice(bbox)
    nonzero_mask = nonzero_mask[slicer][None]

    slicer = (slice(None),) + slicer
    data = data[slicer]
    if seg is not None:
        seg = seg[slicer]
        seg[(seg == 0) & (~nonzero_mask)] = nonzero_label
    else:
        seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label))
    return data, seg, bbox


def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
                      old_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                      new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray:
    assert len(old_spacing) == len(old_shape)
    assert len(old_shape) == len(new_spacing)
    new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)])
    return new_shape


def fast_resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],
                                       new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
                                       current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                                       new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                                       is_seg: bool = False,
                                       order: int = 3, order_z: int = 0,
                                       force_separate_z: Union[bool, None] = False,
                                       separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
    use_gpu = True
    device = torch.device("cuda" if use_gpu else "cpu")
    order_to_mode_map = {
        0: "nearest",
        1: "trilinear" if new_shape[0] > 1 else "bilinear",
        2: "trilinear" if new_shape[0] > 1 else "bilinear",
        3: "trilinear" if new_shape[0] > 1 else "bicubic",
        4: "trilinear" if new_shape[0] > 1 else "bicubic",
        5: "trilinear" if new_shape[0] > 1 else "bicubic",
    }
    resize_fn = torch.nn.functional.interpolate
    kwargs = {
        'mode': order_to_mode_map[order],
        'align_corners': False
    }
    shape = np.array(data[0].shape)
    new_shape = np.array(new_shape)
    if np.any(shape != new_shape):
        if not isinstance(data, torch.Tensor):
            #torch_data = torch.from_numpy(data).float()
            torch_data = torch.as_tensor(data.get())
        else:
            torch_data = data.float()
        if new_shape[0] == 1:
            torch_data = torch_data.transpose(1, 0)
            new_shape = new_shape[1:]
        else:
            torch_data = torch_data.unsqueeze(0)

        torch_data = resize_fn(torch_data.to(device), tuple(new_shape), **kwargs)

        if new_shape[0] == 1:
            torch_data = torch_data.transpose(1, 0)
        else:
            torch_data = torch_data.squeeze(0)

        # if use_gpu:
        #     torch_data = torch_data.cpu()
        reshaped_final_data = torch_data
        # if isinstance(data, np.ndarray):
        #     reshaped_final_data = torch_data.numpy().astype(dtype_data)
        # else:
        #     reshaped_final_data = torch_data.to(dtype_data)

        #print(f"Reshaped data from {shape} to {new_shape}")
        #print(f"reshaped_final_data shape: {reshaped_final_data.shape}")
        assert reshaped_final_data.ndim == 4, f"reshaped_final_data.shape = {reshaped_final_data.shape}"
        return reshaped_final_data
    else:
        print("no resampling necessary")
        return data

## Run Inference

In [None]:
from os import listdir
from time import time
from torch import load
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO


def main() -> dict[str, list[float]]:
    cases = listdir("inference_test_cases")
    results = {"resampling": [], "conversion": []}
    for case in cases:
        image, properties = SimpleITKIO().read_images([f"inference_test_cases/{case[:case.find('.')]}"])
        image = torch.from_numpy(image).to(dtype=torch.float32, memory_format=torch.contiguous_format).to('cuda')
        data = image.clone()
        data = data.permute([0, *[i + 1 for i in CT_configuration['transpose_forward']]])
        shape_before_cropping = data.shape[1:]
        properties['shape_before_cropping'] = shape_before_cropping
        logit = load(f"inference_test_logit/{case[:case.find('.')]}.pt".replace("_0000", "")).to("cuda")
        data, seg, bbox = crop_to_nonzero(data, None)
        torch.cuda.synchronize()
        properties['bbox_used_for_cropping'] = bbox
        properties['shape_after_cropping_and_before_resampling'] = data.shape[1:]
        spacing_transposed = [properties['spacing'][i] for i in CT_configuration['transpose_forward']]
        current_spacing = CT_configuration['spacing'] if len(CT_configuration['spacing']) == len(
            properties['shape_after_cropping_and_before_resampling']) else [spacing_transposed[0],
                                                                            *CT_configuration['spacing']]
        t0 = time()
        predicted_logit = fast_resample_data_or_seg_to_shape(logit, properties['shape_after_cropping_and_before_resampling'],
                                                        current_spacing, [properties['spacing'][i] for i in
                                                                          CT_configuration['transpose_forward']])
        torch.cuda.synchronize()
        results["resampling"].append(time() - t0)
        t0 = time()
        convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logit)
        results["conversion"].append(time() - t0)
    return results


def add_up(table: dict[str, list[float]], entry: str, n: int, num_cases: int) -> list[float]:
    r = []
    for i in range(num_cases):
        s = 0
        for j in range(n):
            s += table[entry][i + j * num_cases]
        r.append(s / n)
    return r


_final = {"resampling": [], "conversion": []}
_n = 10
_num_cases = 5
for _b in range(_n):
    print(f"Porcessing batch {_b}")
    _r = main()
    _final["resampling"] += _r["resampling"]
    _final["conversion"] += _r["conversion"]
print(add_up(_final, "resampling", _n, _num_cases))
print(add_up(_final, "conversion", _n, _num_cases))