In [None]:
#dataset preprocessing with newband1 and 2

!pip install rasterio

import os
import csv
import numpy as np
import rasterio
from PIL import Image
from tqdm import tqdm
import shutil
from skimage.transform import resize

class ResUNetPreprocessor:
    def __init__(self, base_path, output_path, img_size=512):  # Changed from 256 to 512
        self.base_path = base_path
        self.output_path = output_path
        self.img_size = img_size
        
        # Initialize normalization parameters (will be calculated from data)
        self.norm_means = None
        self.norm_stds = None
        
        self.create_output_dirs()
        
        # Define ResUNet dataset structure for Kaggle environment
        self.image_dir = os.path.join('data', 'flood_events', 'HandLabeled', 'S1Hand')
        self.mask_dir = os.path.join('data', 'flood_events', 'HandLabeled', 'LabelHand')
        self.splits_dir = os.path.join('splits', 'flood_handlabeled')
        
        # Statistics dictionary to track dataset properties (only valid pixels)
        self.stats = {
            'train': {'count': 0, 'flood_pixels': 0, 'total_valid_pixels': 0, 'invalid_pixels': 0},
            'val': {'count': 0, 'flood_pixels': 0, 'total_valid_pixels': 0, 'invalid_pixels': 0},
            'test': {'count': 0, 'flood_pixels': 0, 'total_valid_pixels': 0, 'invalid_pixels': 0}
        }

    def create_output_dirs(self):
        # In Kaggle, we can write to /kaggle/working
        if os.path.exists(self.output_path):
            shutil.rmtree(self.output_path)
        
        for split in ['train', 'val', 'test']:
            os.makedirs(os.path.join(self.output_path, split, 'images'), exist_ok=True)
            os.makedirs(os.path.join(self.output_path, split, 'masks'), exist_ok=True)
            os.makedirs(os.path.join(self.output_path, split, 'validity_masks'), exist_ok=True)
            
        print(f"Created output directories at {self.output_path}")

    def read_csv_file(self, csv_path):
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV file not found: {csv_path}")
            
        with open(csv_path, 'r') as f:
            reader = csv.reader(f)
            next(reader, None)  # Skip header
            return [(row[0], row[1]) for row in reader]

    def compute_bands(self, vh, vv):
        """
        Compute the three bands according to the specified formulas:
        Band 1: VV
        Band 2: NewBand1 = (VH - VV) / (VH + VV)
        Band 3: NewBand2 = sqrt((VH^2 + VV^2) / 2)
        """
        eps = 1e-8
        
        # Band 1: VV
        band1 = vv
        
        # Band 2: NewBand1 = (VH - VV) / (VH + VV)
        band2 = np.divide(vh - vv, vh + vv + eps)
        
        # Band 3: NewBand2 = sqrt((VH^2 + VV^2) / 2)
        band3 = np.sqrt((vh**2 + vv**2) / 2)
        
        return band1, band2, band3

    def process_image_for_stats(self, im_path):
        """Process image to collect statistics (first pass)"""
        try:
            with rasterio.open(im_path) as src:
                # Read VH and VV bands
                vh = src.read(1)
                vv = src.read(2)
                
                # Handle NaN and infinite values
                vh = np.nan_to_num(vh)
                vv = np.nan_to_num(vv)
                
                # Compute the three bands
                band1, band2, band3 = self.compute_bands(vh, vv)
                
                # Create 3-channel image
                arr_x = np.stack([band1, band2, band3], axis=0)
                
                # Clip extreme values (common in SAR preprocessing)
                for i in range(3):
                    v_min, v_max = np.percentile(arr_x[i], [1, 99])
                    arr_x[i] = np.clip(arr_x[i], v_min, v_max)
                
                # Resize to target dimensions - UPDATED to 512x512
                arr_x = np.stack([
                    resize(arr_x[0], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[1], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[2], (self.img_size, self.img_size), preserve_range=True)
                ], axis=0)
                
                return arr_x
                
        except Exception as e:
            print(f"Error processing image for stats {im_path}: {str(e)}")
            return None

    def calculate_normalization_params(self, train_csv_path):
        """Calculate mean and std from training data"""
        print("Calculating normalization parameters from training data...")
        
        file_pairs = self.read_csv_file(train_csv_path)
        
        # Collect all pixel values for each band
        all_pixels = [[] for _ in range(3)]
        
        for im_fname, _ in tqdm(file_pairs, desc="Collecting statistics"):
            im_path = os.path.join(self.base_path, self.image_dir, im_fname)
            
            if not os.path.exists(im_path):
                continue
                
            arr_x = self.process_image_for_stats(im_path)
            if arr_x is not None:
                for i in range(3):
                    # Flatten and append pixels
                    all_pixels[i].append(arr_x[i].flatten())
        
        # Calculate means and stds
        means = []
        stds = []
        
        for i in range(3):
            if all_pixels[i]:
                combined_pixels = np.concatenate(all_pixels[i])
                means.append(np.mean(combined_pixels))
                stds.append(np.std(combined_pixels))
            else:
                means.append(0.0)
                stds.append(1.0)
        
        self.norm_means = np.array(means)
        self.norm_stds = np.array(stds)
        
        print(f"Calculated normalization parameters for {self.img_size}×{self.img_size} images:")
        print(f"Band 1 (VV): mean={self.norm_means[0]:.4f}, std={self.norm_stds[0]:.4f}")
        print(f"Band 2 (NewBand1): mean={self.norm_means[1]:.4f}, std={self.norm_stds[1]:.4f}")
        print(f"Band 3 (NewBand2): mean={self.norm_means[2]:.4f}, std={self.norm_stds[2]:.4f}")
        
        # Save normalization parameters
        norm_params_path = os.path.join(self.output_path, 'normalization_params.npy')
        np.save(norm_params_path, {'means': self.norm_means, 'stds': self.norm_stds})
        print(f"Normalization parameters saved to: {norm_params_path}")

    def process_image(self, im_path):
        """Process image with normalization (second pass)"""
        try:
            with rasterio.open(im_path) as src:
                # Read VH and VV bands
                vh = src.read(1)
                vv = src.read(2)
                
                # Handle NaN and infinite values
                vh = np.nan_to_num(vh)
                vv = np.nan_to_num(vv)
                
                # Compute the three bands
                band1, band2, band3 = self.compute_bands(vh, vv)
                
                # Create 3-channel image
                arr_x = np.stack([band1, band2, band3], axis=0)
                
                # Clip extreme values (common in SAR preprocessing)
                for i in range(3):
                    v_min, v_max = np.percentile(arr_x[i], [1, 99])
                    arr_x[i] = np.clip(arr_x[i], v_min, v_max)
                
                # Resize to target dimensions - UPDATED to 512x512
                arr_x = np.stack([
                    resize(arr_x[0], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[1], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[2], (self.img_size, self.img_size), preserve_range=True)
                ], axis=0)
                
                # Normalize using calculated means and stds
                if self.norm_means is not None and self.norm_stds is not None:
                    arr_x = (arr_x - self.norm_means.reshape(3, 1, 1)) / self.norm_stds.reshape(3, 1, 1)
                
                # Convert to HWC format for saving as image
                arr_x = np.transpose(arr_x, (1, 2, 0))
                
                # Scale to 0-1 range for visualization
                eps = 1e-8
                arr_x_viz = (arr_x - arr_x.min()) / (arr_x.max() - arr_x.min() + eps)
                
                return arr_x, arr_x_viz
                
        except Exception as e:
            print(f"Error processing image {im_path}: {str(e)}")
            return None, None

    def process_mask(self, mask_path):
        """
        Process mask preserving -1 values and creating validity mask
        Returns: (ground_truth_mask, validity_mask)
        """
        try:
            with rasterio.open(mask_path) as src:
                arr_y = src.read(1)
            
            # Resize to target dimensions using nearest neighbor to preserve labels - UPDATED to 512x512
            arr_y = resize(arr_y, (self.img_size, self.img_size), order=0, preserve_range=True)
            
            # Create validity mask: True for valid pixels (0 or 1), False for invalid (-1)
            validity_mask = (arr_y != -1).astype(np.uint8)
            
            # Create ground truth mask: convert -1 to 0 for network compatibility,
            # but we'll use the validity mask to ignore these during training
            ground_truth_mask = arr_y.copy()
            ground_truth_mask[arr_y == -1] = 0  # Temporary conversion for network
            ground_truth_mask = (ground_truth_mask > 0).astype(np.uint8)
            
            return ground_truth_mask, validity_mask
            
        except Exception as e:
            print(f"Error processing mask {mask_path}: {str(e)}")
            return None, None

    def save_png(self, arr, save_path, mode='RGB'):
        """Save array as PNG image"""
        # Scale to 0-255 range for 8-bit image
        img = Image.fromarray((arr * 255).astype(np.uint8), mode=mode)
        img.save(save_path)

    def save_npy(self, arr, save_path):
        """Save raw array data as NPY file for preserving exact values"""
        np.save(save_path, arr)

    def update_stats(self, split, ground_truth_mask, validity_mask):
        """Update dataset statistics - only count valid pixels"""
        valid_pixels = validity_mask.astype(bool)
        
        self.stats[split]['count'] += 1
        self.stats[split]['flood_pixels'] += ground_truth_mask[valid_pixels].sum()
        self.stats[split]['total_valid_pixels'] += valid_pixels.sum()
        self.stats[split]['invalid_pixels'] += (~valid_pixels).sum()

    def process_dataset(self, split_name, csv_path):
        print(f"Processing {split_name} dataset with {self.img_size}×{self.img_size} resolution...")
        file_pairs = self.read_csv_file(csv_path)
        output_dir = os.path.join(self.output_path, split_name)
        
        for idx, (im_fname, mask_fname) in enumerate(tqdm(file_pairs, desc=f"Processing {split_name}")):
            im_path = os.path.join(self.base_path, self.image_dir, im_fname)
            mask_path = os.path.join(self.base_path, self.mask_dir, mask_fname)
            
            if not os.path.exists(im_path) or not os.path.exists(mask_path):
                print(f"Warning: Files not found - {im_path} or {mask_path}")
                continue
            
            # Process image (get both normalized data and visualization)
            arr_x, arr_x_viz = self.process_image(im_path)
            if arr_x is not None:
                # Save visualization as PNG
                img_save_path = os.path.join(output_dir, 'images', f'{split_name}_{idx:04d}.png')
                self.save_png(arr_x_viz, img_save_path, mode='RGB')
                
                # Save raw normalized data for exact values
                raw_save_path = os.path.join(output_dir, 'images', f'{split_name}_{idx:04d}.npy')
                self.save_npy(arr_x, raw_save_path)
            
            # Process mask - now returns both ground truth and validity masks
            result = self.process_mask(mask_path)
            if result is not None:
                ground_truth_mask, validity_mask = result
                
                # Save ground truth mask as PNG
                mask_save_path = os.path.join(output_dir, 'masks', f'{split_name}_{idx:04d}.png')
                self.save_png(ground_truth_mask, mask_save_path, mode='L')
                
                # Save validity mask as PNG
                validity_save_path = os.path.join(output_dir, 'validity_masks', f'{split_name}_{idx:04d}.png')
                self.save_png(validity_mask, validity_save_path, mode='L')
                
                # Also save as NPY for exact values
                mask_npy_path = os.path.join(output_dir, 'masks', f'{split_name}_{idx:04d}.npy')
                validity_npy_path = os.path.join(output_dir, 'validity_masks', f'{split_name}_{idx:04d}.npy')
                self.save_npy(ground_truth_mask, mask_npy_path)
                self.save_npy(validity_mask, validity_npy_path)
                
                # Update statistics
                self.update_stats(split_name, ground_truth_mask, validity_mask)

    def print_stats(self):
        """Print dataset statistics"""
        print(f"\nDataset Statistics for {self.img_size}×{self.img_size} images (Only Valid Pixels):")
        print("=" * 70)
        for split, stat in self.stats.items():
            if stat['count'] > 0:
                flood_percentage = 100 * stat['flood_pixels'] / stat['total_valid_pixels'] if stat['total_valid_pixels'] > 0 else 0
                invalid_percentage = 100 * stat['invalid_pixels'] / (stat['total_valid_pixels'] + stat['invalid_pixels'])
                print(f"{split.upper()} set: {stat['count']} samples")
                print(f"  Valid pixels: {stat['total_valid_pixels']:,} ({100-invalid_percentage:.1f}%)")
                print(f"  Invalid pixels: {stat['invalid_pixels']:,} ({invalid_percentage:.1f}%)")
                print(f"  Flood pixels (of valid): {stat['flood_pixels']:,} ({flood_percentage:.2f}%)")
                print()
        print("=" * 70)

    def calculate_class_weights(self):
        """Calculate class weights to handle imbalance (only for valid pixels)"""
        if self.stats['train']['total_valid_pixels'] > 0:
            pos_ratio = self.stats['train']['flood_pixels'] / self.stats['train']['total_valid_pixels']
            neg_ratio = 1 - pos_ratio
            
            # Class weights inversely proportional to class frequency
            weight_non_flood = 1.0
            weight_flood = neg_ratio / pos_ratio if pos_ratio > 0 else 1.0
            
            print(f"\nClass weights for handling imbalance (valid pixels only, {self.img_size}×{self.img_size}):")
            print(f"Weight for non-flood (0): {weight_non_flood:.4f}")
            print(f"Weight for flood (1): {weight_flood:.4f}")
            
            # Save weights for model training
            return np.array([weight_non_flood, weight_flood])
        return np.array([1.0, 1.0])

