In [1]:
import rasterio
import cv2
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import os
from scipy.spatial import cKDTree
import seaborn as sns
import heapq
from tqdm import tqdm  
import math
import rasterio
from scipy.ndimage import uniform_filter  
from astropy.convolution import convolve, Tophat2DKernel, Box2DKernel
import warnings
from scipy.spatial.distance import cdist
from scipy import signal
import time
from pathlib import Path
from skimage import morphology, measure

  "class": algorithms.Blowfish,


In [2]:
############################################### Read datasets ################################################
##############################################################################################################
##############################################################################################################
##Prior to loading the Digital Elevation Model (DEM) and original low resolution water mask, essential 
#preprocessing must be performed to ensure proper code execution. For the DEM data, this includes filling sinks 
#to remove hydrological depressions, along with resampling and reprojecting to match the water mask raster's grid
#resolution, spatial extent, and coordinate reference system. The DEM should be stored in float16 or float32 format. 
##These preprocessing steps can be completed using ArcGIS, ArcGIS Pro, or equivalent Python libraries. 
##The term 'original low resolution water mask' specifically refers to a water mask that has been resampled from 
#its native coarse resolution (e.g., 250m) to a finer resolution (e.g., 30m) using nearest-neighbor interpolation,
#designated as NEAREST in resampling operations.
dem_path = r'C:/Users/XXXX/dem_GDEMV3.tif'
water_path = r'C:/Users/XXXX/Original_flood.tif'
with rasterio.open(dem_path) as dem_src:
    dem = dem_src.read(1)  
    dem_extent = [dem_src.bounds.left, dem_src.bounds.right, dem_src.bounds.bottom, dem_src.bounds.top]
    dem_transform = dem_src.transform  
    dem_nodata = dem_src.nodata  
with rasterio.open(water_path) as water_src:
    flood = water_src.read(1)  
    flood = np.where(flood > 0, 1, 0).astype(np.uint8)  
    water_extent = [water_src.bounds.left, water_src.bounds.right, water_src.bounds.bottom, water_src.bounds.top]
    flood_array = np.where(flood == water_src.nodata, 0, flood).astype(np.int8)
    water_meta = water_src.meta.copy()  
    water_meta.update({
        'driver': 'GTiff',
        'count': 1,
        'dtype': 'float32'  
    })
##############################################################################################################
##############################################################################################################

In [3]:
######################################### Parameter Settings #################################################
##############################################################################################################
##############################################################################################################
##The following eight parameters represent the primary user-configurable settings in SRWD-Hydro, with 
#asterisk-marked (*) parameters being non-recommended for adjustment as they are designated for future optimization. 
##diffs_data: This parameter performs initial filtering of growth starting points by retaining spatially adjacent 
#points with value differences smaller than this threshold. 
##LOWER_PERCENTILE & UPPER_PERCENTILE: These parameters conduct secondary filtering on the initially selected
#points, preserving only those within this percentile range to mitigate extreme low/high value impacts and reduce 
#modeling noise. 
##Area_half_distance：Flooded areas of this size (km2) reach half of the maximum propagation distance, as 
#detaled in FLEXTH (《Water depth estimate and flood extent enhancement for satellite-based inundation maps》)
##Dmax: Maximum propagation distance in km, as detaled in FLEXTH (same as above)
##n_smooth: This parameter determines the number of smoothing iterations applied to the computed Water Surface 
#Elevation (WSE) image to mitigate jagged distributions of extreme water surfaces.
##num_smooth*: This parameter (not-recommended adjustment) controls DEM smoothing iterations. 
##Constant_distance*: This parameter (range 0-1) regulates global water propagation performance, where higher 
#values decrease propagation capability and lower values enhance it.

diffs_data = 0.01  # (m)
LOWER_PERCENTILE = 15  # Quantile truncation range
UPPER_PERCENTILE = 95  # Quantile truncation range
PER = UPPER_PERCENTILE
Area_half_distance = 100  # (km²)
Dmax = 15  # (km)
n_smooth = 20  # WSE smoothing iterations
num_smooth = 1  # DEM smoothing iterations
Constant_distance = 1
##############################################################################################################
##############################################################################################################

