In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split
from IPython.display import clear_output

In [None]:
# --- 0. Configuration and Setup ---
IMG_WIDTH = 256
IMG_HEIGHT = 256
NUM_CLASSES = 1
BATCH_SIZE = 8
EPOCHS_DECODER_TRAINING = 1
EPOCHS_FULL_FINE_TUNE = 1
EPOCHS_PSEUDO_LABELING_TRAINING = 1
LEARNING_RATE_DECODER = 1e-4
LEARNING_RATE_FINE_TUNE = 1e-5
LEARNING_RATE_PSEUDO_LABELING = 5e-6

PSEUDO_LABEL_CONFIDENCE_THRESHOLD = 0.95
PSEUDO_LABEL_MIN_MASK_PIXELS = 10

DATA_DIR = '/content/drive/MyDrive/Colab/PHD/Cerebellar/data_cell_new_200'
LABELED_IMAGES_DIR = os.path.join(DATA_DIR, 'image')
LABELED_MASKS_DIR = os.path.join(DATA_DIR, 'ground_truth')
UNLABELED_IMAGES_DIR = os.path.join('/content/drive/MyDrive/Colab/PHD/Cerebellar/Unlabeled')

os.makedirs(LABELED_IMAGES_DIR, exist_ok=True)
os.makedirs(LABELED_MASKS_DIR, exist_ok=True)
os.makedirs(UNLABELED_IMAGES_DIR, exist_ok=True)

In [None]:
# --- 1. Data Loading and Preprocessing Functions ---

def load_image_and_mask(image_path, mask_path, target_size=(IMG_WIDTH, IMG_HEIGHT)):

    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, target_size)
    image = image.astype(np.float32) / 255.0

    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
    mask = np.expand_dims(mask, axis=-1)
    mask = (mask > 0).astype(np.float32)

    return image, mask

