In [None]:
import numpy as np
import os
import gc
import cv2
import re
from tensorflow.keras.utils import to_categorical

# ‚úÖ Constants for 224x224
IMG_HEIGHT = 224  # Ensure height is 224
IMG_WIDTH = 224   # Ensure width is 224
CHANNELS = 3  # RGB images
NUM_CLASSES = 4  # Brain, CSP, LV, Background

# ‚úÖ Class mapping from RGB to class index
CLASS_MAP = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0,  # Background
}

image_dir = r"D:\augmented_dataset\images"
mask_dir = r"D:\augmented_dataset\masks"

# ‚úÖ Define destination directories
train_image_dir = r"D:\Updated\train\images"
train_mask_dir = r"D:\Updated\train\masks"
val_image_dir = r"D:\Updated\val\images"
val_mask_dir = r"D:\Updated\val\masks"
test_image_dir = r"D:\Updated\test\images"
test_mask_dir = r"D:\Updated\test\masks"

# ‚úÖ Fix sorting issue using natural sorting
def natural_sort_key(s):
    """Sort filenames numerically instead of lexicographically."""
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

# ‚úÖ Convert RGB mask to class index mask
def rgb_to_class(mask_array):
    """Convert RGB mask to single-channel class index mask."""
    height, width, _ = mask_array.shape
    class_mask = np.zeros((height, width), dtype=np.uint8)

    for rgb, class_idx in CLASS_MAP.items():
        matches = np.all(mask_array == rgb, axis=-1)  # Ensure exact match
        class_mask[matches] = class_idx

    return class_mask

# ‚úÖ Preprocess Filtered Dataset for 224x224
def preprocess_filtered_dataset(image_dir, mask_dir):
    """Preprocess images & masks: normalize, resize, and convert masks to one-hot encoding."""

    # ‚úÖ Load and sort filenames correctly
    image_filenames = sorted(os.listdir(image_dir), key=natural_sort_key)
    mask_filenames = sorted(os.listdir(mask_dir), key=natural_sort_key)

    valid_image_paths = []
    valid_mask_paths = []

    # ‚úÖ Ensure each image has a corresponding mask
    for img_file, mask_file in zip(image_filenames, mask_filenames):
        img_path = os.path.join(image_dir, img_file)
        mask_path = os.path.join(mask_dir, mask_file)

        if os.path.exists(img_path) and os.path.exists(mask_path):
            valid_image_paths.append(img_path)
            valid_mask_paths.append(mask_path)
        else:
            print(f"‚ö†Ô∏è Skipping {img_file}: Missing image or mask")

    num_images = len(valid_image_paths)

    # ‚úÖ Initialize arrays
    X = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=np.float32)
    y = np.zeros((num_images, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.float32)  # One-hot encoded masks

    print(f"üöÄ Processing {num_images} filtered images and masks...")

    for idx, (img_path, mask_path) in enumerate(zip(valid_image_paths, valid_mask_paths)):
        if idx % 100 == 0:
            print(f"‚úÖ Processed {idx}/{num_images} images")

        # ‚úÖ Load and Resize Image
        img = cv2.imread(img_path)  # Read image in BGR format
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))  # Resize to (224,224)
        img = img.astype(np.float32) / 255.0  # Normalize

        # ‚úÖ Load and Resize Mask
        mask = cv2.imread(mask_path)  # Read mask in BGR format
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)  # Resize mask correctly

        # ‚úÖ Convert RGB mask to class mask
        class_mask = rgb_to_class(mask)

        # ‚úÖ One-hot encode the class mask
        one_hot_mask = to_categorical(class_mask, num_classes=NUM_CLASSES)

        # ‚úÖ Store preprocessed data
        X[idx] = img
        y[idx] = one_hot_mask

        # ‚úÖ Clear memory to prevent memory leaks
        del img, mask, class_mask, one_hot_mask
        gc.collect()

    return X, y

from sklearn.model_selection import train_test_split

# ‚úÖ Process dataset splits
X_train, y_train = preprocess_filtered_dataset(train_image_dir, train_mask_dir)
X_val, y_val = preprocess_filtered_dataset(val_image_dir, val_mask_dir)
X_test, y_test = preprocess_filtered_dataset(test_image_dir, test_mask_dir)

# ‚úÖ Print dataset information
print("\n‚úÖ Dataset Splits:")
print(f"  - Training set: {X_train.shape}, {y_train.shape}")
print(f"  - Validation set: {X_val.shape}, {y_val.shape}")
print(f"  - Test set: {X_test.shape}, {y_test.shape}")

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Data Augmentation configuration for the training set
train_datagen = ImageDataGenerator(
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
)

# Fit the augmentation parameters on the training data
train_datagen.fit(X_train)

In [None]:
print("Unique values in y_train:", np.unique(np.argmax(y_train, axis=-1)))
print("Unique values in y_val:", np.unique(np.argmax(y_val, axis=-1)))
print("Unique values in y_test:", np.unique(np.argmax(y_test, axis=-1)))

