## Post processing steps for clay

## georeferencing of clay output, post process step 1 (in case of sen1 and sen2 cause original image size and resize is same)

In [None]:

import os
import re
import rasterio
from rasterio.transform import from_bounds
from PIL import Image
import numpy as np

# Define directories
ref_input_folder = r"D:\CVPR\Flood_Results\cambodia\ground_truth"   # Folder containing reference GeoTIFFs
mask_input_folder = r"D:\CVPR\Flood_Results\cambodia\Clay_S2"    # Folder containing mask prediction PNGs
output_folder = os.path.join(mask_input_folder, "georeferenced_masks")
os.makedirs(output_folder, exist_ok=True)

# Get list of reference images
ref_images = {f.replace(".tif", ""): os.path.join(ref_input_folder, f) for f in os.listdir(ref_input_folder) if f.endswith(".tif")}

# Get list of mask prediction images
mask_files = [f for f in os.listdir(mask_input_folder) if f.endswith(".png")]

# Process each mask image
for mask_file in mask_files:
    # Extract the main name (before `_naip-new_chip_X.npy.png`)
    base_name_match = re.match(r"pred_([\w\d_]+)_naip-new_chip_\d+\.npy\.png", mask_file)
    
    if not base_name_match:
        print(f"Skipping {mask_file}, could not extract base name.")
        continue

    base_name = base_name_match.group(1)

    # Ensure corresponding reference image exists
    if base_name not in ref_images:
        print(f"Skipping {mask_file}, no matching reference image found for: {base_name}.tif")
        continue

    ref_image_path = ref_images[base_name]
    mask_image_path = os.path.join(mask_input_folder, mask_file)
    output_path = os.path.join(output_folder, mask_file.replace(".png", ".tif"))

    # Open the reference image to get geospatial info
    with rasterio.open(ref_image_path) as ref:
        ref_crs = ref.crs
        ref_transform = ref.transform
        ref_bounds = ref.bounds

    # Open the mask prediction image
    with Image.open(mask_image_path) as mask_img:
        # Convert to grayscale (single-band)
        mask_data = np.array(mask_img.convert("L"))

    # Get mask dimensions
    mask_height, mask_width = mask_data.shape

    # Compute the new transform to match the reference image's extent
    new_transform = from_bounds(
        ref_bounds.left, ref_bounds.bottom, ref_bounds.right, ref_bounds.top,
        mask_width, mask_height
    )

    # Save the mask as a georeferenced GeoTIFF
    with rasterio.open(
        output_path, "w",
        driver="GTiff",
        height=mask_height,
        width=mask_width,
        count=1,  # Single-band
        dtype=mask_data.dtype,
        crs=ref_crs,
        transform=new_transform
    ) as dst:
        dst.write(mask_data, 1)

    print(f"Saved georeferenced mask: {output_path}")

print("Processing complete for all images.")


### In case of PS when images are split into smaller parts (512x512)

In [None]:


# Define directories
ref_input_folder = r"D:\CVPR\Flood_Results\cambodia\PS_image"   # Folder containing reference GeoTIFFs
mask_input_folder = r"D:\CVPR\Flood_Results\cambodia\Clay_PS"    # Folder containing mask prediction PNGs
output_folder = os.path.join(mask_input_folder, "georeferenced_masks")
os.makedirs(output_folder, exist_ok=True)

# Get list of reference images
# Get list of reference images (removing `_lc` suffix for matching)
ref_images = {f.replace(".tif", ""): os.path.join(ref_input_folder, f) 
              for f in os.listdir(ref_input_folder) if f.endswith(".tif")}


# Get list of mask prediction images
mask_files = [f for f in os.listdir(mask_input_folder) if f.endswith(".png")]

# Process each mask image
for mask_file in mask_files:
    # Extract the main name (e.g., "USA_19_27")## replace it as per region
    base_name_match = re.match(r"pred_(MEK_\d+)_naip-new_chip_\d+\.npy\.png", mask_file)

    
    if not base_name_match:
        print(f"Skipping {mask_file}, could not extract base name.")
        continue

    base_name = base_name_match.group(1)

    # Ensure corresponding reference image exists
    if base_name not in ref_images:
        print(f"Skipping {mask_file}, no matching reference image found for: {base_name}.tif")
        continue

    ref_image_path = ref_images[base_name]
    mask_image_path = os.path.join(mask_input_folder, mask_file)
    output_path = os.path.join(output_folder, mask_file.replace(".png", ".tif"))

    # Open the reference image to get geospatial info
    with rasterio.open(ref_image_path) as ref:
        ref_crs = ref.crs
        ref_transform = ref.transform
        ref_bounds = ref.bounds

    # Open the mask prediction image
    with Image.open(mask_image_path) as mask_img:
        # Convert to grayscale (single-band)
        mask_data = np.array(mask_img.convert("L"))

    # Get mask dimensions
    mask_height, mask_width = mask_data.shape

    # Compute the new transform to match the reference image's extent
    new_transform = from_bounds(
        ref_bounds.left, ref_bounds.bottom, ref_bounds.right, ref_bounds.top,
        mask_width, mask_height
    )

    # Save the mask as a georeferenced GeoTIFF
    with rasterio.open(
        output_path, "w",
        driver="GTiff",
        height=mask_height,
        width=mask_width,
        count=1,  # Single-band
        dtype=mask_data.dtype,
        crs=ref_crs,
        transform=new_transform
    ) as dst:
        dst.write(mask_data, 1)

    print(f"Saved georeferenced mask: {output_path}")