def load_dataset(image_dir, mask_dir):
    image_filenames = sorted([f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    images = []
    masks = []

    for img_name in image_filenames:
        base_name, ext = os.path.splitext(img_name)
        expected_mask_name = f"{base_name}_G{ext}"
        mask_path = os.path.join(mask_dir, expected_mask_name)
        image_path = os.path.join(image_dir, img_name)

        if os.path.exists(mask_path):
            img, msk = load_image_and_mask(image_path, mask_path)
            images.append(img)
            masks.append(msk)
        else:
            print(f"Warning: No mask found for {img_name} at {mask_path}. Skipping this pair.")

    if not images:
        raise ValueError("No matching image/mask pairs found with the '_G' suffix convention. Please check your filenames.")
    return np.array(images), np.array(masks)


In [None]:
# Load your labeled data (200 images).
#print(f"Loading labeled data from: {LABELED_IMAGES_DIR} and {LABELED_MASKS_DIR}")
try:
    X_labeled, y_labeled = load_dataset(LABELED_IMAGES_DIR, LABELED_MASKS_DIR)
    print(f"Loaded {len(X_labeled)} labeled images and masks.")
    if len(X_labeled) == 0:
        raise ValueError("No labeled data found. Please ensure your 'fetal_brain_data/labeled/images' and 'fetal_brain_data/labeled/masks' directories contain images and corresponding masks.")
except Exception as e:
    print(f"Error loading labeled data: {e}")
    print("Please ensure your data is organized as specified and filenames match.")
    pass


# Load unlabeled images
#print(f"Loading unlabeled data from: {UNLABELED_IMAGES_DIR}")
try:
    unlabeled_image_filenames = sorted([f for f in os.listdir(UNLABELED_IMAGES_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    X_unlabeled = []
    for img_name in unlabeled_image_filenames:
        img_path = os.path.join(UNLABELED_IMAGES_DIR, img_name)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (IMG_WIDTH, IMG_HEIGHT))
        image = image.astype(np.float32) / 255.0
        X_unlabeled.append(image)
    X_unlabeled = np.array(X_unlabeled)
    print(f"Loaded {len(X_unlabeled)} unlabeled images.")
except Exception as e:
    print(f"Error loading unlabeled data: {e}")
    print("Please ensure your 'Unlabeled' directory contains images.")
    X_unlabeled = np.array([])



if len(X_labeled) > 0:
    X_train, X_val_, y_train, y_val_ = train_test_split(X_labeled, y_labeled, test_size=0.25, random_state=42)
    X_test, X_val, y_test, y_val = train_test_split(X_val_, y_val_, test_size=0.2, random_state=42)
    print(f"Training on {len(X_train)} images, validating on {len(X_val)} images, testing on {len(X_test)} images.")
else:
    print("Skipping data split as no labeled data was loaded.")
    X_train, X_val, y_train, y_val = np.array([]), np.array([]), np.array([]), np.array([])
    X_test, y_test = np.array([]), np.array([])

Loaded 200 labeled images and masks.
Loaded 51 unlabeled images.
Training on 150 images, validating on 10 images, testing on 40 images.


In [None]:
# --- 2. Build the U-Net Model ---
"""
def unet_model(input_size=(IMG_WIDTH, IMG_HEIGHT, 3), num_classes=1, backbone='resnet50'):
    inputs = keras.Input(input_size)

    # Encoder
    if backbone == 'resnet50':
        base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=inputs)
        skip_connections = [
            base_model.get_layer('conv1_relu').output,    # After initial conv (high resolution)
            base_model.get_layer('conv2_block3_out').output, # After block 2
            base_model.get_layer('conv3_block4_out').output, # After block 3
            base_model.get_layer('conv4_block6_out').output, # After block 4 (lowest resolution before bottleneck)
        ]
        bottleneck = base_model.get_layer('conv5_block3_out').output

    elif backbone == 'efficientnetb0':
        base_model = EfficientNetB0(weights='imagenet', include_top=False, input_tensor=inputs)
        # Define skip connection layers for EfficientNetB0.
        skip_connections = [
            base_model.get_layer('block2a_expand_activation').output,
            base_model.get_layer('block3a_expand_activation').output,
            base_model.get_layer('block4a_expand_activation').output,
            base_model.get_layer('block5a_expand_activation').output,
        ]
        bottleneck = base_model.get_layer('top_activation').output

    else:
        raise ValueError("Unsupported backbone. Choose 'resnet50' or 'efficientnetb0'.")

    # Decoder (Upsampling Path)
    x = bottleneck
    skip_connections.reverse()

    for i, skip_feature in enumerate(skip_connections):
        x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
        x = layers.Concatenate()([x, skip_feature])
        filters = max(32, 256 // (2**i))
        x = layers.Conv2D(filters, 3, activation='relu', padding='same', kernel_initializer='he_normal')(x)
        x = layers.Conv2D(filters, 3, activation='relu', padding='same', kernel_initializer='he_normal')(x)
        x = layers.Dropout(0.3)(x)
    x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
    x = layers.Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(x)
    x = layers.Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(x)
    x = layers.Dropout(0.2)(x)

    if num_classes == 1:
        outputs = layers.Conv2D(1, 1, activation='sigmoid')(x)
    else:
        outputs = layers.Conv2D(num_classes, 1, activation='softmax')(x)

    model = keras.Model(inputs=inputs, outputs=outputs)
    return model
"""

In [None]:
from tensorflow.keras import backend as K
from tensorflow.keras import layers
import tensorflow as tf
# --- 2. Build the Novel U-Net Model (CoSE-U-Net) ---

# Squeeze-and-Excitation (SE) Block


def se_block(input_feature, ratio=32):
    filters = K.int_shape(input_feature)[-1]
    assert filters is not None, "Channel dimension must be defined."
    se = layers.GlobalAveragePooling2D()(input_feature)  # (None, filters)
    se = layers.Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
    se = layers.Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
    se = layers.Lambda(lambda s: tf.expand_dims(tf.expand_dims(s, axis=1), axis=1))(se)  # (None, 1, 1, filters)
    x = layers.Multiply()([input_feature, se])
    return x

def coord_attention_block(x):

    filters = x.shape[-1]
    assert filters is not None, "The channel dimension must be defined."
    reduced_filters = max(8, filters // 32)

    # Height-wise pooling: (B, H, 1, C)
    pool_h = layers.Lambda(lambda x: tf.reduce_mean(x, axis=2, keepdims=True))(x)

    # Width-wise pooling: (B, 1, W, C)
    pool_w = layers.Lambda(lambda x: tf.reduce_mean(x, axis=1, keepdims=True))(x)
    pool_w = layers.Lambda(lambda x: tf.transpose(x, [0, 2, 1, 3]))(pool_w)  # (B, W, 1, C)

    # Concatenate along height and width (axis=1)
    concat = layers.Concatenate(axis=1)([pool_h, pool_w])  # (B, H+W, 1, C)

    # Shared transformation
    shared = layers.Conv2D(reduced_filters, kernel_size=1, activation='relu', padding='same')(concat)
    shared = layers.BatchNormalization()(shared)

    # Split into height and width parts
    split_h, split_w = layers.Lambda(lambda t: tf.split(t, num_or_size_splits=2, axis=1))(shared)
    split_w = layers.Lambda(lambda t: tf.transpose(t, [0, 2, 1, 3]))(split_w)

    # Attention maps
    attn_h = layers.Conv2D(filters, kernel_size=1, activation='sigmoid', padding='same')(split_h)
    attn_w = layers.Conv2D(filters, kernel_size=1, activation='sigmoid', padding='same')(split_w)

    # Apply both attentions
    out = layers.Multiply()([x, attn_h, attn_w])
    return out


def unet_model(input_size=(256, 256, 3), num_classes=1, backbone='resnet50'):

    inputs = keras.Input(input_size)


    # Encoder (Downsampling Path) using a pre-trained backbone
    if backbone == 'resnet50':
        base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=inputs)

        # Collect outputs for skip connections and apply SE blocks
        s1 = base_model.get_layer('conv1_relu').output
        s1 = se_block(s1) # Apply SE to skip connection feature

        s2 = base_model.get_layer('conv2_block3_out').output
        s2 = se_block(s2) # Apply SE to skip connection feature

        s3 = base_model.get_layer('conv3_block4_out').output
        s3 = se_block(s3) # Apply SE to skip connection feature

        s4 = base_model.get_layer('conv4_block6_out').output
        s4 = se_block(s4) # Apply SE to skip connection feature

        # The bottleneck feature
        bottleneck = base_model.get_layer('conv5_block3_out').output

        # Apply Coordinate Attention to the bottleneck
        bottleneck = coord_attention_block(bottleneck)

        skip_connections = [s1, s2, s3, s4] # Collect SE-enhanced skip features


    elif backbone == 'efficientnetb0':
        base_model = EfficientNetB0(weights='imagenet', include_top=False, input_tensor=inputs)


        # Collect outputs for skip connections and apply SE blocks
        s1 = base_model.get_layer('block2a_expand_activation').output
        s1 = se_block(s1)

        s2 = base_model.get_layer('block3a_expand_activation').output
        s2 = se_block(s2)

        s3 = base_model.get_layer('block4a_expand_activation').output
        s3 = se_block(s3)

        s4 = base_model.get_layer('block5a_expand_activation').output
        s4 = se_block(s4)

        # Bottleneck for EfficientNetB0.
        bottleneck = base_model.get_layer('top_activation').output
        bottleneck = coord_attention_block(bottleneck)
        skip_connections = [s1, s2, s3, s4]


    else:
        raise ValueError("Unsupported backbone. Choose 'resnet50' or 'efficientnetb0'.")

    # Decoder (Upsampling Path)
    x = bottleneck

    # Iterate through skip connections in reverse order (from deepest to shallowest encoder features).
    # This matches the U-Net's expansive path.
    skip_connections.reverse()

    for i, skip_feature in enumerate(skip_connections):
        # Bilinear upsampling
        x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
        x = layers.Concatenate()([x, skip_feature])

        # Use a slightly smaller filter count in the decoder as we go up
        filters = max(16, 128 // (2**i))
        x = layers.Conv2D(filters, 3, activation='relu', padding='same', kernel_initializer='he_normal')(x)
        x = layers.Conv2D(filters, 3, activation='relu', padding='same', kernel_initializer='he_normal')(x)
        x = layers.Dropout(0.3)(x)
        x = se_block(x)


    # Final upsampling block to match the target resolution (256x256)
    x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)
    x = layers.Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(x)
    x = layers.Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(x)
    x = layers.Dropout(0.2)(x)


    # Output layer for segmentation.
    if num_classes == 1:
        outputs = layers.Conv2D(1, 1, activation='sigmoid')(x)
    else:
        outputs = layers.Conv2D(num_classes, 1, activation='softmax')(x)

    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


In [None]:
# --- 3. Define Loss Functions and Metrics ---

def dice_coeff(y_true, y_pred):
    smooth = 1e-6
    y_true_f = tf.cast(tf.reshape(y_true, [-1]), tf.float32)
    y_pred_f = tf.cast(tf.reshape(y_pred, [-1]), tf.float32)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coeff(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    bce_loss = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)
    dice = dice_loss(y_true, y_pred)
    return 0.5 * bce_loss + 0.5 * dice

#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
#MERTICS

def jaccard_index(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

# Precision
def precision_metric(y_true, y_pred, smooth=1e-6):
    y_pred_bin = K.round(K.clip(y_pred, 0, 1))
    y_true_bin = K.round(K.clip(y_true, 0, 1))
    true_positives = K.sum(y_true_bin * y_pred_bin)
    predicted_positives = K.sum(y_pred_bin)
    return (true_positives + smooth) / (predicted_positives + smooth)

# Recall
def recall_metric(y_true, y_pred, smooth=1e-6):
    y_pred_bin = K.round(K.clip(y_pred, 0, 1))
    y_true_bin = K.round(K.clip(y_true, 0, 1))
    true_positives = K.sum(y_true_bin * y_pred_bin)
    possible_positives = K.sum(y_true_bin)
    return (true_positives + smooth) / (possible_positives + smooth)

# Accuracy
def accuracy_metric(y_true, y_pred):
    y_pred_bin = K.round(K.clip(y_pred, 0, 1))
    y_true_bin = K.round(K.clip(y_true, 0, 1))
    return K.mean(K.equal(y_true_bin, y_pred_bin))


In [None]:
# --- 4. Model Instantiation and Training Stages ---

model = unet_model(backbone='efficientnetb0')
model.summary()

# Freeze the base model (encoder) layers.
if isinstance(model.layers[1], tf.keras.Model):
    print(f"Freezing base model (encoder): {model.layers[1].name}")
    model.layers[1].trainable = False
else:
    print("Manually freezing backbone layers (assuming ResNet/EfficientNet naming conventions)...")
    for layer in model.layers:
        if 'res' in layer.name or 'conv' in layer.name and layer.name.startswith('conv'): # General ResNet layers
            layer.trainable = False
        if 'efficientnet' in layer.name:
            layer.trainable = False

model.compile(optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE_DECODER),
              loss=bce_dice_loss, metrics=[jaccard_index, dice_coefficient, precision_metric,
              recall_metric, accuracy_metric])

In [None]:
def apply_paper_augmentation(image, mask):
    augmented_images = []
    augmented_masks = []
    img_uint8 = (image * 255).astype(np.uint8)
    img_gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)

    # 1. Noise Removal on Image
    img_denoised_gray = cv2.medianBlur(img_gray, 5)
    img_denoised_rgb = cv2.cvtColor(img_denoised_gray, cv2.COLOR_GRAY2RGB).astype(np.float32) / 255.0
    mask_uint8 = (mask * 255).astype(np.uint8)
    kernel_small = np.ones((3,3), np.uint8)
    kernel_large = np.ones((5,5), np.uint8)

    # 2. Morphological Operations on Mask (4 types * 2 kernel sizes = 8 variations)
    morph_ops = {
        "dilate": cv2.MORPH_DILATE,
        "erode": cv2.MORPH_ERODE,
        "open": cv2.MORPH_OPEN, # Opening = erosion followed by dilation
        "close": cv2.MORPH_CLOSE # Closing = dilation followed by erosion
    }
    kernels = {"small": kernel_small, "large": kernel_large}

    base_augmented_pairs = []
    for op_name, op_code in morph_ops.items():
        for kernel_name, kernel in kernels.items():
            if op_name == "dilate":
                transformed_mask = cv2.dilate(mask_uint8, kernel, iterations=1)
            elif op_name == "erode":
                transformed_mask = cv2.erode(mask_uint8, kernel, iterations=1)
            else:
                transformed_mask = cv2.morphologyEx(mask_uint8, op_code, kernel)
            transformed_mask = np.expand_dims(transformed_mask.astype(np.float32) / 255.0, axis=-1)
            base_augmented_pairs.append((img_denoised_rgb, transformed_mask))

    #  Add Zoom and Flips to each base_augmented_pair
    zoom_scale = 1.2
    h, w, c = img_denoised_rgb.shape

    for base_img, base_mask in base_augmented_pairs:
        # 1. Original Scale (no zoom, no flip)
        augmented_images.append(base_img)
        augmented_masks.append(base_mask)

        # 2. Zoomed Version
        new_h, new_w = int(h * zoom_scale), int(w * zoom_scale)

        # Resize (zoom in)
        zoomed_img = cv2.resize(base_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        zoomed_mask = cv2.resize(base_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
        if zoomed_mask.ndim == 2:
            zoomed_mask = np.expand_dims(zoomed_mask, axis=-1)

        # Crop back to original size from the center
        start_h = (new_h - h) // 2
        start_w = (new_w - w) // 2
        cropped_zoomed_img = zoomed_img[start_h:start_h+h, start_w:start_w+w, :]
        cropped_zoomed_mask = zoomed_mask[start_h:start_h+h, start_w:start_w+w, :] # This line will now work

        # Ensure cropped_zoomed_mask is binary and has channel dim
        cropped_zoomed_mask = (cropped_zoomed_mask > 0.5).astype(np.float32)
        if cropped_zoomed_mask.ndim == 2:
            cropped_zoomed_mask = np.expand_dims(cropped_zoomed_mask, axis=-1)
        augmented_images.append(np.fliplr(base_img).astype(np.float32))
        flipped_base_mask_lr = np.fliplr(base_mask).astype(np.float32)
        if flipped_base_mask_lr.ndim == 2:
            flipped_base_mask_lr = np.expand_dims(flipped_base_mask_lr, axis=-1)
        augmented_masks.append(flipped_base_mask_lr)

        augmented_images.append(np.flipud(base_img).astype(np.float32))
        flipped_base_mask_ud = np.flipud(base_mask).astype(np.float32)
        if flipped_base_mask_ud.ndim == 2:
            flipped_base_mask_ud = np.expand_dims(flipped_base_mask_ud, axis=-1)
        augmented_masks.append(flipped_base_mask_ud)

        augmented_images.append(cropped_zoomed_img)
        augmented_masks.append(cropped_zoomed_mask)

        # 3. Horizontal Flip (original scale)
        augmented_images.append(np.fliplr(base_img).astype(np.float32))
        augmented_masks.append(np.fliplr(base_mask).astype(np.float32))

        # 4. Vertical Flip (original scale)
        augmented_images.append(np.flipud(base_img).astype(np.float32))
        augmented_masks.append(np.flipud(base_mask).astype(np.float32))

    return augmented_images, augmented_masks, img_denoised_rgb # Return denoised image too


def train_generator(images, masks, batch_size):
    num_samples = len(images)
    while True:
        indices = np.random.permutation(num_samples)
        total_augmented_samples_per_epoch = num_samples * 32

        yielded_samples = 0
        current_batch_images = []
        current_batch_masks = []

        for i in range(num_samples): # Iterate through original samples
            original_idx = indices[i]
            current_image = images[original_idx]
            current_mask = masks[original_idx]

            # Apply paper-specific augmentation (now includes flips and zoom)
            aug_imgs, aug_msks, _ = apply_paper_augmentation(current_image, current_mask)

            for aug_img, aug_msk in zip(aug_imgs, aug_msks):
                current_batch_images.append(aug_img)
                current_batch_masks.append(aug_msk)

                if len(current_batch_images) == batch_size:
                    yield np.array(current_batch_images), np.array(current_batch_masks)
                    yielded_samples += batch_size
                    current_batch_images = []
                    current_batch_masks = []

        # Yield any remaining samples in the last partial batch
        if current_batch_images:
            yield np.array(current_batch_images), np.array(current_batch_masks)
            yielded_samples += len(current_batch_images)

        # Print the total number of augmented images generated per epoch
        print(f"\nTotal augmented images generated per epoch: {total_augmented_samples_per_epoch}")

train_gen = train_generator(X_train, y_train, BATCH_SIZE)
val_gen = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(BATCH_SIZE).cache().prefetch(buffer_size=tf.data.AUTOTUNE)
if len(X_train) > 0:
    steps_per_epoch_stage1 = (len(X_train) * 32) // BATCH_SIZE
    if (len(X_train) * 32) % BATCH_SIZE != 0:
        steps_per_epoch_stage1 += 1
else:
    steps_per_epoch_stage1 = 0


In [None]:
# --- Stage 1: Train only the decoder (freeze encoder) ---
print("\n--- Stage 1: Training Decoder (Encoder Frozen) ---")

if len(X_train) > 0:
    history_decoder = model.fit(
        train_gen,
        steps_per_epoch=steps_per_epoch_stage1, # Number of batches per epoch
        epochs=EPOCHS_DECODER_TRAINING,
        validation_data=val_gen,
        callbacks=[
            keras.callbacks.EarlyStopping(patience=7, restore_best_weights=True, monitor='val_loss'),
            keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2, min_lr=1e-7, verbose=1, monitor='val_loss') ]
    )
else:
    print("Skipping Stage 1 training as no training data is available.")
    history_decoder = None # Initialize history as None if no training


In [None]:
def generate_pseudo_labels(model_to_predict, unlabeled_images, confidence_threshold, min_mask_pixels):
    if len(unlabeled_images) == 0:
        print("No unlabeled images to generate pseudo-labels for.")
        return np.array([]), np.array([])

    print(f"\nGenerating predictions for {len(unlabeled_images)} unlabeled images...")
    # Predict raw probability maps for all unlabeled images
    raw_predictions = model_to_predict.predict(unlabeled_images, batch_size=BATCH_SIZE*2, verbose=1)

    pseudo_labeled_images = []
    pseudo_labeled_masks = []
    num_confident_masks = 0

    print(f"Filtering pseudo-labels with confidence threshold > {confidence_threshold} and min_mask_pixels > {min_mask_pixels}...")
    for i in range(len(unlabeled_images)):
        prob_map = raw_predictions[i].squeeze() # Remove channel dim
        pseudo_mask = (prob_map > 0.5).astype(np.float32)
        confident_fg_pixels = np.sum(prob_map >= confidence_threshold)
        confident_bg_pixels = np.sum(prob_map <= (1.0 - confidence_threshold))
        total_pixels = prob_map.size

        confident_pixel_percentage = (confident_fg_pixels + confident_bg_pixels) / total_pixels

        if confident_pixel_percentage > 0.9 and np.sum(pseudo_mask) >= min_mask_pixels: # Require 90% confident pixels AND a non-tiny mask
            pseudo_labeled_images.append(unlabeled_images[i])
            pseudo_labeled_masks.append(np.expand_dims(pseudo_mask, axis=-1)) # Add channel dim back
            num_confident_masks += 1


    print(f"Generated {num_confident_masks} confident pseudo-labeled samples out of {len(unlabeled_images)} unlabeled images.")
    return np.array(pseudo_labeled_images), np.array(pseudo_labeled_masks)


In [None]:
# --- Stage 2: Fine-tune the entire network ---
print("\n--- Stage 2: Fine-tuning Full Network (All Layers Unfrozen) ---")

for layer in model.layers:
    layer.trainable = True

model.compile(optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE_FINE_TUNE),
              loss=bce_dice_loss,
              metrics=[jaccard_index, dice_coefficient, precision_metric,
        recall_metric, accuracy_metric])

if len(X_train) > 0:
    steps_per_epoch_stage2 = (len(X_train) * 32) // BATCH_SIZE
    if (len(X_train) * 32) % BATCH_SIZE != 0:
        steps_per_epoch_stage2 += 1
else:
    steps_per_epoch_stage2 = 0

# Train the model for Stage 2 (full network fine-tuning).
if len(X_train) > 0:
    history_full = model.fit(
        train_gen, # Continues to use X_train from Stage 1 for training
        steps_per_epoch=steps_per_epoch_stage2,
        epochs=EPOCHS_FULL_FINE_TUNE,
        validation_data=val_gen,
        callbacks=[
            # Early stopping with more patience for full fine-tuning.
            # --- FIX: Increased patience slightly for observation ---
            keras.callbacks.EarlyStopping(patience=15, restore_best_weights=True, monitor='val_loss'), # Increased from 10 to 15
            # Reduce learning rate when validation loss plateaus.
            keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3, min_lr=1e-8, verbose=1, monitor='val_loss')
        ]
    )
else:
    print("Skipping Stage 2 training as no training data is available.")
    history_full = None # Initialize history as None if no training

In [None]:
# --- NEW: Stage 3: Pseudo-labeling and Re-training on Combined Data ---
print("\n--- Stage 3: Pseudo-labeling and Re-training on Combined Data ---")

if len(X_unlabeled) > 0 and len(X_train) > 0:
    pseudo_images, pseudo_masks = generate_pseudo_labels(
        model, X_unlabeled, PSEUDO_LABEL_CONFIDENCE_THRESHOLD, PSEUDO_LABEL_MIN_MASK_PIXELS
    )

    if len(pseudo_images) > 0:
        X_labeled_for_pseudo_training = np.concatenate((X_train, X_test), axis=0) # Use all 190 labeled images
        y_labeled_for_pseudo_training = np.concatenate((y_train, y_test), axis=0) # Use all 190 labeled masks

        X_combined_train_pseudo = np.concatenate((X_labeled_for_pseudo_training, pseudo_images), axis=0)
        y_combined_train_pseudo = np.concatenate((y_labeled_for_pseudo_training, pseudo_masks), axis=0)
        print(f"Combined training data size for Stage 3: {len(X_combined_train_pseudo)} (Original Labeled: {len(X_labeled_for_pseudo_training)}, Pseudo: {len(pseudo_images)})")

        # Re-compile the model for Stage 3 with a very small learning rate
        model.compile(optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE_PSEUDO_LABELING),
                      loss=bce_dice_loss,
                      metrics=[jaccard_index, dice_coefficient, precision_metric,
                      recall_metric, accuracy_metric])

        # Create a new generator for the combined dataset (with custom aug again)
        combined_train_gen_pseudo = train_generator(
            X_combined_train_pseudo, y_combined_train_pseudo, BATCH_SIZE
        )

        # Calculate steps_per_epoch for Stage 3
        steps_per_epoch_stage3 = (len(X_combined_train_pseudo) * 32) // BATCH_SIZE # *32 for augmentation
        if (len(X_combined_train_pseudo) * 32) % BATCH_SIZE != 0:
            steps_per_epoch_stage3 += 1

        print(f"Starting Stage 3 training on combined data for {EPOCHS_PSEUDO_LABELING_TRAINING} epochs...")
        history_pseudo_labeling = model.fit(
            combined_train_gen_pseudo,
            steps_per_epoch=steps_per_epoch_stage3,
            epochs=EPOCHS_PSEUDO_LABELING_TRAINING,
            validation_data=val_gen, # Still validate on the original held-out validation set (X_val, y_val)
            callbacks=[
                keras.callbacks.EarlyStopping(patience=20, restore_best_weights=True, monitor='val_loss'), # More patience
                keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5, min_lr=1e-9, verbose=1, monitor='val_loss')
            ]
        )
        print("\nStage 3: Pseudo-labeling training complete!")
    else:
        print("No confident pseudo-labels generated. Skipping Stage 3 training.")
        history_pseudo_labeling = None