In [None]:
# import tensorflow as tf
# from tensorflow.keras import Input, Model
# from tensorflow.keras.layers import Conv2DTranspose, Conv2D, Concatenate, BatchNormalization, Activation
# from tensorflow.keras.applications import Xception

# class ResizeLayer(tf.keras.layers.Layer):
#     """Custom layer to resize images."""
#     def __init__(self, target_size, **kwargs):
#         super(ResizeLayer, self).__init__(**kwargs)
#         self.target_size = target_size
    
#     def call(self, inputs):
#         return tf.image.resize(inputs, self.target_size, method='bilinear')
    
#     def get_config(self):
#         config = super(ResizeLayer, self).get_config()
#         config.update({"target_size": self.target_size})
#         return config

# def find_best_skip_layers(base_model, target_sizes):
#     """Finds the best matching layers for skip connections (based on spatial size)."""
#     layer_dict = {layer.name: layer.output.shape for layer in base_model.layers}
#     skip_connections = []
#     for target_size in target_sizes:
#         best_match = None
#         min_diff = float('inf')
#         for layer_name, shape in layer_dict.items():
#             if len(shape) == 4:  # Ensure it's a conv layer
#                 h, w, c = shape[1], shape[2], shape[3]  # Extract spatial dims & channels
#                 # Prioritize spatial dimensions match
#                 diff = abs(h - target_size[0]) + abs(w - target_size[1])
#                 if diff < min_diff:
#                     min_diff = diff
#                     best_match = layer_name
#         skip_connections.append(best_match)
#     return skip_connections

# def conv_block(x, filters, kernel_size=3, padding='same', activation='relu'):
#     """Helper function for creating a conv block with BN and activation."""
#     x = Conv2D(filters, kernel_size, padding=padding)(x)
#     x = BatchNormalization()(x)
#     x = Activation(activation)(x)
#     return x

# def build_unetpp_with_xception(input_shape=(224, 224, 3), num_classes=4):
#     """
#     Build UNet++ model with Xception backbone
    
#     Args:
#         input_shape: Input shape of the image
#         num_classes: Number of output classes
        
#     Returns:
#         Keras Model instance with UNet++ architecture
#     """
#     inputs = Input(shape=input_shape)
    
#     # Use Xception as the encoder backbone
#     # Note: Xception requires input size >= 71x71
#     base_model = Xception(
#         input_tensor=inputs, 
#         include_top=False, 
#         weights='imagenet'
#     )
    
#     # Define target sizes for skip connections at different levels
#     # These are approximate sizes for a 224x224 input with Xception
#     target_skip_sizes = [(56, 56, 128), (28, 28, 256), (14, 14, 728), (7, 7, 2048)]
    
#     # Find optimal skip connection layers
#     skip_layer_names = find_best_skip_layers(base_model, target_skip_sizes)
#     print("Selected skip connection layers:", skip_layer_names)
    
#     # Extract skip connections
#     skip1 = base_model.get_layer(skip_layer_names[0]).output  # ~56x56
#     skip2 = base_model.get_layer(skip_layer_names[1]).output  # ~28x28
#     skip3 = base_model.get_layer(skip_layer_names[2]).output  # ~14x14
#     skip4 = base_model.get_layer(skip_layer_names[3]).output  # ~7x7
    
#     # Bottleneck (deepest feature map)
#     bottleneck = skip4  # 7x7 feature map
    
#     # ----- UNet++ Architecture -----
#     # UNet++ adds nested skip pathways compared to regular UNet
    
#     # First decoder level (L4)
#     # From bottleneck to 14x14
#     up4 = Conv2DTranspose(512, (3, 3), strides=(2, 2), padding='same')(bottleneck)
#     up4 = BatchNormalization()(up4)
#     up4 = Activation('relu')(up4)
    
#     # Match sizes if needed
#     if up4.shape[1] != skip3.shape[1] or up4.shape[2] != skip3.shape[2]:
#         up4 = ResizeLayer(target_size=(skip3.shape[1], skip3.shape[2]))(up4)
    
#     # Create node X^0_4 (first nested dense block)
#     x0_4 = Concatenate()([up4, skip3])
#     x0_4 = conv_block(x0_4, 512)
#     x0_4 = conv_block(x0_4, 512)
    
#     # Second decoder level (L3)
#     # From x0_4 to 28x28
#     up3 = Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same')(x0_4)
#     up3 = BatchNormalization()(up3)
#     up3 = Activation('relu')(up3)
    
#     # Match sizes if needed
#     if up3.shape[1] != skip2.shape[1] or up3.shape[2] != skip2.shape[2]:
#         up3 = ResizeLayer(target_size=(skip2.shape[1], skip2.shape[2]))(up3)
    
#     # Create node X^0_3
#     x0_3 = Concatenate()([up3, skip2])
#     x0_3 = conv_block(x0_3, 256)
#     x0_3 = conv_block(x0_3, 256)
    