In [4]:
#################### Do not change the following codes unless you have better ideas ##########################
########################################## Calculate the restrictions ########################################
##############################################################################################################
##############################################################################################################
res_x = abs(dem_transform.a) * 1e5  #Horizontal resolution (m)
res_y = abs(dem_transform.e) * 1e5  #Vertical resolution (m)
water_pixel_count = np.count_nonzero(flood_array)
Area_each_pixel = res_x * res_y  # (km²)
water_Area_m2 = water_pixel_count * Area_each_pixel  
water_Area_km2 = water_Area_m2 / 1e6
dmax = Dmax*1000*(1-pow(2, -water_Area_km2/Area_half_distance))#(m²)
max_steps = math.floor(dmax/res_x)
max_steps = max_steps + 1
###############################################################################################################
###############################################################################################################

In [5]:
########################################### Boundary processing ##############################################
##############################################################################################################
##############################################################################################################
kernel_boundary = np.ones((3,3), np.uint8)
dilated = cv2.dilate(flood, kernel_boundary)  # Outer boundary
eroded = cv2.erode(flood, kernel_boundary)    # Inner boundary
outer_boundary = dilated - flood
inner_boundary = flood - eroded  
outer_dem = dem[outer_boundary == 1]
inner_dem = dem[inner_boundary == 1]
# DEM processing
dem = np.where(dem == dem_nodata, 0, dem)
kernel_average = np.ones((3,3), dtype=np.float32) / 9.0
smoothed_dem = dem.copy()
for _ in range(num_smooth):
    smoothed_dem = cv2.filter2D(
        smoothed_dem, 
        -1, 
        kernel_average, 
        borderType=cv2.BORDER_CONSTANT,
    )
# Boundary DEM values
outer_dem_average = smoothed_dem[outer_boundary == 1]
inner_dem_average = smoothed_dem[inner_boundary == 1]
# Initialize boundary arrays
outer_tifsave = np.zeros_like(dem, dtype=np.float32)
inner_tifsave = np.zeros_like(dem, dtype=np.float32)
outer_tifsave[outer_boundary == 1] = outer_dem_average.astype(np.float32)
inner_tifsave[inner_boundary == 1] = inner_dem_average.astype(np.float32)
# Boundary WSE calculation
ib_rows, ib_cols = np.where((inner_boundary == 1) & (inner_tifsave != water_meta['nodata']))
ob_rows, ob_cols = np.where((outer_boundary == 1) & (outer_tifsave != water_meta['nodata']))
ob_points = np.column_stack((ob_rows, ob_cols))
ob_kdtree = cKDTree(ob_points)
boundary_WSE = np.full_like(inner_tifsave, fill_value=water_meta['nodata'], dtype=np.float32)

for i in range(len(ib_rows)):
    ib_row, ib_col = ib_rows[i], ib_cols[i]
    ib_point = np.array([ib_row, ib_col])
    
    min_distance, nearest_idx = ob_kdtree.query(ib_point, k=1)
    
    candidates = ob_kdtree.query_ball_point(ib_point, min_distance + 1e-6)
    
    valid_indices = []
    for idx in candidates:
        actual_distance = np.linalg.norm(ob_points[idx] - ib_point)
        if np.isclose(actual_distance, min_distance, atol=1e-5, rtol=1e-5):
            valid_indices.append(idx)
    
    if valid_indices:
        candidate_rows = ob_rows[valid_indices]
        candidate_cols = ob_cols[valid_indices]
        candidate_values = outer_tifsave[candidate_rows, candidate_cols]
        
        min_value_idx = np.argmin(candidate_values)
        selected_row = candidate_rows[min_value_idx]
        selected_col = candidate_cols[min_value_idx]
        ot_value = outer_tifsave[selected_row, selected_col]
    else:
        ot_value = water_meta['nodata'] 
    
    it_value = inner_tifsave[ib_row, ib_col]
    if ot_value != water_meta['nodata']:
        boundary_WSE[ib_row, ib_col] = (it_value + ot_value) / 2.0
    else:
        boundary_WSE[ib_row, ib_col] = water_meta['nodata']
##############################################################################################################
##############################################################################################################

