In [None]:
import numpy as np
import os
import gc
import cv2
import re
from tensorflow.keras.utils import to_categorical

# ‚úÖ Constants for 224x224
IMG_HEIGHT = 224  # Ensure height is 224
IMG_WIDTH = 224   # Ensure width is 224
CHANNELS = 3  # RGB images
NUM_CLASSES = 4  # Brain, CSP, LV, Background

# ‚úÖ Class mapping from RGB to class index
CLASS_MAP = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0,  # Background
}

image_dir = r"D:\augmented_dataset\images"
mask_dir = r"D:\augmented_dataset\masks"

# # ‚úÖ Define destination directories
train_image_dir = r"D:\Updated\train\images"
train_mask_dir = r"D:\Updated\train\masks"
val_image_dir = r"D:\Updated\val\images"
val_mask_dir = r"D:\Updated\val\masks"
test_image_dir = r"D:\Updated\test\images"
test_mask_dir = r"D:\Updated\test\masks"

# ‚úÖ Fix sorting issue using natural sorting
def natural_sort_key(s):
    """Sort filenames numerically instead of lexicographically."""
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

# ‚úÖ Convert RGB mask to class index mask
def rgb_to_class(mask_array):
    """Convert RGB mask to single-channel class index mask."""
    height, width, _ = mask_array.shape
    class_mask = np.zeros((height, width), dtype=np.uint8)

    for rgb, class_idx in CLASS_MAP.items():
        matches = np.all(mask_array == rgb, axis=-1)  # Ensure exact match
        class_mask[matches] = class_idx

    return class_mask

# ‚úÖ Preprocess Filtered Dataset for 224x224
def preprocess_filtered_dataset(image_dir, mask_dir):
    """Preprocess images & masks: normalize, resize, and convert masks to one-hot encoding."""

    # ‚úÖ Load and sort filenames correctly
    image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)
    mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)

    valid_image_paths = []
    valid_mask_paths = []

    # ‚úÖ Ensure each image has a corresponding mask
    for img_file, mask_file in zip(image_filenames, mask_filenames):
        img_path = os.path.join(image_dir, img_file)
        mask_path = os.path.join(mask_dir, mask_file)

        if os.path.exists(img_path) and os.path.exists(mask_path):
            valid_image_paths.append(img_path)
            valid_mask_paths.append(mask_path)
        else:
            print(f"‚ö†Ô∏è Skipping {img_file}: Missing image or mask")

    num_images = len(valid_image_paths)

    # ‚úÖ Initialize arrays
    X = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=np.float32)
    y = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.float32)  # One-hot encoded masks

    print(f"üöÄ Processing {num_images} filtered images and masks...")

    for idx, (img_path, mask_path) in enumerate(zip(valid_image_paths, valid_mask_paths)):
        if idx % 100 == 0:
            print(f"‚úÖ Processed {idx}/{num_images} images")

        # ‚úÖ Load and Resize Image
        img = cv2.imread(img_path)  # Read image in BGR format
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))  # Resize to (224,224)
        img = img.astype(np.float32) / 255.0  # Normalize

        # ‚úÖ Load and Resize Mask
        mask = cv2.imread(mask_path)  # Read mask in BGR format
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)  # Resize mask correctly

        # ‚úÖ Convert RGB mask to class mask
        class_mask = rgb_to_class(mask)

        # ‚úÖ One-hot encode the class mask
        one_hot_mask = to_categorical(class_mask, num_classes=NUM_CLASSES)

        # ‚úÖ Store preprocessed data
        X[idx] = img
        y[idx] = one_hot_mask

        # ‚úÖ Clear memory to prevent memory leaks
        del img, mask, class_mask, one_hot_mask
        gc.collect()

    return X, y

from sklearn.model_selection import train_test_split

# # ‚úÖ Process dataset splits
# X_train, y_train = preprocess_filtered_dataset(train_image_dir, train_mask_dir)
# X_val, y_val = preprocess_filtered_dataset(val_image_dir, val_mask_dir)
X_test, y_test = preprocess_filtered_dataset(test_image_dir, test_mask_dir)

# ‚úÖ Print dataset information
print("\n‚úÖ Dataset Splits:")
# print(f"  - Training set: {X_train.shape}, {y_train.shape}")
# print(f"  - Validation set: {X_val.shape}, {y_val.shape}")
print(f"  - Test set: {X_test.shape}, {y_test.shape}")

In [None]:
import numpy as np
import os
import gc
import cv2
import re
from tensorflow.keras.utils import to_categorical

# ‚úÖ Constants for 224x224
IMG_HEIGHT = 224  # Ensure height is 224
IMG_WIDTH = 224   # Ensure width is 224
CHANNELS = 3  # RGB images
NUM_CLASSES = 4  # Brain, CSP, LV, Background

# ‚úÖ Class mapping from RGB to class index
CLASS_MAP = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0,  # Background
}

image_dir = r"D:\augmented_dataset\images"
mask_dir = r"D:\augmented_dataset\masks"

# # ‚úÖ Define destination directories
train_image_dir = r"D:\Updated\train\images"
train_mask_dir = r"D:\Updated\train\masks"
val_image_dir = r"D:\Updated\val\images"
val_mask_dir = r"D:\Updated\val\masks"
test_image_dir = r"D:\Updated\test\images"
test_mask_dir = r"D:\Updated\test\masks"

# ‚úÖ Fix sorting issue using natural sorting
def natural_sort_key(s):
    """Sort filenames numerically instead of lexicographically."""
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

# ‚úÖ Convert RGB mask to class index mask
def rgb_to_class(mask_array):
    """Convert RGB mask to single-channel class index mask."""
    height, width, _ = mask_array.shape
    class_mask = np.zeros((height, width), dtype=np.uint8)

    for rgb, class_idx in CLASS_MAP.items():
        matches = np.all(mask_array == rgb, axis=-1)  # Ensure exact match
        class_mask[matches] = class_idx

    return class_mask

# ‚úÖ Preprocess Filtered Dataset for 224x224
def preprocess_filtered_dataset(image_dir, mask_dir):
    """Preprocess images & masks: normalize, resize, and convert masks to one-hot encoding."""

    # ‚úÖ Load and sort filenames correctly
    image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)
    mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)

    valid_image_paths = []
    valid_mask_paths = []

    # ‚úÖ Ensure each image has a corresponding mask
    for img_file, mask_file in zip(image_filenames, mask_filenames):
        img_path = os.path.join(image_dir, img_file)
        mask_path = os.path.join(mask_dir, mask_file)

        if os.path.exists(img_path) and os.path.exists(mask_path):
            valid_image_paths.append(img_path)
            valid_mask_paths.append(mask_path)
        else:
            print(f"‚ö†Ô∏è Skipping {img_file}: Missing image or mask")

    num_images = len(valid_image_paths)

    # ‚úÖ Initialize arrays
    X = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=np.float32)
    y = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.float32)  # One-hot encoded masks

    print(f"üöÄ Processing {num_images} filtered images and masks...")

    for idx, (img_path, mask_path) in enumerate(zip(valid_image_paths, valid_mask_paths)):
        if idx % 100 == 0:
            print(f"‚úÖ Processed {idx}/{num_images} images")

        # ‚úÖ Load and Resize Image
        img = cv2.imread(img_path)  # Read image in BGR format
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))  # Resize to (224,224)
        img = img.astype(np.float32) / 255.0  # Normalize

        # ‚úÖ Load and Resize Mask
        mask = cv2.imread(mask_path)  # Read mask in BGR format
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)  # Resize mask correctly

        # ‚úÖ Convert RGB mask to class mask
        class_mask = rgb_to_class(mask)

        # ‚úÖ One-hot encode the class mask
        one_hot_mask = to_categorical(class_mask, num_classes=NUM_CLASSES)

        # ‚úÖ Store preprocessed data
        X[idx] = img
        y[idx] = one_hot_mask

        # ‚úÖ Clear memory to prevent memory leaks
        del img, mask, class_mask, one_hot_mask
        gc.collect()

    return X, y

from sklearn.model_selection import train_test_split

# # # ‚úÖ Process dataset splits
X_train, y_train = preprocess_filtered_dataset(train_image_dir, train_mask_dir)
X_val, y_val = preprocess_filtered_dataset(val_image_dir, val_mask_dir)
# X_test, y_test = preprocess_filtered_dataset(test_image_dir, test_mask_dir)

# # ‚úÖ Print dataset information
# print("\n‚úÖ Dataset Splits:")
# print(f"  - Training set: {X_train.shape}, {y_train.shape}")
# print(f"  - Validation set: {X_val.shape}, {y_val.shape}")
# print(f"  - Test set: {X_test.shape}, {y_test.shape}")

In [None]:
NUM_CLASSES = 4

# ‚úÖ Dice Coefficient (Mean across all classes)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Lov√°sz-Softmax Loss
def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through classes using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())

# ‚úÖ Combined Loss
def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)
    
    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val

In [None]:
import numpy as np
import os
import gc
import cv2
import re
from tensorflow.keras.utils import to_categorical

# ‚úÖ Constants for 224x224
IMG_HEIGHT = 224  # Ensure height is 224
IMG_WIDTH = 224   # Ensure width is 224
CHANNELS = 3  # RGB images
NUM_CLASSES = 4  # Brain, CSP, LV, Background

# ‚úÖ Class mapping from RGB to class index
CLASS_MAP = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0,  # Background
}

test_image_blur_40_dir = r"D:\Updated\test\images - (Blur 40%)"
test_image_blur_20_dir = r"D:\Updated\test\images - (Blur 20%)"
test_image_bright_dir = r"D:\Updated\test\images - (Brightess enhanced)"
test_image_dark_dir = r"D:\Updated\test\images - (Brightess reduction)"
test_mask_dir = r"D:\Updated\test\masks"

# ‚úÖ Fix sorting issue using natural sorting
def natural_sort_key(s):
    """Sort filenames numerically instead of lexicographically."""
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

# ‚úÖ Convert RGB mask to class index mask
def rgb_to_class(mask_array):
    """Convert RGB mask to single-channel class index mask."""
    height, width, _ = mask_array.shape
    class_mask = np.zeros((height, width), dtype=np.uint8)

    for rgb, class_idx in CLASS_MAP.items():
        matches = np.all(mask_array == rgb, axis=-1)  # Ensure exact match
        class_mask[matches] = class_idx

    return class_mask

# ‚úÖ Preprocess Filtered Dataset for 224x224
def preprocess_filtered_dataset(image_dir, mask_dir):
    """Preprocess images & masks: normalize, resize, and convert masks to one-hot encoding."""

    # ‚úÖ Load and sort filenames correctly
    image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)
    mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)

    valid_image_paths = []
    valid_mask_paths = []

    # ‚úÖ Ensure each image has a corresponding mask
    for img_file, mask_file in zip(image_filenames, mask_filenames):
        img_path = os.path.join(image_dir, img_file)
        mask_path = os.path.join(mask_dir, mask_file)

        if os.path.exists(img_path) and os.path.exists(mask_path):
            valid_image_paths.append(img_path)
            valid_mask_paths.append(mask_path)
        else:
            print(f"‚ö†Ô∏è Skipping {img_file}: Missing image or mask")

    num_images = len(valid_image_paths)

    # ‚úÖ Initialize arrays
    X = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=np.float32)
    y = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.float32)  # One-hot encoded masks

    print(f"üöÄ Processing {num_images} filtered images and masks...")

    for idx, (img_path, mask_path) in enumerate(zip(valid_image_paths, valid_mask_paths)):
        if idx % 100 == 0:
            print(f"‚úÖ Processed {idx}/{num_images} images")

        # ‚úÖ Load and Resize Image
        img = cv2.imread(img_path)  # Read image in BGR format
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))  # Resize to (224,224)
        img = img.astype(np.float32) / 255.0  # Normalize

        # ‚úÖ Load and Resize Mask
        mask = cv2.imread(mask_path)  # Read mask in BGR format
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)  # Resize mask correctly

        # ‚úÖ Convert RGB mask to class mask
        class_mask = rgb_to_class(mask)

        # ‚úÖ One-hot encode the class mask
        one_hot_mask = to_categorical(class_mask, num_classes=NUM_CLASSES)

        # ‚úÖ Store preprocessed data
        X[idx] = img
        y[idx] = one_hot_mask

        # ‚úÖ Clear memory to prevent memory leaks
        del img, mask, class_mask, one_hot_mask
        gc.collect()

    return X, y

from sklearn.model_selection import train_test_split

X_test_blur_40, y_test = preprocess_filtered_dataset(test_image_blur_40_dir, test_mask_dir)
X_test_blur_20, y_test = preprocess_filtered_dataset(test_image_blur_20_dir, test_mask_dir)
X_test_bright, y_test = preprocess_filtered_dataset(test_image_bright_dir, test_mask_dir)
X_test_dark, y_test = preprocess_filtered_dataset(test_image_dark_dir, test_mask_dir)

In [None]:
import cv2
import os
from tqdm import tqdm

# ‚úÖ Folder names and the corresponding transformation to apply
variant_transforms = {
    "images - (Blur 20%)": lambda img: cv2.GaussianBlur(img, (5, 5), 0),
    "images - (Blur 40%)": lambda img: cv2.GaussianBlur(img, (11, 11), 0),
    "images - (Brightess reduction)": lambda img: cv2.convertScaleAbs(img, alpha=0.7, beta=0),
    "images - (Brightess enhanced)": lambda img: cv2.convertScaleAbs(img, alpha=1.3, beta=0),
}

# ‚úÖ Base path to your dataset
base_dir = r"D:\Updated\test"

# ‚úÖ Process each folder individually
for folder_name, transform_fn in variant_transforms.items():
    folder_path = os.path.join(base_dir, folder_name)
    image_filenames = sorted(os.listdir(folder_path))
    total = len(image_filenames)

    print(f"\nüîß Updating: {folder_name} ({total} images)")
    for filename in tqdm(image_filenames):
        img_path = os.path.join(folder_path, filename)

        img = cv2.imread(img_path)
        if img is None:
            print(f"‚ö†Ô∏è Skipping unreadable image: {filename}")
            continue

        transformed_img = transform_fn(img)
        success = cv2.imwrite(img_path, transformed_img)

        if not success:
            print(f"‚ùå Failed to overwrite: {filename}")

print("\n‚úÖ All folders updated successfully.")


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Input
from tensorflow.keras.models import Model

# Constants for 224x224 images
IMG_HEIGHT = 224  # Changed from 256 to 224
IMG_WIDTH = 224   # Changed from 256 to 224
CHANNELS = 3  # RGB images
NUM_CLASSES = 4  # Brain, CSP, LV, Background

def conv_block(inputs, filters, kernel_size=(3, 3), padding='same', strides=1):
    """
    Double convolution block with batch normalization
    """
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    x = layers.Conv2D(filters, kernel_size, padding=padding)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    return x

def build_segnet(input_shape, num_classes):
    """
    Build SegNet model
    """
    inputs = Input(input_shape)
    
    # Encoder
    # Block 1
    conv1 = conv_block(inputs, 64)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2), padding='same')(conv1)
    
    # Block 2
    conv2 = conv_block(pool1, 128)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2), padding='same')(conv2)
    
    # Block 3
    conv3 = conv_block(pool2, 256)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2), padding='same')(conv3)
    
    # Block 4
    conv4 = conv_block(pool3, 512)
    pool4 = layers.MaxPooling2D(pool_size=(2, 2), padding='same')(conv4)
    
    # Bridge
    conv5 = conv_block(pool4, 1024)
    
    # Decoder
    # Block 4
    up4 = layers.UpSampling2D(size=(2, 2))(conv5)
    up4 = layers.concatenate([up4, conv4], axis=-1)
    up_conv4 = conv_block(up4, 512)
    
    # Block 3
    up3 = layers.UpSampling2D(size=(2, 2))(up_conv4)
    up3 = layers.concatenate([up3, conv3], axis=-1)
    up_conv3 = conv_block(up3, 256)
    
    # Block 2
    up2 = layers.UpSampling2D(size=(2, 2))(up_conv3)
    up2 = layers.concatenate([up2, conv2], axis=-1)
    up_conv2 = conv_block(up2, 128)
    
    # Block 1
    up1 = layers.UpSampling2D(size=(2, 2))(up_conv2)
    up1 = layers.concatenate([up1, conv1], axis=-1)
    up_conv1 = conv_block(up1, 64)
    
    # Output
    outputs = layers.Conv2D(num_classes, (1, 1), activation='softmax')(up_conv1)
    
    model = Model(inputs=[inputs], outputs=[outputs])
    return model

# Build model
model = build_segnet(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), 
                     num_classes=NUM_CLASSES)

# Print model summary
# model.summary()

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

# Number of classes (adjust if needed)
NUM_CLASSES = 4

# ‚úÖ Dice Coefficient (Mean across all classes)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Lov√°sz-Softmax Loss
def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through classes using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())

# ‚úÖ Combined Loss
def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)
    
    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val

class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx=0, name=None, **kwargs):  # <- default class_idx=0 to avoid missing arg
        if name is None:
            name = f"DiceClass{class_idx}"
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

    def get_config(self):
        config = super().get_config()
        config.update({"class_idx": self.class_idx})
        return config

    @classmethod
    def from_config(cls, config):
        if "class_idx" not in config:
            # Try to extract class index from name like "DiceClass2"
            name = config.get("name", "DiceClass0")
            if name.startswith("DiceClass"):
                config["class_idx"] = int(name.replace("DiceClass", ""))
            else:
                config["class_idx"] = 0
        return cls(**config)

# ‚úÖ Helper to load Dice metrics by name
def dice_metric_loader(name):
    if name.startswith("DiceClass"):
        class_idx = int(name.replace("DiceClass", ""))
        return DiceCoefficient(class_idx=class_idx)
    raise ValueError(f"Unknown Dice metric name: {name}")

# ‚úÖ Register all custom objects for loading the model
custom_objects = {
    'combined_loss': combined_loss,
    'lovasz_softmax_loss': lovasz_softmax_loss,
    'MeanIoU': tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES),
    'DiceCoefficient': DiceCoefficient,
}