#     # Create node X^1_3 (second nested dense block)
#     up1_3 = Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same')(bottleneck)
#     up1_3 = BatchNormalization()(up1_3)
#     up1_3 = Activation('relu')(up1_3)
#     if up1_3.shape[1] != x0_3.shape[1] or up1_3.shape[2] != x0_3.shape[2]:
#         up1_3 = ResizeLayer(target_size=(x0_3.shape[1], x0_3.shape[2]))(up1_3)
#     x1_3 = Concatenate()([up1_3, x0_3])
#     x1_3 = conv_block(x1_3, 256)
#     x1_3 = conv_block(x1_3, 256)
    
#     # Third decoder level (L2)
#     # From x0_3 to 56x56
#     up2 = Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(x0_3)
#     up2 = BatchNormalization()(up2)
#     up2 = Activation('relu')(up2)
    
#     # Match sizes if needed
#     if up2.shape[1] != skip1.shape[1] or up2.shape[2] != skip1.shape[2]:
#         up2 = ResizeLayer(target_size=(skip1.shape[1], skip1.shape[2]))(up2)
    
#     # Create node X^0_2
#     x0_2 = Concatenate()([up2, skip1])
#     x0_2 = conv_block(x0_2, 128)
#     x0_2 = conv_block(x0_2, 128)
    
#     # Create node X^1_2 (nested dense block)
#     up1_2 = Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(x0_3)
#     up1_2 = BatchNormalization()(up1_2)
#     up1_2 = Activation('relu')(up1_2)
#     if up1_2.shape[1] != x0_2.shape[1] or up1_2.shape[2] != x0_2.shape[2]:
#         up1_2 = ResizeLayer(target_size=(x0_2.shape[1], x0_2.shape[2]))(up1_2)
#     x1_2 = Concatenate()([up1_2, x0_2])
#     x1_2 = conv_block(x1_2, 128)
#     x1_2 = conv_block(x1_2, 128)
    
#     # Create node X^2_2 (nested dense block)
#     up2_2 = Conv2DTranspose(128, (3, 3), strides=(4, 4), padding='same')(x1_3)
#     up2_2 = BatchNormalization()(up2_2)
#     up2_2 = Activation('relu')(up2_2)
#     if up2_2.shape[1] != x1_2.shape[1] or up2_2.shape[2] != x1_2.shape[2]:
#         up2_2 = ResizeLayer(target_size=(x1_2.shape[1], x1_2.shape[2]))(up2_2)
#     x2_2 = Concatenate()([up2_2, x1_2])
#     x2_2 = conv_block(x2_2, 128)
#     x2_2 = conv_block(x2_2, 128)
    
#     # Fourth decoder level (L1) - to original size
#     # From x0_2 to 224x224
#     up1 = Conv2DTranspose(64, (3, 3), strides=(4, 4), padding='same')(x0_2)
#     up1 = BatchNormalization()(up1)
#     up1 = Activation('relu')(up1)
    
#     # Create node X^0_1
#     x0_1 = up1
#     x0_1 = conv_block(x0_1, 64)
#     x0_1 = conv_block(x0_1, 64)
    
#     # Create node X^1_1
#     up1_1 = Conv2DTranspose(64, (3, 3), strides=(4, 4), padding='same')(x0_2)
#     up1_1 = BatchNormalization()(up1_1)
#     up1_1 = Activation('relu')(up1_1)
#     if up1_1.shape[1] != x0_1.shape[1] or up1_1.shape[2] != x0_1.shape[2]:
#         up1_1 = ResizeLayer(target_size=(x0_1.shape[1], x0_1.shape[2]))(up1_1)
#     x1_1 = Concatenate()([up1_1, x0_1])
#     x1_1 = conv_block(x1_1, 64)
#     x1_1 = conv_block(x1_1, 64)
    
#     # Create node X^2_1
#     up2_1 = Conv2DTranspose(64, (3, 3), strides=(8, 8), padding='same')(x1_2)
#     up2_1 = BatchNormalization()(up2_1)
#     up2_1 = Activation('relu')(up2_1)
#     if up2_1.shape[1] != x1_1.shape[1] or up2_1.shape[2] != x1_1.shape[2]:
#         up2_1 = ResizeLayer(target_size=(x1_1.shape[1], x1_1.shape[2]))(up2_1)
#     x2_1 = Concatenate()([up2_1, x1_1])
#     x2_1 = conv_block(x2_1, 64)
#     x2_1 = conv_block(x2_1, 64)
    
#     # Create node X^3_1
#     up3_1 = Conv2DTranspose(64, (3, 3), strides=(16, 16), padding='same')(x2_2)
#     up3_1 = BatchNormalization()(up3_1)
#     up3_1 = Activation('relu')(up3_1)
#     if up3_1.shape[1] != x2_1.shape[1] or up3_1.shape[2] != x2_1.shape[2]:
#         up3_1 = ResizeLayer(target_size=(x2_1.shape[1], x2_1.shape[2]))(up3_1)
#     x3_1 = Concatenate()([up3_1, x2_1])
#     x3_1 = conv_block(x3_1, 64)
#     x3_1 = conv_block(x3_1, 64)
    
#     # Final resize to ensure output is exactly input_shape size
#     if x3_1.shape[1] != input_shape[0] or x3_1.shape[2] != input_shape[1]:
#         x3_1 = ResizeLayer(target_size=(input_shape[0], input_shape[1]))(x3_1)
    
