In [None]:
import numpy as np
import os
import gc
import cv2
import re
from tensorflow.keras.utils import to_categorical

# ‚úÖ Constants for 224x224
IMG_HEIGHT = 224  # Ensure height is 224
IMG_WIDTH = 224   # Ensure width is 224
CHANNELS = 3  # RGB images
NUM_CLASSES = 4  # Brain, CSP, LV, Background

# ‚úÖ Class mapping from RGB to class index
CLASS_MAP = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0,  # Background
}

image_dir = r"D:\augmented_dataset\images"
mask_dir = r"D:\augmented_dataset\masks"

# ‚úÖ Define destination directories
train_image_dir = r"D:\Updated\train\images"
train_mask_dir = r"D:\Updated\train\masks"
val_image_dir = r"D:\Updated\val\images"
val_mask_dir = r"D:\Updated\val\masks"
test_image_dir = r"D:\Updated\test\images"
test_mask_dir = r"D:\Updated\test\masks"

# ‚úÖ Fix sorting issue using natural sorting
def natural_sort_key(s):
    """Sort filenames numerically instead of lexicographically."""
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

# ‚úÖ Convert RGB mask to class index mask
def rgb_to_class(mask_array):
    """Convert RGB mask to single-channel class index mask."""
    height, width, _ = mask_array.shape
    class_mask = np.zeros((height, width), dtype=np.uint8)

    for rgb, class_idx in CLASS_MAP.items():
        matches = np.all(mask_array == rgb, axis=-1)  # Ensure exact match
        class_mask[matches] = class_idx

    return class_mask

# ‚úÖ Preprocess Filtered Dataset for 224x224
def preprocess_filtered_dataset(image_dir, mask_dir):
    """Preprocess images & masks: normalize, resize, and convert masks to one-hot encoding."""

    # ‚úÖ Load and sort filenames correctly
    image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)
    mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)

    valid_image_paths = []
    valid_mask_paths = []

    # ‚úÖ Ensure each image has a corresponding mask
    for img_file, mask_file in zip(image_filenames, mask_filenames):
        img_path = os.path.join(image_dir, img_file)
        mask_path = os.path.join(mask_dir, mask_file)

        if os.path.exists(img_path) and os.path.exists(mask_path):
            valid_image_paths.append(img_path)
            valid_mask_paths.append(mask_path)
        else:
            print(f"‚ö†Ô∏è Skipping {img_file}: Missing image or mask")

    num_images = len(valid_image_paths)

    # ‚úÖ Initialize arrays
    X = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=np.float32)
    y = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.float32)  # One-hot encoded masks

    print(f"üöÄ Processing {num_images} filtered images and masks...")

    for idx, (img_path, mask_path) in enumerate(zip(valid_image_paths, valid_mask_paths)):
        if idx % 100 == 0:
            print(f"‚úÖ Processed {idx}/{num_images} images")

        # ‚úÖ Load and Resize Image
        img = cv2.imread(img_path)  # Read image in BGR format
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))  # Resize to (224,224)
        img = img.astype(np.float32) / 255.0  # Normalize

        # ‚úÖ Load and Resize Mask
        mask = cv2.imread(mask_path)  # Read mask in BGR format
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)  # Resize mask correctly

        # ‚úÖ Convert RGB mask to class mask
        class_mask = rgb_to_class(mask)

        # ‚úÖ One-hot encode the class mask
        one_hot_mask = to_categorical(class_mask, num_classes=NUM_CLASSES)

        # ‚úÖ Store preprocessed data
        X[idx] = img
        y[idx] = one_hot_mask

        # ‚úÖ Clear memory to prevent memory leaks
        del img, mask, class_mask, one_hot_mask
        gc.collect()

    return X, y

from sklearn.model_selection import train_test_split

# ‚úÖ Process dataset splits
X_train, y_train = preprocess_filtered_dataset(train_image_dir, train_mask_dir)
X_val, y_val = preprocess_filtered_dataset(val_image_dir, val_mask_dir)
X_test, y_test = preprocess_filtered_dataset(test_image_dir, test_mask_dir)

# ‚úÖ Print dataset information
print("\n‚úÖ Dataset Splits:")
print(f"  - Training set: {X_train.shape}, {y_train.shape}")
print(f"  - Validation set: {X_val.shape}, {y_val.shape}")
print(f"  - Test set: {X_test.shape}, {y_test.shape}")

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Data Augmentation configuration for the training set
train_datagen = ImageDataGenerator(
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
)

# Fit the augmentation parameters on the training data
train_datagen.fit(X_train)

In [None]:
print("Unique values in y_train:", np.unique(np.argmax(y_train, axis=-1)))
print("Unique values in y_val:", np.unique(np.argmax(y_val, axis=-1)))
print("Unique values in y_test:", np.unique(np.argmax(y_test, axis=-1)))

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Dense
from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Concatenate, Dropout