# ‚úÖ Add DiceClass0‚Äì3 dynamically
for i in range(NUM_CLASSES):
    custom_objects[f'DiceClass{i}'] = dice_metric_loader(f'DiceClass{i}')

# ‚úÖ Load the model
model_segnet = load_model('C:\\Users\\User\\best_unet_model_onlineDA_128_lovaszloss_segnet.keras', custom_objects=custom_objects)

print("‚úÖ Model loaded successfully.")

In [None]:

import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, UpSampling2D, Concatenate, BatchNormalization, Activation
from tensorflow.keras.applications import InceptionResNetV2
import gc

# Constants
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS = 3
NUM_CLASSES = 4  # Brain, CSP, LV, Background

class ResizeLayer(tf.keras.layers.Layer):
    """Custom layer to resize images."""
    def __init__(self, target_size, **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.target_size = target_size
    
    def call(self, inputs):
        return tf.image.resize(inputs, self.target_size, method='bilinear')
    
    def get_config(self):
        config = super(ResizeLayer, self).get_config()
        config.update({"target_size": self.target_size})
        return config

def conv_block(x, filters, kernel_size=3, padding='same', activation='relu'):
    """Helper function for creating a conv block with BN and activation."""
    x = Conv2D(filters, kernel_size, padding=padding)(x)
    x = BatchNormalization()(x)
    x = Activation(activation)(x)
    # Add a second conv to increase parameters
    x = Conv2D(filters, kernel_size, padding=padding)(x)
    x = BatchNormalization()(x)
    x = Activation(activation)(x)
    return x

def build_full_inceptionresnetv2_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), num_classes=NUM_CLASSES):
    """
    Build a full UNet model with InceptionResNetV2 backbone with 60-70M parameters
    
    Args:
        input_shape: Input shape of the image
        num_classes: Number of output classes
        
    Returns:
        Keras Model instance with UNet architecture
    """
    # Input layer (no fixed batch size)
    inputs = Input(shape=input_shape)
    
    # Create a full InceptionResNetV2 model to use as backbone
    base_model = InceptionResNetV2(
        input_tensor=inputs,
        include_top=False,
        weights='imagenet',
        pooling=None
    )
    
    # Make all layers trainable as requested
    for layer in base_model.layers:
        layer.trainable = True
    
    # Extract features from all encoder levels
    # Standard blocks in InceptionResNetV2
    encoder1 = base_model.get_layer('activation').output  # 111x111x64
    encoder2 = base_model.get_layer('activation_3').output  # 55x55x192
    encoder3 = base_model.get_layer('block35_10_ac').output  # 27x27x320
    encoder4 = base_model.get_layer('block17_20_ac').output  # 13x13x1088
    encoder5 = base_model.get_layer('conv_7b_ac').output  # 6x6x2080
    
    # Use the bottleneck as is - don't reduce its channels
    bottleneck = encoder5  # 6x6x2080
    
    # First, reduce the bottleneck dimensions to control parameter count
    bottleneck = Conv2D(512, 1, padding='same')(bottleneck)
    bottleneck = BatchNormalization()(bottleneck)
    bottleneck = Activation('relu')(bottleneck)
    
    # Level 5 to 4: 6x6 -> 13x13
    up4 = UpSampling2D(size=(2, 2))(bottleneck)
    up4 = ResizeLayer(target_size=(encoder4.shape[1], encoder4.shape[2]))(up4)
    up4 = conv_block(up4, 512, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels before concatenation
    skip4 = Conv2D(256, 1, padding='same')(encoder4)
    skip4 = BatchNormalization()(skip4)
    skip4 = Activation('relu')(skip4)
    
    # Concatenate with skip connection
    merge4 = Concatenate()([up4, skip4])
    merge4 = conv_block(merge4, 384)  # Reduced filters
    
    # Level 4 to 3: 13x13 -> 27x27
    up3 = UpSampling2D(size=(2, 2))(merge4)
    up3 = ResizeLayer(target_size=(encoder3.shape[1], encoder3.shape[2]))(up3)
    up3 = conv_block(up3, 384, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels
    skip3 = Conv2D(128, 1, padding='same')(encoder3)
    skip3 = BatchNormalization()(skip3)
    skip3 = Activation('relu')(skip3)
    
    # Concatenate with skip connection
    merge3 = Concatenate()([up3, skip3])
    merge3 = conv_block(merge3, 192)  # Reduced filters
    
    # Level 3 to 2: 27x27 -> 55x55
    up2 = UpSampling2D(size=(2, 2))(merge3)
    up2 = ResizeLayer(target_size=(encoder2.shape[1], encoder2.shape[2]))(up2)
    up2 = conv_block(up2, 192, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels
    skip2 = Conv2D(96, 1, padding='same')(encoder2)
    skip2 = BatchNormalization()(skip2)
    skip2 = Activation('relu')(skip2)
    
    # Concatenate with skip connection
    merge2 = Concatenate()([up2, skip2])
    merge2 = conv_block(merge2, 96)  # Reduced filters
    
    # Level 2 to 1: 55x55 -> 111x111
    up1 = UpSampling2D(size=(2, 2))(merge2)
    up1 = ResizeLayer(target_size=(encoder1.shape[1], encoder1.shape[2]))(up1)
    up1 = conv_block(up1, 96, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels
    skip1 = Conv2D(48, 1, padding='same')(encoder1)
    skip1 = BatchNormalization()(skip1)
    skip1 = Activation('relu')(skip1)
    
    # Concatenate with skip connection
    merge1 = Concatenate()([up1, skip1])
    merge1 = conv_block(merge1, 48)  # Reduced filters
    
    # Final upsampling to original resolution: 111x111 -> 224x224
    up_final = UpSampling2D(size=(2, 2))(merge1)
    up_final = conv_block(up_final, 32)  # Reduced filters
    
    # Ensure final size matches input
    if up_final.shape[1] != input_shape[0] or up_final.shape[2] != input_shape[1]:
        up_final = ResizeLayer(target_size=(input_shape[0], input_shape[1]))(up_final)
    
    # Add a final segmentation head
    outputs = Conv2D(num_classes, 1, activation='softmax', dtype='float32')(up_final)
    
    # Create and return the model
    model = Model(inputs=inputs, outputs=outputs)
    
    return model

# Create the model
print("Creating full InceptionResNetV2-UNet model...")
model = build_full_inceptionresnetv2_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), num_classes=NUM_CLASSES)
print("Model created successfully!")

# Clear memory
gc.collect()
tf.keras.backend.clear_session()

# Model summary
model.summary()

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

class ResizeLayer(tf.keras.layers.Layer):
    """Custom layer to resize images."""
    def __init__(self, target_size, **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.target_size = target_size
    
    def call(self, inputs):
        return tf.image.resize(inputs, self.target_size, method='bilinear')
    
    def get_config(self):
        config = super(ResizeLayer, self).get_config()
        config.update({"target_size": self.target_size})
        return config

# Number of classes (adjust if needed)
NUM_CLASSES = 4

# ‚úÖ Dice Coefficient (Mean across all classes)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Lov√°sz-Softmax Loss
def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through classes using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())

# ‚úÖ Combined Loss
def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)
    
    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val

class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx=0, name=None, **kwargs):  # <- default class_idx=0 to avoid missing arg
        if name is None:
            name = f"DiceClass{class_idx}"
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

    def get_config(self):
        config = super().get_config()
        config.update({"class_idx": self.class_idx})
        return config

    @classmethod
    def from_config(cls, config):
        if "class_idx" not in config:
            # Try to extract class index from name like "DiceClass2"
            name = config.get("name", "DiceClass0")
            if name.startswith("DiceClass"):
                config["class_idx"] = int(name.replace("DiceClass", ""))
            else:
                config["class_idx"] = 0
        return cls(**config)

# ‚úÖ Helper to load Dice metrics by name
def dice_metric_loader(name):
    if name.startswith("DiceClass"):
        class_idx = int(name.replace("DiceClass", ""))
        return DiceCoefficient(class_idx=class_idx)
    raise ValueError(f"Unknown Dice metric name: {name}")

# ‚úÖ Register all custom objects for loading the model
custom_objects = {
    'ResizeLayer': ResizeLayer,
    'combined_loss': combined_loss,
    'lovasz_softmax_loss': lovasz_softmax_loss,
    'MeanIoU': tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES),
    'DiceCoefficient': DiceCoefficient,
}

# ‚úÖ Add DiceClass0‚Äì3 dynamically
for i in range(NUM_CLASSES):
    custom_objects[f'DiceClass{i}'] = dice_metric_loader(f'DiceClass{i}')

# ‚úÖ Load the model
model_inceptionresnetv2 = load_model('lovaszloss_unet++_inceptionresnetv2.keras', custom_objects=custom_objects)

print("‚úÖ Model loaded successfully.")

In [None]:
import tensorflow as tf
from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Activation, Add
from tensorflow.keras.layers import Dense, Dropout, Layer, Reshape, Permute, Multiply, Concatenate
from tensorflow.keras.layers import GlobalAveragePooling2D, LayerNormalization, UpSampling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.applications import EfficientNetB4

class ResizeToMatchLayer(Layer):
    """Layer to resize input to match target tensor's spatial dimensions."""
    def __init__(self, **kwargs):
        super(ResizeToMatchLayer, self).__init__(**kwargs)
    
    def call(self, inputs):
        x, target = inputs
        # Get spatial dimensions of target tensor
        target_shape = tf.shape(target)
        target_height, target_width = target_shape[1], target_shape[2]
        
        # Resize x to match target's spatial dimensions
        return tf.image.resize(x, [target_height, target_width], method='bilinear')
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[1][1], input_shape[1][2], input_shape[0][3])

def conv_block(x, filters, kernel_size=3, strides=1, padding='same', use_bn=True, activation='relu'):
    """Standard convolution block with BatchNorm and activation."""
    x = Conv2D(filters, kernel_size, strides=strides, padding=padding)(x)
    
    if use_bn:
        x = BatchNormalization()(x)
    
    if activation:
        x = Activation(activation)(x)
    
    return x

def attention_gate(x, g, inter_channels):
    """
    Attention Gate as described in Attention U-Net paper.
    Args:
        x: Feature map from skip connection (from encoder)
        g: Gating signal from previous decoder layer
        inter_channels: Number of channels in intermediate representations
    """
    # Resize gating signal to match feature map's spatial dimensions if needed
    g = ResizeToMatchLayer()([g, x])
    
    # Intermediate representation for input feature map
    theta_x = Conv2D(inter_channels, 1, use_bias=False, padding='same')(x)
    
    # Intermediate representation for gating signal
    phi_g = Conv2D(inter_channels, 1, use_bias=False, padding='same')(g)
    
    # Element-wise sum and ReLU
    f = Activation('relu')(Add()([theta_x, phi_g]))
    
    # 1x1 convolution followed by sigmoid to get attention coefficients
    psi_f = Conv2D(1, 1, use_bias=False, padding='same')(f)
    att_map = Activation('sigmoid')(psi_f)
    
    # Apply attention
    return Multiply()([x, att_map])

def decoder_block(x, skip_connection, filters, use_attention=True):
    """Decoder block for Attention U-Net."""
    # Upsampling
    x = UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
    
    # Ensure dimensions match for concatenation
    x = ResizeToMatchLayer()([x, skip_connection])
    
    # Apply attention mechanism if specified
    if use_attention:
        # Generate attention-gated skip connection
        skip_connection = attention_gate(skip_connection, x, filters // 2)
    
    # Concatenate with skip connection
    x = Concatenate()([x, skip_connection])
    
    # Apply two convolution blocks
    x = conv_block(x, filters, 3, padding='same')
    x = conv_block(x, filters, 3, padding='same')
    
    return x

def build_efficientnet_attention_unet(input_shape, num_classes):
    """
    Build an Attention U-Net model with EfficientNetB4 backbone for semantic segmentation.
    
    Args:
        input_shape: Input shape of the image (height, width, channels)
        num_classes: Number of segmentation classes
        
    Returns:
        A Keras Model instance
    """
    inputs = Input(shape=input_shape)
        
    # Load EfficientNetB4 with pre-trained weights as encoder backbone
    # All layers are trainable for fine-tuning
    base_model = EfficientNetB4(
        weights='imagenet',
        include_top=False,
        input_tensor=inputs
    )
    
    # Reduce filter count to control parameter count since we're not freezing any layers
    initial_filters = 32
    
    # Get skip connections from appropriate layers
    skip1 = base_model.get_layer('block1b_add').output        # 1/2 scale (112x112)
    skip2 = base_model.get_layer('block2d_add').output        # 1/4 scale (56x56)
    skip3 = base_model.get_layer('block3d_add').output        # 1/8 scale (28x28)
    skip4 = base_model.get_layer('block5e_add').output        # 1/16 scale (14x14)
    
    # Bridge (bottleneck)
    bridge = base_model.get_layer('top_activation').output    # 1/32 scale (7x7)
    
    
    # Reduce channels for each skip connection to control parameter count
    skip1_conv = conv_block(skip1, initial_filters)
    skip2_conv = conv_block(skip2, initial_filters * 2)
    skip3_conv = conv_block(skip3, initial_filters * 4)
    skip4_conv = conv_block(skip4, initial_filters * 8)
    
    # Reduce channels in bridge
    bridge_conv = conv_block(bridge, initial_filters * 16)
    
    # Decoder pathway with attention gates
    d1 = decoder_block(bridge_conv, skip4_conv, initial_filters * 8, use_attention=True)  # 1/16
    d2 = decoder_block(d1, skip3_conv, initial_filters * 4, use_attention=True)           # 1/8
    d3 = decoder_block(d2, skip2_conv, initial_filters * 2, use_attention=True)           # 1/4
    d4 = decoder_block(d3, skip1_conv, initial_filters, use_attention=True)               # 1/2
    
    # Final upsampling to original image size
    final = UpSampling2D(size=(2, 2), interpolation='bilinear')(d4)
    
    # Final convolution to generate segmentation map
    outputs = Conv2D(num_classes, 1, padding='same', activation='softmax')(final)
    
    # Create and return the model
    model = Model(inputs=inputs, outputs=outputs)
    
    return model

# Build the model
# model = build_efficientnet_attention_unet(input_shape=(224, 224, 3), num_classes=4)

# Print model summary
# model.summary()

In [None]:
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
    union = tf.reduce_sum(y_true, axis=[1,2,3]) + tf.reduce_sum(y_pred, axis=[1,2,3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Custom Dice Coefficient Metric for Each Class
class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx, name=None, **kwargs):  
        if name is None:
            name = f"DiceClass{class_idx}"  
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

# ‚úÖ Function to Get Class-wise Metrics
def class_wise_metrics(num_classes=4):
    return [DiceCoefficient(i) for i in range(num_classes)] + [tf.keras.metrics.MeanIoU(num_classes=num_classes)]

model_efficientnetb4 = build_efficientnet_attention_unet(input_shape=(224, 224, 3), num_classes=4)
model_efficientnetb4.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss=combined_loss,
    metrics=class_wise_metrics(4)  # Number of classes
)
model_efficientnetb4.load_weights("efficientnet_attention_unet_weights.h5")

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.applications import Xception

# Constants for 224x224 images
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS = 3
NUM_CLASSES = 4  # Brain, CSP, LV, Background

def convolution_block(inputs, filters, kernel_size=3, dilation_rate=1, padding='same', use_bias=False):
    """
    Standard convolution block with batch normalization and ReLU activation
    """
    x = layers.Conv2D(
        filters, 
        kernel_size, 
        padding=padding,
        dilation_rate=dilation_rate,
        use_bias=use_bias
    )(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

def ASPP(inputs):
    """
    Atrous Spatial Pyramid Pooling module for DeepLabV3+
    """
    # ASPP with different dilation rates
    b0 = convolution_block(inputs, 256, kernel_size=1, dilation_rate=1)
    b1 = convolution_block(inputs, 256, kernel_size=3, dilation_rate=6)
    b2 = convolution_block(inputs, 256, kernel_size=3, dilation_rate=12)
    b3 = convolution_block(inputs, 256, kernel_size=3, dilation_rate=18)
    
    # Global context - simplified approach
    b4 = layers.GlobalAveragePooling2D()(inputs)
    b4 = layers.Reshape((1, 1, inputs.shape[-1]))(b4)
    b4 = convolution_block(b4, 256, kernel_size=1)
    # Use fixed upsampling instead of dynamic
    b4 = layers.UpSampling2D(size=(inputs.shape[1], inputs.shape[2]))(b4)
    
    # Concatenate all branches
    x = layers.Concatenate()([b0, b1, b2, b3, b4])
    
    # Final 1x1 convolution
    output = convolution_block(x, 256, kernel_size=1)
    return output

def build_deeplabv3_plus_xception(input_shape, num_classes):
    """
    DeepLabV3+ model with Xception backbone
    """
    inputs = Input(input_shape)
    
    # Xception as backbone (with output stride of 16)
    base_model = Xception(
        input_tensor=inputs,
        include_top=False,
        weights='imagenet'
    )
    
    # Don't freeze any layers
    for layer in base_model.layers:
        layer.trainable = True
    
    # Extract features from Xception
    # The entry flow ends with 'block4_sepconv2_bn' which is a good low-level feature point
    low_level_features = base_model.get_layer('block4_sepconv2_bn').output
    # The final features from the exit flow
    high_level_features = base_model.output
    
    # Process low-level features
    low_level_features = convolution_block(low_level_features, 48, kernel_size=1)
    
    # Process high-level features with ASPP
    x = ASPP(high_level_features)
    
    # Calculate upsampling factor for high-level features to match low-level features
    hl_shape = high_level_features.shape
    ll_shape = low_level_features.shape
    h_factor = ll_shape[1] // hl_shape[1]
    w_factor = ll_shape[2] // hl_shape[2]
    
    # Upsample high-level features to match low-level features
    x = layers.UpSampling2D(size=(h_factor, w_factor), interpolation='bilinear')(x)
    
    # Concatenate features
    x = layers.Concatenate()([x, low_level_features])
    
    # Apply convolution blocks
    x = convolution_block(x, 256, kernel_size=3)
    x = convolution_block(x, 256, kernel_size=3)
    
    # Calculate upsampling factor needed to reach 224x224
    current_shape = x.shape
    h_factor = IMG_HEIGHT // current_shape[1]
    w_factor = IMG_WIDTH // current_shape[2]
    
    # Final upsampling to original size (224x224)
    x = layers.UpSampling2D(size=(h_factor, w_factor), interpolation='bilinear')(x)
    
    # Ensure exact dimensions with a reshape if needed
    x = layers.Reshape((IMG_HEIGHT, IMG_WIDTH, int(current_shape[3])))(x)
    
    # Final convolution for output (224, 224, 4)
    outputs = layers.Conv2D(num_classes, kernel_size=1, padding='same', activation='softmax')(x)
    
    # Create model
    model = Model(inputs=inputs, outputs=outputs)
    return model

# Build model
model = build_deeplabv3_plus_xception(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), 
                                     num_classes=NUM_CLASSES)

# Print model summary
# model.summary()

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

# Number of classes (adjust if needed)
NUM_CLASSES = 4

# ‚úÖ Dice Coefficient (Mean across all classes)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Lov√°sz-Softmax Loss
def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through classes using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())