In [6]:
############################################ Boundary filtering ##############################################
##############################################################################################################
##############################################################################################################
# Boundary filtering A
def filter_boundary_wse(boundary_WSE):
    boundary_WSE_small_part = np.zeros_like(boundary_WSE, dtype=np.float32)
    rows, cols = boundary_WSE.shape
    
    for i in range(rows):
        for j in range(cols):
            current_value = boundary_WSE[i, j]
            if current_value == 0:
                continue
            
            min_row = max(0, i - 1)
            max_row = min(rows, i + 2)
            min_col = max(0, j - 1)
            max_col = min(cols, j + 2)
            
            neighborhood = boundary_WSE[min_row:max_row, min_col:max_col]
            
            local_i = i - min_row  
            local_j = j - min_col  
            mask = np.ones(neighborhood.shape, dtype=bool)
            mask[local_i, local_j] = False
            neighbor_values = neighborhood[mask]
            valid_neighbors = neighbor_values[neighbor_values != 0]
            
            if len(valid_neighbors) == 0:
                boundary_WSE_small_part[i, j] = 0
            else:
                diffs = np.abs(valid_neighbors - current_value)
                if np.any(diffs <= diffs_data):
                    boundary_WSE_small_part[i, j] = current_value
                else:
                    boundary_WSE_small_part[i, j] = 0
    
    return boundary_WSE_small_part
boundary_WSE_small_part = filter_boundary_wse(boundary_WSE)
##############################################################################################################
# Boundary filtering B
nodata = water_meta['nodata']
valid_data = boundary_WSE_small_part[boundary_WSE_small_part != nodata]
lower_threshold = np.percentile(valid_data, LOWER_PERCENTILE) 
upper_threshold = np.percentile(valid_data, UPPER_PERCENTILE)  
print(f"Quantile truncation: WSE values ≥ {lower_threshold:.2f} and ≤ {upper_threshold:.2f} ")

boundary_WSE_trimmed = np.where(
    ((boundary_WSE_small_part <= lower_threshold) |  
     (boundary_WSE_small_part >= upper_threshold)) &  
    (boundary_WSE_small_part != nodata),  
    nodata, boundary_WSE_small_part)
##############################################################################################################
# Ensure WSE >= DEM
boundary_WSE_trimmed = np.where(
    (boundary_WSE_trimmed >= dem) | (boundary_WSE_trimmed == water_meta['nodata']),
    boundary_WSE_trimmed,
    water_meta['nodata'])
##############################################################################################################
##############################################################################################################

Quantile truncation: WSE values ≥ 1.50 and ≤ 5.00 


In [7]:
######################################### Region growing algorithms ##########################################
##############################################################################################################
##############################################################################################################
def region_growing_A(dem, boundary_trimmed, nodata=0):
    result = np.where(boundary_trimmed != nodata, boundary_trimmed, nodata).astype(np.float32)
    
    heap = []
    seed_points = np.argwhere(result != nodata)
    for (y, x) in seed_points:
        heapq.heappush(heap, (-result[y, x], y, x))  
    offsets = [(-1,-1), (-1,0), (-1,1),
               (0,-1),          (0,1),
               (1,-1),  (1,0), (1,1)]
    with tqdm(total=len(seed_points), desc="Region Growing A", leave=False) as pbar:
        while heap:
            neg_val, y, x = heapq.heappop(heap)
            current_val = -neg_val

            if result[y, x] > current_val:
                continue
                
            for dy, dx in offsets:
                ny, nx = y + dy, x + dx
                
                if ny < 0 or ny >= dem.shape[0] or nx < 0 or nx >= dem.shape[1]:
                    continue
                
                if dem[ny, nx] > current_val:
                    continue
                
                if current_val > result[ny, nx]:
                    result[ny, nx] = current_val
                    heapq.heappush(heap, (-current_val, ny, nx))
                    pbar.update(1)
    
    return result
