# Pipeline for coregistering and differencing SkySat DEM from reference DEM

In [None]:
import os, glob
import matplotlib.pyplot as plt
import matplotlib
import xarray as xr
import rioxarray as rxr
import rasterio as rio
import numpy as np
from tqdm.auto import tqdm
import xdem
import geoutils as gu
from skimage.filters import threshold_otsu
import math
import pandas as pd
import geopandas as gpd
import json
from affine import Affine

## Define paths in directory

In [None]:
# Define path to this code repo
base_path = '/Users/raineyaberle/Research/PhD/SnowDEMs/snow-dems/'

# Define site name and source DEM date for convenience
data_dir = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/MCS'
refdem_fn = os.path.join(data_dir, 'refdem', 'MCS_REFDEM_WGS84_CHM.tif')
sourcedem_fn = os.path.join(data_dir, '20240420', 'MCS_20240420-1_DEM.tif')
ortho_fn = os.path.join(data_dir, '20240420', 'MCS_20240420-1_4band_orthomosaic.tif')
roads_vector_fn = os.path.join(data_dir, 'roads', 'MCS_roads_polygon.shp')

# Define path for results
out_dir = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/snow_depth_maps'
job_name = 'MCS_20240420-1_reftrees_nosourcetrees'

# Check that files exist
if not os.path.exists(refdem_fn):
    print('Reference DEM not found, check path to file before continuing.')
if not os.path.exists(sourcedem_fn):
    print('Source DEM not found, check path to file before continuing.')
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    print('Created directory for output files:', out_dir)
    
# Define all output directories and files
# New output directory for full job
out_dir = os.path.join(out_dir, job_name)
# step 1: trees and stable surfaces (snow-free) masks
masks_dir = os.path.join(out_dir, 'land_cover_masks')
# step 2: Initial coregistration and differencing
coreg_init_dir = os.path.join(out_dir, 'coreg_initial_diff')
# step 3: Quadratic bias correction
deramp_dir = os.path.join(out_dir, 'deramping')
# step 4: Final coregistration and differencing
coreg_final_dir = os.path.join(out_dir, 'coreg_final_diff')

# Create directories
for directory in [out_dir, masks_dir, coreg_init_dir, deramp_dir, coreg_final_dir]:
    if not os.path.exists(directory):
        os.mkdir(directory)

## Create land cover masks for trees, roads, and stable surfaces

In [None]:
trees_mask_fn = os.path.join(masks_dir, 'trees_mask.tif')
roads_mask_fn = os.path.join(masks_dir, 'roads_mask.tif')
snow_mask_fn = os.path.join(masks_dir, 'snow_mask.tif')
ss_mask_fn = os.path.join(masks_dir, 'stable_surfaces_mask.tif')
fig_fn = os.path.join(masks_dir, 'land_cover_masks.png')