# ‚úÖ Combined Loss
def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)
    
    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val

class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx=0, name=None, **kwargs):  # <- default class_idx=0 to avoid missing arg
        if name is None:
            name = f"DiceClass{class_idx}"
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

    def get_config(self):
        config = super().get_config()
        config.update({"class_idx": self.class_idx})
        return config

    @classmethod
    def from_config(cls, config):
        if "class_idx" not in config:
            # Try to extract class index from name like "DiceClass2"
            name = config.get("name", "DiceClass0")
            if name.startswith("DiceClass"):
                config["class_idx"] = int(name.replace("DiceClass", ""))
            else:
                config["class_idx"] = 0
        return cls(**config)

# ‚úÖ Helper to load Dice metrics by name
def dice_metric_loader(name):
    if name.startswith("DiceClass"):
        class_idx = int(name.replace("DiceClass", ""))
        return DiceCoefficient(class_idx=class_idx)
    raise ValueError(f"Unknown Dice metric name: {name}")



custom_objects = {
    'combined_loss': combined_loss,
    'lovasz_softmax_loss': lovasz_softmax_loss,
    'MeanIoU': tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES),
    'DiceCoefficient': DiceCoefficient,
}

# ‚úÖ Add DiceClass0‚Äì3 dynamically
for i in range(NUM_CLASSES):
    custom_objects[f'DiceClass{i}'] = dice_metric_loader(f'DiceClass{i}')

# ‚úÖ Load the model
model_xception = load_model('lovaszloss_deeplabv3_xception.keras', custom_objects=custom_objects)

In [None]:
import numpy as np
import tensorflow as tf
import scipy.spatial.distance as dist
from scipy.ndimage import binary_erosion
import psutil
import gc
import logging

# Set up logging for better control over debug and output verbosity
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Check system memory
def check_memory():
    """Check available system memory"""
    memory = psutil.virtual_memory()
    logger.info(f"Total RAM: {memory.total / (1024**3):.2f} GB")
    logger.info(f"Available RAM: {memory.available / (1024**3):.2f} GB")
    logger.info(f"Used RAM: {memory.used / (1024**3):.2f} GB")
    logger.info(f"Memory percentage used: {memory.percent:.1f}%")
    return memory

def compute_surface_distances_optimized(pred, true, max_points=1000):
    """
    Memory-optimized surface distance calculation with systematic point sampling.
    """
    pred = tf.cast(pred, tf.bool).numpy()
    true = tf.cast(true, tf.bool).numpy()
    
    # Extract boundary points
    pred_boundary = get_boundary_points(pred)
    true_boundary = get_boundary_points(true)
    
    logger.info(f"Pred boundary points: {len(pred_boundary)}, True boundary points: {len(true_boundary)}")
    
    # Handle edge cases
    if len(pred_boundary) == 0 or len(true_boundary) == 0:
        if len(pred_boundary) == 0 and len(true_boundary) == 0:
            return np.array([[0]]), np.array([[0]])
        else:
            return np.array([[np.inf]]), np.array([[np.inf]])
    
    # Sample points if too many (to avoid memory issues)
    if len(pred_boundary) > max_points:
        step = len(pred_boundary) // max_points
        pred_boundary = pred_boundary[::step][:max_points]
        logger.info(f"Systematically sampled pred boundary to {len(pred_boundary)} points")
    
    if len(true_boundary) > max_points:
        step = len(true_boundary) // max_points
        true_boundary = true_boundary[::step][:max_points]
        logger.info(f"Systematically sampled true boundary to {len(true_boundary)} points")
    
    # Estimate memory needed
    estimated_memory_gb = (len(pred_boundary) * len(true_boundary) * 8) / (1024**3)
    logger.info(f"Estimated memory needed: {estimated_memory_gb:.2f} GB")
    
    # Use chunked computation if still too large
    if estimated_memory_gb > 2.0:  # If > 2GB, use chunked approach
        return compute_distances_chunked(pred_boundary, true_boundary)
    else:
        # Compute distances normally
        dist_pred_to_true = dist.cdist(pred_boundary, true_boundary, 'euclidean')
        dist_true_to_pred = dist.cdist(true_boundary, pred_boundary, 'euclidean')
        return dist_pred_to_true, dist_true_to_pred

def compute_distances_chunked(pred_boundary, true_boundary, chunk_size=500):
    """
    Compute distances in chunks to avoid memory issues.
    """
    logger.info("Using chunked computation...")
    
    # Initialize arrays to store minimum distances
    min_pred_to_true = np.full(len(pred_boundary), np.inf)
    min_true_to_pred = np.full(len(true_boundary), np.inf)
    
    # Process pred_boundary in chunks
    for i in range(0, len(pred_boundary), chunk_size):
        end_i = min(i + chunk_size, len(pred_boundary))
        chunk_pred = pred_boundary[i:end_i]
        
        # Compute distances for this chunk
        chunk_dist = dist.cdist(chunk_pred, true_boundary, 'euclidean')
        
        # Update minimum distances
        min_pred_to_true[i:end_i] = np.min(chunk_dist, axis=1)
        
        # Clean up
        del chunk_dist
        gc.collect()
    
    # Process true_boundary in chunks
    for j in range(0, len(true_boundary), chunk_size):
        end_j = min(j + chunk_size, len(true_boundary))
        chunk_true = true_boundary[j:end_j]
        
        # Compute distances for this chunk
        chunk_dist = dist.cdist(chunk_true, pred_boundary, 'euclidean')
        
        # Update minimum distances
        min_true_to_pred[j:end_j] = np.min(chunk_dist, axis=1)
        
        # Clean up
        del chunk_dist
        gc.collect()
    
    # Return as 2D arrays for compatibility
    return min_pred_to_true.reshape(-1, 1), min_true_to_pred.reshape(-1, 1)

def get_boundary_points(mask):
    """Extract boundary points with optional thinning."""
    if not np.any(mask):
        return np.array([]).reshape(0, mask.ndim)
    
    # Get boundary using binary erosion
    eroded = binary_erosion(mask)
    boundary = mask & ~eroded
    boundary_points = np.argwhere(boundary)
    
    return boundary_points

def hausdorff_distance_optimized(dist_pred_to_true, dist_true_to_pred):
    """Optimized Hausdorff distance calculation."""
    if dist_pred_to_true.size == 0 or dist_true_to_pred.size == 0:
        return np.inf
    
    # Handle chunked output (1D arrays)
    max_dist_pred_to_true = np.max(np.min(dist_pred_to_true, axis=1))
    max_dist_true_to_pred = np.max(np.min(dist_true_to_pred, axis=1))
    
    return max(max_dist_pred_to_true, max_dist_true_to_pred)

def average_symmetric_surface_distance_optimized(dist_pred_to_true, dist_true_to_pred):
    """Optimized ASD calculation."""
    if dist_pred_to_true.size == 0 or dist_true_to_pred.size == 0:
        return np.inf
    
    # Handle chunked output (1D arrays)
    avg_dist_pred_to_true = np.mean(np.min(dist_pred_to_true, axis=1))
    avg_dist_true_to_pred = np.mean(np.min(dist_true_to_pred, axis=1))
    
    return (avg_dist_pred_to_true + avg_dist_true_to_pred) / 2

def calculate_mean_hd_and_asd_optimized(model, x_test, y_test, max_boundary_points=1000):
    """
    Memory-optimized calculation of mean HD and ASD.
    """
    logger.info("Checking system memory:")
    check_memory()
    logger.info("Checking system memory:")
    
    all_hd = []
    all_asd = []
    
    for i in range(len(x_test)):
        logger.info(f"Processing sample {i+1}/{len(x_test)}")
        try:
            y_pred = model.predict(np.expand_dims(x_test[i], axis=0), verbose=0)
            y_true = y_test[i]
            
            # Convert to binary
            if y_pred.max() <= 1.0 and y_pred.min() >= 0.0:
                y_pred_binary = (y_pred[0] > 0.5).astype(bool)
            else:
                y_pred_binary = y_pred[0].astype(bool)
            
            y_true_binary = y_true.astype(bool)
            
            logger.info(f"Mask shapes - Pred: {y_pred_binary.shape}, True: {y_true_binary.shape}")
            
            # Compute surface distances
            dist_pred_to_true, dist_true_to_pred = compute_surface_distances_optimized(
                y_pred_binary, y_true_binary, max_boundary_points
            )
            
            # Compute metrics
            hd = hausdorff_distance_optimized(dist_pred_to_true, dist_true_to_pred)
            asd = average_symmetric_surface_distance_optimized(dist_pred_to_true, dist_true_to_pred)
            
            if not np.isinf(hd) and not np.isinf(asd):
                all_hd.append(hd)
                all_asd.append(asd)
                logger.info(f"HD: {hd:.4f}, ASD: {asd:.4f}")
            else:
                logger.info(f"Skipping sample {i} due to empty mask(s)")
            
            # Clean up memory
            del dist_pred_to_true, dist_true_to_pred
            gc.collect()
            
        except Exception as e:
            logger.error(f"Error processing sample {i}: {e}")
            continue
        
        logger.info("Checking system memory:")
    
    if len(all_hd) == 0:
        logger.warning("No valid samples to compute metrics")
        return np.nan, np.nan
    
    mean_hd = np.mean(all_hd)
    mean_asd = np.mean(all_asd)
    
    return mean_hd, mean_asd

# Check your current memory usage
logger.info("=== SYSTEM MEMORY CHECK ===")
memory_info = check_memory()
logger.info("End of system memory check.")  # Add a message to avoid TypeError

# Run the optimized calculation
logger.info("=== RUNNING OPTIMIZED SURFACE DISTANCE CALCULATION ===")
try:
    # Use fewer boundary points to reduce memory usage
    mean_hd, mean_asd = calculate_mean_hd_and_asd_optimized(
        model_segnet, X_test, y_test, max_boundary_points=500
    )
    logger.info(f"\nFinal Results:")
    logger.info(f"Mean Hausdorff Distance: {mean_hd:.4f}")
    logger.info(f"Mean Average Symmetric Surface Distance: {mean_asd:.4f}")
except Exception as e:
    logger.error(f"Error: {e}")

In [None]:
logger.info("=== RUNNING OPTIMIZED SURFACE DISTANCE CALCULATION ===")
try:
    # Use fewer boundary points to reduce memory usage
    mean_hd, mean_asd = calculate_mean_hd_and_asd_optimized(
        model_xception, X_test, y_test, max_boundary_points=500
    )
    logger.info(f"\nFinal Results:")
    logger.info(f"Mean Hausdorff Distance: {mean_hd:.4f}")
    logger.info(f"Mean Average Symmetric Surface Distance: {mean_asd:.4f}")
except Exception as e:
    logger.error(f"Error: {e}")

In [None]:
logger.info("=== RUNNING OPTIMIZED SURFACE DISTANCE CALCULATION ===")
try:
    # Use fewer boundary points to reduce memory usage
    mean_hd, mean_asd = calculate_mean_hd_and_asd_optimized(
        model_efficientnetb4, X_test, y_test, max_boundary_points=500
    )
    logger.info(f"\nFinal Results:")
    logger.info(f"Mean Hausdorff Distance: {mean_hd:.4f}")
    logger.info(f"Mean Average Symmetric Surface Distance: {mean_asd:.4f}")
except Exception as e:
    logger.error(f"Error: {e}")

In [None]:
logger.info("=== RUNNING OPTIMIZED SURFACE DISTANCE CALCULATION ===")
try:
    # Use fewer boundary points to reduce memory usage
    mean_hd, mean_asd = calculate_mean_hd_and_asd_optimized(
        model_inceptionresnetv2, X_test, y_test, max_boundary_points=500
    )
    logger.info(f"\nFinal Results:")
    logger.info(f"Mean Hausdorff Distance: {mean_hd:.4f}")
    logger.info(f"Mean Average Symmetric Surface Distance: {mean_asd:.4f}")
except Exception as e:
    logger.error(f"Error: {e}")

In [None]:
logger.info("=== RUNNING OPTIMIZED SURFACE DISTANCE CALCULATION ===")
try:
    # Use fewer boundary points to reduce memory usage
    mean_hd, mean_asd = calculate_mean_hd_and_asd_optimized(
        student_model, X_test, y_test, max_boundary_points=500
    )
    logger.info(f"\nFinal Results:")
    logger.info(f"Mean Hausdorff Distance: {mean_hd:.4f}")
    logger.info(f"Mean Average Symmetric Surface Distance: {mean_asd:.4f}")
except Exception as e:
    logger.error(f"Error: {e}")

<h1>Knowledge Distillation</h1>

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Data Augmentation configuration for the training set
train_datagen = ImageDataGenerator(
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
)

# Fit the augmentation parameters on the training data
train_datagen.fit(X_train)

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, Activation, UpSampling2D, Concatenate, Add

def res_conv_block(inputs, filters, kernel_size=(3, 3), padding="same", use_batch_norm=True):
    """
    Residual convolutional block with skip connections
    """
    # Store input for residual connection
    shortcut = inputs
    
    # First convolution
    x = Conv2D(filters, kernel_size, padding=padding)(inputs)
    if use_batch_norm:
        x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    # Second convolution
    x = Conv2D(filters, kernel_size, padding=padding)(x)
    if use_batch_norm:
        x = BatchNormalization()(x)
    
    # If input channels don't match output channels, use 1x1 conv to match dimensions
    if shortcut.shape[-1] != filters:
        shortcut = Conv2D(filters, (1, 1), padding=padding)(shortcut)
        if use_batch_norm:
            shortcut = BatchNormalization()(shortcut)
    
    # Add residual connection
    x = Add()([x, shortcut])
    x = Activation("relu")(x)
    
    return x


def UNetPlusPlus(input_shape=(224, 224, 3), num_classes=4, filters=[24, 48, 96, 192], use_batch_norm=True):
    """
    Enhanced UNet++ with residual connections
    
    Args:
        input_shape: Input image dimensions (height, width, channels)
        num_classes: Number of output classes for segmentation
        filters: List of filter dimensions for each level
        use_batch_norm: Whether to use batch normalization
    """
    # Input
    inputs = Input(input_shape)
    
    # Encoder (Downsampling path)
    conv0_0 = res_conv_block(inputs, filters[0], use_batch_norm=use_batch_norm)
    pool0 = MaxPooling2D(pool_size=(2, 2))(conv0_0)
    
    conv1_0 = res_conv_block(pool0, filters[1], use_batch_norm=use_batch_norm)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1_0)
    
    conv2_0 = res_conv_block(pool1, filters[2], use_batch_norm=use_batch_norm)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2_0)
    
    conv3_0 = res_conv_block(pool2, filters[3], use_batch_norm=use_batch_norm)
    
    # Decoder (Upsampling path with nested dense skip connections)
    # Level 1 skip connections
    up1_0 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv3_0)
    concat2_1 = Concatenate()([up1_0, conv2_0])
    conv2_1 = res_conv_block(concat2_1, filters[2], use_batch_norm=use_batch_norm)
    
    up0_1 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv2_0)
    concat1_1 = Concatenate()([up0_1, conv1_0])
    conv1_1 = res_conv_block(concat1_1, filters[1], use_batch_norm=use_batch_norm)
    
    up0_2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv1_0)
    concat0_1 = Concatenate()([up0_2, conv0_0])
    conv0_1 = res_conv_block(concat0_1, filters[0], use_batch_norm=use_batch_norm)
    
    # Level 2 skip connections
    up1_1 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv2_1)
    concat1_2 = Concatenate()([up1_1, conv1_0, conv1_1])
    conv1_2 = res_conv_block(concat1_2, filters[1], use_batch_norm=use_batch_norm)
    
    up0_3 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv1_1)
    concat0_2 = Concatenate()([up0_3, conv0_0, conv0_1])
    conv0_2 = res_conv_block(concat0_2, filters[0], use_batch_norm=use_batch_norm)
    
    # Level 3 skip connections
    up0_4 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv1_2)
    concat0_3 = Concatenate()([up0_4, conv0_0, conv0_1, conv0_2])
    conv0_3 = res_conv_block(concat0_3, filters[0], use_batch_norm=use_batch_norm)
    
    # Output segmentation map (single output)
    outputs = Conv2D(num_classes, (1, 1), activation='softmax')(conv0_3)
    
    # Create model with single output
    model = Model(inputs=[inputs], outputs=[outputs])
    
    return model

student_model = UNetPlusPlus(input_shape=(224, 224, 3), num_classes=4)
student_model.summary()

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

callbacks = [
    EarlyStopping(
        monitor='val_loss',  # You can also use 'val_loss' if you log it manually
        patience=10,
        restore_best_weights=True
    ),
    ReduceLROnPlateau(
        monitor='val_loss',  # Or 'val_loss'
        factor=0.5,
        patience=3,
        min_lr=1e-6
    ),
    ModelCheckpoint(
        filepath='best_student_unetplusplus.keras',
        monitor='val_loss',  # Or 'val_loss'
        save_best_only=True
    )
]