# Kaggle environment setup
print("Input path:", "/kaggle/input/sen1floods11-essentials/v1.2")
print("Output path:", "/kaggle/working/preprocessed_512")
print("\nBand Configuration:")
print("Band 1: VV")
print("Band 2: NewBand1 = (VH - VV) / (VH + VV)")
print("Band 3: NewBand2 = sqrt((VH² + VV²) / 2)")


# Kaggle paths
base_path = "/kaggle/input/sen1floods11-essentials/v1.2"
output_path = "/kaggle/working/preprocessed_512"  # Updated output path

# Initialize preprocessor with 512×512 resolution
preprocessor = ResUNetPreprocessor(base_path, output_path, img_size=512)

# Step 1: Calculate normalization parameters from training data
train_csv_path = os.path.join(base_path, 'splits', 'flood_handlabeled', 'flood_train_data.csv')
preprocessor.calculate_normalization_params(train_csv_path)

print(f"\nStarting dataset preprocessing with 512×512 resolution and calculated normalization parameters...")

# Step 2: Process all datasets using calculated normalization
# Process train set
preprocessor.process_dataset('train', train_csv_path)

# Process validation set
val_csv_path = os.path.join(base_path, 'splits', 'flood_handlabeled', 'flood_val_data.csv')
preprocessor.process_dataset('val', val_csv_path)