##############################################################################################################
def region_growing_B(dem, boundary_trimmed, nodata=0, max_steps=100):
    result = np.where(boundary_trimmed != nodata, boundary_trimmed, nodata).astype(np.float32)

    heap = []
    step_count = np.full_like(dem, fill_value=-1, dtype=np.int32)  
    seed_points = np.argwhere(result != nodata)

    for (y, x) in seed_points:
        heapq.heappush(heap, (-result[y, x], y, x, 0))  
        step_count[y, x] = 0  

    offsets = [(-1,-1), (-1,0), (-1,1),
               (0,-1),          (0,1),
               (1,-1),  (1,0), (1,1)]

    with tqdm(total=len(seed_points), desc="Region Growing B", leave=False) as pbar:
        while heap:
            neg_val, y, x, steps = heapq.heappop(heap)
            current_val = -neg_val

            if result[y, x] > current_val:
                continue

            if steps >= max_steps:
                continue

            for dy, dx in offsets:
                ny, nx = y + dy, x + dx

                if 0 <= ny < dem.shape[0] and 0 <= nx < dem.shape[1]:
                    if dem[ny, nx] <= current_val:
                        if result[ny, nx] < current_val:
                            result[ny, nx] = current_val
                            heapq.heappush(heap, (-current_val, ny, nx, steps + 1))
                            step_count[ny, nx] = steps + 1
                            pbar.update(1)

    return result
##############################################################################################################
# Execute region growing
grown_result_A = region_growing_A(dem, boundary_WSE_trimmed, nodata=water_meta['nodata'])
inner_grown_result = np.where(flood == 1,grown_result_A,water_meta['nodata']).astype(np.float32) 
grown_result_A = inner_grown_result
grown_result_B = region_growing_B(dem, boundary_WSE_trimmed, nodata=water_meta['nodata'], max_steps=max_steps)
#combination
grown_result_ALL = np.maximum(grown_result_A, grown_result_B)
grown_result_ALL = np.where((grown_result_A == water_meta['nodata']) & (grown_result_B == water_meta['nodata']),
                           water_meta['nodata'],
                           grown_result_ALL)
grown_result_Water_mask = np.where((grown_result_ALL != water_meta['nodata']) & (grown_result_ALL > 0),1,water_meta['nodata'])

a_coords = np.argwhere(grown_result_A != water_meta['nodata'])
if len(a_coords) == 0:
    adjusted_result = grown_result_ALL.copy()
else:
    a_tree = cKDTree(a_coords)
    mask = (grown_result_A == water_meta['nodata']) & (grown_result_ALL != water_meta['nodata'])
    y_coords, x_coords = np.where(mask)
    
    if len(y_coords) > 0:
        target_points = np.column_stack((y_coords, x_coords))
        distances, indices = a_tree.query(target_points, k=1)
        a_points = a_coords[indices]
        a_y, a_x = a_points[:, 0], a_points[:, 1]
        A_vals = grown_result_A[a_y, a_x]
        A_dems = dem[a_y, a_x]
        current_dems = dem[y_coords, x_coords]
        dx_pixels = (x_coords - a_x).astype(np.float32)
        dy_pixels = (y_coords - a_y).astype(np.float32)
        distance_m = np.hypot(dx_pixels * res_x, dy_pixels * res_y)  
        factor = np.clip(distance_m / dmax, 0.0, 1.0)  
        new_vals = A_vals - (A_vals - A_dems) * factor * Constant_distance
        valid_mask = (new_vals >= current_dems) & (new_vals != water_meta['nodata'])
        adjusted_vals = np.where(valid_mask, new_vals, water_meta['nodata'])
        adjusted_result = grown_result_ALL.copy()
        adjusted_result[y_coords, x_coords] = adjusted_vals
    else:
        adjusted_result = grown_result_ALL.copy()

#making sure that WSE>dem
final_result_way_1 = np.where(
    (adjusted_result >= dem) | (adjusted_result == water_meta['nodata']),
    adjusted_result,
    water_meta['nodata']
)
##############################################################################################################
##############################################################################################################

                                                                                                                       

In [8]:
############################################# smooth the WSE for n times ####################################
##############################################################################################################
##############################################################################################################
def smooth_multiple_times(data, n, nodata_val):
    mask = (data == nodata_val)
    result = data.copy()
    for _ in range(n):
        temp = np.where(mask, 0, result)
        valid_count = uniform_filter((~mask).astype(float), size=3, mode='nearest')
        smoothed = uniform_filter(temp, size=3, mode='nearest')
        with np.errstate(invalid='ignore', divide='ignore'):
            smoothed = smoothed / valid_count
        smoothed[mask] = nodata_val
        result = smoothed
    return result.astype(np.float32)
