In [None]:
import os
import csv
import numpy as np
import shutil
from PIL import Image
from tqdm import tqdm
import random
from sklearn.model_selection import train_test_split
import json

class ChangeDetectionFloodPreprocessor:
    def __init__(self, base_path, output_path, img_size=256, test_size=0.15, val_size=0.15, random_state=42):
        """
        Preprocessor for change detection flood dataset
        
        Args:
            base_path: Path to dataset containing A, B, Label folders
            output_path: Path to save preprocessed data
            img_size: Target image size (should be 256 since input is already 256x256)
            test_size: Fraction of data for test set
            val_size: Fraction of remaining data for validation set
            random_state: Random seed for reproducible splits
        """
        self.base_path = base_path
        self.output_path = output_path
        self.img_size = img_size
        self.test_size = test_size
        self.val_size = val_size
        self.random_state = random_state
        
        # Set random seed for reproducibility
        random.seed(random_state)
        np.random.seed(random_state)
        
        # Define folder paths
        self.pre_flood_dir = os.path.join(base_path, 'A')  # Pre-flood images
        self.post_flood_dir = os.path.join(base_path, 'B')  # Post-flood images
        self.label_dir = os.path.join(base_path, 'Label')  # Ground truth masks
        
        # Statistics dictionary to track dataset properties
        self.stats = {
            'train': {'count': 0, 'flood_pixels': 0, 'total_pixels': 0},
            'val': {'count': 0, 'flood_pixels': 0, 'total_pixels': 0},
            'test': {'count': 0, 'flood_pixels': 0, 'total_pixels': 0}
        }
        
        self.create_output_dirs()
        
    def create_output_dirs(self):
        """Create output directory structure"""
        if os.path.exists(self.output_path):
            shutil.rmtree(self.output_path)
        
        for split in ['train', 'val', 'test']:
            os.makedirs(os.path.join(self.output_path, split, 'pre_flood'), exist_ok=True)
            os.makedirs(os.path.join(self.output_path, split, 'post_flood'), exist_ok=True)
            os.makedirs(os.path.join(self.output_path, split, 'masks'), exist_ok=True)
            
        # Create splits directory
        os.makedirs(os.path.join(self.output_path, 'splits'), exist_ok=True)
            
        print(f"Created output directories at {self.output_path}")
    
    def get_image_list(self):
        """Get list of all image names and verify they exist in all folders"""
        print("Scanning dataset for image files...")
        
        # Get all PNG files from pre-flood folder
        pre_flood_files = [f for f in os.listdir(self.pre_flood_dir) if f.endswith('.png')]
        pre_flood_files.sort()
        
        valid_images = []
        missing_files = []
        
        for img_name in tqdm(pre_flood_files, desc="Verifying image triplets"):
            pre_path = os.path.join(self.pre_flood_dir, img_name)
            post_path = os.path.join(self.post_flood_dir, img_name)
            label_path = os.path.join(self.label_dir, img_name)
            
            # Check if all three files exist
            if os.path.exists(pre_path) and os.path.exists(post_path) and os.path.exists(label_path):
                valid_images.append(img_name)
            else:
                missing_files.append(img_name)
                
        print(f"Found {len(valid_images)} valid image triplets")
        if missing_files:
            print(f"Warning: {len(missing_files)} images missing from one or more folders")
            print(f"First few missing: {missing_files[:5]}")
            
        return valid_images
    
    def create_data_splits(self, image_list):
        """Create train/validation/test splits and save them"""
        print(f"Creating data splits with random_state={self.random_state}")
        print(f"Test size: {self.test_size}, Validation size: {self.val_size}")
        
        # First split: separate test set
        train_val_images, test_images = train_test_split(
            image_list, 
            test_size=self.test_size, 
            random_state=self.random_state,
            shuffle=True
        )
        
        # Second split: separate train and validation from remaining data
        # Calculate validation size relative to train_val set
        val_size_adjusted = self.val_size / (1 - self.test_size)
        
        train_images, val_images = train_test_split(
            train_val_images,
            test_size=val_size_adjusted,
            random_state=self.random_state,
            shuffle=True
        )
        
        splits = {
            'train': train_images,
            'val': val_images,
            'test': test_images
        }
        
        # Print split information
        print(f"Data splits created:")
        print(f"  Train: {len(train_images)} images ({len(train_images)/len(image_list)*100:.1f}%)")
        print(f"  Validation: {len(val_images)} images ({len(val_images)/len(image_list)*100:.1f}%)")
        print(f"  Test: {len(test_images)} images ({len(test_images)/len(image_list)*100:.1f}%)")
        
        # Save splits to CSV files
        splits_dir = os.path.join(self.output_path, 'splits')
        
        for split_name, images in splits.items():
            csv_path = os.path.join(splits_dir, f'{split_name}_images.csv')
            with open(csv_path, 'w', newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(['image_name'])  # Header
                for img_name in images:
                    writer.writerow([img_name])
            print(f"Saved {split_name} split to {csv_path}")
        
        # Save split info as JSON for reference
        split_info = {
            'total_images': len(image_list),
            'splits': {
                'train': {'count': len(train_images), 'percentage': len(train_images)/len(image_list)*100},
                'val': {'count': len(val_images), 'percentage': len(val_images)/len(image_list)*100},
                'test': {'count': len(test_images), 'percentage': len(test_images)/len(image_list)*100}
            },
            'random_state': self.random_state,
            'test_size': self.test_size,
            'val_size': self.val_size
        }
        
        with open(os.path.join(splits_dir, 'split_info.json'), 'w') as f:
            json.dump(split_info, f, indent=2)
        
        return splits
    
    def load_and_process_image(self, image_path):
        """Load and process a single image"""
        try:
            # Load image
            img = Image.open(image_path)
            
            # Convert to RGB if needed
            if img.mode != 'RGB':
                img = img.convert('RGB')
            
            # Resize if needed (though should already be 256x256)
            if img.size != (self.img_size, self.img_size):
                img = img.resize((self.img_size, self.img_size), Image.LANCZOS)
            
            # Convert to numpy array and normalize to [0, 1]
            img_array = np.array(img, dtype=np.float32) / 255.0
            
            return img_array
            
        except Exception as e:
            print(f"Error processing image {image_path}: {str(e)}")
            return None
    
    def load_and_process_mask(self, mask_path):
        """Load and process a mask image"""
        try:
            # Load mask
            mask = Image.open(mask_path)
            
            # Convert to grayscale if needed
            if mask.mode != 'L':
                mask = mask.convert('L')
            
            # Resize if needed
            if mask.size != (self.img_size, self.img_size):
                mask = mask.resize((self.img_size, self.img_size), Image.NEAREST)
            
            # Convert to numpy array and create binary mask
            mask_array = np.array(mask, dtype=np.float32)
            
            # Normalize and threshold to create binary mask
            mask_array = mask_array / 255.0  # Normalize to [0, 1]
            mask_array = (mask_array > 0.5).astype(np.float32)  # Binary threshold
            
            return mask_array
            
        except Exception as e:
            print(f"Error processing mask {mask_path}: {str(e)}")
            return None
    
    def update_stats(self, split, mask):
        """Update dataset statistics"""
        self.stats[split]['count'] += 1
        self.stats[split]['flood_pixels'] += mask.sum()
        self.stats[split]['total_pixels'] += mask.size
    
    def process_split(self, split_name, image_list):
        """Process images for a specific split"""
        print(f"Processing {split_name} split...")
        
        split_dir = os.path.join(self.output_path, split_name)
        
        for idx, img_name in enumerate(tqdm(image_list, desc=f"Processing {split_name}")):
            # Define paths
            pre_flood_path = os.path.join(self.pre_flood_dir, img_name)
            post_flood_path = os.path.join(self.post_flood_dir, img_name)
            mask_path = os.path.join(self.label_dir, img_name)
            
            # Process pre-flood image
            pre_flood_array = self.load_and_process_image(pre_flood_path)
            if pre_flood_array is None:
                continue
            
            # Process post-flood image
            post_flood_array = self.load_and_process_image(post_flood_path)
            if post_flood_array is None:
                continue
            
            # Process mask
            mask_array = self.load_and_process_mask(mask_path)
            if mask_array is None:
                continue
            
            # Save processed images
            base_name = os.path.splitext(img_name)[0]
            
            # Save as NPY files for exact precision
            np.save(os.path.join(split_dir, 'pre_flood', f'{base_name}.npy'), pre_flood_array)
            np.save(os.path.join(split_dir, 'post_flood', f'{base_name}.npy'), post_flood_array)
            np.save(os.path.join(split_dir, 'masks', f'{base_name}.npy'), mask_array)
            
            # Also save as PNG for visualization
            pre_flood_img = Image.fromarray((pre_flood_array * 255).astype(np.uint8))
            post_flood_img = Image.fromarray((post_flood_array * 255).astype(np.uint8))
            mask_img = Image.fromarray((mask_array * 255).astype(np.uint8), mode='L')
            
            pre_flood_img.save(os.path.join(split_dir, 'pre_flood', f'{base_name}.png'))
            post_flood_img.save(os.path.join(split_dir, 'post_flood', f'{base_name}.png'))
            mask_img.save(os.path.join(split_dir, 'masks', f'{base_name}.png'))
            
            # Update statistics
            self.update_stats(split_name, mask_array)
    
    def print_stats(self):
        """Print dataset statistics"""
        print("\nDataset Statistics:")
        print("=" * 50)
        for split, stat in self.stats.items():
            if stat['count'] > 0:
                flood_percentage = 100 * stat['flood_pixels'] / stat['total_pixels']
                print(f"{split.upper()} set: {stat['count']} samples, "
                      f"Flood pixels: {flood_percentage:.2f}%")
        print("=" * 50)
    
    def calculate_class_weights(self):
        """Calculate class weights to handle imbalance"""
        if self.stats['train']['total_pixels'] > 0:
            pos_ratio = self.stats['train']['flood_pixels'] / self.stats['train']['total_pixels']
            neg_ratio = 1 - pos_ratio
            
            # Class weights inversely proportional to class frequency
            weight_non_flood = 1.0
            weight_flood = neg_ratio / pos_ratio if pos_ratio > 0 else 1.0
            
            print(f"\nClass weights for handling imbalance:")
            print(f"Weight for non-flood (0): {weight_non_flood:.4f}")
            print(f"Weight for flood (1): {weight_flood:.4f}")
            
            # Save weights for model training
            return np.array([weight_non_flood, weight_flood])
        return np.array([1.0, 1.0])
    
    def run_preprocessing(self):
        """Run the complete preprocessing pipeline"""
        print("Starting Change Detection Flood Preprocessing...")
        print(f"Input path: {self.base_path}")
        print(f"Output path: {self.output_path}")
        print(f"Expected image range: image_1.png to image_5360.png")
        print(f"Image size: {self.img_size}x{self.img_size}")
        print("=" * 50)
        
        # Step 1: Get list of valid images
        image_list = self.get_image_list()
        if len(image_list) == 0:
            raise ValueError("No valid image triplets found!")
        
        # Step 2: Create data splits
        splits = self.create_data_splits(image_list)
        
        # Step 3: Process each split
        for split_name, images in splits.items():
            self.process_split(split_name, images)
        
        # Step 4: Print statistics and calculate class weights
        self.print_stats()
        weights = self.calculate_class_weights()
        
        # Step 5: Save class weights
        weights_path = os.path.join(self.output_path, 'class_weights.npy')
        np.save(weights_path, weights)
        print(f"\nClass weights saved to: {weights_path}")
        
        # Step 6: Save preprocessing info
        preprocessing_info = {
            'dataset_type': 'change_detection_flood',
            'input_format': 'pre_flood + post_flood -> flood_mask',
            'total_images': len(image_list),
            'image_size': [self.img_size, self.img_size],
            'splits': {
                'train': len(splits['train']),
                'val': len(splits['val']),
                'test': len(splits['test'])
            },
            'random_state': self.random_state,
            'class_weights': weights.tolist(),
            'statistics': self.stats
        }
        
        info_path = os.path.join(self.output_path, 'preprocessing_info.json')
        with open(info_path, 'w') as f:
            json.dump(preprocessing_info, f, indent=2)
        
        print(f"\nPreprocessing complete! Processed data saved to: {self.output_path}")
        print("\nOutput structure:")
        print("preprocessed/")
        print("├── train/")
        print("│   ├── pre_flood/ (NPY and PNG files)")
        print("│   ├── post_flood/ (NPY and PNG files)")
        print("│   └── masks/ (NPY and PNG files)")
        print("├── val/")
        print("│   ├── pre_flood/ (NPY and PNG files)")
        print("│   ├── post_flood/ (NPY and PNG files)")
        print("│   └── masks/ (NPY and PNG files)")
        print("├── test/")
        print("│   ├── pre_flood/ (NPY and PNG files)")
        print("│   ├── post_flood/ (NPY and PNG files)")
        print("│   └── masks/ (NPY and PNG files)")
        print("├── splits/")
        print("│   ├── train_images.csv")
        print("│   ├── val_images.csv")
        print("│   ├── test_images.csv")
        print("│   └── split_info.json")
        print("├── class_weights.npy")
        print("└── preprocessing_info.json")
        print("=" * 50)

# Example usage for Kaggle environment
if __name__ == "__main__":
    print("Change Detection Flood Preprocessing for Kaggle Environment")
    print("=" * 60)
    
    # Kaggle paths
    base_path = "/kaggle/input/s1gfloods"  # Contains A, B, Label folders
    output_path = "/kaggle/working/preprocessed_change_detection"
    
    # Initialize and run preprocessor
    preprocessor = ChangeDetectionFloodPreprocessor(
        base_path=base_path,
        output_path=output_path,
        img_size=256,
        test_size=0.15,    # 15% for test
        val_size=0.15,     # 15% for validation 
        random_state=42    # For reproducible splits
    )
    
    # Run preprocessing
    preprocessor.run_preprocessing()
    
    print("\nPreprocessing completed successfully!")
    print("Ready for training with change detection model.")

In [None]:
# Import necessary libraries
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, optimizers
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import random
import glob
from tensorflow.keras import backend as K
import json
import pandas as pd
from PIL import Image
import math

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

print("TensorFlow version:", tf.__version__)
print("GPU Available:", len(tf.config.list_physical_devices('GPU')) > 0)
print("GPU Devices:", tf.config.list_physical_devices('GPU'))

# Comprehensive precision control and TensorFlow configuration
def setup_tensorflow_for_stable_training():
    """Configure TensorFlow for stable float32 training with maximum stability"""
    # 1. Explicitly reset and disable mixed precision at all levels
    tf.keras.mixed_precision.set_global_policy('float32')
    tf.keras.backend.set_floatx('float32')
    
    # 2. Configure TensorFlow to avoid XLA precision issues
    tf.config.optimizer.set_experimental_options({
        'auto_mixed_precision': False,
        'disable_meta_optimizer': True,
        'constant_folding': False,
        'arithmetic_optimization': False,
        'loop_optimization': False,
        'function_optimization': False
    })
    
    # 3. Disable JIT compilation which can cause precision issues
    tf.config.optimizer.set_jit(False)
    
    # 4. Force CPU/GPU to use float32
    tf.config.experimental.enable_tensor_float_32_execution(False)
    
    # 5. Set memory growth and device policy to avoid GPU/CPU conflicts
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(f"GPU configuration warning: {e}")
    
    # 6. CRITICAL: Set device policy to handle CPU/GPU tensor placement issues
    tf.config.experimental.set_device_policy('silent')
    
    # 7. Additional precision controls
    import os
    os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '0'
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Reduce log spam
    
    print("TensorFlow configured for maximum stability float32 training")
    print("- Mixed precision: Completely disabled")
    print("- XLA JIT: Disabled")
    print("- TF32: Disabled") 
    print("- Auto mixed precision: Disabled")
    print("- All optimizations: Disabled for stability")
    print("- Device policy: Silent (handles CPU/GPU conflicts)")
    print("- Default dtype: float32")

# Apply comprehensive TensorFlow configuration
setup_tensorflow_for_stable_training()

# Define paths
BASE_PATH = "/kaggle/working/preprocessed_change_detection"
OUTPUT_PATH = "/kaggle/working/flood_change_detection_model"
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Define helper functions for data loading (CHANGE DETECTION VERSION)
def load_image_pair(pre_flood_path, post_flood_path):
    """Load and concatenate pre-flood and post-flood images with consistent float32 precision"""
    # Convert tensor to string if needed
    if isinstance(pre_flood_path, tf.Tensor):
        pre_flood_path = pre_flood_path.numpy().decode('utf-8')
    if isinstance(post_flood_path, tf.Tensor):
        post_flood_path = post_flood_path.numpy().decode('utf-8')
    
    # Load both images
    pre_flood = np.load(pre_flood_path).astype(np.float32)
    post_flood = np.load(post_flood_path).astype(np.float32)
    
    # Concatenate along channel dimension to create 6-channel input
    # Pre-flood: channels 0-2, Post-flood: channels 3-5
    combined_image = np.concatenate([pre_flood, post_flood], axis=-1)
    
    return combined_image

def load_mask(mask_path):
    """Load mask from .npy file with consistent float32 precision"""
    # Convert tensor to string if needed
    if isinstance(mask_path, tf.Tensor):
        mask_path = mask_path.numpy().decode('utf-8')
    
    mask = np.load(mask_path).astype(np.float32)
    return mask

# Enhanced data augmentation for change detection
def enhanced_change_detection_augmentation():
    """Enhanced data augmentation for change detection with float32 precision"""
    def augment_fn(image, mask):
        # Ensure float32 precision at the start
        image = tf.cast(image, tf.float32)
        mask = tf.cast(mask, tf.float32)
        
        # Random horizontal flip (apply to both pre and post images)
        if tf.random.uniform([]) > 0.5:
            image = tf.image.flip_left_right(image)
            mask = tf.image.flip_left_right(mask)
        
        # Random vertical flip
        if tf.random.uniform([]) > 0.5:
            image = tf.image.flip_up_down(image)
            mask = tf.image.flip_up_down(mask)
        
        # Random rotation (90, 180, 270 degrees)
        k = tf.random.uniform([], 0, 4, dtype=tf.int32)
        image = tf.image.rot90(image, k)
        mask = tf.image.rot90(mask, k)
        
        # Random brightness adjustment (apply to both images)
        if tf.random.uniform([]) > 0.7:
            # Apply same brightness change to both pre and post flood images
            brightness_delta = tf.random.uniform([], -0.1, 0.1)
            image = tf.clip_by_value(image + brightness_delta, 0.0, 1.0)
        
        # Random contrast adjustment (apply to both images)
        if tf.random.uniform([]) > 0.7:
            contrast_factor = tf.random.uniform([], 0.9, 1.1)
            image = tf.clip_by_value(image * contrast_factor, 0.0, 1.0)
        
        # Gaussian noise (simulate sensor noise)
        if tf.random.uniform([]) > 0.6:
            noise = tf.random.normal(tf.shape(image), mean=0, stddev=0.01)
            image = tf.clip_by_value(image + noise, 0.0, 1.0)
        
        # Ensure outputs are float32
        return tf.cast(image, tf.float32), tf.cast(mask, tf.float32)
    
    return augment_fn

def create_change_detection_dataset(base_path, split, batch_size=32, shuffle=True, use_augmentation=False):
    """Create a TensorFlow dataset for change detection"""
    pre_flood_paths = sorted(glob.glob(os.path.join(base_path, split, 'pre_flood', '*.npy')))
    post_flood_paths = sorted(glob.glob(os.path.join(base_path, split, 'post_flood', '*.npy')))
    mask_paths = sorted(glob.glob(os.path.join(base_path, split, 'masks', '*.npy')))
    
    if len(pre_flood_paths) == 0 or len(post_flood_paths) == 0 or len(mask_paths) == 0:
        raise ValueError(f"No images found in {base_path}/{split}")
    
    if not (len(pre_flood_paths) == len(post_flood_paths) == len(mask_paths)):
        raise ValueError(f"Mismatch in number of images: pre={len(pre_flood_paths)}, post={len(post_flood_paths)}, masks={len(mask_paths)}")
    
    print(f"Found {len(pre_flood_paths)} image pairs and masks for {split}")
    
    # Create datasets from paths
    pre_flood_dataset = tf.data.Dataset.from_tensor_slices(pre_flood_paths)
    post_flood_dataset = tf.data.Dataset.from_tensor_slices(post_flood_paths)
    mask_dataset = tf.data.Dataset.from_tensor_slices(mask_paths)
    
    # Combine all paths
    dataset = tf.data.Dataset.zip((pre_flood_dataset, post_flood_dataset, mask_dataset))
    
    # Shuffle if needed
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(pre_flood_paths), seed=42)
    
    # Map loading function to the dataset
    dataset = dataset.map(
        lambda pre_path, post_path, mask_path: (
            tf.py_function(
                func=load_image_pair,
                inp=[pre_path, post_path],
                Tout=tf.float32
            ),
            tf.py_function(
                func=load_mask,
                inp=[mask_path],
                Tout=tf.float32
            )
        ),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    
    # Set shapes and ensure float32 precision
    dataset = dataset.map(
        lambda x, y: (
            tf.ensure_shape(tf.cast(x, tf.float32), [256, 256, 6]),  # 6 channels now
            tf.ensure_shape(tf.cast(y, tf.float32), [256, 256])
        )
    )
    
    # Add channel dimension to mask and ensure float32 precision
    dataset = dataset.map(lambda x, y: (
        tf.cast(x, tf.float32), 
        tf.cast(tf.expand_dims(y, axis=-1), tf.float32)
    ))
    
    # Apply data augmentation only for training
    if use_augmentation and split == 'train':
        augment_fn = enhanced_change_detection_augmentation()
        dataset = dataset.map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE)
    
    # Important: Add repeat to prevent dataset exhaustion
    dataset = dataset.repeat()
    
    # Batch and prefetch
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    return dataset, len(pre_flood_paths)