def create_train_generator(X, y, batch_size=16):
    data_gen_args = dict(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)
    
    seed = 42
    image_generator = image_datagen.flow(X, batch_size=batch_size, seed=seed)
    mask_generator = mask_datagen.flow(y, batch_size=batch_size, seed=seed)
    
    while True:
        X_batch = next(image_generator)
        y_batch = next(mask_generator)
        yield X_batch, y_batch

train_generator = create_train_generator(X_train, y_train, batch_size=16)

<h1>Hyperparameter Analysis</h1>

In [None]:
import tensorflow as tf
import numpy as np
from itertools import product
from tensorflow.keras.callbacks import Callback

# === Hyperparameter Space ===
batch_sizes = [8, 16]
temperatures = [1, 3, 9]
alphas = [0.5]
learning_rates = [1e-4, 1e-3]
optimizers_dict = {
    "adam": tf.keras.optimizers.Adam,
    "sgd": tf.keras.optimizers.SGD,
    "rmsprop": tf.keras.optimizers.RMSprop
    
}

# === Result Logging ===
results = []

# === Dummy Callback for Logging Metrics ===
class LoggingCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f"Epoch {epoch + 1}: {logs}")

# === Training Loop ===
trial_num = 1
for batch_size, temp, alpha, lr, (opt_name, opt_class) in product(
    batch_sizes, temperatures, alphas, learning_rates, optimizers_dict.items()
):
    print(f"\n=== Trial {trial_num} ===")
    print(f"Batch Size: {batch_size}, Temp: {temp}, Alpha: {alpha}, LR: {lr}, Optimizer: {opt_name}")

    # === Prepare Optimizer ===
    optimizer = opt_class(learning_rate=lr)

    # === Instantiate Student Model from Scratch ===
    student = UNetPlusPlus(input_shape=(224, 224, 3), num_classes=4)

    # === KD Trainer Setup ===
    kd_model = KDTrainer(
        student=student,
        teacher=teacher_model,
        alpha=alpha,
        temperature=temp
    )

    kd_model.compile(
        optimizer=optimizer,
        metrics=class_wise_metrics(num_classes=4)
    )

    # === Generator and Training ===
    train_generator = create_train_generator(X_train, y_train, batch_size=batch_size)
    steps_per_epoch = len(X_train) // batch_size

    history = kd_model.fit(
        train_generator,
        steps_per_epoch=steps_per_epoch,
        validation_data=(X_val, y_val),
        epochs=5,
        verbose=0,
        callbacks=[LoggingCallback()]
    )

    # === Evaluate & Log ===
    val_metrics = kd_model.evaluate(X_val, y_val, verbose=0)
    metric_names = kd_model.metrics_names
    metric_dict = dict(zip(metric_names, val_metrics))

    results.append({
        "Trial": trial_num,
        "Dice Coefficient": metric_dict.get("dice_coef", np.nan),
        "Batch Size": batch_size,
        "Temperature": temp,
        "Alpha": alpha,
        "Learning Rate": lr,
        "Optimizer": opt_name
    })

    trial_num += 1

# === Print Summary Table ===
import pandas as pd
results_df = pd.DataFrame(results)
print("\n=== Trial Summary ===")
print(results_df[["Trial", "Dice Coefficient", "Batch Size", "Temperature", "Optimizer", "Learning Rate"]])

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

models = [
    model_xception,
    model_segnet,
    model_inceptionresnetv2,
    model_efficientnetb4
]

class WeightedSoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, weights=None, apply_softmax=True):
        super(WeightedSoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        else:
            total = sum(weights)
            weights = [w / total for w in weights]

        self.model_weights = tf.constant(weights, dtype=tf.float32)

    def call(self, x, training=False):
        weighted_sum = 0
        for i, model in enumerate(self.models):
            output = model(x, training=training)

            is_softmaxed = (
                hasattr(model, "name") and "efficientnet" in model.name.lower()
            )

            if self.apply_softmax and not is_softmaxed:
                probs = tf.nn.softmax(output, axis=-1)
            else:
                probs = output

            weighted_sum += self.model_weights[i] * probs

        avg_prob = weighted_sum  # shape: [B, H, W, C]

        # üîÅ Convert to one-hot for metric compatibility
        one_hot_pred = tf.one_hot(tf.argmax(avg_prob, axis=-1), depth=avg_prob.shape[-1])
        return one_hot_pred  # [B, H, W, C]

final_weights = [0.255, 0.2427, 0.2515, 0.2508]
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.0001),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

teacher_model = ensemble_model 

def distillation_loss(y_true, y_student_logits, y_teacher_probs, alpha=0.5, temperature=3.0):
    # Softened predictions for KL
    student_soft = tf.nn.softmax(y_student_logits / temperature)
    teacher_soft = tf.nn.softmax(y_teacher_probs / temperature)

    # Soft loss: KL divergence
    kl_loss = tf.keras.losses.KLDivergence()(teacher_soft, student_soft)

    # Hard loss: Use your custom combined loss (Dice + Lovasz)
    ce_loss = combined_loss(y_true, y_student_logits) + tf.keras.losses.CategoricalCrossentropy()(y_true, y_student_logits)

    # Combine them
    return alpha * ce_loss + (1 - alpha) * (temperature ** 2) * kl_loss

# === KD Wrapper Model ===
class KDTrainer(tf.keras.Model):
    def __init__(self, student, teacher, alpha=0.5, temperature=3.0):
        super(KDTrainer, self).__init__()
        self.student = student
        self.teacher = teacher
        self.alpha = alpha
        self.temperature = temperature

    def compile(self, optimizer, metrics):
        super().compile()
        self.optimizer = optimizer
        self.metrics_list = metrics

    def train_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)

        with tf.GradientTape() as tape:
            student_logits = self.student(x, training=True)               # [B, H, W, C]
            teacher_probs = self.teacher(x, training=False)               # Soft probs

            loss = distillation_loss(
                y_true, student_logits, teacher_probs,
                alpha=self.alpha, temperature=self.temperature
            )

        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))

        for metric in self.metrics_list:
            metric.update_state(y_true, student_logits)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

    def test_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)
        y_pred = self.student(x, training=False)
        loss = combined_loss(y_true, y_pred)

        for metric in self.metrics_list:
            metric.update_state(y_true, y_pred)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

# === Instantiate KDTrainer ===
kd_model = KDTrainer(
    student=student_model,
    teacher=teacher_model,
    alpha=0.5,
    temperature=1.0
)

# === Compile ===
kd_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    metrics=class_wise_metrics(num_classes=4)
)

from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = f"best_student_unetplusplus_{timestamp}"

callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=15,
        restore_best_weights=True
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=6,
        min_lr=1e-7
    ), 
    ModelCheckpoint(
    filepath=checkpoint_path,
    monitor='val_loss',
    save_best_only=True,
    save_weights_only=True,
    save_format='tf'  # ‚úÖ use TF SavedModel format
    )
]


from tensorflow.keras.preprocessing.image import ImageDataGenerator

def create_train_generator(X, y, batch_size=16):
    data_gen_args = dict(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    seed = 42
    image_generator = image_datagen.flow(X, batch_size=batch_size, seed=seed)
    mask_generator = mask_datagen.flow(y, batch_size=batch_size, seed=seed)

    while True:
        X_batch = next(image_generator)
        y_batch = next(mask_generator)
        yield X_batch.astype('float32'), y_batch.astype('float32')

batch_size = 8
train_generator = create_train_generator(X_train, y_train, batch_size=batch_size)
steps_per_epoch = len(X_train) // batch_size


history = kd_model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    validation_data=(X_val, y_val),
    epochs=100,
    callbacks=callbacks
)

In [None]:
student_model.save_weights("student_model_weights_final.h5")

In [None]:
student_model.load_weights("student_model_weights_final.h5")

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

# Number of classes (adjust if needed)
NUM_CLASSES = 4

# ‚úÖ Dice Coefficient (Mean across all classes)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Lov√°sz-Softmax Loss
def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through classes using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())

# ‚úÖ Combined Loss
def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)
    
    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val

class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx=0, name=None, **kwargs):  # <- default class_idx=0 to avoid missing arg
        if name is None:
            name = f"DiceClass{class_idx}"
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

    def get_config(self):
        config = super().get_config()
        config.update({"class_idx": self.class_idx})
        return config

    @classmethod
    def from_config(cls, config):
        if "class_idx" not in config:
            # Try to extract class index from name like "DiceClass2"
            name = config.get("name", "DiceClass0")
            if name.startswith("DiceClass"):
                config["class_idx"] = int(name.replace("DiceClass", ""))
            else:
                config["class_idx"] = 0
        return cls(**config)

# ‚úÖ Helper to load Dice metrics by name
def dice_metric_loader(name):
    if name.startswith("DiceClass"):
        class_idx = int(name.replace("DiceClass", ""))
        return DiceCoefficient(class_idx=class_idx)
    raise ValueError(f"Unknown Dice metric name: {name}")

In [None]:
def class_wise_metrics(num_classes=4):
    return [DiceCoefficient(i) for i in range(num_classes)] + [tf.keras.metrics.MeanIoU(num_classes=num_classes)]

In [None]:
def student_eval_loss(y_true, y_pred):
    return [combined_loss(y_true, y_pred) + tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred)]

student_model.compile(
    optimizer= tf.keras.optimizers.RMSprop(learning_rate=0.0001),
    loss=student_eval_loss,
    metrics=class_wise_metrics(num_classes=4)
)

In [None]:
from tensorflow.keras.utils import Sequence
import cv2
import numpy as np
import os

class ImageMaskGenerator(Sequence):
    def __init__(self, image_paths, mask_paths, batch_size=4, num_classes=4, img_size=(224, 224), shuffle=True):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.img_size = img_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.image_paths))
        self.CLASS_MAP = {
            (255, 0, 0): 1,
            (0, 255, 0): 2,
            (0, 0, 255): 3,
            (0, 0, 0): 0,
        }
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.image_paths) / self.batch_size))

    def __getitem__(self, index):
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        batch_images = []
        batch_masks = []

        for i in batch_indices:
            img = cv2.imread(self.image_paths[i])
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, self.img_size)
            img = img.astype(np.float32) / 255.0

            mask = cv2.imread(self.mask_paths[i])
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
            mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_NEAREST)
            mask = self.rgb_to_class(mask)
            mask = tf.keras.utils.to_categorical(mask, num_classes=self.num_classes)

            batch_images.append(img)
            batch_masks.append(mask)

        return np.array(batch_images), np.array(batch_masks)

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

    def rgb_to_class(self, mask_array):
        h, w, _ = mask_array.shape
        class_mask = np.zeros((h, w), dtype=np.uint8)
        for rgb, class_idx in self.CLASS_MAP.items():
            matches = np.all(mask_array == rgb, axis=-1)
            class_mask[matches] = class_idx
        return class_mask


import os

def load_paths(image_dir, mask_dir):
    images = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')])
    masks = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith('.png')])
    return images, masks

train_imgs, train_masks = load_paths(train_image_dir, train_mask_dir)
val_imgs, val_masks = load_paths(val_image_dir, val_mask_dir)

In [None]:
train_gen = ImageMaskGenerator(train_imgs, train_masks, batch_size=8)
val_gen = ImageMaskGenerator(val_imgs, val_masks, batch_size=8)

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
import numpy as np
import subprocess
import gc
import time

# ‚úÖ Dice Coefficient Metric for Each Class
class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx=0, name=None, **kwargs):
        if name is None:
            name = f"DiceClass{class_idx}"
        super().__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

def class_wise_metrics(num_classes=4):
    return [DiceCoefficient(i) for i in range(num_classes)] + [tf.keras.metrics.MeanIoU(num_classes=num_classes)]

# ‚úÖ Your existing combined loss setup
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    return 1 - tf.reduce_mean((2. * intersection + smooth) / (union + smooth))

def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    num_classes = tf.shape(y_true)[-1]
    start_class = 1 if ignore_background else 0

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]
        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])
        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)
        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)
        return tf.tensordot(errors_sorted, grad, axes=1)

    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)
    def loop_cond(c, _): return c < num_classes
    def loop_body(c, losses): return c + 1, losses.write(c, compute_class_loss(c))
    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())

def combined_loss(y_true, y_pred):
    dice = dice_loss(y_true, y_pred)
    lovasz = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return dice + lovasz

def student_eval_loss(y_true, y_pred):
    return combined_loss(y_true, y_pred) + tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred)

In [None]:
def run_training_student_model(train_gen, val_gen, model_builder_fn, weights_path,
                               batch_size=8, epochs=3, repeats=1):
    epoch_times_all = []
    power_samples_all = []

    for r in range(repeats):
        print(f"\nüîÅ Repeat {r+1}/{repeats}")

        K.clear_session()
        gc.collect()

        model = model_builder_fn()
        model.compile(
            optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.0001),
            loss=student_eval_loss,
            metrics=class_wise_metrics(num_classes=4)
        )
        model.load_weights(weights_path)

        start = time.time()
        power_proc = subprocess.Popen(
            ['nvidia-smi', '--query-gpu=power.draw', '--format=csv,noheader,nounits', '-lms', '500'],
            stdout=subprocess.PIPE,
            stderr=subprocess.DEVNULL,
            text=True
        )

        model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=epochs,
            verbose=1
        )

        end = time.time()
        total_time = end - start
        avg_epoch_time = total_time / epochs
        epoch_times_all.extend([avg_epoch_time] * epochs)

        power_proc.terminate()
        try:
            power_output = power_proc.stdout.read().strip().split('\n')
            power_values = [float(line) for line in power_output if line.strip()]
            avg_power = np.mean(power_values)
            power_samples_all.extend([avg_power] * epochs)
            print(f"‚ö° Avg Power: {avg_power:.2f} W")
        except:
            print("‚ö†Ô∏è Power log failed.")
            power_samples_all.extend([np.nan] * epochs)

        del model
        gc.collect()
        K.clear_session()

    return epoch_times_all, power_samples_all

def build_student_model():
    return UNetPlusPlus(input_shape=(224, 224, 3), num_classes=4)

In [None]:
import numpy as np

epoch_times, power_vals = run_training_student_model(
    train_gen=train_gen,
    val_gen=val_gen,
    model_builder_fn=build_student_model,
    weights_path="student_model_weights_final.h5",
    batch_size=8,
    epochs=3,
    repeats=1
)

mean_time = np.mean(epoch_times)
mean_power = np.nanmean(power_vals)
energy_wh = (mean_time * mean_power) / 3600

print("\nüìä Summary:")
print(f"‚è±Ô∏è Avg time/epoch: {mean_time:.2f} s")
print(f"‚ö° Avg GPU power: {mean_power:.2f} W")
print(f"üîã Avg energy/epoch: {energy_wh:.4f} Wh")

In [None]:
# Estimate GFLOPS per epoch (assuming 4 GFLOPs/sample)
samples_per_epoch = len(X_train)
estimated_flops_per_sample = 4e9  # 4 GFLOPs
gflops = (2 * estimated_flops_per_sample * samples_per_epoch) / (mean_time * 1e9)
print(f"‚öôÔ∏è  Estimated GFLOPS: {gflops:.2f}")

In [None]:
def calculate_gflops(model, input_res=(224, 224, 3)):
    from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
    import tensorflow as tf

    # Create dummy input
    input_shape = (1,) + input_res
    inputs = tf.random.normal(input_shape)
    
    # Convert model to frozen graph
    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(tf.TensorSpec(input_shape, model.inputs[0].dtype))

    frozen_func = convert_variables_to_constants_v2(full_model)
    graph_def = frozen_func.graph.as_graph_def()

    # Calculate FLOPs using TF profiler
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops = tf.compat.v1.profiler.profile(graph=graph,
                                              run_meta=run_meta,
                                              cmd='op',
                                              options=opts)
    gflops = flops.total_float_ops / 1e9
    return gflops

student_model = UNetPlusPlus(input_shape=(224, 224, 3), num_classes=4)
gflops = calculate_gflops(student_model)
print(f"üìè Model GFLOPs: {gflops:.2f} GFLOPs")

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

# Number of classes (adjust if needed)
NUM_CLASSES = 4

# ‚úÖ Dice Coefficient (Mean across all classes)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Lov√°sz-Softmax Loss
def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through classes using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())

# ‚úÖ Combined Loss
def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)
    
    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val

class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx=0, name=None, **kwargs):  # <- default class_idx=0 to avoid missing arg
        if name is None:
            name = f"DiceClass{class_idx}"
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

    def get_config(self):
        config = super().get_config()
        config.update({"class_idx": self.class_idx})
        return config

    @classmethod
    def from_config(cls, config):
        if "class_idx" not in config:
            # Try to extract class index from name like "DiceClass2"
            name = config.get("name", "DiceClass0")
            if name.startswith("DiceClass"):
                config["class_idx"] = int(name.replace("DiceClass", ""))
            else:
                config["class_idx"] = 0
        return cls(**config)

# ‚úÖ Helper to load Dice metrics by name
def dice_metric_loader(name):
    if name.startswith("DiceClass"):
        class_idx = int(name.replace("DiceClass", ""))
        return DiceCoefficient(class_idx=class_idx)
    raise ValueError(f"Unknown Dice metric name: {name}")

# ‚úÖ Register all custom objects for loading the model
custom_objects = {
    'combined_loss': combined_loss,
    'lovasz_softmax_loss': lovasz_softmax_loss,
    'MeanIoU': tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES),
    'DiceCoefficient': DiceCoefficient,
}

# ‚úÖ Add DiceClass0‚Äì3 dynamically
for i in range(NUM_CLASSES):
    custom_objects[f'DiceClass{i}'] = dice_metric_loader(f'DiceClass{i}')

# ‚úÖ Load the model
model_segnet = load_model('C:\\Users\\User\\best_unet_model_onlineDA_128_lovaszloss_segnet.keras', custom_objects=custom_objects)