# Constants for 224x224 images
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS = 3
NUM_CLASSES = 4  # Brain, CSP, LV, Background

# =========================================================================
# Helper functions for purely functional approach (no custom layer classes)
# =========================================================================

def convolution_block(inputs, filters, kernel_size=3, strides=1, padding='same', use_bias=False, dilation_rate=1):
    """Standard convolution block with batch normalization and activation"""
    x = Conv2D(
        filters, 
        kernel_size, 
        strides=strides,
        padding=padding,
        use_bias=use_bias,
        dilation_rate=dilation_rate
    )(inputs)
    x = BatchNormalization()(x)
    x = Activation('gelu')(x)
    return x

def mlp_block(inputs, hidden_dim, dropout_rate=0.0):
    """MLP block for transformer"""
    x = Dense(hidden_dim, activation='gelu')(inputs)
    x = Dropout(dropout_rate)(x)
    x = Dense(inputs.shape[-1])(x)
    x = Dropout(dropout_rate)(x)
    return x

def window_attention_block(inputs, num_heads):
    """Simplified window attention using standard Keras MultiHeadAttention"""
    # Extract input dimensions
    input_shape = tf.keras.backend.int_shape(inputs)
    B, H, W, C = input_shape[0], input_shape[1], input_shape[2], input_shape[3]
    
    # Reshape to sequence for attention
    x_flat = Reshape((-1, C))(inputs)  # [B, H*W, C]
    
    # Apply multi-head attention
    attn_output = layers.MultiHeadAttention(
        num_heads=num_heads,
        key_dim=C // num_heads,
    )(x_flat, x_flat)
    
    # Reshape back to original dimensions
    output = Reshape((H, W, C))(attn_output)
    
    return output

def swin_transformer_block(inputs, dim, num_heads, mlp_ratio=4.0, dropout_rate=0.0):
    """Functional Swin Transformer block"""
    # Layer norm 1
    norm1 = layers.LayerNormalization(epsilon=1e-5)(inputs)
    
    # Window attention
    attn_output = window_attention_block(norm1, num_heads)
    
    # First residual connection
    res1 = layers.add([inputs, attn_output])
    
    # Layer norm 2
    norm2 = layers.LayerNormalization(epsilon=1e-5)(res1)
    
    # MLP block
    mlp_hidden_dim = int(dim * mlp_ratio)
    mlp_output = Reshape((-1, dim))(norm2)
    mlp_output = mlp_block(mlp_output, mlp_hidden_dim, dropout_rate)
    mlp_output = Reshape(tf.keras.backend.int_shape(norm2)[1:-1] + (dim,))(mlp_output)
    
    # Second residual connection
    res2 = layers.add([res1, mlp_output])
    
    return res2

def downsample_block(inputs, out_dim):
    """Downsample resolution by 2x and increase channels"""
    x = layers.LayerNormalization(epsilon=1e-5)(inputs)
    x = Conv2D(out_dim, kernel_size=2, strides=2, padding='same')(x)
    return x

def swin_transformer_stage(inputs, dim, depth, num_heads, downsample=True):
    """Functional implementation of a Swin Transformer stage"""
    x = inputs
    
    # Apply transformer blocks
    for i in range(depth):
        x = swin_transformer_block(
            x, dim=dim, num_heads=num_heads, 
            mlp_ratio=4.0, dropout_rate=0.1
        )
    
    # Store output before downsampling
    stage_output = x
    
    # Apply downsampling if needed
    if downsample:
        x = downsample_block(x, dim * 2)
        return x, stage_output
    else:
        return x, stage_output