if not os.path.exists(ss_mask_fn):
    # Load orthomosaic 
    ortho_rxr = rxr.open_rasterio(ortho_fn)
    # rearrange ortho to have separate bands
    band_names = ['blue', 'green', 'red', 'NIR']
    ortho = xr.Dataset(coords=dict(y=ortho_rxr.y.data, x=ortho_rxr.x.data))
    for i, band_name in enumerate(band_names):
        ortho[band_name] = (('y', 'x'), ortho_rxr.data[i,:])
    ortho = xr.where(ortho==0, np.nan, ortho / 1e4)
    ortho = ortho.rio.write_crs(ortho_rxr.rio.crs)

    # Classify trees and other vegetation using NDVI
    ndvi = (ortho.NIR - ortho.green) / (ortho.NIR + ortho.green)
    ndvi.data[ortho.green==0] = np.nan
    ndvi_threshold = 0.1
    trees_mask = (ndvi >= ndvi_threshold).astype(int)
    trees_mask = trees_mask.assign_attrs({'Description': 'Constructed by thresholding the NDVI of the orthomosaic image',
                                          'NDVI bands': 'NIR, green',
                                          'NDVI threshold': f'{ndvi_threshold}'})

    # Convert roads to rasterized mask
    roads_vector = gpd.read_file(roads_vector_fn) # Load roads vector
    roads_mask = ortho.blue.rio.clip(roads_vector.geometry.values, roads_vector.crs, drop=False)
    roads_mask.data[ortho.green==0] = np.nan
    roads_mask = xr.where(np.isnan(roads_mask), 0, 1)
    roads_mask = roads_mask.assign_attrs({'Description': 'Roads mask constructed from the Source by buffering, rasterizing, and interpolating the shapefile to the orthomosaic image grid.',
                                          'Source': 'U.S. Geological Survey National Transportation Dataset for Idaho (published 20240215) Shapefile: https://www.sciencebase.gov/catalog/item/5a5f36bfe4b06e28e9bfc1be'})

    # Classify snow using the NDSI
    ndsi = (ortho.red - ortho.NIR) / (ortho.red + ortho.NIR)
    ndsi.data[ortho.green==0] = np.nan
    ndsi_threshold = 0.1
    snow_mask = ((ndsi >= ndsi_threshold) & (trees_mask==0) & (roads_mask==0)).astype(int)
    snow_mask = snow_mask.assign_attrs({'Description': 'Trees mask constructed by thresholding the NDSI of the orthomosaic image',
                                        'NDSI bands': 'red, NIR',
                                        'NDSI threshold': f'{ndsi_threshold}'})

    # Create stable surfaces mask (unclassified + roads)
    ss_mask = ((snow_mask==0) & (trees_mask==0)).astype(int)
    ss_mask = ss_mask.assign_attrs({'Description': 'Stable surfaces mask includes all road-covered, snow-free, and tree-free surfaces according to the trees_mask, snow_mask, and roads_mask files.'})

    # Plot
    plt.rcParams.update({'font.size': 12, 'font.sans-serif': 'Arial'})
    fig, ax = plt.subplots(2, 1, figsize=(8,16))
    ax[0].imshow(np.dstack([ortho.red, ortho.green, ortho.blue]) * 0.5, 
                 extent=(np.min(ortho.x.data)/1e3, np.max(ortho.x.data)/1e3,
                         np.min(ortho.y.data)/1e3, np.max(ortho.y.data)/1e3))
    ax[0].set_title('RGB orthoimage')
    xmin, xmax = ax[0].get_xlim()
    ymin, ymax = ax[0].get_ylim()
    # Iterate over masks
    colors = [(77/255, 175/255, 74/255, 1), # trees
              (55/255, 126/255, 184/255, 1), # snow
              (166/255, 86/255, 40/255, 1)] # roads
    for color, mask, label in zip(colors, 
                                  [trees_mask, snow_mask, roads_mask],
                                  ['trees_mask', 'snow_mask', 'roads_mask']):
        cmap = matplotlib.colors.ListedColormap([(1,1,1,0), color])
        ax[1].imshow(mask.data, cmap=cmap, clim=(0,1),
                     extent=(np.min(ortho.x.data)/1e3, np.max(ortho.x.data)/1e3,
                             np.min(ortho.y.data)/1e3, np.max(ortho.y.data)/1e3))
        # plot dummy point for legend
        ax[1].plot(0, 0, 's', color=color, markersize=5, label=label)
    ax[1].set_title('Land cover masks')
    # reset axes limits
    ax[1].set_xlim(xmin, xmax)
    ax[1].set_ylim(ymin, ymax)
    ax[1].legend(loc='lower right', markerscale=2)
    fig.tight_layout()
    plt.show()

    # Save to file
    for mask, mask_fn in zip([trees_mask, roads_mask, snow_mask, ss_mask],
                             [trees_mask_fn, roads_mask_fn, snow_mask_fn, ss_mask_fn]):
        mask = xr.where(np.isnan(ortho.green), -9999, mask)
        mask = mask.astype(np.int16)
        mask = mask.assign_attrs({'_FillValue': -9999})
        mask = mask.rio.write_crs(ortho_rxr.rio.crs)
        mask.rio.to_raster(mask_fn)
        print('Mask saved to file:', mask_fn)
    fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
    print('Figure saved to file:', fig_fn)
    
else:
    print('Land cover masks already exist in file, skipping...')

In [None]:
# Optional: Mask trees in DEM
sourcedem_masked_fn = os.path.join(masks_dir, os.path.basename(sourcedem_fn).replace('.tif', '_trees-masked.tif'))

if not os.path.exists(sourcedem_masked_fn):
    sourcedem = rxr.open_rasterio(sourcedem_fn).squeeze()
    # save some traits from input file
    crs = sourcedem.rio.crs
    attrs = sourcedem.attrs
    trees_mask = rxr.open_rasterio(trees_mask_fn).squeeze()
    trees_mask = trees_mask.interp(x=sourcedem.x.data, y=sourcedem.y.data).x.data
    # Mask source DEM where trees_mask == 1
    sourcedem_masked = xr.where(trees_mask==1, -9999, sourcedem)
    # Save to file
    sourcedem_masked = sourcedem_masked.rio.write_crs(crs)
    sourcedem_masked = sourcedem_masked.assign_attrs(attrs)
    sourcedem_masked.rio.to_raster(sourcedem_masked_fn)
    print('Masked source DEM saved to file:', sourcedem_masked_fn)
    sourcedem_masked.plot()
    plt.show()
    
else:
    print('Masked source DEM already exists in file, skipping.')
    


In [None]:
# # -----Use roads polygon for plotting differences
# if not os.path.exists(roads_mask_fn):
#     refdem = xdem.DEM(refdem_fn)
#     roads_fn = os.path.join(data_dir, 'MCS_roads_polygon.shp')
#     roads_vector = gu.Vector(roads_fn)
#     roads_raster = roads_vector.rasterize(refdem)
#     roads_raster.data.mask = refdem.data.mask
#     roads_raster = roads_raster.astype(np.int16)
#     roads_raster.set_nodata(-9999)
#     roads_raster.save(roads_mask_fn)
#     print('Roads raster saved to file:', roads_mask_fn)
    
#     roads_raster.plot()
#     plt.show()

# else:
#     print('Roads mask already exists in file.')
    

In [None]:
# # -----Apply Otsu thresholding to the orthomosaic
# # outputs
# ss_mask_fn = os.path.join(stable_surfaces_dir, 'stable_surfaces.tif')
# ss_mask_fig_fn = ss_mask_fn.replace('.tif', '.png')

