In [None]:
# import statements
import numpy as np
import rasterio
from scipy.spatial import cKDTree

In [None]:
# define input/output file paths
input_dem = ''    # input Depressionless DEM raster
''' rasterized DL data 
    (this can be recomputed also, 
    but we have used this from the temp outputs of stream order raster)'''
input_dl_raster = ''
output_path = ''  

In [None]:
# Load the elevation raster
with rasterio.open(input_dem) as src:
    elevation = src.read(1)
    elevation_transform = src.transform

# Load the drainage lines raster
with rasterio.open(input_dl_raster) as src:
    drainage_lines = src.read(1)

In [None]:

# Create a mask for drainage line pixels
drainage_line_mask = (drainage_lines > 0)

# Get coordinates of all pixels
rows, cols = np.indices(elevation.shape)
xy_coords = np.column_stack((rows.ravel(), cols.ravel()))

# Get drainage line pixels coordinates and their elevations
drainage_line_coords = xy_coords[drainage_line_mask.ravel()]
drainage_line_elevations = elevation[drainage_line_mask]

# Create a k-d tree for fast nearest neighbor search in pixel space
tree = cKDTree(drainage_line_coords)

# Prepare to store the minimum distances in pixel units
min_distances = np.full(elevation.shape, np.inf)

# Iterate over all non-drainage pixels
non_drainage_mask = ~drainage_line_mask
non_drainage_coords = xy_coords[non_drainage_mask.ravel()]

for coord in non_drainage_coords:
    i, j = coord
    current_elevation = elevation[i, j]

    # Find drainage line pixels with higher or equal elevation
    valid_drainage_mask = drainage_line_elevations >= current_elevation
    if np.any(valid_drainage_mask):
        valid_drainage_coords = drainage_line_coords[valid_drainage_mask]

        # Create a k-d tree for the valid drainage pixels
        valid_tree = cKDTree(valid_drainage_coords)

        # Find the nearest valid drainage line pixel
        dist, idx = valid_tree.query([coord], k=1)
        min_distances[i, j] = dist

# Round the distances to the nearest integer
min_distances_rounded = np.rint(min_distances).astype(np.int32)

# Set all distances less than 0 to -1
min_distances_rounded[min_distances_rounded < 0] = -1

# Save the result to the output raster file
with rasterio.open(output_path, 'w', 
                   driver='GTiff', 
                   height=elevation.shape[0], 
                   width=elevation.shape[1], 
                   count=1, 
                   dtype=rasterio.int32, 
                   crs=src.crs, 
                   transform=src.transform) as dst:
    dst.write(min_distances_rounded, 1)