print("Processing complete for all images.")


### Note: We dont need any steps given below for sen 1 and sen2, because they are trained/predicted on original image size  

### Mosiacking:  Clay post process step 2

In [None]:
from collections import defaultdict
import rasterio.transform

# Define the folder containing the TIFF files
folder_path = r"D:\CVPR\Flood_Results\cambodia\Clay_PS\georeferenced_masks"

# List all TIFF files in the folder
tif_files = [f for f in os.listdir(folder_path) if f.endswith(".tif")]

# Group files by their base name (excluding the _naip-new_chip_X part)
file_groups = defaultdict(dict)

for file in tif_files:
    parts = file.rsplit("_naip-new_chip_", 1)
    if len(parts) == 2:
        base_name, suffix = parts
        file_groups[base_name][suffix] = os.path.join(folder_path, file)

# Output directory
output_dir = os.path.join(folder_path, "mosaicked_outputs")
os.makedirs(output_dir, exist_ok=True)

# Store output file paths
output_files = []

# Process each group
for base_name, files in file_groups.items():
    required_suffixes = ["0.npy.tif", "1.npy.tif", "2.npy.tif", "3.npy.tif"]

    if all(suffix in files for suffix in required_suffixes):
        # Open raster files
        rasters = {suffix: rasterio.open(files[suffix]) for suffix in required_suffixes}
        meta = rasters["0.npy.tif"].meta.copy()

        # Get pixel size and top-left coordinates
        pixel_size_x = rasters["0.npy.tif"].transform.a  # Pixel width
        pixel_size_y = rasters["0.npy.tif"].transform.e  # Pixel height (should be negative)
        top_left_x = rasters["0.npy.tif"].transform.c  # Leftmost X
        top_left_y = rasters["0.npy.tif"].transform.f  # Topmost Y

        # Correct transform to align with QGIS georeferencing
        corrected_transform = rasterio.transform.Affine(
            pixel_size_x, 0, top_left_x,
            0, pixel_size_y, top_left_y
        )

        # Update metadata for mosaic dimensions
        meta.update({
            "width": rasters["0.npy.tif"].width + rasters["2.npy.tif"].width,
            "height": rasters["0.npy.tif"].height + rasters["1.npy.tif"].height,
            "transform": corrected_transform
        })

        # Create an empty mosaic array
        mosaic = np.zeros((meta["count"], meta["height"], meta["width"]), dtype=meta["dtype"])

        # Place each raster in the correct position
        mosaic[:, 0:rasters["0.npy.tif"].height, 0:rasters["0.npy.tif"].width] = rasters["0.npy.tif"].read()
        mosaic[:, rasters["0.npy.tif"].height:, 0:rasters["0.npy.tif"].width] = rasters["1.npy.tif"].read()
        mosaic[:, 0:rasters["2.npy.tif"].height, rasters["0.npy.tif"].width:] = rasters["2.npy.tif"].read()
        mosaic[:, rasters["2.npy.tif"].height:, rasters["0.npy.tif"].width:] = rasters["3.npy.tif"].read()

        # Save the corrected mosaic
        output_path_qgis = os.path.join(output_dir, f"{base_name}.tif")

        with rasterio.open(output_path_qgis, "w", **meta) as dst:
            dst.write(mosaic)

        # Close raster files
        for raster in rasters.values():
            raster.close()

        # Store output file path
        output_files.append(output_path_qgis)
        print(f"Corrected mosaic saved: {output_path_qgis}")

# Print all output file paths
print("\nAll mosaicked files saved:")
for file in output_files:
    print(file)


### Again geo referencing using ground truth data to match ground truth extent and spatial resolution :Step 3

In [None]:

# Define directories
reference_folder = r"D:\CVPR\Flood_Results\cambodia\PS_image"  # Reference GeoTIFFs
mask_folder = r"D:\CVPR\Flood_Results\cambodia\Clay_PS\georeferenced_masks\mosaicked_outputs"  # Predicted masks
output_folder = r"D:\CVPR\Flood_Results\cambodia\Clay_PS\georeferenced_masks\mosaicked_outputs\geo"  # Corrected output folder

# Create output directory if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

# Get list of reference images
ref_images = {f.replace(".tif", ""): os.path.join(reference_folder, f) for f in os.listdir(reference_folder) if f.endswith(".tif")}

# Get list of predicted mask images
mask_files = [f for f in os.listdir(mask_folder) if f.startswith("pred_") and f.endswith(".tif")]

# DEBUG: Print matched reference images
print("📂 Reference images found:", list(ref_images.keys()))
print("📂 Predicted mask images found:", mask_files)