# # Check if stable surfaces already exist in file
# if os.path.exists(ss_mask_fn):
#     print('Stable surfaces mask already exists in file.')

# else:

#     # Open orthomosaic
#     ortho = rxr.open_rasterio(ortho_fn)
#     # Save CRS for later
#     crs = ortho.rio.crs 
#     # Grab data
#     image = ortho.data[0]
#     # Create no-data mask
#     nodata_mask = image == 0
    
#     # Calculate and apply Otsu threshold
#     otsu_thresh = threshold_otsu(image)
#     ss_mask = (image < otsu_thresh).astype(float)

#     # Apply nodata mask
#     ss_mask[nodata_mask] = np.nan
#     image[nodata_mask] = np.nan
    
#     # Plot results
#     fig, axes = plt.subplots(1, 3, figsize=(12, 5))
#     ax = axes.ravel()
#     im = ax[0].imshow(image, cmap=plt.cm.gray,
#                       extent=(np.min(ortho.x.data)/1e3, np.max(ortho.x.data)/1e3,
#                               np.min(ortho.y.data)/1e3, np.max(ortho.y.data)/1e3))
#     fig.colorbar(im, ax=ax[0], orientation='horizontal', shrink=0.9)
#     ax[0].set_title('Original')
#     ax[1].hist(image.ravel(), bins=100, color='grey')
#     ax[1].set_title('Histogram')
#     ax[1].axvline(otsu_thresh, color='r')
#     ax[1].text(otsu_thresh + (np.nanmax(image)-np.nanmin(image))*0.1,
#                ax[1].get_ylim()[0] + (ax[1].get_ylim()[1]-ax[1].get_ylim()[0])*0.9,
#                f'Otsu threshold = \n{np.round(otsu_thresh, 2)}', color='r')
#     cmap_binary = matplotlib.colors.ListedColormap(['w', 'k'])
#     im = ax[2].imshow(ss_mask, cmap=cmap_binary,
#                       extent=(np.min(ortho.x.data)/1e3, np.max(ortho.x.data)/1e3,
#                               np.min(ortho.y.data)/1e3, np.max(ortho.y.data)/1e3))
#     cbar = fig.colorbar(im, ax=ax[2], orientation='horizontal', shrink=0.9, ticks=[0.25, 0.75])
#     cbar.ax.set_xticklabels(['unstable', 'stable'])
#     ax[2].set_title('Stable surfaces mask')

#     plt.show()

#     # Save results to file
#     ss_mask[np.isnan(ss_mask)] = -9999
#     ss_mask = ss_mask.astype(int)
#     ss_mask_xr = xr.DataArray(data = ss_mask,
#                               coords = dict(y = ortho.y, x = ortho.x), 
#                               attrs = dict(Description = 'Stable surfaces mask generated using Otsu thresholding of the input image. 1 = stable, 0 = unstable',
#                                            InputImage = os.path.basename(ortho_fn),
#                                            OtsuThreshold = otsu_thresh,
#                                            _FillValue = -9999)
#                               )
#     ss_mask_xr = ss_mask_xr.rio.write_crs(crs)
#     ss_mask_xr.rio.to_raster(ss_mask_fn)
#     print('Stable surfaces mask saved to file:', ss_mask_fn)
    
#     fig.savefig(ss_mask_fig_fn, dpi=300, bbox_inches='tight')
#     print('Figure saved to file:', ss_mask_fig_fn)


In [None]:
# # Define output files
# trees_mask_fn = os.path.join(stable_surfaces_dir, 'trees_mask.tif')
# ss_mask_fn = os.path.join(stable_surfaces_dir, 'stable_surfaces_mask.tif')
# snow_mask_fn = os.path.join(stable_surfaces_dir, 'snow_mask.tif')
# sourcedem_masked_fn = os.path.join(stable_surfaces_dir, 
#                                    os.path.basename(sourcedem_fn).replace('.tif', '_masked-trees.tif'))
# fig_fn = os.path.join(stable_surfaces_dir, 'masks.png')
    
# if ((not os.path.exists(trees_mask_fn)) or (not os.path.exists(ss_mask_fn)) 
#     or (not os.path.exists(snow_mask_fn)) or (not os.path.exists(sourcedem_masked_fn))):
#     # Load orthomosaic
#     ortho = gu.Raster(ortho_fn, load_data=True)
#     ortho.set_nodata(0)
#     ortho = ortho / 1e4
#     # Reproject to input DEM
#     sourcedem = xdem.DEM(sourcedem_fn)
#     ortho = ortho.reproject(sourcedem)


#     # Create trees mask
#     ndvi = (ortho.data[3] - ortho.data[2]) / (ortho.data[3] + ortho.data[2])
#     trees_mask = (ndvi >= 0.1).astype(int)
#     trees_mask.data[ortho.data[3]==0] = -9999
#     trees_mask_gu = gu.Raster.from_array(data=trees_mask, 
#                                          transform=ortho.transform, 
#                                          crs=ortho.crs, 
#                                          nodata=-9999)
#     trees_mask_gu.save(trees_mask_fn)
#     print('Trees mask saved to file:', trees_mask_fn)