print("‚úÖ Model loaded successfully.")

In [None]:
def calculate_gflops(model, input_res=(224, 224, 3)):
    import tensorflow as tf
    from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

    # Prepare input
    inputs = tf.random.normal((1,) + input_res)

    # Convert to ConcreteFunction
    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(tf.TensorSpec(inputs.shape, model.inputs[0].dtype))

    # Freeze the model
    frozen_func = convert_variables_to_constants_v2(full_model)
    graph_def = frozen_func.graph.as_graph_def()

    # Calculate FLOPs using profiler
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()

        flops = tf.compat.v1.profiler.profile(
            graph=graph,
            run_meta=run_meta,
            cmd='op',
            options=opts
        )
    
    if flops is None:
        raise ValueError("‚ùå Could not compute FLOPs. Check if the model contains unsupported ops.")
    
    return flops.total_float_ops / 1e9  # Convert to GFLOPs

gflops = calculate_gflops(model_segnet)
print(f"üìè Model GFLOPs: {gflops:.2f} GFLOPs")

In [None]:

import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, UpSampling2D, Concatenate, BatchNormalization, Activation
from tensorflow.keras.applications import InceptionResNetV2
import gc

# Constants
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS = 3
NUM_CLASSES = 4  # Brain, CSP, LV, Background