# Process each mask file
for mask_file in mask_files:
    # Extract base name (remove "pred_" prefix)
    base_name_match = re.match(r"pred_(.+)\.tif", mask_file)
    
    if not base_name_match:
        print(f"⚠️ Skipping {mask_file}, could not extract base name.")
        continue

    base_name = base_name_match.group(1)

    # Ensure corresponding reference image exists
    if base_name not in ref_images:
        print(f"⚠️ Skipping {mask_file}, no matching reference image found for: {base_name}.tif")
        continue

    ref_image_path = ref_images[base_name]
    mask_image_path = os.path.join(mask_folder, mask_file)
    output_path = os.path.join(output_folder, f"{base_name}.tif")  # Save using reference name

    # DEBUG: Print file being processed
    print(f"🔄 Processing: {mask_file} → Saving as: {output_path}")

    # Open the reference image to get geospatial info
    with rasterio.open(ref_image_path) as ref:
        ref_crs = ref.crs
        ref_transform = ref.transform
        ref_bounds = ref.bounds

    # Open the predicted mask image
    with rasterio.open(mask_image_path) as mask:
        mask_data = mask.read(1)  # Read first (and only) band

        # DEBUG: Check if the mask has data
        if mask_data is None or mask_data.size == 0:
            print(f"⚠️ Warning: Empty mask file detected: {mask_file}")
            continue

        mask_height, mask_width = mask_data.shape

    # Compute the new transform to match the reference image's extent
    new_transform = from_bounds(
        ref_bounds.left, ref_bounds.bottom, ref_bounds.right, ref_bounds.top,
        mask_width, mask_height
    )

    # Save the mask as a georeferenced GeoTIFF using reference name
    with rasterio.open(
        output_path, "w",
        driver="GTiff",
        height=mask_height,
        width=mask_width,
        count=1,  # Single-band
        dtype=mask_data.dtype,
        crs=ref_crs,
        transform=new_transform
    ) as dst:
        dst.write(mask_data, 1)

    print(f"✅ Saved georeferenced mask: {output_path}")

print("🎉 Processing complete for all images.")


### Geo Reference in case of prithvi

In [None]:
from rasterio.enums import Resampling

# Define directories
input_folder = r"D:\CVPR\Flood_Results\cambodia\Prithvi_S2"  # Folder with images to transform
ref_folder = r"D:\CVPR\Flood_Results\cambodia\ground_truth"  # Folder where reference images (masks) are stored
output_folder = os.path.join(input_folder, "georeferenced_and_resized")
os.makedirs(output_folder, exist_ok=True)

# Get list of input images
input_files = [f for f in os.listdir(input_folder) if f.endswith(".tif")]

# Process each input image
for input_file in input_files:
    if input_file.startswith("mask"):  # Skip if it's the reference mask
        continue

    # Define the correct reference filename with "_lc" suffix
    reference_file = input_file.replace(".tif", ".tif")  # e.g., "USA_11_58.tif" -> "USA_11_58_lc.tif"
    ref_image_path = os.path.join(ref_folder, reference_file)  # Use ref_folder instead of input_folder
    input_image_path = os.path.join(input_folder, input_file)  # Image to transform
    output_path = os.path.join(output_folder, input_file)  # Transformed output image

    # Check if the reference image exists
    if not os.path.exists(ref_image_path):
        print(f"Skipping {input_file}, reference image not found: {reference_file}")
        continue

    # Open the reference image to get geospatial info
    with rasterio.open(ref_image_path) as ref:
        ref_crs = ref.crs
        ref_transform = ref.transform
        ref_width = ref.width
        ref_height = ref.height

    # Open the input image
    with Image.open(input_image_path) as img:
        # Resize the input image to match the reference size
        img_resized = img.resize((ref_width, ref_height), Image.BILINEAR)
        img_data = np.array(img_resized)

    # Save the transformed image as a GeoTIFF
    with rasterio.open(
        output_path, "w",
        driver="GTiff",
        height=ref_height,
        width=ref_width,
        count=1,  # Assuming single-band
        dtype=img_data.dtype,
        crs=ref_crs,
        transform=ref_transform
    ) as dst:
        dst.write(img_data, 1)

    print(f"✅ Saved transformed and resized image: {output_path}")

print("🎉 Processing complete for all images.")


### Plotting and saving 

In [None]:
import matplotlib.pyplot as plt

# Define directories
base_dir = r"D:\CVPR\Flood_Results\plot"  # Change this to your extracted folder path
image_dir = os.path.join(base_dir, "image")
ground_truth_dir = os.path.join(base_dir, "ground_truth")
model_folders = ["Atten_UNet", "Prithvi_600m", "Transnorm", "UNet", "UViT_100", "UViT_600", "Clay"]

# Function to read and visualize an image using a 4-3-2 band combination
def read_image_bands(image_path):
    with rasterio.open(image_path) as src:
        num_bands = src.count
        if num_bands >= 4:
            r, g, b = src.read(4), src.read(3), src.read(2)
        else:
            r = g = b = src.read(1)
        
        r, g, b = [np.clip(band / band.max(), 0, 1) for band in (r, g, b)]
        return np.dstack([r, g, b])

# Function to read a grayscale image (for ground truth and similar single-band images)
def read_grayscale_image(image_path):
    with rasterio.open(image_path) as src:
        band = src.read(1)
        return np.clip(band / band.max(), 0, 1)

# Get list of image filenames
image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(".tif")])

