# [**Adaptive Real-Time Multi-Loss Function Optimization Using Dynamic Memory Fusion Framework: A Case Study on Breast Cancer Segmentation**](https://arxiv.org/pdf/2410.19745)  

---

Amin Golnari, Mostafa Diba. **"Adaptive Real-Time Multi-Loss Function Optimization Using Dynamic Memory Fusion Framework: A Case Study on Breast Cancer Segmentation"** Preprint. [https://doi.org/10.48550/arXiv.2410.19745](https://doi.org/10.48550/arXiv.2410.19745)


**Amin Golnari <sup>a<sup>**, **Mostafa Diba <sup>a<sup>** <br>
a) Faculty of Electrical Engineering, Shahrood University of Technology, Shahrood, Iran <br>

**Link to the preprint version:** [Click here](https://arxiv.org/pdf/2410.19745)

**Or You Can Run this Python Code on Google Colab:**    

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/amingolnari/Demo-Dynamic-Memory-Fusion-Framework/blob/main/DynamicMemoryFusion.ipynb)

**Compatibility:**
- TensorFlow: 2.17.0
- Keras: 3.4.1
- TensorFlow Datasets: 4.9.6

---

## **Framework Overview**

The **Dynamic Memory Fusion Framework** is an adaptive, real-time, multi-loss functions optimization framework designed to improve deep learning task performance.

This framework is based on deep learning and addresses the inefficiencies of manual tuning in multi-loss functions by dynamically adjusting loss weights during training. It utilizes historical loss data to update weights and integrates auxiliary loss functions that enhance model performance, especially in the early stages of training. Additionally, the **Class-Balanced Dice Loss** function is introduced to address the issue of class imbalance, which is crucial for accurate segmentation tasks.

---

## **Key Features:**

1. **Dynamic Weight Adjustment**:
   - The framework dynamically adjusts the weighting of multiple loss functions in real time based on historical loss values.
   - This adaptation ensures better performance across different training stages without the need for manual fine-tuning.

2. **Class-Balanced Dice Loss (CB-Dice)**:
   - A novel loss function that handles class imbalance by focusing more on underrepresented classes, improving the overall segmentation accuracy.

3. **Auxiliary Loss Functions**:
   - Auxiliary loss functions are employed to assist in the early stages of training, helping the model converge faster.

4. **Breast Ultrasound Dataset**:
   - The framework has been tested on breast ultrasound datasets, demonstrating improvements in segmentation accuracy and robustness across various metrics, such as Dice score and IoU (Intersection over Union).

---

## **Usage of CB-Dice Loss**

The **CB-Dice Loss** function is crucial in image segmentation where there is often a significant class imbalance. In this framework, the Dice loss function has been modified to prioritize underrepresented classes, ensuring better accuracy for minority class segmentation (e.g., cancerous regions in medical images).

---

