In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# CELL 1: Import necessary libraries
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, optimizers
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score
import random
import glob
from tensorflow.keras import backend as K

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

print("TensorFlow version:", tf.__version__)
print("GPU Available:", len(tf.config.list_physical_devices('GPU')) > 0)
print("GPU Devices:", tf.config.list_physical_devices('GPU'))

In [None]:
# CELL 2: Define helper functions for data loading
def load_image(image_path):
    """Load normalized image data from .npy file"""
    # Convert tensor to string if needed
    if isinstance(image_path, tf.Tensor):
        image_path = image_path.numpy().decode('utf-8')
    # Explicitly convert to float32
    return np.load(image_path).astype(np.float32)

def load_mask(mask_path):
    """Load mask from PNG file"""
    # Convert tensor to string if needed
    if isinstance(mask_path, tf.Tensor):
        mask_path = mask_path.numpy().decode('utf-8')
    mask = plt.imread(mask_path)
    # Convert to binary mask if needed
    if len(mask.shape) > 2 and mask.shape[2] > 1:
        mask = mask[:, :, 0]
    # Explicitly convert to float32
    return (mask > 0).astype(np.float32)

def create_dataset(base_path, split, batch_size=8, shuffle=True):
    """Create a TensorFlow dataset for the specified split"""
    img_paths = sorted(glob.glob(os.path.join(base_path, split, 'images', '*.npy')))
    # If no .npy files, try .png files
    if len(img_paths) == 0:
        img_paths = sorted(glob.glob(os.path.join(base_path, split, 'images', '*.png')))
        print(f"Using PNG files for images in {split} split")

    mask_paths = sorted(glob.glob(os.path.join(base_path, split, 'masks', '*.png')))

    if len(img_paths) == 0 or len(mask_paths) == 0:
        raise ValueError(f"No images or masks found in {base_path}/{split}")

    print(f"Found {len(img_paths)} images and {len(mask_paths)} masks for {split}")

    # Create a dataset of image paths
    img_dataset = tf.data.Dataset.from_tensor_slices(img_paths)
    mask_dataset = tf.data.Dataset.from_tensor_slices(mask_paths)

    # Combine image and mask paths
    dataset = tf.data.Dataset.zip((img_dataset, mask_dataset))

    # Shuffle if needed
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(img_paths), seed=42)

    # Map loading function to the dataset
    dataset = dataset.map(
        lambda img_path, mask_path: (
            tf.py_function(
                func=load_image,
                inp=[img_path],
                Tout=tf.float32
            ),
            tf.py_function(
                func=load_mask,
                inp=[mask_path],
                Tout=tf.float32
            )
        ),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    # Set shapes
    dataset = dataset.map(
        lambda x, y: (
            tf.ensure_shape(x, [256, 256, 3]),
            tf.ensure_shape(y, [256, 256])
        )
    )

    # Add channel dimension to mask
    dataset = dataset.map(lambda x, y: (x, tf.expand_dims(y, axis=-1)))

    # Batch and prefetch
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return dataset




def plot_accuracy(history):
    plt.figure(figsize=(10, 6))

    # Plot training and validation binary accuracy
    plt.plot(history.history['binary_accuracy'], label='Training Binary Accuracy')
    plt.plot(history.history['val_binary_accuracy'], label='Validation Binary Accuracy')

    # Plot training and validation general accuracy (if available)
    if 'accuracy' in history.history and 'val_accuracy' in history.history:
        plt.plot(history.history['accuracy'], label='Training Accuracy')
        plt.plot(history.history['val_accuracy'], label='Validation Accuracy')

    plt.title('Model Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend(loc='upper left')
    plt.show()


def visualize_samples(dataset, num_samples=3):
    """Visualize random samples from the dataset"""
    plt.figure(figsize=(15, 5*num_samples))

    for i, (images, masks) in enumerate(dataset.take(num_samples)):
        for j in range(min(images.shape[0], 3)):
            # Get image and mask
            image = images[j].numpy()
            mask = masks[j].numpy().squeeze()

            # Normalize image for visualization (0-1 range)
            image_viz = (image - image.min()) / (image.max() - image.min() + 1e-8)

            # Display RGB channels
            plt.subplot(num_samples, 4, i*4+1)
            plt.imshow(image_viz[:, :, 0], cmap='gray')
            plt.title(f"VH Channel - Sample {i+1}")
            plt.axis('off')

            plt.subplot(num_samples, 4, i*4+2)
            plt.imshow(image_viz[:, :, 1], cmap='gray')
            plt.title(f"VV Channel - Sample {i+1}")
            plt.axis('off')

            plt.subplot(num_samples, 4, i*4+3)
            plt.imshow(image_viz[:, :, 2], cmap='gray')
            plt.title(f"VH/VV Ratio - Sample {i+1}")
            plt.axis('off')

            # Display mask
            plt.subplot(num_samples, 4, i*4+4)
            plt.imshow(mask, cmap='Blues')
            plt.title(f"Flood Mask - Sample {i+1}")
            plt.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# CELL 3: Define custom metrics and loss functions
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    """Calculate Dice coefficient

    Args:
        y_true: Ground truth masks
        y_pred: Predicted masks
        smooth: Smoothing factor to avoid division by zero

    Returns:
        Dice coefficient (0-1)
    """
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    """Dice loss function based on dice coefficient"""
    return 1 - dice_coefficient(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    """Combined binary cross-entropy and dice loss"""
    bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)
    dice = dice_loss(y_true, y_pred)
    return bce + dice

def iou_score(y_true, y_pred, smooth=1e-6):
    """Calculate IoU (Intersection over Union) score

    Args:
        y_true: Ground truth masks
        y_pred: Predicted masks
        smooth: Smoothing factor to avoid division by zero

    Returns:
        IoU score (0-1)
    """
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def f1_score_metric(y_true, y_pred, smooth=1e-6):
    """Calculate F1 score metric

    Args:
        y_true: Ground truth masks
        y_pred: Predicted masks
        smooth: Smoothing factor to avoid division by zero

    Returns:
        F1 score (0-1)
    """
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)

    # Calculate precision and recall
    true_positives = K.sum(y_true_f * y_pred_f)
    predicted_positives = K.sum(y_pred_f)
    actual_positives = K.sum(y_true_f)

    precision = (true_positives + smooth) / (predicted_positives + smooth)
    recall = (true_positives + smooth) / (actual_positives + smooth)

    # Calculate F1 score
    f1 = 2 * (precision * recall) / (precision + recall + smooth)
    return f1

def precision_metric(y_true, y_pred, smooth=1e-6):
    """Calculate precision metric

    Args:
        y_true: Ground truth masks
        y_pred: Predicted masks
        smooth: Smoothing factor to avoid division by zero

    Returns:
        Precision (0-1)
    """
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)

    true_positives = K.sum(y_true_f * y_pred_f)
    predicted_positives = K.sum(y_pred_f)

    precision = (true_positives + smooth) / (predicted_positives + smooth)
    return precision

