In [None]:
import numpy as np
import os
import gc
import cv2
import re
from tensorflow.keras.utils import to_categorical

# ‚úÖ Constants for 224x224
IMG_HEIGHT = 224  # Ensure height is 224
IMG_WIDTH = 224   # Ensure width is 224
CHANNELS = 3  # RGB images
NUM_CLASSES = 4  # Brain, CSP, LV, Background

# ‚úÖ Class mapping from RGB to class index
CLASS_MAP = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0,  # Background
}

image_dir = r"D:\augmented_dataset\images"
mask_dir = r"D:\augmented_dataset\masks"

# # ‚úÖ Define destination directories
train_image_dir = r"D:\Updated\train\images"
train_mask_dir = r"D:\Updated\train\masks"
val_image_dir = r"D:\Updated\val\images"
val_mask_dir = r"D:\Updated\val\masks"
test_image_dir = r"D:\Updated\test\images"
test_mask_dir = r"D:\Updated\test\masks"

# ‚úÖ Fix sorting issue using natural sorting
def natural_sort_key(s):
    """Sort filenames numerically instead of lexicographically."""
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

# ‚úÖ Convert RGB mask to class index mask
def rgb_to_class(mask_array):
    """Convert RGB mask to single-channel class index mask."""
    height, width, _ = mask_array.shape
    class_mask = np.zeros((height, width), dtype=np.uint8)

    for rgb, class_idx in CLASS_MAP.items():
        matches = np.all(mask_array == rgb, axis=-1)  # Ensure exact match
        class_mask[matches] = class_idx

    return class_mask

# ‚úÖ Preprocess Filtered Dataset for 224x224
def preprocess_filtered_dataset(image_dir, mask_dir):
    """Preprocess images & masks: normalize, resize, and convert masks to one-hot encoding."""

    # ‚úÖ Load and sort filenames correctly
    image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)
    mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)

    valid_image_paths = []
    valid_mask_paths = []

    # ‚úÖ Ensure each image has a corresponding mask
    for img_file, mask_file in zip(image_filenames, mask_filenames):
        img_path = os.path.join(image_dir, img_file)
        mask_path = os.path.join(mask_dir, mask_file)

        if os.path.exists(img_path) and os.path.exists(mask_path):
            valid_image_paths.append(img_path)
            valid_mask_paths.append(mask_path)
        else:
            print(f"‚ö†Ô∏è Skipping {img_file}: Missing image or mask")

    num_images = len(valid_image_paths)

    # ‚úÖ Initialize arrays
    X = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=np.float32)
    y = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.float32)  # One-hot encoded masks

    print(f"üöÄ Processing {num_images} filtered images and masks...")

    for idx, (img_path, mask_path) in enumerate(zip(valid_image_paths, valid_mask_paths)):
        if idx % 100 == 0:
            print(f"‚úÖ Processed {idx}/{num_images} images")

        # ‚úÖ Load and Resize Image
        img = cv2.imread(img_path)  # Read image in BGR format
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))  # Resize to (224,224)
        img = img.astype(np.float32) / 255.0  # Normalize

        # ‚úÖ Load and Resize Mask
        mask = cv2.imread(mask_path)  # Read mask in BGR format
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)  # Resize mask correctly

        # ‚úÖ Convert RGB mask to class mask
        class_mask = rgb_to_class(mask)

        # ‚úÖ One-hot encode the class mask
        one_hot_mask = to_categorical(class_mask, num_classes=NUM_CLASSES)

        # ‚úÖ Store preprocessed data
        X[idx] = img
        y[idx] = one_hot_mask

        # ‚úÖ Clear memory to prevent memory leaks
        del img, mask, class_mask, one_hot_mask
        gc.collect()

    return X, y

from sklearn.model_selection import train_test_split

# # ‚úÖ Process dataset splits
X_train, y_train = preprocess_filtered_dataset(train_image_dir, train_mask_dir)
X_val, y_val = preprocess_filtered_dataset(val_image_dir, val_mask_dir)
# X_test, y_test = preprocess_filtered_dataset(test_image_dir, test_mask_dir)

# ‚úÖ Print dataset information
print("\n‚úÖ Dataset Splits:")
print(f"  - Training set: {X_train.shape}, {y_train.shape}")
print(f"  - Validation set: {X_val.shape}, {y_val.shape}")
# print(f"  - Test set: {X_test.shape}, {y_test.shape}")

In [None]:
import numpy as np
import os
import gc
import cv2
import re
from tensorflow.keras.utils import to_categorical

# ‚úÖ Constants for 224x224
IMG_HEIGHT = 224  # Ensure height is 224
IMG_WIDTH = 224   # Ensure width is 224
CHANNELS = 3  # RGB images
NUM_CLASSES = 4  # Brain, CSP, LV, Background

# ‚úÖ Class mapping from RGB to class index
CLASS_MAP = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0,  # Background
}

image_dir = r"D:\augmented_dataset\images"
mask_dir = r"D:\augmented_dataset\masks"

# # ‚úÖ Define destination directories
train_image_dir = r"D:\Updated\train\images"
train_mask_dir = r"D:\Updated\train\masks"
val_image_dir = r"D:\Updated\val\images"
val_mask_dir = r"D:\Updated\val\masks"
test_image_dir = r"D:\Updated\synthetic_dataset\images"
test_mask_dir = r"D:\Updated\synthetic_dataset\masks"

# ‚úÖ Fix sorting issue using natural sorting
def natural_sort_key(s):
    """Sort filenames numerically instead of lexicographically."""
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

# ‚úÖ Convert RGB mask to class index mask
def rgb_to_class(mask_array):
    """Convert RGB mask to single-channel class index mask."""
    height, width, _ = mask_array.shape
    class_mask = np.zeros((height, width), dtype=np.uint8)

    for rgb, class_idx in CLASS_MAP.items():
        matches = np.all(mask_array == rgb, axis=-1)  # Ensure exact match
        class_mask[matches] = class_idx

    return class_mask

# ‚úÖ Preprocess Filtered Dataset for 224x224
def preprocess_filtered_dataset(image_dir, mask_dir):
    """Preprocess images & masks: normalize, resize, and convert masks to one-hot encoding."""

    # ‚úÖ Load and sort filenames correctly
    image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)
    mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)

    valid_image_paths = []
    valid_mask_paths = []

    # ‚úÖ Ensure each image has a corresponding mask
    for img_file, mask_file in zip(image_filenames, mask_filenames):
        img_path = os.path.join(image_dir, img_file)
        mask_path = os.path.join(mask_dir, mask_file)

        if os.path.exists(img_path) and os.path.exists(mask_path):
            valid_image_paths.append(img_path)
            valid_mask_paths.append(mask_path)
        else:
            print(f"‚ö†Ô∏è Skipping {img_file}: Missing image or mask")

    num_images = len(valid_image_paths)

    # ‚úÖ Initialize arrays
    X = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=np.float32)
    y = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.float32)  # One-hot encoded masks

    print(f"üöÄ Processing {num_images} filtered images and masks...")

    for idx, (img_path, mask_path) in enumerate(zip(valid_image_paths, valid_mask_paths)):
        if idx % 100 == 0:
            print(f"‚úÖ Processed {idx}/{num_images} images")

        # ‚úÖ Load and Resize Image
        img = cv2.imread(img_path)  # Read image in BGR format
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))  # Resize to (224,224)
        img = img.astype(np.float32) / 255.0  # Normalize

        # ‚úÖ Load and Resize Mask
        mask = cv2.imread(mask_path)  # Read mask in BGR format
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)  # Resize mask correctly

        # ‚úÖ Convert RGB mask to class mask
        class_mask = rgb_to_class(mask)

        # ‚úÖ One-hot encode the class mask
        one_hot_mask = to_categorical(class_mask, num_classes=NUM_CLASSES)

        # ‚úÖ Store preprocessed data
        X[idx] = img
        y[idx] = one_hot_mask

        # ‚úÖ Clear memory to prevent memory leaks
        del img, mask, class_mask, one_hot_mask
        gc.collect()

    return X, y

from sklearn.model_selection import train_test_split

# # ‚úÖ Process dataset splits
# X_train, y_train = preprocess_filtered_dataset(train_image_dir, train_mask_dir)
# X_val, y_val = preprocess_filtered_dataset(val_image_dir, val_mask_dir)
X_test, y_test = preprocess_filtered_dataset(test_image_dir, test_mask_dir)

# ‚úÖ Print dataset information
print("\n‚úÖ Dataset Splits:")
# print(f"  - Training set: {X_train.shape}, {y_train.shape}")
# print(f"  - Validation set: {X_val.shape}, {y_val.shape}")
print(f"  - Test set: {X_test.shape}, {y_test.shape}")

In [None]:
X_val, y_val = preprocess_filtered_dataset(val_image_dir, val_mask_dir)

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Input
from tensorflow.keras.models import Model

# Constants for 224x224 images
IMG_HEIGHT = 224  # Changed from 256 to 224
IMG_WIDTH = 224   # Changed from 256 to 224
CHANNELS = 3  # RGB images
NUM_CLASSES = 4  # Brain, CSP, LV, Background

def conv_block(inputs, filters, kernel_size=(3, 3), padding='same', strides=1):
    """
    Double convolution block with batch normalization
    """
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    x = layers.Conv2D(filters, kernel_size, padding=padding)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    return x

def build_segnet(input_shape, num_classes):
    """
    Build SegNet model
    """
    inputs = Input(input_shape)
    
    # Encoder
    # Block 1
    conv1 = conv_block(inputs, 64)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2), padding='same')(conv1)
    
    # Block 2
    conv2 = conv_block(pool1, 128)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2), padding='same')(conv2)
    
    # Block 3
    conv3 = conv_block(pool2, 256)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2), padding='same')(conv3)
    
    # Block 4
    conv4 = conv_block(pool3, 512)
    pool4 = layers.MaxPooling2D(pool_size=(2, 2), padding='same')(conv4)
    
    # Bridge
    conv5 = conv_block(pool4, 1024)
    
    # Decoder
    # Block 4
    up4 = layers.UpSampling2D(size=(2, 2))(conv5)
    up4 = layers.concatenate([up4, conv4], axis=-1)
    up_conv4 = conv_block(up4, 512)
    
    # Block 3
    up3 = layers.UpSampling2D(size=(2, 2))(up_conv4)
    up3 = layers.concatenate([up3, conv3], axis=-1)
    up_conv3 = conv_block(up3, 256)
    
    # Block 2
    up2 = layers.UpSampling2D(size=(2, 2))(up_conv3)
    up2 = layers.concatenate([up2, conv2], axis=-1)
    up_conv2 = conv_block(up2, 128)
    
    # Block 1
    up1 = layers.UpSampling2D(size=(2, 2))(up_conv2)
    up1 = layers.concatenate([up1, conv1], axis=-1)
    up_conv1 = conv_block(up1, 64)
    
    # Output
    outputs = layers.Conv2D(num_classes, (1, 1), activation='softmax')(up_conv1)
    
    model = Model(inputs=[inputs], outputs=[outputs])
    return model

# Build model
model = build_segnet(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), 
                     num_classes=NUM_CLASSES)

# Print model summary
model.summary()

In [None]:
# import tensorflow as tf
# from tensorflow.keras import backend as K
# from tensorflow.keras.optimizers import Adam
# from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

# # ‚úÖ Dice Coefficient Metric
# def dice_coefficient(y_true, y_pred):
#     smooth = 1e-15
#     y_true = tf.cast(y_true, tf.float32)
#     y_pred = tf.cast(y_pred, tf.float32)
#     intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
#     union = tf.reduce_sum(y_true, axis=[1,2,3]) + tf.reduce_sum(y_pred, axis=[1,2,3])
#     dice = (2. * intersection + smooth) / (union + smooth)
#     return tf.reduce_mean(dice)

# def weighted_categorical_crossentropy(y_true, y_pred):
#     class_weights = tf.constant([0.3794, 0.7521, 69.7061, 49.3458], dtype=tf.float32)

#     # Ensure y_true has the same shape as y_pred
#     y_true = tf.cast(y_true, tf.float32)  # Make sure it's float32 for numerical stability
#     y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)  # Avoid log(0)

#     # Compute categorical cross-entropy
#     loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)  # Sum over the last axis (class axis)

#     # Reshape the class weights to match the loss shape
#     class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))  # [1, 1, 1, 4]

#     # Apply the class weights
#     weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)  # Broadcast weights over the batch and spatial dimensions

 #     return tf.reduce_mean(weighted_loss)

# # ‚úÖ Dice Loss
# def dice_loss(y_true, y_pred):
#     smooth = 1e-6
#     y_true = tf.cast(y_true, y_pred.dtype)
#     intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
#     union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
#     dice = (2. * intersection + smooth) / (union + smooth)
#     return 1 - tf.reduce_mean(dice)

# # ‚úÖ Combined Loss Function
# def combined_loss(y_true, y_pred):
#     return weighted_categorical_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# # ‚úÖ Custom Dice Coefficient Metric for Each Class
# class DiceCoefficient(tf.keras.metrics.Metric):
#     def __init__(self, class_idx, name=None, **kwargs):  
#         if name is None:
#             name = f"DiceClass{class_idx}"  
#         super(DiceCoefficient, self).__init__(name=name, **kwargs)
        
#         self.class_idx = class_idx
#         self.dice = self.add_weight(name="dice", initializer="zeros")

#     def update_state(self, y_true, y_pred, sample_weight=None):
#         y_true_class = y_true[..., self.class_idx]
#         y_pred_class = y_pred[..., self.class_idx]

#         intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
#         union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])

#         dice = (2. * intersection + 1e-6) / (union + 1e-6)
#         self.dice.assign(tf.reduce_mean(dice))

#     def result(self):
#         return self.dice

# # ‚úÖ Function to Get Class-wise Metrics
# def class_wise_metrics(num_classes=4):
#     return [DiceCoefficient(i) for i in range(num_classes)] + [tf.keras.metrics.MeanIoU(num_classes=num_classes)]

# # ‚úÖ Compile the Model
# model.compile(
#     optimizer=Adam(learning_rate=1e-4),
#     loss=combined_loss,
#     metrics=class_wise_metrics(4)  # Number of classes (adjust if needed)
# )

# # ‚úÖ Train the Model with the **Filtered Dataset**
# history = model.fit(
#     X_train, y_train,  # Use the loaded and split data
#     validation_data=(X_val, y_val),
#     epochs=50,
#     batch_size=16,  # Adjust based on available resources
#     callbacks=[
#         EarlyStopping(
#             monitor='val_loss', 
#             patience=7, 
#             restore_best_weights=True
#         ),
#         ReduceLROnPlateau(
#             monitor='val_loss', 
#             factor=0.5, 
#             patience=3, 
#             min_lr=1e-6
#         ),
#         ModelCheckpoint(
#             '/kaggle/working/best_unet_model_filtered.keras',  # Save in Kaggle working directory with .keras extension
#             monitor='val_loss',
#             save_best_only=True
#         )
#     ]
# )
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# ‚úÖ Dice Coefficient Metric
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
    union = tf.reduce_sum(y_true, axis=[1,2,3]) + tf.reduce_sum(y_pred, axis=[1,2,3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Custom Dice Coefficient Metric for Each Class
class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx, name=None, **kwargs):  
        if name is None:
            name = f"DiceClass{class_idx}"  
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

# ‚úÖ Function to Get Class-wise Metrics
def class_wise_metrics(num_classes=4):
    return [DiceCoefficient(i) for i in range(num_classes)] + [tf.keras.metrics.MeanIoU(num_classes=num_classes)]

# ‚úÖ Create Data Generator
def create_train_generator(X, y, batch_size=16):
    data_gen_args = dict(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)
    
    seed = 42
    image_generator = image_datagen.flow(X, batch_size=batch_size, seed=seed)
    mask_generator = mask_datagen.flow(y, batch_size=batch_size, seed=seed)
    
    while True:
        X_batch = next(image_generator)
        y_batch = next(mask_generator)
        yield X_batch, y_batch

# ‚úÖ Create the generator
train_generator = create_train_generator(X_train, y_train, batch_size=16)

# ‚úÖ Compile the Model
import tensorflow as tf
import numpy as np

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through each class using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())


def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = 1 - tf.reduce_mean(dice)

    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)

    return lovasz_loss_val + dice_loss_val


# Usage in model compilation
model.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss=combined_loss,
    metrics=class_wise_metrics(4)  # Number of classes
)

# ‚úÖ Train the Model with the Data Generator
history = model.fit(
    train_generator,
    steps_per_epoch=len(X_train) // 16,
    validation_data=(X_val, y_val),
    epochs=100,
    callbacks=[
        EarlyStopping(
            monitor='val_loss', 
            patience=10, 
            restore_best_weights=True
        ),
        ReduceLROnPlateau(
            monitor='val_loss', 
            factor=0.5, 
            patience=3, 
            min_lr=1e-6
        ),
        ModelCheckpoint(
            'best_unet_model_onlineDA_128_lovaszloss_segnet.keras',
            monitor='val_loss',
            save_best_only=True
        )
    ]
)

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

# Number of classes (adjust if needed)
NUM_CLASSES = 4

# ‚úÖ Dice Coefficient (Mean across all classes)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Lov√°sz-Softmax Loss
def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through classes using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())

# ‚úÖ Combined Loss
def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)
    
    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val

class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx=0, name=None, **kwargs):  # <- default class_idx=0 to avoid missing arg
        if name is None:
            name = f"DiceClass{class_idx}"
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

    def get_config(self):
        config = super().get_config()
        config.update({"class_idx": self.class_idx})
        return config

    @classmethod
    def from_config(cls, config):
        if "class_idx" not in config:
            # Try to extract class index from name like "DiceClass2"
            name = config.get("name", "DiceClass0")
            if name.startswith("DiceClass"):
                config["class_idx"] = int(name.replace("DiceClass", ""))
            else:
                config["class_idx"] = 0
        return cls(**config)

# ‚úÖ Helper to load Dice metrics by name
def dice_metric_loader(name):
    if name.startswith("DiceClass"):
        class_idx = int(name.replace("DiceClass", ""))
        return DiceCoefficient(class_idx=class_idx)
    raise ValueError(f"Unknown Dice metric name: {name}")

# ‚úÖ Register all custom objects for loading the model
custom_objects = {
    'combined_loss': combined_loss,
    'lovasz_softmax_loss': lovasz_softmax_loss,
    'MeanIoU': tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES),
    'DiceCoefficient': DiceCoefficient,
}

# ‚úÖ Add DiceClass0‚Äì3 dynamically
for i in range(NUM_CLASSES):
    custom_objects[f'DiceClass{i}'] = dice_metric_loader(f'DiceClass{i}')

# ‚úÖ Load the model
model_segnet = load_model('C:\\Users\\User\\best_unet_model_onlineDA_128_lovaszloss_segnet.keras', custom_objects=custom_objects)

print("‚úÖ Model loaded successfully.")

In [None]:
from tensorflow.keras.utils import Sequence
import cv2
import numpy as np
import os