#     # Create snow mask
#     ndsi = (ortho.data[1] - ortho.data[3]) / (ortho.data[1] + ortho.data[3])
#     snow_mask = (ndsi >= 0.1).astype(int)
#     snow_mask.data[ortho.data[3]==0] = -9999
#     snow_mask_gu = gu.Raster.from_array(data=snow_mask, 
#                                         transform=ortho.transform, 
#                                         crs=ortho.crs, 
#                                         nodata=-9999)
#     snow_mask_gu.save(snow_mask_fn)
#     print('Snow mask saved to file:', snow_mask_fn)

#     # Create stable surfaces mask (snow-free)
#     ss_mask = (snow_mask==0).astype(int) #np.logical_and(trees_mask==0, snow_mask==0).astype(int)
#     ss_mask.data[ortho.data[3]==0] = -9999
#     ss_mask_gu = gu.Raster.from_array(data=ss_mask, 
#                                       transform=ortho.transform, 
#                                       crs=ortho.crs, 
#                                       nodata=-9999)
#     ss_mask_gu.save(ss_mask_fn)
#     print('Snow mask saved to file:', ss_mask_fn)
    
#     # Mask trees in DEM
#     # trees_mask_gu = trees_mask_gu.reproject(sourcedem)
#     new_mask = np.logical_or(sourcedem.data==-9999, trees_mask_gu.data.data==1)
#     sourcedem_masked = sourcedem
#     sourcedem_masked.set_mask(new_mask)
#     sourcedem_masked_gu = gu.Raster.from_array(data=sourcedem.data,
#                                                transform=ortho.transform, 
#                                                crs=ortho.crs, 
#                                                nodata=-9999)
#     sourcedem_masked_gu.save(sourcedem_masked_fn)
#     print('Masked DEM saved to file:', sourcedem_masked_fn)

#     # Plot results
#     fig, ax = plt.subplots(3, 2, figsize=(12,15))
#     ax = ax.flatten()
#     cmap = plt.get_cmap('Greys', 2) # binary colormap for masks
#     # NDVI
#     im = ax[0].imshow(ndvi, cmap='PRGn', vmin=-0.5, vmax=0.5)
#     fig.colorbar(im, ax=ax[0], orientation='horizontal', shrink=0.5)
#     ax[0].set_title('NDVI')
#     # trees mask
#     im = ax[1].imshow(trees_mask, cmap=cmap, vmin=0, vmax=1)
#     cbar = fig.colorbar(im, ax=ax[1], orientation='horizontal', shrink=0.5, ticks=[0.25, 0.75])
#     cbar.ax.set_xticklabels(['Tree-free', 'Tree-covered'])
#     ax[1].set_title('Trees mask')
#     # NDSI
#     im = ax[2].imshow(ndsi, cmap='BrBG', vmin=-0.5, vmax=0.5)
#     fig.colorbar(im, ax=ax[2], orientation='horizontal', shrink=0.5)
#     ax[2].set_title('NDSI')
#     # snow mask
#     im = ax[3].imshow(snow_mask, cmap=cmap, vmin=0, vmax=1)
#     cbar = fig.colorbar(im, ax=ax[3], orientation='horizontal', shrink=0.5, ticks=[0.25, 0.75])
#     cbar.ax.set_xticklabels(['Snow-free', 'Snow-covered'])
#     ax[3].set_title('Snow mask')
#     # stable surfaces 
#     ax[4].imshow(ss_mask, cmap=cmap, vmin=0, vmax=1)
#     cbar = fig.colorbar(im, ax=ax[4], orientation='horizontal', shrink=0.5, ticks=[0.25, 0.75])
#     cbar.ax.set_xticklabels(['Unstable', 'Stable'])
#     ax[4].set_title('Stable surfaces mask')
#     # masked DEM
#     im = ax[5].imshow(sourcedem_masked_gu.data, cmap='terrain')
#     fig.colorbar(im, ax=ax[5], orientation='horizontal', shrink=0.5, label='Elevation [m]')
#     ax[5].set_title('Masked DEM')
#     for axis in ax:
#         axis.set_xticks([])
#         axis.set_yticks([])
#     fig.tight_layout()
#     plt.show()
    
#     fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
#     print('Figure saved to file:', fig_fn)
    
# else:
#     print('Masks already exist in file, skipping.')

## Initial coregistration and differencing

In [None]:
# -----Define functions for coregistration
def create_coreg_object(coreg_name):
    if type(coreg_name) == list:
        try:
            if coreg_name[0]=='BiasCorr':
                coreg_class = getattr(xdem.coreg, coreg_name[0])(bias_vars=["elevation", "slope", "aspect"])
            else:
                coreg_class = getattr(xdem.coreg, coreg_name[0])()
            for i in range(1, len(coreg_name)):
                if coreg_name[i]=='BiasCorr':
                    coreg_class += getattr(xdem.coreg, coreg_name[i])(bias_vars=["elevation", "slope", "aspect"])
                else:
                    coreg_class += getattr(xdem.coreg, coreg_name[i])()
            return coreg_class
        except AttributeError:
            raise ValueError(f"Coregistration method '{coreg_name}' not found.")
    elif type(coreg_name) == str:
        try:
            if coreg_name=='BiasCorr':
                coreg_class = getattr(xdem.coreg, coreg_name)(bias_vars=["elevation", "slope", "aspect"])
            else:
                coreg_class = getattr(xdem.coreg, coreg_name)()
            return coreg_class
        except AttributeError:
            raise ValueError(f"Coregistration method '{coreg_name}' not found.")
    else:
        print('coreg_method format not recognized, exiting...')
        return None