#     # Segmentation head
#     outputs = Conv2D(num_classes, (1, 1), activation='softmax')(x3_1)
    
#     # Create model
#     model = Model(inputs=inputs, outputs=outputs)
    
#     return model

# # Model parameters
# IMG_HEIGHT = 224
# IMG_WIDTH = 224
# CHANNELS = 3
# NUM_CLASSES = 4

# # Build and summarize model
# model = build_unetpp_with_xception(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), num_classes=NUM_CLASSES)
# model.summary()

import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, UpSampling2D, Concatenate, BatchNormalization, Activation
from tensorflow.keras.applications import Xception

# Enable mixed precision training to reduce memory usage
tf.keras.mixed_precision.set_global_policy('mixed_float16')

class ResizeLayer(tf.keras.layers.Layer):
    """Custom layer to resize images."""
    def __init__(self, target_size, **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.target_size = target_size
    
    def call(self, inputs):
        return tf.image.resize(inputs, self.target_size, method='bilinear')
    
    def get_config(self):
        config = super(ResizeLayer, self).get_config()
        config.update({"target_size": self.target_size})
        return config

def find_best_skip_layers(base_model, target_sizes):
    """Finds the best matching layers for skip connections (based on spatial size)."""
    layer_dict = {layer.name: layer.output.shape for layer in base_model.layers}
    skip_connections = []
    for target_size in target_sizes:
        best_match = None
        min_diff = float('inf')
        for layer_name, shape in layer_dict.items():
            if len(shape) == 4:  # Ensure it's a conv layer
                h, w, c = shape[1], shape[2], shape[3]  # Extract spatial dims & channels
                # Prioritize spatial dimensions match
                diff = abs(h - target_size[0]) + abs(w - target_size[1])
                if diff < min_diff:
                    min_diff = diff
                    best_match = layer_name
        skip_connections.append(best_match)
    return skip_connections

def conv_block(x, filters, kernel_size=3, padding='same', activation='relu'):
    """Helper function for creating a conv block with BN and activation."""
    x = Conv2D(filters, kernel_size, padding=padding)(x)
    x = BatchNormalization()(x)
    x = Activation(activation)(x)
    return x

def build_memory_efficient_unetpp_xception(input_shape=(224, 224, 3), num_classes=4, batch_size=None):
    """
    Build memory-efficient UNet++ model with Xception backbone
    
    Args:
        input_shape: Input shape of the image
        num_classes: Number of output classes
        batch_size: Optional fixed batch size to optimize memory usage
        
    Returns:
        Keras Model instance with UNet++ architecture
    """
    # Optionally use fixed batch size to optimize memory allocation
    if batch_size:
        inputs = Input(batch_shape=(batch_size,) + input_shape)
    else:
        inputs = Input(shape=input_shape)
    
    # Use Xception as the encoder backbone
    base_model = Xception(
        input_tensor=inputs, 
        include_top=False, 
        weights='imagenet'
    )
    
    # Freeze early layers to reduce gradient memory usage
    for layer in base_model.layers[:50]:
        layer.trainable = False
    
    # Define target sizes for skip connections at different levels
    target_skip_sizes = [(56, 56, 128), (28, 28, 256), (14, 14, 728), (7, 7, 2048)]
    
    # Find optimal skip connection layers
    skip_layer_names = find_best_skip_layers(base_model, target_sizes=target_skip_sizes)
    print("Selected skip connection layers:", skip_layer_names)
    
    # Extract skip connections
    skip1 = base_model.get_layer(skip_layer_names[0]).output  # ~56x56
    skip2 = base_model.get_layer(skip_layer_names[1]).output  # ~28x28
    skip3 = base_model.get_layer(skip_layer_names[2]).output  # ~14x14
    skip4 = base_model.get_layer(skip_layer_names[3]).output  # ~7x7
    
    # Reduce channel dimensions in skip connections to save memory
    skip1 = Conv2D(64, 1, padding='same')(skip1)
    skip2 = Conv2D(128, 1, padding='same')(skip2)
    skip3 = Conv2D(256, 1, padding='same')(skip3)
    
    # Bottleneck (deepest feature map) - reduce channels
    bottleneck = Conv2D(512, 1, padding='same')(skip4)  # 7x7 feature map
    
    # ----- UNet++ Architecture with memory optimizations -----
    
    # First decoder level (L4) - From bottleneck to 14x14
    # Use UpSampling2D instead of Conv2DTranspose to save memory
    up4 = UpSampling2D(size=(2, 2))(bottleneck)
    up4 = conv_block(up4, 256)  # Reduced filter count
    
    # Match sizes if needed
    if up4.shape[1] != skip3.shape[1] or up4.shape[2] != skip3.shape[2]:
        up4 = ResizeLayer(target_size=(skip3.shape[1], skip3.shape[2]))(up4)
    
    # Create node X^0_4 (first nested dense block) - with reduced filters
    x0_4 = Concatenate()([up4, skip3])
    x0_4 = conv_block(x0_4, 256)  # Single conv block instead of double
    
    # Second decoder level (L3) - From x0_4 to 28x28
    up3 = UpSampling2D(size=(2, 2))(x0_4)
    up3 = conv_block(up3, 128)  # Reduced filter count
    
    # Match sizes if needed
    if up3.shape[1] != skip2.shape[1] or up3.shape[2] != skip2.shape[2]:
        up3 = ResizeLayer(target_size=(skip2.shape[1], skip2.shape[2]))(up3)
    
    # Create node X^0_3
    x0_3 = Concatenate()([up3, skip2])
    x0_3 = conv_block(x0_3, 128)  # Single conv block
    
    # Create node X^1_3 (second nested dense block)
    # Direct upsampling from bottleneck
    up1_3 = UpSampling2D(size=(4, 4))(bottleneck)  # 28x28
    up1_3 = conv_block(up1_3, 128)
    
    if up1_3.shape[1] != x0_3.shape[1] or up1_3.shape[2] != x0_3.shape[2]:
        up1_3 = ResizeLayer(target_size=(x0_3.shape[1], x0_3.shape[2]))(up1_3)
    
    x1_3 = Concatenate()([up1_3, x0_3])
    x1_3 = conv_block(x1_3, 128)
    
    # Third decoder level (L2) - From x0_3 to 56x56
    up2 = UpSampling2D(size=(2, 2))(x0_3)
    up2 = conv_block(up2, 64)
    
    # Match sizes if needed
    if up2.shape[1] != skip1.shape[1] or up2.shape[2] != skip1.shape[2]:
        up2 = ResizeLayer(target_size=(skip1.shape[1], skip1.shape[2]))(up2)
    
    # Create node X^0_2
    x0_2 = Concatenate()([up2, skip1])
    x0_2 = conv_block(x0_2, 64)
    
    # Create node X^1_2 (nested dense block)
    up1_2 = UpSampling2D(size=(2, 2))(x0_3)
    up1_2 = conv_block(up1_2, 64)
    
    if up1_2.shape[1] != x0_2.shape[1] or up1_2.shape[2] != x0_2.shape[2]:
        up1_2 = ResizeLayer(target_size=(x0_2.shape[1], x0_2.shape[2]))(up1_2)
    
    x1_2 = Concatenate()([up1_2, x0_2])
    x1_2 = conv_block(x1_2, 64)
    
    # Memory optimization: Skip X^2_2 to reduce memory usage
    # Instead, directly create X^2_1 from X^1_2
    
    # Fourth decoder level (L1) - to original size
    # From x0_2 to 224x224
    up1 = UpSampling2D(size=(4, 4))(x0_2)
    up1 = conv_block(up1, 32)  # Reduced filters
    
    # Create node X^0_1
    x0_1 = up1
    
    # Create node X^1_1
    up1_1 = UpSampling2D(size=(4, 4))(x1_2)
    up1_1 = conv_block(up1_1, 32)  # Reduced filters
    
    if up1_1.shape[1] != x0_1.shape[1] or up1_1.shape[2] != x0_1.shape[2]:
        up1_1 = ResizeLayer(target_size=(x0_1.shape[1], x0_1.shape[2]))(up1_1)
    
    x1_1 = Concatenate()([up1_1, x0_1])
    x1_1 = conv_block(x1_1, 32)  # Reduced filters
    
    # Memory optimization: Skip some deeper connections
    # Use x1_1 as the final output
    
    # Final resize to ensure output is exactly input_shape size
    if x1_1.shape[1] != input_shape[0] or x1_1.shape[2] != input_shape[1]:
        x1_1 = ResizeLayer(target_size=(input_shape[0], input_shape[1]))(x1_1)
    
    # Segmentation head
    outputs = Conv2D(num_classes, (1, 1), activation='softmax', dtype='float32')(x1_1)
    
    # Create model
    model = Model(inputs=inputs, outputs=outputs)
    
    return model

# Model parameters
IMG_HEIGHT = 224
IMG_WIDTH = 224
CHANNELS = 3
NUM_CLASSES = 4
BATCH_SIZE = None  # Fix batch size to help with memory management

# Build and summarize model
model = build_memory_efficient_unetpp_xception(
    input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), 
    num_classes=NUM_CLASSES,
    batch_size=BATCH_SIZE
)