class ImageMaskGenerator(Sequence):
    def __init__(self, image_paths, mask_paths, batch_size=4, num_classes=4, img_size=(224, 224), shuffle=True):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.img_size = img_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.image_paths))
        self.CLASS_MAP = {
            (255, 0, 0): 1,
            (0, 255, 0): 2,
            (0, 0, 255): 3,
            (0, 0, 0): 0,
        }
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.image_paths) / self.batch_size))

    def __getitem__(self, index):
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        batch_images = []
        batch_masks = []

        for i in batch_indices:
            img = cv2.imread(self.image_paths[i])
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, self.img_size)
            img = img.astype(np.float32) / 255.0

            mask = cv2.imread(self.mask_paths[i])
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
            mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_NEAREST)
            mask = self.rgb_to_class(mask)
            mask = tf.keras.utils.to_categorical(mask, num_classes=self.num_classes)

            batch_images.append(img)
            batch_masks.append(mask)

        return np.array(batch_images), np.array(batch_masks)

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

    def rgb_to_class(self, mask_array):
        h, w, _ = mask_array.shape
        class_mask = np.zeros((h, w), dtype=np.uint8)
        for rgb, class_idx in self.CLASS_MAP.items():
            matches = np.all(mask_array == rgb, axis=-1)
            class_mask[matches] = class_idx
        return class_mask


import os

def load_paths(image_dir, mask_dir):
    images = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')])
    masks = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith('.png')])
    return images, masks

train_imgs, train_masks = load_paths(train_image_dir, train_mask_dir)
val_imgs, val_masks = load_paths(val_image_dir, val_mask_dir)


In [None]:
train_gen = ImageMaskGenerator(train_imgs, train_masks, batch_size=8)
val_gen = ImageMaskGenerator(val_imgs, val_masks, batch_size=8)


In [None]:
def run_training(X_train, y_train, X_val, y_val, batch_size=8, epochs=3, repeats=1):
    from tensorflow.keras import backend as K
    import gc

    epoch_times_all = []
    power_samples_all = []

    for r in range(repeats):
        print(f"\nüîÅ Repeat {r+1}/{repeats}")

        # Clean up previous session
        K.clear_session()
        gc.collect()

        model = tf.keras.models.load_model(
            'C:\\Users\\User\\best_unet_model_onlineDA_128_lovaszloss_segnet.keras',
            custom_objects=custom_objects
        )

        start = time.time()

        # Power monitoring
        power_proc = subprocess.Popen(
            ['nvidia-smi', '--query-gpu=power.draw', '--format=csv,noheader,nounits', '-lms', '500'],
            stdout=subprocess.PIPE,
            stderr=subprocess.DEVNULL,
            text=True
        )

        # # Train all epochs in one go
        # model.fit(
        #     X_train, y_train,
        #     batch_size=batch_size,
        #     epochs=epochs,
        #     validation_data=(X_val, y_val),
        #     verbose=1
        # )
        model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=epochs,
            verbose=1
        )


        end = time.time()
        total_time = end - start
        avg_epoch_time = total_time / epochs
        epoch_times_all.extend([avg_epoch_time] * epochs)

        # Handle power logs
        power_proc.terminate()
        try:
            power_output = power_proc.stdout.read().strip().split('\n')
            power_values = [float(line) for line in power_output if line.strip()]
            avg_power = np.mean(power_values)
            power_samples_all.extend([avg_power] * epochs)
            print(f"‚ö° Avg Power: {avg_power:.2f} W")
        except:
            print("‚ö†Ô∏è Power log failed.")
            power_samples_all.extend([np.nan] * epochs)

        # Final cleanup
        del model
        gc.collect()
        K.clear_session()

    return epoch_times_all, power_samples_all

In [None]:
epoch_times, power_vals = run_training(X_train, y_train, X_val, y_val)

# Compute stats
epoch_times = np.array(epoch_times)
power_vals = np.array(power_vals)

mean_time = np.mean(epoch_times)
std_time = np.std(epoch_times)

mean_power = np.nanmean(power_vals)
energy_per_epoch_wh = (mean_power * mean_time) / 3600

# Estimate GFLOPS per epoch (assuming 4 GFLOPs/sample)
samples_per_epoch = len(X_train)
estimated_flops_per_sample = 4e9  # 4 GFLOPs
gflops = (2 * estimated_flops_per_sample * samples_per_epoch) / (mean_time * 1e9)

print("\nüìä Summary:")
print(f"‚è±Ô∏è  Average epoch time: {mean_time:.2f} ¬± {std_time:.2f} sec")
print(f"‚öôÔ∏è  Estimated GFLOPS: {gflops:.2f}")
print(f"‚ö° Average power: {mean_power:.2f} W")
print(f"üîã Energy per epoch: {energy_per_epoch_wh:.4f} Wh")

In [None]:

import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, UpSampling2D, Concatenate, BatchNormalization, Activation
from tensorflow.keras.applications import InceptionResNetV2
import gc

# Constants
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS = 3
NUM_CLASSES = 4  # Brain, CSP, LV, Background

class ResizeLayer(tf.keras.layers.Layer):
    """Custom layer to resize images."""
    def __init__(self, target_size, **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.target_size = target_size
    
    def call(self, inputs):
        return tf.image.resize(inputs, self.target_size, method='bilinear')
    
    def get_config(self):
        config = super(ResizeLayer, self).get_config()
        config.update({"target_size": self.target_size})
        return config

def conv_block(x, filters, kernel_size=3, padding='same', activation='relu'):
    """Helper function for creating a conv block with BN and activation."""
    x = Conv2D(filters, kernel_size, padding=padding)(x)
    x = BatchNormalization()(x)
    x = Activation(activation)(x)
    # Add a second conv to increase parameters
    x = Conv2D(filters, kernel_size, padding=padding)(x)
    x = BatchNormalization()(x)
    x = Activation(activation)(x)
    return x

def build_full_inceptionresnetv2_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), num_classes=NUM_CLASSES):
    """
    Build a full UNet model with InceptionResNetV2 backbone with 60-70M parameters
    
    Args:
        input_shape: Input shape of the image
        num_classes: Number of output classes
        
    Returns:
        Keras Model instance with UNet architecture
    """
    # Input layer (no fixed batch size)
    inputs = Input(shape=input_shape)
    
    # Create a full InceptionResNetV2 model to use as backbone
    base_model = InceptionResNetV2(
        input_tensor=inputs,
        include_top=False,
        weights='imagenet',
        pooling=None
    )
    
    # Make all layers trainable as requested
    for layer in base_model.layers:
        layer.trainable = True
    
    # Extract features from all encoder levels
    # Standard blocks in InceptionResNetV2
    encoder1 = base_model.get_layer('activation').output  # 111x111x64
    encoder2 = base_model.get_layer('activation_3').output  # 55x55x192
    encoder3 = base_model.get_layer('block35_10_ac').output  # 27x27x320
    encoder4 = base_model.get_layer('block17_20_ac').output  # 13x13x1088
    encoder5 = base_model.get_layer('conv_7b_ac').output  # 6x6x2080
    
    # Use the bottleneck as is - don't reduce its channels
    bottleneck = encoder5  # 6x6x2080
    
    # First, reduce the bottleneck dimensions to control parameter count
    bottleneck = Conv2D(512, 1, padding='same')(bottleneck)
    bottleneck = BatchNormalization()(bottleneck)
    bottleneck = Activation('relu')(bottleneck)
    
    # Level 5 to 4: 6x6 -> 13x13
    up4 = UpSampling2D(size=(2, 2))(bottleneck)
    up4 = ResizeLayer(target_size=(encoder4.shape[1], encoder4.shape[2]))(up4)
    up4 = conv_block(up4, 512, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels before concatenation
    skip4 = Conv2D(256, 1, padding='same')(encoder4)
    skip4 = BatchNormalization()(skip4)
    skip4 = Activation('relu')(skip4)
    
    # Concatenate with skip connection
    merge4 = Concatenate()([up4, skip4])
    merge4 = conv_block(merge4, 384)  # Reduced filters
    
    # Level 4 to 3: 13x13 -> 27x27
    up3 = UpSampling2D(size=(2, 2))(merge4)
    up3 = ResizeLayer(target_size=(encoder3.shape[1], encoder3.shape[2]))(up3)
    up3 = conv_block(up3, 384, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels
    skip3 = Conv2D(128, 1, padding='same')(encoder3)
    skip3 = BatchNormalization()(skip3)
    skip3 = Activation('relu')(skip3)
    
    # Concatenate with skip connection
    merge3 = Concatenate()([up3, skip3])
    merge3 = conv_block(merge3, 192)  # Reduced filters
    
    # Level 3 to 2: 27x27 -> 55x55
    up2 = UpSampling2D(size=(2, 2))(merge3)
    up2 = ResizeLayer(target_size=(encoder2.shape[1], encoder2.shape[2]))(up2)
    up2 = conv_block(up2, 192, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels
    skip2 = Conv2D(96, 1, padding='same')(encoder2)
    skip2 = BatchNormalization()(skip2)
    skip2 = Activation('relu')(skip2)
    
    # Concatenate with skip connection
    merge2 = Concatenate()([up2, skip2])
    merge2 = conv_block(merge2, 96)  # Reduced filters
    
    # Level 2 to 1: 55x55 -> 111x111
    up1 = UpSampling2D(size=(2, 2))(merge2)
    up1 = ResizeLayer(target_size=(encoder1.shape[1], encoder1.shape[2]))(up1)
    up1 = conv_block(up1, 96, kernel_size=3)  # Reduced filters
    
    # Reduce skip connection channels
    skip1 = Conv2D(48, 1, padding='same')(encoder1)
    skip1 = BatchNormalization()(skip1)
    skip1 = Activation('relu')(skip1)
    
    # Concatenate with skip connection
    merge1 = Concatenate()([up1, skip1])
    merge1 = conv_block(merge1, 48)  # Reduced filters
    
    # Final upsampling to original resolution: 111x111 -> 224x224
    up_final = UpSampling2D(size=(2, 2))(merge1)
    up_final = conv_block(up_final, 32)  # Reduced filters
    
    # Ensure final size matches input
    if up_final.shape[1] != input_shape[0] or up_final.shape[2] != input_shape[1]:
        up_final = ResizeLayer(target_size=(input_shape[0], input_shape[1]))(up_final)
    
    # Add a final segmentation head
    outputs = Conv2D(num_classes, 1, activation='softmax', dtype='float32')(up_final)
    
    # Create and return the model
    model = Model(inputs=inputs, outputs=outputs)
    
    return model

# Create the model
print("Creating full InceptionResNetV2-UNet model...")
model = build_full_inceptionresnetv2_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), num_classes=NUM_CLASSES)
print("Model created successfully!")

# Clear memory
gc.collect()
tf.keras.backend.clear_session()

# Model summary
model.summary()

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

class ResizeLayer(tf.keras.layers.Layer):
    """Custom layer to resize images."""
    def __init__(self, target_size, **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.target_size = target_size
    
    def call(self, inputs):
        return tf.image.resize(inputs, self.target_size, method='bilinear')
    
    def get_config(self):
        config = super(ResizeLayer, self).get_config()
        config.update({"target_size": self.target_size})
        return config

# Number of classes (adjust if needed)
NUM_CLASSES = 4

# ‚úÖ Dice Coefficient (Mean across all classes)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Lov√°sz-Softmax Loss
def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through classes using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())

# ‚úÖ Combined Loss
def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)
    
    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val

class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx=0, name=None, **kwargs):  # <- default class_idx=0 to avoid missing arg
        if name is None:
            name = f"DiceClass{class_idx}"
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

    def get_config(self):
        config = super().get_config()
        config.update({"class_idx": self.class_idx})
        return config

    @classmethod
    def from_config(cls, config):
        if "class_idx" not in config:
            # Try to extract class index from name like "DiceClass2"
            name = config.get("name", "DiceClass0")
            if name.startswith("DiceClass"):
                config["class_idx"] = int(name.replace("DiceClass", ""))
            else:
                config["class_idx"] = 0
        return cls(**config)

# ‚úÖ Helper to load Dice metrics by name
def dice_metric_loader(name):
    if name.startswith("DiceClass"):
        class_idx = int(name.replace("DiceClass", ""))
        return DiceCoefficient(class_idx=class_idx)
    raise ValueError(f"Unknown Dice metric name: {name}")

# ‚úÖ Register all custom objects for loading the model
custom_objects = {
    'ResizeLayer': ResizeLayer,
    'combined_loss': combined_loss,
    'lovasz_softmax_loss': lovasz_softmax_loss,
    'MeanIoU': tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES),
    'DiceCoefficient': DiceCoefficient,
}

# ‚úÖ Add DiceClass0‚Äì3 dynamically
for i in range(NUM_CLASSES):
    custom_objects[f'DiceClass{i}'] = dice_metric_loader(f'DiceClass{i}')

# ‚úÖ Load the model
model_inceptionresnetv2 = load_model('lovaszloss_unet++_inceptionresnetv2.keras', custom_objects=custom_objects)

print("‚úÖ Model loaded successfully.")

In [None]:
# SAFE custom_objects (no functions returning arrays)
custom_objects = {
    'ResizeLayer': ResizeLayer,
    'combined_loss': combined_loss,
    'lovasz_softmax_loss': lovasz_softmax_loss,
    'MeanIoU': tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES),
    'DiceCoefficient': DiceCoefficient,  # needed to deserialize
}

# Add DiceClass0‚Äì3 safely as instances
for i in range(NUM_CLASSES):
    custom_objects[f'DiceClass{i}'] = DiceCoefficient(class_idx=i)


In [None]:
def run_training(train_gen, val_gen, model_path, custom_objects, batch_size=8, epochs=3, repeats=1):
    import subprocess
    import time
    import numpy as np
    from tensorflow.keras import backend as K
    import gc

    epoch_times_all = []
    power_samples_all = []

    for r in range(repeats):
        print(f"\nüîÅ Repeat {r+1}/{repeats}")

        # ‚úÖ Clean up previous session
        K.clear_session()
        gc.collect()

        # ‚úÖ Reload the model (InceptionResNetV2-UNet++)
        model = tf.keras.models.load_model(model_path, custom_objects=custom_objects)

        # ‚úÖ Start timing and power monitoring
        start = time.time()

        power_proc = subprocess.Popen(
            ['nvidia-smi', '--query-gpu=power.draw', '--format=csv,noheader,nounits', '-lms', '500'],
            stdout=subprocess.PIPE,
            stderr=subprocess.DEVNULL,
            text=True
        )

        # ‚úÖ Train using generator (saves memory)
        model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=epochs,
            verbose=1
        )

        # ‚úÖ Stop timing
        end = time.time()
        total_time = end - start
        avg_epoch_time = total_time / epochs
        epoch_times_all.extend([avg_epoch_time] * epochs)

        # ‚úÖ Stop and process power readings
        power_proc.terminate()
        try:
            power_output = power_proc.stdout.read().strip().split('\n')
            power_values = [float(line) for line in power_output if line.strip()]
            avg_power = np.mean(power_values)
            power_samples_all.extend([avg_power] * epochs)
            print(f"‚ö° Avg Power: {avg_power:.2f} W")
        except Exception as e:
            print(f"‚ö†Ô∏è Power log failed: {e}")
            power_samples_all.extend([np.nan] * epochs)

        # ‚úÖ Cleanup
        del model
        gc.collect()
        K.clear_session()

    return epoch_times_all, power_samples_all

In [None]:
model_path = 'lovaszloss_unet++_inceptionresnetv2.keras'

epoch_times, power_vals = run_training(
    train_gen=train_gen,
    val_gen=val_gen,
    model_path=model_path,
    custom_objects=custom_objects,
    batch_size=4,
    epochs=3,
    repeats=1
)

# Compute stats
epoch_times = np.array(epoch_times)
power_vals = np.array(power_vals)

mean_time = np.mean(epoch_times)
std_time = np.std(epoch_times)

mean_power = np.nanmean(power_vals)
energy_per_epoch_wh = (mean_power * mean_time) / 3600

# Estimate GFLOPS per epoch (assuming 4 GFLOPs/sample)
samples_per_epoch = len(X_train)
estimated_flops_per_sample = 4e9  # 4 GFLOPs
gflops = (2 * estimated_flops_per_sample * samples_per_epoch) / (mean_time * 1e9)

print("\nüìä Summary:")
print(f"‚è±Ô∏è  Average epoch time: {mean_time:.2f} ¬± {std_time:.2f} sec")
print(f"‚öôÔ∏è  Estimated GFLOPS: {gflops:.2f}")
print(f"‚ö° Average power: {mean_power:.2f} W")
print(f"üîã Energy per epoch: {energy_per_epoch_wh:.4f} Wh")

In [None]:
import tensorflow as tf
from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Activation, Add
from tensorflow.keras.layers import Dense, Dropout, Layer, Reshape, Permute, Multiply, Concatenate
from tensorflow.keras.layers import GlobalAveragePooling2D, LayerNormalization, UpSampling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.applications import EfficientNetB4

class ResizeToMatchLayer(Layer):
    """Layer to resize input to match target tensor's spatial dimensions."""
    def __init__(self, **kwargs):
        super(ResizeToMatchLayer, self).__init__(**kwargs)
    
    def call(self, inputs):
        x, target = inputs
        # Get spatial dimensions of target tensor
        target_shape = tf.shape(target)
        target_height, target_width = target_shape[1], target_shape[2]
        
        # Resize x to match target's spatial dimensions
        return tf.image.resize(x, [target_height, target_width], method='bilinear')
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[1][1], input_shape[1][2], input_shape[0][3])

def conv_block(x, filters, kernel_size=3, strides=1, padding='same', use_bn=True, activation='relu'):
    """Standard convolution block with BatchNorm and activation."""
    x = Conv2D(filters, kernel_size, strides=strides, padding=padding)(x)
    
    if use_bn:
        x = BatchNormalization()(x)
    
    if activation:
        x = Activation(activation)(x)
    
    return x

def attention_gate(x, g, inter_channels):
    """
    Attention Gate as described in Attention U-Net paper.
    Args:
        x: Feature map from skip connection (from encoder)
        g: Gating signal from previous decoder layer
        inter_channels: Number of channels in intermediate representations
    """
    # Resize gating signal to match feature map's spatial dimensions if needed
    g = ResizeToMatchLayer()([g, x])
    
    # Intermediate representation for input feature map
    theta_x = Conv2D(inter_channels, 1, use_bias=False, padding='same')(x)
    
    # Intermediate representation for gating signal
    phi_g = Conv2D(inter_channels, 1, use_bias=False, padding='same')(g)
    
    # Element-wise sum and ReLU
    f = Activation('relu')(Add()([theta_x, phi_g]))
    
    # 1x1 convolution followed by sigmoid to get attention coefficients
    psi_f = Conv2D(1, 1, use_bias=False, padding='same')(f)
    att_map = Activation('sigmoid')(psi_f)
    
    # Apply attention
    return Multiply()([x, att_map])