# Iterate over all images
for img_file in image_files:
    img_path = os.path.join(image_dir, img_file)
    gt_path = os.path.join(ground_truth_dir, img_file)

    # Read image and ground truth
    image_rgb = read_image_bands(img_path)
    ground_truth_gray = read_grayscale_image(gt_path)

    # Read model outputs
    model_images = []
    for model in model_folders:
        model_img_path = os.path.join(base_dir, model, img_file)
        model_rgb = read_image_bands(model_img_path)
        model_images.append(model_rgb)

    # Create a plot
    fig, axes = plt.subplots(1, len(model_folders) + 2, figsize=(25, 5))

    # Plot original image
    axes[0].imshow(image_rgb)
    axes[0].set_title(f"Original: {img_file} (4-3-2)")
    axes[0].axis("off")

    # Plot ground truth as grayscale
    axes[1].imshow(ground_truth_gray, cmap="gray")
    axes[1].set_title("Ground Truth (Grayscale)")
    axes[1].axis("off")

    # Plot all model results
    for i, (img, title) in enumerate(zip(model_images, model_folders)):
        axes[i + 2].imshow(img)
        axes[i + 2].set_title(f"{title} (4-3-2)")
        axes[i + 2].axis("off")

    plt.tight_layout()
    plt.show()


## saving but different strech

In [None]:
# Define directories
base_dir = r"D:\CVPR\Flood_Results\plot"  # Change this to your extracted folder path
image_dir = os.path.join(base_dir, "image")
ground_truth_dir = os.path.join(base_dir, "ground_truth")
model_folders = [ "Prithvi-UperNet", "Clay", "UViT100m", "UViT600m", "Transnorm", "UNet", "Atten_UNet" ]

# Output directory for saved plots
output_dir = os.path.join(base_dir, "saved_plots")
os.makedirs(output_dir, exist_ok=True)

# Function to read and visualize an image using a 4-3-2 band combination
def read_image_bands(image_path):
    with rasterio.open(image_path) as src:
        num_bands = src.count
        print(f"Reading {image_path} - Bands available: {num_bands}")  # Debugging print

        if num_bands >= 4:
            r, g, b = src.read(4), src.read(3), src.read(2)  # Standard RGB (4-3-2)
        elif num_bands == 3:
            r, g, b = src.read(3), src.read(2), src.read(1)  # Alternative (3-2-1)
        elif num_bands == 1:
            r = g = b = src.read(1)  # Convert single-band to grayscale
        else:
            raise ValueError(f"Unexpected band count: {num_bands} in {image_path}")

        # Normalize using 2-98 percentile stretch
        def normalize_band(band):
            p2, p98 = np.percentile(band, (2, 98))  # Robust scaling
            return np.clip((band - p2) / (p98 - p2 + 1e-6), 0, 1)  # Avoid zero division
        
        r, g, b = map(normalize_band, [r, g, b])  

        return np.dstack([r, g, b])



# Function to read a grayscale image (for ground truth and similar single-band images)
def read_grayscale_image(image_path):
    with rasterio.open(image_path) as src:
        band = src.read(1)
        return np.clip(band / band.max(), 0, 1)

# Get list of image filenames
image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(".tif")])

# Process images in batches of 4
batch_size = 4
num_batches = (len(image_files) + batch_size - 1) // batch_size  # Calculate number of batches

for batch_idx in range(num_batches):
    batch_files = image_files[batch_idx * batch_size : (batch_idx + 1) * batch_size]
    num_rows = len(batch_files)
    num_cols = len(model_folders) + 2  # Original + Ground Truth + Models

    # Dynamically calculate figure height based on number of rows
    fig, axes = plt.subplots(
        num_rows, num_cols, figsize=(30, 3 * num_rows),
        gridspec_kw={'hspace': 0.02, 'wspace': 0.01}  # Minimized gaps
    )

    # If there's only one row, make axes a 2D list for consistency
    if num_rows == 1:
        axes = [axes]

    for row, img_file in enumerate(batch_files):
        img_path = os.path.join(image_dir, img_file)
        gt_path = os.path.join(ground_truth_dir, img_file)

        # Read image and ground truth
        image_rgb = read_image_bands(img_path)
        ground_truth_gray = read_grayscale_image(gt_path)

        # Read model outputs
        model_images = []
        for model in model_folders:
            model_img_path = os.path.join(base_dir, model, img_file)
            model_rgb = read_image_bands(model_img_path)
            model_images.append(model_rgb)

        # Plot original image
        axes[row][0].imshow(image_rgb)
        axes[row][0].axis("off")
        axes[row][0].add_patch(Rectangle((0, 0), 1, 1, transform=axes[row][0].transAxes,
                                         color="black", fill=False, lw=0.5))

        # Plot ground truth as grayscale
        axes[row][1].imshow(ground_truth_gray, cmap="gray")
        axes[row][1].axis("off")
        axes[row][1].add_patch(Rectangle((0, 0), 1, 1, transform=axes[row][1].transAxes,
                                         color="black", fill=False, lw=0.5))

        # Plot all model results
        for i, img in enumerate(model_images):
            axes[row][i + 2].imshow(img)
            axes[row][i + 2].axis("off")
            axes[row][i + 2].add_patch(Rectangle((0, 0), 1, 1, transform=axes[row][i + 2].transAxes,
                                                 color="black", fill=False, lw=0.5))

    # Set column titles only on the top row
    column_titles = ["Image", "Ground Truth"] + model_folders
    for col, title in enumerate(column_titles):
        axes[0][col].set_title(title, fontsize=18, fontweight='bold', pad=5)

    # Save the batch plot
    output_path = os.path.join(output_dir, f"batch_{batch_idx + 1}.png")
    plt.savefig(output_path, dpi=500, bbox_inches='tight', pad_inches=0.1)
    plt.close(fig)  # Close the figure to free memory

