<a href="https://colab.research.google.com/github/Debayan2004/BR-Tumor-Segmentation/blob/main/compute_parameters.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import tensorflow as tf
import numpy as np

from scipy.spatial.distance import directed_hausdorff

In [3]:
def calculate_class_weights(y_true, num_classes=11, irrelevant_classes=[0, 10, 11], max_weight=50):
    """
    Calculate class weights based on class proportions in the training dataset.
    Classes with fewer voxels (rare classes) will have higher weights.

    Args:
        y_true: Ground truth labels (one-hot encoded, shape: [batch_size, depth, height, width, num_classes]).
        num_classes: Total number of classes in the dataset.
        irrelevant_classes: List of labels to be excluded from weight calculation.
        max_weight: Maximum allowable weight for any class to prevent excessive weighting.

    Returns:
        class_weights: Tensor of class weights, where the index corresponds to the class.
    """
    # Sum over spatial dimensions (depth, height, width) to get the count of voxels for each class
    num_voxels_per_class = tf.reduce_sum(y_true, axis=(0, 1, 2, 3))  # Shape: [num_classes]

    # Calculate the total number of voxels across all classes
    total_voxels = tf.reduce_sum(num_voxels_per_class)

    # Inverse proportionality to calculate class weights
    class_weights = total_voxels / (num_voxels_per_class + 1e-6)  # Avoid division by zero

    # Scale weights logarithmically to avoid overly large weights
    class_weights = tf.math.log1p(class_weights)  # log1p(x) = log(x + 1)

    # Cap weights to avoid extreme values
    class_weights = tf.minimum(class_weights, max_weight)

    # Set weights for irrelevant classes to zero
    irrelevant_mask = tf.reduce_sum(tf.one_hot(irrelevant_classes, depth=num_classes), axis=0)
    class_weights = class_weights * (1 - irrelevant_mask)  # Zero out weights for irrelevant labels

    return class_weights

In [4]:
def normalized_categorical_crossentropy_with_weights(y_true, y_pred, class_weights):
    """
    Calculates normalized categorical cross-entropy loss with class weights for multi-class segmentation.

    Args:
        y_true: Ground truth labels (one-hot encoded, shape: [batch_size, height, width, num_classes]).
        y_pred: Predicted probabilities (after softmax, shape: [batch_size, height, width, num_classes]).
        class_weights: A tensor or list of class weights (length should match number of classes).

    Returns:
        The normalized categorical cross-entropy loss.
    """
    # Convert class_weights to a tensor if not already
    class_weights = tf.convert_to_tensor(class_weights, dtype=tf.float32)

    # Flatten y_true and y_pred, keeping the class dimension intact
    y_true_flat = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])  # Shape: (total_voxels, num_classes)
    y_pred_flat = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])  # Shape: (total_voxels, num_classes)

    # Categorical cross-entropy loss (not averaged yet)
    ce_loss = tf.keras.losses.categorical_crossentropy(y_true_flat, y_pred_flat)

    # Get indices of the true classes from one-hot encoding
    true_class_indices = tf.argmax(y_true_flat, axis=-1)

    # Apply class weights
    weighted_ce_loss = ce_loss * tf.gather(class_weights, true_class_indices)

    # Normalize by taking the mean across all voxels
    normalized_loss = tf.reduce_mean(weighted_ce_loss)

    return normalized_loss

In [5]:
def custom_loss_function(y_true, y_pred, class_weights):
    """
    Custom loss function with class weights for imbalanced datasets.

    Args:
        y_true: Ground truth labels (one-hot encoded, shape: [batch_size, depth, height, width, num_classes]).
        y_pred: Predicted probabilities (softmax output, shape: [batch_size, depth, height, width, num_classes]).
        class_weights: Class weights tensor calculated based on the dataset.

    Returns:
        Loss value for the batch.
    """
    # Compute the weighted normalized categorical cross-entropy
    loss = normalized_categorical_crossentropy_with_weights(y_true, y_pred, class_weights)
    return loss