def recall_metric(y_true, y_pred, smooth=1e-6):
    """Calculate recall metric

    Args:
        y_true: Ground truth masks
        y_pred: Predicted masks
        smooth: Smoothing factor to avoid division by zero

    Returns:
        Recall (0-1)
    """
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)

    true_positives = K.sum(y_true_f * y_pred_f)
    actual_positives = K.sum(y_true_f)

    recall = (true_positives + smooth) / (actual_positives + smooth)
    return recall

In [None]:
# CELL 4: Define ResUNet model architecture
def conv_block(inputs, filters, kernel_size=3, strides=1, padding='same'):
    """Convolutional block with batch normalization and activation"""
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

def residual_block(inputs, filters, kernel_size=3, strides=1):
    """Residual block with skip connection"""
    x = conv_block(inputs, filters, kernel_size, strides)
    x = conv_block(x, filters, kernel_size, 1)

    # If the input and output dimensions differ, use 1x1 conv to match dimensions
    if strides > 1 or inputs.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, kernel_size=1, strides=strides, padding='same')(inputs)
        shortcut = layers.BatchNormalization()(shortcut)
    else:
        shortcut = inputs

    x = layers.add([x, shortcut])
    x = layers.ReLU()(x)
    return x

def build_resunet_model(input_shape=(256, 256, 3), num_classes=1):
    """Build ResUNet model

    Args:
        input_shape: Input image shape (height, width, channels)
        num_classes: Number of output classes (1 for binary segmentation)

    Returns:
        Compiled ResUNet model
    """
    # Input layer
    inputs = layers.Input(input_shape)

    # Initial Convolution
    x = conv_block(inputs, 64, kernel_size=7, strides=1)

    # Encoder blocks with residual connections and max pooling
    # Encoder block 1
    skip1 = residual_block(x, 64)
    x = layers.MaxPooling2D(2)(skip1)

    # Encoder block 2
    skip2 = residual_block(x, 128)
    x = layers.MaxPooling2D(2)(skip2)

    # Encoder block 3
    skip3 = residual_block(x, 256)
    x = layers.MaxPooling2D(2)(skip3)

    # Bridge
    x = residual_block(x, 512)

    # Decoder blocks with upsampling and concatenation with skip connections
    # Decoder block 1
    x = layers.UpSampling2D(2)(x)
    x = conv_block(x, 256)
    x = layers.Concatenate()([x, skip3])
    x = residual_block(x, 256)

    # Decoder block 2
    x = layers.UpSampling2D(2)(x)
    x = conv_block(x, 128)
    x = layers.Concatenate()([x, skip2])
    x = residual_block(x, 128)

    # Decoder block 3
    x = layers.UpSampling2D(2)(x)
    x = conv_block(x, 64)
    x = layers.Concatenate()([x, skip1])
    x = residual_block(x, 64)

    # Output layer
    if num_classes > 1:
        # Multi-class segmentation
        outputs = layers.Conv2D(num_classes, kernel_size=1, activation='softmax')(x)
    else:
        # Binary segmentation
        outputs = layers.Conv2D(num_classes, kernel_size=1, activation='sigmoid')(x)

    # Create model
    model = models.Model(inputs=inputs, outputs=outputs)

    return model