def decoder_block(x, skip_connection, filters, use_attention=True):
    """Decoder block for Attention U-Net."""
    # Upsampling
    x = UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
    
    # Ensure dimensions match for concatenation
    x = ResizeToMatchLayer()([x, skip_connection])
    
    # Apply attention mechanism if specified
    if use_attention:
        # Generate attention-gated skip connection
        skip_connection = attention_gate(skip_connection, x, filters // 2)
    
    # Concatenate with skip connection
    x = Concatenate()([x, skip_connection])
    
    # Apply two convolution blocks
    x = conv_block(x, filters, 3, padding='same')
    x = conv_block(x, filters, 3, padding='same')
    
    return x

def build_efficientnet_attention_unet(input_shape, num_classes):
    """
    Build an Attention U-Net model with EfficientNetB4 backbone for semantic segmentation.
    
    Args:
        input_shape: Input shape of the image (height, width, channels)
        num_classes: Number of segmentation classes
        
    Returns:
        A Keras Model instance
    """
    inputs = Input(shape=input_shape)
        
    # Load EfficientNetB4 with pre-trained weights as encoder backbone
    # All layers are trainable for fine-tuning
    base_model = EfficientNetB4(
        weights='imagenet',
        include_top=False,
        input_tensor=inputs
    )
    
    # Reduce filter count to control parameter count since we're not freezing any layers
    initial_filters = 32
    
    # Get skip connections from appropriate layers
    skip1 = base_model.get_layer('block1b_add').output        # 1/2 scale (112x112)
    skip2 = base_model.get_layer('block2d_add').output        # 1/4 scale (56x56)
    skip3 = base_model.get_layer('block3d_add').output        # 1/8 scale (28x28)
    skip4 = base_model.get_layer('block5e_add').output        # 1/16 scale (14x14)
    
    # Bridge (bottleneck)
    bridge = base_model.get_layer('top_activation').output    # 1/32 scale (7x7)
    
    
    # Reduce channels for each skip connection to control parameter count
    skip1_conv = conv_block(skip1, initial_filters)
    skip2_conv = conv_block(skip2, initial_filters * 2)
    skip3_conv = conv_block(skip3, initial_filters * 4)
    skip4_conv = conv_block(skip4, initial_filters * 8)
    
    # Reduce channels in bridge
    bridge_conv = conv_block(bridge, initial_filters * 16)
    
    # Decoder pathway with attention gates
    d1 = decoder_block(bridge_conv, skip4_conv, initial_filters * 8, use_attention=True)  # 1/16
    d2 = decoder_block(d1, skip3_conv, initial_filters * 4, use_attention=True)           # 1/8
    d3 = decoder_block(d2, skip2_conv, initial_filters * 2, use_attention=True)           # 1/4
    d4 = decoder_block(d3, skip1_conv, initial_filters, use_attention=True)               # 1/2
    
    # Final upsampling to original image size
    final = UpSampling2D(size=(2, 2), interpolation='bilinear')(d4)
    
    # Final convolution to generate segmentation map
    outputs = Conv2D(num_classes, 1, padding='same', activation='softmax')(final)
    
    # Create and return the model
    model = Model(inputs=inputs, outputs=outputs)
    
    return model

# Build the model
model = build_efficientnet_attention_unet(input_shape=(224, 224, 3), num_classes=4)

# Print model summary
model.summary()

In [None]:
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Lov√°sz-Softmax Loss
def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through classes using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())

# ‚úÖ Combined Loss
def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)
    
    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val

In [None]:
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
    union = tf.reduce_sum(y_true, axis=[1,2,3]) + tf.reduce_sum(y_pred, axis=[1,2,3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Custom Dice Coefficient Metric for Each Class
class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx, name=None, **kwargs):  
        if name is None:
            name = f"DiceClass{class_idx}"  
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

# ‚úÖ Function to Get Class-wise Metrics
def class_wise_metrics(num_classes=4):
    return [DiceCoefficient(i) for i in range(num_classes)] + [tf.keras.metrics.MeanIoU(num_classes=num_classes)]

model_efficientnetb4 = build_efficientnet_attention_unet(input_shape=(224, 224, 3), num_classes=4)
model_efficientnetb4.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss=combined_loss,
    metrics=class_wise_metrics(4)  # Number of classes
)
model_efficientnetb4.load_weights("efficientnet_attention_unet_weights.h5")

In [None]:
custom_objects_efficientnet = {
    'combined_loss': combined_loss,
    'dice_loss': dice_loss,
    'weighted_categorical_crossentropy': weighted_categorical_crossentropy,
    'DiceCoefficient': DiceCoefficient,
    'MeanIoU': tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES)
}

# Dynamically add DiceClass0‚Äì3
for i in range(NUM_CLASSES):
    custom_objects_efficientnet[f'DiceClass{i}'] = DiceCoefficient(class_idx=i)

def load_efficientnetb4_model():
    model = build_efficientnet_attention_unet(input_shape=(224, 224, 3), num_classes=4)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss=combined_loss,
        metrics=class_wise_metrics(4)
    )
    model.load_weights("efficientnet_attention_unet_weights.h5")
    return model

def run_training(train_gen, val_gen, model_loader_fn, custom_objects, batch_size=8, epochs=3, repeats=1):
    import subprocess
    import time
    import numpy as np
    from tensorflow.keras import backend as K
    import gc

    epoch_times_all = []
    power_samples_all = []

    for r in range(repeats):
        print(f"\nüîÅ Repeat {r+1}/{repeats}")

        K.clear_session()
        gc.collect()

        # Build and load model
        model = model_loader_fn()

        start = time.time()

        power_proc = subprocess.Popen(
            ['nvidia-smi', '--query-gpu=power.draw', '--format=csv,noheader,nounits', '-lms', '500'],
            stdout=subprocess.PIPE,
            stderr=subprocess.DEVNULL,
            text=True
        )

        model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=epochs,
            verbose=1
        )

        end = time.time()
        total_time = end - start
        avg_epoch_time = total_time / epochs
        epoch_times_all.extend([avg_epoch_time] * epochs)

        power_proc.terminate()
        try:
            power_output = power_proc.stdout.read().strip().split('\n')
            power_values = [float(line) for line in power_output if line.strip()]
            avg_power = np.mean(power_values)
            power_samples_all.extend([avg_power] * epochs)
            print(f"‚ö° Avg Power: {avg_power:.2f} W")
        except Exception as e:
            print(f"‚ö†Ô∏è Power log failed: {e}")
            power_samples_all.extend([np.nan] * epochs)

        del model
        gc.collect()
        K.clear_session()

    return epoch_times_all, power_samples_all

In [None]:
import numpy as np
epoch_times, power_vals = run_training(
    train_gen=train_gen,
    val_gen=val_gen,
    model_loader_fn=load_efficientnetb4_model,  # note: function, not string path
    custom_objects=custom_objects_efficientnet,
    batch_size=4,
    epochs=3,
    repeats=1
)

epoch_times = np.array(epoch_times)
power_vals = np.array(power_vals)

mean_time = np.mean(epoch_times)
std_time = np.std(epoch_times)
mean_power = np.nanmean(power_vals)
energy_wh = (mean_power * mean_time) / 3600

samples_per_epoch = len(train_gen) * train_gen.batch_size
flops_per_sample = 4e9  # adjust if you have exact FLOPs
gflops = (2 * flops_per_sample * samples_per_epoch) / (mean_time * 1e9)

print("\nüìä Summary:")
print(f"‚è±Ô∏è  Avg epoch time: {mean_time:.2f} ¬± {std_time:.2f} sec")
print(f"‚öôÔ∏è  Estimated GFLOPS: {gflops:.2f}")
print(f"‚ö°  Avg power: {mean_power:.2f} W")
print(f"üîã  Avg energy/epoch: {energy_wh:.4f} Wh")

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.applications import Xception

# Constants for 224x224 images
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS = 3
NUM_CLASSES = 4  # Brain, CSP, LV, Background

def convolution_block(inputs, filters, kernel_size=3, dilation_rate=1, padding='same', use_bias=False):
    """
    Standard convolution block with batch normalization and ReLU activation
    """
    x = layers.Conv2D(
        filters, 
        kernel_size, 
        padding=padding,
        dilation_rate=dilation_rate,
        use_bias=use_bias
    )(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

def ASPP(inputs):
    """
    Atrous Spatial Pyramid Pooling module for DeepLabV3+
    """
    # ASPP with different dilation rates
    b0 = convolution_block(inputs, 256, kernel_size=1, dilation_rate=1)
    b1 = convolution_block(inputs, 256, kernel_size=3, dilation_rate=6)
    b2 = convolution_block(inputs, 256, kernel_size=3, dilation_rate=12)
    b3 = convolution_block(inputs, 256, kernel_size=3, dilation_rate=18)
    
    # Global context - simplified approach
    b4 = layers.GlobalAveragePooling2D()(inputs)
    b4 = layers.Reshape((1, 1, inputs.shape[-1]))(b4)
    b4 = convolution_block(b4, 256, kernel_size=1)
    # Use fixed upsampling instead of dynamic
    b4 = layers.UpSampling2D(size=(inputs.shape[1], inputs.shape[2]))(b4)
    
    # Concatenate all branches
    x = layers.Concatenate()([b0, b1, b2, b3, b4])
    
    # Final 1x1 convolution
    output = convolution_block(x, 256, kernel_size=1)
    return output

def build_deeplabv3_plus_xception(input_shape, num_classes):
    """
    DeepLabV3+ model with Xception backbone
    """
    inputs = Input(input_shape)
    
    # Xception as backbone (with output stride of 16)
    base_model = Xception(
        input_tensor=inputs,
        include_top=False,
        weights='imagenet'
    )
    
    # Don't freeze any layers
    for layer in base_model.layers:
        layer.trainable = True
    
    # Extract features from Xception
    # The entry flow ends with 'block4_sepconv2_bn' which is a good low-level feature point
    low_level_features = base_model.get_layer('block4_sepconv2_bn').output
    # The final features from the exit flow
    high_level_features = base_model.output
    
    # Process low-level features
    low_level_features = convolution_block(low_level_features, 48, kernel_size=1)
    
    # Process high-level features with ASPP
    x = ASPP(high_level_features)
    
    # Calculate upsampling factor for high-level features to match low-level features
    hl_shape = high_level_features.shape
    ll_shape = low_level_features.shape
    h_factor = ll_shape[1] // hl_shape[1]
    w_factor = ll_shape[2] // hl_shape[2]
    
    # Upsample high-level features to match low-level features
    x = layers.UpSampling2D(size=(h_factor, w_factor), interpolation='bilinear')(x)
    
    # Concatenate features
    x = layers.Concatenate()([x, low_level_features])
    
    # Apply convolution blocks
    x = convolution_block(x, 256, kernel_size=3)
    x = convolution_block(x, 256, kernel_size=3)
    
    # Calculate upsampling factor needed to reach 224x224
    current_shape = x.shape
    h_factor = IMG_HEIGHT // current_shape[1]
    w_factor = IMG_WIDTH // current_shape[2]
    
    # Final upsampling to original size (224x224)
    x = layers.UpSampling2D(size=(h_factor, w_factor), interpolation='bilinear')(x)
    
    # Ensure exact dimensions with a reshape if needed
    x = layers.Reshape((IMG_HEIGHT, IMG_WIDTH, int(current_shape[3])))(x)
    
    # Final convolution for output (224, 224, 4)
    outputs = layers.Conv2D(num_classes, kernel_size=1, padding='same', activation='softmax')(x)
    
    # Create model
    model = Model(inputs=inputs, outputs=outputs)
    return model

# Build model
model = build_deeplabv3_plus_xception(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), 
                                     num_classes=NUM_CLASSES)

# Print model summary
model.summary()

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

# Number of classes (adjust if needed)
NUM_CLASSES = 4

# ‚úÖ Dice Coefficient (Mean across all classes)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Lov√°sz-Softmax Loss
def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through classes using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())

# ‚úÖ Combined Loss
def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)
    
    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val

class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx=0, name=None, **kwargs):  # <- default class_idx=0 to avoid missing arg
        if name is None:
            name = f"DiceClass{class_idx}"
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

    def get_config(self):
        config = super().get_config()
        config.update({"class_idx": self.class_idx})
        return config

    @classmethod
    def from_config(cls, config):
        if "class_idx" not in config:
            # Try to extract class index from name like "DiceClass2"
            name = config.get("name", "DiceClass0")
            if name.startswith("DiceClass"):
                config["class_idx"] = int(name.replace("DiceClass", ""))
            else:
                config["class_idx"] = 0
        return cls(**config)

# ‚úÖ Helper to load Dice metrics by name
def dice_metric_loader(name):
    if name.startswith("DiceClass"):
        class_idx = int(name.replace("DiceClass", ""))
        return DiceCoefficient(class_idx=class_idx)
    raise ValueError(f"Unknown Dice metric name: {name}")



custom_objects = {
    'combined_loss': combined_loss,
    'lovasz_softmax_loss': lovasz_softmax_loss,
    'MeanIoU': tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES),
    'DiceCoefficient': DiceCoefficient,
}

# ‚úÖ Add DiceClass0‚Äì3 dynamically
for i in range(NUM_CLASSES):
    custom_objects[f'DiceClass{i}'] = dice_metric_loader(f'DiceClass{i}')

# ‚úÖ Load the model
model_xception = load_model('lovaszloss_deeplabv3_xception.keras', custom_objects=custom_objects)

In [None]:
custom_objects_xception = {
    'combined_loss': combined_loss,
    'lovasz_softmax_loss': lovasz_softmax_loss,
    'MeanIoU': tf.keras.metrics.MeanIoU(num_classes=NUM_CLASSES),
    'DiceCoefficient': DiceCoefficient,
}

# Add DiceClass0‚Äì3 directly
for i in range(NUM_CLASSES):
    custom_objects_xception[f'DiceClass{i}'] = DiceCoefficient(class_idx=i)


In [None]:
def run_training(train_gen, val_gen, model_loader_fn, custom_objects, batch_size=8, epochs=3, repeats=1):
    import subprocess
    import time
    import numpy as np
    from tensorflow.keras import backend as K
    import gc

    epoch_times_all = []
    power_samples_all = []

    for r in range(repeats):
        print(f"\nüîÅ Repeat {r+1}/{repeats}")

        K.clear_session()
        gc.collect()

        # Build and load model
        model = model_loader_fn()

        start = time.time()

        power_proc = subprocess.Popen(
            ['nvidia-smi', '--query-gpu=power.draw', '--format=csv,noheader,nounits', '-lms', '500'],
            stdout=subprocess.PIPE,
            stderr=subprocess.DEVNULL,
            text=True
        )

        model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=epochs,
            verbose=1
        )

        end = time.time()
        total_time = end - start
        avg_epoch_time = total_time / epochs
        epoch_times_all.extend([avg_epoch_time] * epochs)

        power_proc.terminate()
        try:
            power_output = power_proc.stdout.read().strip().split('\n')
            power_values = [float(line) for line in power_output if line.strip()]
            avg_power = np.mean(power_values)
            power_samples_all.extend([avg_power] * epochs)
            print(f"‚ö° Avg Power: {avg_power:.2f} W")
        except Exception as e:
            print(f"‚ö†Ô∏è Power log failed: {e}")
            power_samples_all.extend([np.nan] * epochs)

        del model
        gc.collect()
        K.clear_session()

    return epoch_times_all, power_samples_all


In [None]:
def load_deeplabv3_xception_model():
    return tf.keras.models.load_model(
        'lovaszloss_deeplabv3_xception.keras',
        custom_objects=custom_objects_xception
    )

In [None]:
import numpy as np

epoch_times, power_vals = run_training(
    train_gen=train_gen,
    val_gen=val_gen,
    model_loader_fn=load_deeplabv3_xception_model,
    custom_objects=custom_objects_xception,
    batch_size=8,
    epochs=3,
    repeats=1
)

epoch_times = np.array(epoch_times)
power_vals = np.array(power_vals)

mean_time = np.mean(epoch_times)
std_time = np.std(epoch_times)
mean_power = np.nanmean(power_vals)
energy_wh = (mean_power * mean_time) / 3600

samples_per_epoch = len(train_gen) * train_gen.batch_size
flops_per_sample = 4e9  # Rough estimate for Xception-based DeepLab
gflops = (2 * flops_per_sample * samples_per_epoch) / (mean_time * 1e9)

print("\nüìä Summary:")
print(f"‚è±Ô∏è  Avg epoch time: {mean_time:.2f} ¬± {std_time:.2f} sec")
print(f"‚öôÔ∏è  Estimated GFLOPS: {gflops:.2f}")
print(f"‚ö°  Avg power: {mean_power:.2f} W")
print(f"üîã  Avg energy/epoch: {energy_wh:.4f} Wh")


In [None]:
import tensorflow as tf

class SoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, apply_softmax=True):
        """
        Args:
            models (list): List of pretrained tf.keras.Model instances.
            apply_softmax (bool): Apply softmax to logits before averaging.
        """
        super(SoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

    def call(self, x, training=False):
        prob_sum = 0

        for model in self.models:
            logits = model(x, training=training)
            probs = tf.nn.softmax(logits, axis=-1) if self.apply_softmax else logits
            prob_sum += probs

        avg_prob = prob_sum / len(self.models)
        final_pred = tf.argmax(avg_prob, axis=-1)  # shape: [B, H, W]
        return final_pred

# Assume you have trained models: model1, model2, model3
ensemble_model = SoftVotingEnsemble([model_segnet, model_inceptionresnetv2])

In [None]:
# class SoftVotingEnsemble(tf.keras.Model):
#     def __init__(self, models, apply_softmax=True, return_probs=False):
#         """
#         Args:
#             models (list): List of pretrained tf.keras.Model instances.
#             apply_softmax (bool): Apply softmax to logits before averaging.
#             return_probs (bool): If True, return averaged class probabilities (for loss/metrics).
#                                  If False, return argmax predictions (for direct inference).
#         """
#         super(SoftVotingEnsemble, self).__init__()
#         self.models = models
#         self.apply_softmax = apply_softmax
#         self.return_probs = return_probs

#     def call(self, x, training=False):
#         prob_sum = 0
#         for model in self.models:
#             logits = model(x, training=training)
#             probs = tf.nn.softmax(logits, axis=-1) if self.apply_softmax else logits
#             prob_sum += probs
#         avg_prob = prob_sum / len(self.models)
        
#         if self.return_probs:
#             return avg_prob  # shape: [B, H, W, C]
#         else:
#             return tf.argmax(avg_prob, axis=-1)  # shape: [B, H, W]

class SoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, apply_softmax=True, return_probs=False):
        """
        Args:
            models (list): List of pretrained tf.keras.Model instances.
            apply_softmax (bool): Apply softmax to logits before averaging.
            return_probs (bool): If True, return averaged class probabilities (for loss/metrics).
                                 If False, return argmax predictions (for direct inference).
        """
        super(SoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax
        self.return_probs = return_probs

    def call(self, x, training=False):
        prob_sum = 0
        for model in self.models:
            output = model(x, training=training)

            # Detect EfficientNet-like model (already softmaxed)
            is_softmaxed = (
                hasattr(model, "name") and "efficientnet" in model.name.lower()
            )

            if self.apply_softmax and not is_softmaxed:
                probs = tf.nn.softmax(output, axis=-1)
            else:
                probs = output

            prob_sum += probs

        avg_prob = prob_sum / len(self.models)

        if self.return_probs:
            return avg_prob
        else:
            return tf.argmax(avg_prob, axis=-1)

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff
from tensorflow.keras.utils import to_categorical

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ Function to convert RGB masks to class index masks
def rgb_to_class_mask(rgb_mask):
    # Create a mask initialized with zeros (for background class)
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)

    # Loop through the RGB_TO_CLASS dictionary
    for rgb, class_idx in RGB_TO_CLASS.items():
        # Identify the pixels with the current RGB value and assign them the class index
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx

    return class_mask

# ‚úÖ Function to calculate Dice Similarity Coefficient (DSC)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ Function to calculate IoU (Intersection over Union)
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Function to calculate Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')  # Return inf if no points for either true or pred class

    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Function to calculate Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')  # Return inf if no points for either true or pred class

    distances = []
    for true_point in true_points:
        distances.append(np.min(np.linalg.norm(pred_points - true_point, axis=1)))
    return np.mean(distances)