# Process test set
test_csv_path = os.path.join(base_path, 'splits', 'flood_handlabeled', 'flood_test_data.csv')
preprocessor.process_dataset('test', test_csv_path)

# Print statistics and calculate class weights
preprocessor.print_stats()
weights = preprocessor.calculate_class_weights()

# Save class weights for later use
weights_path = os.path.join(output_path, 'class_weights.npy')
np.save(weights_path, weights)
print(f"\nClass weights saved to: {weights_path}")

print(f"\nPreprocessing complete! 512×512 processed data saved to: {output_path}")
print("\nOutput structure:")
print("preprocessed_512/")
print("├── train/")
print("│   ├── images/ (PNG and NPY files - 512×512)")
print("│   ├── masks/ (PNG and NPY files - ground truth - 512×512)")
print("│   └── validity_masks/ (PNG and NPY files - valid pixel indicators - 512×512)")
print("├── val/")
print("│   ├── images/ (PNG and NPY files - 512×512)")
print("│   ├── masks/ (PNG and NPY files - ground truth - 512×512)")
print("│   └── validity_masks/ (PNG and NPY files - valid pixel indicators - 512×512)")
print("├── test/")
print("│   ├── images/ (PNG and NPY files - 512×512)")
print("│   ├── masks/ (PNG and NPY files - ground truth - 512×512)")
print("│   └── validity_masks/ (PNG and NPY files - valid pixel indicators - 512×512)")
print("├── class_weights.npy")
print("└── normalization_params.npy")



print(f"\n MEMORY USAGE NOTE:")
print(f"512×512 images use 4x more memory than 256×256")
print(f"You may need to reduce batch size during training")
print(f"Recommended batch size: 16-24 instead of 32 depending on GPU memory")

In [None]:
# UNet for 512x512 - Standard Architecture
# Optimized for Tesla P100 16GB GPU

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, optimizers
from tensorflow.keras import mixed_precision
import matplotlib.pyplot as plt
import glob
import json
from PIL import Image
import gc

# Enable mixed precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

# Configuration
INPUT_SIZE = 512
INPUT_CHANNELS = 2
BATCH_SIZE = 2
EPOCHS = 200
LEARNING_RATE = 1e-4

# Paths
DATA_PATH = "/kaggle/working/preprocessed_512"
OUTPUT_PATH = "/kaggle/working/unet_standard"
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Set random seeds
np.random.seed(42)
tf.random.set_seed(42)

def check_gpu_memory():
    try:
        gpu_devices = tf.config.list_physical_devices('GPU')
        if gpu_devices:
            print(f"GPU detected: {len(gpu_devices)} device(s)")
            return True
        else:
            print("No GPU detected")
            return False
    except Exception as e:
        print(f"Could not check GPU: {e}")
        return False

gpu_available = check_gpu_memory()

def load_data_batch_by_batch(base_path, split, max_samples=None):
    print(f"Loading {split} data from {base_path}/{split}")
    
    img_files = sorted(glob.glob(os.path.join(base_path, split, 'images', '*.npy')))
    mask_files = sorted(glob.glob(os.path.join(base_path, split, 'masks', '*.npy')))
    
    if len(img_files) == 0:
        raise ValueError(f"No data found in {base_path}/{split}")
    
    if max_samples:
        img_files = img_files[:max_samples]
        mask_files = mask_files[:max_samples]
    
    print(f"Processing {len(img_files)} files for {split}")
    
    batch_size = 20
    all_images = []
    all_masks = []
    
    for i in range(0, len(img_files), batch_size):
        batch_img_files = img_files[i:i+batch_size]
        batch_mask_files = mask_files[i:i+batch_size]
        
        batch_images = []
        batch_masks = []
        
        for img_file, mask_file in zip(batch_img_files, batch_mask_files):
            img = np.load(img_file).astype(np.float16)
            img = img[:, :, :INPUT_CHANNELS]
            
            mask = np.load(mask_file).astype(np.float16)
            
            batch_images.append(img)
            batch_masks.append(mask)
        
        all_images.extend(batch_images)
        all_masks.extend(batch_masks)
        
        del batch_images, batch_masks
        gc.collect()
    
    images = np.array(all_images, dtype=np.float16)
    masks = np.array(all_masks, dtype=np.float16)
    
    print(f"{split} shape: images {images.shape}, masks {masks.shape}")
    return images, masks

