Import libraries

In [None]:
import os
import random

from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from tqdm import tqdm
from skimage.exposure import rescale_intensity
from skimage.io import imread, imsave
from skimage.transform import resize, rescale, rotate
from torch.utils.data import Dataset
from torchvision.transforms import Compose

Set up the functions for the future using. Each function has the clear explanation in case further changes are needed

In [None]:
def crop_sample(x):
    # Function to crop the volume and mask based on non-zero values
    volume, mask = x

    # Set voxels with values less than 10% of maximum to 0
    volume[volume < np.max(volume) * 0.1] = 0

    # Compute maximum projection along the z-axis
    z_projection = np.max(np.max(np.max(volume, axis=-1), axis=-1), axis=-1)

    # Find non-zero indices in the z-projection
    z_nonzero = np.nonzero(z_projection)

    # Determine minimum and maximum indices in the z-axis
    z_min = np.min(z_nonzero)
    z_max = np.max(z_nonzero) + 1

    # Compute maximum projection along the y-axis
    y_projection = np.max(np.max(np.max(volume, axis=0), axis=-1), axis=-1)

    # Find non-zero indices in the y-projection
    y_nonzero = np.nonzero(y_projection)

    # Determine minimum and maximum indices in the y-axis
    y_min = np.min(y_nonzero)
    y_max = np.max(y_nonzero) + 1

    # Compute maximum projection along the x-axis
    x_projection = np.max(np.max(np.max(volume, axis=0), axis=0), axis=-1)

    # Find non-zero indices in the x-projection
    x_nonzero = np.nonzero(x_projection)

    # Determine minimum and maximum indices in the x-axis
    x_min = np.min(x_nonzero)
    x_max = np.max(x_nonzero) + 1

    # Return cropped volume and mask based on computed indices
    return (
        volume[z_min:z_max, y_min:y_max, x_min:x_max],
        mask[z_min:z_max, y_min:y_max, x_min:x_max],
    )


def pad_sample(x):
    # Function to pad the volume and mask to make them square-shaped
    volume, mask = x

    # Get dimensions of the volume along the y and x axes
    a = volume.shape[1]
    b = volume.shape[2]

    # Check if dimensions are already equal
    if a == b:
        return volume, mask

    # Calculate difference between maximum and minimum dimensions
    diff = (max(a, b) - min(a, b)) / 2.0

    # Check if padding is needed along y-axis or x-axis
    if a > b:
        padding = ((0, 0), (0, 0), (int(np.floor(diff)), int(np.ceil(diff))))
    else:
        padding = ((0, 0), (int(np.floor(diff)), int(np.ceil(diff))), (0, 0))

    # Pad the mask with zeros using the determined padding
    mask = np.pad(mask, padding, mode="constant", constant_values=0)

    # Add padding to the volume with zeros
    padding = padding + ((0, 0),)
    volume = np.pad(volume, padding, mode="constant", constant_values=0)

    # Return padded volume and mask
    return volume, mask


def resize_sample(x, size=256):
    # Function to resize the volume and mask to a target size
    volume, mask = x

    # Get shape of the volume
    v_shape = volume.shape

    # Set output shape for the volume based on target size
    out_shape = (v_shape[0], size, size)

    # Resize the mask to the output shape using nearest-neighbor interpolation
    mask = resize(
        mask,
        output_shape=out_shape,
        order=0,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    )

    # Update output shape to include the number of channels
    out_shape = out_shape + (v_shape[3],)

    # Resize the volume to the output shape using bicubic interpolation
    volume = resize(
        volume,
        output_shape=out_shape,
        order=2,
        mode="constant",
        cval=0,
        anti_aliasing=False,
    )

    # Return resized volume and mask
    return volume, mask


def normalize_volume(volume):
    # Function to normalize the volume intensities
    # Compute the 10th and 99th percentiles of the volume intensity values
    p10 = np.percentile(volume, 10)
    p99 = np.percentile(volume, 99)

    # Rescale the volume intensities based on the percentiles
    volume = rescale_intensity(volume, in_range=(p10, p99))

    # Compute the mean and standard deviation of the volume intensities along the spatial axes
    m = np.mean(volume, axis=(0, 1, 2))
    s = np.std(volume, axis=(0, 1, 2))

    # Standardize the volume intensities by subtracting the mean and dividing by the standard deviation
    volume = (volume - m) / s

    # Return the normalized volume
    return volume