def visualize_change_detection_samples(dataset, num_samples=3):
    """Visualize random samples from the change detection dataset"""
    plt.figure(figsize=(20, 5*num_samples))
    
    for i, (images, masks) in enumerate(dataset.take(num_samples)):
        for j in range(min(images.shape[0], 1)):  # Show one sample per batch
            # Get image and mask
            image = images[j].numpy()
            mask = masks[j].numpy().squeeze()
            
            # Split the 6-channel image back to pre and post flood
            pre_flood = image[:, :, :3]  # First 3 channels
            post_flood = image[:, :, 3:]  # Last 3 channels
            
            # Display pre-flood image
            plt.subplot(num_samples, 4, i*4+1)
            plt.imshow(pre_flood)
            plt.title(f"Pre-Flood - Sample {i+1}")
            plt.axis('off')
            
            # Display post-flood image
            plt.subplot(num_samples, 4, i*4+2)
            plt.imshow(post_flood)
            plt.title(f"Post-Flood - Sample {i+1}")
            plt.axis('off')
            
            # Display difference image
            difference = np.abs(post_flood - pre_flood)
            plt.subplot(num_samples, 4, i*4+3)
            plt.imshow(difference)
            plt.title(f"Difference - Sample {i+1}")
            plt.axis('off')
            
            # Display mask
            plt.subplot(num_samples, 4, i*4+4)
            plt.imshow(mask, cmap='Blues')
            plt.title(f"Flood Mask - Sample {i+1}")
            plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'change_detection_samples.png'))
    plt.show()