class MemoryEfficientGenerator(tf.keras.utils.Sequence):
    def __init__(self, images, masks, batch_size=2, shuffle=True, augment=False):
        self.images = images
        self.masks = masks
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.augment = augment
        self.indices = np.arange(len(self.images))
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.ceil(len(self.images) / self.batch_size))
    
    def __getitem__(self, index):
        start_idx = index * self.batch_size
        end_idx = min((index + 1) * self.batch_size, len(self.images))
        batch_indices = self.indices[start_idx:end_idx]
        
        batch_images = self.images[batch_indices].astype(np.float32)
        batch_masks = self.masks[batch_indices].astype(np.float32)
        
        if self.augment:
            batch_images, batch_masks = self.augment_batch(batch_images, batch_masks)
        
        batch_masks = np.expand_dims(batch_masks, axis=-1)
        
        return batch_images, batch_masks
    
    def augment_batch(self, images, masks):
        augmented_images = []
        augmented_masks = []
        
        for img, mask in zip(images, masks):
            if np.random.random() > 0.5:
                img = np.fliplr(img)
                mask = np.fliplr(mask)
            
            if np.random.random() > 0.5:
                img = np.flipud(img)
                mask = np.flipud(mask)
            
            augmented_images.append(img)
            augmented_masks.append(mask)
        
        return np.array(augmented_images), np.array(augmented_masks)
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

def build_unet(input_shape=(512, 512, 2)):
    """Standard U-Net architecture"""
    inputs = layers.Input(shape=input_shape, dtype='float32')
    
    # Encoder
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same', dtype='float32')(inputs)
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same', dtype='float32')(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same', dtype='float32')(pool1)
    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same', dtype='float32')(conv2)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same', dtype='float32')(pool2)
    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same', dtype='float32')(conv3)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same', dtype='float32')(pool3)
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same', dtype='float32')(conv4)
    pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4)
    
    # Bottom
    conv5 = layers.Conv2D(1024, 3, activation='relu', padding='same', dtype='float32')(pool4)
    conv5 = layers.Conv2D(1024, 3, activation='relu', padding='same', dtype='float32')(conv5)
    
    # Decoder
    up6 = layers.Conv2DTranspose(512, 2, strides=(2, 2), padding='same', dtype='float32')(conv5)
    up6 = layers.concatenate([up6, conv4])
    conv6 = layers.Conv2D(512, 3, activation='relu', padding='same', dtype='float32')(up6)
    conv6 = layers.Conv2D(512, 3, activation='relu', padding='same', dtype='float32')(conv6)
    
    up7 = layers.Conv2DTranspose(256, 2, strides=(2, 2), padding='same', dtype='float32')(conv6)
    up7 = layers.concatenate([up7, conv3])
    conv7 = layers.Conv2D(256, 3, activation='relu', padding='same', dtype='float32')(up7)
    conv7 = layers.Conv2D(256, 3, activation='relu', padding='same', dtype='float32')(conv7)
    
    up8 = layers.Conv2DTranspose(128, 2, strides=(2, 2), padding='same', dtype='float32')(conv7)
    up8 = layers.concatenate([up8, conv2])
    conv8 = layers.Conv2D(128, 3, activation='relu', padding='same', dtype='float32')(up8)
    conv8 = layers.Conv2D(128, 3, activation='relu', padding='same', dtype='float32')(conv8)
    
    up9 = layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same', dtype='float32')(conv8)
    up9 = layers.concatenate([up9, conv1])
    conv9 = layers.Conv2D(64, 3, activation='relu', padding='same', dtype='float32')(up9)
    conv9 = layers.Conv2D(64, 3, activation='relu', padding='same', dtype='float32')(conv9)
    
    outputs = layers.Conv2D(1, 1, activation='sigmoid', dtype='float32')(conv9)
    
    model = models.Model(inputs=inputs, outputs=outputs, name='UNet')
    return model

print("Loading datasets...")
train_images, train_masks = load_data_batch_by_batch(DATA_PATH, 'train')
val_images, val_masks = load_data_batch_by_batch(DATA_PATH, 'val')
test_images, test_masks = load_data_batch_by_batch(DATA_PATH, 'test')

print(f"\nDataset sizes:")
print(f"Train: {train_images.shape[0]} samples")
print(f"Val: {val_images.shape[0]} samples")
print(f"Test: {test_images.shape[0]} samples")

train_gen = MemoryEfficientGenerator(train_images, train_masks, BATCH_SIZE, shuffle=True, augment=True)
val_gen = MemoryEfficientGenerator(val_images, val_masks, BATCH_SIZE, shuffle=False, augment=False)
test_gen = MemoryEfficientGenerator(test_images, test_masks, BATCH_SIZE, shuffle=False, augment=False)

print(f"\nBatches per epoch: Train={len(train_gen)}, Val={len(val_gen)}")

print("\nBuilding U-Net model...")
model = build_unet(input_shape=(INPUT_SIZE, INPUT_SIZE, INPUT_CHANNELS))

print(f"Model parameters: {model.count_params():,}")

optimizer = mixed_precision.LossScaleOptimizer(
    optimizers.Adam(learning_rate=LEARNING_RATE)
)

model.compile(
    optimizer=optimizer,
    loss='binary_crossentropy',
    metrics=['accuracy', 'precision', 'recall']
)

checkpoint_path = os.path.join(OUTPUT_PATH, "best_model.keras")
log_dir = os.path.join(OUTPUT_PATH, "logs")
os.makedirs(log_dir, exist_ok=True)

callbacks_list = [
    callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    ),
    callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=30,
        restore_best_weights=True,
        mode='max',
        verbose=1
    ),
    callbacks.ReduceLROnPlateau(
        monitor='val_accuracy',
        factor=0.5,
        patience=10,
        min_lr=1e-7,
        verbose=1
    ),
    callbacks.CSVLogger(
        os.path.join(OUTPUT_PATH, 'training_log.csv')
    ),
    callbacks.LambdaCallback(
        on_epoch_end=lambda epoch, logs: gc.collect()
    )
]

print("\nStarting training...")

try:
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=EPOCHS,
        callbacks=callbacks_list,
        verbose=1
    )
    print("Training completed")
    
except Exception as e:
    print(f"Training failed: {e}")
    raise

print("\nEvaluating model...")

best_model = models.load_model(checkpoint_path)
test_results = best_model.evaluate(test_gen, verbose=1)

print(f"\nTest Results:")
for metric_name, value in zip(best_model.metrics_names, test_results):
    print(f"{metric_name}: {value:.4f}")

test_metrics = {
    'configuration': 'standard_unet',
    'input_size': INPUT_SIZE,
    'input_channels': INPUT_CHANNELS,
    'batch_size': BATCH_SIZE,
    'mixed_precision': True,
    'model_parameters': int(model.count_params())
}

for metric_name, value in zip(best_model.metrics_names, test_results):
    test_metrics[metric_name] = float(value)

with open(os.path.join(OUTPUT_PATH, 'test_results.json'), 'w') as f:
    json.dump(test_metrics, f, indent=4)