final_result_way_1_smoothed = smooth_multiple_times(final_result_way_1, n_smooth, water_meta['nodata'])
Grown_result_flood_mask = np.where(final_result_way_1_smoothed > 0, 1, 0).astype(np.uint8)
boundary_mask = (boundary_WSE_trimmed > 0).astype(np.uint8)

combined_mask = np.where(
    (Grown_result_flood_mask == 1) | (boundary_mask == 1),
    1,
    0
).astype(np.uint8)

num_labels, labels = cv2.connectedComponents(combined_mask, connectivity=8)

boundary_coords = np.argwhere(boundary_mask == 1)
if boundary_coords.size > 0:
    boundary_labels = labels[boundary_coords[:,0], boundary_coords[:,1]]
    valid_labels = np.unique(boundary_labels[boundary_labels != 0])
    retained_mask = np.isin(labels, valid_labels)
    # Filter the growth results: Keep the parts that are connected to the boundary point 
    Grown_result_flood_mask = np.where(
        (Grown_result_flood_mask == 1) & retained_mask,
        1,
        0
    ).astype(np.uint8)
else:
    Grown_result_flood_mask[:, :] = 0
final_result_way_1_smoothed = np.where(Grown_result_flood_mask == 1, final_result_way_1_smoothed, water_meta['nodata'])
#############################################################################################################
#############################################################################################################

In [9]:
##################################### Caculate the Water Depth ##############################################
#############################################################################################################
#############################################################################################################
water_depth = final_result_way_1_smoothed - dem
water_depth = np.where(Grown_result_flood_mask == 1,water_depth,water_meta['nodata'])
water_depth = np.where((water_depth < 0) & (water_depth != water_meta['nodata']),0,  water_depth)
#############################################################################################################
#############################################################################################################

In [10]:
############################################## save results #################################################
#############################################################################################################
#############################################################################################################
filename_suffix_flood = f"{diffs_data:.2f}_{Area_half_distance}_{Dmax}_{PER}_{num_smooth}_{Constant_distance}"
filename_suffix_WSE = f"{diffs_data:.2f}_{Area_half_distance}_{Dmax}_{PER}_{num_smooth}_{Constant_distance}+{n_smooth}"
filename_suffix_WD = f"{diffs_data:.2f}_{Area_half_distance}_{Dmax}_{PER}_{num_smooth}_{Constant_distance}+{n_smooth}"

output_path = os.path.join(os.path.dirname(water_path), f"flood_{filename_suffix_flood}.tif")
with rasterio.open(output_path, 'w', **water_meta) as flood_dst:
    flood_dst.write(Grown_result_flood_mask.astype(np.float32), 1)
print(f"path: {output_path}")

output_path = os.path.join(os.path.dirname(water_path), f"WSE_Result_{filename_suffix_WSE}.tif")
with rasterio.open(output_path, 'w', **water_meta) as final_result_way_1_smoothed_dst:
    final_result_way_1_smoothed_dst.write(final_result_way_1_smoothed.astype(np.float32), 1)
print(f"path: {output_path}")

output_path = os.path.join(os.path.dirname(water_path), f"WD_Result_{filename_suffix_WD}.tif")
with rasterio.open(output_path, 'w', **water_meta) as water_depth_dst:
    water_depth_dst.write(water_depth.astype(np.float32), 1)
print(f"path: {output_path}")
#############################################################################################################
#############################################################################################################

path: C:/Users/RedPaaocai/Desktop/paper4/paper4_experiment_data/Coast_Waters/Changjiang_River/30m\flood_0.01_100_15_95_1_1.tif
path: C:/Users/RedPaaocai/Desktop/paper4/paper4_experiment_data/Coast_Waters/Changjiang_River/30m\WSE_Result_0.01_100_15_95_1_1+20.tif
path: C:/Users/RedPaaocai/Desktop/paper4/paper4_experiment_data/Coast_Waters/Changjiang_River/30m\WD_Result_0.01_100_15_95_1_1+20.tif
