<a href="https://colab.research.google.com/github/ProjectNeura/bwlab_eval/blob/main/postprocess_eval_s.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install Dependencies

In [1]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

Looking in indexes: https://download.pytorch.org/whl/cu126
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading https://download.pytorch.org/whl/cu126/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading https://download.pytorch.org/whl/cu126/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.met

In [2]:
!pip3 install nnunetv2 acvl-utils

Collecting nnunetv2
  Downloading nnunetv2-2.6.0.tar.gz (206 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/206.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m206.3/206.3 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting acvl-utils
  Downloading acvl_utils-0.2.5.tar.gz (29 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dynamic-network-architectures<0.4,>=0.3.1 (from nnunetv2)
  Downloading dynamic_network_architectures-0.3.1.tar.gz (20 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dicom2nifti (from nnunetv2)
  Downloading dicom2nifti-2.6.0-py3-none-any.whl.metadata (1.5 kB)
Collecting batchgenerators>=0.25.1 (from nnunetv2)
  Downloading batchgenerators-0.25.1.tar.gz (76 kB)
[2

## Define Functions

In [1]:
from collections import OrderedDict
from copy import deepcopy
from typing import Union, Tuple, List

import numpy as np
import pandas as pd
import torch
from batchgenerators.augmentations.utils import resize_segmentation
from nnunetv2.configuration import ANISO_THRESHOLD
from scipy.ndimage import map_coordinates
from scipy.ndimage import binary_fill_holes
from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, bounding_box_to_slice
from skimage.transform import resize


def create_nonzero_mask(data):
    """

    :param data:
    :return: the mask is True where the data is nonzero
    """
    assert data.ndim in (3, 4), "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
    nonzero_mask = data[0] != 0
    for c in range(1, data.shape[0]):
        nonzero_mask |= data[c] != 0
    return binary_fill_holes(nonzero_mask)


def crop_to_nonzero(data, seg=None, nonzero_label=-1):
    """

    :param data:
    :param seg:
    :param nonzero_label: this will be written into the segmentation map
    :return:
    """
    nonzero_mask = create_nonzero_mask(data)
    bbox = get_bbox_from_mask(nonzero_mask)
    slicer = bounding_box_to_slice(bbox)
    nonzero_mask = nonzero_mask[slicer][None]

    slicer = (slice(None), ) + slicer
    data = data[slicer]
    if seg is not None:
        seg = seg[slicer]
        seg[(seg == 0) & (~nonzero_mask)] = nonzero_label
    else:
        seg = np.where(nonzero_mask, np.int8(0), np.int8(nonzero_label))
    return data, seg, bbox


def get_do_separate_z(spacing: Union[Tuple[float, ...], List[float], np.ndarray], anisotropy_threshold=ANISO_THRESHOLD):
    do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold
    return do_separate_z


def get_lowres_axis(new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]):
    axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0]  # find which axis is anisotropic
    return axis


def determine_do_sep_z_and_axis(
        force_separate_z: bool,
        current_spacing,
        new_spacing,
        separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]:
    if force_separate_z is not None:
        do_separate_z = force_separate_z
        if force_separate_z:
            axis = get_lowres_axis(current_spacing)
        else:
            axis = None
    else:
        if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold):
            do_separate_z = True
            axis = get_lowres_axis(current_spacing)
        elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold):
            do_separate_z = True
            axis = get_lowres_axis(new_spacing)
        else:
            do_separate_z = False
            axis = None

    if axis is not None:
        if len(axis) == 3:
            do_separate_z = False
            axis = None
        elif len(axis) == 2:
            # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
            # separately in the out of plane axis
            do_separate_z = False
            axis = None
        else:
            axis = axis[0]
    return do_separate_z, axis