# Print training tips
print("\nMemory Efficiency Tips:")
print("1. This model uses mixed precision (float16) for most operations")
print("2. Use a fixed batch size of 8 or less")
print("3. Consider reducing image size further if memory issues persist")
print("4. Clear session between epochs with tf.keras.backend.clear_session()")
print("5. Update your generator to match the fixed batch size:")
print("   train_generator = create_train_generator(X_train, y_train, batch_size=8)")

model.summary()

In [None]:
# import tensorflow as tf
# from tensorflow.keras import backend as K
# from tensorflow.keras.optimizers import Adam
# from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

# # ‚úÖ Dice Coefficient Metric
# def dice_coefficient(y_true, y_pred):
#     smooth = 1e-15
#     y_true = tf.cast(y_true, tf.float32)
#     y_pred = tf.cast(y_pred, tf.float32)
#     intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
#     union = tf.reduce_sum(y_true, axis=[1,2,3]) + tf.reduce_sum(y_pred, axis=[1,2,3])
#     dice = (2. * intersection + smooth) / (union + smooth)
#     return tf.reduce_mean(dice)

# def weighted_categorical_crossentropy(y_true, y_pred):
#     class_weights = tf.constant([0.3794, 0.7521, 69.7061, 49.3458], dtype=tf.float32)