In [None]:
# CELL 5: Load and visualize data
# Set paths
base_path = "/content/drive/MyDrive/ResUNet_preprocessed"  # Update this to your actual path

# Create datasets
batch_size = 8
train_dataset = create_dataset(base_path, 'train', batch_size=batch_size)
val_dataset = create_dataset(base_path, 'val', batch_size=batch_size)
test_dataset = create_dataset(base_path, 'test', batch_size=batch_size)

# Visualize some samples


In [None]:
# CELL 6: Build and compile model (updated)
# Build model
input_shape = (256, 256, 3)  # Based on your preprocessing
model = build_resunet_model(input_shape)

# Compile model with all the important metrics
model.compile(
    optimizer=optimizers.Adam(learning_rate=0.001),
    loss=bce_dice_loss,
    metrics=[
        dice_coefficient,
        iou_score,
        'binary_accuracy',
        f1_score_metric,
        precision_metric,
        recall_metric
    ]
)

# Display model summary
model.summary()

In [None]:
# CELL 7: Define callbacks (updated)
# Create model checkpoint callback
checkpoint_path = "/content/drive/MyDrive/flood_resunet_weights/model.h5"
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)

# Define callbacks with enhanced logging
callbacks_list = [
    callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        monitor='val_dice_coefficient',
        save_best_only=True,
        mode='max'
    ),
    callbacks.EarlyStopping(
        monitor='val_dice_coefficient',
        patience=10,
        restore_best_weights=True,
        mode='max'
    ),
    callbacks.ReduceLROnPlateau(
        monitor='val_dice_coefficient',
        factor=0.5,
        patience=5,
        min_lr=1e-6,
        mode='max'
    ),
    callbacks.TensorBoard(
        log_dir='/content/drive/MyDrive/flood_resunet_logs',
        histogram_freq=1,
        update_freq='epoch',
        write_graph=True,
        write_images=True,
        profile_batch=0
    ),
    # Add CSV logger to save all metrics
    callbacks.CSVLogger(
        '/content/drive/MyDrive/flood_resunet_training_log.csv',
        separator=',',
        append=False
    )
]