def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray],
                         is_seg: bool = False, axis: Union[None, int] = None, order: int = 3,
                         do_separate_z: bool = False, order_z: int = 0, dtype_out=None):
    """
    separate_z=True will resample with order 0 along z
    :param data:
    :param new_shape:
    :param is_seg:
    :param axis:
    :param order:
    :param do_separate_z:
    :param order_z: only applies if do_separate_z is True
    :return:
    """
    assert data.ndim == 4, "data must be (c, x, y, z)"
    assert len(new_shape) == data.ndim - 1

    if is_seg:
        resize_fn = resize_segmentation
        kwargs = OrderedDict()
    else:
        resize_fn = resize
        kwargs = {'mode': 'edge', 'anti_aliasing': False}
    shape = np.array(data[0].shape)
    new_shape = np.array(new_shape)
    if dtype_out is None:
        dtype_out = data.dtype
    reshaped_final = np.zeros((data.shape[0], *new_shape), dtype=dtype_out)
    if np.any(shape != new_shape):
        data = data.astype(float, copy=False)
        if do_separate_z:
            # print("separate z, order in z is", order_z, "order inplane is", order)
            assert axis is not None, 'If do_separate_z, we need to know what axis is anisotropic'
            if axis == 0:
                new_shape_2d = new_shape[1:]
            elif axis == 1:
                new_shape_2d = new_shape[[0, 2]]
            else:
                new_shape_2d = new_shape[:-1]

            for c in range(data.shape[0]):
                tmp = deepcopy(new_shape)
                tmp[axis] = shape[axis]
                reshaped_here = np.zeros(tmp)
                for slice_id in range(shape[axis]):
                    if axis == 0:
                        reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs)
                    elif axis == 1:
                        reshaped_here[:, slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)
                    else:
                        reshaped_here[:, :, slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order,
                                                                  **kwargs)
                if shape[axis] != new_shape[axis]:

                    # The following few lines are blatantly copied and modified from sklearn's resize()
                    rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]
                    orig_rows, orig_cols, orig_dim = reshaped_here.shape

                    # align_corners=False
                    row_scale = float(orig_rows) / rows
                    col_scale = float(orig_cols) / cols
                    dim_scale = float(orig_dim) / dim

                    map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim]
                    map_rows = row_scale * (map_rows + 0.5) - 0.5
                    map_cols = col_scale * (map_cols + 0.5) - 0.5
                    map_dims = dim_scale * (map_dims + 0.5) - 0.5

                    coord_map = np.array([map_rows, map_cols, map_dims])
                    if not is_seg or order_z == 0:
                        reshaped_final[c] = map_coordinates(reshaped_here, coord_map, order=order_z, mode='nearest')[
                            None]
                    else:
                        unique_labels = np.sort(pd.unique(reshaped_here.ravel()))  # np.unique(reshaped_data)
                        for i, cl in enumerate(unique_labels):
                            reshaped_final[c][np.round(
                                map_coordinates((reshaped_here == cl).astype(float), coord_map, order=order_z,
                                                mode='nearest')) > 0.5] = cl
                else:
                    reshaped_final[c] = reshaped_here
        else:
            # print("no separate z, order", order)
            for c in range(data.shape[0]):
                reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)
        return reshaped_final
    else:
        # print("no resampling necessary")
        return data


def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],
                                  new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
                                  current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                                  new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
                                  is_seg: bool = False,
                                  order: int = 3, order_z: int = 0,
                                  force_separate_z: Union[bool, None] = False,
                                  separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
    """
    needed for segmentation export. Stupid, I know
    """
    if isinstance(data, torch.Tensor):
        data = data.numpy()

    do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
                                                      separate_z_anisotropy_threshold)

    if data is not None:
        assert data.ndim == 4, "data must be c x y z"

    data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)
    return data_reshaped


@torch.inference_mode()
def convert_probabilities_to_segmentation(predicted_probabilities: Union[np.ndarray, torch.Tensor]) -> \
        Union[np.ndarray, torch.Tensor]:
    """
    assumes that inference_nonlinearity was already applied!

    predicted_probabilities has to have shape (c, x, y(, z)) where c is the number of classes/regions
    """
    if not isinstance(predicted_probabilities, (np.ndarray, torch.Tensor)):
        raise RuntimeError(f"Unexpected input type. Expected np.ndarray or torch.Tensor,"
                           f" got {type(predicted_probabilities)}")

    # numpy is faster than torch. :facepalm:
    is_numpy = isinstance(predicted_probabilities, np.ndarray)
    if not is_numpy:
        predicted_probabilities = predicted_probabilities.numpy()
    segmentation = predicted_probabilities.argmax(0)
    if not is_numpy:
        segmentation = torch.from_numpy(segmentation)

    return segmentation


@torch.inference_mode()
def convert_logits_to_segmentation(predicted_logits: Union[np.ndarray, torch.Tensor]) -> \
        Union[np.ndarray, torch.Tensor]:
    input_is_numpy = isinstance(predicted_logits, np.ndarray)
    # we can skip this step if we do not have region. Argmax is the same between logits or probabilities

    probabilities = predicted_logits
    if input_is_numpy and isinstance(probabilities, torch.Tensor):
        probabilities = probabilities.cpu().numpy()
    return convert_probabilities_to_segmentation(probabilities)

## Upload Data

In [2]:
from google.colab import drive
from os import mkdir
from os.path import exists
from shutil import copy

drive.mount('/content/drive')
!rm -r inference_test_cases
mkdir("inference_test_cases")
if not exists("inference_test_cases.zip"):
    copy("/content/drive/MyDrive/inference_test_cases.zip", "inference_test_cases.zip")
!unzip inference_test_cases.zip -d inference_test_cases

!rm -r inference_test_logit
mkdir("inference_test_logit")
if not exists("inference_test_logit.zip"):
    copy("/content/drive/MyDrive/inference_test_logit.zip", "inference_test_logit.zip")
!unzip inference_test_logit.zip -d inference_test_logit

