# Mask output SkySat DEMs using the NMAD 

In [None]:
import xarray as xr
import rioxarray as rxr
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
import os, glob
from scipy.stats import iqr

## Define paths to data files

In [None]:
# Input file names
site_name = 'MCS'
dem_date = '20240420'
data_path = f'/Users/rdcrlrka/Research/PhD/study-sites/{site_name}/{dem_date}'
ortho_fn = os.path.join(data_path, f'{site_name}_{dem_date}_orthomosaic.tif')
dem_fn = os.path.join(data_path, f'{site_name}_{dem_date}_DEM.tif')
nmad_fn = os.path.join(data_path, f'{site_name}_{dem_date}_nmad_mos.tif')

# Masked DEM (output) file name
dem_masked_fn = dem_fn.replace('.tif', '_masked.tif')

for fn in [dem_fn, nmad_fn, ortho_fn]:
    if not os.path.exists(fn):
        print('File does not exist, check name:', fn)

## Load DEM, NMAD, and orthomosaic

In [None]:
# Load files
ortho = rxr.open_rasterio(ortho_fn)
ortho = xr.where(ortho==ortho._FillValue, np.nan, ortho)
dem = rxr.open_rasterio(dem_fn)
dem = xr.where(dem==dem._FillValue, np.nan, dem)
nmad = rxr.open_rasterio(nmad_fn)
nmad = xr.where(nmad==nmad._FillValue, np.nan, nmad)

# Plot
cbar_shrink = 0.7
fig, ax = plt.subplots(2, 2, figsize=(10,10))
ax = ax.flatten()
ortho_im = ax[0].imshow(ortho.data[0], cmap='Greys_r', 
             extent=(np.min(ortho.x.data)/1e3, np.max(ortho.x.data)/1e3, 
                     np.min(ortho.y.data)/1e3, np.max(ortho.y.data)/1e3))
fig.colorbar(ortho_im, ax=ax[0], shrink=cbar_shrink, label='Reflectance')
dem_im = ax[1].imshow(dem.data[0], cmap='terrain', 
                      extent=(np.min(dem.x.data)/1e3, np.max(dem.x.data)/1e3, 
                              np.min(dem.y.data)/1e3, np.max(dem.y.data)/1e3))
fig.colorbar(dem_im, ax=ax[1], shrink=cbar_shrink, label='Elevation [m]')
nmad_im = ax[2].imshow(nmad.data[0], cmap='Reds', 
                       extent=(np.min(nmad.x.data)/1e3, np.max(nmad.x.data)/1e3, 
                               np.min(nmad.y.data)/1e3, np.max(nmad.y.data)/1e3))
fig.colorbar(nmad_im, ax=ax[2], shrink=cbar_shrink, label='NMAD [m]')
ax[3].hist(np.ravel(nmad.data), bins=100)
ax[3].set_title('NMAD [m]')

plt.show()

## Test some methods for automatically setting an NMAD threshold

In [None]:
# Interpolate DEM to NMAD coordinates
dem_interp = dem.interp(x=nmad.x, y=nmad.y, method='nearest')
plt.imshow(dem_interp.data[0])

cmap = LinearSegmentedColormap.from_list("Custom", ['w','#67000d'])

# Define thresholds
thresholds = [[1.5 * iqr(np.ravel(nmad.data), nan_policy='omit'), '1.5 x IQR(NMAD)'],
              [np.nanpercentile(np.ravel(nmad.data), 95), 'P$_{95}$(NMAD)'],
              [np.nanpercentile(np.ravel(nmad.data), 99), 'P$_{99}$(NMAD)'],
              [5, 'value'],
              [10, 'value'],
              [20, 'value']
             ]
                   
# Iterate over thresholds
for thresh, thresh_name in thresholds:
    mask = xr.where(nmad >= thresh, 1, 0)
    dem_interp_masked = xr.where(mask == 1, np.nan, dem_interp)
    # plot
    fig, ax = plt.subplots(1, 2, figsize=(12,5))
    # ax[0].imshow(mask.data[0], cmap=cmap, clim=(0,1))
    ax[0].imshow(nmad.data[0], cmap=cmap, clim=(0,thresh))
    ax[1].imshow(dem_interp_masked.data[0])
    fig.suptitle(f"Mask NMAD >= {thresh_name} ({np.round(thresh, 2)})")
    plt.show()