#     # Ensure y_true has the same shape as y_pred
#     y_true = tf.cast(y_true, tf.float32)  # Make sure it's float32 for numerical stability
#     y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)  # Avoid log(0)

#     # Compute categorical cross-entropy
#     loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)  # Sum over the last axis (class axis)

#     # Reshape the class weights to match the loss shape
#     class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))  # [1, 1, 1, 4]

#     # Apply the class weights
#     weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)  # Broadcast weights over the batch and spatial dimensions

 #     return tf.reduce_mean(weighted_loss)

# # ‚úÖ Dice Loss
# def dice_loss(y_true, y_pred):
#     smooth = 1e-6
#     y_true = tf.cast(y_true, y_pred.dtype)
#     intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
#     union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
#     dice = (2. * intersection + smooth) / (union + smooth)
#     return 1 - tf.reduce_mean(dice)

# # ‚úÖ Combined Loss Function
# def combined_loss(y_true, y_pred):
#     return weighted_categorical_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# # ‚úÖ Custom Dice Coefficient Metric for Each Class
# class DiceCoefficient(tf.keras.metrics.Metric):
#     def __init__(self, class_idx, name=None, **kwargs):  
#         if name is None:
#             name = f"DiceClass{class_idx}"  
#         super(DiceCoefficient, self).__init__(name=name, **kwargs)
        
#         self.class_idx = class_idx
#         self.dice = self.add_weight(name="dice", initializer="zeros")

#     def update_state(self, y_true, y_pred, sample_weight=None):
#         y_true_class = y_true[..., self.class_idx]
#         y_pred_class = y_pred[..., self.class_idx]

#         intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
#         union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])

#         dice = (2. * intersection + 1e-6) / (union + 1e-6)
#         self.dice.assign(tf.reduce_mean(dice))

#     def result(self):
#         return self.dice

# # ‚úÖ Function to Get Class-wise Metrics
# def class_wise_metrics(num_classes=4):
#     return [DiceCoefficient(i) for i in range(num_classes)] + [tf.keras.metrics.MeanIoU(num_classes=num_classes)]

# # ‚úÖ Compile the Model
# model.compile(
#     optimizer=Adam(learning_rate=1e-4),
#     loss=combined_loss,
#     metrics=class_wise_metrics(4)  # Number of classes (adjust if needed)
# )

# # ‚úÖ Train the Model with the **Filtered Dataset**
# history = model.fit(
#     X_train, y_train,  # Use the loaded and split data
#     validation_data=(X_val, y_val),
#     epochs=50,
#     batch_size=16,  # Adjust based on available resources
#     callbacks=[
#         EarlyStopping(
#             monitor='val_loss', 
#             patience=7, 
#             restore_best_weights=True
#         ),
#         ReduceLROnPlateau(
#             monitor='val_loss', 
#             factor=0.5, 
#             patience=3, 
#             min_lr=1e-6
#         ),
#         ModelCheckpoint(
#             '/kaggle/working/best_unet_model_filtered.keras',  # Save in Kaggle working directory with .keras extension
#             monitor='val_loss',
#             save_best_only=True
#         )
#     ]
# )
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# ‚úÖ Dice Coefficient Metric
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
    union = tf.reduce_sum(y_true, axis=[1,2,3]) + tf.reduce_sum(y_pred, axis=[1,2,3])
    dice = (2. * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

# ‚úÖ Weighted Categorical Crossentropy
def weighted_categorical_crossentropy(y_true, y_pred):
    class_weights = tf.constant([0.3776, 0.7605, 65.8554, 46.2381], dtype=tf.float32)
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1.0)
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1)
    class_weights = tf.reshape(class_weights, (1, 1, 1, NUM_CLASSES))
    weighted_loss = loss * tf.reduce_sum(class_weights, axis=-1)
    return tf.reduce_mean(weighted_loss)

# ‚úÖ Dice Loss
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - tf.reduce_mean(dice)

# ‚úÖ Custom Dice Coefficient Metric for Each Class
class DiceCoefficient(tf.keras.metrics.Metric):
    def __init__(self, class_idx, name=None, **kwargs):  
        if name is None:
            name = f"DiceClass{class_idx}"  
        super(DiceCoefficient, self).__init__(name=name, **kwargs)
        self.class_idx = class_idx
        self.dice = self.add_weight(name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_class = y_true[..., self.class_idx]
        y_pred_class = y_pred[..., self.class_idx]
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2])
        union = tf.reduce_sum(y_true_class, axis=[1, 2]) + tf.reduce_sum(y_pred_class, axis=[1, 2])
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        self.dice.assign(tf.reduce_mean(dice))

    def result(self):
        return self.dice