If our work is helpful to you, please kindly cite our paper as:

    @article{GOLNARI2024DMF,
       title={Adaptive Real-Time Multi-Loss Function Optimization Using Dynamic Memory Fusion Framework: A Case Study on Breast Cancer Segmentation},
       author={Golnari, Amin and Diba, Mostafa},
       journal={arXiv preprint arXiv:2410.19745},
       year={2024},
       doi={https://doi.org/10.48550/arXiv.2410.19745}
    }


In [None]:
# Standard libraries
import os         # For handling file and directory operations
import shutil     # For performing high-level file operations like copying and deleting

# Deep learning frameworks
import keras             # High-level neural networks API running on top of TensorFlow
import tensorflow as tf  # TensorFlow is an open-source machine learning library

# Data visualization
import matplotlib.pyplot as plt  # For plotting graphs and visualizing data

# TensorFlow datasets
import tensorflow_datasets as tfds  # Library of ready-to-use datasets for machine learning

# Numerical operations
import numpy as np  # For efficient array and numerical operations

# Specific modules from Keras
from keras import layers, models  # For defining layers and building models in Keras
from keras import backend as K    # Low-level tensor manipulation and backend operations in Keras

# Scikit-learn's train_test_split utility
from sklearn.model_selection import train_test_split  # For splitting datasets into training and testing sets

In [None]:
class Metrics:
    """
    A class to compute various evaluation metrics and loss functions for classification models,
    supporting both binary and multi-class scenarios. It leverages TensorFlow and Keras backend
    for efficient computation and integration with neural network models.
    """

    def _confusion_matrix_elements(self, y_true, y_pred, class_id, is_binary, across_batch=False):
        """
        Computes the elements of the confusion matrix (True Positives, False Positives,
        True Negatives, False Negatives) for a specific class.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.
            class_id (int): The class identifier for which to compute the confusion matrix elements.
            is_binary (bool): Flag indicating if the classification is binary.
            across_batch (bool): Whether to compute metrics across the entire batch.

        Returns:
            Tuple of Tensors: (TP, FP, TN, FN)
        """
        if is_binary:
            # Threshold predictions for binary classification
            y_pred = tf.where(y_pred >= 0.5, 1.0, 0.0)
            y_pred_binary = tf.cast(y_pred == class_id, tf.float32)
            y_true_binary = tf.cast(y_true == class_id, tf.float32)
        else:
            # For multi-class classification, use argmax to determine predicted class
            y_true_binary = tf.cast(tf.argmax(y_true, axis=-1) == class_id, tf.float32)
            y_pred_binary = tf.cast(tf.argmax(y_pred, axis=-1) == class_id, tf.float32)

        # Determine the shape of the tensors to handle different dimensions
        if tf.rank(y_true_binary) == 2:  # [height, width]
            TP = tf.reduce_sum(tf.cast((y_true_binary == 1) & (y_pred_binary == 1), tf.float32))
            FP = tf.reduce_sum(tf.cast((y_true_binary == 0) & (y_pred_binary == 1), tf.float32))
            TN = tf.reduce_sum(tf.cast((y_true_binary == 0) & (y_pred_binary == 0), tf.float32))
            FN = tf.reduce_sum(tf.cast((y_true_binary == 1) & (y_pred_binary == 0), tf.float32))
        else:
            # For 3D tensors (batch, height, width), reduce across spatial dimensions
            TP = tf.reduce_sum(tf.cast((y_true_binary == 1) & (y_pred_binary == 1), tf.float32), axis=[1, 2])
            FP = tf.reduce_sum(tf.cast((y_true_binary == 0) & (y_pred_binary == 1), tf.float32), axis=[1, 2])
            TN = tf.reduce_sum(tf.cast((y_true_binary == 0) & (y_pred_binary == 0), tf.float32), axis=[1, 2])
            FN = tf.reduce_sum(tf.cast((y_true_binary == 1) & (y_pred_binary == 0), tf.float32), axis=[1, 2])

            if not across_batch:
                # Average over the batch dimension
                TP = tf.reduce_mean(TP)
                FP = tf.reduce_mean(FP)
                TN = tf.reduce_mean(TN)
                FN = tf.reduce_mean(FN)

        return TP, FP, TN, FN

    def _get_num_classes(self, y_true, y_pred):
        """
        Determines the number of classes and whether the classification is binary.

        Args:
            y_true (numpy.ndarray): Ground truth labels.
            y_pred (numpy.ndarray): Predicted labels.

        Returns:
            Tuple: (number_of_classes, is_binary)
        """
        if np.shape(y_pred)[-1] == 1:  # Binary classification with sigmoid output
            return 2, True
        elif np.shape(y_pred)[-1] == 2:  # Binary classification with softmax output
            return 2, False
        else:  # Multi-class classification with softmax output
            return np.shape(y_pred)[-1], False

    def _compute_metric(self, y_true, y_pred, metric_fn, across_classes=True):
        """
        Computes a specified metric across all classes.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.
            metric_fn (function): Function to compute the metric based on TP, FP, TN, FN.
            across_classes (bool): Whether to average the metric across classes.

        Returns:
            Tensor: Computed metric.
        """
        num_classes, is_binary = self._get_num_classes(y_true, y_pred)
        metric_sum = 0.0

        if across_classes:
            # Aggregate metric across all classes
            for class_id in range(num_classes):
                TP, FP, TN, FN = self._confusion_matrix_elements(y_true, y_pred, class_id, is_binary)
                metric_sum += metric_fn(TP, FP, TN, FN)

            return metric_sum / tf.cast(num_classes, tf.float32)
        else:
            # Compute metric for each class individually
            metric = []
            for class_id in range(num_classes):
                TP, FP, TN, FN = self._confusion_matrix_elements(y_true, y_pred, class_id, is_binary)
                metric.append(metric_fn(TP, FP, TN, FN))

            return metric

    # ----------------------- Evaluation Metrics -----------------------

    def precision(self, y_true, y_pred):
        """
        Calculates the precision metric.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: Precision score.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: TP / (TP + FP + K.epsilon())
        )

    def recall(self, y_true, y_pred):
        """
        Calculates the recall metric.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: Recall score.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: TP / (TP + FN + K.epsilon())
        )

    def f1_score(self, y_true, y_pred):
        """
        Calculates the F1 score, the harmonic mean of precision and recall.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: F1 score.
        """
        precision_value = self.precision(y_true, y_pred)
        recall_value = self.recall(y_true, y_pred)
        return 2 * (precision_value * recall_value) / (precision_value + recall_value + K.epsilon())

    def accuracy(self, y_true, y_pred):
        """
        Calculates the accuracy metric.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: Accuracy score.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: (TP + TN) / (TP + FP + TN + FN + K.epsilon())
        )

    def dice(self, y_true, y_pred):
        """
        Calculates the Dice coefficient.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: Dice coefficient.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: (2 * TP) / (2 * TP + FP + FN + K.epsilon())
        )

    def iou(self, y_true, y_pred):
        """
        Calculates the Intersection over Union (IoU) metric.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: IoU score.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: TP / (TP + FP + FN + K.epsilon())
        )

    def specificity(self, y_true, y_pred):
        """
        Calculates the specificity metric.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: Specificity score.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: TN / (TN + FP + K.epsilon())
        )

    def sensitivity(self, y_true, y_pred):
        """
        Calculates the sensitivity metric, equivalent to recall.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: Sensitivity score.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: TP / (TP + FN + K.epsilon())
        )

    def fp_rate(self, y_true, y_pred):
        """
        Calculates the False Positive Rate (FPR).

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: FPR score.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: FP / (FP + TN + K.epsilon())
        )

    def fn_rate(self, y_true, y_pred):
        """
        Calculates the False Negative Rate (FNR).

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: FNR score.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: FN / (FN + TP + K.epsilon())
        )

    def negative_predictive(self, y_true, y_pred):
        """
        Calculates the Negative Predictive Value (NPV).

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: NPV score.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: TN / (TN + FN + K.epsilon())
        )

    # ----------------------- Loss Functions -----------------------

    def dice_loss(self, y_true, y_pred):
        """
        Calculates the Dice loss, which is 1 minus the Dice coefficient.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: Dice loss.
        """
        return 1 - self.dice(y_true, y_pred)

    def iou_loss(self, y_true, y_pred):
        """
        Calculates the IoU loss, which is 1 minus the IoU score.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: IoU loss.
        """
        return 1 - self.iou(y_true, y_pred)

    def precision_loss(self, y_true, y_pred):
        """
        Calculates the Precision loss, which is 1 minus the precision score.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: Precision loss.
        """
        return 1 - self.precision(y_true, y_pred)

    def recall_loss(self, y_true, y_pred):
        """
        Calculates the Recall loss, which is 1 minus the recall score.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: Recall loss.
        """
        return 1 - self.recall(y_true, y_pred)

    def f1_score_loss(self, y_true, y_pred):
        """
        Calculates the F1 score loss, which is 1 minus the F1 score.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: F1 score loss.
        """
        return 1 - self.f1_score(y_true, y_pred)

    def focal_loss(self, y_true, y_pred, gamma=2.0, alpha=0.25):
        """
        Computes the Focal Loss to address class imbalance.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels (probabilities).
            gamma (float): Focusing parameter that adjusts the rate at which easy examples are down-weighted.
            alpha (float): Weighting factor for the rare class.

        Returns:
            Tensor: Focal Loss value.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: -alpha * tf.pow(1 - TP / (TP + FP + K.epsilon()), gamma) * tf.math.log(TP / (TP + FP + K.epsilon()))
        )


    def tversky_loss(self, y_true, y_pred, alpha=0.7, beta=0.3):
        """
        Computes the Tversky loss for imbalanced data.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.
            alpha (float): Weight for false positives.
            beta (float): Weight for false negatives.

        Returns:
            Tensor: Tversky loss value.
        """
        return self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: 1 - (TP + K.epsilon()) / (TP + alpha * FP + beta * FN + K.epsilon())
        )

    # ----------------------- Class Balanced Dice Metric and Loss -----------------------

    def _compute_class_weights(self, y_true):
        """
        Computes class weights based on the frequency of each class in the ground truth.

        Args:
            y_true (Tensor): Ground truth labels.

        Returns:
            Tensor: Normalized class weights.
        """
        epsilon = 1e-6
        total_pixels = tf.reduce_prod(tf.shape(y_true)[:-1])
        class_pixel_counts = tf.reduce_sum(y_true, axis=[0, 1, 2])
        class_ratios = class_pixel_counts / tf.cast(total_pixels, tf.float32)
        class_weights = 1 / (class_ratios + epsilon)
        class_weights = class_weights / np.sum(class_weights)  # Normalize weights

        return class_weights

    def class_balanced_dice_score(self, y_true, y_pred):
        """
        Computes a class-balanced Dice score by weighting each class's Dice score.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: Class-balanced Dice score.
        """
        # Compute Dice score for each class without averaging across classes
        loss_score = self._compute_metric(
            y_true, y_pred,
            lambda TP, FP, TN, FN: (2 * TP) / (2 * TP + FP + FN + K.epsilon()),
            across_classes=False
        )
        class_weights = self._compute_class_weights(y_true)
        score = tf.reduce_sum(class_weights * loss_score)  # Weighted sum of Dice score
        return score

    def class_balanced_dice_loss(self, y_true, y_pred):
        """
        Calculates the class-balanced Dice loss.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: Class-balanced Dice loss.
        """
        return 1 - self.class_balanced_dice_score(y_true, y_pred)


# Instantiate the Metrics class
metrics_instance = Metrics()

# ----------------------- Metric Functions -----------------------

def precision(y_true, y_pred):
    """Wrapper function to compute precision using the Metrics instance."""
    return metrics_instance.precision(y_true, y_pred)

def recall(y_true, y_pred):
    """Wrapper function to compute recall using the Metrics instance."""
    return metrics_instance.recall(y_true, y_pred)

def f1_score(y_true, y_pred):
    """Wrapper function to compute F1 score using the Metrics instance."""
    return metrics_instance.f1_score(y_true, y_pred)

def accuracy(y_true, y_pred):
    """Wrapper function to compute accuracy using the Metrics instance."""
    return metrics_instance.accuracy(y_true, y_pred)

def dice(y_true, y_pred):
    """Wrapper function to compute Dice coefficient using the Metrics instance."""
    return metrics_instance.dice(y_true, y_pred)

def iou(y_true, y_pred):
    """Wrapper function to compute Intersection over Union (IoU) using the Metrics instance."""
    return metrics_instance.iou(y_true, y_pred)

def specificity(y_true, y_pred):
    """Wrapper function to compute specificity using the Metrics instance."""
    return metrics_instance.specificity(y_true, y_pred)

def sensitivity(y_true, y_pred):
    """Wrapper function to compute sensitivity using the Metrics instance."""
    return metrics_instance.sensitivity(y_true, y_pred)

def dice_loss(y_true, y_pred):
    """Wrapper function to compute Dice loss using the Metrics instance."""
    return metrics_instance.dice_loss(y_true, y_pred)

def iou_loss(y_true, y_pred):
    """Wrapper function to compute IoU loss using the Metrics instance."""
    return metrics_instance.iou_loss(y_true, y_pred)

def precision_loss(y_true, y_pred):
    """Wrapper function to compute Precision loss using the Metrics instance."""
    return metrics_instance.precision_loss(y_true, y_pred)

def recall_loss(y_true, y_pred):
    """Wrapper function to compute Recall loss using the Metrics instance."""
    return metrics_instance.recall_loss(y_true, y_pred)

def f1_score_loss(y_true, y_pred):
    """Wrapper function to compute F1 score loss using the Metrics instance."""
    return metrics_instance.f1_score_loss(y_true, y_pred)

def tversky_loss(y_true, y_pred):
    """Wrapper function to compute Tversky loss using the Metrics instance."""
    return metrics_instance.tversky_loss(y_true, y_pred)

def focal_loss(y_true, y_pred):
    """Wrapper function to compute Focal loss using the Metrics instance."""
    return metrics_instance.focal_loss(y_true, y_pred)

def class_balanced_dice_score(y_true, y_pred):
    """Wrapper function to compute class-balanced Dice score using the Metrics instance."""
    return metrics_instance.class_balanced_dice_score(y_true, y_pred)

def class_balanced_dice_loss(y_true, y_pred):
    """Wrapper function to compute class-balanced Dice loss using the Metrics instance."""
    return metrics_instance.class_balanced_dice_loss(y_true, y_pred)

In [None]:
class DynamicMemoryFusionLoss(tf.keras.losses.Loss):
    """
    A custom Keras loss function that dynamically fuses multiple loss functions
    by adjusting their weights based on historical performance. This approach
    aims to balance the contribution of each loss function during training,
    enhancing model performance especially in complex tasks with multiple objectives.
    """

    def __init__(self,
                 loss_functions,
                 initial_loss_weights=None,
                 history_size=50,
                 history_size_max=100,
                 batch_track_steps=None,
                 gama=1,
                 decay_rate=0.05,
                 auxiliary_loss_function=None,
                 weighting_method='var',
                 training_phase=True,
                 name="DynamicMemoryFusionLoss",
                 **kwargs):
        """
        Initializes the DynamicMemoryFusionLoss.

        Args:
            loss_functions (list of callable): A list of loss functions to be fused.
            initial_loss_weights (list of float, optional): Initial weights for each loss function.
                If None, weights are initialized uniformly.
            history_size (int, optional): The number of past loss values to consider for adjusting weights.
                Minimum is set to 30.
            history_size_max (int, optional): The maximum number of historical loss values to store.
                Minimum is set to 90.
            batch_track_steps (int, optional): Number of steps between tracking updates.
                If None, tracking occurs every step.
            gama (float, optional): Scaling factor for the auxiliary loss.
            decay_rate (float, optional): Decay rate for scaling the auxiliary loss over steps.
            auxiliary_loss_function (callable, optional): An additional loss function to include.
            weighting_method (str, optional): Method for adjusting loss weights. Options are
                'var' (variance-based), 'bayes', or 'mad' (median absolute deviation). Defaults to 'var'.
            training_phase (bool, optional): Flag indicating if the loss is being used in training phase.
                If False, weight adjustments are not performed.
            name (str, optional): Name of the loss function.
            **kwargs: Additional keyword arguments for the parent class.
        """
        super(DynamicMemoryFusionLoss, self).__init__(name=name, **kwargs)

        # Store the loss functions and initialize their weights
        self.loss_functions = loss_functions
        if initial_loss_weights is not None:
            self.loss_weights = initial_loss_weights
        else:
            # Initialize weights uniformly if not provided
            self.loss_weights = [1.0 / len(loss_functions)] * len(loss_functions)

        # Set history size parameters with minimum constraints
        self.history_size = max(30, history_size)
        self.history_size_max = max(90, history_size_max)

        # Initialize batch tracking parameters
        self.batch_track_steps = 0 if batch_track_steps is None else batch_track_steps
        self.batch_track = self.batch_track_steps

        # Set scaling and decay parameters
        self.gama = gama
        self.decay_rate = decay_rate

        # Auxiliary loss function and weighting method
        self.auxiliary_loss_function = auxiliary_loss_function
        if weighting_method in ['var', 'bayes', 'mad']:
            self.weighting_method = weighting_method
        else:
            self.weighting_method = 'mad'  # Default to 'mad' if invalid method is provided

        # Flag to indicate if in training phase
        self.training_phase = training_phase

        # Initialize priors for Bayesian weighting
        self.priors = np.array([1.0 / len(loss_functions)] * len(loss_functions))

        # Initialize history storage for each loss function
        self.loss_histories = [[] for _ in loss_functions]

        # Initialize step counter
        self.step = 0

    def _ema(self, values, alpha=0.9):
        """
        Computes the Exponential Moving Average (EMA) of the provided values.

        Args:
            values (Tensor): Tensor of values to compute EMA on.
            alpha (float, optional): Smoothing factor for EMA. Defaults to 0.9.

        Returns:
            float: The EMA of the values.
        """
        values = values.numpy()
        ema_value = values[0]
        for i in range(1, len(values)):
            ema_value = alpha * values[i] + (1 - alpha) * ema_value

        return ema_value

    def _batch_track(self):
        """
        Manages the tracking of batches to control when weight adjustments occur.
        """
        if self.batch_track_steps:
            self.batch_track -= 1
            if self.batch_track == 0:
                self.batch_track = self.batch_track_steps
                self.step += 1
        else:
            self.step += 1

    def _update_loss_histories(self, loss_values):
        """
        Updates the historical loss values for each loss function.

        Args:
            loss_values (list of float): Current loss values for each loss function.
        """
        for i, loss_value in enumerate(loss_values):
            self.loss_histories[i].append(loss_value)
            # Ensure the history does not exceed the maximum size
            if len(self.loss_histories[i]) > self.history_size_max:
                self.loss_histories[i].pop(0)

    def _compute_median(self, tensor):
        """
        Computes the median of a tensor.

        Args:
            tensor (Tensor): Input tensor.

        Returns:
            Tensor: Median value.
        """
        tensor = tf.reshape(tensor, [-1])
        sorted_tensor = tf.sort(tensor)
        num_elements = tf.size(sorted_tensor)
        is_even = tf.math.mod(num_elements, 2)
        median_idx = num_elements // 2
        if is_even:
            # Average the two middle values for even-sized tensors
            median = tf.reduce_mean([sorted_tensor[median_idx - 1], sorted_tensor[median_idx]])
        else:
            # Middle value for odd-sized tensors
            median = sorted_tensor[median_idx]

        return median

    def _compute_mad(self, loss_history):
        """
        Computes the Median Absolute Deviation (MAD) of a loss history.

        Args:
            loss_history (Tensor): Tensor containing historical loss values.

        Returns:
            Tensor: MAD value.
        """
        median = self._compute_median(loss_history)
        abs_deviations = tf.math.abs(loss_history - median)
        mad = self._compute_median(abs_deviations)

        return mad

    def _minmax_scaling(self, loss):
        """
        Applies min-max scaling to the loss values.

        Args:
            loss (Tensor): Tensor of loss values.

        Returns:
            Tensor: Scaled loss values between 0 and 1.
        """
        min_val = tf.reduce_min(loss)
        max_val = tf.reduce_max(loss)
        scaled_loss = (loss - min_val) / (max_val - min_val + K.epsilon())

        return scaled_loss

    def _symmetric_log_scaling(self, loss):
        """
        Applies symmetric logarithmic scaling to the loss values.

        Args:
            loss (Tensor): Tensor of loss values.

        Returns:
            Tensor: Scaled loss values with logarithmic transformation.
        """
        positive_scaled = tf.math.log1p(loss)      # log(1 + loss) for positive values
        negative_scaled = -tf.math.log1p(-loss)   # -log(1 + (-loss)) for negative values
        scaled_loss = tf.where(loss >= 0, positive_scaled, negative_scaled)

        return scaled_loss

    def _adjust_loss_weights(self):
        """
        Adjusts the weights of each loss function based on their historical performance
        using the specified weighting method.
        """
        if self.weighting_method == 'var':
            if len(self.loss_histories[0]) >= self.history_size:
                normalized_histories = []
                for history in self.loss_histories:
                    history_tensor = tf.convert_to_tensor(history)
                    normalized_log = self._symmetric_log_scaling(history_tensor)
                    normalized_minmax = self._minmax_scaling(normalized_log)
                    normalized_histories.append(normalized_minmax)

                # Compute variance for each normalized loss history
                variances = [tf.math.reduce_variance(history) for history in normalized_histories]

                if len(variances) == len(self.loss_functions):
                    total_variance = tf.reduce_sum(variances)
                    normalized_weights = [v / total_variance for v in variances]
                    self.loss_weights = normalized_weights

        elif self.weighting_method in ['bayes', 'mad']:
            if len(self.loss_histories[0]) >= self.history_size:
                normalized_histories = []
                for history in self.loss_histories:
                    history_tensor = tf.convert_to_tensor(history)
                    normalized_log = self._symmetric_log_scaling(history_tensor)
                    normalized_minmax = self._minmax_scaling(normalized_log)
                    normalized_histories.append(normalized_minmax)

                # Compute MAD for each normalized loss history
                mad_values = tf.stack([self._compute_mad(history) for history in normalized_histories])

                if len(mad_values) == len(self.loss_functions):
                    if self.weighting_method == 'bayes':
                        # Bayesian weighting: likelihood * priors normalized
                        likelihoods = 1.0 / (mad_values + K.epsilon())
                        posteriors = likelihoods * self.priors
                        posteriors /= tf.reduce_sum(posteriors)
                        self.loss_weights = posteriors
                        self.priors = posteriors  # Update priors
                    else:
                        # MAD-based weighting: inverse MAD normalized
                        self.loss_weights = 1.0 / (mad_values + K.epsilon())
                        self.loss_weights /= tf.reduce_sum(self.loss_weights)

    def call(self, y_true, y_pred):
        """
        Computes the dynamic fused loss.

        Args:
            y_true (Tensor): Ground truth labels.
            y_pred (Tensor): Predicted labels.

        Returns:
            Tensor: The computed dynamic fused loss.
        """
        # Compute individual loss values from each loss function
        loss_values = [loss_function(y_true, y_pred) for loss_function in self.loss_functions]

        if self.training_phase:
            # Update loss histories with current loss values
            self._update_loss_histories(loss_values)
            # Adjust loss weights based on historical performance
            self._adjust_loss_weights()

        # Compute the weighted sum of loss values
        weighted_loss = tf.reduce_sum([w * loss for w, loss in zip(self.loss_weights, loss_values)])

        # Incorporate auxiliary loss if provided
        if self.auxiliary_loss_function is not None:
            auxiliary_loss = self.auxiliary_loss_function(y_true, y_pred)
            # Apply scaling and decay to the auxiliary loss
            weighted_loss += (self.gama * tf.math.exp(-self.decay_rate * self.step)) * auxiliary_loss

            if self.training_phase:
                # Update batch tracking for auxiliary loss
                self._batch_track()

        return weighted_loss

In [None]:
# A function defining a downsampling block with MaxPooling2D
def downsampling_block(x, filters, kernel_size=3, padding='same'):
    """
    Applies a convolutional downsampling block to the input using MaxPooling.

    Args:
        x (Tensor): Input tensor.
        filters (int): Number of filters for the convolution.
        kernel_size (int, optional): Size of the convolution kernel. Defaults to 3.
        padding (str, optional): Padding strategy, typically 'same'. Defaults to 'same'.

    Returns:
        Tensor: Output tensor after applying Conv2D, BatchNormalization, ReLU activation, and MaxPooling.
    """
    x = layers.Conv2D(filters, kernel_size, padding=padding)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    pooled_x = layers.MaxPooling2D(pool_size=(2, 2))(x)  # Apply MaxPooling for downsampling
    return pooled_x, x  # Return both the pooled output and the pre-pooled tensor (for skip connection)

# A function defining an upsampling block, which increases the spatial dimensions of the input.
def upsampling_block(x, skip, filters, kernel_size=3, stride=2, padding='same'):
    """
    Applies a convolutional upsampling block to the input and concatenates with skip connections.

    Args:
        x (Tensor): Input tensor.
        skip (Tensor): Skip connection tensor from the downsampling path.
        filters (int): Number of filters for the transposed convolution.
        kernel_size (int, optional): Size of the convolution kernel. Defaults to 3.
        stride (int, optional): Stride for the transposed convolution. Defaults to 2 for upsampling.
        padding (str, optional): Padding strategy, typically 'same'. Defaults to 'same'.

    Returns:
        Tensor: Output tensor after applying Conv2DTranspose, concatenating with skip connection,
        BatchNormalization, and ReLU activation.
    """
    # Apply a transposed convolution for upsampling
    x = layers.Conv2DTranspose(filters, kernel_size, strides=stride, padding=padding)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    # Concatenate the skip connection from the downsampling path
    x = layers.Concatenate()([x, skip])
    return x

# Function to create a U-Net model with skip connections
def get_unet_model(img_size, filters, num_classes):
    """
    Builds a U-Net model for image segmentation.

    Args:
        img_size (tuple): Shape of the input image (height, width, channels).
        filters (int): Initial number of filters for the convolutional layers.
        num_classes (int): Number of output classes for the segmentation task.

    Returns:
        Model: A compiled U-Net model for segmentation.
    """
    inputs = tf.keras.Input(shape=img_size)

    # Downsampling path
    downsampling_layers = []
    x = inputs
    while x.shape[1] > 8 and x.shape[2] > 8:
        x, skip_connection = downsampling_block(x, filters)  # Get both downsampled and skip connection
        downsampling_layers.append(skip_connection)  # Save the layer for skip connection
        filters *= 2

    # Upsampling path with skip connections
    for skip in reversed(downsampling_layers):
        filters //= 2
        x = upsampling_block(x, skip, filters)

    # Final output layer
    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(x)

    # Create and return the U-Net model
    model = models.Model(inputs, outputs)
    return model

In [None]:
def train_model(model, data_train, data_val, data_test, loss_function, optimizer, metrics, num_epochs):
    # Initialize the variable to track the best loss (used to identify the best model)
    best_loss = float('inf')

    # Main training loop for the specified number of epochs
    for epoch in range(num_epochs):
        # Initialize metrics for the current epoch
        epoch_metrics = {name: tf.keras.metrics.Mean() for name in metrics.keys()}  # To store per-metric averages
        epoch_loss_avg = tf.keras.metrics.Mean()  # To compute average loss over the epoch

        # Iterate through batches of the training dataset
        for batch_index, (x_batch, y_batch) in enumerate(data_train):
            with tf.GradientTape() as tape:
                # Make predictions with the model
                y_pred = model(x_batch, training=True)
                # Compute the loss for the current batch
                loss_value = loss_function(y_batch, y_pred)

            # Compute gradients and apply them to update the model weights
            grads = tape.gradient(loss_value, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            del tape  # Free up memory by deleting the gradient tape

            # Update the average loss and metrics for the current epoch
            epoch_loss_avg.update_state(loss_value)
            for name, metric in metrics.items():
                epoch_metrics[name].update_state(metric(y_batch, y_pred))

            # Log progress for each batch in the current epoch
            log_message = f"[Epoch {epoch + 1:03d} - Batch {batch_index + 1:04d}] "
            log_message += f"Loss: {loss_value.numpy():.4f} | "
            log_message += " | ".join([f"{name.capitalize()}: {metric.result().numpy():.4f}" for name, metric in epoch_metrics.items()])
            print(log_message)

        # Validation phase: evaluate the model on the validation dataset
        val_loss_avg = tf.keras.metrics.Mean()  # To store average validation loss
        val_metrics = {name: tf.keras.metrics.Mean() for name in metrics.keys()}  # Store validation metrics

        # Disable training-specific behavior for loss function (like dropout)
        loss_function.training_phase = False

        # Iterate over validation dataset
        for x_val, y_val in data_val:
            y_val_pred = model(x_val, training=False)  # Make predictions for validation
            val_loss_value = loss_function(y_val, y_val_pred)  # Compute validation loss

            val_loss_avg.update_state(val_loss_value)
            for name, metric in metrics.items():
                val_metrics[name].update_state(metric(y_val, y_val_pred))

        # Switch loss function back to training mode for the next epoch
        loss_function.training_phase = True

        # Log validation results for the epoch
        log_message = f"[Epoch {epoch + 1:03d} - Validation] "
        log_message += f"Loss: {val_loss_avg.result().numpy():.4f} | "
        log_message += " | ".join([f"{name.capitalize()}: {metric.result().numpy():.4f}" for name, metric in val_metrics.items()])
        print(log_message)

    # Test phase: Evaluate the model on the test dataset after training is completed
    loss_function.training_phase = False  # Disable training-specific behaviors for loss calculation

    test_loss_avg = tf.keras.metrics.Mean()  # To store average test loss
    test_metrics = {name: tf.keras.metrics.Mean() for name in metrics.keys()}  # Store test metrics

    # Iterate over the test dataset
    for x_test, y_test in data_test:
        y_test_pred = model.predict(x_test, verbose=0)  # Make predictions for test data
        test_loss_value = loss_function(y_test, y_test_pred)  # Compute test loss

        test_loss_avg.update_state(test_loss_value)
        for name, metric in metrics.items():
            test_metrics[name].update_state(metric(y_test, y_test_pred))

    # Log the test results
    log_message = f"[Test] "
    log_message += f"Loss: {test_loss_avg.result().numpy():.4f} | "
    log_message += " | ".join([f"{name.capitalize()}: {metric.result().numpy():.4f}" for name, metric in test_metrics.items()])
    print(log_message)

    return model  # Return the trained model

In [None]:
def prepare_breast_dataset():
    """
    Prepares the breast ultrasound dataset by creating necessary directories,
    downloading the dataset from Kaggle, extracting images and masks, and moving
    them to appropriate folders for further use.

    Steps:
        1. Create necessary directories for the dataset.
        2. Download the dataset from Kaggle using the Kaggle API.
        3. Unzip the dataset into a temporary location.
        4. Move images and masks from the unzipped folder to dedicated directories.
        5. Clean up temporary folders after processing.

    Returns:
        None
    """

    # Step 1: Create directories for storing the dataset (images and masks)
    os.makedirs('Dataset', exist_ok=True)  # Create a 'Dataset' folder if it doesn't already exist
    os.makedirs('Breast-Dataset/images', exist_ok=True)  # Create folder for storing images
    os.makedirs('Breast-Dataset/masks', exist_ok=True)  # Create folder for storing masks

    # Step 2: Download the breast ultrasound dataset from Kaggle (force overwrite if exists)
    # The dataset is downloaded as a .zip file.
    !kaggle datasets download -d jocelyndumlao/bus-synthetic-dataset --force

    # Step 3: Unzip the downloaded dataset into the 'Dataset' folder
    # The -o option overwrites existing files, ensuring the latest dataset is used.
    !unzip -o "/content/bus-synthetic-dataset.zip" -d "/content/Dataset"

    # Step 4: Move all image files from the extracted dataset folder to the 'Breast-Dataset/images' directory
    images_primary_path = '/content/Dataset/BUS Synthetic Dataset/BUS_synthetic_dataset/images'  # Path to extracted images
    images_destination_path = '/content/Breast-Dataset/images'  # Destination for images
    for item in os.listdir(images_primary_path):
        # Skip non-png files (only move .png image files)
        if not item.endswith('png'):
            continue
        # Move the image file from the source folder to the destination folder
        source_item = os.path.join(images_primary_path, item)
        destination_item = os.path.join(images_destination_path, item)
        shutil.move(source_item, destination_item)

    # Step 5: Move all mask files from the extracted dataset folder to the 'Breast-Dataset/masks' directory
    masks_primary_path = '/content/Dataset/BUS Synthetic Dataset/BUS_synthetic_dataset/masks'  # Path to extracted masks
    masks_destination_path = '/content/Breast-Dataset/masks'  # Destination for masks
    for item in os.listdir(masks_primary_path):
        # Skip non-png files (only move .png mask files)
        if not item.endswith('png'):
            continue
        # Move the mask file from the source folder to the destination folder
        source_item = os.path.join(masks_primary_path, item)
        destination_item = os.path.join(masks_destination_path, item)
        shutil.move(source_item, destination_item)

    # Step 6: Clean up by removing the 'Dataset' directory after processing
    # Try removing the 'Dataset' folder to save space after moving files
    try:
        shutil.rmtree('Dataset')  # Remove the temporary 'Dataset' directory
    except OSError as e:
        # Print error if the directory cannot be removed
        print(f"Error: 'Dataset' could not be removed. {e}")

In [None]:
prepare_breast_dataset()

In [None]:
def load_image_and_mask(image_path, mask_path):
    """
    Loads an image and its corresponding segmentation mask from file paths.

    Args:
        image_path (str): Path to the image file.
        mask_path (str): Path to the mask file.

    Returns:
        image (Tensor): Loaded image tensor with shape [height, width, 1].
        mask (Tensor): Loaded mask tensor with shape [height, width, 1].
    """
    # Load the image as a raw byte file and decode it into a grayscale PNG image (1 channel)
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=1)

    # Load the mask as a raw byte file and decode it into a grayscale PNG image (1 channel)
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=1)

    return image, mask


def preprocess_data(image_path, mask_path):
    """
    Preprocesses the image and mask by resizing and normalizing the image,
    and converting the mask to one-hot encoded format.

    Args:
        image_path (str): Path to the image file.
        mask_path (str): Path to the mask file.

    Returns:
        image (Tensor): Preprocessed image tensor (resized, normalized).
        segmentation_mask (Tensor): Preprocessed segmentation mask tensor (resized, one-hot encoded).
    """
    # Load image and mask using helper function
    image, segmentation_mask = load_image_and_mask(image_path, mask_path)

    # Resize the image and mask to the desired input dimensions (input_height, input_width)
    resize_layer_image = keras.layers.Resizing(input_height, input_width, interpolation="bilinear")
    resize_layer_mask = keras.layers.Resizing(input_height, input_width, interpolation="nearest")

    # Apply the resizing layers to both image and mask
    image = resize_layer_image(image)
    segmentation_mask = resize_layer_mask(segmentation_mask)

    # Normalize the image to a range of [0, 1] by dividing pixel values by 255
    image = tf.cast(image, tf.float32) / 255.0

    # Normalize and cast the mask to int32 for one-hot encoding
    segmentation_mask = tf.cast(segmentation_mask, tf.float32) / 255.0
    segmentation_mask = tf.cast(segmentation_mask, tf.int32)

    # One-hot encode the segmentation mask to have shape [height, width, num_classes]
    segmentation_mask = tf.one_hot(segmentation_mask[..., 0], num_classes)

    return image, segmentation_mask


def create_dataset(image_paths, mask_paths, batch_size):
    """
    Creates a TensorFlow dataset from a list of image and mask paths,
    applies preprocessing, batching, shuffling, and prefetching for efficiency.

    Args:
        image_paths (list of str): List of image file paths.
        mask_paths (list of str): List of mask file paths.
        batch_size (int): The size of batches to create.

    Returns:
        dataset (tf.data.Dataset): A batched and prefetched dataset ready for training.
    """
    # Create a dataset from the image and mask file paths
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))

    # Apply the preprocessing function to the dataset in parallel
    dataset = dataset.map(preprocess_data, num_parallel_calls=tf.data.AUTOTUNE)

    # Batch the dataset and shuffle it for better generalization
    dataset = dataset.batch(batch_size)
    dataset = dataset.shuffle(buffer_size=1024).prefetch(buffer_size=tf.data.AUTOTUNE)  # Prefetch for faster training

    return dataset

In [None]:
# Directory paths where the images and masks are stored
image_dir = "Breast-Dataset/images"  # Path to the image dataset
mask_dir = "Breast-Dataset/masks"    # Path to the mask dataset

# Input image dimensions and number of segmentation classes
input_height = 128  # Height to resize the input images
input_width = 128   # Width to resize the input images
num_classes = 2     # Number of segmentation classes (e.g., background and tumor)

# Hyperparameters for training
batch_size = 32        # Number of samples per batch during training
num_filters = 32       # Initial number of filters in the convolutional layers
learning_rate = 1e-3   # Learning rate for the optimizer
num_epochs = 20        # Number of epochs to train the model

In [None]:
# Get the full paths of image and mask files, ensuring only .png files are included
image_paths = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith('.png')])
mask_paths = sorted([os.path.join(mask_dir, fname) for fname in os.listdir(mask_dir) if fname.endswith('.png')])

# Split the dataset into training (70%) and temporary (30%) sets for validation/testing
train_images, temp_images, train_masks, temp_masks = train_test_split(
    image_paths, mask_paths, test_size=0.3, random_state=42)

# Further split the temporary set equally into validation (15%) and test (15%) sets
val_images, test_images, val_masks, test_masks = train_test_split(
    temp_images, temp_masks, test_size=0.5, random_state=42)

# Print the number of images in each subset for confirmation
print(f"Number of training images: {len(train_images)}")
print(f"Number of validation images: {len(val_images)}")
print(f"Number of test images: {len(test_images)}")

In [None]:
# Create TensorFlow datasets for training, validation, and testing by processing the respective image and mask paths
data_train = create_dataset(train_images, train_masks, batch_size)  # Create training dataset
data_val = create_dataset(val_images, val_masks, batch_size)        # Create validation dataset
data_test = create_dataset(test_images, test_masks, batch_size)     # Create testing dataset

In [None]:
# Get the next batch of images and masks from the training dataset
images, masks = next(iter(data_train))

# Convert the first 16 images and masks to NumPy arrays and ensure the data type is float32
images_np = images[:16].numpy().astype("float32")
masks_np = masks[:16].numpy().astype("float32")

# Create a 4x4 grid of subplots for visualization
fig, axs = plt.subplots(4, 4, figsize=(10, 10))

# Loop through the first 16 images to display them along with their masks
for i in range(16):
    ax = axs[i // 4, i % 4]  # Get the appropriate subplot axis
    ax.imshow(images_np[i], cmap="magma")  # Display the image using the 'magma' colormap
    ax.imshow(np.argmax(masks_np[i].squeeze(), axis=-1), cmap="gray", alpha=0.5)  # Overlay the mask with transparency
    ax.axis("off")  # Hide the axis ticks and labels

# Adjust layout to prevent overlap and show the plot
plt.tight_layout()
plt.show()

In [None]:
# Create a U-Net model with the specified input dimensions, number of filters, and number of output classes
model = get_unet_model((input_height, input_width, 1), filters=num_filters, num_classes=num_classes)

# Display the model architecture summary, including layer details and the total number of parameters
model.summary()

In [None]:
# Initialize the custom loss function for training
loss_function = DynamicMemoryFusionLoss(
    # List of loss functions to be combined in the final loss computation
    loss_functions=[
        tf.keras.losses.CategoricalCrossentropy(),     # Standard categorical cross-entropy loss
        iou_loss,                                      # Intersection over Union (IoU) loss
        dice_loss                                      # Dice coefficient loss
    ],
    history_size=50,                                   # Number of past loss values to consider for dynamic weighting
    history_size_max=150,                              # Maximum size of the loss history for tracking purposes
    batch_track_steps=100,                             # Number of batches to track for auxiliary loss function decay rate updates
    auxiliary_loss_function=class_balanced_dice_loss,  # Auxiliary loss function to address class imbalance
    gama=5,                                            # Hyperparameter for controlling the contribution of the auxiliary loss
    decay_rate=0.05,                                   # Decay rate for the auxiliary loss over time
    weighting_method='var',                            # Method used for weighting losses; in this case, based on variance
    training_phase=True                                # Flag indicating whether the model is in training mode
)


# Initialize the Adam optimizer with a specified learning rate for model training
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

In [None]:
# Define a dictionary to store various evaluation metrics for model performance
metrics = {
    "dice": dice,                               # Dice coefficient metric for measuring overlap between predicted and true masks
    "iou": iou,                                 # Intersection over Union metric for evaluating segmentation accuracy
    "f1_score": f1_score,                       # F1 Score to balance precision and recall
    "precision": precision,                     # Precision metric to evaluate the correctness of positive predictions
    "recall": recall,                           # Recall metric to assess the model's ability to identify true positives
    "cb-dice": class_balanced_dice_score        # Class Balanced Dice Score for handling class imbalances in segmentation tasks
}

In [None]:
"""
 Train the model using the specified training, validation, and test datasets,
 along with the defined loss function, optimizer, metrics, and number of epochs.