def calculate_stable_surface_stats(diff_dem, ss_mask):
    ss_masked_data = np.where(ss_mask.data==1, diff_dem.data, -9999)  
    diff_dem_ss = gu.Raster.from_array(
        ss_masked_data,
        transform=diff_dem.transform,
        crs=diff_dem.crs,
        nodata=-9999)
    diff_dem_ss_median = np.nanmedian(gu.raster.get_array_and_mask(diff_dem_ss)[0])
    diff_dem_ss_nmad = xdem.spatialstats.nmad(diff_dem)
    return diff_dem_ss, diff_dem_ss_median, diff_dem_ss_nmad

def plot_coreg_dh_results(dh_before, dh_before_ss, dh_before_ss_med, dh_before_ss_nmad,
                          dh_after, dh_after_ss, dh_after_ss_med, dh_after_ss_nmad,
                          dh_after_ss_adj, dh_after_ss_adj_ss, dh_after_ss_adj_ss_med, dh_after_ss_adj_ss_nmad,
                          vmin=-10, vmax=10):
    fig, ax = plt.subplots(3, 2, figsize=(10,16))
    dhs = [dh_before, dh_after, dh_after_ss_adj]
    titles = ['Difference before coreg.', 'Difference after coreg.', 'Difference after coreg. - median SS diff.']
    dhs_ss = [dh_before_ss, dh_after_ss, dh_after_ss_adj_ss]
    ss_meds = [dh_before_ss_med, dh_after_ss_med, dh_after_ss_adj_ss_med]
    ss_nmads = [dh_before_ss_nmad, dh_after_ss_nmad, dh_after_ss_adj_ss_nmad]
    bins = np.linspace(vmin, vmax, num=100)
    for i in range(len(dhs)):
        # plot dh
        dhs[i].plot(cmap="coolwarm_r", ax=ax[i,0], vmin=vmin, vmax=vmax)
        ax[i,0].set_title(f'{titles[i]} \nSS median = {np.round(ss_meds[i], 3)}, SS NMAD = {np.round(ss_nmads[i], 3)}')
        # Adjust map units to km
        ax[i,0].set_xticks(ax[i,0].get_xticks())
        ax[i,0].set_xticklabels(np.divide(ax[i,0].get_xticks(), 1e3).astype(str))
        ax[i,0].set_yticks(ax[i,0].get_yticks())
        ax[i,0].set_yticklabels(np.divide(ax[i,0].get_yticks(), 1e3).astype(str))
        ax[i,0].set_xlabel('Easting [km]')
        ax[i,0].set_ylabel('Northing [km]')
         # plot histograms
        ax[i,1].hist(np.ravel(dhs[i].data), bins=bins, color='grey', alpha=0.8, label='All surfaces')
        ax[i,1].legend(loc='upper left')
        ax[i,1].set_xlabel('Differences [m]')
        ax[i,1].set_ylabel('Counts')
        ax[i,1].set_xlim(vmin,vmax)
        ax2 = ax[i,1].twinx()
        ax2.hist(np.ravel(dhs_ss[i].data), bins=bins, color='m', alpha=0.8, label='Stable surfaces')
        ax2.legend(loc='upper right')
        ax2.spines['right'].set_color('m')
        ax2.set_yticks(ax2.get_yticks())
        ax2.set_yticklabels(ax2.get_yticklabels(), color='m')
    fig.tight_layout()
    plt.show()
    return fig
                                           