print(f"All plots saved in {output_dir}")


### Plotting single image from each sensor

In [None]:
## with border
import matplotlib.patches as patches  # Import patches for adding borders

# Define base directory
base_dir = r"D:\CVPR\Flood_Results\ghana"

# Define sensor folders
sensor_folders = {
    "PS": "PS_image",
    "S2": "Sen2_image",
    "S1": "Sen1_image",
}

# Define corresponding model folders
model_mapping = {
    "PS": {"Clay": "Clay_PS", "Prithvi": "Prithvi_PS"},
    "S2": {"Clay": "Clay_S2", "Prithvi": "Prithvi_S2"},
    "S1": {"Clay": "Clay_S1", "Prithvi": "Prithvi_S1"},
}

# Ground truth directory
ground_truth_dir = os.path.join(base_dir, "ground_truth")

# Output directory for saved plots
output_dir = os.path.join(base_dir, "saved_plots")
os.makedirs(output_dir, exist_ok=True)

# Function to read and visualize an image using a robust 2-98 percentile normalization
def read_image_bands(image_path):
    with rasterio.open(image_path) as src:
        num_bands = src.count
        print(f"Reading {image_path} - Bands available: {num_bands}")  # Debugging print

        if num_bands >= 3:
            r, g, b = src.read(3), src.read(2), src.read(1)  # Standard RGB (3-2-1)
        elif num_bands == 2:
            r, g = src.read(1), src.read(2)  # Use both bands
            b = (r + g) / 2  # Approximate blue as the mean of the two bands
        elif num_bands == 1:
            r = g = b = src.read(1)  # Convert single-band to grayscale
        else:
            raise ValueError(f"Unexpected band count: {num_bands} in {image_path}")

        # Normalize using 2-98 percentile stretch
        def normalize_band(band):
            p2, p98 = np.percentile(band, (2, 98))  # Robust scaling
            return np.clip((band - p2) / (p98 - p2 + 1e-6), 0, 1)  # Avoid zero division

        r, g, b = map(normalize_band, [r, g, b])  

        return np.dstack([r, g, b])

# Function to read a grayscale image
def read_grayscale_image(image_path):
    with rasterio.open(image_path) as src:
        band = src.read(1)
        return np.clip(band / band.max(), 0, 1)

# Get a list of image filenames from one sensor folder
image_files = sorted(os.listdir(os.path.join(base_dir, sensor_folders["PS"])))

# Row titles corresponding to sensors
row_titles = ["PlanetScope", "Sentinel 2", "Sentinel 1"]

# Process each image one at a time
for img_file in image_files:
    fig, axes = plt.subplots(
        len(sensor_folders), 4, figsize=(14.5, 3.5 * len(sensor_folders)),  
        gridspec_kw={'hspace': 0.02, 'wspace': -0.4}  # Reduce whitespace
    )

    for row, (sensor, folder) in enumerate(sensor_folders.items()):
        img_path = os.path.join(base_dir, folder, img_file)
        gt_path = os.path.join(ground_truth_dir, img_file)

        if not os.path.exists(img_path):
            print(f"Skipping {img_file} for {sensor}, file not found.")
            continue

        # Read images
        image_rgb = read_image_bands(img_path)
        ground_truth_gray = read_grayscale_image(gt_path)
        clay_rgb = read_image_bands(os.path.join(base_dir, model_mapping[sensor]["Clay"], img_file))
        prithvi_rgb = read_image_bands(os.path.join(base_dir, model_mapping[sensor]["Prithvi"], img_file))

        # Plot images
        axes[row][0].imshow(image_rgb)
        axes[row][1].imshow(ground_truth_gray, cmap="gray")
        axes[row][2].imshow(clay_rgb)
        axes[row][3].imshow(prithvi_rgb)

        # Add a thin black border to each image using patches.Rectangle
        for col in range(4):
            rect = patches.Rectangle(
                (0, 0), 1, 1,  # (x, y) and width, height
                transform=axes[row][col].transAxes,  # Align with subplot
                linewidth=3, edgecolor='black', facecolor='none'
            )
            axes[row][col].add_patch(rect)  # Apply border
            axes[row][col].axis("off")  # Hide axes

    # Set column titles
    column_titles = ["Image", "Ground truth", "Clay", "Prithvi"]
    for col, title in enumerate(column_titles):
        axes[0][col].set_title(title, fontsize=18, fontweight='bold', pad=10)

    # Adjust row titles and move Sentinel 1 & Sentinel 2 slightly up
    row_titles = ["PlanetScope", "Sentinel 2", "Sentinel 1"]
    for i, title in enumerate(row_titles):
        y_offset = 0  # Default offset
        if title == "Sentinel 2":
            y_offset = 0.05  # Move Sentinel 2 up slightly
        elif title == "Sentinel 1":
            y_offset = 0.075  # Move Sentinel 1 up slightly

        fig.text(-0.02, 1 - (i + 0.5) / len(row_titles) + y_offset, title, 
                 fontsize=18, fontweight='bold', ha='right', va='center', rotation=90)

    # Slightly adjust layout to prevent cutting off labels
    fig.subplots_adjust(left=-0.095, top=0.95)  

    # Save the figure
    output_path = os.path.join(output_dir, f"{img_file.replace('.tif', '.png')}")
    plt.savefig(output_path, dpi=300, bbox_inches='tight', pad_inches=0.1)
    plt.close(fig)  # Free memory