"""
model = train_model(model=model,                   # The model to be trained
                    data_train=data_train,         # Training dataset containing image-mask pairs
                    data_val=data_val,             # Validation dataset for evaluating model performance during training
                    data_test=data_test,           # Test dataset for final evaluation after training
                    loss_function=loss_function,   # Custom loss function to guide the training process
                    optimizer=optimizer,           # Optimizer to update model weights during training
                    metrics=metrics,               # Dictionary of metrics for tracking model performance
                    num_epochs=num_epochs)         # Number of epochs for which to train the model

In [None]:
# Retrieve a batch of images and corresponding masks from the test dataset
images, masks = next(iter(data_test))

# Convert the first 4 images to NumPy arrays and ensure they are of type float32
images_np = images[:4].numpy().astype("float32")

# Generate predictions for the input images using the trained model
pred_masks = model.predict(images, verbose=0)

# Create a 2x2 grid of subplots for visualization
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

# Loop through the first 4 images and display them along with their predicted masks
for i in range(4):
    ax = axs[i // 2, i % 2]  # Determine subplot position
    ax.imshow(images_np[i], cmap="magma")  # Display the original image
    ax.imshow(np.argmax(pred_masks[i].squeeze(), axis=-1), cmap="gray", alpha=0.5)  # Overlay the predicted mask
    ax.axis("off")  # Hide the axis ticks and labels

# Adjust the layout to prevent overlap and display the plots
plt.tight_layout()
plt.show()

In [None]:
# Define hyperparameters and settings for the model training
num_classes = 4                # Number of classes for segmentation (including background)
input_height = 128             # Height of the input images
input_width = 128              # Width of the input images
learning_rate = 1e-3           # Learning rate for the optimizer
num_epochs = 20                # Number of epochs for training
batch_size = 32                # Number of samples per batch
num_filters = 32               # Initial number of filters in the convolutional layers
shuffle = True                 # Flag to indicate whether to shuffle the dataset during training

In [None]:
"""
 Load the Oxford Pets dataset from TensorFlow Datasets (TFDS)
 The dataset is split into training, validation, and testing sets
 The images are loaded in batches of the specified size, and the order of the files can be shuffled