def patch_embedding(inputs, embed_dim=128):
    """Patch embedding layer"""
    x = Conv2D(embed_dim, kernel_size=4, strides=4, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('gelu')(x)
    return x

def ASPP(inputs):
    """Atrous Spatial Pyramid Pooling for DeepLabV3+"""
    input_shape = tf.keras.backend.int_shape(inputs)
    
    # ASPP with different dilation rates
    b0 = convolution_block(inputs, 512, kernel_size=1)
    
    b1 = convolution_block(inputs, 512, kernel_size=3, dilation_rate=6)
    b2 = convolution_block(inputs, 512, kernel_size=3, dilation_rate=12)
    b3 = convolution_block(inputs, 512, kernel_size=3, dilation_rate=18)
    b4 = convolution_block(inputs, 512, kernel_size=3, dilation_rate=24)  # Extra branch
    
    # Global context
    b5 = GlobalAveragePooling2D()(inputs)
    b5 = Reshape((1, 1, input_shape[3]))(b5)
    b5 = convolution_block(b5, 512, kernel_size=1)
    b5 = layers.UpSampling2D(size=(input_shape[1], input_shape[2]), interpolation='bilinear')(b5)
    
    # Concatenate all branches
    x = Concatenate()([b0, b1, b2, b3, b4, b5])
    
    # Project to output channels with more filters
    x = convolution_block(x, 512, kernel_size=1)
    x = convolution_block(x, 512, kernel_size=3)  # Extra conv
    
    return x

def build_deeplabv3_plus_swinv2_functional(input_shape, num_classes):
    """DeepLabV3+ with SwinV2 backbone using functional API only"""
    # Input layer
    inputs = Input(shape=input_shape)
    
    # SwinV2-Base configuration
    embed_dim = 128
    depths = [2, 2, 18, 2]  # Standard Swin-Base depth
    num_heads = [4, 8, 16, 32]
    
    # Patch embedding
    x = patch_embedding(inputs, embed_dim=embed_dim)
    
    # Store low-level features for skip connection
    low_level_features = x  # 56x56
    
    # Apply transformer stages
    features = {}
    features["low_level"] = low_level_features
    
    # Apply stages
    current_dim = embed_dim
    for i in range(len(depths)):
        if i < len(depths) - 1:
            x, stage_output = swin_transformer_stage(
                x, dim=current_dim, depth=depths[i],
                num_heads=num_heads[i], downsample=True
            )
            current_dim *= 2
        else:
            x, stage_output = swin_transformer_stage(
                x, dim=current_dim, depth=depths[i],
                num_heads=num_heads[i], downsample=False
            )
        
        features[f"stage{i+1}"] = stage_output
    
    # ASPP module on final features
    x = ASPP(x)
    
    # Process low-level features
    low_level_features = convolution_block(low_level_features, 128, kernel_size=1)
    
    # Upsample ASPP output to match low-level features size
    low_level_shape = tf.keras.backend.int_shape(low_level_features)
    x = layers.Conv2DTranspose(
        512, 
        kernel_size=3, 
        strides=4, 
        padding='same'
    )(x)
    
    # Ensure right size
    x = layers.Resizing(low_level_shape[1], low_level_shape[2])(x)
    
    # Concatenate features
    x = Concatenate()([x, low_level_features])
    
    # Decoder convolutions
    x = convolution_block(x, 512, kernel_size=3)
    x = convolution_block(x, 512, kernel_size=3)
    x = convolution_block(x, 256, kernel_size=3)  # Extra conv
    
    # Final upsampling to original resolution
    x = layers.Conv2DTranspose(
        256, 
        kernel_size=3, 
        strides=4, 
        padding='same'
    )(x)
    
    # Final processing
    x = layers.Resizing(input_shape[0], input_shape[1])(x)
    x = convolution_block(x, 128, kernel_size=3)  # Extra conv
    
    # Output layer
    outputs = Conv2D(
        num_classes, 
        kernel_size=1, 
        padding='same', 
        activation='softmax'
    )(x)
    
    # Create model
    model = Model(inputs=inputs, outputs=outputs)
    return model

# Build model
model = build_deeplabv3_plus_swinv2_functional(
    input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS),
    num_classes=NUM_CLASSES
)

# Print model summary
model.summary()

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# ‚úÖ Dice Coefficient Metric
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
    union = tf.reduce_sum(y_true, axis=[1,2,3]) + tf.reduce_sum(y_pred, axis=[1,2,3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Combined Loss Function
def combined_loss(y_true, y_pred):
    return weighted_categorical_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# ‚úÖ Custom Dice Coefficient Metric for Each Class
class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx, name=None, **kwargs):  
        if name is None:
            name = f"DiceClass{class_idx}"  
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

# ‚úÖ Function to Get Class-wise Metrics
def class_wise_metrics(num_classes=4):
    return [DiceCoefficient(i) for i in range(num_classes)] + [tf.keras.metrics.MeanIoU(num_classes=num_classes)]

# ‚úÖ Create Data Generator
def create_train_generator(X, y, batch_size=6):
    data_gen_args = dict(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)
    
    seed = 42
    image_generator = image_datagen.flow(X, batch_size=batch_size, seed=seed)
    mask_generator = mask_datagen.flow(y, batch_size=batch_size, seed=seed)
    
    while True:
        X_batch = next(image_generator)
        y_batch = next(mask_generator)
        yield X_batch, y_batch

# ‚úÖ Create the generator
train_generator = create_train_generator(X_train, y_train, batch_size=3)

# ‚úÖ Compile the Model
model.compile(
    optimizer=Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(4)
)

# ‚úÖ Train the Model with the Data Generator
history = model.fit(
    train_generator,
    steps_per_epoch=len(X_train) // 16,
    validation_data=(X_val, y_val),
    epochs=100,
    callbacks=[
        EarlyStopping(
            monitor='val_loss', 
            patience=7, 
            restore_best_weights=True
        ),
        ReduceLROnPlateau(
            monitor='val_loss', 
            factor=0.5, 
            patience=3, 
            min_lr=1e-6
        ),
        ModelCheckpoint(
            'best_DeepLabV3+_Mobilevit.keras',
            monitor='val_loss',
            save_best_only=True
        )
    ]
)

In [None]:
import matplotlib.pyplot as plt

# Extracting data from the history object
history_dict = history.history

# Plotting the training and validation loss
plt.figure(figsize=(12, 6))

# Plotting loss
plt.subplot(1, 2, 1)
plt.plot(history_dict['loss'], label='Training Loss')
plt.plot(history_dict['val_loss'], label='Validation Loss')
plt.title('Loss Curves')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# If accuracy is available, plot training and validation accuracy
if 'accuracy' in history_dict:
    plt.subplot(1, 2, 2)
    plt.plot(history_dict['accuracy'], label='Training Accuracy')
    plt.plot(history_dict['val_accuracy'], label='Validation Accuracy')
    plt.title('Accuracy Curves')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

plt.tight_layout()
plt.show()

In [None]:
print(y_test.shape)  # Should be (num_samples, 128, 128, 3)
print(X_test.shape)  # Should be (num_samples, 128, 128, 3)

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff
from tensorflow.keras.utils import to_categorical

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ Function to convert RGB masks to class index masks
def rgb_to_class_mask(rgb_mask):
    # Create a mask initialized with zeros (for background class)
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)

    # Loop through the RGB_TO_CLASS dictionary
    for rgb, class_idx in RGB_TO_CLASS.items():
        # Identify the pixels with the current RGB value and assign them the class index
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx

    return class_mask