print(f"All plots saved in {output_dir}")


### Post process to define CRS of Unet, UViT, and TrasNorm

In [None]:
# Define directories
mask_input_folder = r"D:\CVPR\Flood_Results\Spain\UViT_600"  # Folder with _msk_pred.png
ref_input_folder = r"D:\CVPR\Flood_Results\Spain\ground_truth"  # Folder with _lc.tif georeferenced images
output_folder = os.path.join(mask_input_folder, "georeferenced_masks")
os.makedirs(output_folder, exist_ok=True)

# Get list of reference images (removing "_lc" suffix for matching)
ref_images = {f.replace("_lc.tif", "") for f in os.listdir(ref_input_folder) if f.endswith(".tif")}

# Get list of mask prediction images
mask_files = [f for f in os.listdir(mask_input_folder) if f.endswith("_msk_pred.png")]

# Process each mask image
for mask_file in mask_files:
    # Extract base name correctly, handling brackets
    base_name_match = re.search(r"\[?'?([\w\d_]+)\.tif'?\]?_msk_pred\.png", mask_file)

    if not base_name_match:
        print(f"Skipping {mask_file}, could not extract base name.")
        continue

    base_name = base_name_match.group(1)

    # Ensure corresponding reference image exists
    if base_name not in ref_images:
        print(f"Skipping {mask_file}, no matching reference image found for: {base_name}_lc.tif")
        continue

    # Define reference and mask file paths
    ref_image_path = os.path.join(ref_input_folder, f"{base_name}_lc.tif")  # Reference image
    mask_image_path = os.path.join(mask_input_folder, mask_file)  # Mask image
    output_path = os.path.join(output_folder, f"{base_name}_msk_pred.tif")  # Output GeoTIFF

    # Open the reference image to get geospatial info
    with rasterio.open(ref_image_path) as ref:
        ref_crs = ref.crs
        ref_transform = ref.transform
        ref_bounds = ref.bounds

    # Open the mask prediction image
    with Image.open(mask_image_path) as mask_img:
        # Convert to grayscale (single-band)
        mask_data = np.array(mask_img.convert("L"))

    # Ensure mask dimensions match the expected size
    mask_height, mask_width = mask_data.shape

    # Compute the new transform to match the reference image's extent
    new_transform = from_bounds(
        ref_bounds.left, ref_bounds.bottom, ref_bounds.right, ref_bounds.top,
        mask_width, mask_height
    )

    # Save the mask as a georeferenced GeoTIFF
    with rasterio.open(
        output_path, "w",
        driver="GTiff",
        height=mask_height,
        width=mask_width,
        count=1,  # Single-band
        dtype=mask_data.dtype,
        crs=ref_crs,
        transform=new_transform
    ) as dst:
        dst.write(mask_data, 1)

    print(f"✅ Saved georeferenced mask: {output_path}")

print("🎉 Processing complete for all images.")


### Minor Changes in saving SAR (sen1) outputs

In [None]:

# Define base directory
base_dir = r"D:\CVPR\Flood_Results\Spain\main_figure"

# Define sensor folders
sensor_folders = {
    "PS": "PS_image",
    "S2": "Sen2_image",
    "S1": "Sen1_image",
}

# Define corresponding model folders
model_mapping = {
    "PS": {"Clay": "Clay_PS", "Prithvi": "Prithvi_PS"},
    "S2": {"Clay": "Clay_S2", "Prithvi": "Prithvi_S2"},
    "S1": {"Clay": "Clay_S1", "Prithvi": "Prithvi_S1"},
}

# Ground truth directory
ground_truth_dir = os.path.join(base_dir, "ground_truth")

# Output directory for saved plots
output_dir = os.path.join(base_dir, "saved_plots")
os.makedirs(output_dir, exist_ok=True)

# Function to read and visualize an image using a 4-3-2 band combination
def read_image_bands(image_path):
    with rasterio.open(image_path) as src:
        num_bands = src.count
        if num_bands >= 4:
            r, g, b = src.read(4), src.read(3), src.read(2)
        else:
            r = g = b = src.read(1)
        
        r, g, b = [np.clip(band / band.max(), 0, 1) for band in (r, g, b)]
        return np.dstack([r, g, b])

# Function to read a grayscale image (for ground truth and similar single-band images)
def read_grayscale_image(image_path):
    with rasterio.open(image_path) as src:
        band = src.read(1)
        return np.clip(band / band.max(), 0, 1)

# Get a list of image filenames from one sensor folder
image_files = sorted(os.listdir(os.path.join(base_dir, sensor_folders["PS"])))