"""
(data_train, data_val, data_test) = tfds.load(
    name = "oxford_iiit_pet",                        # Name of the dataset to load
    split = ["train[:80%]", "train[80%:]", "test"],  # Splits the training set into 80% for training and 20% for validation
    batch_size = batch_size,                         # Number of samples per batch
    shuffle_files = shuffle,                         # Flag to indicate whether to shuffle the dataset files
)


In [None]:
def preprocess_data(section):
    # Extract the image and segmentation mask from the input section
    image = section["image"]
    segmentation_mask = section["segmentation_mask"]

    # Create resizing layers for both image and segmentation mask
    resize_layer_image = keras.layers.Resizing(input_height, input_width, interpolation="bilinear")
    resize_layer_mask = keras.layers.Resizing(input_height, input_width, interpolation="nearest")

    # Resize the image and segmentation mask
    image = resize_layer_image(image)
    segmentation_mask = resize_layer_mask(segmentation_mask)

    # Normalize the image to the range [0, 1]
    image = tf.cast(image, tf.float32) / 255.0

    # Convert the segmentation mask to one-hot encoded format
    # Use the last channel of the mask and cast to int32 for one-hot encoding
    segmentation_mask = tf.one_hot(tf.cast(segmentation_mask[..., -1], tf.int32), num_classes)

    return image, segmentation_mask

# Preprocess the training dataset with shuffling and prefetching for better performance
data_train = (
    data_train.map(preprocess_data, num_parallel_calls=tf.data.AUTOTUNE)  # Apply preprocessing function
    .shuffle(buffer_size=1024)  # Shuffle the dataset to ensure randomness
    .prefetch(buffer_size=1024)  # Prefetch batches to improve performance
)

# Preprocess the validation dataset similarly to the training dataset
data_val = (
    data_val.map(preprocess_data, num_parallel_calls=tf.data.AUTOTUNE)  # Apply preprocessing function
    .shuffle(buffer_size=1024)  # Shuffle the dataset to ensure randomness
    .prefetch(buffer_size=1024)  # Prefetch batches to improve performance
)

# Preprocess the test dataset similarly to the training dataset
data_test = (
    data_test.map(preprocess_data, num_parallel_calls=tf.data.AUTOTUNE)  # Apply preprocessing function
    .shuffle(buffer_size=1024)  # Shuffle the dataset to ensure randomness
    .prefetch(buffer_size=1024)  # Prefetch batches to improve performance
)

In [None]:
# Get a batch of images and masks from the training dataset
images, masks = next(iter(data_train))

# Convert the first 16 images and masks to NumPy arrays with float32 type
images_np = images[:16].numpy().astype("float32")
masks_np = masks[:16].numpy().astype("float32")

# Create a 4x4 grid for displaying images and their corresponding masks
fig, axs = plt.subplots(4, 4, figsize=(10, 10))

# Iterate over the first 16 images and masks to display them
for i in range(16):
    ax = axs[i // 4, i % 4]  # Get the appropriate subplot axis
    ax.imshow(images_np[i])  # Display the image
    # Overlay the predicted mask on top of the image using a colormap
    ax.imshow(np.argmax(masks_np[i].squeeze(), axis=-1), cmap="turbo", alpha=0.5)
    ax.axis("off")  # Hide the axis for better visualization

# Adjust layout to avoid overlap and show the plot
plt.tight_layout()
plt.show()

In [None]:
print(f"Input image shape: {np.shape(images)[1:]}")
print(f"Mask shape: {np.shape(masks)[1:]}")
print(f"Number of classes: {np.shape(masks)[-1]}")

In [None]:
# Create a U-Net model with the specified input dimensions, number of filters, and number of output classes
model = get_unet_model((input_height, input_width, 3), filters=num_filters, num_classes=num_classes)

# Display the model architecture summary, including layer details and the total number of parameters
model.summary()

In [None]:
# Initialize the custom loss function for training
loss_function = DynamicMemoryFusionLoss(
    # List of loss functions to be combined in the final loss computation
    loss_functions=[
        tf.keras.losses.CategoricalCrossentropy(),     # Standard categorical cross-entropy loss
        iou_loss,                                      # Intersection over Union (IoU) loss
        dice_loss                                      # Dice coefficient loss
    ],
    history_size=100,                                  # Number of past loss values to consider for dynamic weighting
    history_size_max=250,                              # Maximum size of the loss history for tracking purposes
    batch_track_steps=100,                             # Number of batches to track for auxiliary loss function decay rate updates
    auxiliary_loss_function=class_balanced_dice_loss,  # Auxiliary loss function to address class imbalance
    gama=5,                                            # Hyperparameter for controlling the contribution of the auxiliary loss
    decay_rate=0.05,                                   # Decay rate for the auxiliary loss over time
    weighting_method='var',                            # Method used for weighting losses; in this case, based on variance
    training_phase=True                                # Flag indicating whether the model is in training mode
)


# Initialize the Adam optimizer with a specified learning rate for model training
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

In [None]:
# Define a dictionary to store various evaluation metrics for model performance
metrics = {
    "dice": dice,                               # Dice coefficient metric for measuring overlap between predicted and true masks
    "iou": iou,                                 # Intersection over Union metric for evaluating segmentation accuracy
    "f1_score": f1_score,                       # F1 Score to balance precision and recall
    "precision": precision,                     # Precision metric to evaluate the correctness of positive predictions
    "recall": recall,                           # Recall metric to assess the model's ability to identify true positives
    "cb-dice": class_balanced_dice_score        # Class Balanced Dice Score for handling class imbalances in segmentation tasks
}

In [None]:
"""
 Train the model using the specified training, validation, and test datasets,
 along with the defined loss function, optimizer, metrics, and number of epochs.