def visualize_sample_predictions(model, images, masks, num_samples=3):
    plt.figure(figsize=(15, 5*num_samples))
    
    for i in range(min(num_samples, len(images))):
        img = images[i:i+1]
        mask = masks[i]
        
        pred = model.predict(img, verbose=0)[0].squeeze()
        
        img_display = img[0]
        img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min() + 1e-8)
        
        plt.subplot(num_samples, 4, i*4 + 1)
        plt.imshow(img_display[:, :, 0], cmap='gray')
        plt.title(f'Sample {i+1} - VV')
        plt.axis('off')
        
        plt.subplot(num_samples, 4, i*4 + 2)
        plt.imshow(img_display[:, :, 1], cmap='gray')
        plt.title(f'Sample {i+1} - VH')
        plt.axis('off')
        
        plt.subplot(num_samples, 4, i*4 + 3)
        plt.imshow(mask, cmap='Blues')
        plt.title('Ground Truth')
        plt.axis('off')
        
        plt.subplot(num_samples, 4, i*4 + 4)
        plt.imshow(pred, cmap='Blues')
        plt.title('Prediction')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'sample_predictions.png'), dpi=100, bbox_inches='tight')
    plt.show()

print("\nCreating visualizations...")
visualize_sample_predictions(best_model, test_images, test_masks, num_samples=3)

del train_images, val_images, test_images
del train_masks, val_masks, test_masks
gc.collect()

print("Done")

In [None]:
# Comprehensive Model Evaluation

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import models
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import json
import pandas as pd
from PIL import Image
from sklearn.metrics import confusion_matrix, classification_report
import gc


# Paths
MODEL_PATH = "/kaggle/working/unet_paper_memory_optimized/best_model_memory_opt.keras"
DATA_PATH = "/kaggle/working/preprocessed_512"
OUTPUT_PATH = "/kaggle/working/comprehensive_evaluation"

# Alternative paths to check if primary paths don't exist
ALTERNATIVE_MODEL_PATHS = [
    "/kaggle/working/unet_paper_memory_optimized/best_model_memory_opt.keras",
    "/kaggle/working/flood_unet_paper_config_512/best_unet_model_paper_config_cosine_annealing.keras",
    "/kaggle/working/unet_paper_exact/best_unet_paper_model.keras"
]

ALTERNATIVE_DATA_PATHS = [
    "/kaggle/working/preprocessed_512",
    "/kaggle/working/preprocessed_256",
    "/kaggle/working/preprocessed"
]

# Find working paths
print(" Searching for model and data...")

# Find model
working_model_path = None
for path in ALTERNATIVE_MODEL_PATHS:
    if os.path.exists(path):
        working_model_path = path
        print(f" Found model: {path}")
        break

if working_model_path is None:
    print(" No model found. Please check these locations:")
    for path in ALTERNATIVE_MODEL_PATHS:
        print(f"   • {path}")
    raise FileNotFoundError("No trained model found")

MODEL_PATH = working_model_path

# Find data
working_data_path = None
for path in ALTERNATIVE_DATA_PATHS:
    test_dir = os.path.join(path, 'test')
    if os.path.exists(test_dir):
        working_data_path = path
        print(f" Found data: {path}")
        break

if working_data_path is None:
    print(" No test data found. Please check these locations:")
    for path in ALTERNATIVE_DATA_PATHS:
        print(f"   • {path}/test/")
    raise FileNotFoundError("No test data found")

DATA_PATH = working_data_path

# Create output directory
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Configuration
INPUT_SIZE = 512
INPUT_CHANNELS = 2
BATCH_SIZE = 2  # Same as training for consistency
THRESHOLD = 0.5  # Binary classification threshold

print("=" * 70)
print(" COMPREHENSIVE MODEL EVALUATION")
print("=" * 70)
print(f" Model: {MODEL_PATH}")
print(f" Data: {DATA_PATH}")
print(f" Output: {OUTPUT_PATH}")
print(f" Input Size: {INPUT_SIZE}×{INPUT_SIZE}")
print(f" Threshold: {THRESHOLD}")
print("=" * 70)

# =====================================================
# DATA LOADING FUNCTIONS
# =====================================================

def load_test_data(base_path):
    """Load test dataset with better error handling"""
    print("Loading test dataset...")
    
    img_files = sorted(glob.glob(os.path.join(base_path, 'test', 'images', '*.npy')))
    mask_files = sorted(glob.glob(os.path.join(base_path, 'test', 'masks', '*.npy')))
    
    if len(img_files) == 0:
        raise ValueError(f"No test images found in {base_path}/test/images/")
    
    if len(mask_files) == 0:
        raise ValueError(f"No test masks found in {base_path}/test/masks/")
    
    if len(img_files) != len(mask_files):
        print(f" Warning: {len(img_files)} images but {len(mask_files)} masks")
        # Use minimum of both
        min_len = min(len(img_files), len(mask_files))
        img_files = img_files[:min_len]
        mask_files = mask_files[:min_len]
    
    print(f"Found {len(img_files)} test samples")
    
    images = []
    masks = []
    valid_samples = 0
    
    for i, (img_file, mask_file) in enumerate(zip(img_files, mask_files)):
        try:
            # Load image
            img = np.load(img_file).astype(np.float32)
            
            # Validate image shape
            if len(img.shape) != 3:
                print(f" Skipping {img_file}: invalid shape {img.shape}")
                continue
                
            img = img[:, :, :INPUT_CHANNELS]  # Take VV, VH channels
            
            # Load mask
            mask = np.load(mask_file).astype(np.float32)
            
            # Validate mask shape
            if len(mask.shape) != 2:
                print(f" Skipping {mask_file}: invalid shape {mask.shape}")
                continue
            
            # Validate shapes match
            if img.shape[:2] != mask.shape:
                print(f" Skipping sample {i}: shape mismatch img{img.shape[:2]} vs mask{mask.shape}")
                continue
            
            images.append(img)
            masks.append(mask)
            valid_samples += 1
            
        except Exception as e:
            print(f" Error loading sample {i}: {e}")
            continue
        
        if (i + 1) % 50 == 0:
            print(f"Processed {i + 1}/{len(img_files)} files, {valid_samples} valid")
    
    if valid_samples == 0:
        raise ValueError("No valid test samples loaded!")
    
    images = np.array(images)
    masks = np.array(masks)
    
    print(f" Test data loaded: {valid_samples} samples")
    print(f"Images shape: {images.shape}, Masks shape: {masks.shape}")
    
    return images, masks