else:
    print("Not enough unlabeled or labeled data to perform pseudo-labeling. Skipping Stage 3 training.")
    history_pseudo_labeling = None


In [None]:
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(BATCH_SIZE).cache().prefetch(buffer_size=tf.data.AUTOTUNE)
evaluation_results = model.evaluate(test_dataset)
(
    test_loss,
    test_jaccard,
    test_dice,
    test_precision,
    test_recall,
    test_accuracy
) = evaluation_results

print(f"Test Loss: {test_loss:.4f}")
print(f"Jaccard Index: {test_jaccard:.4f}")
print(f"Dice Coefficient: {test_dice:.4f}")
print(f"Precision: {test_precision:.4f}")
print(f"Recall: {test_recall:.4f}")
print(f"Accuracy: {test_accuracy:.4f}")


In [None]:
# --- 5. Evaluation and Visualization ---

def plot_history(history, title):
    """
    Plots the training and validation loss and Dice Coefficient over epochs.
    """
    if history is None:
        print(f"No history to plot for {title}.")
        return

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title(f'{title} - Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(history.history['dice_coefficient'], label='Train Dice Coeff')
    plt.plot(history.history['val_dice_coefficient'], label='Validation Dice Coeff')
    plt.title(f'{title} - Dice Coefficient')
    plt.xlabel('Epoch')
    plt.ylabel('Dice Coeff')
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

# IMPORTANT: Ensure your 'my_trained_model' (your main model object),
# X_test, and y_test are defined and loaded before running this.

# Example for dummy data/model if you're just testing the visualization:
if 'my_trained_model' not in locals():
    print("WARNING: 'my_trained_model' not found. Creating a DUMMY model for visualization testing.")
    input_shape = (128, 128, 3) # Adjust to your actual input image shape
    dummy_inputs = tf.keras.layers.Input(input_shape)
    dummy_conv = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(dummy_inputs)
    my_trained_model = tf.keras.Model(inputs=dummy_inputs, outputs=dummy_conv)

if 'X_test' not in locals() or not isinstance(X_test, np.ndarray) or X_test.size == 0:
    print("WARNING: X_test/y_test not found. Creating DUMMY data for visualization testing.")
    X_test = np.random.rand(10, 128, 128, 3).astype(np.float32) # Adjust shape as needed
    y_test = np.random.randint(0, 2, size=(10, 128, 128, 1)).astype(np.float32) # Adjust shape as needed


def visualize_predictions(main_model_object, model1_path, model2_path, model3_path, model4_path,
                          X_set, y_set, num_samples=5):
    """
    Visualizes original image, ground truth, main model's prediction, and
    predictions from four other specified models in a grid.

    Args:
        main_model_object (tf.keras.Model): Your primary Keras model object (passed directly).
        model1_path (str): Path to the first comparison model (FCRBUnet).
        model2_path (str): Path to the second comparison model (ECAUnet).
        model3_path (str): Path to the third comparison model (Auto_ReUnet).
        model4_path (str): Path to the fourth comparison model (Dual_CNN).
        X_set (np.array): Input images for prediction.
        y_set (np.array): Ground truth masks.
        num_samples (int): Number of samples (rows) to display.
    """
    if len(X_set) == 0:
        print(f"No data to visualize predictions.")
        return

    num_samples_to_display = min(num_samples, len(X_set))

    # --- Load Additional Models (from .h5 files) ---
    loaded_additional_models = []
    additional_model_paths = [model1_path, model2_path, model3_path, model4_path]

    # Define intended display names for the additional models
    additional_model_display_names = ["FCRBUnet", "ECAUnet", "Auto_ReUnet", "Dual_CNN"]

    # This list will store the names of only the models that SUCCESSFULLY load
    actual_loaded_additional_names = []

    print("\n--- Attempting to load additional models ---")
    for i, path in enumerate(additional_model_paths):
        display_name = additional_model_display_names[i]
        try:
            # --- CRITICAL CHECK FOR FILE TYPE ---
            if not os.path.exists(path):
                print(f"ERROR: File for {display_name} NOT FOUND at '{path}'. Skipping this model.")
                continue
            if path.endswith('.ipynb'): # This was the specific issue for model3_path
                print(f"ERROR: {display_name} path '{path}' is a .ipynb file. Keras cannot load models from notebooks. Please provide the .h5 model file path. Skipping this model.")
                continue

            model_loaded = tf.keras.models.load_model(path)
            loaded_additional_models.append(model_loaded)
            actual_loaded_additional_names.append(display_name) # Add name only if loaded
            print(f"SUCCESS: Loaded {display_name} from '{path}'.")
        except Exception as e:
            print(f"ERROR loading {display_name} from '{path}': {e}. Skipping this model.")
            # If an error occurs (e.g., corrupted file, wrong format), it's caught here.

    print("--- Finished loading additional models ---\n")

    # --- Make predictions ---
    # Use the directly passed main_model_object
    predictions_main = main_model_object.predict(X_set[:num_samples_to_display])

    predictions_additional = []
    for model_obj in loaded_additional_models:
        predictions_additional.append(model_obj.predict(X_set[:num_samples_to_display]))

    # --- Setup Plotting ---
    # Calculate total columns based on what ACTUALLY loaded
    total_columns = 2 + 1 + len(loaded_additional_models)

    # Define column titles for clarity, matching the order of plots
    column_titles = ["Original", "Ground Truth", "My Model"] + actual_loaded_additional_names

    print(f"Generating plot with {num_samples_to_display} rows and {total_columns} columns.")
    plt.figure(figsize=(total_columns * 3, num_samples_to_display * 3)) # Adjust figure size

    for i in range(num_samples_to_display): # Loop for each sample (row)
        # 1. Original Image (Column 1)
        plt.subplot(num_samples_to_display, total_columns, i * total_columns + 1)
        plt.imshow(X_set[i])
        plt.axis('off')
        if i == 0: plt.title(column_titles[0], fontsize=10) # Title only for the first row

        # 2. True Mask (Column 2)
        plt.subplot(num_samples_to_display, total_columns, i * total_columns + 2)
        plt.imshow(y_set[i].squeeze(), cmap='gray')
        plt.axis('off')
        if i == 0: plt.title(column_titles[1], fontsize=10)

        # 3. My Model Predicted Mask (Column 3)
        plt.subplot(num_samples_to_display, total_columns, i * total_columns + 3)
        predicted_mask_main = (predictions_main[i].squeeze() > 0.5).astype(np.float32)
        plt.imshow(predicted_mask_main, cmap='gray')
        plt.axis('off')
        if i == 0: plt.title(column_titles[2], fontsize=10)

        # 4. Additional Models Predicted Masks (Columns 4 onwards)
        for model_idx, preds_for_model in enumerate(predictions_additional):
            current_col_idx = i * total_columns + (3 + 1 + model_idx)

            plt.subplot(num_samples_to_display, total_columns, current_col_idx)
            predicted_mask = (preds_for_model[i].squeeze() > 0.5).astype(np.float32)
            plt.imshow(predicted_mask, cmap='gray')
            plt.axis('off')
            if i == 0: plt.title(column_titles[3 + model_idx], fontsize=10)

    plt.tight_layout()
    plt.show()


# --- How to use this function ---

# 1. Define your main model object (e.g., 'my_trained_model' after training)
#    (Example dummy model provided at the top if needed for testing this script itself)

# 2. Define the paths to your 4 additional models (.h5 files)
model1_path = '/content/drive/MyDrive/Colab/PHD/Cerebellar/Methods/FCRBUnet_model.h5'
model2_path = '/content/drive/MyDrive/Colab/PHD/Cerebellar/Methods/ECAUnet_model.h5'
model3_path = '/content/drive/MyDrive/Colab/PHD/Cerebellar/Methods/Aut_ReUnet_model.h5'
model4_path = '/content/drive/MyDrive/Colab/PHD/Cerebellar/Methods/Dual_CNN_model.h5'

# 3. Ensure your X_test and y_test data are loaded and available
#    (Example dummy data provided at the top if needed for testing this script itself)

print("--- Starting Visualization Script ---")
visualize_predictions(
    main_model_object=model, # Pass your actual Keras model object here
    model1_path=model1_path,
    model2_path=model2_path,
    model3_path=model3_path,
    model4_path=model4_path,
    X_set=X_test,
    y_set=y_test,
    num_samples=5
)


In [None]:
# --- NEW: Test-Time Augmentation (TTA) Prediction Function ---
def predict_with_tta(model, image, tta_transforms):
    predictions = []

    # Original prediction
    original_pred = model.predict(np.expand_dims(image, axis=0), verbose=0)[0]
    predictions.append(original_pred)

    for transform_func, inverse_transform_func in tta_transforms:
        # Apply transform to image
        aug_image = transform_func(image)
        # Predict on augmented image
        aug_pred = model.predict(np.expand_dims(aug_image, axis=0), verbose=0)[0]
        # Inverse transform prediction
        inverse_pred = inverse_transform_func(aug_pred)
        predictions.append(inverse_pred)

    # Average predictions from all augmented versions
    avg_prediction = np.mean(predictions, axis=0)
    return avg_prediction

# Define simple TTA transforms for demonstration
def flip_lr(img_or_mask): return np.fliplr(img_or_mask).astype(np.float32)
def flip_ud(img_or_mask): return np.flipud(img_or_mask).astype(np.float32)
def rotate_90(img_or_mask): return np.rot90(img_or_mask, k=1).astype(np.float32)
def rotate_270(img_or_mask): return np.rot90(img_or_mask, k=-1).astype(np.float32) # Inverse of rotate_90

TTA_TRANSFORMS = [
    (flip_lr, flip_lr), # Left-Right flip
    (flip_ud, flip_ud), # Up-Down flip
    (rotate_90, rotate_270), # Rotate 90 degrees clockwise
    (rotate_270, rotate_90), # Rotate 90 degrees counter-clockwise
]

In [None]:
plot_history(history_decoder, "Decoder Training")
plot_history(history_full, "Full Fine-tuning")

if 'history_pseudo_labeling' in locals() and history_pseudo_labeling is not None:
    plot_history(history_pseudo_labeling, "Pseudo-labeling Fine-tuning")


print("\nVisualizing predictions on validation set...")
visualize_predictions(
    main_model_object=model,
    model1_path=model1_path,
    model2_path=model2_path,
    model3_path=model3_path,
    model4_path=model4_path,
    X_set=X_val,
    y_set=y_val,
    num_samples=5
)


print("\nVisualizing predictions on test set (without TTA)...")
# Visualize predictions for 5 random samples from the test set (without TTA).
visualize_predictions(
    main_model_object=model,
    model1_path=model1_path,
    model2_path=model2_path,
    model3_path=model3_path,
    model4_path=model4_path,
    X_set=X_test,
    y_set=y_test,
    num_samples=5
)

In [None]:

# Final Evaluation on Test Set with Test-Time Augmentation (TTA) ---
print("\n--- Final Evaluation on Test Set with Test-Time Augmentation (TTA) ---")

if len(X_test) > 0:
    tta_predictions_list = []
    print(f"Applying TTA to {len(X_test)} test images...")
    for i, img in enumerate(X_test):
        tta_pred = predict_with_tta(model, img, TTA_TRANSFORMS)
        tta_predictions_list.append(tta_pred)
        if (i + 1) % 10 == 0 or (i + 1) == len(X_test):
            clear_output(wait=True)
            print(f"Processed {i + 1}/{len(X_test)} test images with TTA.")

    tta_predictions = np.array(tta_predictions_list)

    # Calculate metrics with TTA predictions
    tta_dice_scores = []
    for i in range(len(X_test)):
        # Binarize TTA prediction for Dice calculation
        bin_tta_pred = (tta_predictions[i] > 0.5).astype(np.float32)
        tta_dice_scores.append(dice_coeff(y_test[i], bin_tta_pred).numpy())

    overall_tta_dice = np.mean(tta_dice_scores)

    print(f"\nOverall Test Dice Coefficient (with TTA): {overall_tta_dice:.4f}")

    # Let's also evaluate without TTA for comparison
    non_tta_test_loss, non_tta_test_dice, non_tta_test_accuracy = model.evaluate(tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(BATCH_SIZE).cache().prefetch(buffer_size=tf.data.AUTOTUNE))
    print(f"Overall Test Dice Coefficient (without TTA): {non_tta_test_dice:.4f}")
    print(f"Overall Test Loss (without TTA): {non_tta_test_loss:.4f}")
    print(f"Overall Test Accuracy (without TTA): {non_tta_test_accuracy:.4f}")


    # --- Visualizing TTA predictions on Test Set ---
    print("\nVisualizing predictions on test set (with TTA)...")
    # Take the first 5 predictions with TTA for visualization
    tta_visual_preds = []
    for i in range(5):
        if i < len(X_test):
            tta_visual_preds.append((tta_predictions[i] > 0.5).astype(np.float32))

    if :
        plt.figure(figsize=(15, 5 * 3)) # Adjust figure size dynamically
        for i in range(len(tta_visual_preds)):
            plt.subplot(len(tta_visual_preds), 3, i*3 + 1)
            plt.imshow(X_test[i])
            plt.title(f'Test Original Image {i+1}')
            plt.axis('off')

            plt.subplot(len(tta_visual_preds), 3, i*3 + 2)
            plt.imshow(y_test[i].squeeze(), cmap='gray')
            plt.title(f'Test True Mask {i+1}')
            plt.axis('off')

            plt.subplot(len(tta_visual_preds), 3, i*3 + 3)
            plt.imshow(tta_visual_preds[i].squeeze(), cmap='gray')
            plt.title(f'Test Predicted Mask (TTA) {i+1}')
            plt.axis('off')
        plt.tight_layout()
        plt.show()

else:
    print("No test data available for evaluation.")

In [None]:
import matplotlib.pyplot as plt


def visualize_augmentation_examples(original_image, original_mask):

    if original_image is None or original_mask is None:
        print("Original image or mask is None. Cannot visualize augmentations.")
        return

    # Get augmented images and masks. We will ONLY use the images for plotting.
    augmented_images_list, _, denoised_original_image = apply_paper_augmentation(original_image, original_mask)

    # --- Collect all images to display ---
    # This list will contain the 10 images in the desired order for the 2x5 grid:
    # Original, Denoised, then the 8 augmented images from your function.
    all_images_to_display = [original_image, denoised_original_image] + augmented_images_list

    # Define the number of images we expect to display
    num_total_images = len(all_images_to_display) # Should be 10 (1 original + 1 denoised + 8 augmented)

    # --- Define Grid Size ---
    num_cols = 5
    num_rows = math.ceil(num_total_images / num_cols) # For 10 images and 5 columns, this will be 2.

    # --- Define Titles for each image in the grid ---
    # These titles should match the order in 'all_images_to_display'.
    # Adjust these labels to be more specific to your augmentations if needed.
    image_titles = [
        'Original Image',
        'Denoised Image',
        'Augmented (Type 1)',      # Replace with your actual augmentation names (e.g., Dilation 3x3)
        'Augmented (Type 2)',
        'Augmented (Type 3)',
        'Augmented (Type 4)',
        'Augmented (Type 5)',
        'Augmented (Type 6)',
        'Augmented (Type 7)',
        'Augmented (Type 8)'
    ]
    # Ensure titles list matches the actual number of images if `augmented_images_list`
    # happens to return a number different from 8.
    image_titles = image_titles[:num_total_images]

    # --- Create Figure and Subplots ---
    # Adjust figsize: width = num_cols * scaling_factor, height = num_rows * scaling_factor
    # The '3.5' here is a scaling factor; adjust it to make images bigger or smaller on the plot.
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 3.5, num_rows * 3.5))
    axes = axes.flatten() # Flatten the 2D array of axes for easy 1D indexing

    # Set a main title for the entire figure
    fig.suptitle('Data Augmentation Examples (Images Only for Article)', fontsize=16, y=1.02)

    # --- Plot Each Image ---
    for i in range(num_total_images):
        if i < len(axes): # Safety check to prevent errors if num_total_images somehow exceeds subplot count
            axes[i].imshow(all_images_to_display[i])
            axes[i].set_title(image_titles[i], fontsize=9) # Apply specific title
            axes[i].axis('off') # Remove axes for a clean, article-ready look
        # No 'else' block needed, as we hide remaining subplots below

    # --- Hide any remaining empty subplots ---
    # This loop ensures that if num_total_images is less than num_rows * num_cols (e.g., 9 images in 2x5 grid),
    # the extra empty subplots are hidden.
    for i in range(num_total_images, len(axes)):
        axes[i].set_visible(False)

    # --- Minimize space between subplots and show plot ---
    # plt.tight_layout() automatically adjusts subplot parameters for a tight layout.
    # 'rect' parameter adjusts the bounding box for the subplots, leaving space for the suptitle.
    plt.tight_layout(rect=[0, 0.03, 1, 0.98])
    plt.show()


# --- Example Usage (Using your existing variables X_train, y_train) ---

# IMPORTANT: Ensure your X_train and y_train are loaded and available
# (e.g., X_train = np.load('path/to/X_train.npy'), y_train = np.load('path/to/y_train.npy'))

if 'X_train' in locals() and len(X_train) > 0:
    sample_index = np.random.randint(0, len(X_train))
    sample_image = X_train[sample_index]
    sample_mask = y_train[sample_index]

    print(f"\n--- Visualizing Data Augmentation (Images Only) for Sample {sample_index} ---")
    visualize_augmentation_examples(sample_image, sample_mask)
else:
    print("\nError: X_train not found or is empty. Please ensure X_train and y_train are loaded.")
    print("Example: X_train = np.random.rand(10, 128, 128, 3).astype(np.float32)")
    print("         y_train = np.random.randint(0, 2, size=(10, 128, 128, 1)).astype(np.float32)")