def coregister_difference_dems(ref_dem_fn=None, source_dem_fn=None, ss_mask_fn=None, out_dir=None, 
                               coreg_method='NuthKaab', coreg_stable_only=False, vmin=-10, vmax=10):
    # Define output file names
    coreg_meta_fn = os.path.join(out_dir, 'coregistration_fit.json')
    dem_coreg_fn = os.path.join(out_dir, 'dem_coregistered.tif')
    ss_mask_shift_fn = os.path.join(out_dir, 'stable_surfaces_coregistered.tif')
    ddem_fn = os.path.join(out_dir, 'ddem.tif')
    fig1_fn = os.path.join(out_dir, 'ddem_results.png')
    fig2_fn = os.path.join(out_dir, 'ddem_terrain_boxplots.png')

    # Load input files
    ref_dem = xdem.DEM(gu.Raster(ref_dem_fn, load_data=True, bands=1, nodata=-9999))
    tba_dem = xdem.DEM(gu.Raster(source_dem_fn, load_data=True, bands=1, nodata=-9999))
    ss_mask = gu.Raster(ss_mask_fn, load_data=True, nodata=-9999)

    # Reproject source DEM and stable surfaces mask to reference DEM grid
    tba_dem = tba_dem.reproject(ref_dem)
    ss_mask = ss_mask.reproject(ref_dem)

    # Calculate differences before coregistration
    print('Calculating differences before coregistration...')
    diff_before = tba_dem - ref_dem
    # Calculate stable surface stats
    diff_before_ss, diff_before_ss_median, diff_before_ss_nmad = calculate_stable_surface_stats(diff_before, ss_mask)

    # Create and fit the coregistration object
    print('Coregistering source DEM to reference DEM...')
    coreg_obj = create_coreg_object(coreg_method)
    if coreg_stable_only:
        coreg_obj.fit(ref_dem, tba_dem, ss_mask)   
    else:
        coreg_obj.fit(ref_dem, tba_dem)
    # Save the fit coregistration metadata
    meta = coreg_obj.meta
    with open(coreg_meta_fn, 'w') as f:
        json.dump(meta, f)
    print('Coregistration fit saved to file:', coreg_meta_fn)
        
    # Apply the coregistration object to the source DEM
    aligned_dem = coreg_obj.apply(tba_dem)

    # Apply horizontal components of the coregistration object to the stable surfaces mask if applicable
    if 'shift_x' in meta.keys():
        dx, dy = meta['shift_x'], meta['shift_y']
        ss_mask_affine = ss_mask.transform
        new_affine = ss_mask_affine * Affine.translation(dx, dy)
        ss_mask_shift = ss_mask.copy()
        ss_mask_shift.transform = new_affine 
        ss_mask = ss_mask_shift
    
    # Calculate differences after coregistration
    print('Calculating differences after coregistration...')
    diff_after = aligned_dem - ref_dem

    # Calculate stable surfaces stats
    diff_after_ss, diff_after_ss_median, diff_after_ss_nmad = calculate_stable_surface_stats(diff_after, ss_mask)

    # Subtract the median difference over stable surfaces
    aligned_dem = aligned_dem - diff_after_ss_median
    diff_after_ss_adj = diff_after - diff_after_ss_median

    # Save coregistered DEM, stable surfaces, and dDEM to file
    aligned_dem.save(dem_coreg_fn)
    print('Coregistered DEM saved to file:', dem_coreg_fn)
    ss_mask.save(ss_mask_shift_fn)
    print('Coregistered stable surfaces mask saved to file:', ss_mask_shift_fn)
    diff_after_ss_adj.save(ddem_fn)
    print('dDEM saved to file:', ddem_fn)

    # Re-calculate stable surfaces stats
    diff_after_ss_adj_ss, diff_after_ss_adj_ss_median, diff_after_ss_adj_ss_nmad = calculate_stable_surface_stats(diff_after_ss_adj, ss_mask)

    # Plot results
    print('Plotting dDEM results...')
    fig1 = plot_coreg_dh_results(diff_before, diff_before_ss, diff_before_ss_median, diff_before_ss_nmad,
                                 diff_after, diff_after_ss, diff_after_ss_median, diff_after_ss_nmad,
                                 diff_after_ss_adj, diff_after_ss_adj_ss, diff_after_ss_adj_ss_median, diff_after_ss_adj_ss_nmad,
                                 vmin=vmin, vmax=vmax)
    fig1.savefig(fig1_fn)
    print('Figure saved to file:', fig1_fn)
    
    return

In [None]:
coregister_difference_dems(ref_dem_fn = refdem_fn, 
                           source_dem_fn = sourcedem_masked_fn, 
                           ss_mask_fn = ss_mask_fn, 
                           out_dir = coreg_init_dir, 
                           coreg_method = 'NuthKaab', 
                           coreg_stable_only = False)


## Vertical bias correction