class EvaluationGenerator(tf.keras.utils.Sequence):
    """Simple and robust data generator for evaluation"""
    
    def __init__(self, images, masks, batch_size=2):
        self.images = images
        self.masks = masks
        self.batch_size = batch_size
        self.n_samples = len(images)
    
    def __len__(self):
        return int(np.ceil(self.n_samples / self.batch_size))
    
    def __getitem__(self, index):
        start_idx = index * self.batch_size
        end_idx = min((index + 1) * self.batch_size, self.n_samples)
        
        batch_images = self.images[start_idx:end_idx]
        batch_masks = self.masks[start_idx:end_idx]
        
        # Ensure masks have channel dimension
        if len(batch_masks.shape) == 3:  # (batch, height, width)
            batch_masks = np.expand_dims(batch_masks, axis=-1)
        
        return batch_images, batch_masks
    
    def get_all_data(self):
        """Return all data at once - useful for evaluation"""
        if len(self.masks.shape) == 3:
            masks_with_channel = np.expand_dims(self.masks, axis=-1)
        else:
            masks_with_channel = self.masks
        
        return self.images, masks_with_channel

# =====================================================
# COMPREHENSIVE METRICS CALCULATION
# =====================================================

def calculate_pixel_metrics(y_true, y_pred, threshold=0.5):
    """Calculate pixel-wise metrics"""
    
    # Flatten arrays
    y_true_flat = y_true.flatten()
    y_pred_flat = y_pred.flatten()
    
    # Apply threshold to predictions
    y_pred_binary = (y_pred_flat >= threshold).astype(np.float32)
    y_true_binary = (y_true_flat >= threshold).astype(np.float32)
    
    # Calculate confusion matrix components
    tp = np.sum((y_true_binary == 1) & (y_pred_binary == 1))
    tn = np.sum((y_true_binary == 0) & (y_pred_binary == 0))
    fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1))
    fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0))
    
    # Calculate metrics
    metrics = {}
    
    # Basic metrics
    metrics['true_positives'] = int(tp)
    metrics['true_negatives'] = int(tn)
    metrics['false_positives'] = int(fp)
    metrics['false_negatives'] = int(fn)
    
    # Accuracy
    metrics['accuracy'] = (tp + tn) / (tp + tn + fp + fn + 1e-8)
    
    # Precision (Positive Predictive Value)
    metrics['precision'] = tp / (tp + fp + 1e-8)
    
    # Recall (Sensitivity, True Positive Rate)
    metrics['recall'] = tp / (tp + fn + 1e-8)
    
    # Specificity (True Negative Rate)
    metrics['specificity'] = tn / (tn + fp + 1e-8)
    
    # F1 Score
    metrics['f1_score'] = 2 * (metrics['precision'] * metrics['recall']) / (metrics['precision'] + metrics['recall'] + 1e-8)
    
    # IoU (Intersection over Union)
    intersection = tp
    union = tp + fp + fn
    metrics['iou'] = intersection / (union + 1e-8)
    
    # Dice Coefficient
    metrics['dice'] = 2 * tp / (2 * tp + fp + fn + 1e-8)
    
    # Jaccard Index (same as IoU)
    metrics['jaccard'] = metrics['iou']
    
    return metrics

def calculate_sample_wise_metrics(y_true, y_pred, threshold=0.5):
    """Calculate metrics for each sample separately"""
    
    sample_metrics = []
    
    for i in range(len(y_true)):
        sample_true = y_true[i]
        sample_pred = y_pred[i]
        
        metrics = calculate_pixel_metrics(sample_true, sample_pred, threshold)
        metrics['sample_id'] = i
        sample_metrics.append(metrics)
    
    return sample_metrics

def calculate_class_distribution(y_true, y_pred, threshold=0.5):
    """Calculate class distribution statistics"""
    
    y_true_flat = y_true.flatten()
    y_pred_flat = y_pred.flatten()
    
    y_pred_binary = (y_pred_flat >= threshold).astype(np.float32)
    y_true_binary = (y_true_flat >= threshold).astype(np.float32)
    
    stats = {
        'total_pixels': len(y_true_flat),
        'true_flood_pixels': int(np.sum(y_true_binary)),
        'true_non_flood_pixels': int(np.sum(1 - y_true_binary)),
        'pred_flood_pixels': int(np.sum(y_pred_binary)),
        'pred_non_flood_pixels': int(np.sum(1 - y_pred_binary)),
        'true_flood_percentage': float(np.mean(y_true_binary) * 100),
        'pred_flood_percentage': float(np.mean(y_pred_binary) * 100)
    }
    
    return stats

# =====================================================
# LOAD MODEL AND DATA
# =====================================================

print("\n Loading trained model...")
if not os.path.exists(MODEL_PATH):
    print(f" Model not found at {MODEL_PATH}")
    print("Available models:")
    model_dir = os.path.dirname(MODEL_PATH)
    if os.path.exists(model_dir):
        for f in os.listdir(model_dir):
            if f.endswith('.keras') or f.endswith('.h5'):
                print(f"   • {os.path.join(model_dir, f)}")
    raise FileNotFoundError(f"Model not found at {MODEL_PATH}")

try:
    model = models.load_model(MODEL_PATH)
    print(f" Model loaded successfully")
    print(f"Model input shape: {model.input_shape}")
    print(f"Model output shape: {model.output_shape}")
    print(f"Model parameters: {model.count_params():,}")
except Exception as e:
    print(f" Failed to load model: {e}")
    raise

# Load test data
try:
    test_images, test_masks = load_test_data(DATA_PATH)
    print(f" Test data loaded successfully")
except Exception as e:
    print(f" Failed to load test data: {e}")
    print(f"Check if path exists: {DATA_PATH}")
    print(f"Expected structure: {DATA_PATH}/test/images/*.npy and {DATA_PATH}/test/masks/*.npy")
    raise

# Validate data compatibility
print(f"\n Data validation:")
print(f"Test images shape: {test_images.shape}")
print(f"Test masks shape: {test_masks.shape}")
print(f"Model expects: {model.input_shape}")

if test_images.shape[1:] != model.input_shape[1:]:
    print(f" Warning: Input shape mismatch!")
    print(f"Model expects: {model.input_shape[1:]}")
    print(f"Data provides: {test_images.shape[1:]}")

# Create evaluation generator (keeping for compatibility)
test_gen = EvaluationGenerator(test_images, test_masks, BATCH_SIZE)

print(f"Test samples: {len(test_images)}")
print(f"Evaluation batches: {len(test_gen)}")

# Test a small prediction to ensure everything works
print(f"\n Testing model prediction...")
try:
    test_batch = test_images[:min(BATCH_SIZE, len(test_images))]
    test_pred = model.predict(test_batch, verbose=0)
    print(f" Test prediction successful: {test_pred.shape}")
    print(f"Prediction range: [{test_pred.min():.3f}, {test_pred.max():.3f}]")