# ‚úÖ Function to evaluate the model on the test set class-wise
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=16):
    # Predict in batches
    y_pred = model.predict(X_test, batch_size=batch_size)
    y_pred = np.argmax(y_pred, axis=-1)  # Convert to class index prediction

    # Convert y_test to class index format (since it's one-hot encoded)
    y_test_class = np.argmax(y_test, axis=-1)

    # Initialize lists to store class-wise metrics
    class_metrics = {i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []} for i in range(num_classes)}

    # Calculate metrics for each test sample
    for i in range(len(X_test)):
        true_mask = y_test_class[i]  # one-hot -> class index
        pred_mask = y_pred[i]

        # For each class (0: Background, 1: Brain, 2: CSP, 3: LV)
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            # Dice Coefficient
            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            # IoU
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            # Precision
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Recall
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # F1 Score
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Accuracy
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # # Hausdorff Distance
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # # Average Surface Distance
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # Print class-wise metrics in percentage
    print(f"{'Class':<10}{'Dice Coefficient (%)':<20}{'IoU (%)':<20}{'Precision (%)':<20}{'Recall (%)':<20}{'F1 Score (%)':<20}{'Accuracy (%)':<20}{'Hausdorff Distance':<20}{'Avg Surface Distance':<20}")
    print('-' * 180)

    for class_idx in range(num_classes):
        print(f"Class {class_idx}:")
        print(f"  Dice Coefficient: {np.mean(class_metrics[class_idx]['dice']) * 100:.2f}%")
        print(f"  IoU: {np.mean(class_metrics[class_idx]['iou']) * 100:.2f}%")
        print(f"  Precision: {np.mean(class_metrics[class_idx]['precision']) * 100:.2f}%")
        print(f"  Recall: {np.mean(class_metrics[class_idx]['recall']) * 100:.2f}%")
        print(f"  F1 Score: {np.mean(class_metrics[class_idx]['f1']) * 100:.2f}%")
        print(f"  Accuracy: {np.mean(class_metrics[class_idx]['accuracy']) * 100:.2f}%")
        # print(f"  Hausdorff Distance: {np.mean(class_metrics[class_idx]['hausdorff']):.4f}")
        # print(f"  Average Surface Distance: {np.mean(class_metrics[class_idx]['asd']):.4f}")
        print("-" * 180)

    # Evaluate on test set to print overall test accuracy and loss
    # test_loss, *test_metrics = model.evaluate(X_test, y_test, batch_size=batch_size)
    # print(f"Test Loss: {test_loss:.4f}")

    for metric, value in zip(model.metrics_names[1:], test_metrics):
        print(f"{metric}: {value:.4f}")

# ‚úÖ Call the evaluation function on the test set class-wise
evaluate_classwise_metrics(model_segnet, X_test, y_test)

In [None]:
def class_wise_metrics(num_classes=4):
    return [DiceCoefficient(i) for i in range(num_classes)] + [tf.keras.metrics.MeanIoU(num_classes=num_classes)]

ensemble_model = SoftVotingEnsemble(
    models=[model_segnet, model_inceptionresnetv2],
    apply_softmax=False,
    return_probs=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)


In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ RGB mask to class mask
def rgb_to_class_mask(rgb_mask):
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)
    for rgb, class_idx in RGB_TO_CLASS.items():
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx
    return class_mask

# ‚úÖ Dice
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ IoU
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    distances = [np.min(np.linalg.norm(pred_points - pt, axis=1)) for pt in true_points]
    return np.mean(distances)

# # ‚úÖ Soft Voting Ensemble Model
# class SoftVotingEnsemble(tf.keras.Model):
#     def __init__(self, models, apply_softmax=True):
#         super(SoftVotingEnsemble, self).__init__()
#         self.models = models
#         self.apply_softmax = apply_softmax

#     def call(self, x, training=False):
#         prob_sum = 0
#         for model in self.models:
#             logits = model(x, training=training)
#             probs = tf.nn.softmax(logits, axis=-1) if self.apply_softmax else logits
#             prob_sum += probs
#         avg_prob = prob_sum / len(self.models)
#         final_pred = tf.argmax(avg_prob, axis=-1)
#         return final_pred

# ‚úÖ Evaluation Function (for normal + ensemble models)
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=16, is_ensemble=False):
    if is_ensemble:
        y_pred = []
        for i in range(0, len(X_test), batch_size):
            batch_x = X_test[i:i+batch_size]
            preds = model(batch_x, training=False).numpy()  # [B, H, W]
            y_pred.extend(preds)
        y_pred = np.array(y_pred)
    else:
        y_pred = model.predict(X_test, batch_size=batch_size)
        y_pred = np.argmax(y_pred, axis=-1)  # [B, H, W]

    y_test_class = np.argmax(y_test, axis=-1)

    class_metrics = {
        i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []}
        for i in range(num_classes)
    }

    for i in range(len(X_test)):
        true_mask = y_test_class[i]
        pred_mask = y_pred[i]
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # üìä Print results
    print(f"{'Class':<10}{'Dice Coef (%)':<15}{'IoU (%)':<12}{'Precision (%)':<17}{'Recall (%)':<15}{'F1 Score (%)':<17}{'Accuracy (%)':<17}")
    print("-" * 100)

    for class_idx in range(num_classes):
        print(f"{class_idx:<10}"
              f"{np.mean(class_metrics[class_idx]['dice']) * 100:>10.2f}"
              f"{np.mean(class_metrics[class_idx]['iou']) * 100:>12.2f}"
              f"{np.mean(class_metrics[class_idx]['precision']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['recall']) * 100:>15.2f}"
              f"{np.mean(class_metrics[class_idx]['f1']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['accuracy']) * 100:>17.2f}")
        # Optional: Print Hausdorff and ASD if needed

ensemble_model = SoftVotingEnsemble(
    models=[model_segnet, model_xception],
    apply_softmax=True
)

evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_segnet, model_xception],
    apply_softmax=False,
    return_probs=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_segnet, model_efficientnetb4],
    apply_softmax=True
)

evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_segnet, model_efficientnetb4],
    apply_softmax=False,      # because both models already output probabilities
    return_probs=True         # required for custom loss/metrics to work
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_xception, model_efficientnetb4],
    apply_softmax=True
)

evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_xception, model_efficientnetb4],
    apply_softmax=False,      # because both models already output probabilities
    return_probs=True         # required for custom loss/metrics to work
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_inceptionresnetv2, model_efficientnetb4],
    apply_softmax=True
)

evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_inceptionresnetv2, model_efficientnetb4],
    apply_softmax=False,      # because both models already output probabilities
    return_probs=True         # required for custom loss/metrics to work
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_segnet, model_inceptionresnetv2, model_efficientnetb4],
    apply_softmax=True
)

evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_segnet, model_inceptionresnetv2, model_efficientnetb4],
    apply_softmax=False,      # because both models already output probabilities
    return_probs=True         # required for custom loss/metrics to work
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_xception, model_inceptionresnetv2, model_efficientnetb4],
    apply_softmax=True
)

evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_xception, model_inceptionresnetv2, model_efficientnetb4],
    apply_softmax=False,      # because both models already output probabilities
    return_probs=True         # required for custom loss/metrics to work
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_segnet, model_xception, model_efficientnetb4],
    apply_softmax=True
)

evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_segnet, model_xception, model_efficientnetb4],
    apply_softmax=False,      # because both models already output probabilities
    return_probs=True         # required for custom loss/metrics to work
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_segnet, model_xception, model_inceptionresnetv2, model_efficientnetb4],
    apply_softmax=True
)

evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_segnet, model_xception, model_inceptionresnetv2, model_efficientnetb4],
    apply_softmax=False,      # because both models already output probabilities
    return_probs=True         # required for custom loss/metrics to work
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ RGB mask to class mask
def rgb_to_class_mask(rgb_mask):
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)
    for rgb, class_idx in RGB_TO_CLASS.items():
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx
    return class_mask

# ‚úÖ Dice
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ IoU
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    distances = [np.min(np.linalg.norm(pred_points - pt, axis=1)) for pt in true_points]
    return np.mean(distances)

# ‚úÖ Soft Voting Ensemble Model
class SoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, apply_softmax=True):
        super(SoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

    def call(self, x, training=False):
        prob_sum = 0
        for model in self.models:
            logits = model(x, training=training)
            probs = tf.nn.softmax(logits, axis=-1) if self.apply_softmax else logits
            prob_sum += probs
        avg_prob = prob_sum / len(self.models)
        final_pred = tf.argmax(avg_prob, axis=-1)
        return final_pred

# ‚úÖ Evaluation Function (for normal + ensemble models)
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=16, is_ensemble=False):
    if is_ensemble:
        y_pred = []
        for i in range(0, len(X_test), batch_size):
            batch_x = X_test[i:i+batch_size]
            preds = model(batch_x, training=False).numpy()  # [B, H, W]
            y_pred.extend(preds)
        y_pred = np.array(y_pred)
    else:
        y_pred = model.predict(X_test, batch_size=batch_size)
        y_pred = np.argmax(y_pred, axis=-1)  # [B, H, W]

    y_test_class = np.argmax(y_test, axis=-1)

    class_metrics = {
        i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []}
        for i in range(num_classes)
    }

    for i in range(len(X_test)):
        true_mask = y_test_class[i]
        pred_mask = y_pred[i]
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # üìä Print results
    print(f"{'Class':<10}{'Dice Coef (%)':<15}{'IoU (%)':<12}{'Precision (%)':<17}{'Recall (%)':<15}{'F1 Score (%)':<17}{'Accuracy (%)':<17}")
    print("-" * 100)

    for class_idx in range(num_classes):
        print(f"{class_idx:<10}"
              f"{np.mean(class_metrics[class_idx]['dice']) * 100:>10.2f}"
              f"{np.mean(class_metrics[class_idx]['iou']) * 100:>12.2f}"
              f"{np.mean(class_metrics[class_idx]['precision']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['recall']) * 100:>15.2f}"
              f"{np.mean(class_metrics[class_idx]['f1']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['accuracy']) * 100:>17.2f}")
        # Optional: Print Hausdorff and ASD if needed

ensemble_model = SoftVotingEnsemble(
    models=[model_xception, model_inceptionresnetv2],
    apply_softmax=True
)

evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_xception, model_inceptionresnetv2],
    apply_softmax=False,
    return_probs=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ RGB mask to class mask
def rgb_to_class_mask(rgb_mask):
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)
    for rgb, class_idx in RGB_TO_CLASS.items():
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx
    return class_mask

# ‚úÖ Dice
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ IoU
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    distances = [np.min(np.linalg.norm(pred_points - pt, axis=1)) for pt in true_points]
    return np.mean(distances)

# ‚úÖ Soft Voting Ensemble Model
class SoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, apply_softmax=True):
        super(SoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

    def call(self, x, training=False):
        prob_sum = 0
        for model in self.models:
            logits = model(x, training=training)
            probs = tf.nn.softmax(logits, axis=-1) if self.apply_softmax else logits
            prob_sum += probs
        avg_prob = prob_sum / len(self.models)
        final_pred = tf.argmax(avg_prob, axis=-1)
        return final_pred

# ‚úÖ Evaluation Function (for normal + ensemble models)
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=16, is_ensemble=False):
    if is_ensemble:
        y_pred = []
        for i in range(0, len(X_test), batch_size):
            batch_x = X_test[i:i+batch_size]
            preds = model(batch_x, training=False).numpy()  # [B, H, W]
            y_pred.extend(preds)
        y_pred = np.array(y_pred)
    else:
        y_pred = model.predict(X_test, batch_size=batch_size)
        y_pred = np.argmax(y_pred, axis=-1)  # [B, H, W]

    y_test_class = np.argmax(y_test, axis=-1)

    class_metrics = {
        i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []}
        for i in range(num_classes)
    }

    for i in range(len(X_test)):
        true_mask = y_test_class[i]
        pred_mask = y_pred[i]
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # üìä Print results
    print(f"{'Class':<10}{'Dice Coef (%)':<15}{'IoU (%)':<12}{'Precision (%)':<17}{'Recall (%)':<15}{'F1 Score (%)':<17}{'Accuracy (%)':<17}")
    print("-" * 100)

    for class_idx in range(num_classes):
        print(f"{class_idx:<10}"
              f"{np.mean(class_metrics[class_idx]['dice']) * 100:>10.2f}"
              f"{np.mean(class_metrics[class_idx]['iou']) * 100:>12.2f}"
              f"{np.mean(class_metrics[class_idx]['precision']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['recall']) * 100:>15.2f}"
              f"{np.mean(class_metrics[class_idx]['f1']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['accuracy']) * 100:>17.2f}")

# Optional: Print Hausdorff and ASD if needed
ensemble_model = SoftVotingEnsemble(
    models=[model_xception, model_segnet, model_inceptionresnetv2],
    apply_softmax=True
)

evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_xception, model_segnet, model_inceptionresnetv2],
    apply_softmax=False,
    return_probs=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = SoftVotingEnsemble(
    models=[model_xception, model_segnet, model_inceptionresnetv2],
    apply_softmax=False,
    return_probs=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

<h1>Majority Voting</h1>

In [None]:
import tensorflow as tf
import numpy as np
from scipy import stats  # for mode
from tensorflow.keras.models import Model

class HardVotingEnsemble(tf.keras.Model):
    def __init__(self, models):
        super(HardVotingEnsemble, self).__init__()
        self.models = models

    def call(self, x, training=False):
        predictions = []
        for model in self.models:
            logits = model(x, training=training)
            pred_mask = tf.argmax(logits, axis=-1)  # [B, H, W]
            predictions.append(pred_mask)

        stacked_preds = tf.stack(predictions, axis=0)  # [N_models, B, H, W]
        stacked_preds = tf.transpose(stacked_preds, [1, 2, 3, 0])  # [B, H, W, N_models]

        mode_preds = tf.numpy_function(
            func=lambda x: stats.mode(x, axis=-1)[0],
            inp=[stacked_preds],
            Tout=tf.int64
        )

        return mode_preds

In [None]:
# import tensorflow as tf
# import numpy as np
# from scipy import stats

# class HardVotingEnsemble(tf.keras.Model):
#     def __init__(self, models, num_classes):
#         super(HardVotingEnsemble, self).__init__()
#         self.models = models
#         self.num_classes = num_classes

#     def call(self, x, training=False):
#         predictions = []
#         for model in self.models:
#             logits = model(x, training=training)               # [B, H, W, C]
#             pred_mask = tf.argmax(logits, axis=-1)             # [B, H, W]
#             predictions.append(pred_mask)

#         stacked_preds = tf.stack(predictions, axis=0)          # [N_models, B, H, W]
#         stacked_preds = tf.transpose(stacked_preds, [1, 2, 3, 0])  # [B, H, W, N_models]

#         # Use numpy + scipy mode
#         def compute_mode(x):
#             mode, _ = stats.mode(x, axis=-1, keepdims=False)
#             return mode.astype(np.int32)

#         mode_preds = tf.numpy_function(
#             func=compute_mode,
#             inp=[stacked_preds],
#             Tout=tf.int32
#         )

#         # Manually set output shape: [B, H, W]
#         batch_size = tf.shape(x)[0]
#         height = tf.shape(x)[1]
#         width = tf.shape(x)[2]
#         mode_preds.set_shape([None, None, None])  # Symbolic shape for [B, H, W]

#         one_hot_preds = tf.one_hot(mode_preds, depth=self.num_classes)  # [B, H, W, C]
#         return one_hot_preds

# import tensorflow as tf
# import numpy as np
# from scipy import stats

# class HardVotingEnsemble(tf.keras.Model):
#     def __init__(self, models, num_classes, return_probs=True):
#         super(HardVotingEnsemble, self).__init__()
#         self.models = models
#         self.num_classes = num_classes
#         self.return_probs = return_probs

#     def call(self, x, training=False):
#         predictions = []

#         for model in self.models:
#             output = model(x, training=training)

#             # Handle EfficientNet-like models with built-in softmax
#             is_softmaxed = hasattr(model, "name") and "efficientnet" in model.name.lower()
#             probs = output if is_softmaxed else tf.nn.softmax(output, axis=-1)

#             pred_mask = tf.argmax(probs, axis=-1)  # [B, H, W]
#             predictions.append(pred_mask)

#         stacked_preds = tf.stack(predictions, axis=0)  # [N_models, B, H, W]
#         stacked_preds = tf.transpose(stacked_preds, [1, 2, 3, 0])  # [B, H, W, N_models]

#         # Use scipy mode to find majority vote
#         def compute_mode(x):
#             mode, _ = stats.mode(x, axis=-1, keepdims=False)
#             return mode.astype(np.int32)

#         mode_preds = tf.numpy_function(
#             func=compute_mode,
#             inp=[stacked_preds],
#             Tout=tf.int32
#         )
#         mode_preds.set_shape([None, None, None])  # [B, H, W]

#         if self.return_probs:
#             return tf.one_hot(mode_preds, depth=self.num_classes)  # [B, H, W, C]
#         else:
#             return mode_preds  # [B, H, W]

import tensorflow as tf
import numpy as np
from scipy import stats

class HardVotingEnsemble(tf.keras.Model):
    def __init__(self, models, num_classes):
        super(HardVotingEnsemble, self).__init__()
        self.models = models
        self.num_classes = num_classes

    def call(self, x, training=False):
        predictions = []
        for model in self.models:
            logits = model(x, training=training)               # [B, H, W, C]
            pred_mask = tf.argmax(logits, axis=-1)             # [B, H, W]
            predictions.append(pred_mask)

        stacked_preds = tf.stack(predictions, axis=0)          # [N_models, B, H, W]
        stacked_preds = tf.transpose(stacked_preds, [1, 2, 3, 0])  # [B, H, W, N_models]

        # Use numpy + scipy mode
        def compute_mode(x):
            mode, _ = stats.mode(x, axis=-1, keepdims=False)
            return mode.astype(np.int32)

        mode_preds = tf.numpy_function(
            func=compute_mode,
            inp=[stacked_preds],
            Tout=tf.int32
        )

        # Manually set output shape: [B, H, W]
        batch_size = tf.shape(x)[0]
        height = tf.shape(x)[1]
        width = tf.shape(x)[2]
        mode_preds.set_shape([None, None, None])  # Symbolic shape for [B, H, W]

        one_hot_preds = tf.one_hot(mode_preds, depth=self.num_classes)  # [B, H, W, C]
        return one_hot_preds

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ RGB mask to class mask
def rgb_to_class_mask(rgb_mask):
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)
    for rgb, class_idx in RGB_TO_CLASS.items():
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx
    return class_mask

# ‚úÖ Dice
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ IoU
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    distances = [np.min(np.linalg.norm(pred_points - pt, axis=1)) for pt in true_points]
    return np.mean(distances)

# ‚úÖ Soft Voting Ensemble Model
class SoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, apply_softmax=True):
        super(SoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

    def call(self, x, training=False):
        prob_sum = 0
        for model in self.models:
            logits = model(x, training=training)
            probs = tf.nn.softmax(logits, axis=-1) if self.apply_softmax else logits
            prob_sum += probs
        avg_prob = prob_sum / len(self.models)
        final_pred = tf.argmax(avg_prob, axis=-1)
        return final_pred

# ‚úÖ Evaluation Function (for normal + ensemble models)
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=16, is_ensemble=False):
    if is_ensemble:
        y_pred = []
        for i in range(0, len(X_test), batch_size):
            batch_x = X_test[i:i+batch_size]
            preds = model(batch_x, training=False).numpy()  # [B, H, W]
            y_pred.extend(preds)
        y_pred = np.array(y_pred)
    else:
        y_pred = model.predict(X_test, batch_size=batch_size)
        y_pred = np.argmax(y_pred, axis=-1)  # [B, H, W]

    y_test_class = np.argmax(y_test, axis=-1)

    class_metrics = {
        i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []}
        for i in range(num_classes)
    }

    for i in range(len(X_test)):
        true_mask = y_test_class[i]
        pred_mask = y_pred[i]
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # üìä Print results
    print(f"{'Class':<10}{'Dice Coef (%)':<15}{'IoU (%)':<12}{'Precision (%)':<17}{'Recall (%)':<15}{'F1 Score (%)':<17}{'Accuracy (%)':<17}")
    print("-" * 100)

    for class_idx in range(num_classes):
        print(f"{class_idx:<10}"
              f"{np.mean(class_metrics[class_idx]['dice']) * 100:>10.2f}"
              f"{np.mean(class_metrics[class_idx]['iou']) * 100:>12.2f}"
              f"{np.mean(class_metrics[class_idx]['precision']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['recall']) * 100:>15.2f}"
              f"{np.mean(class_metrics[class_idx]['f1']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['accuracy']) * 100:>17.2f}")
        # Optional: Print Hausdorff and ASD if needed