# Process each image one at a time
for img_file in image_files:
    fig, axes = plt.subplots(
        len(sensor_folders), 4, figsize=(16, 3.5 * len(sensor_folders)),  
        gridspec_kw={'hspace': 0.02, 'wspace': 0.05}  # Reduce whitespace
    )

    for row, (sensor, folder) in enumerate(sensor_folders.items()):
        img_path = os.path.join(base_dir, folder, img_file)
        gt_path = os.path.join(ground_truth_dir, img_file)

        if not os.path.exists(img_path):
            print(f"Skipping {img_file} for {sensor}, file not found.")
            continue

        # Read images
        image_rgb = read_image_bands(img_path)
        ground_truth_gray = read_grayscale_image(gt_path)
        clay_rgb = read_image_bands(os.path.join(base_dir, model_mapping[sensor]["Clay"], img_file))
        prithvi_rgb = read_image_bands(os.path.join(base_dir, model_mapping[sensor]["Prithvi"], img_file))

        # ✅ FIX: Set Y-label (sensor name) for each row using plt.ylabel()
        plt.sca(axes[row, 0])  # Select the first subplot in the row
        plt.ylabel(sensor, fontsize=18, fontweight='bold', rotation=90, labelpad=10)

        # Plot images
        axes[row][0].imshow(image_rgb)
        axes[row][1].imshow(ground_truth_gray, cmap="gray")
        axes[row][2].imshow(clay_rgb)
        axes[row][3].imshow(prithvi_rgb)

        # Remove axes for clean look
        for col in range(4):
            axes[row][col].axis("off")

    # Set column titles
    column_titles = ["Image", "GT", "Clay", "Prithvi"]
    for col, title in enumerate(column_titles):
        axes[0][col].set_title(title, fontsize=18, fontweight='bold', pad=10)

    # ✅ FIX: Slightly adjust layout to prevent cutting off labels
    fig.subplots_adjust(left=0.1, top=0.95)  

    # Save the figure
    output_path = os.path.join(output_dir, f"{img_file.replace('.tif', '.png')}")
    plt.savefig(output_path, dpi=400, bbox_inches='tight', pad_inches=0.1)
    plt.close(fig)  # Free memory

print(f"All plots saved in {output_dir}")


In [None]:

# Define base directory
base_dir = r"D:\CVPR\Flood_Results\Spain\main_figure"

# Define sensor folders
sensor_folders = {
    "PS": "PS_image",
    "S2": "Sen2_image",
    "S1": "Sen1_image",
}

# Define corresponding model folders
model_mapping = {
    "PS": {"Clay": "Clay_PS", "Prithvi": "Prithvi_PS"},
    "S2": {"Clay": "Clay_S2", "Prithvi": "Prithvi_S2"},
    "S1": {"Clay": "Clay_S1", "Prithvi": "Prithvi_S1"},
}

# Ground truth directory
ground_truth_dir = os.path.join(base_dir, "ground_truth")

# Output directory for saved plots
output_dir = os.path.join(base_dir, "saved_plots")
os.makedirs(output_dir, exist_ok=True)

# Function to read and visualize an image using a 4-3-2 band combination
def read_image_bands(image_path):
    with rasterio.open(image_path) as src:
        num_bands = src.count
        if num_bands >= 4:
            r, g, b = src.read(4), src.read(3), src.read(2)
        else:
            r = g = b = src.read(1)
        
        r, g, b = [np.clip(band / band.max(), 0, 1) for band in (r, g, b)]
        return np.dstack([r, g, b])

# Function to read a grayscale image (for ground truth and similar single-band images)
def read_grayscale_image(image_path):
    with rasterio.open(image_path) as src:
        band = src.read(1)
        return np.clip(band / band.max(), 0, 1)

# Get a list of image filenames from one sensor folder
image_files = sorted(os.listdir(os.path.join(base_dir, sensor_folders["PS"])))

# Row titles corresponding to sensors
row_titles = ["PlanetScope", "Sentinel 2", "Sentinel 1"]

# Process each image one at a time
for img_file in image_files:
    fig, axes = plt.subplots(
        len(sensor_folders), 4, figsize=(15, 3.5 * len(sensor_folders)),  
        gridspec_kw={'hspace': 0.02, 'wspace': -0.39}  # Reduce whitespace
    )

    for row, (sensor, folder) in enumerate(sensor_folders.items()):
        img_path = os.path.join(base_dir, folder, img_file)
        gt_path = os.path.join(ground_truth_dir, img_file)

        if not os.path.exists(img_path):
            print(f"Skipping {img_file} for {sensor}, file not found.")
            continue

        # Read images
        image_rgb = read_image_bands(img_path)
        ground_truth_gray = read_grayscale_image(gt_path)
        clay_rgb = read_image_bands(os.path.join(base_dir, model_mapping[sensor]["Clay"], img_file))
        prithvi_rgb = read_image_bands(os.path.join(base_dir, model_mapping[sensor]["Prithvi"], img_file))

        # ✅ FIX: Set Y-label (sensor name) for each row using plt.ylabel()
        plt.sca(axes[row, 0])  # Select the first subplot in the row
        plt.ylabel(row_titles[row], fontsize=18, fontweight='bold', rotation=90, labelpad=10)

        # Plot images
        axes[row][0].imshow(image_rgb)
        axes[row][1].imshow(ground_truth_gray, cmap="gray")
        axes[row][2].imshow(clay_rgb)
        axes[row][3].imshow(prithvi_rgb)

        # Remove axes for clean look
        for col in range(4):
            axes[row][col].axis("off")

    # Set column titles
    column_titles = ["Image", "Ground truth", "Clay", "Prithvi"]
    for col, title in enumerate(column_titles):
        axes[0][col].set_title(title, fontsize=18, fontweight='bold', pad=10)
    # Adjust row titles and move Sentinel 1 & Sentinel 2 slightly up
    row_titles = ["PlanetScope", "Sentinel 2", "Sentinel 1"]
    for i, title in enumerate(row_titles):
        y_offset = 0  # Default offset
        if title == "Sentinel 2":
            y_offset = 0.05  # Move Sentinel 2 up slightly
        elif title == "Sentinel 1":
            y_offset = 0.075  # Move Sentinel 1 up slightly
    
        fig.text(-0.02, 1 - (i + 0.5) / len(row_titles) + y_offset, title, 
                 fontsize=18, fontweight='bold', ha='right', va='center', rotation=90)

    
    # ✅ FIX: Slightly adjust layout to prevent cutting off labels
    fig.subplots_adjust(left=-0.095, top=0.95)  
    
    # Save the figure
    output_path = os.path.join(output_dir, f"{img_file.replace('.tif', '.png')}")
    plt.savefig(output_path, dpi=400, bbox_inches='tight', pad_inches=0.1)
    plt.close(fig)  # Free memory