except Exception as e:
    print(f" Test prediction failed: {e}")
    raise

# =====================================================
# GENERATE PREDICTIONS
# =====================================================

print("\n Generating predictions...")

# Method 1: Direct prediction (more reliable)
print("Using direct prediction method...")

try:
    # Predict directly on all test images
    predictions = model.predict(test_images, batch_size=BATCH_SIZE, verbose=1)
    predictions = predictions.squeeze()  # Remove channel dimension if present
    
    print(f" Predictions generated: {predictions.shape}")
    print(f"Prediction range: [{predictions.min():.3f}, {predictions.max():.3f}]")
    
except Exception as e:
    print(f" Direct prediction failed: {e}")
    print(" Trying batch-by-batch prediction...")
    
    # Method 2: Batch-by-batch prediction (fallback)
    all_predictions = []
    
    for i in range(0, len(test_images), BATCH_SIZE):
        batch_end = min(i + BATCH_SIZE, len(test_images))
        batch_images = test_images[i:batch_end]
        
        batch_predictions = model.predict(batch_images, verbose=0)
        all_predictions.append(batch_predictions)
        
        if (i // BATCH_SIZE + 1) % 10 == 0:
            print(f"Processed {i // BATCH_SIZE + 1} batches")
    
    # Concatenate all predictions
    predictions = np.concatenate(all_predictions, axis=0)
    predictions = predictions.squeeze()  # Remove channel dimension
    
    print(f" Predictions generated: {predictions.shape}")
    print(f"Prediction range: [{predictions.min():.3f}, {predictions.max():.3f}]")

# Verify predictions shape matches masks
if predictions.shape != test_masks.shape:
    print(f" Shape mismatch: predictions {predictions.shape} vs masks {test_masks.shape}")
    # Fix shape if needed
    if len(predictions.shape) == 3 and predictions.shape[-1] == 1:
        predictions = predictions.squeeze(-1)
    print(f" Fixed predictions shape: {predictions.shape}")

# =====================================================
# CALCULATE COMPREHENSIVE METRICS
# =====================================================

print("\n Calculating comprehensive metrics...")

# Overall metrics
overall_metrics = calculate_pixel_metrics(test_masks, predictions, THRESHOLD)

# Sample-wise metrics
sample_metrics = calculate_sample_wise_metrics(test_masks, predictions, THRESHOLD)

# Class distribution
class_stats = calculate_class_distribution(test_masks, predictions, THRESHOLD)

# Convert to DataFrame for easier analysis
sample_df = pd.DataFrame(sample_metrics)

print(" Metrics calculated successfully")

# =====================================================
# PRINT RESULTS
# =====================================================

print("\n" + "="*70)
print(" COMPREHENSIVE EVALUATION RESULTS")
print("="*70)

print("\n OVERALL PERFORMANCE:")
print(f"   • Accuracy: {overall_metrics['accuracy']:.4f} ({overall_metrics['accuracy']*100:.2f}%)")
print(f"   • IoU Score: {overall_metrics['iou']:.4f} ({overall_metrics['iou']*100:.2f}%)")
print(f"   • Dice Coefficient: {overall_metrics['dice']:.4f} ({overall_metrics['dice']*100:.2f}%)")
print(f"   • F1 Score: {overall_metrics['f1_score']:.4f} ({overall_metrics['f1_score']*100:.2f}%)")
print(f"   • Precision: {overall_metrics['precision']:.4f} ({overall_metrics['precision']*100:.2f}%)")
print(f"   • Recall: {overall_metrics['recall']:.4f} ({overall_metrics['recall']*100:.2f}%)")
print(f"   • Specificity: {overall_metrics['specificity']:.4f} ({overall_metrics['specificity']*100:.2f}%)")

print("\n CONFUSION MATRIX:")
print(f"   • True Positives: {overall_metrics['true_positives']:,}")
print(f"   • True Negatives: {overall_metrics['true_negatives']:,}")
print(f"   • False Positives: {overall_metrics['false_positives']:,}")
print(f"   • False Negatives: {overall_metrics['false_negatives']:,}")

print("\n CLASS DISTRIBUTION:")
print(f"   • Total Pixels: {class_stats['total_pixels']:,}")
print(f"   • True Flood Pixels: {class_stats['true_flood_pixels']:,} ({class_stats['true_flood_percentage']:.2f}%)")
print(f"   • Predicted Flood Pixels: {class_stats['pred_flood_pixels']:,} ({class_stats['pred_flood_percentage']:.2f}%)")

print("\n SAMPLE-WISE STATISTICS:")
print(f"   • Mean IoU: {sample_df['iou'].mean():.4f} ± {sample_df['iou'].std():.4f}")
print(f"   • Mean Dice: {sample_df['dice'].mean():.4f} ± {sample_df['dice'].std():.4f}")
print(f"   • Mean F1: {sample_df['f1_score'].mean():.4f} ± {sample_df['f1_score'].std():.4f}")
print(f"   • Best Sample IoU: {sample_df['iou'].max():.4f}")
print(f"   • Worst Sample IoU: {sample_df['iou'].min():.4f}")


# =====================================================
# SAVE RESULTS
# =====================================================

print("\n Saving results...")

# Save overall metrics
overall_results = {
    'model_path': MODEL_PATH,
    'test_samples': len(test_images),
    'threshold': THRESHOLD,
    'overall_metrics': overall_metrics,
    'class_statistics': class_stats,
    'paper_comparison': {
        'paper_iou': paper_results['iou'],
        'paper_accuracy': paper_results['accuracy'],
        'paper_f1': paper_results['f1_score'],
        'our_iou': overall_metrics['iou'],
        'our_accuracy': overall_metrics['accuracy'],
        'our_f1': overall_metrics['f1_score'],
        'iou_difference': overall_metrics['iou'] - paper_results['iou'],
        'accuracy_difference': overall_metrics['accuracy'] - paper_results['accuracy'],
        'f1_difference': overall_metrics['f1_score'] - paper_results['f1_score']
    }
}

# Save to JSON
with open(os.path.join(OUTPUT_PATH, 'comprehensive_results.json'), 'w') as f:
    json.dump(overall_results, f, indent=4)

# Save sample-wise results to CSV
sample_df.to_csv(os.path.join(OUTPUT_PATH, 'sample_wise_metrics.csv'), index=False)

print(f" Results saved to {OUTPUT_PATH}")

# =====================================================
# VISUALIZATIONS
# =====================================================

def create_comprehensive_visualizations():
    """Create comprehensive evaluation visualizations"""
    
    # 1. Metrics comparison bar chart
    plt.figure(figsize=(12, 8))
    
    metrics_names = ['Accuracy', 'IoU', 'Dice', 'F1-Score', 'Precision', 'Recall', 'Specificity']
    our_values = [
        overall_metrics['accuracy'],
        overall_metrics['iou'],
        overall_metrics['dice'],
        overall_metrics['f1_score'],
        overall_metrics['precision'],
        overall_metrics['recall'],
        overall_metrics['specificity']
    ]
    
    paper_values = [
        paper_results['accuracy'],
        paper_results['iou'],
        np.nan,  # Dice not reported in paper
        paper_results['f1_score'],
        np.nan,  # Precision not reported
        np.nan,  # Recall not reported
        np.nan   # Specificity not reported
    ]
    
    x = np.arange(len(metrics_names))
    width = 0.35
    
    plt.bar(x - width/2, our_values, width, label='Our Model', color='skyblue', alpha=0.8)
    
    # Only plot paper values where available
    paper_indices = [0, 1, 3]  # Accuracy, IoU, F1
    paper_x = [x[i] for i in paper_indices]
    paper_vals = [paper_values[i] for i in paper_indices]
    
    plt.bar([px + width/2 for px in paper_x], paper_vals, width, label='Paper Results', color='lightcoral', alpha=0.8)
    
    plt.xlabel('Metrics')
    plt.ylabel('Score')
    plt.title('Performance Comparison: Our Model vs Paper')
    plt.xticks(x, metrics_names, rotation=45)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 1)
    
    # Add value labels on bars
    for i, v in enumerate(our_values):
        plt.text(i - width/2, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontsize=9)
    
    for i, v in zip(paper_indices, paper_vals):
        plt.text(i + width/2, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'metrics_comparison.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 2. Sample-wise IoU distribution
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.hist(sample_df['iou'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    plt.axvline(sample_df['iou'].mean(), color='red', linestyle='--', label=f'Mean: {sample_df["iou"].mean():.3f}')
    plt.xlabel('IoU Score')
    plt.ylabel('Number of Samples')
    plt.title('Distribution of Sample-wise IoU Scores')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.boxplot([sample_df['iou'], sample_df['dice'], sample_df['f1_score']], 
                labels=['IoU', 'Dice', 'F1'])
    plt.ylabel('Score')
    plt.title('Sample-wise Metrics Distribution')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'sample_distributions.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 3. Confusion matrix heatmap
    plt.figure(figsize=(8, 6))
    
    # Create confusion matrix for visualization
    cm = np.array([[overall_metrics['true_negatives'], overall_metrics['false_positives']],
                   [overall_metrics['false_negatives'], overall_metrics['true_positives']]])
    
    # Normalize to percentages
    cm_percent = cm / cm.sum() * 100
    
    sns.heatmap(cm_percent, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=['Predicted Non-Flood', 'Predicted Flood'],
                yticklabels=['Actual Non-Flood', 'Actual Flood'],
                cbar_kws={'label': 'Percentage of Total Pixels'})
    
    plt.title('Confusion Matrix (% of Total Pixels)')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 4. Sample predictions visualization
    plt.figure(figsize=(20, 12))
    
    # Show 6 samples with worst and best IoU scores
    worst_samples = sample_df.nsmallest(3, 'iou')['sample_id'].values
    best_samples = sample_df.nlargest(3, 'iou')['sample_id'].values
    
    sample_indices = np.concatenate([worst_samples, best_samples])
    
    for i, idx in enumerate(sample_indices):
        # Original image (VV channel)
        plt.subplot(2, 9, i*3 + 1)
        img_display = test_images[idx][:, :, 0]  # VV channel
        img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min() + 1e-8)
        plt.imshow(img_display, cmap='gray')
        plt.title(f'Sample {idx}\nVV Channel')
        plt.axis('off')
        
        # Ground truth
        plt.subplot(2, 9, i*3 + 2)
        plt.imshow(test_masks[idx], cmap='Blues', vmin=0, vmax=1)
        plt.title('Ground Truth')
        plt.axis('off')
        
        # Prediction
        plt.subplot(2, 9, i*3 + 3)
        plt.imshow(predictions[idx], cmap='Blues', vmin=0, vmax=1)
        iou_score = sample_df.loc[sample_df['sample_id'] == idx, 'iou'].iloc[0]
        plt.title(f'Prediction\nIoU: {iou_score:.3f}')
        plt.axis('off')
    
    plt.suptitle('Worst 3 Samples (top) vs Best 3 Samples (bottom)', fontsize=16)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'sample_predictions.png'), dpi=300, bbox_inches='tight')
    plt.show()