In [6]:
def get_loss_function(y_true_sample, num_classes=11, irrelevant_classes=[0, 10, 11]):
    """
    Prepare the custom loss function by excluding irrelevant classes.

    Args:
        y_true_sample: A sample of the ground truth labels to calculate weights.
        num_classes: Total number of classes in the dataset.
        irrelevant_classes: Classes to be excluded from the weighting calculation.

    Returns:
        A custom loss function ready for use.
    """
    # Mask irrelevant classes by setting them to 0
    mask = tf.reduce_sum(tf.one_hot(irrelevant_classes, depth=num_classes), axis=0)
    y_true_sample = y_true_sample * (1 - mask)

    # Calculate class weights excluding irrelevant classes
    class_weights = calculate_class_weights(y_true_sample, num_classes=num_classes, irrelevant_classes=irrelevant_classes)

    def loss_fn(y_true, y_pred):
        # Apply the custom loss function with calculated weights
        return custom_loss_function(y_true, y_pred, class_weights)

    return loss_fn

In [7]:
# Dice Coefficient Function
def dice_coefficient(y_true, y_pred, threshold=0.5):
    """Calculate the Dice coefficient for multi-class segmentation."""
    smooth = 1e-6  # Avoid division by zero
    y_true = tf.round(y_true)  # Convert to binary values
    y_pred = tf.round(y_pred)  # Convert to binary values
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])  # Sum intersection
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])  # Sum of union
    dice = 2 * intersection / (union + smooth)  # Calculate Dice coefficient
    return tf.reduce_mean(dice)  # Mean over the batch


In [8]:
# Volume Similarity Function
def volume_similarity(y_true, y_pred):
    """Calculate the volume similarity for multi-class segmentation."""
    y_true = tf.round(y_true)  # Convert to binary values
    y_pred = tf.round(y_pred)  # Convert to binary values
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])  # Intersection
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])  # Union
    volume_sim = intersection / (union + 1e-6)  # Volume similarity
    return tf.reduce_mean(volume_sim)  # Mean over the batch


In [9]:
# Hausdorff Distance Calculation Function
def HausdorffDist(A, B):
    """Compute the Hausdorff distance between two point clouds."""
    # Find pairwise distance
    D_mat = np.sqrt(inner1d(A, A)[np.newaxis].T + inner1d(B, B) - 2 * (np.dot(A, B.T)))
    # Find DH
    dH = np.max(np.array([np.max(np.min(D_mat, axis=0)), np.max(np.min(D_mat, axis=1))]))
    return dH

In [10]:
# Modified Hausdorff Distance Calculation Function
def ModHausdorffDist(A, B):
    """Compute the Modified Hausdorff Distance."""
    # Find pairwise distance
    D_mat = np.sqrt(inner1d(A, A)[np.newaxis].T + inner1d(B, B) - 2 * (np.dot(A, B.T)))
    # Calculating the forward HD: mean(min(each col))
    FHD = np.mean(np.min(D_mat, axis=1))
    # Calculating the reverse HD: mean(min(each row))
    RHD = np.mean(np.min(D_mat, axis=0))
    # Calculating MHD
    MHD = np.max(np.array([FHD, RHD]))
    return MHD, FHD, RHD

# Hausdorff Distance Wrapper for TensorFlow
def hausdorff_distance(y_true, y_pred):
    """Calculate the Hausdorff distance between the true and predicted binary masks."""
    # Convert to binary mask (0 or 1)
    y_true = tf.round(y_true)
    y_pred = tf.round(y_pred)

    # Convert tensors to numpy arrays outside TensorFlow operations
    y_true = y_true.numpy() if tf.executing_eagerly() else y_true
    y_pred = y_pred.numpy() if tf.executing_eagerly() else y_pred

    # Extract coordinates of non-zero voxels (the contour of the mask)
    true_points = np.array(np.where(y_true > 0)).T
    pred_points = np.array(np.where(y_pred > 0)).T

    # Compute directed Hausdorff distance
    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]

    # The Hausdorff distance is the maximum of these two directed distances
    return np.max([forward_hausdorff, reverse_hausdorff])