class ResizeLayer(tf.keras.layers.Layer):
    """Custom layer to resize images."""
    def __init__(self, target_size, **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.target_size = target_size
    
    def call(self, inputs):
        return tf.image.resize(inputs, self.target_size, method='bilinear')
    
    def get_config(self):
        config = super(ResizeLayer, self).get_config()
        config.update({"target_size": self.target_size})
        return config

def conv_block(x, filters, kernel_size=3, padding='same', activation='relu'):
    """Helper function for creating a conv block with BN and activation."""
    x = Conv2D(filters, kernel_size, padding=padding)(x)
    x = BatchNormalization()(x)
    x = Activation(activation)(x)
    # Add a second conv to increase parameters
    x = Conv2D(filters, kernel_size, padding=padding)(x)
    x = BatchNormalization()(x)
    x = Activation(activation)(x)
    return x

def build_full_inceptionresnetv2_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), num_classes=NUM_CLASSES):
    """
    Build a full UNet model with InceptionResNetV2 backbone with 60-70M parameters
    
    Args:
        input_shape: Input shape of the image
        num_classes: Number of output classes
        
    Returns:
        Keras Model instance with UNet architecture
    """
    # Input layer (no fixed batch size)
    inputs = Input(shape=input_shape)
    
    # Create a full InceptionResNetV2 model to use as backbone
    base_model = InceptionResNetV2(
        input_tensor=inputs,
        include_top=False,
        weights='imagenet',
        pooling=None
    )
    
    # Make all layers trainable as requested
    for layer in base_model.layers:
        layer.trainable = True
    
    # Extract features from all encoder levels
    # Standard blocks in InceptionResNetV2
    encoder1 = base_model.get_layer('activation').output  # 111x111x64
    encoder2 = base_model.get_layer('activation_3').output  # 55x55x192
    encoder3 = base_model.get_layer('block35_10_ac').output  # 27x27x320
    encoder4 = base_model.get_layer('block17_20_ac').output  # 13x13x1088
    encoder5 = base_model.get_layer('conv_7b_ac').output  # 6x6x2080
    
    # Use the bottleneck as is - don't reduce its channels
    bottleneck = encoder5  # 6x6x2080
    
    # First, reduce the bottleneck dimensions to control parameter count
    bottleneck = Conv2D(512, 1, padding='same')(bottleneck)
    bottleneck = BatchNormalization()(bottleneck)
    bottleneck = Activation('relu')(bottleneck)
    
    # Level 5 to 4: 6x6 -> 13x13
    up4 = UpSampling2D(size=(2, 2))(bottleneck)
    up4 = ResizeLayer(target_size=(encoder4.shape[1], encoder4.shape[2]))(up4)
    up4 = conv_block(up4, 512, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels before concatenation
    skip4 = Conv2D(256, 1, padding='same')(encoder4)
    skip4 = BatchNormalization()(skip4)
    skip4 = Activation('relu')(skip4)
    
    # Concatenate with skip connection
    merge4 = Concatenate()([up4, skip4])
    merge4 = conv_block(merge4, 384)  # Reduced filters
    
    # Level 4 to 3: 13x13 -> 27x27
    up3 = UpSampling2D(size=(2, 2))(merge4)
    up3 = ResizeLayer(target_size=(encoder3.shape[1], encoder3.shape[2]))(up3)
    up3 = conv_block(up3, 384, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels
    skip3 = Conv2D(128, 1, padding='same')(encoder3)
    skip3 = BatchNormalization()(skip3)
    skip3 = Activation('relu')(skip3)
    
    # Concatenate with skip connection
    merge3 = Concatenate()([up3, skip3])
    merge3 = conv_block(merge3, 192)  # Reduced filters
    
    # Level 3 to 2: 27x27 -> 55x55
    up2 = UpSampling2D(size=(2, 2))(merge3)
    up2 = ResizeLayer(target_size=(encoder2.shape[1], encoder2.shape[2]))(up2)
    up2 = conv_block(up2, 192, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels
    skip2 = Conv2D(96, 1, padding='same')(encoder2)
    skip2 = BatchNormalization()(skip2)
    skip2 = Activation('relu')(skip2)
    
    # Concatenate with skip connection
    merge2 = Concatenate()([up2, skip2])
    merge2 = conv_block(merge2, 96)  # Reduced filters
    
    # Level 2 to 1: 55x55 -> 111x111
    up1 = UpSampling2D(size=(2, 2))(merge2)
    up1 = ResizeLayer(target_size=(encoder1.shape[1], encoder1.shape[2]))(up1)
    up1 = conv_block(up1, 96, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels
    skip1 = Conv2D(48, 1, padding='same')(encoder1)
    skip1 = BatchNormalization()(skip1)
    skip1 = Activation('relu')(skip1)
    
    # Concatenate with skip connection
    merge1 = Concatenate()([up1, skip1])
    merge1 = conv_block(merge1, 48)  # Reduced filters
    
    # Final upsampling to original resolution: 111x111 -> 224x224
    up_final = UpSampling2D(size=(2, 2))(merge1)
    up_final = conv_block(up_final, 32)  # Reduced filters
    
    # Ensure final size matches input
    if up_final.shape[1] != input_shape[0] or up_final.shape[2] != input_shape[1]:
        up_final = ResizeLayer(target_size=(input_shape[0], input_shape[1]))(up_final)
    
    # Add a final segmentation head
    outputs = Conv2D(num_classes, 1, activation='softmax', dtype='float32')(up_final)
    
    # Create and return the model
    model = Model(inputs=inputs, outputs=outputs)
    
    return model

# Create the model
print("Creating full InceptionResNetV2-UNet model...")
model = build_full_inceptionresnetv2_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), num_classes=NUM_CLASSES)
print("Model created successfully!")

# Clear memory
gc.collect()
tf.keras.backend.clear_session()

In [None]:
def calculate_gflops(model, input_res=(224, 224, 3)):
    from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
    import tensorflow as tf

    # Create dummy input
    input_shape = (1,) + input_res
    inputs = tf.random.normal(input_shape)
    
    # Convert model to frozen graph
    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(tf.TensorSpec(input_shape, model.inputs[0].dtype))

    frozen_func = convert_variables_to_constants_v2(full_model)
    graph_def = frozen_func.graph.as_graph_def()

    # Calculate FLOPs using TF profiler
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops = tf.compat.v1.profiler.profile(graph=graph,
                                              run_meta=run_meta,
                                              cmd='op',
                                              options=opts)
    gflops = flops.total_float_ops / 1e9
    return gflops

gflops = calculate_gflops(model)
print(f"üìè Model GFLOPs: {gflops:.2f} GFLOPs")

In [None]:
import time
import numpy as np

def measure_inference_speed(model, input_shape=(224, 224, 3), batch_size=1, num_runs=100):
    dummy_input = np.random.rand(batch_size, *input_shape).astype(np.float32)

    # Warm-up
    for _ in range(10):
        model.predict(dummy_input)

    # Timed inference
    start = time.time()
    for _ in range(num_runs):
        model.predict(dummy_input)
    end = time.time()

    total_time = end - start
    avg_time_per_image = total_time / (num_runs * batch_size)
    fps = 1.0 / avg_time_per_image
    return avg_time_per_image * 1000, fps  # return in ms, FPS

ms, fps = measure_inference_speed(model)
print(f"Inference Time: {ms:.2f} ms/image | FPS: {fps:.2f}")


In [None]:
import tensorflow as tf
from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Activation, Add
from tensorflow.keras.layers import Dense, Dropout, Layer, Reshape, Permute, Multiply, Concatenate
from tensorflow.keras.layers import GlobalAveragePooling2D, LayerNormalization, UpSampling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.applications import EfficientNetB4

class ResizeToMatchLayer(Layer):
    """Layer to resize input to match target tensor's spatial dimensions."""
    def __init__(self, **kwargs):
        super(ResizeToMatchLayer, self).__init__(**kwargs)
    
    def call(self, inputs):
        x, target = inputs
        # Get spatial dimensions of target tensor
        target_shape = tf.shape(target)
        target_height, target_width = target_shape[1], target_shape[2]
        
        # Resize x to match target's spatial dimensions
        return tf.image.resize(x, [target_height, target_width], method='bilinear')
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[1][1], input_shape[1][2], input_shape[0][3])

def conv_block(x, filters, kernel_size=3, strides=1, padding='same', use_bn=True, activation='relu'):
    """Standard convolution block with BatchNorm and activation."""
    x = Conv2D(filters, kernel_size, strides=strides, padding=padding)(x)
    
    if use_bn:
        x = BatchNormalization()(x)
    
    if activation:
        x = Activation(activation)(x)
    
    return x

def attention_gate(x, g, inter_channels):
    """
    Attention Gate as described in Attention U-Net paper.
    Args:
        x: Feature map from skip connection (from encoder)
        g: Gating signal from previous decoder layer
        inter_channels: Number of channels in intermediate representations
    """
    # Resize gating signal to match feature map's spatial dimensions if needed
    g = ResizeToMatchLayer()([g, x])
    
    # Intermediate representation for input feature map
    theta_x = Conv2D(inter_channels, 1, use_bias=False, padding='same')(x)
    
    # Intermediate representation for gating signal
    phi_g = Conv2D(inter_channels, 1, use_bias=False, padding='same')(g)
    
    # Element-wise sum and ReLU
    f = Activation('relu')(Add()([theta_x, phi_g]))
    
    # 1x1 convolution followed by sigmoid to get attention coefficients
    psi_f = Conv2D(1, 1, use_bias=False, padding='same')(f)
    att_map = Activation('sigmoid')(psi_f)
    
    # Apply attention
    return Multiply()([x, att_map])

def decoder_block(x, skip_connection, filters, use_attention=True):
    """Decoder block for Attention U-Net."""
    # Upsampling
    x = UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
    
    # Ensure dimensions match for concatenation
    x = ResizeToMatchLayer()([x, skip_connection])
    
    # Apply attention mechanism if specified
    if use_attention:
        # Generate attention-gated skip connection
        skip_connection = attention_gate(skip_connection, x, filters // 2)
    
    # Concatenate with skip connection
    x = Concatenate()([x, skip_connection])
    
    # Apply two convolution blocks
    x = conv_block(x, filters, 3, padding='same')
    x = conv_block(x, filters, 3, padding='same')
    
    return x

def build_efficientnet_attention_unet(input_shape, num_classes):
    """
    Build an Attention U-Net model with EfficientNetB4 backbone for semantic segmentation.
    
    Args:
        input_shape: Input shape of the image (height, width, channels)
        num_classes: Number of segmentation classes
        
    Returns:
        A Keras Model instance
    """
    inputs = Input(shape=input_shape)
        
    # Load EfficientNetB4 with pre-trained weights as encoder backbone
    # All layers are trainable for fine-tuning
    base_model = EfficientNetB4(
        weights='imagenet',
        include_top=False,
        input_tensor=inputs
    )
    
    # Reduce filter count to control parameter count since we're not freezing any layers
    initial_filters = 32
    
    # Get skip connections from appropriate layers
    skip1 = base_model.get_layer('block1b_add').output        # 1/2 scale (112x112)
    skip2 = base_model.get_layer('block2d_add').output        # 1/4 scale (56x56)
    skip3 = base_model.get_layer('block3d_add').output        # 1/8 scale (28x28)
    skip4 = base_model.get_layer('block5e_add').output        # 1/16 scale (14x14)
    
    # Bridge (bottleneck)
    bridge = base_model.get_layer('top_activation').output    # 1/32 scale (7x7)
    
    
    # Reduce channels for each skip connection to control parameter count
    skip1_conv = conv_block(skip1, initial_filters)
    skip2_conv = conv_block(skip2, initial_filters * 2)
    skip3_conv = conv_block(skip3, initial_filters * 4)
    skip4_conv = conv_block(skip4, initial_filters * 8)
    
    # Reduce channels in bridge
    bridge_conv = conv_block(bridge, initial_filters * 16)
    
    # Decoder pathway with attention gates
    d1 = decoder_block(bridge_conv, skip4_conv, initial_filters * 8, use_attention=True)  # 1/16
    d2 = decoder_block(d1, skip3_conv, initial_filters * 4, use_attention=True)           # 1/8
    d3 = decoder_block(d2, skip2_conv, initial_filters * 2, use_attention=True)           # 1/4
    d4 = decoder_block(d3, skip1_conv, initial_filters, use_attention=True)               # 1/2
    
    # Final upsampling to original image size
    final = UpSampling2D(size=(2, 2), interpolation='bilinear')(d4)
    
    # Final convolution to generate segmentation map
    outputs = Conv2D(num_classes, 1, padding='same', activation='softmax')(final)
    
    # Create and return the model
    model = Model(inputs=inputs, outputs=outputs)
    
    return model

# Build the model
model = build_efficientnet_attention_unet(input_shape=(224, 224, 3), num_classes=4)

In [None]:
gflops = calculate_gflops(model)
print(f"üìè Model GFLOPs: {gflops:.2f} GFLOPs")

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.applications import Xception

# Constants for 224x224 images
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS = 3
NUM_CLASSES = 4  # Brain, CSP, LV, Background

def convolution_block(inputs, filters, kernel_size=3, dilation_rate=1, padding='same', use_bias=False):
    """
    Standard convolution block with batch normalization and ReLU activation
    """
    x = layers.Conv2D(
        filters, 
        kernel_size, 
        padding=padding,
        dilation_rate=dilation_rate,
        use_bias=use_bias
    )(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

def ASPP(inputs):
    """
    Atrous Spatial Pyramid Pooling module for DeepLabV3+
    """
    # ASPP with different dilation rates
    b0 = convolution_block(inputs, 256, kernel_size=1, dilation_rate=1)
    b1 = convolution_block(inputs, 256, kernel_size=3, dilation_rate=6)
    b2 = convolution_block(inputs, 256, kernel_size=3, dilation_rate=12)
    b3 = convolution_block(inputs, 256, kernel_size=3, dilation_rate=18)
    
    # Global context - simplified approach
    b4 = layers.GlobalAveragePooling2D()(inputs)
    b4 = layers.Reshape((1, 1, inputs.shape[-1]))(b4)
    b4 = convolution_block(b4, 256, kernel_size=1)
    # Use fixed upsampling instead of dynamic
    b4 = layers.UpSampling2D(size=(inputs.shape[1], inputs.shape[2]))(b4)
    
    # Concatenate all branches
    x = layers.Concatenate()([b0, b1, b2, b3, b4])
    
    # Final 1x1 convolution
    output = convolution_block(x, 256, kernel_size=1)
    return output

def build_deeplabv3_plus_xception(input_shape, num_classes):
    """
    DeepLabV3+ model with Xception backbone
    """
    inputs = Input(input_shape)
    
    # Xception as backbone (with output stride of 16)
    base_model = Xception(
        input_tensor=inputs,
        include_top=False,
        weights='imagenet'
    )
    
    # Don't freeze any layers
    for layer in base_model.layers:
        layer.trainable = True
    
    # Extract features from Xception
    # The entry flow ends with 'block4_sepconv2_bn' which is a good low-level feature point
    low_level_features = base_model.get_layer('block4_sepconv2_bn').output
    # The final features from the exit flow
    high_level_features = base_model.output
    
    # Process low-level features
    low_level_features = convolution_block(low_level_features, 48, kernel_size=1)
    
    # Process high-level features with ASPP
    x = ASPP(high_level_features)
    
    # Calculate upsampling factor for high-level features to match low-level features
    hl_shape = high_level_features.shape
    ll_shape = low_level_features.shape
    h_factor = ll_shape[1] // hl_shape[1]
    w_factor = ll_shape[2] // hl_shape[2]
    
    # Upsample high-level features to match low-level features
    x = layers.UpSampling2D(size=(h_factor, w_factor), interpolation='bilinear')(x)
    
    # Concatenate features
    x = layers.Concatenate()([x, low_level_features])
    
    # Apply convolution blocks
    x = convolution_block(x, 256, kernel_size=3)
    x = convolution_block(x, 256, kernel_size=3)
    
    # Calculate upsampling factor needed to reach 224x224
    current_shape = x.shape
    h_factor = IMG_HEIGHT // current_shape[1]
    w_factor = IMG_WIDTH // current_shape[2]
    
    # Final upsampling to original size (224x224)
    x = layers.UpSampling2D(size=(h_factor, w_factor), interpolation='bilinear')(x)
    
    # Ensure exact dimensions with a reshape if needed
    x = layers.Reshape((IMG_HEIGHT, IMG_WIDTH, int(current_shape[3])))(x)
    
    # Final convolution for output (224, 224, 4)
    outputs = layers.Conv2D(num_classes, kernel_size=1, padding='same', activation='softmax')(x)
    
    # Create model
    model = Model(inputs=inputs, outputs=outputs)
    return model

# Build model
model_xception = build_deeplabv3_plus_xception(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), 
                                     num_classes=NUM_CLASSES)

gflops = calculate_gflops(model_xception)
print(f"üìè Model GFLOPs: {gflops:.2f} GFLOPs")

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

# Define your data
data = {
    "Model": ["Unet++(InceptionResNetV2)", "AttentionUnet(EfficientNetB4)", "Xception", "SegNet", "LightSeg"],
    "Training Time (s)": [190.62, 149.88, 102.30, 250.62, 164.84],
    "GFLOPs": [34.78, 8.26, 17.15, 85.32, 18.89],
    "Energy (Wh)": [5.56, 4.00, 3.28, 9.93, 5.51],
    "Params (M)": [69.3, 29.7, 37.8, 31.4, 1.33],
    "Size (MB)": [796, 114, 432, 359, 5.35]
}

df = pd.DataFrame(data)

# Set seaborn style
sns.set(style="whitegrid")

# List of metrics to plot
metrics = ["Training Time (s)", "GFLOPs", "Energy (Wh)", "Params (M)", "Size (MB)"]

# Plot each metric
for metric in metrics:
    plt.figure(figsize=(8, 5))
    sns.barplot(x="Model", y=metric, data=df, palette="coolwarm")
    plt.title(f"Comparison of Models by {metric}")
    plt.ylabel(metric)
    plt.xlabel("Model")
    plt.xticks(rotation=15)
    plt.tight_layout()
    plt.show()

In [None]:
import matplotlib.pyplot as plt

# Extracting data from the history object
history_dict = history.history

# Plotting the training and validation loss
plt.figure(figsize=(12, 6))

# Plotting loss
plt.subplot(1, 2, 1)
plt.plot(history_dict['loss'], label='Training Loss')
plt.plot(history_dict['val_loss'], label='Validation Loss')
plt.title('Loss Curves')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# If accuracy is available, plot training and validation accuracy
if 'accuracy' in history_dict:
    plt.subplot(1, 2, 2)
    plt.plot(history_dict['accuracy'], label='Training Accuracy')
    plt.plot(history_dict['val_accuracy'], label='Validation Accuracy')
    plt.title('Accuracy Curves')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

plt.tight_layout()
plt.show()


In [None]:
kd_model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
student_model.evaluate(X_test, y_test, batch_size=8)

In [None]:
student_model.evaluate(X_test, y_test, batch_size=8)

In [None]:
import numpy as np
import tensorflow as tf
import scipy.spatial.distance as dist
from scipy.ndimage import binary_erosion
import psutil
import gc
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def check_memory():
    """Check available system memory"""
    memory = psutil.virtual_memory()
    logger.info(f"Total RAM: {memory.total / (1024**3):.2f} GB")
    logger.info(f"Available RAM: {memory.available / (1024**3):.2f} GB")
    logger.info(f"Used RAM: {memory.used / (1024**3):.2f} GB")
    logger.info(f"Memory percentage used: {memory.percent:.1f}%")
    return memory


def get_boundary_points(mask):
    """Extract boundary points from a binary mask."""
    if not np.any(mask):
        return np.empty((0, mask.ndim), dtype=int)
    eroded = binary_erosion(mask)
    boundary = mask & ~eroded
    return np.argwhere(boundary)


def compute_surface_distances_optimized(pred, true, max_points=1000):
    """
    Memory-optimized surface distance calculation with systematic point sampling.
    """
    pred = tf.cast(pred, tf.bool).numpy()
    true = tf.cast(true, tf.bool).numpy()

    pred_boundary = get_boundary_points(pred)
    true_boundary = get_boundary_points(true)

    if len(pred_boundary) == 0 or len(true_boundary) == 0:
        if len(pred_boundary) == 0 and len(true_boundary) == 0:
            return np.array([[0]]), np.array([[0]])
        else:
            return np.array([[np.inf]]), np.array([[np.inf]])

    if len(pred_boundary) > max_points:
        step = len(pred_boundary) // max_points
        pred_boundary = pred_boundary[::step][:max_points]
        logger.info(f"Sampled pred boundary to {len(pred_boundary)}")

    if len(true_boundary) > max_points:
        step = len(true_boundary) // max_points
        true_boundary = true_boundary[::step][:max_points]
        logger.info(f"Sampled true boundary to {len(true_boundary)}")

    estimated_memory_gb = (len(pred_boundary) * len(true_boundary) * 8) / (1024**3)
    logger.info(f"Estimated memory needed: {estimated_memory_gb:.2f} GB")

    if estimated_memory_gb > 2.0:
        return compute_distances_chunked(pred_boundary, true_boundary)
    else:
        dist_pred_to_true = dist.cdist(pred_boundary, true_boundary, 'euclidean')
        dist_true_to_pred = dist.cdist(true_boundary, pred_boundary, 'euclidean')
        return dist_pred_to_true, dist_true_to_pred


def compute_distances_chunked(pred_boundary, true_boundary, chunk_size=500):
    """
    Compute distances in chunks to avoid memory issues.
    """
    logger.info("Using chunked distance computation...")
    min_pred_to_true = np.full(len(pred_boundary), np.inf)
    min_true_to_pred = np.full(len(true_boundary), np.inf)

    for i in range(0, len(pred_boundary), chunk_size):
        chunk_pred = pred_boundary[i:i + chunk_size]
        chunk_dist = dist.cdist(chunk_pred, true_boundary, 'euclidean')
        min_pred_to_true[i:i + chunk_size] = np.min(chunk_dist, axis=1)
        del chunk_dist
        gc.collect()

    for j in range(0, len(true_boundary), chunk_size):
        chunk_true = true_boundary[j:j + chunk_size]
        chunk_dist = dist.cdist(chunk_true, pred_boundary, 'euclidean')
        min_true_to_pred[j:j + chunk_size] = np.min(chunk_dist, axis=1)
        del chunk_dist
        gc.collect()

    return min_pred_to_true.reshape(-1, 1), min_true_to_pred.reshape(-1, 1)


def hausdorff_distance_95(dist_pred_to_true, dist_true_to_pred):
    if dist_pred_to_true.size == 0 or dist_true_to_pred.size == 0:
        return np.inf
    pred_to_true_95 = np.percentile(np.min(dist_pred_to_true, axis=1), 95)
    true_to_pred_95 = np.percentile(np.min(dist_true_to_pred, axis=1), 95)
    return max(pred_to_true_95, true_to_pred_95)

def average_symmetric_surface_distance_optimized(dist_pred_to_true, dist_true_to_pred):
    if dist_pred_to_true.size == 0 or dist_true_to_pred.size == 0:
        return np.inf
    return (np.mean(np.min(dist_pred_to_true, axis=1)) + np.mean(np.min(dist_true_to_pred, axis=1))) / 2


def calculate_hd_asd_combined_mask(model, x_test, y_test, batch_size=8, max_boundary_points=1000, foreground_class_indices=[1, 2, 3]):
    """
    Compute HD and ASD for combined foreground mask (e.g., brain/tumor/all relevant classes).
    Matches style of FetSAM evaluation.
    """
    logger.info("Running combined-mask HD/ASD computation...")
    check_memory()
    
    all_hd = []
    all_asd = []

    # Predict in batches
    y_preds = model.predict(x_test, batch_size=batch_size, verbose=1)

    for i in range(len(y_preds)):
        y_pred = y_preds[i]
        y_true = y_test[i]

        # Convert softmax outputs to class predictions
        y_pred_labels = np.argmax(y_pred, axis=-1)  # shape: (H, W)
        y_true_labels = np.argmax(y_true, axis=-1)  # shape: (H, W)

        # Combine specified foreground classes into one binary mask
        y_pred_binary = np.isin(y_pred_labels, foreground_class_indices).astype(bool)
        y_true_binary = np.isin(y_true_labels, foreground_class_indices).astype(bool)

        # Compute surface distances
        dist_pred_to_true, dist_true_to_pred = compute_surface_distances_optimized(
            y_pred_binary, y_true_binary, max_boundary_points
        )

        # Compute HD and ASD
        hd = hausdorff_distance_optimized(dist_pred_to_true, dist_true_to_pred)
        asd = average_symmetric_surface_distance_optimized(dist_pred_to_true, dist_true_to_pred)

        if not np.isinf(hd) and not np.isinf(asd):
            all_hd.append(hd)
            all_asd.append(asd)
            logger.info(f"Sample {i}: HD = {hd:.4f}, ASD = {asd:.4f}")
        else:
            logger.warning(f"Skipping sample {i} due to empty masks")

        del dist_pred_to_true, dist_true_to_pred
        gc.collect()

    if not all_hd:
        logger.error("No valid samples for HD/ASD calculation")
        return np.nan, np.nan

    mean_hd = np.mean(all_hd)
    mean_asd = np.mean(all_asd)

    return mean_hd, mean_asd

# For a combined mask of classes 1, 2, 3 (e.g. whole fetal head):
mean_hd, mean_asd = calculate_hd_asd_combined_mask(student_model, X_test, y_test, batch_size=8)

print(f"Combined HD: {mean_hd:.4f}")
print(f"Combined ASD: {mean_asd:.4f}")

In [None]:
import numpy as np
import tensorflow as tf
import scipy.spatial.distance as dist
from scipy.ndimage import binary_erosion
import psutil
import gc
import logging

# Setup logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def check_memory():
    """Check and log system memory."""
    memory = psutil.virtual_memory()
    logger.info(f"Total RAM: {memory.total / (1024**3):.2f} GB")
    logger.info(f"Available RAM: {memory.available / (1024**3):.2f} GB")
    logger.info(f"Used RAM: {memory.used / (1024**3):.2f} GB")
    logger.info(f"Memory usage: {memory.percent:.1f}%")
    return memory


def get_boundary_points(mask):
    """Extract boundary points from a binary mask."""
    if not np.any(mask):
        return np.empty((0, mask.ndim), dtype=int)
    eroded = binary_erosion(mask)
    boundary = mask & ~eroded
    return np.argwhere(boundary)


def compute_distances_chunked(pred_boundary, true_boundary, chunk_size=500):
    """Compute surface distances in memory-safe chunks."""
    min_pred_to_true = np.full(len(pred_boundary), np.inf)
    min_true_to_pred = np.full(len(true_boundary), np.inf)

    for i in range(0, len(pred_boundary), chunk_size):
        chunk_pred = pred_boundary[i:i + chunk_size]
        dists = dist.cdist(chunk_pred, true_boundary, 'euclidean')
        min_pred_to_true[i:i + chunk_size] = np.min(dists, axis=1)
        del dists
        gc.collect()

    for j in range(0, len(true_boundary), chunk_size):
        chunk_true = true_boundary[j:j + chunk_size]
        dists = dist.cdist(chunk_true, pred_boundary, 'euclidean')
        min_true_to_pred[j:j + chunk_size] = np.min(dists, axis=1)
        del dists
        gc.collect()

    return min_pred_to_true.reshape(-1, 1), min_true_to_pred.reshape(-1, 1)


def compute_surface_distances_optimized(pred, true, max_points=1000):
    """Efficient surface distance computation using boundary sampling."""
    pred = tf.cast(pred, tf.bool).numpy()
    true = tf.cast(true, tf.bool).numpy()

    pred_boundary = get_boundary_points(pred)
    true_boundary = get_boundary_points(true)

    if len(pred_boundary) == 0 or len(true_boundary) == 0:
        if len(pred_boundary) == 0 and len(true_boundary) == 0:
            return np.array([[0]]), np.array([[0]])
        return np.array([[np.inf]]), np.array([[np.inf]])

    if len(pred_boundary) > max_points:
        step = len(pred_boundary) // max_points
        pred_boundary = pred_boundary[::step][:max_points]

    if len(true_boundary) > max_points:
        step = len(true_boundary) // max_points
        true_boundary = true_boundary[::step][:max_points]

    est_mem_gb = (len(pred_boundary) * len(true_boundary) * 8) / (1024**3)
    if est_mem_gb > 2.0:
        return compute_distances_chunked(pred_boundary, true_boundary)

    dptt = dist.cdist(pred_boundary, true_boundary, 'euclidean')
    dttp = dist.cdist(true_boundary, pred_boundary, 'euclidean')
    return dptt, dttp


def hausdorff_distance_95(dist_pred_to_true, dist_true_to_pred):
    """Compute the 95th percentile Hausdorff Distance."""
    if dist_pred_to_true.size == 0 or dist_true_to_pred.size == 0:
        return np.inf
    pred_to_true_95 = np.percentile(np.min(dist_pred_to_true, axis=1), 95)
    true_to_pred_95 = np.percentile(np.min(dist_true_to_pred, axis=1), 95)
    return max(pred_to_true_95, true_to_pred_95)


def average_symmetric_surface_distance(dist_pred_to_true, dist_true_to_pred):
    """Compute ASD (Average Symmetric Surface Distance)."""
    if dist_pred_to_true.size == 0 or dist_true_to_pred.size == 0:
        return np.inf
    return (np.mean(np.min(dist_pred_to_true, axis=1)) + np.mean(np.min(dist_true_to_pred, axis=1))) / 2


def calculate_hd95_asd_combined_mask(model, x_test, y_test, batch_size=8, max_boundary_points=1000, foreground_class_indices=[1, 2, 3]):
    """
    Compute HD95 and ASD over combined foreground (e.g., brain, CSP, LV).
    Reviewer-compliant: accurate, interpretable, consistent with medical benchmarks.
    """
    logger.info("üîç Starting combined-mask HD95 + ASD evaluation")
    check_memory()

    all_hd95 = []
    all_asd = []

    y_preds = model.predict(x_test, batch_size=batch_size, verbose=1)

    for i in range(len(y_preds)):
        y_pred = y_preds[i]
        y_true = y_test[i]

        # Convert softmax outputs to class labels
        y_pred_labels = np.argmax(y_pred, axis=-1)
        y_true_labels = np.argmax(y_true, axis=-1)

        # Create combined foreground binary masks
        y_pred_binary = np.isin(y_pred_labels, foreground_class_indices)
        y_true_binary = np.isin(y_true_labels, foreground_class_indices)

        # Compute distances
        dist_pred_to_true, dist_true_to_pred = compute_surface_distances_optimized(
            y_pred_binary, y_true_binary, max_boundary_points
        )

        hd95 = hausdorff_distance_95(dist_pred_to_true, dist_true_to_pred)
        asd = average_symmetric_surface_distance(dist_pred_to_true, dist_true_to_pred)

        if not np.isinf(hd95) and not np.isinf(asd):
            all_hd95.append(hd95)
            all_asd.append(asd)
            logger.info(f"Sample {i:03d} | HD95 = {hd95:.2f} | ASD = {asd:.2f}")
        else:
            logger.warning(f"‚ö†Ô∏è Sample {i} skipped due to empty mask")

        del dist_pred_to_true, dist_true_to_pred
        gc.collect()

    if len(all_hd95) == 0:
        logger.error("‚ùå No valid samples for HD95/ASD computation.")
        return np.nan, np.nan

    # Summary stats
    mean_hd95 = np.mean(all_hd95)
    median_hd95 = np.median(all_hd95)
    mean_asd = np.mean(all_asd)

    logger.info("\nüìä Final Evaluation Results:")
    logger.info(f"Mean HD95: {mean_hd95:.4f}")
    logger.info(f"Median HD95: {median_hd95:.4f}")
    logger.info(f"Mean ASD: {mean_asd:.4f}")

    return mean_hd95, mean_asd

mean_hd95, mean_asd = calculate_hd95_asd_combined_mask(
    model=student_model,
    x_test=X_test,
    y_test=y_test,
    batch_size=8,
    foreground_class_indices=[1, 2, 3]  # Brain + CSP + LV
)

print(f"\n‚úÖ Combined HD95: {mean_hd95:.4f}")
print(f"‚úÖ Combined ASD:  {mean_asd:.4f}")

In [None]:
import numpy as np
import tensorflow as tf
import scipy.spatial.distance as dist
from scipy.ndimage import binary_erosion
import psutil
import gc
import logging

# Setup logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def check_memory():
    """Check and log system memory."""
    memory = psutil.virtual_memory()
    logger.info(f"Total RAM: {memory.total / (1024**3):.2f} GB")
    logger.info(f"Available RAM: {memory.available / (1024**3):.2f} GB")
    logger.info(f"Used RAM: {memory.used / (1024**3):.2f} GB")
    logger.info(f"Memory usage: {memory.percent:.1f}%")
    return memory


def get_boundary_points(mask):
    """Extract boundary points from a binary mask."""
    if not np.any(mask):
        return np.empty((0, mask.ndim), dtype=int)
    eroded = binary_erosion(mask)
    boundary = mask & ~eroded
    return np.argwhere(boundary)


def compute_distances_chunked(pred_boundary, true_boundary, chunk_size=500):
    """Compute surface distances in memory-safe chunks."""
    min_pred_to_true = np.full(len(pred_boundary), np.inf)
    min_true_to_pred = np.full(len(true_boundary), np.inf)

    for i in range(0, len(pred_boundary), chunk_size):
        chunk_pred = pred_boundary[i:i + chunk_size]
        dists = dist.cdist(chunk_pred, true_boundary, 'euclidean')
        min_pred_to_true[i:i + chunk_size] = np.min(dists, axis=1)
        del dists
        gc.collect()

    for j in range(0, len(true_boundary), chunk_size):
        chunk_true = true_boundary[j:j + chunk_size]
        dists = dist.cdist(chunk_true, pred_boundary, 'euclidean')
        min_true_to_pred[j:j + chunk_size] = np.min(dists, axis=1)
        del dists
        gc.collect()

    return min_pred_to_true.reshape(-1, 1), min_true_to_pred.reshape(-1, 1)


def compute_surface_distances_optimized(pred, true, max_points=1000):
    """Efficient surface distance computation using boundary sampling."""
    pred = tf.cast(pred, tf.bool).numpy()
    true = tf.cast(true, tf.bool).numpy()

    pred_boundary = get_boundary_points(pred)
    true_boundary = get_boundary_points(true)

    if len(pred_boundary) == 0 or len(true_boundary) == 0:
        if len(pred_boundary) == 0 and len(true_boundary) == 0:
            return np.array([[0]]), np.array([[0]])
        return np.array([[np.inf]]), np.array([[np.inf]])

    if len(pred_boundary) > max_points:
        step = len(pred_boundary) // max_points
        pred_boundary = pred_boundary[::step][:max_points]

    if len(true_boundary) > max_points:
        step = len(true_boundary) // max_points
        true_boundary = true_boundary[::step][:max_points]

    est_mem_gb = (len(pred_boundary) * len(true_boundary) * 8) / (1024**3)
    if est_mem_gb > 2.0:
        return compute_distances_chunked(pred_boundary, true_boundary)

    dptt = dist.cdist(pred_boundary, true_boundary, 'euclidean')
    dttp = dist.cdist(true_boundary, pred_boundary, 'euclidean')
    return dptt, dttp


def hausdorff_distance_95(dist_pred_to_true, dist_true_to_pred):
    """Compute the 95th percentile Hausdorff Distance."""
    if dist_pred_to_true.size == 0 or dist_true_to_pred.size == 0:
        return np.inf
    pred_to_true_95 = np.percentile(np.min(dist_pred_to_true, axis=1), 95)
    true_to_pred_95 = np.percentile(np.min(dist_true_to_pred, axis=1), 95)
    return max(pred_to_true_95, true_to_pred_95)


def average_symmetric_surface_distance(dist_pred_to_true, dist_true_to_pred):
    """Compute ASD (Average Symmetric Surface Distance)."""
    if dist_pred_to_true.size == 0 or dist_true_to_pred.size == 0:
        return np.inf
    return (np.mean(np.min(dist_pred_to_true, axis=1)) + np.mean(np.min(dist_true_to_pred, axis=1))) / 2


def calculate_hd95_asd_combined_mask(model, x_test, y_test, batch_size=8, max_boundary_points=1000, foreground_class_indices=[1, 2, 3]):
    """
    Compute HD95 and ASD over combined foreground (e.g., brain, CSP, LV).
    Reviewer-compliant: accurate, interpretable, consistent with medical benchmarks.
    """
    logger.info("üîç Starting combined-mask HD95 + ASD evaluation")
    check_memory()

    all_hd95 = []
    all_asd = []

    y_preds = model.predict(x_test, batch_size=batch_size, verbose=1)

    for i in range(len(y_preds)):
        y_pred = y_preds[i]
        y_true = y_test[i]

        # Convert softmax outputs to class labels
        y_pred_labels = np.argmax(y_pred, axis=-1)
        y_true_labels = np.argmax(y_true, axis=-1)

        # Create combined foreground binary masks
        y_pred_binary = np.isin(y_pred_labels, foreground_class_indices)
        y_true_binary = np.isin(y_true_labels, foreground_class_indices)

        # Compute distances
        dist_pred_to_true, dist_true_to_pred = compute_surface_distances_optimized(
            y_pred_binary, y_true_binary, max_boundary_points
        )

        hd95 = hausdorff_distance_95(dist_pred_to_true, dist_true_to_pred)
        asd = average_symmetric_surface_distance(dist_pred_to_true, dist_true_to_pred)

        if not np.isinf(hd95) and not np.isinf(asd):
            all_hd95.append(hd95)
            all_asd.append(asd)
            logger.info(f"Sample {i:03d} | HD95 = {hd95:.2f} | ASD = {asd:.2f}")
        else:
            logger.warning(f"‚ö†Ô∏è Sample {i} skipped due to empty mask")

        del dist_pred_to_true, dist_true_to_pred
        gc.collect()

    if len(all_hd95) == 0:
        logger.error("‚ùå No valid samples for HD95/ASD computation.")
        return np.nan, np.nan

    # Summary stats
    mean_hd95 = np.mean(all_hd95)
    median_hd95 = np.median(all_hd95)
    mean_asd = np.mean(all_asd)

    logger.info("\nüìä Final Evaluation Results:")
    logger.info(f"Mean HD95: {mean_hd95:.4f}")
    logger.info(f"Median HD95: {median_hd95:.4f}")
    logger.info(f"Mean ASD: {mean_asd:.4f}")

    return mean_hd95, mean_asd

mean_hd95, mean_asd = calculate_hd95_asd_combined_mask(
    model=model_xception,
    x_test=X_test,
    y_test=y_test,
    batch_size=16,
    foreground_class_indices=[1, 2, 3]  # Brain + CSP + LV
)

print(f"\n‚úÖ Combined HD95: {mean_hd95:.4f}")
print(f"‚úÖ Combined ASD:  {mean_asd:.4f}")

In [None]:
import numpy as np
import tensorflow as tf
import scipy.spatial.distance as dist
from scipy.ndimage import binary_erosion
import psutil
import gc
import logging
import matplotlib.pyplot as plt

# Setup logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def check_memory():
    memory = psutil.virtual_memory()
    logger.info(f"Total RAM: {memory.total / (1024**3):.2f} GB")
    logger.info(f"Available RAM: {memory.available / (1024**3):.2f} GB")
    logger.info(f"Used RAM: {memory.used / (1024**3):.2f} GB")
    logger.info(f"Memory usage: {memory.percent:.1f}%")
    return memory

def get_boundary_points(mask):
    if not np.any(mask):
        return np.empty((0, mask.ndim), dtype=int)
    eroded = binary_erosion(mask)
    boundary = mask & ~eroded
    return np.argwhere(boundary)

def compute_distances_chunked(pred_boundary, true_boundary, chunk_size=500):
    min_pred_to_true = np.full(len(pred_boundary), np.inf)
    min_true_to_pred = np.full(len(true_boundary), np.inf)

    for i in range(0, len(pred_boundary), chunk_size):
        chunk_pred = pred_boundary[i:i + chunk_size]
        dists = dist.cdist(chunk_pred, true_boundary, 'euclidean')
        min_pred_to_true[i:i + chunk_size] = np.min(dists, axis=1)
        del dists
        gc.collect()

    for j in range(0, len(true_boundary), chunk_size):
        chunk_true = true_boundary[j:j + chunk_size]
        dists = dist.cdist(chunk_true, pred_boundary, 'euclidean')
        min_true_to_pred[j:j + chunk_size] = np.min(dists, axis=1)
        del dists
        gc.collect()

    return min_pred_to_true.reshape(-1, 1), min_true_to_pred.reshape(-1, 1)

def compute_surface_distances_optimized(pred, true, max_points=1000):
    pred = tf.cast(pred, tf.bool).numpy()
    true = tf.cast(true, tf.bool).numpy()

    pred_boundary = get_boundary_points(pred)
    true_boundary = get_boundary_points(true)

    if len(pred_boundary) == 0 or len(true_boundary) == 0:
        if len(pred_boundary) == 0 and len(true_boundary) == 0:
            return np.array([[0]]), np.array([[0]])
        return np.array([[np.inf]]), np.array([[np.inf]])

    if len(pred_boundary) > max_points:
        step = len(pred_boundary) // max_points
        pred_boundary = pred_boundary[::step][:max_points]

    if len(true_boundary) > max_points:
        step = len(true_boundary) // max_points
        true_boundary = true_boundary[::step][:max_points]

    est_mem_gb = (len(pred_boundary) * len(true_boundary) * 8) / (1024**3)
    if est_mem_gb > 2.0:
        return compute_distances_chunked(pred_boundary, true_boundary)

    dptt = dist.cdist(pred_boundary, true_boundary, 'euclidean')
    dttp = dist.cdist(true_boundary, pred_boundary, 'euclidean')
    return dptt, dttp

def hausdorff_distance_95(dist_pred_to_true, dist_true_to_pred):
    if dist_pred_to_true.size == 0 or dist_true_to_pred.size == 0:
        return np.inf
    pred_to_true_95 = np.percentile(np.min(dist_pred_to_true, axis=1), 95)
    true_to_pred_95 = np.percentile(np.min(dist_true_to_pred, axis=1), 95)
    return max(pred_to_true_95, true_to_pred_95)

def average_symmetric_surface_distance(dist_pred_to_true, dist_true_to_pred):
    if dist_pred_to_true.size == 0 or dist_true_to_pred.size == 0:
        return np.inf
    return (np.mean(np.min(dist_pred_to_true, axis=1)) + np.mean(np.min(dist_true_to_pred, axis=1))) / 2


# ‚úÖ Storage for visualization
worst_hd95_results = []

def calculate_hd95_asd_combined_mask(model, x_test, y_test, batch_size=8, max_boundary_points=1000, foreground_class_indices=[1, 2, 3]):
    logger.info("üîç Starting combined-mask HD95 + ASD evaluation")
    check_memory()

    all_hd95 = []
    all_asd = []

    y_preds = model.predict(x_test, batch_size=batch_size, verbose=1)

    for i in range(len(y_preds)):
        y_pred = y_preds[i]
        y_true = y_test[i]

        y_pred_labels = np.argmax(y_pred, axis=-1)
        y_true_labels = np.argmax(y_true, axis=-1)

        y_pred_binary = np.isin(y_pred_labels, foreground_class_indices)
        y_true_binary = np.isin(y_true_labels, foreground_class_indices)

        dist_pred_to_true, dist_true_to_pred = compute_surface_distances_optimized(
            y_pred_binary, y_true_binary, max_boundary_points
        )

        hd95 = hausdorff_distance_95(dist_pred_to_true, dist_true_to_pred)
        asd = average_symmetric_surface_distance(dist_pred_to_true, dist_true_to_pred)

        if not np.isinf(hd95) and not np.isinf(asd):
            all_hd95.append(hd95)
            all_asd.append(asd)
            worst_hd95_results.append((i, hd95, y_pred_labels, y_true_labels))  # Save for plotting
            logger.info(f"Sample {i:03d} | HD95 = {hd95:.2f} | ASD = {asd:.2f}")
        else:
            logger.warning(f"‚ö†Ô∏è Sample {i} skipped due to empty mask")

        del dist_pred_to_true, dist_true_to_pred
        gc.collect()

    if len(all_hd95) == 0:
        logger.error("‚ùå No valid samples for HD95/ASD computation.")
        return np.nan, np.nan

    mean_hd95 = np.mean(all_hd95)
    median_hd95 = np.median(all_hd95)
    mean_asd = np.mean(all_asd)

    logger.info("\nüìä Final Evaluation Results:")
    logger.info(f"Mean HD95: {mean_hd95:.4f}")
    logger.info(f"Median HD95: {median_hd95:.4f}")
    logger.info(f"Mean ASD: {mean_asd:.4f}")

    return mean_hd95, mean_asd


# ‚úÖ New: Visualization function
def visualize_worst_predictions(X_test, worst_results, top_n=5, class_colors=None):
    if class_colors is None:
        # Default class color map: background, brain, CSP, LV
        class_colors = {
            0: [0, 0, 0],
            1: [255, 0, 0],
            2: [0, 255, 0],
            3: [0, 0, 255],
        }

    worst_results = sorted(worst_results, key=lambda x: x[1], reverse=True)[:top_n]

    for idx, (i, hd95, y_pred_label, y_true_label) in enumerate(worst_results):
        img = X_test[i]

        def color_mask(label_mask):
            color_mask = np.zeros((label_mask.shape[0], label_mask.shape[1], 3), dtype=np.uint8)
            for class_idx, color in class_colors.items():
                color_mask[label_mask == class_idx] = color
            return color_mask

        pred_rgb = color_mask(y_pred_label)
        true_rgb = color_mask(y_true_label)
        error_map = (y_pred_label != y_true_label).astype(np.uint8)

        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        axs[0].imshow(img)
        axs[0].set_title(f"Input Image [{i}]")
        axs[1].imshow(true_rgb)
        axs[1].set_title("Ground Truth")
        axs[2].imshow(pred_rgb)
        axs[2].set_title("Prediction")
        axs[3].imshow(error_map, cmap='hot')
        axs[3].set_title(f"Error Map (HD95={hd95:.2f})")

        for ax in axs:
            ax.axis('off')
        plt.tight_layout()
        plt.show()

# Run evaluation and store worst cases
mean_hd95, mean_asd = calculate_hd95_asd_combined_mask(
    model=student_model,
    x_test=X_test,
    y_test=y_test,
    batch_size=8,
    foreground_class_indices=[1, 2, 3]
)

# Visualize top 5 worst predictions
visualize_worst_predictions(X_test, worst_hd95_results, top_n=5)

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff
from tensorflow.keras.utils import to_categorical

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ Function to convert RGB masks to class index masks
def rgb_to_class_mask(rgb_mask):
    # Create a mask initialized with zeros (for background class)
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)

    # Loop through the RGB_TO_CLASS dictionary
    for rgb, class_idx in RGB_TO_CLASS.items():
        # Identify the pixels with the current RGB value and assign them the class index
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx

    return class_mask

# ‚úÖ Function to calculate Dice Similarity Coefficient (DSC)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ Function to calculate IoU (Intersection over Union)
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Function to calculate Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')  # Return inf if no points for either true or pred class

    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Function to calculate Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')  # Return inf if no points for either true or pred class

    distances = []
    for true_point in true_points:
        distances.append(np.min(np.linalg.norm(pred_points - true_point, axis=1)))
    return np.mean(distances)

# ‚úÖ Function to evaluate the model on the test set class-wise
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=8):
    # Predict in batches
    y_pred = model.predict(X_test, batch_size=batch_size)
    y_pred = np.argmax(y_pred, axis=-1)  # Convert to class index prediction

    # Convert y_test to class index format (since it's one-hot encoded)
    y_test_class = np.argmax(y_test, axis=-1)

    # Initialize lists to store class-wise metrics
    class_metrics ={i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []} for i in range(num_classes)}

    # Calculate metrics for each test sample
    for i in range(len(X_test)):
        true_mask = y_test_class[i]  # one-hot -> class index
        pred_mask = y_pred[i]

        # For each class (0: Background, 1: Brain, 2: CSP, 3: LV)
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            # Dice Coefficient
            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            # IoU
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            # Precision
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Recall
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # F1 Score
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Accuracy
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # Hausdorff Distance
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # # Average Surface Distance
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # Print class-wise metrics in percentage
    print(f"{'Class':<10}{'Dice Coefficient (%)':<20}{'IoU (%)':<20}{'Precision (%)':<20}{'Recall (%)':<20}{'F1 Score (%)':<20}{'Accuracy (%)':<20}{'Hausdorff Distance':<20}{'Avg Surface Distance':<20}")
    print('-' * 180)

    for class_idx in range(num_classes):
        print(f"Class {class_idx}:")
        print(f"  Dice Coefficient: {np.mean(class_metrics[class_idx]['dice']) * 100:.2f}%")
        print(f"  IoU: {np.mean(class_metrics[class_idx]['iou']) * 100:.2f}%")
        print(f"  Precision: {np.mean(class_metrics[class_idx]['precision']) * 100:.2f}%")
        print(f"  Recall: {np.mean(class_metrics[class_idx]['recall']) * 100:.2f}%")
        print(f"  F1 Score: {np.mean(class_metrics[class_idx]['f1']) * 100:.2f}%")
        print(f"  Accuracy: {np.mean(class_metrics[class_idx]['accuracy']) * 100:.2f}%")
        # print(f"  Hausdorff Distance: {np.mean(class_metrics[class_idx]['hausdorff']):.4f}")
        # print(f"  Average Surface Distance: {np.mean(class_metrics[class_idx]['asd']):.4f}")
        print("-" * 180)

    # Evaluate on test set to print overall test accuracy and loss
    test_loss, *test_metrics = model.evaluate(X_test, y_test, batch_size=batch_size)
    print(f"Test Loss: {test_loss:.4f}")

    for metric, value in zip(model.metrics_names[1:], test_metrics):
        print(f"{metric}: {value:.4f}")