print("\n Creating visualizations...")
create_comprehensive_visualizations()

# =====================================================
# THRESHOLD ANALYSIS
# =====================================================

def threshold_analysis():
    """Analyze performance across different thresholds"""
    print("\n Performing threshold analysis...")
    
    thresholds = np.arange(0.1, 1.0, 0.1)
    threshold_results = []
    
    for thresh in thresholds:
        metrics = calculate_pixel_metrics(test_masks, predictions, thresh)
        metrics['threshold'] = thresh
        threshold_results.append(metrics)
    
    threshold_df = pd.DataFrame(threshold_results)
    
    # Plot threshold analysis
    plt.figure(figsize=(12, 8))
    
    metrics_to_plot = ['accuracy', 'iou', 'dice', 'f1_score', 'precision', 'recall']
    
    for metric in metrics_to_plot:
        plt.plot(threshold_df['threshold'], threshold_df[metric], marker='o', label=metric.replace('_', ' ').title())
    
    plt.axvline(x=0.5, color='red', linestyle='--', alpha=0.7, label='Default Threshold')
    plt.xlabel('Threshold')
    plt.ylabel('Score')
    plt.title('Performance vs Classification Threshold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xlim(0.1, 0.9)
    plt.ylim(0, 1)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, 'threshold_analysis.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # Save threshold analysis
    threshold_df.to_csv(os.path.join(OUTPUT_PATH, 'threshold_analysis.csv'), index=False)
    
    # Find optimal threshold for IoU
    optimal_idx = threshold_df['iou'].idxmax()
    optimal_threshold = threshold_df.loc[optimal_idx, 'threshold']
    optimal_iou = threshold_df.loc[optimal_idx, 'iou']
    
    print(f" Optimal threshold for IoU: {optimal_threshold:.1f} (IoU: {optimal_iou:.4f})")
    print(f" Default threshold (0.5) IoU: {threshold_df.loc[threshold_df['threshold'] == 0.5, 'iou'].iloc[0]:.4f}")
    
    return threshold_df

threshold_results = threshold_analysis()