ensemble_model = HardVotingEnsemble([model_segnet, model_inceptionresnetv2])

# Evaluate like before (same evaluation function as used with soft voting)
evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])

    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)

    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val


In [None]:
hard_ensemble = HardVotingEnsemble(
    models=[model_segnet, model_inceptionresnetv2],
    num_classes=4
)

hard_ensemble.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

hard_ensemble.evaluate(X_test, y_test, batch_size=16)

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ RGB mask to class mask
def rgb_to_class_mask(rgb_mask):
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)
    for rgb, class_idx in RGB_TO_CLASS.items():
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx
    return class_mask

# ‚úÖ Dice
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ IoU
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    distances = [np.min(np.linalg.norm(pred_points - pt, axis=1)) for pt in true_points]
    return np.mean(distances)

# ‚úÖ Soft Voting Ensemble Model
class SoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, apply_softmax=True):
        super(SoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

    def call(self, x, training=False):
        prob_sum = 0
        for model in self.models:
            logits = model(x, training=training)
            probs = tf.nn.softmax(logits, axis=-1) if self.apply_softmax else logits
            prob_sum += probs
        avg_prob = prob_sum / len(self.models)
        final_pred = tf.argmax(avg_prob, axis=-1)
        return final_pred

# ‚úÖ Evaluation Function (for normal + ensemble models)
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=16, is_ensemble=False):
    if is_ensemble:
        y_pred = []
        for i in range(0, len(X_test), batch_size):
            batch_x = X_test[i:i+batch_size]
            preds = model(batch_x, training=False).numpy()  # [B, H, W]
            y_pred.extend(preds)
        y_pred = np.array(y_pred)
    else:
        y_pred = model.predict(X_test, batch_size=batch_size)
        y_pred = np.argmax(y_pred, axis=-1)  # [B, H, W]

    y_test_class = np.argmax(y_test, axis=-1)

    class_metrics = {
        i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []}
        for i in range(num_classes)
    }

    for i in range(len(X_test)):
        true_mask = y_test_class[i]
        pred_mask = y_pred[i]
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # üìä Print results
    print(f"{'Class':<10}{'Dice Coef (%)':<15}{'IoU (%)':<12}{'Precision (%)':<17}{'Recall (%)':<15}{'F1 Score (%)':<17}{'Accuracy (%)':<17}")
    print("-" * 100)

    for class_idx in range(num_classes):
        print(f"{class_idx:<10}"
              f"{np.mean(class_metrics[class_idx]['dice']) * 100:>10.2f}"
              f"{np.mean(class_metrics[class_idx]['iou']) * 100:>12.2f}"
              f"{np.mean(class_metrics[class_idx]['precision']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['recall']) * 100:>15.2f}"
              f"{np.mean(class_metrics[class_idx]['f1']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['accuracy']) * 100:>17.2f}")
        # Optional: Print Hausdorff and ASD if needed

ensemble_model = HardVotingEnsemble([model_segnet, model_xception])

# Evaluate like before (same evaluation function as used with soft voting)
evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ RGB mask to class mask
def rgb_to_class_mask(rgb_mask):
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)
    for rgb, class_idx in RGB_TO_CLASS.items():
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx
    return class_mask

# ‚úÖ Dice
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ IoU
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    distances = [np.min(np.linalg.norm(pred_points - pt, axis=1)) for pt in true_points]
    return np.mean(distances)

# ‚úÖ Soft Voting Ensemble Model
class SoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, apply_softmax=True):
        super(SoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

    def call(self, x, training=False):
        prob_sum = 0
        for model in self.models:
            logits = model(x, training=training)
            probs = tf.nn.softmax(logits, axis=-1) if self.apply_softmax else logits
            prob_sum += probs
        avg_prob = prob_sum / len(self.models)
        final_pred = tf.argmax(avg_prob, axis=-1)
        return final_pred

# ‚úÖ Evaluation Function (for normal + ensemble models)
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=16, is_ensemble=False):
    if is_ensemble:
        y_pred = []
        for i in range(0, len(X_test), batch_size):
            batch_x = X_test[i:i+batch_size]
            preds = model(batch_x, training=False).numpy()  # [B, H, W]
            y_pred.extend(preds)
        y_pred = np.array(y_pred)
    else:
        y_pred = model.predict(X_test, batch_size=batch_size)
        y_pred = np.argmax(y_pred, axis=-1)  # [B, H, W]

    y_test_class = np.argmax(y_test, axis=-1)

    class_metrics = {
        i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []}
        for i in range(num_classes)
    }

    for i in range(len(X_test)):
        true_mask = y_test_class[i]
        pred_mask = y_pred[i]
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # üìä Print results
    print(f"{'Class':<10}{'Dice Coef (%)':<15}{'IoU (%)':<12}{'Precision (%)':<17}{'Recall (%)':<15}{'F1 Score (%)':<17}{'Accuracy (%)':<17}")
    print("-" * 100)

    for class_idx in range(num_classes):
        print(f"{class_idx:<10}"
              f"{np.mean(class_metrics[class_idx]['dice']) * 100:>10.2f}"
              f"{np.mean(class_metrics[class_idx]['iou']) * 100:>12.2f}"
              f"{np.mean(class_metrics[class_idx]['precision']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['recall']) * 100:>15.2f}"
              f"{np.mean(class_metrics[class_idx]['f1']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['accuracy']) * 100:>17.2f}")
        # Optional: Print Hausdorff and ASD if needed

ensemble_model = HardVotingEnsemble([model_xception, model_inceptionresnetv2])

# Evaluate like before (same evaluation function as used with soft voting)
evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
hard_ensemble = HardVotingEnsemble(
    models=[model_xception, model_inceptionresnetv2],
    num_classes=4
)

hard_ensemble.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

hard_ensemble.evaluate(X_test, y_test, batch_size=16)

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ RGB mask to class mask
def rgb_to_class_mask(rgb_mask):
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)
    for rgb, class_idx in RGB_TO_CLASS.items():
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx
    return class_mask

# ‚úÖ Dice
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ IoU
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T
    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')
    distances = [np.min(np.linalg.norm(pred_points - pt, axis=1)) for pt in true_points]
    return np.mean(distances)

# ‚úÖ Soft Voting Ensemble Model
class SoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, apply_softmax=True):
        super(SoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

    def call(self, x, training=False):
        prob_sum = 0
        for model in self.models:
            logits = model(x, training=training)
            probs = tf.nn.softmax(logits, axis=-1) if self.apply_softmax else logits
            prob_sum += probs
        avg_prob = prob_sum / len(self.models)
        final_pred = tf.argmax(avg_prob, axis=-1)
        return final_pred

# ‚úÖ Evaluation Function (for normal + ensemble models)
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=16, is_ensemble=False):
    if is_ensemble:
        y_pred = []
        for i in range(0, len(X_test), batch_size):
            batch_x = X_test[i:i+batch_size]
            preds = model(batch_x, training=False).numpy()  # [B, H, W]
            y_pred.extend(preds)
        y_pred = np.array(y_pred)
    else:
        y_pred = model.predict(X_test, batch_size=batch_size)
        y_pred = np.argmax(y_pred, axis=-1)  # [B, H, W]

    y_test_class = np.argmax(y_test, axis=-1)

    class_metrics = {
        i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []}
        for i in range(num_classes)
    }

    for i in range(len(X_test)):
        true_mask = y_test_class[i]
        pred_mask = y_pred[i]
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # üìä Print results
    print(f"{'Class':<10}{'Dice Coef (%)':<15}{'IoU (%)':<12}{'Precision (%)':<17}{'Recall (%)':<15}{'F1 Score (%)':<17}{'Accuracy (%)':<17}")
    print("-" * 100)

    for class_idx in range(num_classes):
        print(f"{class_idx:<10}"
              f"{np.mean(class_metrics[class_idx]['dice']) * 100:>10.2f}"
              f"{np.mean(class_metrics[class_idx]['iou']) * 100:>12.2f}"
              f"{np.mean(class_metrics[class_idx]['precision']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['recall']) * 100:>15.2f}"
              f"{np.mean(class_metrics[class_idx]['f1']) * 100:>17.2f}"
              f"{np.mean(class_metrics[class_idx]['accuracy']) * 100:>17.2f}")
        # Optional: Print Hausdorff and ASD if needed

ensemble_model = HardVotingEnsemble([model_xception, model_segnet, model_inceptionresnetv2])

# Evaluate like before (same evaluation function as used with soft voting)
evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
hard_ensemble = HardVotingEnsemble(
    models=[model_xception, model_segnet, model_inceptionresnetv2],
    num_classes=4
)

hard_ensemble.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

hard_ensemble.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = HardVotingEnsemble(
    models=[model_segnet, model_efficientnetb4],
    # num_classes=4,
    # return_probs=False
)


# Evaluate like before (same evaluation function as used with soft voting)
evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
ensemble = HardVotingEnsemble(
    models=[model_segnet, model_efficientnetb4],
    num_classes=4
)

hard_ensemble.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

hard_ensemble.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = HardVotingEnsemble([model_xception, model_efficientnetb4], num_classes=4)

# Evaluate like before (same evaluation function as used with soft voting)
evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
hard_ensemble = HardVotingEnsemble(
    models=[model_xception, model_efficientnetb4],
    num_classes=4
)

hard_ensemble.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

hard_ensemble.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = HardVotingEnsemble([model_efficientnetb4, model_inceptionresnetv2])
    # num_classes=4,
    # return_probs=False 

# Evaluate like before (same evaluation function as used with soft voting)
evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
hard_ensemble = HardVotingEnsemble(
    models=[model_efficientnetb4, model_inceptionresnetv2],
    num_classes=4
)

hard_ensemble.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

hard_ensemble.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = HardVotingEnsemble([model_xception, model_segnet, model_efficientnetb4])

# Evaluate like before (same evaluation function as used with soft voting)
evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
hard_ensemble = HardVotingEnsemble(
    models=[model_xception, model_segnet, model_efficientnetb4],
    num_classes=4
)

hard_ensemble.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

hard_ensemble.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = HardVotingEnsemble([model_xception, model_efficientnetb4, model_inceptionresnetv2])

# Evaluate like before (same evaluation function as used with soft voting)
evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
hard_ensemble = HardVotingEnsemble(
    models=[model_xception, model_efficientnetb4, model_inceptionresnetv2],
    num_classes=4
)

hard_ensemble.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

hard_ensemble.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = HardVotingEnsemble([model_efficientnetb4, model_segnet, model_inceptionresnetv2])

# Evaluate like before (same evaluation function as used with soft voting)
evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
hard_ensemble = HardVotingEnsemble(
    models=[model_efficientnetb4, model_segnet, model_inceptionresnetv2],
    num_classes=4
)

hard_ensemble.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

hard_ensemble.evaluate(X_test, y_test, batch_size=16)

In [None]:
ensemble_model = HardVotingEnsemble([model_xception, model_segnet, model_inceptionresnetv2, model_efficientnetb4])

# Evaluate like before (same evaluation function as used with soft voting)
evaluate_classwise_metrics(ensemble_model, X_test, y_test, is_ensemble=True)

In [None]:
hard_ensemble = HardVotingEnsemble(
    models=[model_xception, model_segnet, model_inceptionresnetv2, model_efficientnetb4],
    num_classes=4
)

hard_ensemble.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

hard_ensemble.evaluate(X_test, y_test, batch_size=16)

<h1>Weighted Soft Voting</h1>

In [None]:
import optuna
import tensorflow as tf
import numpy as np

# ‚úÖ Softmax wrapper
def get_softmax_preds(model, dataset):
    preds = []
    for x_batch, _ in dataset:
        logits = model(x_batch, training=False)
        probs = tf.nn.softmax(logits)  # [B, H, W, C]
        preds.append(probs)
    return tf.concat(preds, axis=0)

# ‚úÖ One-hot true labels from dataset
def get_ground_truth(dataset):
    y_true_list = [y for _, y in dataset]
    return tf.concat(y_true_list, axis=0)

# ‚úÖ Dice calculation
def dice_score(y_true, y_pred, smooth=1e-6):
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice).numpy()

# ‚úÖ Optuna objective function
def objective(trial):
    # Sample weights and normalize them
    raw_weights = [trial.suggest_float(f'w{i}', 0.0, 1.0) for i in range(len(models))]
    total = sum(raw_weights)
    weights = [w / total for w in raw_weights]

    # Weighted ensemble
    ensemble_pred = sum(w * p for w, p in zip(weights, all_softmax_preds))
    dice = dice_score(y_true, ensemble_pred)
    return dice

# Step 1: Get predictions from each model
models = [model_xception, model_segnet, model_inceptionresnetv2, model_efficientnetb4]
all_softmax_preds = [get_softmax_preds(m, val_dataset) for m in models]

# Step 2: Get validation ground truth
y_true = get_ground_truth(val_dataset)

# Step 3: Optimize with Optuna
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# Step 4: Get best weights, rounded
best_raw = [study.best_trial.params[f'w{i}'] for i in range(len(models))]
total = sum(best_raw)
best_weights = [round(w / total, 4) for w in best_raw]

print("‚úÖ Best weights (normalized, rounded):", best_weights)


In [None]:
import tensorflow as tf

class WeightedSoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, weights=None, apply_softmax=True):
        super(WeightedSoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        else:
            total = sum(weights)
            weights = [w / total for w in weights]

        self.model_weights = tf.constant(weights, dtype=tf.float32)

    def call(self, x, training=False):
        weighted_sum = 0
        for i, model in enumerate(self.models):
            output = model(x, training=training)

            is_softmaxed = (
                hasattr(model, "name") and "efficientnet" in model.name.lower()
            )

            if self.apply_softmax and not is_softmaxed:
                probs = tf.nn.softmax(output, axis=-1)
            else:
                probs = output

            weighted_sum += self.model_weights[i] * probs

        avg_prob = weighted_sum  # shape: [B, H, W, C]

        # üîÅ Convert to one-hot for metric compatibility
        one_hot_pred = tf.one_hot(tf.argmax(avg_prob, axis=-1), depth=avg_prob.shape[-1])
        return one_hot_pred  # [B, H, W, C]

        
weights = [1, 1, 1, 1]

ensemble_model = WeightedSoftVotingEnsemble(
    models=[model_xception, model_segnet, model_inceptionresnetv2, model_efficientnetb4],
    weights=weights
)

ensemble_model.compile(
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
import tensorflow as tf
import optuna
import numpy as np

# === YOUR TRAINED MODELS HERE ===
models = [
    model_xception,
    model_segnet,
    model_inceptionresnetv2,
    model_efficientnetb4
]

import tensorflow as tf
import optuna
import numpy as np

# === INPUTS ===
# models: your list of trained models
# X_val, y_val: validation data as numpy arrays or tensors (one-hot encoded y_val)
# Example: y_val shape = [B, H, W, C], with one-hot encoding

# === STEP 1: Batch-wise softmax predictions from each model ===
def get_softmax_preds_from_array(model, X, batch_size=16):
    soft_preds = []
    for i in range(0, len(X), batch_size):
        x_batch = X[i:i+batch_size]
        logits = model(x_batch, training=False)
        probs = tf.nn.softmax(logits)
        soft_preds.append(probs)
    return tf.concat(soft_preds, axis=0)

# === STEP 2: Mean Dice (one-hot predictions only) ===
def mean_dice_per_class(y_true, y_pred_soft, smooth=1e-6):
    y_pred_argmax = tf.argmax(y_pred_soft, axis=-1)                     # [B, H, W]
    y_pred = tf.one_hot(y_pred_argmax, depth=y_pred_soft.shape[-1])    # [B, H, W, C]

    dice_scores = []
    for i in range(y_true.shape[-1]):
        y_true_c = y_true[..., i]
        y_pred_c = y_pred[..., i]
        intersection = tf.reduce_sum(y_true_c * y_pred_c)
        union = tf.reduce_sum(y_true_c) + tf.reduce_sum(y_pred_c)
        dice = (2. * intersection + smooth) / (union + smooth)
        dice_scores.append(dice)
    return tf.reduce_mean(tf.stack(dice_scores)).numpy()

# === STEP 3: Precompute softmax predictions ===
print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

# === STEP 4: Optuna Objective Function ===
def objective(trial):
    raw_weights = [trial.suggest_float(f'w{i}', 0.0, 1.0) for i in range(len(models))]
    total = sum(raw_weights)
    weights = [w / total for w in raw_weights]

    # Weighted average of soft predictions
    weighted_avg = sum(w * p for w, p in zip(weights, soft_preds_all))

    # ‚úÖ Compute mean Dice with one-hot predictions
    score = mean_dice_per_class(y_true_val, weighted_avg)
    return score

# === STEP 5: Run Optuna Search ===
print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

# === Final WeightedSoftVotingEnsemble using best weights ===
class WeightedSoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, weights=None, apply_softmax=True):
        super(WeightedSoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        else:
            total = sum(weights)
            weights = [w / total for w in weights]

        self.model_weights = tf.constant(weights, dtype=tf.float32)

    def call(self, x, training=False):
        weighted_sum = 0
        for i, model in enumerate(self.models):
            output = model(x, training=training)

            is_softmaxed = (
                hasattr(model, "name") and "efficientnet" in model.name.lower()
            )

            if self.apply_softmax and not is_softmaxed:
                probs = tf.nn.softmax(output, axis=-1)
            else:
                probs = output

            weighted_sum += self.model_weights[i] * probs

        avg_prob = weighted_sum  # shape: [B, H, W, C]

        # üîÅ Convert to one-hot for metric compatibility
        one_hot_pred = tf.one_hot(tf.argmax(avg_prob, axis=-1), depth=avg_prob.shape[-1])
        return one_hot_pred  # [B, H, W, C]

# === Create the final ensemble model ===
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

In [None]:
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
final_weights

In [None]:
# === Final WeightedSoftVotingEnsemble using best weights ===
class WeightedSoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, weights=None, apply_softmax=True):
        super(WeightedSoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        else:
            total = sum(weights)
            weights = [w / total for w in weights]

        self.model_weights = tf.constant(weights, dtype=tf.float32)

    def call(self, x, training=False):
        weighted_sum = 0
        for i, model in enumerate(self.models):
            output = model(x, training=training)

            is_softmaxed = (
                hasattr(model, "name") and "efficientnet" in model.name.lower()
            )

            if self.apply_softmax and not is_softmaxed:
                probs = tf.nn.softmax(output, axis=-1)
            else:
                probs = output

            weighted_sum += self.model_weights[i] * probs

        avg_prob = weighted_sum  # shape: [B, H, W, C]

        # üîÅ Convert to one-hot for metric compatibility
        one_hot_pred = tf.one_hot(tf.argmax(avg_prob, axis=-1), depth=avg_prob.shape[-1])
        return one_hot_pred  # [B, H, W, C]

In [None]:
models = [
    model_xception,
    model_segnet,
    model_inceptionresnetv2,
    model_efficientnetb4
]

final_weights = [0.3717, 0.301, 0.1892, 0.1381]

In [None]:

ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_xception,
    model_inceptionresnetv2,
    model_efficientnetb4
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_xception,
    model_inceptionresnetv2,
    model_segnet
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_segnet,
    model_inceptionresnetv2,
    model_efficientnetb4
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_segnet,
    model_xception,
    model_efficientnetb4
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_xception,
    model_inceptionresnetv2
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_inceptionresnetv2,
    model_efficientnetb4
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_inceptionresnetv2,
    model_segnet
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_xception,
    model_segnet
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_xception,
    model_efficientnetb4
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_segnet,
    model_efficientnetb4
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_xception,
    model_inceptionresnetv2,
    model_efficientnetb4
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_xception,
    model_inceptionresnetv2,
    model_efficientnetb4
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_xception,
    model_inceptionresnetv2,
    model_efficientnetb4
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_xception,
    model_inceptionresnetv2,
    model_efficientnetb4
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
models = [
    model_xception,
    model_inceptionresnetv2,
    model_efficientnetb4
]

print("‚úÖ Precomputing predictions from all models...")
soft_preds_all = [get_softmax_preds_from_array(m, X_val) for m in models]
y_true_val = tf.convert_to_tensor(y_val, dtype=tf.float32)

print("üéØ Running Optuna for best ensemble weights...")
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

# === STEP 6: Extract & Format Best Weights ===
best_raw_weights = [study.best_trial.params[f"w{i}"] for i in range(len(models))]
total = sum(best_raw_weights)
final_weights = [round(w / total, 4) for w in best_raw_weights]

print("\n‚úÖ Best Ensemble Weights (rounded):", final_weights)
print("‚úÖ Best Mean Dice Score:", round(study.best_value, 5))

print(final_weights)
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random

def decode_segmentation(mask, num_classes=4):
    """
    Converts a one-hot encoded mask or softmax prediction into a color image.
    """
    label_mask = np.argmax(mask, axis=-1)  # shape: (H, W)
    colors = [
        (0, 0, 0),         # Background - Black
        (255, 0, 0),       # Brain - Red
        (0, 255, 0),       # CSP - Green
        (0, 0, 255),       # LV - Blue
    ]

    color_mask = np.zeros((*label_mask.shape, 3), dtype=np.uint8)
    for cls_idx, color in enumerate(colors):
        color_mask[label_mask == cls_idx] = color
    return color_mask

def display_predictions(model, X, y_true, num_samples=5):
    indices = random.sample(range(len(X)), num_samples)
    X_samples = X[indices]
    y_samples = y_true[indices]

    preds = model.predict(X_samples)
    
    for i in range(num_samples):
        image = X_samples[i]
        true_mask = decode_segmentation(y_samples[i])
        pred_mask = decode_segmentation(preds[i])

        # Plot side-by-side
        plt.figure(figsize=(12, 5))

        # Ground truth overlay
        plt.subplot(1, 2, 1)
        plt.imshow(image.astype(np.uint8))
        plt.imshow(true_mask, alpha=0.5)
        plt.title("Ground Truth Overlay")
        plt.axis('off')

        # Prediction overlay
        plt.subplot(1, 2, 2)
        plt.imshow(image.astype(np.uint8))
        plt.imshow(pred_mask, alpha=0.5)
        plt.title("Prediction Overlay")
        plt.axis('off')

        plt.tight_layout()
        plt.show()

In [None]:
display_predictions(model, X_test, y_test, num_samples=5)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random

# Function to decode one-hot encoded mask or softmax prediction into RGB
def decode_segmentation(mask, num_classes=4):
    label_mask = np.argmax(mask, axis=-1)  # shape: (H, W)
    
    colors = [
        (0, 0, 0),         # Background - Black
        (255, 0, 0),       # Brain - Red
        (0, 255, 0),       # CSP - Green
        (0, 0, 255),       # LV - Blue
    ]
    
    color_mask = np.zeros((*label_mask.shape, 3), dtype=np.uint8)
    for cls_idx, color in enumerate(colors):
        color_mask[label_mask == cls_idx] = color
    return color_mask

# Function to overlay mask on an image
def overlay_mask_on_image(image, mask, alpha=0.5):
    overlay = image.copy()
    if overlay.max() <= 1.0:
        overlay = (overlay * 255).astype(np.uint8)

    if mask.max() <= 1.0:
        mask = (mask * 255).astype(np.uint8)
    
    overlay = overlay.astype(np.float32)
    mask = mask.astype(np.float32)
    
    combined = cv2.addWeighted(overlay, 1 - alpha, mask, alpha, 0)
    return combined.astype(np.uint8)

# Main display function for N samples
def display_overlay_predictions(model, X, y_true, num_samples=5):
    import cv2
    indices = random.sample(range(len(X)), num_samples)
    
    X_batch = X[indices]
    y_batch = y_true[indices]
    y_pred_batch = model.predict(X_batch)

    for i in range(num_samples):
        image = X_batch[i]
        true_mask = decode_segmentation(y_batch[i])
        pred_mask = decode_segmentation(y_pred_batch[i])

        # Make sure image is uint8
        if image.max() <= 1.0:
            image = (image * 255).astype(np.uint8)

        gt_overlay = overlay_mask_on_image(image, true_mask, alpha=0.5)
        pred_overlay = overlay_mask_on_image(image, pred_mask, alpha=0.5)

        # Plot
        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        plt.imshow(gt_overlay)
        plt.title("Ground Truth Overlay")
        plt.axis("off")

        plt.subplot(1, 2, 2)
        plt.imshow(pred_overlay)
        plt.title("Predicted Overlay")
        plt.axis("off")

        plt.tight_layout()
        plt.show()

display_overlay_predictions(model, X_test, y_test, num_samples=5)

<h1>Knowledge Distillation</h1>

In [None]:
import os
import shutil
import numpy as np
from sklearn.model_selection import train_test_split

image_dir = r"D:\augmented_dataset\images"
mask_dir = r"D:\augmented_dataset\masks"

# # ‚úÖ Define destination directories
train_image_dir = r"D:\Updated\train\images"
train_mask_dir = r"D:\Updated\train\masks"
val_image_dir = r"D:\Updated\val\images"
val_mask_dir = r"D:\Updated\val\masks"
test_image_dir = r"D:\Updated\test\images"
test_mask_dir = r"D:\Updated\test\masks"

import numpy as np
import os
import gc
import cv2
import re
from tensorflow.keras.utils import to_categorical

# ‚úÖ Constants for 224x224
IMG_HEIGHT = 224  # Ensure height is 224
IMG_WIDTH = 224   # Ensure width is 224
CHANNELS = 3  # RGB images
NUM_CLASSES = 4  # Brain, CSP, LV, Background

# ‚úÖ Class mapping from RGB to class index
CLASS_MAP = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0,  # Background
}

# ‚úÖ Fix sorting issue using natural sorting
def natural_sort_key(s):
    """Sort filenames numerically instead of lexicographically."""
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

# ‚úÖ Convert RGB mask to class index mask
def rgb_to_class(mask_array):
    """Convert RGB mask to single-channel class index mask."""
    height, width, _ = mask_array.shape
    class_mask = np.zeros((height, width), dtype=np.uint8)

    for rgb, class_idx in CLASS_MAP.items():
        matches = np.all(mask_array == rgb, axis=-1)  # Ensure exact match
        class_mask[matches] = class_idx

    return class_mask

# ‚úÖ Preprocess Filtered Dataset for 224x224
def preprocess_filtered_dataset(image_dir, mask_dir):
    """Preprocess images & masks: normalize, resize, and convert masks to one-hot encoding."""

    # ‚úÖ Load and sort filenames correctly
    image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)
    mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)

    valid_image_paths = []
    valid_mask_paths = []

    # ‚úÖ Ensure each image has a corresponding mask
    for img_file, mask_file in zip(image_filenames, mask_filenames):
        img_path = os.path.join(image_dir, img_file)
        mask_path = os.path.join(mask_dir, mask_file)

        if os.path.exists(img_path) and os.path.exists(mask_path):
            valid_image_paths.append(img_path)
            valid_mask_paths.append(mask_path)
        else:
            print(f"‚ö†Ô∏è Skipping {img_file}: Missing image or mask")

    num_images = len(valid_image_paths)

    # ‚úÖ Initialize arrays
    X = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=np.float32)
    y = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.float32)  # One-hot encoded masks

    print(f"üöÄ Processing {num_images} filtered images and masks...")

    for idx, (img_path, mask_path) in enumerate(zip(valid_image_paths, valid_mask_paths)):
        if idx % 100 == 0:
            print(f"‚úÖ Processed {idx}/{num_images} images")

        # ‚úÖ Load and Resize Image
        img = cv2.imread(img_path)  # Read image in BGR format
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))  # Resize to (224,224)
        img = img.astype(np.float32) / 255.0  # Normalize

        # ‚úÖ Load and Resize Mask
        mask = cv2.imread(mask_path)  # Read mask in BGR format
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)  # Resize mask correctly

        # ‚úÖ Convert RGB mask to class mask
        class_mask = rgb_to_class(mask)

        # ‚úÖ One-hot encode the class mask
        one_hot_mask = to_categorical(class_mask, num_classes=NUM_CLASSES)

        # ‚úÖ Store preprocessed data
        X[idx] = img
        y[idx] = one_hot_mask

        # ‚úÖ Clear memory to prevent memory leaks
        del img, mask, class_mask, one_hot_mask
        gc.collect()

    return X, y

from sklearn.model_selection import train_test_split

# ‚úÖ Process dataset splits
X_train, y_train = preprocess_filtered_dataset(train_image_dir, train_mask_dir)
X_val, y_val = preprocess_filtered_dataset(val_image_dir, val_mask_dir)
X_test, y_test = preprocess_filtered_dataset(test_image_dir, test_mask_dir)

# ‚úÖ Print dataset information
print("\n‚úÖ Dataset Splits:")
print(f"  - Training set: {X_train.shape}, {y_train.shape}")
print(f"  - Validation set: {X_val.shape}, {y_val.shape}")
print(f"  - Test set: {X_test.shape}, {y_test.shape}")

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Data Augmentation configuration for the training set
train_datagen = ImageDataGenerator(
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
)

# Fit the augmentation parameters on the training data
train_datagen.fit(X_train)

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, BatchNormalization, Activation, UpSampling2D, Concatenate

def conv_block(inputs, filters, kernel_size=(3, 3), padding="same", use_batch_norm=True):
    """
    Convolutional block with optional batch normalization
    """
    x = Conv2D(filters, kernel_size, padding=padding)(inputs)
    
    if use_batch_norm:
        x = BatchNormalization()(x)
        
    x = Activation("relu")(x)
    
    x = Conv2D(filters, kernel_size, padding=padding)(x)
    
    if use_batch_norm:
        x = BatchNormalization()(x)
        
    x = Activation("relu")(x)
    
    return x

def UNetPlusPlus(input_shape=(224, 224, 3), num_classes=4, filters=[24, 48, 96, 192], use_batch_norm=True):
    """
    UNet++ (Nested U-Net) model for multiclass segmentation
    
    Args:
        input_shape: Input image dimensions (height, width, channels)
        num_classes: Number of output classes for segmentation
        filters: List of filter dimensions for each level
        use_batch_norm: Whether to use batch normalization
        
    Returns:
        tf.keras.Model: UNet++ model
    """
    # Input
    inputs = Input(input_shape)
    
    # Encoder (Downsampling path)
    conv0_0 = conv_block(inputs, filters[0], use_batch_norm=use_batch_norm)
    pool0 = MaxPooling2D(pool_size=(2, 2))(conv0_0)
    
    conv1_0 = conv_block(pool0, filters[1], use_batch_norm=use_batch_norm)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1_0)
    
    conv2_0 = conv_block(pool1, filters[2], use_batch_norm=use_batch_norm)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2_0)
    
    conv3_0 = conv_block(pool2, filters[3], use_batch_norm=use_batch_norm)
    
    # Decoder (Upsampling path with nested dense skip connections)
    # Level 1 skip connections
    up1_0 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv3_0)
    conv2_1 = conv_block(Concatenate()([up1_0, conv2_0]), filters[2], use_batch_norm=use_batch_norm)
    
    up0_1 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv2_0)
    conv1_1 = conv_block(Concatenate()([up0_1, conv1_0]), filters[1], use_batch_norm=use_batch_norm)
    
    up0_2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv1_0)
    conv0_1 = conv_block(Concatenate()([up0_2, conv0_0]), filters[0], use_batch_norm=use_batch_norm)
    
    # Level 2 skip connections
    up1_1 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv2_1)
    conv1_2 = conv_block(Concatenate()([up1_1, conv1_0, conv1_1]), filters[1], use_batch_norm=use_batch_norm)
    
    up0_3 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv1_1)
    conv0_2 = conv_block(Concatenate()([up0_3, conv0_0, conv0_1]), filters[0], use_batch_norm=use_batch_norm)
    
    # Level 3 skip connections
    up0_4 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv1_2)
    conv0_3 = conv_block(Concatenate()([up0_4, conv0_0, conv0_1, conv0_2]), filters[0], use_batch_norm=use_batch_norm)
    
    # Output segmentation map
    outputs = Conv2D(num_classes, (1, 1), activation='softmax')(conv0_3)
    
    # Create model
    model = Model(inputs=[inputs], outputs=[outputs])
    
    return model

# Example usage
student_model = UNetPlusPlus(input_shape=(224, 224, 3), num_classes=4)
student_model.summary()

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, Activation, UpSampling2D, Concatenate, Add

def res_conv_block(inputs, filters, kernel_size=(3, 3), padding="same", use_batch_norm=True):
    """
    Residual convolutional block with skip connections
    """
    # Store input for residual connection
    shortcut = inputs
    
    # First convolution
    x = Conv2D(filters, kernel_size, padding=padding)(inputs)
    if use_batch_norm:
        x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    # Second convolution
    x = Conv2D(filters, kernel_size, padding=padding)(x)
    if use_batch_norm:
        x = BatchNormalization()(x)
    
    # If input channels don't match output channels, use 1x1 conv to match dimensions
    if shortcut.shape[-1] != filters:
        shortcut = Conv2D(filters, (1, 1), padding=padding)(shortcut)
        if use_batch_norm:
            shortcut = BatchNormalization()(shortcut)
    
    # Add residual connection
    x = Add()([x, shortcut])
    x = Activation("relu")(x)
    
    return x


def UNetPlusPlus(input_shape=(224, 224, 3), num_classes=4, filters=[24, 48, 96, 192], use_batch_norm=True):
    """
    Enhanced UNet++ with residual connections
    
    Args:
        input_shape: Input image dimensions (height, width, channels)
        num_classes: Number of output classes for segmentation
        filters: List of filter dimensions for each level
        use_batch_norm: Whether to use batch normalization
    """
    # Input
    inputs = Input(input_shape)
    
    # Encoder (Downsampling path)
    conv0_0 = res_conv_block(inputs, filters[0], use_batch_norm=use_batch_norm)
    pool0 = MaxPooling2D(pool_size=(2, 2))(conv0_0)
    
    conv1_0 = res_conv_block(pool0, filters[1], use_batch_norm=use_batch_norm)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1_0)
    
    conv2_0 = res_conv_block(pool1, filters[2], use_batch_norm=use_batch_norm)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2_0)
    
    conv3_0 = res_conv_block(pool2, filters[3], use_batch_norm=use_batch_norm)
    
    # Decoder (Upsampling path with nested dense skip connections)
    # Level 1 skip connections
    up1_0 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv3_0)
    concat2_1 = Concatenate()([up1_0, conv2_0])
    conv2_1 = res_conv_block(concat2_1, filters[2], use_batch_norm=use_batch_norm)
    
    up0_1 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv2_0)
    concat1_1 = Concatenate()([up0_1, conv1_0])
    conv1_1 = res_conv_block(concat1_1, filters[1], use_batch_norm=use_batch_norm)
    
    up0_2 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv1_0)
    concat0_1 = Concatenate()([up0_2, conv0_0])
    conv0_1 = res_conv_block(concat0_1, filters[0], use_batch_norm=use_batch_norm)
    
    # Level 2 skip connections
    up1_1 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv2_1)
    concat1_2 = Concatenate()([up1_1, conv1_0, conv1_1])
    conv1_2 = res_conv_block(concat1_2, filters[1], use_batch_norm=use_batch_norm)
    
    up0_3 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv1_1)
    concat0_2 = Concatenate()([up0_3, conv0_0, conv0_1])
    conv0_2 = res_conv_block(concat0_2, filters[0], use_batch_norm=use_batch_norm)
    
    # Level 3 skip connections
    up0_4 = UpSampling2D(size=(2, 2), interpolation='bilinear')(conv1_2)
    concat0_3 = Concatenate()([up0_4, conv0_0, conv0_1, conv0_2])
    conv0_3 = res_conv_block(concat0_3, filters[0], use_batch_norm=use_batch_norm)
    
    # Output segmentation map (single output)
    outputs = Conv2D(num_classes, (1, 1), activation='softmax')(conv0_3)
    
    # Create model with single output
    model = Model(inputs=[inputs], outputs=[outputs])
    
    return model

student_model = UNetPlusPlus(input_shape=(224, 224, 3), num_classes=4)
student_model.summary()

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Input

# === Config ===
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS = 3
NUM_CLASSES = 4