"""
model = train_model(model=model,                   # The model to be trained
                    data_train=data_train,         # Training dataset containing image-mask pairs
                    data_val=data_val,             # Validation dataset for evaluating model performance during training
                    data_test=data_test,           # Test dataset for final evaluation after training
                    loss_function=loss_function,   # Custom loss function to guide the training process
                    optimizer=optimizer,           # Optimizer to update model weights during training
                    metrics=metrics,               # Dictionary of metrics for tracking model performance
                    num_epochs=num_epochs)         # Number of epochs for which to train the model

In [None]:
# Get a batch of images and masks from the test dataset
images, masks = next(iter(data_test))

# Convert the first 16 images to NumPy arrays with float32 type
images_np = images[:16].numpy().astype("float32")
# Generate predicted masks using the trained model on the test images
pred_masks = model.predict(images, verbose=0)

# Create a 4x4 grid for displaying images and their corresponding predicted masks
fig, axs = plt.subplots(4, 4, figsize=(10, 10))

# Iterate over the first 16 images and their predicted masks for display
for i in range(16):
    ax = axs[i // 4, i % 4]  # Get the appropriate subplot axis
    ax.imshow(images_np[i])  # Display the original image
    # Overlay the predicted mask on top of the image using a colormap
    ax.imshow(np.argmax(pred_masks[i].squeeze(), axis=-1), cmap="turbo", alpha=0.5)
    ax.axis("off")  # Hide the axis for better visualization

# Adjust layout to avoid overlap and show the plot
plt.tight_layout()
plt.show()