# ‚úÖ Function to Get Class-wise Metrics
def class_wise_metrics(num_classes=4):
    return [DiceCoefficient(i) for i in range(num_classes)] + [tf.keras.metrics.MeanIoU(num_classes=num_classes)]

# ‚úÖ Create Data Generator
def create_train_generator(X, y, batch_size=16):
    data_gen_args = dict(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)
    
    seed = 42
    image_generator = image_datagen.flow(X, batch_size=batch_size, seed=seed)
    mask_generator = mask_datagen.flow(y, batch_size=batch_size, seed=seed)
    
    while True:
        X_batch = next(image_generator)
        y_batch = next(mask_generator)
        yield X_batch, y_batch

# ‚úÖ Create the generator
train_generator = create_train_generator(X_train, y_train, batch_size=8)

# ‚úÖ Compile the Model
import tensorflow as tf
import numpy as np

def lovasz_softmax_loss(y_true, y_pred, ignore_background=False):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    num_classes = tf.shape(y_true)[-1]
    start_class = tf.constant(1 if ignore_background else 0)

    def compute_class_loss(c):
        y_true_class = y_true[..., c]
        y_pred_class = y_pred[..., c]

        y_true_flat = tf.reshape(y_true_class, [-1])
        y_pred_flat = tf.reshape(y_pred_class, [-1])

        errors = tf.abs(y_true_flat - y_pred_flat)
        errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], sorted=True)
        y_true_sorted = tf.gather(y_true_flat, perm)

        gts = tf.reduce_sum(y_true_sorted)
        intersection = gts - tf.cumsum(y_true_sorted)
        union = gts + tf.cumsum(1. - y_true_sorted)
        jaccard = 1. - intersection / union
        grad = tf.concat([[jaccard[0]], jaccard[1:] - jaccard[:-1]], 0)

        return tf.tensordot(errors_sorted, grad, axes=1)

    # Loop through each class using tf.while_loop
    losses = tf.TensorArray(dtype=tf.float32, size=num_classes)

    def loop_cond(c, losses):
        return tf.less(c, num_classes)

    def loop_body(c, losses):
        loss_c = compute_class_loss(c)
        losses = losses.write(c, loss_c)
        return c + 1, losses

    _, losses = tf.while_loop(loop_cond, loop_body, [start_class, losses])
    return tf.reduce_mean(losses.stack())


def combined_loss(y_true, y_pred):
    smooth = 1e-6
    y_true = tf.cast(y_true, y_pred.dtype)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    dice_loss_val = 1 - tf.reduce_mean(dice)

    lovasz_loss_val = lovasz_softmax_loss(y_true, tf.nn.softmax(y_pred), ignore_background=False)

    return lovasz_loss_val + dice_loss_val


# Usage in model compilation
model.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss=combined_loss,
    metrics=class_wise_metrics(4)  # Number of classes
)

# ‚úÖ Train the Model with the Data Generator
history = model.fit(
    train_generator,
    steps_per_epoch=len(X_train) // 16,
    validation_data=(X_val, y_val),
    epochs=100,
    callbacks=[
        EarlyStopping(
            monitor='val_loss', 
            patience=10, 
            restore_best_weights=True
        ),
        ReduceLROnPlateau(
            monitor='val_loss', 
            factor=0.5, 
            patience=3, 
            min_lr=1e-6
        ),
        ModelCheckpoint(
            'lovaszloss_unet++_xception.keras',
            monitor='val_loss',
            save_best_only=True
        )
    ]
)

In [None]:
import matplotlib.pyplot as plt

# Extracting data from the history object
history_dict = history.history

# Plotting the training and validation loss
plt.figure(figsize=(12, 6))

# Plotting loss
plt.subplot(1, 2, 1)
plt.plot(history_dict['loss'], label='Training Loss')
plt.plot(history_dict['val_loss'], label='Validation Loss')
plt.title('Loss Curves')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# If accuracy is available, plot training and validation accuracy
if 'accuracy' in history_dict:
    plt.subplot(1, 2, 2)
    plt.plot(history_dict['accuracy'], label='Training Accuracy')
    plt.plot(history_dict['val_accuracy'], label='Validation Accuracy')
    plt.title('Accuracy Curves')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from scipy.spatial.distance import directed_hausdorff
from tensorflow.keras.utils import to_categorical

# ‚úÖ RGB to Class Index Conversion (for the test masks)
RGB_TO_CLASS = {
    (255, 0, 0): 1,  # Brain
    (0, 255, 0): 2,  # CSP
    (0, 0, 255): 3,  # LV
    (0, 0, 0): 0     # Background
}

# ‚úÖ Function to convert RGB masks to class index masks
def rgb_to_class_mask(rgb_mask):
    # Create a mask initialized with zeros (for background class)
    class_mask = np.zeros(rgb_mask.shape[:2], dtype=int)

    # Loop through the RGB_TO_CLASS dictionary
    for rgb, class_idx in RGB_TO_CLASS.items():
        # Identify the pixels with the current RGB value and assign them the class index
        match_mask = np.all(rgb_mask == np.array(rgb), axis=-1)
        class_mask[match_mask] = class_idx

    return class_mask