# Define custom metrics with consistent float32 casting (same as before)
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    """Calculate Dice coefficient with consistent float32 precision"""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    result = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return tf.cast(result, tf.float32)

def dice_loss(y_true, y_pred, smooth=1e-6):
    """Dice loss function based on dice coefficient with consistent float32 precision"""
    return tf.cast(1 - dice_coefficient(y_true, y_pred, smooth), tf.float32)

def optimized_focal_tversky_loss(y_true, y_pred, alpha=0.7, beta=0.3, gamma=1.5, smooth=1e-6):
    """Optimized Focal Tversky Loss for maximum IoU"""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    
    true_pos = K.sum(y_true_f * y_pred_f)
    false_neg = K.sum(y_true_f * (1 - y_pred_f))
    false_pos = K.sum((1 - y_true_f) * y_pred_f)
    
    tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth)
    focal_tversky = K.pow((1 - tversky), gamma)
    
    return tf.cast(focal_tversky, tf.float32)

def iou_score(y_true, y_pred, smooth=1e-6):
    """Calculate IoU (Intersection over Union) score with consistent float32 precision"""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection
    result = (intersection + smooth) / (union + smooth)
    return tf.cast(result, tf.float32)

def f1_score_metric(y_true, y_pred, smooth=1e-6):
    """Calculate F1 score metric with consistent float32 precision"""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)

    true_positives = K.sum(y_true_f * y_pred_f)
    predicted_positives = K.sum(y_pred_f)
    actual_positives = K.sum(y_true_f)

    precision = (true_positives + smooth) / (predicted_positives + smooth)
    recall = (true_positives + smooth) / (actual_positives + smooth)

    f1 = 2 * (precision * recall) / (precision + recall + smooth)
    return tf.cast(f1, tf.float32)