# ‚úÖ Call the evaluation function on the test set class-wise
evaluate_classwise_metrics(student_model, X_test, y_test)

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff
from tensorflow.keras.utils import to_categorical

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ Function to convert RGB masks to class index masks
def rgb_to_class_mask(rgb_mask):
    # Create a mask initialized with zeros (for background class)
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)

    # Loop through the RGB_TO_CLASS dictionary
    for rgb, class_idx in RGB_TO_CLASS.items():
        # Identify the pixels with the current RGB value and assign them the class index
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx

    return class_mask

# ‚úÖ Function to calculate Dice Similarity Coefficient (DSC)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ Function to calculate IoU (Intersection over Union)
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

def hausdorff_distance_debug(y_true, y_pred, threshold=0.5):
    # Apply threshold to convert predicted probabilities (if any) to binary
    y_true = (y_true > threshold).astype(np.uint8)
    y_pred = (y_pred > threshold).astype(np.uint8)

    # Debug: print shapes and unique values
    print(f"[HD] y_true shape: {y_true.shape}, unique: {np.unique(y_true)}")
    print(f"[HD] y_pred shape: {y_pred.shape}, unique: {np.unique(y_pred)}")

    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        print("[HD] ‚ùå Empty mask detected ‚Äî returning inf")
        return float('inf')

    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)


def average_surface_distance_debug(y_true, y_pred, threshold=0.5):
    # Apply threshold
    y_true = (y_true > threshold).astype(np.uint8)
    y_pred = (y_pred > threshold).astype(np.uint8)

    # Debug: print shapes and unique values
    print(f"[ASD] y_true shape: {y_true.shape}, unique: {np.unique(y_true)}")
    print(f"[ASD] y_pred shape: {y_pred.shape}, unique: {np.unique(y_pred)}")

    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        print("[ASD] ‚ùå Empty mask detected ‚Äî returning inf")
        return float('inf')

    distances = [np.min(np.linalg.norm(pred_points - tp, axis=1)) for tp in true_points]
    return np.mean(distances)