# ‚úÖ Function to calculate Dice Similarity Coefficient (DSC)
def dice_coefficient(y_true, y_pred):
    smooth = 1e-15
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    return (2. * intersection + smooth) / (union + smooth)

# ‚úÖ Function to calculate IoU (Intersection over Union)
def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    return intersection / (union + 1e-15)

# ‚úÖ Function to calculate Hausdorff Distance
def hausdorff_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')  # Return inf if no points for either true or pred class

    forward_hausdorff = directed_hausdorff(true_points, pred_points)[0]
    reverse_hausdorff = directed_hausdorff(pred_points, true_points)[0]
    return max(forward_hausdorff, reverse_hausdorff)

# ‚úÖ Function to calculate Average Surface Distance (ASD)
def average_surface_distance(y_true, y_pred):
    true_points = np.array(np.where(y_true == 1)).T
    pred_points = np.array(np.where(y_pred == 1)).T

    if len(true_points) == 0 or len(pred_points) == 0:
        return float('inf')  # Return inf if no points for either true or pred class

    distances = []
    for true_point in true_points:
        distances.append(np.min(np.linalg.norm(pred_points - true_point, axis=1)))
    return np.mean(distances)

# ‚úÖ Function to evaluate the model on the test set class-wise
def evaluate_classwise_metrics(model, X_test, y_test, num_classes=4, batch_size=16):
    # Predict in batches
    y_pred = model.predict(X_test, batch_size=batch_size)
    y_pred = np.argmax(y_pred, axis=-1)  # Convert to class index prediction

    # Convert y_test to class index format (since it's one-hot encoded)
    y_test_class = np.argmax(y_test, axis=-1)

    # Initialize lists to store class-wise metrics
    class_metrics ={i: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'f1': [], 'accuracy': [], 'hausdorff': [], 'asd': []} for i in range(num_classes)}

    # Calculate metrics for each test sample
    for i in range(len(X_test)):
        true_mask = y_test_class[i]  # one-hot -> class index
        pred_mask = y_pred[i]

        # For each class (0: Background, 1: Brain, 2: CSP, 3: LV)
        for class_idx in range(num_classes):
            true_class_mask = (true_mask == class_idx).astype(int)
            pred_class_mask = (pred_mask == class_idx).astype(int)

            # Dice Coefficient
            class_metrics[class_idx]['dice'].append(dice_coefficient(true_class_mask, pred_class_mask))
            # IoU
            class_metrics[class_idx]['iou'].append(iou(true_class_mask, pred_class_mask))
            # Precision
            class_metrics[class_idx]['precision'].append(precision_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Recall
            class_metrics[class_idx]['recall'].append(recall_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # F1 Score
            class_metrics[class_idx]['f1'].append(f1_score(true_class_mask.flatten(), pred_class_mask.flatten(), zero_division=0))
            # Accuracy
            class_metrics[class_idx]['accuracy'].append(accuracy_score(true_class_mask.flatten(), pred_class_mask.flatten()))
            # Hausdorff Distance
            # class_metrics[class_idx]['hausdorff'].append(hausdorff_distance(true_class_mask, pred_class_mask))
            # # Average Surface Distance
            # class_metrics[class_idx]['asd'].append(average_surface_distance(true_class_mask, pred_class_mask))

    # Print class-wise metrics in percentage
    print(f"{'Class':<10}{'Dice Coefficient (%)':<20}{'IoU (%)':<20}{'Precision (%)':<20}{'Recall (%)':<20}{'F1 Score (%)':<20}{'Accuracy (%)':<20}{'Hausdorff Distance':<20}{'Avg Surface Distance':<20}")
    print('-' * 180)

    for class_idx in range(num_classes):
        print(f"Class {class_idx}:")
        print(f"  Dice Coefficient: {np.mean(class_metrics[class_idx]['dice']) * 100:.2f}%")
        print(f"  IoU: {np.mean(class_metrics[class_idx]['iou']) * 100:.2f}%")
        print(f"  Precision: {np.mean(class_metrics[class_idx]['precision']) * 100:.2f}%")
        print(f"  Recall: {np.mean(class_metrics[class_idx]['recall']) * 100:.2f}%")
        print(f"  F1 Score: {np.mean(class_metrics[class_idx]['f1']) * 100:.2f}%")
        print(f"  Accuracy: {np.mean(class_metrics[class_idx]['accuracy']) * 100:.2f}%")
        # print(f"  Hausdorff Distance: {np.mean(class_metrics[class_idx]['hausdorff']):.4f}")
        # print(f"  Average Surface Distance: {np.mean(class_metrics[class_idx]['asd']):.4f}")
        print("-" * 180)

    # Evaluate on test set to print overall test accuracy and loss
    test_loss, *test_metrics = model.evaluate(X_test, y_test, batch_size=batch_size)
    print(f"Test Loss: {test_loss:.4f}")

    for metric, value in zip(model.metrics_names[1:], test_metrics):
        print(f"{metric}: {value:.4f}")

# ‚úÖ Call the evaluation function on the test set class-wise
evaluate_classwise_metrics(model, X_test, y_test)