def precision_metric(y_true, y_pred, smooth=1e-6):
    """Calculate precision metric with consistent float32 precision"""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)

    true_positives = K.sum(y_true_f * y_pred_f)
    predicted_positives = K.sum(y_pred_f)

    precision = (true_positives + smooth) / (predicted_positives + smooth)
    return tf.cast(precision, tf.float32)

def recall_metric(y_true, y_pred, smooth=1e-6):
    """Calculate recall metric with consistent float32 precision"""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)

    true_positives = K.sum(y_true_f * y_pred_f)
    actual_positives = K.sum(y_true_f)

    recall = (true_positives + smooth) / (actual_positives + smooth)
    return tf.cast(recall, tf.float32)

# Modified ResUNet architecture for 6-channel input (CHANGE DETECTION)
def conv_block(inputs, filters, kernel_size=3, strides=1, padding='same'):
    """Convolutional block with batch normalization and activation - explicit float32"""
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding, dtype='float32')(inputs)
    x = layers.BatchNormalization(dtype='float32')(x)
    x = layers.ReLU(dtype='float32')(x)
    return x

def multi_scale_context_module(inputs):
    """Multi-scale context module using dilated convolutions"""
    channels = inputs.shape[-1]
    
    # Original features
    branch1 = inputs
    
    # Dilated convolutions for multi-scale context
    branch2 = layers.Conv2D(channels // 4, 3, padding='same', dilation_rate=2, activation='relu')(inputs)
    branch3 = layers.Conv2D(channels // 4, 3, padding='same', dilation_rate=4, activation='relu')(inputs)
    branch4 = layers.Conv2D(channels // 4, 3, padding='same', dilation_rate=8, activation='relu')(inputs)
    
    # Global context using 1x1 conv
    global_context = layers.GlobalAveragePooling2D()(inputs)
    global_context = layers.Dense(channels // 4, activation='relu')(global_context)
    global_context = layers.RepeatVector(inputs.shape[1] * inputs.shape[2])(global_context)
    global_context = layers.Reshape((inputs.shape[1], inputs.shape[2], channels // 4))(global_context)
    
    # Concatenate all branches
    concatenated = layers.Concatenate()([branch1, branch2, branch3, branch4, global_context])
    
    # Reduce channels back to original
    output = layers.Conv2D(channels, 1, activation='relu', padding='same')(concatenated)
    
    return output

def channel_attention(inputs, ratio=16):
    """Enhanced Squeeze and Excitation Block for channel attention with explicit float32"""
    channels = int(inputs.shape[-1])
    reduced_channels = max(channels // ratio, 8)
    
    x = layers.GlobalAveragePooling2D(dtype='float32')(inputs)
    x = layers.Reshape((1, 1, channels))(x)
    x = layers.Conv2D(reduced_channels, kernel_size=1, activation='elu', padding='same', dtype='float32')(x)
    x = layers.Dropout(0.1, dtype='float32')(x)
    x = layers.Conv2D(channels, kernel_size=1, activation='sigmoid', padding='same', dtype='float32')(x)
    
    output = layers.Multiply(dtype='float32')([inputs, x])
    return output

class ChannelMean(layers.Layer):
    """Custom layer to compute channel-wise mean - safer than Lambda"""
    def __init__(self, **kwargs):
        super(ChannelMean, self).__init__(**kwargs)
    
    def call(self, inputs):
        return tf.reduce_mean(inputs, axis=-1, keepdims=True)
    
    def get_config(self):
        return super(ChannelMean, self).get_config()

class ChannelMax(layers.Layer):
    """Custom layer to compute channel-wise max - safer than Lambda"""
    def __init__(self, **kwargs):
        super(ChannelMax, self).__init__(**kwargs)
    
    def call(self, inputs):
        return tf.reduce_max(inputs, axis=-1, keepdims=True)
    
    def get_config(self):
        return super(ChannelMax, self).get_config()

def spatial_attention(inputs):
    """Enhanced spatial attention module with custom layers instead of Lambda"""
    avg_pool = ChannelMean(dtype='float32')(inputs)
    max_pool = ChannelMax(dtype='float32')(inputs)
    
    concat = layers.Concatenate(axis=-1, dtype='float32')([avg_pool, max_pool])
    attention_map = layers.Conv2D(1, kernel_size=7, padding='same', activation='sigmoid', dtype='float32')(concat)
    
    output = layers.Multiply(dtype='float32')([inputs, attention_map])
    return output

def attention_residual_block(inputs, filters, kernel_size=3, strides=1, dropout_rate=0.1):
    """Enhanced residual block with channel and spatial attention"""
    x = conv_block(inputs, filters, kernel_size, strides)
    x = layers.Dropout(dropout_rate)(x)
    x = conv_block(x, filters, kernel_size, 1)
    
    x = channel_attention(x)
    x = spatial_attention(x)
    
    if strides > 1 or inputs.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, kernel_size=1, strides=strides, padding='same')(inputs)
        shortcut = layers.BatchNormalization()(shortcut)
    else:
        shortcut = inputs
    
    x = layers.Add()([x, shortcut])
    x = layers.ReLU()(x)
    
    return x

def build_change_detection_resunet(input_shape=(256, 256, 6), num_classes=1):
    """Build Change Detection ResUNet model - 6 channel input for pre+post flood images"""
    # Input layer with explicit dtype - 6 channels for change detection
    inputs = layers.Input(input_shape, dtype='float32')
    
    # Optional: Add a layer to process the change information explicitly
    # This helps the model understand the temporal relationship
    change_features = layers.Conv2D(32, kernel_size=3, padding='same', activation='relu', dtype='float32')(inputs)
    change_features = layers.BatchNormalization(dtype='float32')(change_features)
    
    # Initial Convolution with larger kernel for better feature extraction
    x = conv_block(change_features, 64, kernel_size=7, strides=1)

    # Encoder blocks with residual connections and max pooling
    # Encoder block 1
    skip1 = simple_attention_residual_block(x, 64, dropout_rate=0.1)
    x = layers.MaxPooling2D(2, dtype='float32')(skip1)

    # Encoder block 2
    skip2 = simple_attention_residual_block(x, 128, dropout_rate=0.1)
    x = layers.MaxPooling2D(2, dtype='float32')(skip2)

    # Encoder block 3
    skip3 = simple_attention_residual_block(x, 256, dropout_rate=0.15)
    x = layers.MaxPooling2D(2, dtype='float32')(skip3)

    # Enhanced Bridge with Multi-scale Context
    x = simple_attention_residual_block(x, 512, dropout_rate=0.2)
    x = multi_scale_context_module(x)

    # Decoder blocks with upsampling and concatenation with skip connections
    # Decoder block 1
    x = layers.UpSampling2D(2, interpolation='bilinear', dtype='float32')(x)
    x = conv_block(x, 256)
    x = layers.Concatenate(dtype='float32')([x, skip3])
    x = simple_attention_residual_block(x, 256, dropout_rate=0.15)

    # Decoder block 2
    x = layers.UpSampling2D(2, interpolation='bilinear', dtype='float32')(x)
    x = conv_block(x, 128)
    x = layers.Concatenate(dtype='float32')([x, skip2])
    x = simple_attention_residual_block(x, 128, dropout_rate=0.1)

    # Decoder block 3
    x = layers.UpSampling2D(2, interpolation='bilinear', dtype='float32')(x)
    x = conv_block(x, 64)
    x = layers.Concatenate(dtype='float32')([x, skip1])
    x = simple_attention_residual_block(x, 64, dropout_rate=0.1)

    # Output layer with explicit float32
    outputs = layers.Conv2D(num_classes, kernel_size=1, activation='sigmoid', dtype='float32')(x)

    # Create model
    model = models.Model(inputs=inputs, outputs=outputs)

    return model

def simple_spatial_attention(inputs):
    """Simple spatial attention without Lambda layers"""
    # Use custom layers instead of Lambda
    avg_pool = ChannelMean(dtype='float32')(inputs)
    max_pool = ChannelMax(dtype='float32')(inputs)
    
    concat = layers.Concatenate(axis=-1, dtype='float32')([avg_pool, max_pool])
    attention_map = layers.Conv2D(1, kernel_size=7, padding='same', activation='sigmoid', dtype='float32')(concat)
    
    output = layers.Multiply(dtype='float32')([inputs, attention_map])
    return output

def simple_attention_residual_block(inputs, filters, kernel_size=3, strides=1, dropout_rate=0.1):
    """Simple residual block with attention - no Lambda layers"""
    x = conv_block(inputs, filters, kernel_size, strides)
    x = layers.Dropout(dropout_rate)(x)
    x = conv_block(x, filters, kernel_size, 1)
    
    x = channel_attention(x)
    x = simple_spatial_attention(x)  # Use the Lambda-free version
    
    if strides > 1 or inputs.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, kernel_size=1, strides=strides, padding='same')(inputs)
        shortcut = layers.BatchNormalization()(shortcut)
    else:
        shortcut = inputs
    
    x = layers.Add()([x, shortcut])
    x = layers.ReLU()(x)
    
    return x

def build_simple_change_detection_resunet(input_shape=(256, 256, 6), num_classes=1):
    """Build a simpler Change Detection ResUNet model as fallback - no Lambda layers"""
    inputs = layers.Input(input_shape, dtype='float32')

    # Encoder
    # Block 1
    x = layers.Conv2D(64, 3, padding='same', dtype='float32')(inputs)
    x = layers.BatchNormalization(dtype='float32')(x)
    x = layers.ReLU(dtype='float32')(x)
    skip1 = x
    x = layers.MaxPooling2D(2, dtype='float32')(x)

    # Block 2
    x = layers.Conv2D(128, 3, padding='same', dtype='float32')(x)
    x = layers.BatchNormalization(dtype='float32')(x)
    x = layers.ReLU(dtype='float32')(x)
    skip2 = x
    x = layers.MaxPooling2D(2, dtype='float32')(x)

    # Block 3
    x = layers.Conv2D(256, 3, padding='same', dtype='float32')(x)
    x = layers.BatchNormalization(dtype='float32')(x)
    x = layers.ReLU(dtype='float32')(x)
    skip3 = x
    x = layers.MaxPooling2D(2, dtype='float32')(x)

    # Bridge
    x = layers.Conv2D(512, 3, padding='same', dtype='float32')(x)
    x = layers.BatchNormalization(dtype='float32')(x)
    x = layers.ReLU(dtype='float32')(x)

    # Decoder
    # Block 1
    x = layers.UpSampling2D(2, dtype='float32')(x)
    x = layers.Concatenate(dtype='float32')([x, skip3])
    x = layers.Conv2D(256, 3, padding='same', dtype='float32')(x)
    x = layers.BatchNormalization(dtype='float32')(x)
    x = layers.ReLU(dtype='float32')(x)

    # Block 2
    x = layers.UpSampling2D(2, dtype='float32')(x)
    x = layers.Concatenate(dtype='float32')([x, skip2])
    x = layers.Conv2D(128, 3, padding='same', dtype='float32')(x)
    x = layers.BatchNormalization(dtype='float32')(x)
    x = layers.ReLU(dtype='float32')(x)

    # Block 3
    x = layers.UpSampling2D(2, dtype='float32')(x)
    x = layers.Concatenate(dtype='float32')([x, skip1])
    x = layers.Conv2D(64, 3, padding='same', dtype='float32')(x)
    x = layers.BatchNormalization(dtype='float32')(x)
    x = layers.ReLU(dtype='float32')(x)

    # Output
    outputs = layers.Conv2D(num_classes, 1, activation='sigmoid', dtype='float32')(x)

    model = models.Model(inputs=inputs, outputs=outputs)
    return model

# Advanced Learning Rate Scheduling (same as before)
def warmup_cosine_schedule(epoch, warmup_epochs=15, total_epochs=150, base_lr=0.001, min_lr=1e-7):
    """Warmup + Cosine annealing learning rate schedule function"""
    if epoch < warmup_epochs:
        lr = base_lr * (epoch + 1) / warmup_epochs
    else:
        progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
        lr = min_lr + (base_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
    
    return lr

def get_optimized_callbacks(output_path, total_epochs=150):
    """Get optimized callbacks for maximum IoU performance"""
    
    callbacks_list = [
        callbacks.LearningRateScheduler(
            lambda epoch: warmup_cosine_schedule(
                epoch, 
                warmup_epochs=15, 
                total_epochs=total_epochs, 
                base_lr=0.001, 
                min_lr=1e-7
            ),
            verbose=1
        ),
        
        callbacks.ModelCheckpoint(
            filepath=os.path.join(output_path, "best_iou_change_detection_model.keras"),
            monitor='val_iou_score',
            save_best_only=True,
            mode='max',
            verbose=1,
            save_weights_only=False
        ),
        
        callbacks.ModelCheckpoint(
            filepath=os.path.join(output_path, "best_dice_change_detection_model.keras"),
            monitor='val_iou_score',
            save_best_only=True,
            mode='max',
            verbose=1,
            save_weights_only=False
        ),
        
        callbacks.EarlyStopping(
            monitor='val_iou_score',
            patience=40,
            restore_best_weights=True,
            mode='max',
            verbose=1
        ),
        
        callbacks.ReduceLROnPlateau(
            monitor='val_iou_score',
            factor=0.5,
            patience=12,
            min_lr=1e-8,
            mode='max',
            verbose=1
        ),
        
        callbacks.TensorBoard(
            log_dir=os.path.join(output_path, "logs"),
            histogram_freq=1,
            update_freq='epoch',
            write_graph=True,
            write_images=True,
            profile_batch=0
        ),
        
        callbacks.CSVLogger(
            os.path.join(output_path, 'change_detection_training_log.csv'),
            separator=',',
            append=False
        )
    ]
    
    return callbacks_list

def plot_history(history):
    """Plot comprehensive training history with all metrics"""
    plt.figure(figsize=(20, 15))

    # Plot loss
    plt.subplot(3, 2, 1)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot Dice coefficient
    plt.subplot(3, 2, 2)
    plt.plot(history.history['dice_coefficient'], label='Training Dice')
    plt.plot(history.history['val_dice_coefficient'], label='Validation Dice')
    plt.title('Dice Coefficient')
    plt.xlabel('Epoch')
    plt.ylabel('Dice')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot IoU
    plt.subplot(3, 2, 3)
    plt.plot(history.history['iou_score'], label='Training IoU')
    plt.plot(history.history['val_iou_score'], label='Validation IoU')
    plt.title('IoU Score')
    plt.xlabel('Epoch')
    plt.ylabel('IoU')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot binary accuracy
    plt.subplot(3, 2, 4)
    plt.plot(history.history['binary_accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_binary_accuracy'], label='Validation Accuracy')
    plt.title('Binary Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot F1 Score
    plt.subplot(3, 2, 5)
    plt.plot(history.history['f1_score_metric'], label='Training F1')
    plt.plot(history.history['val_f1_score_metric'], label='Validation F1')
    plt.title('F1 Score')
    plt.xlabel('Epoch')
    plt.ylabel('F1')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot Precision and Recall
    plt.subplot(3, 2, 6)
    plt.plot(history.history['precision_metric'], label='Training Precision')
    plt.plot(history.history['val_precision_metric'], label='Validation Precision')
    plt.plot(history.history['recall_metric'], label='Training Recall')
    plt.plot(history.history['val_recall_metric'], label='Validation Recall')
    plt.title('Precision and Recall')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'change_detection_training_history.png'), dpi=300, bbox_inches='tight')
    plt.show()

def visualize_change_detection_predictions(model, dataset, num_samples=5):
    """Visualize change detection model predictions against ground truth"""
    plt.figure(figsize=(20, 5*num_samples))

    for i, (images, masks) in enumerate(dataset.take(num_samples)):
        if i >= num_samples:
            break

        preds = model.predict(images, verbose=0)

        for j in range(min(images.shape[0], 1)):  # Show one per batch
            image = images[j].numpy()
            mask = masks[j].numpy()
            pred = preds[j]

            # Split the 6-channel image back to pre and post flood
            pre_flood = image[:, :, :3]
            post_flood = image[:, :, 3:]

            # Calculate difference
            difference = np.abs(post_flood - pre_flood)

            # Convert masks to binary
            mask_binary = (mask > 0.5).astype(np.float32)
            pred_binary = (pred > 0.5).astype(np.float32)

            # Calculate metrics for this sample
            dice = np.sum(2 * mask_binary * pred_binary) / (np.sum(mask_binary) + np.sum(pred_binary) + 1e-8)

            row_idx = i

            # Plot pre-flood image
            plt.subplot(num_samples, 5, row_idx * 5 + 1)
            plt.imshow(pre_flood)
            plt.title(f"Pre-Flood - Sample {row_idx+1}")
            plt.axis('off')

            # Plot post-flood image
            plt.subplot(num_samples, 5, row_idx * 5 + 2)
            plt.imshow(post_flood)
            plt.title(f"Post-Flood - Sample {row_idx+1}")
            plt.axis('off')

            # Plot difference
            plt.subplot(num_samples, 5, row_idx * 5 + 3)
            plt.imshow(difference)
            plt.title(f"Difference - Sample {row_idx+1}")
            plt.axis('off')

            # Plot ground truth mask
            plt.subplot(num_samples, 5, row_idx * 5 + 4)
            plt.imshow(mask.squeeze(), cmap='Blues')
            plt.title(f"Ground Truth")
            plt.axis('off')

            # Plot prediction
            plt.subplot(num_samples, 5, row_idx * 5 + 5)
            plt.imshow(pred.squeeze(), cmap='Blues')
            plt.title(f"Prediction (Dice={dice:.3f})")
            plt.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'change_detection_predictions.png'))
    plt.show()

def calculate_metrics(model, dataset, num_batches=None, threshold=0.5):
    """Calculate comprehensive metrics on dataset with pixel-level and image-level evaluation"""
    y_true_all = []
    y_pred_all = []
    dice_scores = []
    iou_scores = []

    image_metrics = []
    batch_count = 0
    
    for images, masks in dataset:
        if num_batches is not None and batch_count >= num_batches:
            break
            
        batch_count += 1
        
        preds = model.predict(images, verbose=0)

        for i in range(len(images)):
            mask = masks[i].numpy()
            pred = preds[i]

            mask_flat = mask.flatten()
            pred_flat = pred.flatten()

            pred_binary = (pred_flat > threshold).astype(np.int32)
            mask_binary = (mask_flat > threshold).astype(np.int32)

            y_true_all.extend(mask_binary)
            y_pred_all.extend(pred_binary)

            # Calculate per-image Dice
            intersection = np.sum(mask_binary * pred_binary)
            dice = (2. * intersection) / (np.sum(mask_binary) + np.sum(pred_binary) + 1e-8)
            dice_scores.append(dice)

            # Calculate per-image IoU
            union = np.sum(mask_binary) + np.sum(pred_binary) - intersection
            iou = intersection / (union + 1e-8)
            iou_scores.append(iou)

            # Per-image confusion matrix
            img_tn, img_fp, img_fn, img_tp = confusion_matrix(mask_binary, pred_binary, labels=[0, 1]).ravel()

            # Per-image metrics
            img_precision = img_tp / (img_tp + img_fp + 1e-8)
            img_recall = img_tp / (img_tp + img_fn + 1e-8)
            img_f1 = 2 * (img_precision * img_recall) / (img_precision + img_recall + 1e-8)
            img_accuracy = (img_tp + img_tn) / (img_tp + img_tn + img_fp + img_fn)

            image_metrics.append({
                'dice': dice,
                'iou': iou,
                'precision': img_precision,
                'recall': img_recall,
                'f1': img_f1,
                'accuracy': img_accuracy,
                'tp': int(img_tp),
                'fp': int(img_fp),
                'tn': int(img_tn),
                'fn': int(img_fn)
            })

    # Calculate overall metrics from all pixels
    tn, fp, fn, tp = confusion_matrix(y_true_all, y_pred_all, labels=[0, 1]).ravel()

    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    specificity = tn / (tn + fp + 1e-8)

    intersection = tp
    union = tp + fp + fn
    iou = intersection / (union + 1e-8)

    mean_dice = np.mean(dice_scores)
    mean_iou = np.mean(iou_scores)

    print("\n======== Change Detection Metrics ========")
    print(f"Overall Metrics (pixel-level):")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall/Sensitivity: {recall:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"IoU: {iou:.4f}")
    print(f"Dice Coefficient: {(2 * tp) / (2 * tp + fp + fn + 1e-8):.4f}")

    print("\nMean Per-Image Metrics:")
    print(f"Mean Dice: {mean_dice:.4f}")
    print(f"Mean IoU: {mean_iou:.4f}")

    print("\nConfusion Matrix:")
    print(f"True Positives: {tp}")
    print(f"False Positives: {fp}")
    print(f"True Negatives: {tn}")
    print(f"False Negatives: {fn}")

    # Save per-image metrics to CSV
    img_metrics_df = pd.DataFrame(image_metrics)
    img_metrics_df.to_csv(os.path.join(OUTPUT_PATH, 'change_detection_per_image_metrics.csv'), index_label='image_id')
    print("\nPer-image metrics saved to CSV file")

    return {
        'overall': {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'specificity': float(specificity),
            'f1': float(f1),
            'iou': float(iou),
            'dice': float((2 * tp) / (2 * tp + fp + fn + 1e-8))
        },
        'per_image_mean': {
            'dice': float(mean_dice),
            'iou': float(mean_iou)
        },
        'confusion_matrix': {
            'tn': int(tn),
            'fp': int(fp),
            'fn': int(fn),
            'tp': int(tp)
        }
    }

# Main execution for Change Detection Training
BATCH_SIZE = 24
EPOCHS = 50
USE_AUGMENTATION = True

print("=== CHANGE DETECTION FLOOD TRAINING ===")
print(f"Configuration:")
print(f"- Input: Pre-flood + Post-flood images (6 channels)")
print(f"- Output: Flood detection mask")
print(f"- Batch Size: {BATCH_SIZE}")
print(f"- Total Epochs: {EPOCHS}")
print(f"- Data Augmentation: {USE_AUGMENTATION}")
print(f"- Architecture: Enhanced ResUNet for Change Detection")
print("="*50)

# Create datasets for change detection
train_dataset, train_size = create_change_detection_dataset(
    BASE_PATH, 'train', 
    batch_size=BATCH_SIZE, 
    use_augmentation=USE_AUGMENTATION
)
val_dataset, val_size = create_change_detection_dataset(
    BASE_PATH, 'val', 
    batch_size=BATCH_SIZE, 
    use_augmentation=False
)
test_dataset, test_size = create_change_detection_dataset(
    BASE_PATH, 'test', 
    batch_size=BATCH_SIZE, 
    use_augmentation=False
)

print(f"Training dataset size: {train_size} image pairs")
print(f"Validation dataset size: {val_size} image pairs")
print(f"Test dataset size: {test_size} image pairs")

# Visualize some samples
visualize_change_detection_samples(train_dataset, num_samples=2)

# Build change detection model
input_shape = (256, 256, 6)  # 6 channels for pre+post flood

print("Building Change Detection ResUNet model...")
try:
    model = build_change_detection_resunet(input_shape)
    print("✓ Enhanced Change Detection ResUNet model built successfully")
except Exception as e:
    print(f"⚠ Enhanced model failed: {e}")
    print("Falling back to simpler Change Detection ResUNet model...")
    model = build_simple_change_detection_resunet(input_shape)
    print("✓ Simple Change Detection ResUNet model built successfully")

# Get optimized optimizer
optimizer = optimizers.AdamW(
    learning_rate=0.001,
    weight_decay=0.01,
    beta_1=0.9,
    beta_2=0.999,
    epsilon=1e-7,
    clipnorm=1.0
)

# Compile model
model.compile(
    optimizer=optimizer,
    loss=optimized_focal_tversky_loss,
    metrics=[
        dice_coefficient,
        iou_score,
        'binary_accuracy',
        f1_score_metric,
        precision_metric,
        recall_metric
    ]
)

model.summary()

# Get callbacks
callbacks_list = get_optimized_callbacks(OUTPUT_PATH, EPOCHS)

# Calculate steps
steps_per_epoch = max(1, train_size // BATCH_SIZE)
validation_steps = max(1, val_size // BATCH_SIZE)

print(f"Steps per epoch: {steps_per_epoch}")
print(f"Validation steps: {validation_steps}")

# Train the change detection model
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=EPOCHS,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    callbacks=callbacks_list,
    verbose=1
)

# Plot training history
plot_history(history)

# Load and evaluate best model
best_iou_model_path = os.path.join(OUTPUT_PATH, "best_iou_change_detection_model.keras")

# Custom objects dictionary with all custom functions and layers
custom_objects_dict = {
    'optimized_focal_tversky_loss': optimized_focal_tversky_loss,
    'dice_coefficient': dice_coefficient,
    'iou_score': iou_score,
    'f1_score_metric': f1_score_metric,
    'precision_metric': precision_metric,
    'recall_metric': recall_metric,
    'ChannelMean': ChannelMean,
    'ChannelMax': ChannelMax
}

try:
    # Try loading with safe mode first
    best_iou_model = models.load_model(best_iou_model_path, custom_objects=custom_objects_dict)
    print("✓ Model loaded successfully with safe mode")
except ValueError as e:
    if "Lambda" in str(e) and "unsafe" in str(e):
        print("⚠ Model contains Lambda layers, loading with safe_mode=False...")
        # If the model contains Lambda layers, we need to disable safe mode
        # This is safe in our case since we know the model architecture
        import keras
        keras.config.enable_unsafe_deserialization()
        best_iou_model = models.load_model(best_iou_model_path, custom_objects=custom_objects_dict)
        print("✓ Model loaded successfully with unsafe deserialization")
    else:
        print(f"✗ Failed to load model: {e}")
        raise e

# Evaluate on test set
print("Evaluating best IoU model on test set...")
test_steps = max(1, test_size // BATCH_SIZE)
test_results = best_iou_model.evaluate(test_dataset, steps=test_steps, verbose=1)
print("\nChange Detection Test Results:")
for metric_name, value in zip(best_iou_model.metrics_names, test_results):
    print(f"{metric_name}: {value:.4f}")

# Save test results
test_metrics = {metric_name: float(value) for metric_name, value in zip(best_iou_model.metrics_names, test_results)}
with open(os.path.join(OUTPUT_PATH, 'change_detection_test_metrics.json'), 'w') as f:
    json.dump(test_metrics, f, indent=4)

# Visualize predictions
visualize_change_detection_predictions(best_iou_model, test_dataset, num_samples=5)

# Calculate detailed metrics
test_batches_for_eval = min(15, test_size // BATCH_SIZE)
print(f"Calculating detailed metrics on {test_batches_for_eval} test batches...")
detailed_metrics = calculate_metrics(best_iou_model, test_dataset, num_batches=test_batches_for_eval)

# Save detailed metrics
with open(os.path.join(OUTPUT_PATH, 'change_detection_detailed_metrics.json'), 'w') as f:
    json.dump(detailed_metrics, f, indent=4)

# Save final model
best_iou_model.save(os.path.join(OUTPUT_PATH, 'final_change_detection_flood_model.keras'))

print("\n" + "="*50)
print("CHANGE DETECTION TRAINING COMPLETE!")
print("="*50)
print("Key Features:")
print("✓ 6-channel input (pre-flood + post-flood images)")
print("✓ Enhanced ResUNet architecture for change detection")
print("✓ Optimized for flood detection from temporal image pairs")
print("✓ Comprehensive evaluation metrics")
print("✓ Robust data augmentation for change detection")
print("="*50)
print(f"All outputs saved to: {OUTPUT_PATH}")
print("="*50)