# ‚úÖ Function to calculate Dice Similarity Coefficient (DSC)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ Function to calculate IoU (Intersection over Union)
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Function to calculate Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')  # Return inf if no points for either true or pred class

    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Function to calculate Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')  # Return inf if no points for either true or pred class

    distances = []
    for true_point in true_points:
        distances.append(np.min(np.linalg.norm(pred_points - true_point, axis=1)))
    return np.mean(distances)

# ‚úÖ Function to evaluate the model on the test set class-wise
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=16):
    # Predict in batches
    y_pred = model.predict(X_test, batch_size=batch_size)
    y_pred = np.argmax(y_pred, axis=-1)  # Convert to class index prediction

    # Convert y_test to class index format (since it's one-hot encoded)
    y_test_class = np.argmax(y_test, axis=-1)

    # Initialize lists to store class-wise metrics
    class_metrics = {i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []} for i in range(num_classes)}

    # Calculate metrics for each test sample
    for i in range(len(X_test)):
        true_mask = y_test_class[i]  # one-hot -> class index
        pred_mask = y_pred[i]

        # For each class (0: Background, 1: Brain, 2: CSP, 3: LV)
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            # Dice Coefficient
            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            # IoU
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            # Precision
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Recall
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # F1 Score
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Accuracy
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # # Hausdorff Distance
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # # Average Surface Distance
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # Print class-wise metrics in percentage
    print(f"{'Class':<10}{'Dice Coefficient (%)':<20}{'IoU (%)':<20}{'Precision (%)':<20}{'Recall (%)':<20}{'F1 Score (%)':<20}{'Accuracy (%)':<20}{'Hausdorff Distance':<20}{'Avg Surface Distance':<20}")
    print('-' * 180)

    for class_idx in range(num_classes):
        print(f"Class {class_idx}:")
        print(f"  Dice Coefficient: {np.mean(class_metrics[class_idx]['dice']) * 100:.2f}%")
        print(f"  IoU: {np.mean(class_metrics[class_idx]['iou']) * 100:.2f}%")
        print(f"  Precision: {np.mean(class_metrics[class_idx]['precision']) * 100:.2f}%")
        print(f"  Recall: {np.mean(class_metrics[class_idx]['recall']) * 100:.2f}%")
        print(f"  F1 Score: {np.mean(class_metrics[class_idx]['f1']) * 100:.2f}%")
        print(f"  Accuracy: {np.mean(class_metrics[class_idx]['accuracy']) * 100:.2f}%")
        # print(f"  Hausdorff Distance: {np.mean(class_metrics[class_idx]['hausdorff']):.4f}")
        # print(f"  Average Surface Distance: {np.mean(class_metrics[class_idx]['asd']):.4f}")
        print("-" * 180)

    # Evaluate on test set to print overall test accuracy and loss
    test_loss, *test_metrics = model.evaluate(X_test, y_test, batch_size=batch_size)
    print(f"Test Loss: {test_loss:.4f}")

    for metric, value in zip(model.metrics_names[1:], test_metrics):
        print(f"{metric}: {value:.4f}")

# ‚úÖ Call the evaluation function on the test set class-wise
evaluate_classwise_metrics(model, X_test, y_test)