Mounted at /content/drive
rm: cannot remove 'inference_test_cases': No such file or directory
Archive:  inference_test_cases.zip
 extracting: inference_test_cases/FLARETs_0001_0000.nii.gz  
 extracting: inference_test_cases/FLARETs_0015_0000.nii.gz  
 extracting: inference_test_cases/FLARETs_0018_0000.nii.gz  
 extracting: inference_test_cases/FLARETs_0036_0000.nii.gz  
 extracting: inference_test_cases/FLARETs_0046_0000.nii.gz  
rm: cannot remove 'inference_test_logit': No such file or directory
Archive:  inference_test_logit.zip
  inflating: inference_test_logit/FLARETs_0001.pt  
  inflating: inference_test_logit/FLARETs_0015.pt  
  inflating: inference_test_logit/FLARETs_0018.pt  
  inflating: inference_test_logit/FLARETs_0036.pt  
  inflating: inference_test_logit/FLARETs_0046.pt  


In [3]:
!ls inference_test_logit -l --block-size=M

total 5828M
-rw-r--r-- 1 root root 1512M Mar 26 15:21 FLARETs_0001.pt
-rw-r--r-- 1 root root 1385M Mar 26 15:15 FLARETs_0015.pt
-rw-r--r-- 1 root root 1483M Mar 26 15:21 FLARETs_0018.pt
-rw-r--r-- 1 root root  638M Mar 26 14:29 FLARETs_0036.pt
-rw-r--r-- 1 root root  811M Mar 26 14:48 FLARETs_0046.pt


In [4]:
!rm inference_test_cases/FLARETs_0001_0000.nii.gz inference_test_logit/FLARETs_0001.pt inference_test_cases/FLARETs_0018_0000.nii.gz inference_test_logit/FLARETs_0018.pt inference_test_cases/FLARETs_0015_0000.nii.gz inference_test_logit/FLARETs_0015.pt

## Run Inference

In [5]:
from os import listdir
from time import time
from torch import load
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO


def main() -> dict[str, list[float]]:
    CT_configuration = {
        "transpose_forward": [
            0,
            1,
            2
        ],
        "spacing": [
            2.5,
            0.7958984971046448,
            0.7958984971046448
        ],
        'intensity_prop': {
            "max": 3071.0,
            "mean": 97.29716491699219,
            "median": 118.0,
            "min": -1024.0,
            "percentile_00_5": -958.0,
            "percentile_99_5": 270.0,
            "std": 137.8484649658203
        }}

    cases = listdir("inference_test_cases")
    results = {"resampling": [], "conversion": []}
    for case in cases:
        image, properties = SimpleITKIO().read_images([f"inference_test_cases/{case[:case.find('.')]}"])
        logit = load(f"inference_test_logit/{case[:case.find('.')]}.pt".replace("_0000", "")).numpy()
        data = image.astype(np.float32)
        data = data.transpose([0, *[i + 1 for i in CT_configuration['transpose_forward']]])
        shape_before_cropping = data.shape[1:]
        properties['shape_before_cropping'] = shape_before_cropping
        data, seg, bbox = crop_to_nonzero(data, None)
        properties['bbox_used_for_cropping'] = bbox
        properties['shape_after_cropping_and_before_resampling'] = data.shape[1:]
        spacing_transposed = [properties['spacing'][i] for i in CT_configuration['transpose_forward']]
        current_spacing = CT_configuration['spacing'] if len(CT_configuration['spacing']) == len(
            properties['shape_after_cropping_and_before_resampling']) else [spacing_transposed[0],
                                                                            *CT_configuration['spacing']]
        t0 = time()
        predicted_logit = resample_data_or_seg_to_shape(logit, properties['shape_after_cropping_and_before_resampling'],
                                                        current_spacing, [properties['spacing'][i] for i in
                                                                          CT_configuration['transpose_forward']])
        results["resampling"].append(time() - t0)
        t0 = time()
        convert_logits_to_segmentation(predicted_logit)
        results["conversion"].append(time() - t0)
    return results


def add_up(table: dict[str, list[float]], entry: str, n: int, num_cases: int) -> list[float]:
    r = []
    for i in range(num_cases):
        s = 0
        for j in range(n):
            s += table[entry][i + j * num_cases]
        r.append(s / n)
    return r


_final = {"resampling": [], "conversion": []}
_n = 10
_num_cases = 2
for _b in range(_n):
    print(f"Porcessing batch {_b}")
    _r = main()
    _final["resampling"] += _r["resampling"]
    _final["conversion"] += _r["conversion"]
print(add_up(_final, "resampling", _n, _num_cases))
print(add_up(_final, "conversion", _n, _num_cases))

Porcessing batch 0
Porcessing batch 1
Porcessing batch 2
Porcessing batch 3
Porcessing batch 4
Porcessing batch 5
Porcessing batch 6
Porcessing batch 7
Porcessing batch 8
Porcessing batch 9
[584.4420743227005, 217.17762620449065]
[7.661114120483399, 2.7208886623382567]
