In [6]:
import os
import numpy as np
from PIL import Image
import tensorflow as tf
import imageio
import tifffile

# -----------------------------
# Load and Preprocess Function
# -----------------------------


def load_image(path):
    # Check if the image is a .tif file
    if path.lower().endswith(('.tif', '.tiff')):
        img = tifffile.imread(path)  # Use tifffile for reading .tif images
        # Convert the image to uint8 if it's in a different type
        if img.dtype != 'uint8':
            img = (img / img.max() * 255).astype(np.uint8)  # Normalize to 8-bit
        img = Image.fromarray(img)  # Convert to a PIL image
    else:
        img = Image.open(path)  # Use PIL for other formats
    
    img = img.resize((256, 256)).convert('L')  # Convert to grayscale
    return np.array(img) / 255.0  # Normalize to [0, 1]



def get_stacked_images(before_dir, after_dir):
    before_files = sorted(os.listdir(before_dir))
    after_files = sorted(os.listdir(after_dir))
    
    inputs = []
    filenames = []

    for bfile, afile in zip(before_files, after_files):
        # Only process files with valid image extensions
        if bfile.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')) and afile.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')):
            before = load_image(os.path.join(before_dir, bfile))
            after = load_image(os.path.join(after_dir, afile))
            
            stacked = np.stack([before, after], axis=-1)  # (256, 256, 2)
            inputs.append(stacked)
            filenames.append(os.path.splitext(bfile)[0])  # use base name

    return np.array(inputs), filenames

# -----------------------------
# Dice Loss (for model loading)
# -----------------------------

def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (
        tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth
    )

# -----------------------------
# Inference & Save Masks
# -----------------------------

def run_inference(model_path, before_dir, after_dir, output_dir):
    print("Loading model...")
    model = tf.keras.models.load_model(model_path, custom_objects={"dice_loss": dice_loss})

    print("Reading images...")
    input_images, filenames = get_stacked_images(before_dir, after_dir)

    print("=Running inference...")
    predictions = model.predict(input_images, verbose=1)
    binary_masks = (predictions > 0.5).astype(np.uint8)

    print("=Saving masks...")
    os.makedirs(output_dir, exist_ok=True)
    for mask, name in zip(binary_masks, filenames):
        mask_img = (mask.squeeze() * 255).astype(np.uint8)
        imageio.imwrite(os.path.join(output_dir, f"{name}_pred_mask.png"), mask_img)

    print(f"Masks saved to: {output_dir}")

# -----------------------------
# Entry Point
# -----------------------------

if __name__ == "__main__":
    model_path = r"D:\Projects\Flood_Mapping\best_unet_model_2.keras"     
    before_dir = "test_before"  # Folder of before images
    after_dir = "test_after"    # Folder of after images
    output_dir = "predicted_masks"  # Where to save outputs

    run_inference(model_path, before_dir, after_dir, output_dir)  # make this work for .tif files also


Loading model...
Reading images...
=Running inference...
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 10s/step
=Saving masks...
Masks saved to: predicted_masks