In [None]:
def deramp_dem(tba_dem_fn=None, ss_mask_fn=None, ref_dem_fn=None, out_dir=None, poly_order=2,
               vmin=-5, vmax=5):
    # Apply a vertical correction using a polynomial 2D surface to the to-be-aligned DEM
    # See example in the XDEM docs: https://xdem.readthedocs.io/en/stable/advanced_examples/plot_deramp.html
    
    # Define output file names
    dem_corrected_fn = os.path.join(out_dir, os.path.basename(tba_dem_fn).replace('.tif', '_deramped.tif'))
    deramp_meta_fn = os.path.join(out_dir, os.path.basename(tba_dem_fn).replace('.tif', '_deramp_fit.json'))
    fig_fn = os.path.join(out_dir, os.path.basename(tba_dem_fn).replace('.tif', '_deramp_correction.png'))
    
    # Load input files
    tba_dem = xdem.DEM(tba_dem_fn)
    ss_mask = gu.Raster(ss_mask_fn, load_data=True)
    ss_mask = (ss_mask == 1) # convert to boolean mask
    ref_dem = xdem.DEM(ref_dem_fn)
    
    # Calculate difference before
    diff_before = tba_dem - ref_dem

    # Mask values in DEM where dDEM > 5 (probably trees)
    tba_dem.data[diff_before.data > 5] = np.nan

    # Fit and apply Deramp object
    deramp = xdem.coreg.Deramp(poly_order=poly_order)
    deramp.fit(ref_dem, tba_dem, inlier_mask=ss_mask)
    meta = deramp.meta
    print(meta)
    dem_corrected = deramp.apply(tba_dem)
    
    # Save corrected DEM
    dem_corrected.save(dem_corrected_fn)
    print('Deramped DEM saved to file:', dem_corrected_fn)
    
    # Save Deramp fit metadata
    # keys_sub = [x for x in list(meta.keys()) if (x!= 'fit_func') & (x!='fit_optimizer')] # can't serialize functions in JSON
    # meta_sub = {key: meta[key] for key in keys_sub}
    # with open(deramp_meta_fn, 'w') as f:
    #     json.dump(meta_sub, f)
    # print('Deramp fit saved to file:', deramp_meta_fn)
    
    # Calculate difference after
    diff_after = dem_corrected - ref_dem
    
    # Plot results
    bins = np.linspace(vmin, vmax, num=100)
    fig, ax = plt.subplots(2, 2, figsize=(10,10))
    ax = ax.flatten()
    diff_before.plot(cmap='coolwarm_r', vmin=vmin, vmax=vmax, ax=ax[0])
    ax[0].set_title('dDEM')
    diff_after.plot(cmap='coolwarm_r', vmin=vmin, vmax=vmax, ax=ax[1])
    ax[1].set_title('Deramped dDEM')
    ax[2].hist(np.ravel(diff_before.data), color='grey', bins=bins)
    ax[2].set_xlim(vmin,vmax)
    ax[2].set_xlabel('Elevation differences (all surfaces) [m]')
    ax[3].hist(np.ravel(diff_after.data), color='grey', bins=bins)
    ax[3].set_xlim(vmin, vmax)
    ax[3].set_xlabel('Elevation differences (all surfaces) [m]')
    fig.tight_layout()
    plt.show()
    
    # Save figure
    fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
    print('Figure saved to file:', fig_fn)
    
    return
    

In [None]:
## First round ##
deramp_dem(tba_dem_fn = os.path.join(coreg_init_dir, 'dem_coregistered.tif'), 
           ss_mask_fn = os.path.join(coreg_init_dir, 'stable_surfaces_coregistered.tif'), 
           ref_dem_fn = refdem_fn, 
           out_dir = deramp_dir, 
           poly_order = 2)

In [None]:
# ## Second round ## - Didn't make much difference for MCS
# in_dem_fn = glob.glob(os.path.join(deramp_dir, '*_deramped.tif'))[0]
# in_ss_mask_fn = os.path.join(coreg_init_dir, 'stable_surfaces_coregistered.tif')
# deramp_dem(tba_dem_fn = in_dem_fn, 
#            ss_mask_fn = in_ss_mask_fn, 
#            ref_dem_fn = refdem_fn, 
#            out_dir = deramp_dir, 
#            poly_order = 3)

## Final coregistration and differencing

In [None]:
in_dem_fn = glob.glob(os.path.join(deramp_dir, '*_deramped.tif'))[0]
in_ss_mask_fn = os.path.join(coreg_init_dir, 'stable_surfaces_coregistered.tif')

coregister_difference_dems(ref_dem_fn = refdem_fn, 
                           source_dem_fn = in_dem_fn, 
                           ss_mask_fn = in_ss_mask_fn, 
                           out_dir = coreg_final_dir, 
                           coreg_method = 'NuthKaab', 
                           coreg_stable_only=False,
                           vmin=-5, vmax=5)

## Remove large negative biases likely due to veg differences in the DEMs

In [None]:
# # Load dDEM and stable surfaces mask
# ddem_fn = os.path.join(coreg_final_dir, 'ddem.tif')
# ss_mask_fn = os.path.join(stable_surfaces_dir, 'roads_mask.tif')

# ddem = gu.Raster(ddem_fn, load_data=True)
# ss_mask = gu.Raster(ss_mask_fn, load_data=True)
# # ss_mask = (ss_mask == 1) # convert to boolean mask

# # Mask values in dDEM below a certain threshold
# threshold = -4
# ddem_thresh = ddem.copy()
# ddem_thresh.data = np.where(ddem.data <= threshold, np.nan, ddem.data)

# # Shift vertically using median stable surfaces difference
# ddem_ss, ddem_ss_median, ddem_ss_nmad = calculate_stable_surface_stats(ddem_thresh, ss_mask)
# ddem_shift = ddem_thresh - ddem_ss_median

# # Plot
# fig, ax = plt.subplots(2, 2, figsize=(10,10))
# ax = ax.flatten()
# ddem.plot(ax=ax[0], cmap='coolwarm_r', vmin=-5, vmax=5)
# ax[0].set_title('dDEM')
# ddem_shift.plot(ax=ax[1], cmap='coolwarm_r', vmin=-5, vmax=5)
# ax[1].set_title('dDEM, masked and shifted')
# bins = np.linspace(-10,10,100)
# ax[2].hist(np.ravel(ddem.data), bins=bins, color='grey', alpha=0.8, label='dDEM')
# ax[2].set_xlabel('Elevation difference [m]')
# ax[3].hist(np.ravel(ddem_shift.data), bins=bins, color='grey', alpha=0.8, label='dDEM, masked and shifted')
# ax[3].set_xlabel('Elevation difference [m]')

