# Add canopy height model to reference DEM

In [1]:
import os
import xdem
import geoutils as gu
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import rioxarray as rxr
import xarray as xr

In [2]:
# Define input layers
data_dir = '/Volumes/LaCie/raineyaberle/Research/PhD/Skysat-Stereo/study-sites/MCS/'
refdem_fn = os.path.join(data_dir, 'refdem', 'MCS_REFDEM_WGS84.tif')
chm_fn = os.path.join(data_dir, 'refdem', 'chm_mcs_1m.tif')

In [15]:
# Explore the canopy height model
chm = rxr.open_rasterio(chm_fn).squeeze()
crs = chm.rio.crs
chm = xr.where(chm < -1e10, np.nan, chm) # remove no data values

# Plot pre-masked data
# plt.figure(figsize=(12,12))
# lt0_mask = xr.where(chm < 0, 1, np.nan)
# im = plt.imshow(lt0_mask.data, cmap='Reds', clim=(0,1))
# plt.colorbar(orientation='horizontal', label='Canopy height [m]', shrink=0.8)
# plt.xticks([])
# plt.yticks([])
# # ax[1].hist(np.ravel(chm.data), bins=100)
# # ax[1].set_xlabel('Canopy height [m]')
# plt.show()

# # Mask values < 0
# chm = xr.where(chm < 0, np.nan, chm)

# # plot
# fig, ax = plt.subplots(1, 2, figsize=(10,5))
# im = ax[0].imshow(chm.data, cmap='Greens', clim=(0,30))
# fig.colorbar(im, ax=ax[0], orientation='horizontal', label='Canopy height [m]', shrink=0.8)
# ax[0].set_xticks([])
# ax[0].set_yticks([])
# ax[1].hist(np.ravel(chm.data), bins=100)
# ax[1].set_xlabel('Canopy height [m]')
# plt.show()

In [23]:
len(np.argwhere(chm.data < 0)) / len(np.argwhere(~np.isnan(chm.data)))

0.045575943295401926

In [None]:
# Add vegetation to bare earth DEM
out_fn = refdem_fn.replace('.tif', '_CHM.tif')

refdem = xdem.DEM(refdem_fn)
chm = gu.Raster(chm_fn, load_data=True)

chm = chm.reproject(refdem)
refdem_chm = refdem + chm

fig, ax = plt.subplots(1, 3, figsize=(12,5))
refdem.plot(ax=ax[0], cmap='terrain')
chm.plot(ax=ax[1], cmap='Greens', vmin=0, vmax=45)
refdem_chm.plot(ax=ax[2], cmap='terrain')
fig.tight_layout()
plt.show()

refdem_chm.save(out_fn)
print('Reference DEM + CHM saved to file:', out_fn)

In [None]:
# Mask areas where veg >= threshold
threshold = 0
mask = (chm >= threshold)
masked_data = np.ma.masked_where((mask.data==1) | refdem_chm.data.mask, refdem_chm.data)
refdem_chm_masked = gu.Raster.from_array(data=masked_data,
                                         transform=refdem.transform, 
                                         crs=refdem.crs, 
                                         nodata=-9999)
# Plot
refdem_chm_masked.plot()

# Save to file
out_fn = refdem_fn.replace('.tif', f'_CHM-lte-{threshold}m.tif')
refdem_chm_masked.save(out_fn)
print(f'Reference DEM + CHM <= {threshold} m saved to file:', out_fn)