In [None]:
# CELL 8: Train model
# Calculate steps per epoch
steps_per_epoch = len(glob.glob(os.path.join(base_path, 'train', 'images', '*.npy'))) // batch_size
validation_steps = len(glob.glob(os.path.join(base_path, 'val', 'images', '*.npy'))) // batch_size
callbacks_list = [
    callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        monitor='val_dice_coefficient',
        save_best_only=True,
        mode='max'
    ),
    callbacks.EarlyStopping(
        monitor='val_dice_coefficient',
        patience=10,
        restore_best_weights=True,
        mode='max'  # Added mode parameter
    ),
    callbacks.ReduceLROnPlateau(
        monitor='val_dice_coefficient',
        factor=0.5,
        patience=5,
        min_lr=1e-6,
        mode='max'  # Make sure this is consistent too
    ),
    callbacks.TensorBoard(
        log_dir='/content/drive/MyDrive/flood_resunet_logs1',
        histogram_freq=1
    )
]

# Train model
epochs = 50
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    callbacks=callbacks_list
)



In [None]:
# CELL 9: Plot training history (updated)
def plot_history(history):
    """
    Plot comprehensive training history with all metrics
    """
    plt.figure(figsize=(20, 15))

    # Plot loss
    plt.subplot(3, 2, 1)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot Dice coefficient
    plt.subplot(3, 2, 2)
    plt.plot(history.history['dice_coefficient'], label='Training Dice')
    plt.plot(history.history['val_dice_coefficient'], label='Validation Dice')
    plt.title('Dice Coefficient')
    plt.xlabel('Epoch')
    plt.ylabel('Dice')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot IoU
    plt.subplot(3, 2, 3)
    plt.plot(history.history['iou_score'], label='Training IoU')
    plt.plot(history.history['val_iou_score'], label='Validation IoU')
    plt.title('IoU Score')
    plt.xlabel('Epoch')
    plt.ylabel('IoU')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot binary accuracy
    plt.subplot(3, 2, 4)
    plt.plot(history.history['binary_accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_binary_accuracy'], label='Validation Accuracy')
    plt.title('Binary Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot F1 Score
    plt.subplot(3, 2, 5)
    plt.plot(history.history['f1_score_metric'], label='Training F1')
    plt.plot(history.history['val_f1_score_metric'], label='Validation F1')
    plt.title('F1 Score')
    plt.xlabel('Epoch')
    plt.ylabel('F1')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot Precision and Recall
    plt.subplot(3, 2, 6)
    plt.plot(history.history['precision_metric'], label='Training Precision')
    plt.plot(history.history['val_precision_metric'], label='Validation Precision')
    plt.plot(history.history['recall_metric'], label='Training Recall')
    plt.plot(history.history['val_recall_metric'], label='Validation Recall')
    plt.title('Precision and Recall')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    plt.tight_layout()
    plt.savefig('/content/drive/MyDrive/flood_resunet_training_history.png', dpi=300, bbox_inches='tight')
    plt.show()




plot_history(history)

In [None]:
# CELL 10: Evaluate model on test set (updated)
# Load best model with all custom metrics
best_model = models.load_model(checkpoint_path, custom_objects={
    'dice_coefficient': dice_coefficient,
    'dice_loss': dice_loss,
    'bce_dice_loss': bce_dice_loss,
    'iou_score': iou_score,
    'f1_score_metric': f1_score_metric,
    'precision_metric': precision_metric,
    'recall_metric': recall_metric
})

# Evaluate on test set
test_results = best_model.evaluate(test_dataset)
print("\nTest Results:")
for metric_name, value in zip(best_model.metrics_names, test_results):
    print(f"{metric_name}: {value:.4f}")

# Save test results to file
import json
test_metrics = {metric_name: float(value) for metric_name, value in zip(best_model.metrics_names, test_results)}
with open('/content/drive/MyDrive/flood_resunet_test_metrics.json', 'w') as f:
    json.dump(test_metrics, f, indent=4)
print("Test metrics saved to JSON file")

In [None]:
# CELL 11: Visualize predictions
def visualize_predictions(model, dataset, num_samples=5):
    """Visualize model predictions against ground truth"""
    plt.figure(figsize=(15, 5*num_samples))

    # Get samples from dataset
    for i, (images, masks) in enumerate(dataset.take(num_samples)):
        if i >= num_samples:
            break

        # Get predictions
        preds = model.predict(images)

        for j in range(min(images.shape[0], 3)):
            # Get image, mask, and prediction
            image = images[j].numpy()
            mask = masks[j].numpy()
            pred = preds[j]

            # Normalize image for visualization (0-1 range)
            image_viz = (image - image.min()) / (image.max() - image.min() + 1e-8)

            # Composite RGB using VH and VV channels
            rgb_viz = np.stack([
                image_viz[:, :, 0],  # VH as R
                image_viz[:, :, 1],  # VV as G
                (image_viz[:, :, 0] + image_viz[:, :, 1]) / 2  # Average as B
            ], axis=-1)

            # Clip to 0-1 range
            rgb_viz = np.clip(rgb_viz, 0, 1)

            # Convert masks to binary
            mask_binary = (mask > 0.5).astype(np.float32)
            pred_binary = (pred > 0.5).astype(np.float32)

            # Calculate metrics for this sample
            dice = np.sum(2 * mask_binary * pred_binary) / (np.sum(mask_binary) + np.sum(pred_binary) + 1e-8)

            # Plot original image
            row_idx = i * 3 + j
            plt.subplot(num_samples * 3, 3, row_idx * 3 + 1)
            plt.imshow(rgb_viz)
            plt.title(f"SAR Image - Sample {row_idx+1}")
            plt.axis('off')

            # Plot ground truth mask
            plt.subplot(num_samples * 3, 3, row_idx * 3 + 2)
            plt.imshow(mask.squeeze(), cmap='Blues')
            plt.title(f"Ground Truth")
            plt.axis('off')

            # Plot prediction
            plt.subplot(num_samples * 3, 3, row_idx * 3 + 3)
            plt.imshow(pred.squeeze(), cmap='Blues')
            plt.title(f"Prediction (Dice={dice:.3f})")
            plt.axis('off')

    plt.tight_layout()
    plt.savefig('/content/drive/MyDrive/flood_resunet_predictions.png')
    plt.show()

# Visualize predictions on test set
visualize_predictions(best_model, test_dataset, num_samples=5)

In [None]:
# CELL 12: Calculate and report detailed metrics (updated)
def calculate_metrics(model, dataset, threshold=0.5):
    """Calculate comprehensive metrics on dataset with pixel-level and image-level evaluation"""
    y_true_all = []
    y_pred_all = []
    dice_scores = []
    iou_scores = []

    # For per-image metrics
    image_metrics = []

    for images, masks in dataset:
        # Get predictions
        preds = model.predict(images)

        # Process each image in the batch
        for i in range(len(images)):
            mask = masks[i].numpy()
            pred = preds[i]

            # Calculate per-image metrics
            mask_flat = mask.flatten()
            pred_flat = pred.flatten()

            # Apply threshold
            pred_binary = (pred_flat > threshold).astype(np.int32)
            mask_binary = (mask_flat > threshold).astype(np.int32)

            # Add to overall lists
            y_true_all.extend(mask_binary)
            y_pred_all.extend(pred_binary)

            # Calculate per-image Dice
            intersection = np.sum(mask_binary * pred_binary)
            dice = (2. * intersection) / (np.sum(mask_binary) + np.sum(pred_binary) + 1e-8)
            dice_scores.append(dice)

            # Calculate per-image IoU
            union = np.sum(mask_binary) + np.sum(pred_binary) - intersection
            iou = intersection / (union + 1e-8)
            iou_scores.append(iou)

            # Per-image confusion matrix
            img_tn, img_fp, img_fn, img_tp = confusion_matrix(mask_binary, pred_binary, labels=[0, 1]).ravel()

            # Per-image metrics
            img_precision = img_tp / (img_tp + img_fp + 1e-8)
            img_recall = img_tp / (img_tp + img_fn + 1e-8)
            img_f1 = 2 * (img_precision * img_recall) / (img_precision + img_recall + 1e-8)
            img_accuracy = (img_tp + img_tn) / (img_tp + img_tn + img_fp + img_fn)

            # Store per-image metrics
            image_metrics.append({
                'dice': dice,
                'iou': iou,
                'precision': img_precision,
                'recall': img_recall,
                'f1': img_f1,
                'accuracy': img_accuracy,
                'tp': int(img_tp),
                'fp': int(img_fp),
                'tn': int(img_tn),
                'fn': int(img_fn)
            })

    # Calculate overall metrics from all pixels
    tn, fp, fn, tp = confusion_matrix(y_true_all, y_pred_all, labels=[0, 1]).ravel()

    # Calculate overall metrics
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    specificity = tn / (tn + fp + 1e-8)

    # Calculate IoU
    intersection = tp
    union = tp + fp + fn
    iou = intersection / (union + 1e-8)

    # Calculate mean Dice and IoU across all images
    mean_dice = np.mean(dice_scores)
    mean_iou = np.mean(iou_scores)

    print("\n======== Detailed Metrics ========")
    print(f"Overall Metrics (pixel-level):")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall/Sensitivity: {recall:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"IoU: {iou:.4f}")
    print(f"Dice Coefficient: {(2 * tp) / (2 * tp + fp + fn + 1e-8):.4f}")

    print("\nMean Per-Image Metrics:")
    print(f"Mean Dice: {mean_dice:.4f}")
    print(f"Mean IoU: {mean_iou:.4f}")

    print("\nConfusion Matrix:")
    print(f"True Positives: {tp}")
    print(f"False Positives: {fp}")
    print(f"True Negatives: {tn}")
    print(f"False Negatives: {fn}")

    # Calculate per-class metrics (for binary segmentation)
    print("\nPer-Class Metrics:")
    # Background (class 0)
    bg_precision = tn / (tn + fn + 1e-8)
    bg_recall = tn / (tn + fp + 1e-8)
    bg_f1 = 2 * (bg_precision * bg_recall) / (bg_precision + bg_recall + 1e-8)
    print(f"Background - Precision: {bg_precision:.4f}, Recall: {bg_recall:.4f}, F1: {bg_f1:.4f}")

    # Flood (class 1)
    flood_precision = precision
    flood_recall = recall
    flood_f1 = f1
    print(f"Flood - Precision: {flood_precision:.4f}, Recall: {flood_recall:.4f}, F1: {flood_f1:.4f}")

    # Save per-image metrics to CSV
    import pandas as pd
    img_metrics_df = pd.DataFrame(image_metrics)
    img_metrics_df.to_csv('/content/drive/MyDrive/flood_resunet_per_image_metrics.csv', index_label='image_id')
    print("\nPer-image metrics saved to CSV file")

    # Return comprehensive metrics as dictionary
    return {
        'overall': {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'specificity': float(specificity),
            'f1': float(f1),
            'iou': float(iou),
            'dice': float((2 * tp) / (2 * tp + fp + fn + 1e-8))
        },
        'per_image_mean': {
            'dice': float(mean_dice),
            'iou': float(mean_iou)
        },
        'confusion_matrix': {
            'tn': int(tn),
            'fp': int(fp),
            'fn': int(fn),
            'tp': int(tp)
        },
        'per_class': {
            'background': {
                'precision': float(bg_precision),
                'recall': float(bg_recall),
                'f1': float(bg_f1)
            },
            'flood': {
                'precision': float(flood_precision),
                'recall': float(flood_recall),
                'f1': float(flood_f1)
            }
        }
    }

# Calculate comprehensive metrics on test set
detailed_metrics = calculate_metrics(best_model, test_dataset)

# Save detailed metrics to file
with open('/content/drive/MyDrive/flood_resunet_detailed_metrics.json', 'w') as f:
    json.dump(detailed_metrics, f, indent=4)

print("Detailed metrics saved!")

In [None]:
# CELL 13: Save model and metrics
# Save model in TensorFlow SavedModel format (recommended)
best_model.save('/content/drive/MyDrive/flood_resunet_model')

# Save metrics to file
import json
with open('/content/drive/MyDrive/flood_resunet_metrics.json', 'w') as f:
    # Convert confusion matrix values to int for JSON serialization
    metrics['confusion_matrix'] = {k: int(v) for k, v in metrics['confusion_matrix'].items()}
    json.dump(metrics, f, indent=4)

print("Model and metrics saved!")

In [None]:
# CELL 14: Create example inference function for new images
def predict_flood(model, image_path, output_path=None):
    """Predict flood on a new SAR image

    Args:
        model: Loaded ResUNet model
        image_path: Path to input .npy file (preprocessed)
        output_path: Path to save visualization (optional)

    Returns:
        Predicted flood mask
    """
    # Load image
    image = np.load(image_path)

    # Add batch dimension
    image_batch = np.expand_dims(image, axis=0)

    # Predict
    prediction = model.predict(image_batch)[0]

    # Squeeze prediction
    prediction = prediction.squeeze()

    # Create binary mask
    binary_mask = (prediction > 0.5).astype(np.uint8)

    if output_path:
        # Visualize
        plt.figure(figsize=(15, 5))

        # Normalize image for visualization
        image_viz = (image - image.min()) / (image.max() - image.min() + 1e-8)

        # Create composite RGB
        rgb_viz = np.stack([
            image_viz[:, :, 0],  # VH as R
            image_viz[:, :, 1],  # VV as G
            (image_viz[:, :, 0] + image_viz[:, :, 1]) / 2  # Average as B
        ], axis=-1)

        # Plot original image
        plt.subplot(1, 3, 1)
        plt.imshow(rgb_viz)
        plt.title("SAR Image")
        plt.axis('off')

        # Plot probability map
        plt.subplot(1, 3, 2)
        plt.imshow(prediction, cmap='Blues')
        plt.colorbar(fraction=0.046, pad=0.04)
        plt.title("Flood Probability")
        plt.axis('off')

        # Plot binary mask
        plt.subplot(1, 3, 3)
        plt.imshow(binary_mask, cmap='Blues')
        plt.title("Flood Mask (Binary)")
        plt.axis('off')

        plt.tight_layout()
        plt.savefig(output_path)
        plt.close()

    return binary_mask

# Example usage (uncomment when needed)
"""
test_image_path = "/content/drive/MyDrive/ResUNet_preprocessed/test/images/test_0001.npy"
output_path = "/content/drive/MyDrive/test_prediction.png"
flood_mask = predict_flood(best_model, test_image_path, output_path)
"""