Much of this code is from the pysheds documentation: https://github.com/mdbartos/pysheds

In [None]:
# Libraries

# pysheds
from pysheds.grid import Grid

# plotting
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import matplotlib as matplot
import seaborn as sns

# other
import time
import copy

In [None]:
# Settings
DATA_FILE = "../data/washington_small.tif"

In [None]:
# Open rasters
grid = Grid.from_raster(DATA_FILE)
dem = grid.read_raster(DATA_FILE)

In [None]:
# Plot the DEM
fig, ax = plt.subplots(figsize=(8,6))
fig.patch.set_alpha(0)

plt.imshow(dem, extent=grid.extent, cmap='terrain', zorder=1)
plt.colorbar(label='Elevation (m)')
plt.grid(zorder=0)
plt.title('Digital elevation map', size=14)
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.tight_layout()

In [None]:
# Condition the DEM (takes about 40 seconds on my laptop for "washington_medium.tif")
# NOTE: "washington_medium.tif" is already conditioned, so there is really no need to do this step

preprocessed_dem = dem

"""
start_time = time.perf_counter()

# Fill pits in DEM
pit_filled_dem = grid.fill_pits(dem)

# Fill depressions in DEM
flooded_dem = grid.fill_depressions(pit_filled_dem)
    
# Resolve flats in DEM
inflated_dem = grid.resolve_flats(flooded_dem)

end_time = time.perf_counter()
execution_time = end_time - start_time

(f"Execution time: {execution_time:.6f} seconds")

preprocessed_dem = inflated_dem
"""


In [None]:
# Determine D8 flow directions from DEM

# ESRI scheme that specifies directions as numbers (which is the default)
"""
North: 64
Northeast: 128
East: 1
Southeast: 2
South: 4
Southwest: 8
West: 16
Northwest: 32
"""

#dirmap = (64, 128, 1, 2, 4, 8, 16, 32) # this is the default value
dirmap = (7, 8, 1, 2, 3, 4, 5, 6) # new value to make the D8 and Dinf plots the same
    
# Compute flow directions using D8
fdir = grid.flowdir(preprocessed_dem, dirmap=dirmap)

In [None]:
# Compute flow directions using Dinf
fdir_dinf = grid.flowdir(preprocessed_dem, routing='dinf')

# this is slightly slower than D8, but should give better accuracy

In [None]:
# Plot D8 flow direction

CMAP = "twilight_shifted" # using "twilight_shifted" for cyclical coloring

fig = plt.figure(figsize=(8,6))
fig.patch.set_alpha(0)

plt.imshow(fdir, extent=grid.extent, cmap=CMAP, zorder=2, vmin = 1, vmax = 8)
boundaries = ([0] + sorted(list(dirmap)))
plt.colorbar(boundaries= boundaries,
             values=sorted(dirmap))
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Flow direction grid (D8)', size=14)
plt.grid(zorder=-1)
plt.tight_layout()

In [None]:
# Plot DInf flow direction

fig = plt.figure(figsize=(8,6))
fig.patch.set_alpha(0)

# potentially could be used to show nan values, didn't get it to work yet though
#masked_array = np.ma.array(fdir_dinf, mask=np.isnan(fdir_dinf))
#cmap = matplot.cm.viridis
#cmap.set_bad('red', 1.)

plt.imshow(fdir_dinf, extent=grid.extent, cmap=CMAP, zorder=2, vmin = 0, vmax = 2 * np.pi)
ticks = [0, np.pi, 2 * np.pi]
cbar = plt.colorbar(ticks = ticks)
cbar.set_ticklabels(["0 (east?)", "pi (west?)", "two pi (east?)"])
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Flow direction grid (DInf)', size=14)
plt.grid(zorder=-1)
plt.tight_layout()

In [None]:
# Calculate flow accumulation
acc = grid.accumulation(fdir, dirmap=dirmap)

In [None]:
# Plot flow accumulation for D8
fig, ax = plt.subplots(figsize=(8,6))
fig.patch.set_alpha(0)
plt.grid('on', zorder=0)
im = ax.imshow(acc, extent=grid.extent, zorder=2,
               cmap='cubehelix',
               vmin = 1,
               vmax = 8,
               #norm=colors.LogNorm(1, acc.max()), #only needed for ESRI direction map scheme
               interpolation='bilinear')
plt.colorbar(im, ax=ax, label='Upstream Cells')
plt.title('Flow Accumulation', size=14)
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.tight_layout()

In [None]:
grid.bbox

In [None]:
# Delineate a catchment

# Specify pour point
#x, y = -97.294, 32.737
#x, y = -113.0495, 47.2395
#x, y = -114.0, 48.0

# mid point
x = ((grid.bbox[2] - grid.bbox[0]) / 2) + grid.bbox[0]
y = ((grid.bbox[3] - grid.bbox[1]) / 2) + grid.bbox[1]

# Snap pour point to high accumulation cell
x_snap, y_snap = grid.snap_to_mask(acc > 10000, (x, y))
#x_snap, y_snap = x, y

# Delineate the catchment
catch = grid.catchment(x=x_snap, y=y_snap, fdir=fdir, dirmap=dirmap, 
                       xytype='coordinate')


# Clip the bounding box to the catchment
#clipped_grid = copy.deepcopy(grid) # deep copy so that we can reuse grid if desired
#clipped_grid.clip_to(catch)
#clipped_catch = clipped_grid.view(catch)
#grid_use = clipped_grid

# for no clipping
clipped_catch = catch
grid_use = grid

In [None]:
# Plot the catchment
fig, ax = plt.subplots(figsize=(8,6))
fig.patch.set_alpha(0)

plt.grid('on', zorder=0)
im = ax.imshow(np.where(clipped_catch, clipped_catch, np.nan), extent=grid_use.extent,
               zorder=1, cmap='Greys_r')

# this shows the pour point
plt.scatter([x_snap], [y_snap], c='red', s=50, marker='o')

plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Delineated Catchment', size=14)

In [None]:
# Extract river network
branches = grid_use.extract_river_network(fdir, acc > 50, dirmap=dirmap)

In [None]:
# Plot the river network

sns.set_palette('husl')
fig, ax = plt.subplots(figsize=(8.5,6.5))

plt.xlim(grid_use.bbox[0], grid_use.bbox[2])
plt.ylim(grid_use.bbox[1], grid_use.bbox[3])
ax.set_aspect('equal')

# this shows the pour point
plt.scatter([x_snap], [y_snap], c='red', s=50, marker='o', zorder=3)

for branch in branches['features']:
    line = np.asarray(branch['geometry']['coordinates'])
    plt.plot(line[:, 0], line[:, 1])
    
_ = plt.title('D8 channels', size=14)

In [None]:
# Calculate distance to outlet from each cell
dist = grid_use.distance_to_outlet(x=x_snap, y=y_snap, fdir=fdir, dirmap=dirmap,
                               xytype='coordinate')

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
fig.patch.set_alpha(0)
plt.grid('on', zorder=0)
im = ax.imshow(dist, extent=grid_use.extent, zorder=2,
               cmap='cubehelix_r')
plt.colorbar(im, ax=ax, label='Distance to outlet (cells)')

# this shows the pour point
plt.scatter([x_snap], [y_snap], c='red', s=50, marker='o', zorder=3)

plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Flow Distance', size=14)