print(f"All plots saved in {output_dir}")


## Plottitng data size vs IOU and Model's Performance on smaller data

In [None]:

# Load the uploaded CSV files
clay_file_path = r"D:\CVPR\Flood_Results/Clay.csv"
prithvi_file_path = r"D:\CVPR\Flood_Results/Prithvi.csv"

# Read CSV files
clay_df = pd.read_csv(clay_file_path)
prithvi_df = pd.read_csv(prithvi_file_path)

# Convert time from seconds to minutes
clay_df["time_minutes"] = clay_df["time_seconds"] / 60
prithvi_df["time_minutes"] = prithvi_df["time_seconds"] / 60

# Training data sizes
x_values = [2, 5, 10, 20, 50, 100, 150, 200, 250, 300, 350]

# IOU values for different models
clay_iou = [0.58, 0.64, 0.64, 0.64, 0.754, 0.7553, 0.7524, 0.7563, 0.7157, 0.7095, 0.7809]
prithvi_iou = [0.15, 0.24, 0.44, 0.45, 0.59, 0.62, 0.65, 0.67, 0.65, 0.66, 0.66]

# Create figure and subplots (vertically arranged)
fig, axes = plt.subplots(2, 1, figsize=(8, 10))

# Define colors for consistency
clay_color = 'blue'
prithvi_color = 'red'

# First subplot: Training Time Comparison
axes[0].plot(clay_df["epoch"], clay_df["time_minutes"], label="Clay", marker='o', linestyle='-', color=clay_color, markersize=5, linewidth=1.5)
axes[0].plot(prithvi_df["epoch"], prithvi_df["time_minutes"], label="Prithvi", marker='s', linestyle='--', color=prithvi_color, markersize=5, linewidth=1.5)

axes[0].set_xlabel("Epochs", fontsize=11)
axes[0].set_ylabel("Training Time (minutes)", fontsize=11)
axes[0].set_ylim(0, max(clay_df["time_minutes"].max(), prithvi_df["time_minutes"].max()) + 2)
axes[0].set_yticks(range(0, int(max(clay_df["time_minutes"].max(), prithvi_df["time_minutes"].max()) + 2), 2))
axes[0].legend(fontsize=12)
axes[0].tick_params(axis='both', labelsize=12)
axes[0].text(0.05, 0.9, "(a)", transform=axes[0].transAxes, fontsize=16) #fontweight="bold"

# Second subplot: IOU Performance
axes[1].plot(x_values, clay_iou, marker='o', linestyle='-', label='Clay', color=clay_color, markersize=5, linewidth=2)
axes[1].plot(x_values, prithvi_iou, marker='s', linestyle='--', label='Prithvi', color=prithvi_color, markersize=5, linewidth=2)

axes[1].set_xlabel('Training Data Size', fontsize=12)
axes[1].set_ylabel('Model IOU', fontsize=12)
axes[1].legend(fontsize=11)
axes[1].tick_params(axis='both', labelsize=12)
axes[1].text(0.05, 0.9, "(b)", transform=axes[1].transAxes, fontsize=16) #fontweight="bold"

# Inset plot for 0-10 training data range
ax_inset = fig.add_axes([0.71, 0.17, 0.25, 0.2])  # [left, bottom, width, height] for inset placement
ax_inset.plot(x_values[:3], clay_iou[:3], marker='o', linestyle='-', label='Clay', color=clay_color, markersize=5, linewidth=2)
ax_inset.plot(x_values[:3], prithvi_iou[:3], marker='s', linestyle='--', label='Prithvi', color=prithvi_color, markersize=5, linewidth=2)

ax_inset.set_xlabel('Train Data (0-10)', fontsize=9)
ax_inset.set_ylabel('IOU', fontsize=9)
ax_inset.set_xticks(x_values[:3])
ax_inset.tick_params(axis='both', labelsize=9)

plt.tight_layout()
plt.savefig(r"D:\CVPR\Flood_Results\model_comparison.png", dpi=300)
# Show plot
plt.show()