# === Resize Layer ===
class ResizeLayer(tf.keras.layers.Layer):
    def __init__(self, target_size, **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.target_size = target_size

    def call(self, inputs):
        return tf.image.resize(inputs, self.target_size, method='bilinear')

    def get_config(self):
        config = super().get_config()
        config.update({'target_size': self.target_size})
        return config


# === Conv Block ===
def conv_block(x, filters, kernel_size=3, strides=1):
    x = layers.Conv2D(filters, kernel_size, padding='same', strides=strides, use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x


# === MBConv Block ===
def mbconv_block(x, in_ch, out_ch, stride=1, expansion=4):
    hidden_dim = in_ch * expansion
    res = x
    x = layers.Conv2D(hidden_dim, 1, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('swish')(x)

    x = layers.DepthwiseConv2D(3, strides=stride, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('swish')(x)

    x = layers.Conv2D(out_ch, 1, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)

    if stride == 1 and in_ch == out_ch:
        x = layers.Add()([res, x])
    return x


# === MLA Block ===
def mla_block(x, channels, stride=1):
    residual = x

    if stride > 1:
        residual = layers.AveragePooling2D(pool_size=stride, strides=stride, padding='same')(residual)
        x = layers.AveragePooling2D(pool_size=stride, strides=stride, padding='same')(x)

    if x.shape[-1] != channels:
        residual = layers.Conv2D(channels, 1, padding='same')(residual)

    attn = layers.DepthwiseConv2D(5, padding='same')(x)
    attn = layers.Conv2D(channels, 1, padding='same')(attn)
    attn = layers.ReLU()(attn)

    return layers.Add()([residual, attn])


# === EfficientViT Block ===
def efficientvit_block(x, in_ch, out_ch, stride=1):
    local = mbconv_block(x, in_ch, out_ch, stride=stride)
    global_ = mla_block(x, out_ch, stride=stride)
    return layers.Add()([local, global_])


# EfficientViT-B0 Encoder
def efficientvit_b0_encoder(inputs):
    x = layers.Conv2D(16, 3, strides=2, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('swish')(x)
    e0 = x  # 112x112

    x = efficientvit_block(x, 16, 32, stride=2)
    e1 = x  # 56x56

    x = efficientvit_block(x, 32, 64, stride=2)
    e2 = x  # 28x28

    x = efficientvit_block(x, 64, 96, stride=2)
    e3 = x  # 14x14

    x = efficientvit_block(x, 96, 128, stride=2)
    e4 = x  # 7x7

    return [e0, e1, e2, e3, e4]


def attention_gate(x, g, inter_channels):
    """Standard attention gate block"""
    x_shape = tf.keras.backend.int_shape(x)
    g_shape = tf.keras.backend.int_shape(g)

    if x_shape[1] != g_shape[1] or x_shape[2] != g_shape[2]:
        g = ResizeLayer((x_shape[1], x_shape[2]))(g)

    theta_x = layers.Conv2D(inter_channels, 1, padding='same')(x)
    phi_g = layers.Conv2D(inter_channels, 1, padding='same')(g)
    f = layers.Add()([theta_x, phi_g])
    f = layers.Activation('relu')(f)
    psi = layers.Conv2D(1, 1, padding='same')(f)
    alpha = layers.Activation('sigmoid')(psi)
    return layers.Multiply()([x, alpha])


def build_attention_unet_with_efficientvit(input_shape=(224, 224, 3), num_classes=4):
    inputs = Input(shape=input_shape)

    # === EfficientViT Encoder ===
    encoder_features = efficientvit_b0_encoder(inputs)
    e0, e1, e2, e3, e4 = encoder_features  # e0: shallowest, e4: bottleneck

    # === Decoder with Attention U-Net style ===
    up_filters = [96, 64, 32, 16]
    skip_connections = [e3, e2, e1, e0]  # deepest to shallowest
    up = e4  # bottleneck

    for i in range(4):
        up = layers.Conv2DTranspose(up_filters[i], 3, strides=2, padding='same')(up)
        up = layers.BatchNormalization()(up)
        up = layers.Activation('relu')(up)
        up = layers.Dropout(0.2)(up)

        skip = skip_connections[i]
        att_skip = attention_gate(skip, up, up_filters[i] // 2)

        # Resize up to match skip if needed
        if tf.keras.backend.int_shape(att_skip)[1:3] != tf.keras.backend.int_shape(up)[1:3]:
            up = ResizeLayer((att_skip.shape[1], att_skip.shape[2]))(up)

        up = layers.Concatenate()([up, att_skip])
        up = conv_block(up, up_filters[i])
        up = conv_block(up, up_filters[i])

    # Final Upsampling to full resolution
    up = layers.Conv2DTranspose(64, 3, strides=2, padding='same')(up)
    up = conv_block(up, 64)
    up = conv_block(up, 32)

    if tf.keras.backend.int_shape(up)[1:3] != (input_shape[0], input_shape[1]):
        up = ResizeLayer((input_shape[0], input_shape[1]))(up)

    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(up)
    return Model(inputs, outputs)

# === Build and Print Model ===
tf.keras.backend.clear_session()
student_model = build_attention_unet_with_efficientvit(
    input_shape=(224, 224, 3),
    num_classes=4
)
student_model.summary()

In [None]:
import copy

def clone_metric(metric):
    if isinstance(metric, DiceCoefficient):
        return DiceCoefficient(class_idx=metric.class_idx)
    else:
        return type(metric)(**metric.get_config())

class Distiller(tf.keras.Model):
    def __init__(self, student, teacher, temperature=3.0, alpha=0.5, metrics=None):
        super().__init__()
        self.student = student
        self.teacher = teacher
        self.temperature = temperature
        self.alpha = alpha
        self.kl_loss_fn = tf.keras.losses.KLDivergence()
        self.train_accuracy = tf.keras.metrics.CategoricalAccuracy(name="train_accuracy")
        self.val_accuracy = tf.keras.metrics.CategoricalAccuracy(name="val_accuracy")
        self.train_metrics = metrics or []
        self.val_metrics = [clone_metric(m) for m in self.train_metrics]


    def compile(self, optimizer):
        super().compile()
        self.optimizer = optimizer

    def train_step(self, data):
        x, y_true = data
        teacher_soft = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            student_logits = self.student(x, training=True)
            student_soft = tf.nn.softmax(student_logits / self.temperature)
            distill_loss = self.kl_loss_fn(teacher_soft, student_soft)
            supervised_loss = combined_loss(y_true, tf.nn.softmax(student_logits))
            loss = self.alpha * supervised_loss + (1 - self.alpha) * distill_loss

        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))

        self.train_accuracy.update_state(y_true, tf.nn.softmax(student_logits))
        for m in self.train_metrics:
            m.update_state(y_true, tf.nn.softmax(student_logits))

        logs = {"loss": loss, "accuracy": self.train_accuracy.result()}
        for m in self.train_metrics:
            logs[m.name] = m.result()
        return logs

    def test_step(self, data):
        x, y_true = data
        y_pred = self.student(x, training=False)

        self.val_accuracy.update_state(y_true, tf.nn.softmax(y_pred))
        for m in self.val_metrics:
            m.update_state(y_true, tf.nn.softmax(y_pred))

        logs = {"accuracy": self.val_accuracy.result()}
        for m in self.val_metrics:
            logs[m.name] = m.result()
        return logs

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

callbacks = [
    EarlyStopping(
        monitor='val_loss',  # You can also use 'val_loss' if you log it manually
        patience=10,
        restore_best_weights=True
    ),
    ReduceLROnPlateau(
        monitor='val_loss',  # Or 'val_loss'
        factor=0.5,
        patience=3,
        min_lr=1e-6
    ),
    ModelCheckpoint(
        filepath='best_student_unetplusplus.keras',
        monitor='val_loss',  # Or 'val_loss'
        save_best_only=True
    )
]

def create_train_generator(X, y, batch_size=16):
    data_gen_args = dict(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)
    
    seed = 42
    image_generator = image_datagen.flow(X, batch_size=batch_size, seed=seed)
    mask_generator = mask_datagen.flow(y, batch_size=batch_size, seed=seed)
    
    while True:
        X_batch = next(image_generator)
        y_batch = next(mask_generator)
        yield X_batch, y_batch

train_generator = create_train_generator(X_train, y_train, batch_size=16)

In [None]:
import tensorflow as tf
import numpy as np
from scipy import stats

def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])

    dice_loss_val = 1 - (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = tf.reduce_mean(dice_loss_val)

    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)
    return lovasz_loss_val + dice_loss_val


class HardVotingEnsemble(tf.keras.Model):
    def __init__(self, models, num_classes):
        super(HardVotingEnsemble, self).__init__()
        self.models = models
        self.num_classes = num_classes

    def call(self, x, training=False):
        predictions = []
        for model in self.models:
            logits = model(x, training=training)               # [B, H, W, C]
            pred_mask = tf.argmax(logits, axis=-1)             # [B, H, W]
            predictions.append(pred_mask)

        stacked_preds = tf.stack(predictions, axis=0)          # [N_models, B, H, W]
        stacked_preds = tf.transpose(stacked_preds, [1, 2, 3, 0])  # [B, H, W, N_models]

        # Use numpy + scipy mode
        def compute_mode(x):
            mode, _ = stats.mode(x, axis=-1, keepdims=False)
            return mode.astype(np.int32)

        mode_preds = tf.numpy_function(
            func=compute_mode,
            inp=[stacked_preds],
            Tout=tf.int32
        )

        # Manually set output shape: [B, H, W]
        batch_size = tf.shape(x)[0]
        height = tf.shape(x)[1]
        width = tf.shape(x)[2]
        mode_preds.set_shape([None, None, None])  # Symbolic shape for [B, H, W]

        one_hot_preds = tf.one_hot(mode_preds, depth=self.num_classes)  # [B, H, W, C]
        return one_hot_preds

hard_ensemble = HardVotingEnsemble(
    models=[model_xception, model_segnet, model_inceptionresnetv2, model_efficientnetb4],
    num_classes=4
)

hard_ensemble.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Input

# === Config ===
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS = 3
NUM_CLASSES = 4


# === Resize Layer ===
class ResizeLayer(tf.keras.layers.Layer):
    def __init__(self, target_size, **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.target_size = target_size

    def call(self, inputs):
        return tf.image.resize(inputs, self.target_size, method='bilinear')

    def get_config(self):
        config = super().get_config()
        config.update({'target_size': self.target_size})
        return config


# === Conv Block ===
def conv_block(x, filters, kernel_size=3, strides=1):
    x = layers.Conv2D(filters, kernel_size, padding='same', strides=strides, use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x


# === MBConv Block ===
def mbconv_block(x, in_ch, out_ch, stride=1, expansion=4):
    hidden_dim = in_ch * expansion
    res = x
    x = layers.Conv2D(hidden_dim, 1, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('swish')(x)

    x = layers.DepthwiseConv2D(3, strides=stride, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('swish')(x)

    x = layers.Conv2D(out_ch, 1, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)

    if stride == 1 and in_ch == out_ch:
        x = layers.Add()([res, x])
    return x


# === MLA Block ===
def mla_block(x, channels, stride=1):
    residual = x

    if stride > 1:
        residual = layers.AveragePooling2D(pool_size=stride, strides=stride, padding='same')(residual)
        x = layers.AveragePooling2D(pool_size=stride, strides=stride, padding='same')(x)

    if x.shape[-1] != channels:
        residual = layers.Conv2D(channels, 1, padding='same')(residual)

    attn = layers.DepthwiseConv2D(5, padding='same')(x)
    attn = layers.Conv2D(channels, 1, padding='same')(attn)
    attn = layers.ReLU()(attn)

    return layers.Add()([residual, attn])


# === EfficientViT Block ===
def efficientvit_block(x, in_ch, out_ch, stride=1):
    local = mbconv_block(x, in_ch, out_ch, stride=stride)
    global_ = mla_block(x, out_ch, stride=stride)
    return layers.Add()([local, global_])


# EfficientViT-B0 Encoder
def efficientvit_b0_encoder(inputs):
    x = layers.Conv2D(16, 3, strides=2, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('swish')(x)
    e0 = x  # 112x112

    x = efficientvit_block(x, 16, 32, stride=2)
    e1 = x  # 56x56

    x = efficientvit_block(x, 32, 64, stride=2)
    e2 = x  # 28x28

    x = efficientvit_block(x, 64, 96, stride=2)
    e3 = x  # 14x14

    x = efficientvit_block(x, 96, 128, stride=2)
    e4 = x  # 7x7

    return [e0, e1, e2, e3, e4]


def attention_gate(x, g, inter_channels):
    """Standard attention gate block"""
    x_shape = tf.keras.backend.int_shape(x)
    g_shape = tf.keras.backend.int_shape(g)

    if x_shape[1] != g_shape[1] or x_shape[2] != g_shape[2]:
        g = ResizeLayer((x_shape[1], x_shape[2]))(g)

    theta_x = layers.Conv2D(inter_channels, 1, padding='same')(x)
    phi_g = layers.Conv2D(inter_channels, 1, padding='same')(g)
    f = layers.Add()([theta_x, phi_g])
    f = layers.Activation('relu')(f)
    psi = layers.Conv2D(1, 1, padding='same')(f)
    alpha = layers.Activation('sigmoid')(psi)
    return layers.Multiply()([x, alpha])


def build_attention_unet_with_efficientvit(input_shape=(224, 224, 3), num_classes=4):
    inputs = Input(shape=input_shape)

    # === EfficientViT Encoder ===
    encoder_features = efficientvit_b0_encoder(inputs)
    e0, e1, e2, e3, e4 = encoder_features  # e0: shallowest, e4: bottleneck

    # === Decoder with Attention U-Net style ===
    up_filters = [96, 64, 32, 16]
    skip_connections = [e3, e2, e1, e0]  # deepest to shallowest
    up = e4  # bottleneck

    for i in range(4):
        up = layers.Conv2DTranspose(up_filters[i], 3, strides=2, padding='same')(up)
        up = layers.BatchNormalization()(up)
        up = layers.Activation('relu')(up)
        up = layers.Dropout(0.2)(up)

        skip = skip_connections[i]
        att_skip = attention_gate(skip, up, up_filters[i] // 2)

        # Resize up to match skip if needed
        if tf.keras.backend.int_shape(att_skip)[1:3] != tf.keras.backend.int_shape(up)[1:3]:
            up = ResizeLayer((att_skip.shape[1], att_skip.shape[2]))(up)

        up = layers.Concatenate()([up, att_skip])
        up = conv_block(up, up_filters[i])
        up = conv_block(up, up_filters[i])

    # Final Upsampling to full resolution
    up = layers.Conv2DTranspose(64, 3, strides=2, padding='same')(up)
    up = conv_block(up, 64)
    up = conv_block(up, 32)

    if tf.keras.backend.int_shape(up)[1:3] != (input_shape[0], input_shape[1]):
        up = ResizeLayer((input_shape[0], input_shape[1]))(up)

    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(up)
    return Model(inputs, outputs)

tf.keras.backend.clear_session()
student_model = build_attention_unet_with_efficientvit(
    input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS),
    num_classes=NUM_CLASSES
)
student_model.summary()

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

models = [
    model_xception,
    model_segnet,
    model_inceptionresnetv2,
    model_efficientnetb4
]

class WeightedSoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, weights=None, apply_softmax=True):
        super(WeightedSoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        else:
            total = sum(weights)
            weights = [w / total for w in weights]

        self.model_weights = tf.constant(weights, dtype=tf.float32)

    def call(self, x, training=False):
        weighted_sum = 0
        for i, model in enumerate(self.models):
            output = model(x, training=training)

            is_softmaxed = (
                hasattr(model, "name") and "efficientnet" in model.name.lower()
            )

            if self.apply_softmax and not is_softmaxed:
                probs = tf.nn.softmax(output, axis=-1)
            else:
                probs = output

            weighted_sum += self.model_weights[i] * probs

        avg_prob = weighted_sum  # shape: [B, H, W, C]

        # üîÅ Convert to one-hot for metric compatibility
        one_hot_pred = tf.one_hot(tf.argmax(avg_prob, axis=-1), depth=avg_prob.shape[-1])
        return one_hot_pred  # [B, H, W, C]

final_weights = [0.255, 0.2427, 0.2515, 0.2508]
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

teacher_model = ensemble_model 

def distillation_loss(y_true, y_student_logits, y_teacher_probs, alpha=0.5, temperature=3.0):
    # Softened predictions for KL
    student_soft = tf.nn.softmax(y_student_logits / temperature)
    teacher_soft = tf.nn.softmax(y_teacher_probs / temperature)

    # Soft loss: KL divergence
    kl_loss = tf.keras.losses.KLDivergence()(teacher_soft, student_soft)

    # Hard loss: Use your custom combined loss (Dice + Lovasz)
    ce_loss = combined_loss(y_true, y_student_logits)

    # Combine them
    return alpha * ce_loss + (1 - alpha) * (temperature ** 2) * kl_loss

# === KD Wrapper Model ===
class KDTrainer(tf.keras.Model):
    def __init__(self, student, teacher, alpha=0.5, temperature=3.0):
        super(KDTrainer, self).__init__()
        self.student = student
        self.teacher = teacher
        self.alpha = alpha
        self.temperature = temperature

    def compile(self, optimizer, metrics):
        super().compile()
        self.optimizer = optimizer
        self.metrics_list = metrics

    def train_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)

        with tf.GradientTape() as tape:
            student_logits = self.student(x, training=True)               # [B, H, W, C]
            teacher_probs = self.teacher(x, training=False)               # Soft probs

            loss = distillation_loss(
                y_true, student_logits, teacher_probs,
                alpha=self.alpha, temperature=self.temperature
            )

        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))

        for metric in self.metrics_list:
            metric.update_state(y_true, student_logits)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

    def test_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)
        y_pred = self.student(x, training=False)
        loss = combined_loss(y_true, y_pred)

        for metric in self.metrics_list:
            metric.update_state(y_true, y_pred)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

# === Instantiate KDTrainer ===
kd_model = KDTrainer(
    student=student_model,
    teacher=teacher_model,
    alpha=0.5,
    temperature=3.0
)

# === Compile ===
kd_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    metrics=class_wise_metrics(num_classes=4)
)

from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = f"best_student_unetplusplus_{timestamp}"

callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        min_lr=1e-6
    ), 
    ModelCheckpoint(
    filepath=checkpoint_path,
    monitor='val_loss',
    save_best_only=True,
    save_weights_only=True,
    save_format='tf'  # ‚úÖ use TF SavedModel format
    )
]


from tensorflow.keras.preprocessing.image import ImageDataGenerator

def create_train_generator(X, y, batch_size=16):
    data_gen_args = dict(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    seed = 42
    image_generator = image_datagen.flow(X, batch_size=batch_size, seed=seed)
    mask_generator = mask_datagen.flow(y, batch_size=batch_size, seed=seed)

    while True:
        X_batch = next(image_generator)
        y_batch = next(mask_generator)
        yield X_batch.astype('float32'), y_batch.astype('float32')

batch_size = 16
train_generator = create_train_generator(X_train, y_train, batch_size=batch_size)
steps_per_epoch = len(X_train) // batch_size


history = kd_model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    validation_data=(X_val, y_val),
    epochs=100,
    callbacks=callbacks
)

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

# === Your ensemble setup ===
models = [
    model_xception,
    model_segnet,
    model_inceptionresnetv2,
    model_efficientnetb4
]

class WeightedSoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, weights=None, apply_softmax=True):
        super(WeightedSoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        else:
            total = sum(weights)
            weights = [w / total for w in weights]

        self.model_weights = tf.constant(weights, dtype=tf.float32)

    def call(self, x, training=False):
        weighted_sum = 0
        for i, model in enumerate(self.models):
            output = model(x, training=training)

            is_softmaxed = (
                hasattr(model, "name") and "efficientnet" in model.name.lower()
            )

            if self.apply_softmax and not is_softmaxed:
                probs = tf.nn.softmax(output, axis=-1)
            else:
                probs = output

            weighted_sum += self.model_weights[i] * probs

        avg_prob = weighted_sum
        one_hot_pred = tf.one_hot(tf.argmax(avg_prob, axis=-1), depth=avg_prob.shape[-1])
        return one_hot_pred  # ‚ö†Ô∏è returns one-hot

final_weights = [0.255, 0.2427, 0.2515, 0.2508]
ensemble_model = WeightedSoftVotingEnsemble(models=models, weights=final_weights, apply_softmax=True)

# === SOFT OUTPUT EXTRACTOR (helper workaround) ===
def get_teacher_soft_output(ensemble_model, x):
    weighted_sum = 0
    for i, model in enumerate(ensemble_model.models):
        output = model(x, training=False)

        is_softmaxed = (
            hasattr(model, "name") and "efficientnet" in model.name.lower()
        )

        if ensemble_model.apply_softmax and not is_softmaxed:
            probs = tf.nn.softmax(output, axis=-1)
        else:
            probs = output

        weighted_sum += ensemble_model.model_weights[i] * probs

    return weighted_sum  # soft probabilities

# === Distillation Loss ===
def distillation_loss(y_true, y_student_logits, y_teacher_probs, alpha=0.2, temperature=5.0):
    student_soft = tf.nn.softmax(y_student_logits / temperature)
    teacher_soft = tf.nn.softmax(y_teacher_probs / temperature)

    kl_loss = tf.keras.losses.KLDivergence()(teacher_soft, student_soft)

    # Option 1: Combined loss on hard labels
    ce_loss = combined_loss(y_true, tf.nn.softmax(y_student_logits))
    
    # Option 2: Standard CE (more stable for KD) ‚Äî you can switch if needed
    # ce_loss = tf.keras.losses.CategoricalCrossentropy()(y_true, tf.nn.softmax(y_student_logits))

    return alpha * ce_loss + (1 - alpha) * (temperature ** 2) * kl_loss

# === KD Wrapper ===
class KDTrainer(tf.keras.Model):
    def __init__(self, student, teacher, alpha=0.2, temperature=5.0):
        super(KDTrainer, self).__init__()
        self.student = student
        self.teacher = teacher
        self.alpha = alpha
        self.temperature = temperature

    def compile(self, optimizer, metrics):
        super().compile()
        self.optimizer = optimizer
        self.metrics_list = metrics

    def train_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)

        with tf.GradientTape() as tape:
            student_logits = self.student(x, training=True)
            teacher_probs = get_teacher_soft_output(self.teacher, x)  # ‚úÖ SOFT OUTPUT FIX

            loss = distillation_loss(
                y_true, student_logits, teacher_probs,
                alpha=self.alpha, temperature=self.temperature
            )

        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))

        student_soft = tf.nn.softmax(student_logits)
        for metric in self.metrics_list:
            metric.update_state(y_true, student_soft)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

    def test_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)
        student_logits = self.student(x, training=False)
        student_soft = tf.nn.softmax(student_logits)
        loss = combined_loss(y_true, student_soft)

        for metric in self.metrics_list:
            metric.update_state(y_true, student_soft)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