# plt.show()

## Remove Off Terrain Object (ROTO) 

In [None]:
# from whitebox_tools import WhiteboxTools
# from scipy.optimize import minimize  

# wbt = WhiteboxTools()
# wbt.set_whitebox_dir('/Users/raineyaberle/opt/anaconda3/envs/snow-dems/bin')  

# out_path = os.path.join(out_dir, 'ROTO')
# if not os.path.exists(out_path):
#     os.mkdir(out_path)

# def compute_hillshade(h, azimuth=315, altitude=45):
#     """
#     Compute hillshade using xdem.
#     """
#     hillshade = xdem.terrain.hillshade(h, azimuth=azimuth, altitude=altitude)
#     return hillshade

# def suppress_output(message):
#     pass  

# def hillshade_difference(params, dsm_fn, hillshade_dem, stable_mask, out_path):
#     """
#     Calculate hillshade difference within the stable surface mask.
#     """
#     slope, filter = params
#     # Apply ROTO function using WhiteboxTools with specified parameters
#     out_fn = os.path.join(out_path, f'ROTO_slope{slope}_filter{filter}.tif')
#     if not os.path.exists(out_fn):
#         wbt.remove_off_terrain_objects(dsm_fn, output=out_fn, slope=slope, 
#                                        filter=filter, callback=suppress_output)
    
#     # Load DSM
#     filtered_dsm = xdem.DEM(out_fn)
#     filtered_dsm = filtered_dsm
    
#     # Compute hillshade for both DEM and DSM
#     hillshade_dsm = compute_hillshade(filtered_dsm)

#     # Calculate hillshade difference within stable surfaces
#     difference = np.abs(hillshade_dem - hillshade_dsm)
#     difference_masked = difference[stable_mask]
    
#     # Return the RMSE
#     rmse = np.sqrt(np.mean(difference_masked ** 2))
    
#     return rmse

# def brute_force_search(dsm_fn, dem_fn, stable_mask_fn, out_path):
#     # Calculate hillshade for DEM
#     dem = xdem.DEM(dem_fn)
#     hillshade_dem = compute_hillshade(dem)
    
#     # Load stable surface mask
#     stable_mask = gu.Raster(stable_mask_fn, load_data=True)
#     stable_mask = (stable_mask == 1)
#     stable_mask.set_nodata(False)
#     stable_mask = stable_mask.reproject(dem)
    
#     # Initialize best parameters
#     best_params = None
#     best_rmse = float('inf')

#     # Define ranges for slope and filter size
#     slope_range = range(2, 7, 2)  # Slope thresholds 
#     filter_size_range = range(5, 11, 5)  # Filter sizes 

#     # Iterate over all combinations of parameters
#     print('Slope \t\t Filter \t\t RMSE')
#     for slope in slope_range:
#         for filter_size in filter_size_range:
#             params = (slope, filter_size)
#             rmse = hillshade_difference(params, dsm_fn, hillshade_dem, stable_mask, out_path)
#             print(f'{slope} \t\t {filter_size} \t\t {rmse}')
#             if rmse < best_rmse:
#                 best_rmse = rmse
#                 best_params = params

#     best_slope, best_filter = params[0], params[1]
#     print(f"\nBest params: Slope={best_slope}, Filter Size={best_filter} with RMSE={best_rmse}")

#     return best_slope, best_filter

# dsm_fn = os.path.join(coreg_init_dir, 'dem_coregistered.tif')
# dem_fn = refdem_fn
# stable_mask_fn = ss_mask_fn

# best_slope, best_filter = brute_force_search(dsm_fn, dem_fn, stable_mask_fn, out_path)
# print(f"Optimal parameters: Slope={best_slope}, Filter Size={best_filter}")

# best_dem_fn = os.path.join(out_path, f'ROTO_slope{best_slope}_filter{best_filter}.tif')

### ICP + Nuth and Kaab - coregister all surfaces

For large rotations

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method=['ICP', 'NuthKaab'], 
#                                         coreg_stable_only=False,
#                                         ss_method=ss_method)

### Nuth and Kaab + Tilt - coregister all surfaces

For small rotations

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method=['NuthKaab', 'Tilt'], 
#                                         coreg_stable_only=False,
#                                         ss_method=ss_method)

### Bias Corr + ICP + Nuth and Kaab - coregister all surfaces

For large shifts, rotations and high amounts of noise

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method=['BiasCorr', 'ICP', 'NuthKaab'], 
#                                         coreg_stable_only=False,
#                                         ss_method=ss_method)

### Nuth and Kaab - coregister all surfaces

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method='NuthKaab', 
#                                         coreg_stable_only=False,
#                                         ss_method=ss_method)

### Nuth and Kaab - coregister stable surfaces only

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method='NuthKaab', 
#                                         coreg_stable_only=True,
#                                         ss_method=ss_method)

### Gradient Descending - coregister all surfaces

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method='GradientDescending', 
#                                         coreg_stable_only=False,
#                                         ss_method=ss_method)

### Gradient Descending - coregister stable surfaces only

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir
#                                         coreg_method='GradientDescending', 
#                                         coreg_stable_only=True,
#                                         ss_method=ss_method)