# ‚úÖ Function to evaluate the model on the test set class-wise
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=8):
    # Predict in batches
    y_pred = model.predict(X_test, batch_size=batch_size)
    y_pred = np.argmax(y_pred, axis=-1)  # Convert to class index prediction

    # Convert y_test to class index format (since it's one-hot encoded)
    y_test_class = np.argmax(y_test, axis=-1)

    # Initialize lists to store class-wise metrics
    class_metrics ={i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []} for i in range(num_classes)}

    # Calculate metrics for each test sample
    for i in range(len(X_test)):
        true_mask = y_test_class[i]  # one-hot -> class index
        pred_mask = y_pred[i]

        # For each class (0: Background, 1: Brain, 2: CSP, 3: LV)
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            # Dice Coefficient
            # class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            # IoU
            # class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            # Precision
            # class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Recall
            # class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # F1 Score
            # class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Accuracy
            # class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # Hausdorff Distance
            class_metrics[class_idx]['hausdorff'].append(hausdorff_distance_debug(true_class_mask, pred_class_mask))
            # # Average Surface Distance
            class_metrics[class_idx]['asd'].append(average_surface_distance_debug(true_class_mask, pred_class_mask))

    # Print class-wise metrics in percentage
    print(f"{'Class':<10}{'Dice Coefficient (%)':<20}{'IoU (%)':<20}{'Precision (%)':<20}{'Recall (%)':<20}{'F1 Score (%)':<20}{'Accuracy (%)':<20}{'Hausdorff Distance':<20}{'Avg Surface Distance':<20}")
    print('-' * 180)

    for class_idx in range(num_classes):
        print(f"Class {class_idx}:")
        # print(f"  Dice Coefficient: {np.mean(class_metrics[class_idx]['dice']) * 100:.2f}%")
        # print(f"  IoU: {np.mean(class_metrics[class_idx]['iou']) * 100:.2f}%")
        # print(f"  Precision: {np.mean(class_metrics[class_idx]['precision']) * 100:.2f}%")
        # print(f"  Recall: {np.mean(class_metrics[class_idx]['recall']) * 100:.2f}%")
        # print(f"  F1 Score: {np.mean(class_metrics[class_idx]['f1']) * 100:.2f}%")
        # print(f"  Accuracy: {np.mean(class_metrics[class_idx]['accuracy']) * 100:.2f}%")
        print(f"  Hausdorff Distance: {np.mean(class_metrics[class_idx]['hausdorff']):.4f}")
        print(f"  Average Surface Distance: {np.mean(class_metrics[class_idx]['asd']):.4f}")
        print("-" * 180)

    for metric, value in zip(model.metrics_names[1:], test_metrics):
        print(f"{metric}: {value:.4f}")

# ‚úÖ Call the evaluation function on the test set class-wise
evaluate_classwise_metrics(student_model, X_test, y_test)

In [None]:
print(f"true shape: {y_true.shape}, pred shape: {y_pred.shape}")

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import accuracy_score
from scipy.spatial.distance import directed_hausdorff
import matplotlib.pyplot as plt

# --- METRICS ---

def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

def hausdorff_distance_debug(y_true, y_pred):
    y_true = (y_true > 0.5).astype(np.uint8)
    y_pred = (y_pred > 0.5).astype(np.uint8)
    print(f"[HD] y_true shape: {y_true.shape}, unique: {np.unique(y_true)}")
    print(f"[HD] y_pred shape: {y_pred.shape}, unique: {np.unique(y_pred)}")

    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    return max(
        directed_hausdorff(true_points, pred_points)[0],
        directed_hausdorff(pred_points, true_points)[0]
    )

def average_surface_distance_debug(y_true, y_pred):
    y_true = (y_true > 0.5).astype(np.uint8)
    y_pred = (y_pred > 0.5).astype(np.uint8)
    print(f"[ASD] y_true shape: {y_true.shape}, unique: {np.unique(y_true)}")
    print(f"[ASD] y_pred shape: {y_pred.shape}, unique: {np.unique(y_pred)}")

    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    return np.mean([np.min(np.linalg.norm(pred_points - p, axis=1)) for p in true_points])

# --- EVALUATION ---

def evaluate_with_debug(model, X_test, y_test, num_classes=4, batch_size=8, threshold=0.3):
    y_pred = model.predict(X_test, batch_size=batch_size)
    y_pred_class = np.argmax(y_pred, axis=-1)
    y_true_class = np.argmax(y_test, axis=-1)

    metrics = {i: {'dice': [], 'iou': [], 'hd': [], 'asd': []} for i in range(num_classes)}

    for i in range(len(X_test)):
        for class_idx in range(num_classes):
            true_mask = (y_true_class[i] == class_idx).astype(np.uint8)
            pred_mask = (y_pred[i, :, :, class_idx] > threshold).astype(np.uint8)

            dice = dice_coefficient(true_mask, pred_mask)
            iou_score = iou(true_mask, pred_mask)

            if np.any(true_mask) and np.any(pred_mask):
                hd = hausdorff_distance_debug(true_mask, pred_mask)
                asd = average_surface_distance_debug(true_mask, pred_mask)
            else:
                hd = np.nan
                asd = np.nan

            metrics[class_idx]['dice'].append(dice)
            metrics[class_idx]['iou'].append(iou_score)
            metrics[class_idx]['hd'].append(hd)
            metrics[class_idx]['asd'].append(asd)

    return metrics

# --- SUMMARY ---

def summarize_metrics(metrics):
    import pandas as pd
    rows = []
    for class_idx, m in metrics.items():
        row = {
            'Class': class_idx,
            'Dice (%)': np.nanmean(m['dice']) * 100,
            'IoU (%)': np.nanmean(m['iou']) * 100,
            'Hausdorff': np.nanmean(m['hd']),
            'ASD': np.nanmean(m['asd']),
        }
        rows.append(row)
    return pd.DataFrame(rows)

# --- USAGE EXAMPLE ---
metrics = evaluate_with_debug(student_model, X_test, y_test)
summary_df = summarize_metrics(metrics)
# print(summary_df)

In [None]:
print(summary_df)

In [None]:
kd_model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler

# Raw data - Updated to match your actual column names with arrows
data = {
    "Model": ["Unet++(InceptionResNetV2)", "AttentionUnet(EfficientNetB4)", "Deeplabv3+(Xception)", "CNN(SegNet)", "LightSeg"],
    "Training Time (s) ‚Üì": [190.62, 149.88, 102.30, 250.62, 164.84],
    "GFLOPs ‚Üë": [34.78, 8.26, 17.15, 85.32, 18.89],
    "Energy (Wh) ‚Üì": [5.56, 4.00, 3.28, 9.93, 5.51],
    "Params (M) ‚Üì": [69.3, 29.7, 37.8, 31.4, 1.33],
    "Size (MB) ‚Üì": [796, 114, 432, 359, 5.35]
}
df = pd.DataFrame(data)

# Normalize metrics - Updated to match your actual column names
metrics = ["Training Time (s) ‚Üì", "GFLOPs ‚Üë", "Energy (Wh) ‚Üì", "Params (M) ‚Üì", "Size (MB) ‚Üì"]
scaler = MinMaxScaler()
df_norm = df.copy()
df_norm[metrics] = scaler.fit_transform(df[metrics])

# Melt normalized and original data
df_norm_melted = df_norm.melt(id_vars="Model", value_vars=metrics, var_name="Metric", value_name="Normalized Value")
df_orig_melted = df.melt(id_vars="Model", value_vars=metrics, var_name="Metric", value_name="Original Value")

# Merge for annotations
df_plot = pd.merge(df_norm_melted, df_orig_melted, on=["Model", "Metric"])

# Plot
plt.figure(figsize=(16, 8))
ax = sns.barplot(x="Metric", y="Normalized Value", hue="Model", data=df_plot, palette="coolwarm", edgecolor="black")

# Correct way to annotate original values
# Get the order of models as they appear in the legend
model_order = df_plot['Model'].unique()

# Annotate bars with original values
for i, metric in enumerate(metrics):
    metric_data = df_plot[df_plot['Metric'] == metric]
    for j, model in enumerate(model_order):
        model_data = metric_data[metric_data['Model'] == model]
        if not model_data.empty:
            original_value = model_data['Original Value'].iloc[0]
            # Calculate bar position: metric position + offset for each model
            x_pos = i + (j - len(model_order)/2 + 0.5) * (0.8 / len(model_order))
            y_pos = model_data['Normalized Value'].iloc[0]
            
            # Format the annotation based on the metric
            if 'Size' in metric:
                text = f'{original_value:.1f}'
            elif 'Params' in metric:
                text = f'{original_value:.1f}'
            else:
                text = f'{original_value:.2f}'
            
            ax.annotate(text, 
                       (x_pos, y_pos),
                       ha='center', va='bottom', fontsize=9, fontweight='bold',
                       xytext=(0, 3), textcoords='offset points')

# Formatting
plt.title("Normalized Comparison of Model Efficiency Metrics", fontsize=16, fontweight='bold')
plt.ylabel("Normalized Value (0‚Äì1)", fontsize=12, fontweight='bold')
plt.xlabel("Metric", fontsize=12, fontweight='bold')
plt.xticks(fontsize=12, fontweight='bold')
plt.yticks(fontsize=12, fontweight='bold')
plt.legend(title="Model", bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=12, title_fontsize=11)
sns.despine()
plt.tight_layout()
plt.show()

# Alternative simpler approach using seaborn's built-in annotation
plt.figure(figsize=(16, 8))
ax = sns.barplot(x="Metric", y="Normalized Value", hue="Model", data=df_plot, palette="coolwarm", edgecolor="black")

# Simpler annotation approach
for container in ax.containers:
    ax.bar_label(container, labels=[f'{v:.2f}' for v in container.datavalues], 
                fontsize=10, fontweight='bold', padding=3)

plt.title("Normalized Comparison of Model Efficiency Metrics (Alternative)", fontsize=16, fontweight='bold')
plt.ylabel("Normalized Value (0‚Äì1)", fontsize=12, fontweight='bold')
plt.xlabel("Metric", fontsize=12, fontweight='bold')
plt.xticks(fontsize=11, fontweight='bold')
plt.yticks(fontsize=11, fontweight='bold')
plt.legend(title="Model", bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=10, title_fontsize=11)
sns.despine()
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data from the chart
models = ['InceptionResNetV2', 'Xception', 'EfficientNetB4', 'SegNet', 'Soft Voting', 
          'Majority Voting', 'Weighted Soft Voting', 'Proposed Lightseg']
mean_iou = [83.34, 83.25, 89.08, 83.32, 84.21, 85.26, 85.26, 80.39]
mean_dice = [92.07, 91.62, 91.56, 91.34, 91.82, 93.07, 93.09, 92.49]

# Create figure and axis
fig, ax = plt.subplots(figsize=(14, 8))

# Set the width of bars
bar_width = 0.35
x = np.arange(len(models))

# Create vertical bars
bars1 = ax.bar(x - bar_width/2, mean_iou, bar_width, label='Mean IoU', 
               color='#FFB6C1', alpha=0.8, edgecolor='black', linewidth=0.5)
bars2 = ax.bar(x + bar_width/2, mean_dice, bar_width, label='Mean Dice Coefficient', 
               color='#DDA0DD', alpha=0.8, edgecolor='black', linewidth=0.5)

# Add value labels on bars
for i, (iou, dice) in enumerate(zip(mean_iou, mean_dice)):
    ax.text(i - bar_width/2, iou + 0.2, f'{iou:.2f}', 
            ha='center', va='bottom', fontweight='bold', fontsize=10)
    ax.text(i + bar_width/2, dice + 0.2, f'{dice:.2f}', 
            ha='center', va='bottom', fontweight='bold', fontsize=10)

# Customize the plot
ax.set_xticks(x)
ax.set_xticklabels(models, fontsize=11, rotation=45, ha='right')
ax.set_ylabel('Score', fontsize=13, fontweight='bold')
ax.set_ylim(75, 96)

# Add legend
ax.legend(loc='upper left', fontsize=11)

# Add grid for better readability
ax.grid(axis='y', alpha=0.3, linestyle='--')

# Set title
ax.set_title('Model Performance Comparison\n(Mean IoU vs Mean Dice Coefficient)', 
             fontsize=14, fontweight='bold', pad=20)

# Adjust layout
plt.tight_layout()
plt.show()

In [None]:
kd_model.evaluate(X_test, y_test, batch_size=batch_size)

<h1>External Validation</h1>

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

models = [
    model_xception,
    model_segnet,
    model_inceptionresnetv2,
    model_efficientnetb4
]

class WeightedSoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, weights=None, apply_softmax=True):
        super(WeightedSoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        else:
            total = sum(weights)
            weights = [w / total for w in weights]

        self.model_weights = tf.constant(weights, dtype=tf.float32)

    def call(self, x, training=False):
        weighted_sum = 0
        for i, model in enumerate(self.models):
            output = model(x, training=training)

            is_softmaxed = (
                hasattr(model, "name") and "efficientnet" in model.name.lower()
            )

            if self.apply_softmax and not is_softmaxed:
                probs = tf.nn.softmax(output, axis=-1)
            else:
                probs = output

            weighted_sum += self.model_weights[i] * probs

        avg_prob = weighted_sum  # shape: [B, H, W, C]

        # üîÅ Convert to one-hot for metric compatibility
        one_hot_pred = tf.one_hot(tf.argmax(avg_prob, axis=-1), depth=avg_prob.shape[-1])
        return one_hot_pred  # [B, H, W, C]

final_weights = [0.255, 0.2427, 0.2515, 0.2508]
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.0001),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

teacher_model = ensemble_model 

def distillation_loss(y_true, y_student_logits, y_teacher_probs, alpha=0.5, temperature=3.0):
    # Softened predictions for KL
    student_soft = tf.nn.softmax(y_student_logits / temperature)
    teacher_soft = tf.nn.softmax(y_teacher_probs / temperature)

    # Soft loss: KL divergence
    kl_loss = tf.keras.losses.KLDivergence()(teacher_soft, student_soft)

    # Hard loss: Use your custom combined loss (Dice + Lovasz)
    ce_loss = combined_loss(y_true, y_student_logits) + tf.keras.losses.CategoricalCrossentropy()(y_true, y_student_logits)

    # Combine them
    return alpha * ce_loss + (1 - alpha) * (temperature ** 2) * kl_loss

# === KD Wrapper Model ===
class KDTrainer(tf.keras.Model):
    def __init__(self, student, teacher, alpha=0.5, temperature=3.0):
        super(KDTrainer, self).__init__()
        self.student = student
        self.teacher = teacher
        self.alpha = alpha
        self.temperature = temperature

    def compile(self, optimizer, metrics):
        super().compile()
        self.optimizer = optimizer
        self.metrics_list = metrics

    def train_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)

        with tf.GradientTape() as tape:
            student_logits = self.student(x, training=True)               # [B, H, W, C]
            teacher_probs = self.teacher(x, training=False)               # Soft probs

            loss = distillation_loss(
                y_true, student_logits, teacher_probs,
                alpha=self.alpha, temperature=self.temperature
            )

        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))

        for metric in self.metrics_list:
            metric.update_state(y_true, student_logits)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

    def test_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)
        y_pred = self.student(x, training=False)
        loss = combined_loss(y_true, y_pred)

        for metric in self.metrics_list:
            metric.update_state(y_true, y_pred)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

# === Instantiate KDTrainer ===
kd_model = KDTrainer(
    student=student_model,
    teacher=teacher_model,
    alpha=0.5,
    temperature=1.0
)

# === Compile ===
kd_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    metrics=class_wise_metrics(num_classes=4)
)


<h3>Brightness</h3>

In [None]:
batch_size = 8

In [None]:
model_xception.evaluate(X_test_bright, y_test, batch_size=batch_size)

In [None]:
model_segnet.evaluate(X_test_bright, y_test, batch_size=batch_size)

In [None]:
model_inceptionresnetv2.evaluate(X_test_bright, y_test, batch_size=batch_size)

In [None]:
model_efficientnetb4.evaluate(X_test_bright, y_test, batch_size=batch_size)

In [None]:
teacher_model.evaluate(X_test_bright, y_test, batch_size=batch_size)

In [None]:
kd_model.evaluate(X_test_bright, y_test, batch_size=batch_size)

<h3>Dark</h3>

In [None]:
model_xception.evaluate(X_test_dark, y_test, batch_size=batch_size)

In [None]:
model_segnet.evaluate(X_test_dark, y_test, batch_size=batch_size)

In [None]:
model_inceptionresnetv2.evaluate(X_test_dark, y_test, batch_size=batch_size)

In [None]:
model_efficientnetb4.evaluate(X_test_dark, y_test, batch_size=batch_size)

In [None]:
teacher_model.evaluate(X_test_dark, y_test, batch_size=batch_size)

In [None]:
kd_model.evaluate(X_test_dark, y_test, batch_size=batch_size)

<h3>Blur 20%</h3>

In [None]:
model_xception.evaluate(X_test_blur_20, y_test, batch_size=batch_size)

In [None]:
model_segnet.evaluate(X_test_blur_20, y_test, batch_size=batch_size)

In [None]:
model_inceptionresnetv2.evaluate(X_test_blur_20, y_test, batch_size=batch_size)

In [None]:
model_efficientnetb4.evaluate(X_test_blur_20, y_test, batch_size=batch_size)

In [None]:
teacher_model.evaluate(X_test_blur_20, y_test, batch_size=batch_size)

In [None]:
kd_model.evaluate(X_test_blur_20, y_test, batch_size=batch_size)

<h3>Blur 40%</h3>

In [None]:
model_xception.evaluate(X_test_blur_40, y_test, batch_size=batch_size)

In [None]:
model_segnet.evaluate(X_test_blur_40, y_test, batch_size=batch_size)

In [None]:
model_inceptionresnetv2.evaluate(X_test_blur_40, y_test, batch_size=batch_size)

In [None]:
model_efficientnetb4.evaluate(X_test_blur_40, y_test, batch_size=batch_size)

In [None]:
teacher_model.evaluate(X_test_blur_40, y_test, batch_size=batch_size)

In [None]:
kd_model.evaluate(X_test_blur_40, y_test, batch_size=batch_size)