# === Compile KD Model ===
kd_model = KDTrainer(
    student=student_model,
    teacher=ensemble_model,
    alpha=0.2,
    temperature=5.0
)

kd_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    metrics=class_wise_metrics(num_classes=4)
)

# === Callbacks ===
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = f"best_student_unetplusplus_{timestamp}"

callbacks = [
    EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6),
    ModelCheckpoint(filepath=checkpoint_path, monitor='val_loss', save_best_only=True,
                    save_weights_only=True, save_format='tf')
]

# === Data Generator ===
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def create_train_generator(X, y, batch_size=16):
    data_gen_args = dict(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    image_gen = ImageDataGenerator(**data_gen_args)
    mask_gen = ImageDataGenerator(**data_gen_args)

    seed = 42
    image_generator = image_gen.flow(X, batch_size=batch_size, seed=seed)
    mask_generator = mask_gen.flow(y, batch_size=batch_size, seed=seed)

    while True:
        X_batch = next(image_generator)
        y_batch = next(mask_generator)
        yield X_batch.astype('float32'), y_batch.astype('float32')

# === Train ===
batch_size = 16
train_generator = create_train_generator(X_train, y_train, batch_size=batch_size)
steps_per_epoch = len(X_train) // batch_size

history = kd_model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    validation_data=(X_val, y_val),
    epochs=100,
    callbacks=callbacks
)

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

models = [
    model_xception,
    model_segnet,
    model_inceptionresnetv2,
    model_efficientnetb4
]

class WeightedSoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, weights=None, apply_softmax=True):
        super(WeightedSoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        else:
            total = sum(weights)
            weights = [w / total for w in weights]

        self.model_weights = tf.constant(weights, dtype=tf.float32)

    def call(self, x, training=False):
        weighted_sum = 0
        for i, model in enumerate(self.models):
            output = model(x, training=training)

            is_softmaxed = (
                hasattr(model, "name") and "efficientnet" in model.name.lower()
            )

            if self.apply_softmax and not is_softmaxed:
                probs = tf.nn.softmax(output, axis=-1)
            else:
                probs = output

            weighted_sum += self.model_weights[i] * probs

        avg_prob = weighted_sum  # shape: [B, H, W, C]

        # üîÅ Convert to one-hot for metric compatibility
        one_hot_pred = tf.one_hot(tf.argmax(avg_prob, axis=-1), depth=avg_prob.shape[-1])
        return one_hot_pred  # [B, H, W, C]

final_weights = [0.255, 0.2427, 0.2515, 0.2508]
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

teacher_model = ensemble_model 

def distillation_loss(y_true, y_student_logits, y_teacher_probs, alpha=0.5, temperature=3.0):
    # Softened predictions for KL
    student_soft = tf.nn.softmax(y_student_logits / temperature)
    teacher_soft = tf.nn.softmax(y_teacher_probs / temperature)

    # Soft loss: KL divergence
    kl_loss = tf.keras.losses.KLDivergence()(teacher_soft, student_soft)

    # Hard loss: Use your custom combined loss (Dice + Lovasz)
    ce_loss = combined_loss(y_true, y_student_logits) + tf.keras.losses.CategoricalCrossentropy()(y_true, y_student_logits)

    # Combine them
    return alpha * ce_loss + (1 - alpha) * (temperature ** 2) * kl_loss

# === KD Wrapper Model ===
class KDTrainer(tf.keras.Model):
    def __init__(self, student, teacher, alpha=0.5, temperature=3.0):
        super(KDTrainer, self).__init__()
        self.student = student
        self.teacher = teacher
        self.alpha = alpha
        self.temperature = temperature

    def compile(self, optimizer, metrics):
        super().compile()
        self.optimizer = optimizer
        self.metrics_list = metrics

    def train_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)

        with tf.GradientTape() as tape:
            student_logits = self.student(x, training=True)               # [B, H, W, C]
            teacher_probs = self.teacher(x, training=False)               # Soft probs

            loss = distillation_loss(
                y_true, student_logits, teacher_probs,
                alpha=self.alpha, temperature=self.temperature
            )

        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))

        for metric in self.metrics_list:
            metric.update_state(y_true, student_logits)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

    def test_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)
        y_pred = self.student(x, training=False)
        loss = combined_loss(y_true, y_pred)

        for metric in self.metrics_list:
            metric.update_state(y_true, y_pred)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

# === Instantiate KDTrainer ===
kd_model = KDTrainer(
    student=student_model,
    teacher=teacher_model,
    alpha=0.5,
    temperature=3.0
)

# === Compile ===
kd_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    metrics=class_wise_metrics(num_classes=4)
)

from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = f"best_student_unetplusplus_{timestamp}"

callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        min_lr=1e-6
    ), 
    ModelCheckpoint(
    filepath=checkpoint_path,
    monitor='val_loss',
    save_best_only=True,
    save_weights_only=True,
    save_format='tf'  # ‚úÖ use TF SavedModel format
    )
]


from tensorflow.keras.preprocessing.image import ImageDataGenerator

def create_train_generator(X, y, batch_size=16):
    data_gen_args = dict(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    seed = 42
    image_generator = image_datagen.flow(X, batch_size=batch_size, seed=seed)
    mask_generator = mask_datagen.flow(y, batch_size=batch_size, seed=seed)

    while True:
        X_batch = next(image_generator)
        y_batch = next(mask_generator)
        yield X_batch.astype('float32'), y_batch.astype('float32')

batch_size = 16
train_generator = create_train_generator(X_train, y_train, batch_size=batch_size)
steps_per_epoch = len(X_train) // batch_size


# history = kd_model.fit(
#     train_generator,
#     steps_per_epoch=steps_per_epoch,
#     validation_data=(X_val, y_val),
#     epochs=100,
#     callbacks=callbacks
# )

In [None]:
def objective(trial):
    tf.keras.backend.clear_session()

    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
    temperature = trial.suggest_categorical("temperature", [1, 3, 5, 10])
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "Nadam", "SGD"])
    alpha = trial.suggest_float("alpha", 0.1, 0.9, step=0.2)
    lr = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)

    if optimizer_name == "Adam":
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    elif optimizer_name == "RMSprop":
        optimizer = tf.keras.optimizers.RMSprop(learning_rate=lr)
    elif optimizer_name == "SGD":
        optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
    elif optimizer_name == "Nadam":
        optimizer = tf.keras.optimizers.Nadam(learning_rate=lr)

    
    student_copy = tf.keras.models.clone_model(student_model)

    kd_model = KDTrainer(
        student=student_copy,
        teacher=teacher_model,
        alpha=alpha,
        temperature=temperature
    )

    kd_model.compile(
        optimizer=optimizer,
        metrics=class_wise_metrics(num_classes=4)
    )

    train_gen = create_train_generator(X_train, y_train, batch_size=batch_size)
    val_data = (X_val, y_val)

    try:
        history = kd_model.fit(
            train_gen,
            steps_per_epoch=len(X_train) // batch_size,
            validation_data=val_data,
            epochs=5,
            verbose=0
        )
    except tf.errors.ResourceExhaustedError:
        print(f"OOM at batch_size={batch_size}")
        tf.keras.backend.clear_session()
        raise optuna.exceptions.TrialPruned()

    return history.history['val_loss'][-1]

study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=48)  # or 48 for full grid coverage
print("Best params:", study.best_params)

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

models = [
    model_xception,
    model_segnet,
    model_inceptionresnetv2,
    model_efficientnetb4
]

class WeightedSoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, weights=None, apply_softmax=True):
        super(WeightedSoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        else:
            total = sum(weights)
            weights = [w / total for w in weights]

        self.model_weights = tf.constant(weights, dtype=tf.float32)

    def call(self, x, training=False):
        weighted_sum = 0
        for i, model in enumerate(self.models):
            output = model(x, training=training)

            is_softmaxed = (
                hasattr(model, "name") and "efficientnet" in model.name.lower()
            )

            if self.apply_softmax and not is_softmaxed:
                probs = tf.nn.softmax(output, axis=-1)
            else:
                probs = output

            weighted_sum += self.model_weights[i] * probs

        avg_prob = weighted_sum  # shape: [B, H, W, C]

        # üîÅ Convert to one-hot for metric compatibility
        one_hot_pred = tf.one_hot(tf.argmax(avg_prob, axis=-1), depth=avg_prob.shape[-1])
        return one_hot_pred  # [B, H, W, C]

final_weights = [0.255, 0.2427, 0.2515, 0.2508]
ensemble_model = WeightedSoftVotingEnsemble(
    models=models,
    weights=final_weights,
    apply_softmax=True
)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

teacher_model = ensemble_model 

def distillation_loss(y_true, y_student_logits, y_teacher_probs, alpha=0.5, temperature=3.0):
    # Softened predictions for KL
    student_soft = tf.nn.softmax(y_student_logits / temperature)
    teacher_soft = tf.nn.softmax(y_teacher_probs / temperature)

    # Soft loss: KL divergence
    kl_loss = tf.keras.losses.KLDivergence()(teacher_soft, student_soft)

    # Hard loss: Use your custom combined loss (Dice + Lovasz)
    ce_loss = combined_loss(y_true, y_student_logits) + tf.keras.losses.CategoricalCrossentropy()(y_true, y_student_logits)

    # Combine them
    return alpha * ce_loss + (1 - alpha) * (temperature ** 2) * kl_loss

# === KD Wrapper Model ===
class KDTrainer(tf.keras.Model):
    def __init__(self, student, teacher, alpha=0.5, temperature=3.0):
        super(KDTrainer, self).__init__()
        self.student = student
        self.teacher = teacher
        self.alpha = alpha
        self.temperature = temperature

    def compile(self, optimizer, metrics):
        super().compile()
        self.optimizer = optimizer
        self.metrics_list = metrics

    def train_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)

        with tf.GradientTape() as tape:
            student_logits = self.student(x, training=True)               # [B, H, W, C]
            teacher_probs = self.teacher(x, training=False)               # Soft probs

            loss = distillation_loss(
                y_true, student_logits, teacher_probs,
                alpha=self.alpha, temperature=self.temperature
            )

        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))

        for metric in self.metrics_list:
            metric.update_state(y_true, student_logits)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

    def test_step(self, data):
        x, y_true = data
        y_true = tf.cast(y_true, tf.float32)
        y_pred = self.student(x, training=False)
        loss = combined_loss(y_true, y_pred)

        for metric in self.metrics_list:
            metric.update_state(y_true, y_pred)

        return {m.name: m.result() for m in self.metrics_list} | {"loss": loss}

# === Instantiate KDTrainer ===
kd_model = KDTrainer(
    student=student_model,
    teacher=teacher_model,
    alpha=0.5,
    temperature=3.0
)

# === Compile ===
kd_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    metrics=class_wise_metrics(num_classes=4)
)

from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = f"best_student_unetplusplus_{timestamp}"

callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        min_lr=1e-6
    ), 
    ModelCheckpoint(
    filepath=checkpoint_path,
    monitor='val_loss',
    save_best_only=True,
    save_weights_only=True,
    save_format='tf'  # ‚úÖ use TF SavedModel format
    )
]


from tensorflow.keras.preprocessing.image import ImageDataGenerator

def create_train_generator(X, y, batch_size=16):
    data_gen_args = dict(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    seed = 42
    image_generator = image_datagen.flow(X, batch_size=batch_size, seed=seed)
    mask_generator = mask_datagen.flow(y, batch_size=batch_size, seed=seed)

    while True:
        X_batch = next(image_generator)
        y_batch = next(mask_generator)
        yield X_batch.astype('float32'), y_batch.astype('float32')

batch_size = 16
train_generator = create_train_generator(X_train, y_train, batch_size=batch_size)
steps_per_epoch = len(X_train) // batch_size


history = kd_model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    validation_data=(X_val, y_val),
    epochs=100,
    callbacks=callbacks
)

In [None]:
import matplotlib.pyplot as plt

# Extracting data from the history object
history_dict = history.history

# Plotting the training and validation loss
plt.figure(figsize=(12, 6))

# Plotting loss
plt.subplot(1, 2, 1)
plt.plot(history_dict['loss'], label='Training Loss')
plt.plot(history_dict['val_loss'], label='Validation Loss')
plt.title('Loss Curves')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# If accuracy is available, plot training and validation accuracy
if 'accuracy' in history_dict:
    plt.subplot(1, 2, 2)
    plt.plot(history_dict['accuracy'], label='Training Accuracy')
    plt.plot(history_dict['val_accuracy'], label='Validation Accuracy')
    plt.title('Accuracy Curves')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff
from tensorflow.keras.utils import to_categorical

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ Function to convert RGB masks to class index masks
def rgb_to_class_mask(rgb_mask):
    # Create a mask initialized with zeros (for background class)
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)

    # Loop through the RGB_TO_CLASS dictionary
    for rgb, class_idx in RGB_TO_CLASS.items():
        # Identify the pixels with the current RGB value and assign them the class index
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx

    return class_mask

# ‚úÖ Function to calculate Dice Similarity Coefficient (DSC)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ Function to calculate IoU (Intersection over Union)
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Function to calculate Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')  # Return inf if no points for either true or pred class

    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Function to calculate Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')  # Return inf if no points for either true or pred class

    distances = []
    for true_point in true_points:
        distances.append(np.min(np.linalg.norm(pred_points - true_point, axis=1)))
    return np.mean(distances)

# ‚úÖ Function to evaluate the model on the test set class-wise
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=16):
    # Predict in batches
    y_pred = model.predict(X_test, batch_size=batch_size)
    y_pred = np.argmax(y_pred, axis=-1)  # Convert to class index prediction

    # Convert y_test to class index format (since it's one-hot encoded)
    y_test_class = np.argmax(y_test, axis=-1)

    # Initialize lists to store class-wise metrics
    class_metrics = {i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []} for i in range(num_classes)}

    # Calculate metrics for each test sample
    for i in range(len(X_test)):
        true_mask = y_test_class[i]  # one-hot -> class index
        pred_mask = y_pred[i]

        # For each class (0: Background, 1: Brain, 2: CSP, 3: LV)
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            # Dice Coefficient
            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            # IoU
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            # Precision
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Recall
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # F1 Score
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Accuracy
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # # Hausdorff Distance
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # # Average Surface Distance
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # Print class-wise metrics in percentage
    print(f"{'Class':<10}{'Dice Coefficient (%)':<20}{'IoU (%)':<20}{'Precision (%)':<20}{'Recall (%)':<20}{'F1 Score (%)':<20}{'Accuracy (%)':<20}{'Hausdorff Distance':<20}{'Avg Surface Distance':<20}")
    print('-' * 180)

    for class_idx in range(num_classes):
        print(f"Class {class_idx}:")
        print(f"  Dice Coefficient: {np.mean(class_metrics[class_idx]['dice']) * 100:.2f}%")
        print(f"  IoU: {np.mean(class_metrics[class_idx]['iou']) * 100:.2f}%")
        print(f"  Precision: {np.mean(class_metrics[class_idx]['precision']) * 100:.2f}%")
        print(f"  Recall: {np.mean(class_metrics[class_idx]['recall']) * 100:.2f}%")
        print(f"  F1 Score: {np.mean(class_metrics[class_idx]['f1']) * 100:.2f}%")
        print(f"  Accuracy: {np.mean(class_metrics[class_idx]['accuracy']) * 100:.2f}%")
        # print(f"  Hausdorff Distance: {np.mean(class_metrics[class_idx]['hausdorff']):.4f}")
        # print(f"  Average Surface Distance: {np.mean(class_metrics[class_idx]['asd']):.4f}")
        print("-" * 180)

    # Evaluate on test set to print overall test accuracy and loss
    test_loss, *test_metrics = model.evaluate(X_test, y_test, batch_size=batch_size)
    print(f"Test Loss: {test_loss:.4f}")

    for metric, value in zip(model.metrics_names[1:], test_metrics):
        print(f"{metric}: {value:.4f}")

# ‚úÖ Call the evaluation function on the test set class-wise
evaluate_classwise_metrics(kd_model, X_test, y_test)

In [None]:
kd_model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
kd_model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
kd_model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
student_model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
teacher_model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
teacher_model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
class WeightedSoftVotingEnsemble(tf.keras.Model):
    def __init__(self, models, weights=None, apply_softmax=True):
        super(WeightedSoftVotingEnsemble, self).__init__()
        self.models = models
        self.apply_softmax = apply_softmax

        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        else:
            total = sum(weights)
            weights = [w / total for w in weights]

        self.model_weights = tf.constant(weights, dtype=tf.float32)

    def call(self, x, training=False):
        weighted_sum = 0
        for i, model in enumerate(self.models):
            output = model(x, training=training)

            is_softmaxed = (
                hasattr(model, "name") and "efficientnet" in model.name.lower()
            )

            if self.apply_softmax and not is_softmaxed:
                probs = tf.nn.softmax(output, axis=-1)
            else:
                probs = output

            weighted_sum += self.model_weights[i] * probs

        avg_prob = weighted_sum
        one_hot_pred = tf.one_hot(tf.argmax(avg_prob, axis=-1), depth=avg_prob.shape[-1])
        return one_hot_pred 

final_weights = [0.255, 0.2427, 0.2515, 0.2508]
ensemble_model = WeightedSoftVotingEnsemble(models=models, weights=final_weights, apply_softmax=True)

ensemble_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=class_wise_metrics(num_classes=4)
)

ensemble_model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
kd_model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
kd_model.evaluate(X_test, y_test, batch_size=batch_size)

In [None]:
import os
if os.path.exists("best_student_unetplusplus_distilled.weights.h5"):
    os.remove("best_student_unetplusplus_distilled.weights.h5")

In [None]:
student_model.load_weights("student_model_weights_final.h5")

In [None]:
def class_wise_metrics(num_classes=4):
    return [DiceCoefficient(i) for i in range(num_classes)] + [tf.keras.metrics.MeanIoU(num_classes=num_classes)]

def student_eval_loss(y_true, y_pred):
    return [combined_loss(y_true, y_pred) + tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred)]

student_model.compile(
    optimizer= tf.keras.optimizers.RMSprop(learning_rate=0.0001),
    loss=student_eval_loss,
    metrics=class_wise_metrics(num_classes=4)
)

In [None]:
student_model.evaluate(X_test, y_test, batch_size=8)

In [None]:
student_model.evaluate(X_test, y_test, batch_size=16)

In [None]:
student_model.evaluate(X_test, y_test, batch_size=32)