# Pipeline for coregistering and differencing SkySat DEM from reference DEM

In [None]:
import os, glob
import matplotlib.pyplot as plt
import matplotlib
import xarray as xr
import rioxarray as rxr
import numpy as np
from tqdm.auto import tqdm
import xdem
import geoutils as gu
from skimage.filters import threshold_otsu
import math
import pandas as pd
import json
from affine import Affine

## Define paths to DEMs

In [None]:
# Define site name and source DEM date for convenience
data_dir = f'/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/MCS'
refdem_fn = os.path.join(data_dir, 'refdem', 'MCS_REFDEM_WGS84.tif')
sourcedem_fn = os.path.join(data_dir, '20240420', 'MCS_20240420-1_DEM.tif')
ortho_fn = os.path.join(data_dir, '20240420', 'MCS_20240420-1_orthomosaic.tif')

# Define path for output snow depth images
out_dir = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/snow_depth_maps'
job_name = 'MCS_20240420-1'

# Check that files exist
if not os.path.exists(refdem_fn):
    print('Reference DEM not found, check path to file before continuing.')
if not os.path.exists(sourcedem_fn):
    print('Source DEM not found, check path to file before continuing.')
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    print('Created directory for output files:', out_dir)

In [None]:
# Define all output directories and files
# New output directory for full job
out_dir = os.path.join(out_dir, job_name)
# step 1: stable surfaces (snow-free) mask by Otsu thresholding
stable_surfaces_dir = os.path.join(out_dir, 'stable_surfaces')
# step 2: Initial coregistration and differencing
coreg_init_dir = os.path.join(out_dir, 'coreg_initial_diff')
# step 3: Quadratic bias correction
deramp_dir = os.path.join(out_dir, 'deramping')
# step 4: Final coregistration and differencing
coreg_final_dir = os.path.join(out_dir, 'coreg_final_diff')

# Create directories
for directory in [out_dir, stable_surfaces_dir, coreg_init_dir, deramp_dir, coreg_final_dir]:
    if not os.path.exists(directory):
        os.mkdir(directory)

## 1. Create stable surfaces (snow-free) mask

In [None]:
# -----Use roads polygon for plotting differences
roads_mask_fn = os.path.join(stable_surfaces_dir, 'roads_mask.tif')

if not os.path.exists(roads_mask_fn):
    # refdem = xdem.DEM(refdem_fn)
    roads_fn = os.path.join(data_dir, 'MCS_roads_polygon.shp')
    roads_vector = gu.Vector(roads_fn)
    roads_raster = roads_vector.rasterize(refdem)
    roads_raster.data.mask = refdem.data.mask
    roads_raster = roads_raster.astype(np.int16)
    roads_raster.set_nodata(-9999)
    roads_raster.save(roads_mask_fn)
    print('Roads raster saved to file:', roads_mask_fn)
    
    roads_raster.plot()
    plt.show()

else:
    print('Roads mask already exists in file.')
    

In [None]:
# -----Apply Otsu thresholding to the orthomosaic
# outputs
ss_mask_fn = os.path.join(stable_surfaces_dir, 'stable_surfaces.tif')

# Check if stable surfaces already exist in file
if os.path.exists(ss_mask_fn):
    print('Stable surfaces mask already exists in file.')

else:

    # Open orthomosaic
    ortho = rxr.open_rasterio(ortho_fn)
    # Save CRS for later
    crs = ortho.rio.crs 
    # Grab data
    image = ortho.data[0]
    # Create no-data mask
    nodata_mask = image == 0
    
    # Calculate and apply Otsu threshold
    otsu_thresh = threshold_otsu(image)
    ss_mask = (image < otsu_thresh).astype(float)

    # Apply nodata mask
    ss_mask[nodata_mask] = np.nan
    image[nodata_mask] = np.nan
    
    # Plot results
    fig, axes = plt.subplots(1, 3, figsize=(12, 5))
    ax = axes.ravel()
    im = ax[0].imshow(image, cmap=plt.cm.gray,
                      extent=(np.min(ortho.x.data)/1e3, np.max(ortho.x.data)/1e3,
                              np.min(ortho.y.data)/1e3, np.max(ortho.y.data)/1e3))
    fig.colorbar(im, ax=ax[0], orientation='horizontal', shrink=0.9)
    ax[0].set_title('Original')
    ax[1].hist(image.ravel(), bins=100, color='grey')
    ax[1].set_title('Histogram')
    ax[1].axvline(otsu_thresh, color='r')
    ax[1].text(otsu_thresh + (np.nanmax(image)-np.nanmin(image))*0.1,
               ax[1].get_ylim()[0] + (ax[1].get_ylim()[1]-ax[1].get_ylim()[0])*0.9,
               f'Otsu threshold = \n{np.round(otsu_thresh, 2)}', color='r')
    cmap_binary = matplotlib.colors.ListedColormap(['w', 'k'])
    im = ax[2].imshow(ss_mask, cmap=cmap_binary,
                      extent=(np.min(ortho.x.data)/1e3, np.max(ortho.x.data)/1e3,
                              np.min(ortho.y.data)/1e3, np.max(ortho.y.data)/1e3))
    cbar = fig.colorbar(im, ax=ax[2], orientation='horizontal', shrink=0.9, ticks=[0.25, 0.75])
    cbar.ax.set_xticklabels(['unstable', 'stable'])
    ax[2].set_title('Stable surfaces mask')

    plt.show()

    # Save results to file
    ss_mask[np.isnan(ss_mask)] = -9999
    ss_mask = ss_mask.astype(int)
    ss_mask_xr = xr.DataArray(data = ss_mask,
                              coords = dict(y = ortho.y, x = ortho.x), 
                              attrs = dict(Description = 'Stable surfaces mask generated using Otsu thresholding of the input image. 1 = stable, 0 = unstable',
                                           InputImage = os.path.basename(ortho_fn),
                                           OtsuThreshold = otsu_thresh,
                                           _FillValue = -9999)
                              )
    ss_mask_xr = ss_mask_xr.rio.write_crs(crs)
    ss_mask_xr.rio.to_raster(ss_mask_fn)
    print('Stable surfaces mask saved to file:', ss_mask_fn)
    
    fig.savefig(ss_mask_fig_fn, dpi=300, bbox_inches='tight')
    print('Figure saved to file:', ss_mask_fig_fn)


## 2. Initial coregistration and differencing

In [None]:
# -----Define functions for coregistration
def create_coreg_object(coreg_name):
    if type(coreg_name) == list:
        try:
            if coreg_name[0]=='BiasCorr':
                coreg_class = getattr(xdem.coreg, coreg_name[0])(bias_vars=["elevation", "slope", "aspect"])
            else:
                coreg_class = getattr(xdem.coreg, coreg_name[0])()
            for i in range(1, len(coreg_name)):
                if coreg_name[i]=='BiasCorr':
                    coreg_class += getattr(xdem.coreg, coreg_name[i])(bias_vars=["elevation", "slope", "aspect"])
                else:
                    coreg_class += getattr(xdem.coreg, coreg_name[i])()
            return coreg_class
        except AttributeError:
            raise ValueError(f"Coregistration method '{coreg_name}' not found.")
    elif type(coreg_name) == str:
        try:
            if coreg_name=='BiasCorr':
                coreg_class = getattr(xdem.coreg, coreg_name)(bias_vars=["elevation", "slope", "aspect"])
            else:
                coreg_class = getattr(xdem.coreg, coreg_name)()
            return coreg_class
        except AttributeError:
            raise ValueError(f"Coregistration method '{coreg_name}' not found.")
    else:
        print('coreg_method format not recognized, exiting...')
        return None


def differences_vs_slope_aspect(dem, diffs):
    # Calculate slope and aspect from DEM
    slope = xdem.terrain.slope(dem)
    aspect = xdem.terrain.aspect(dem)
    
    # Compile differences, slopes, and aspects in a dataframe
    df = pd.DataFrame({'diff': np.ravel(diffs.data),
                       'elev': np.ravel(dem.data),
                       'slope': np.ravel(slope.data),
                       'aspect': np.ravel(aspect.data)})
    df.dropna(inplace=True)
    df.reset_index(drop=True, inplace=True)
    
    # Create bins for elev, slope, and aspect
    bin_min = np.floor(dem.data.min()  / 5) * 5
    bins = np.arange(bin_min, dem.data.max()+50, step=50)
    df['elev_bin'] = pd.cut(df['elev'], bins=bins)
    df['slope_bin'] = pd.cut(df['slope'], bins=np.arange(0, 41, step=2.5))
    df['aspect_bin'] = pd.cut(df['aspect'], bins=np.arange(0, 361, step=22.5))

    # Plot
    fig2, ax = plt.subplots(3, 1, figsize=(8,14))
    # elev
    df.boxplot(column='diff', by='elev_bin', showfliers=False, patch_artist=True, ax=ax[0],
               boxprops=dict(color='k'), medianprops=dict(color='w', linewidth=1.5), whiskerprops=dict(color='k'))
    ax[0].set_title('')
    ax[0].set_xlabel('Elevation range [m]')
    # slope
    df.boxplot(column='diff', by='slope_bin', showfliers=False, patch_artist=True, ax=ax[1],
               boxprops=dict(color='k'), medianprops=dict(color='w', linewidth=1.5), whiskerprops=dict(color='k'))
    ax[1].set_title('')
    ax[1].set_xlabel('Slope range [degrees]')
    # aspect
    df.boxplot(column='diff', by='aspect_bin', showfliers=False, patch_artist=True, ax=ax[2],
               boxprops=dict(color='k'), medianprops=dict(color='w', linewidth=1.5), whiskerprops=dict(color='k'))
    ax[2].set_title('')
    ax[2].set_xlabel('Aspect range [degrees from North]')
    for axis in ax:
        axis.set_ylabel('Differences [m]')
    fig2.suptitle('')
    fig2.tight_layout()
    
    return fig2

def calculate_stable_surface_stats(diff_dem, ss_mask):
    ss_masked_data = np.where(ss_mask.data==1, diff_dem.data, -9999)  
    diff_dem_ss = gu.Raster.from_array(
        ss_masked_data,
        transform=diff_dem.transform,
        crs=diff_dem.crs,
        nodata=-9999)
    diff_dem_ss_median = np.nanmedian(gu.raster.get_array_and_mask(diff_dem_ss)[0])
    diff_dem_ss_nmad = xdem.spatialstats.nmad(diff_dem)
    return diff_dem_ss, diff_dem_ss_median, diff_dem_ss_nmad

def plot_coreg_dh_results(dh_before, dh_before_ss, dh_before_ss_med, dh_before_ss_nmad,
                          dh_after, dh_after_ss, dh_after_ss_med, dh_after_ss_nmad,
                          dh_after_ss_adj, dh_after_ss_adj_ss, dh_after_ss_adj_ss_med, dh_after_ss_adj_ss_nmad,
                          vmin=-10, vmax=10):
    fig, ax = plt.subplots(3, 2, figsize=(10,16))
    dhs = [dh_before, dh_after, dh_after_ss_adj]
    titles = ['Difference before coreg.', 'Difference after coreg.', 'Difference after coreg. - median SS diff.']
    dhs_ss = [dh_before_ss, dh_after_ss, dh_after_ss_adj_ss]
    ss_meds = [dh_before_ss_med, dh_after_ss_med, dh_after_ss_adj_ss_med]
    ss_nmads = [dh_before_ss_nmad, dh_after_ss_nmad, dh_after_ss_adj_ss_nmad]
    for i in range(len(dhs)):
        # plot dh
        dhs[i].plot(cmap="coolwarm_r", ax=ax[i,0], vmin=vmin, vmax=vmax)
        ax[i,0].set_title(f'{titles[i]} \nSS median = {np.round(ss_meds[i], 3)}, SS NMAD = {np.round(ss_nmads[i], 3)}')
        # Adjust map units to km
        ax[i,0].set_xticks(ax[i,0].get_xticks())
        ax[i,0].set_xticklabels(np.divide(ax[i,0].get_xticks(), 1e3).astype(str))
        ax[i,0].set_yticks(ax[i,0].get_yticks())
        ax[i,0].set_yticklabels(np.divide(ax[i,0].get_yticks(), 1e3).astype(str))
        ax[i,0].set_xlabel('Easting [km]')
        ax[i,0].set_ylabel('Northing [km]')
         # plot histograms
        ax[i,1].hist(np.ravel(dhs[i].data), bins=50, color='grey', alpha=0.8, label='All surfaces')
        ax[i,1].legend(loc='upper left')
        ax[i,1].set_xlabel('Differences [m]')
        ax[i,1].set_ylabel('Counts')
        ax[i,1].set_xlim(vmin,vmax)
        ax2 = ax[i,1].twinx()
        ax2.hist(np.ravel(dhs_ss[i].data), bins=50, color='m', alpha=0.8, label='Stable surfaces')
        ax2.legend(loc='upper right')
        ax2.spines['right'].set_color('m')
        ax2.set_yticks(ax2.get_yticks())
        ax2.set_yticklabels(ax2.get_yticklabels(), color='m')
    fig.tight_layout()
    plt.show()
    return fig
                                           
def coregister_difference_dems(ref_dem_fn=None, source_dem_fn=None, ss_mask_fn=None, out_dir=None, 
                               coreg_method='NuthKaab', coreg_stable_only=False, plot_terrain_results=True,
                               vmin=-10, vmax=10):
    # Define output file names
    coreg_meta_fn = os.path.join(out_dir, 'coregistration_fit.json')
    dem_coreg_fn = os.path.join(out_dir, 'dem_coregistered.tif')
    ss_mask_shift_fn = os.path.join(out_dir, 'stable_surfaces_coregistered.tif')
    ddem_fn = os.path.join(out_dir, 'ddem.tif')
    fig1_fn = os.path.join(out_dir, 'ddem_results.png')
    fig2_fn = os.path.join(out_dir, 'ddem_terrain_boxplots.png')

    # Load input files
    ref_dem = xdem.DEM(gu.Raster(ref_dem_fn, load_data=True, bands=1))
    tba_dem = xdem.DEM(gu.Raster(source_dem_fn, load_data=True, bands=1))
    ss_mask = gu.Raster(ss_mask_fn, load_data=True, nodata=-9999)
    ss_mask = (ss_mask == 1)  # Convert to boolean mask
    
    # Reproject source DEM and stable surfaces mask to reference DEM grid
    tba_dem = tba_dem.reproject(ref_dem)
    ss_mask = ss_mask.reproject(ref_dem)

    # Calculate differences before coregistration
    print('Calculating differences before coregistration...')
    diff_before = tba_dem - ref_dem
    # Calculate stable surface stats
    diff_before_ss, diff_before_ss_median, diff_before_ss_nmad = calculate_stable_surface_stats(diff_before, ss_mask)

    # Create and fit the coregistration object
    print('Coregistering source DEM to reference DEM...')
    coreg_obj = create_coreg_object(coreg_method)
    if coreg_stable_only:
        coreg_obj.fit(ref_dem, tba_dem, ss_mask)   
    else:
        coreg_obj.fit(ref_dem, tba_dem)
    # Save the fit coregistration metadata
    meta = coreg_obj.meta
    with open(coreg_meta_fn, 'w') as f:
        json.dump(meta, f)
    print('Coregistration fit saved to file:', coreg_meta_fn)
        
    # Apply the coregistration object to the source DEM
    aligned_dem = coreg_obj.apply(tba_dem)

    # Apply horizontal components of the coregistration object to the stable surfaces mask if applicable
    if 'shift_x' in meta.keys():
        dx, dy = meta['shift_x'], meta['shift_y']
        ss_mask_affine = ss_mask.transform
        new_affine = ss_mask_affine * Affine.translation(dx, dy)
        ss_mask_shift = ss_mask.copy()
        ss_mask_shift.transform = new_affine 
        ss_mask = ss_mask_shift
    
    # Save coregistered DEM and stable surfaces to file
    aligned_dem.save(dem_coreg_fn)
    print('Coregistered DEM saved to file:', dem_coreg_fn)
    ss_mask.save(ss_mask_shift_fn)
    print('Coregistered stable surfaces mask saved to file:', ss_mask_shift_fn)

    # Calculate differences after coregistration
    print('Calculating differences after coregistration...')
    diff_after = aligned_dem - ref_dem

    # Calculate stable surfaces stats
    diff_after_ss, diff_after_ss_median, diff_after_ss_nmad = calculate_stable_surface_stats(diff_after, ss_mask)

    # Subtract the median difference over stable surfaces
    diff_after_ss_adj = diff_after - diff_after_ss_median

    # Save dDEM to file
    diff_after_ss_adj.save(ddem_fn)
    print('dDEM saved to file:', ddem_fn)

    # Re-calculate stable surfaces stats
    diff_after_ss_adj_ss, diff_after_ss_adj_ss_median, diff_after_ss_adj_ss_nmad = calculate_stable_surface_stats(diff_after_ss_adj, ss_mask)

    # Plot results
    print('Plotting dDEM results...')
    fig1 = plot_coreg_dh_results(diff_before, diff_before_ss, diff_before_ss_median, diff_before_ss_nmad,
                                 diff_after, diff_after_ss, diff_after_ss_median, diff_after_ss_nmad,
                                 diff_after_ss_adj, diff_after_ss_adj_ss, diff_after_ss_adj_ss_median, diff_after_ss_adj_ss_nmad,
                                 vmin=vmin, vmax=vmax)
    fig1.savefig(fig1_fn)
    print('Figure saved to file:', fig1_fn)
    
    # Calculate differences as a function of slope and aspect
    if plot_terrain_results:
        print('Plotting differences as a function of elevation, slope, and aspect...')
        fig2 = differences_vs_slope_aspect(ref_dem, diff_after_ss_adj)
        fig2.savefig(fig2_fn, dpi=300, bbox_inches='tight')
        print('Figure saved to file:', fig2_fn)
    
    return

In [None]:
coregister_difference_dems(ref_dem_fn = refdem_fn, 
                           source_dem_fn = sourcedem_fn, 
                           ss_mask_fn = roads_mask_fn, 
                           out_dir = coreg_init_dir, 
                           coreg_method = 'NuthKaab', 
                           coreg_stable_only = False, 
                           plot_terrain_results = True)


## 3. Vertical bias correction

In [None]:
def deramp_dem(tba_dem_fn=None, ss_mask_fn=None, ref_dem_fn=None, out_dir=None, poly_order=2):
    # Apply a vertical correction using a polynomial 2D surface to the to-be-aligned DEM
    # See example in the XDEM docs: https://xdem.readthedocs.io/en/stable/advanced_examples/plot_deramp.html
    
    # Define output file names
    dem_corrected_fn = os.path.join(out_dir, os.path.basename(tba_dem_fn).replace('.tif', '_deramped.tif'))
    deramp_meta_fn = os.path.join(out_dir, os.path.basename(tba_dem_fn).replace('.tif', '_deramp_fit.json'))
    fig_fn = os.path.join(out_dir, os.path.basename(tba_dem_fn).replace('.tif', '_deramp_correction.png'))
    
    # Load input files
    tba_dem = xdem.DEM(tba_dem_fn)
    ss_mask = gu.Raster(ss_mask_fn, load_data=True)
    ss_mask = (ss_mask == 0) # convert to boolean mask
    ref_dem = xdem.DEM(ref_dem_fn)
    
    # Calculate difference before
    diff_before = tba_dem - ref_dem

    # Mask values in DEM where dDEM > 5 (probably trees)
    tba_dem.data[diff_before.data > 5] = np.nan

    # Fit and apply Deramp object
    deramp = xdem.coreg.Deramp(poly_order=poly_order)
    deramp.fit(ref_dem, tba_dem, inlier_mask=ss_mask)
    meta = deramp.meta
    print(meta)
    dem_corrected = deramp.apply(tba_dem)
    
    # Save corrected DEM
    dem_corrected.save(dem_corrected_fn)
    print('Deramped DEM saved to file:', dem_corrected_fn)
    
    # Save Deramp fit metadata
    # keys_sub = [x for x in list(meta.keys()) if (x!= 'fit_func') & (x!='fit_optimizer')] # can't serialize functions in JSON
    # meta_sub = {key: meta[key] for key in keys_sub}
    # with open(deramp_meta_fn, 'w') as f:
    #     json.dump(meta_sub, f)
    # print('Deramp fit saved to file:', deramp_meta_fn)
    
    # Calculate difference after
    diff_after = dem_corrected - ref_dem
    
    # Plot results
    fig, ax = plt.subplots(2, 2, figsize=(10,10))
    ax = ax.flatten()
    diff_before.plot(cmap='coolwarm_r', vmin=-5, vmax=5, ax=ax[0])
    ax[0].set_title('dDEM')
    diff_after.plot(cmap='coolwarm_r', vmin=-5, vmax=5, ax=ax[1])
    ax[1].set_title('Deramped dDEM')
    ax[2].hist(np.ravel(diff_before.data), bins=100)
    ax[2].set_xlim(-10,10)
    ax[3].hist(np.ravel(diff_after.data), bins=100)
    ax[3].set_xlim(-10, 10)
    plt.show()
    
    # Save figure
    fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
    print('Figure saved to file:', fig_fn)
    
    return
    

In [None]:
## First round ##
deramp_dem(tba_dem_fn = os.path.join(coreg_init_dir, 'dem_coregistered.tif'), 
           ss_mask_fn = os.path.join(coreg_init_dir, 'stable_surfaces_coregistered.tif'), 
           ref_dem_fn = refdem_fn, 
           out_dir = deramp_dir, 
           poly_order = 3)

In [None]:
## Second round ##
deramp_dem(tba_dem_fn = os.path.join(deramp_dir, 'dem_coregistered_deramped.tif'), 
           ss_mask_fn = os.path.join(coreg_init_dir, 'stable_surfaces_coregistered.tif'), 
           ref_dem_fn = refdem_fn, 
           out_dir = deramp_dir, 
           poly_order = 3)

## 4. Final coregistration and differencing

In [None]:
coregister_difference_dems(ref_dem_fn = refdem_fn, 
                           source_dem_fn = os.path.join(deramp_dir, 'dem_coregistered_deramped_deramped.tif'), 
                           ss_mask_fn = roads_mask_fn, 
                           out_dir = coreg_final_dir, 
                           coreg_method = 'NuthKaab', 
                           coreg_stable_only=False, 
                           plot_terrain_results=True)

### ICP + Nuth and Kaab - coregister all surfaces

For large rotations

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method=['ICP', 'NuthKaab'], 
#                                         coreg_stable_only=False,
#                                         ss_method=ss_method)

In [None]:
# # -----Adjust vertically using SNOTEL snow depth
# # Load dDEM
# ddem_fn = os.path.join(out_dir, 'MCS_20240420-1_DEM_masked-NMAD2m-refl0_ss-roads_ICP-NuthKaab-all-surfaces_differences.tif')
# ddem = rxr.open_rasterio(ddem_fn)
# ddem_crs = ddem.rio.crs
# ddem = xr.where(ddem < -1e10, np.nan, ddem)

# # Load SNOTEL data vars
# snotel_fn = os.path.join(data_dir, 'snotel', 'MCS_2020-01-01_2024-06-07_adj.csv')
# snotel = pd.read_csv(snotel_fn)
# # subset to DEM date
# snotel['datetime'] = pd.DatetimeIndex(snotel['datetime'])
# sourcedem_dt = pd.Timestamp(np.datetime64(f'{sourcedem_date[0:4]}-{sourcedem_date[4:6]}-{sourcedem_date[6:]}'))
# Idate = np.argmin([abs(x - sourcedem_dt) for x in snotel['datetime'].values])
# snotel_date = snotel.iloc[Idate]
# snotel_sd = snotel_date['SNWD_m']
# print(f'SNOTEL snow depth = {snotel_sd} m')

# # Load SNOTEL site info
# from shapely import wkt
# snotel_info_fn = os.path.join(data_dir, 'snotel', 'MCS_SNOTEL_site_info.csv')
# snotel_info = pd.read_csv(snotel_info_fn, index_col='Unnamed: 0')
# snotel_info['geometry'] = snotel_info['geometry'].apply(wkt.loads)
# snotel_info = gpd.GeoDataFrame(snotel_info, geometry='geometry', crs='EPSG:4326')
# # reproject to DEM CRS
# snotel_info.to_crs(ddem_crs, inplace=True)

# # Sample dDEM at SNOTEL location
# ddem_sample = ddem.sel(x=snotel_info.geometry[0].coords.xy[0][0], y=snotel_info.geometry[0].coords.xy[1][0], method='nearest').data[0]
# print(f'Elevation diff. at SNOTEL location = {ddem_sample} m')

# # Add the difference from the dDEM
# diff = snotel_sd - ddem_sample
# ddem_adj = ddem + diff

# # Plot results
# fig, ax = plt.subplots(2, 2, figsize=(14,8), gridspec_kw={'height_ratios':[2,1]})
# ax = ax.flatten()
# # Diff before SNOTEL adjustment
# im1 = ax[0].imshow(ddem.data[0], cmap="coolwarm_r", vmin=-5, vmax=5,
#                    extent=(np.min(ddem.x.data)/1e3, np.max(ddem.x.data)/1e3, 
#                            np.min(ddem.y.data)/1e3, np.max(ddem.y.data)/1e3))
# ax[0].set_ylabel('Northing [m]')
# ax[0].set_xlabel('Easting [m]')
# ax[0].set_title('dDEM before SNOTEL adjustment')
# ax[2].hist(np.ravel(ddem.data), bins=100)
# ax[2].set_xlim(-10,10)
# # Diff after SNOTEL adjustment
# im2 = ax[1].imshow(ddem_adj.data[0], cmap="coolwarm_r", vmin=-5, vmax=5,
#                    extent=(np.min(ddem.x.data)/1e3, np.max(ddem.x.data)/1e3, 
#                            np.min(ddem.y.data)/1e3, np.max(ddem.y.data)/1e3))
# ax[1].set_xlabel('Easting [m]')
# ax[1].set_title('dDEM after SNOTEL adjustment')
# ax[3].hist(np.ravel(ddem_adj.data), bins=100)
# ax[3].set_xlim(-10,10)
# for axis in [ax[0], ax[1]]:
#     axis.plot(snotel_info.geometry[0].coords.xy[0][0]/1e3, snotel_info.geometry[0].coords.xy[1][0]/1e3, 
#               '*k', markersize=10, label='SNOTEL site')
#     axis.legend(loc='upper right')
# for im in [im1, im2]:
#     fig.colorbar(im, orientation='vertical', label='Elevation difference [m]', shrink=0.8)
# for axis in [ax[2], ax[3]]:
#     axis.set_xlabel('Elevation differences [m]')
#     axis.vlines(0, 0, 1.7e7, color='k')
# ax[2].set_ylabel('Counts')
# plt.show()

# fig_fn = ddem_fn.replace('.tif', '_SNOTEL-adj.png')
# fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
# print('Figure saved to file:', fig_fn)

In [None]:
# refdem = xr.open_dataset(refdem_fn)
# ortho_fn = os.path.join(data_dir, '20240420', 'MCS_20240420-1_orthomosaic.tif')
# ortho = xr.open_dataset(ortho_fn)
# # interpolate to refdem coords
# ortho= ortho.sel(x=refdem.x.data, y=refdem.y.data, method='nearest')
# # mask where refdem is no data
# ortho_masked = ortho.copy(deep=True)
# ortho_masked['refdem'] = (('band', 'y', 'x'), refdem.band_data.data)
# ortho_masked = xr.where(np.isnan(ortho_masked['refdem']), np.nan, ortho)

# fig, ax = plt.subplots(1, 2, figsize=(12,6))
# ax[0].imshow(ortho['band_data'].data[0], cmap='Greys_r', clim=(0,5e4),
#              extent=(np.min(ortho.x.data), np.max(ortho.x.data),
#                      np.min(ortho.y.data), np.max(ortho.y.data)))
# ax[0].set_title('Orthomosaic')
# ax[1].imshow(ortho_masked['band_data'].data[0], cmap='Greys_r', clim=(0,5e4),
#              extent=(np.min(ortho_masked.x.data), np.max(ortho_masked.x.data),
#                      np.min(ortho_masked.y.data), np.max(ortho_masked.y.data)))
# ax[1].set_title('Orthomosaic cropped to reference DEM coverage')
# plt.show()

### Nuth and Kaab + Tilt - coregister all surfaces

For small rotations

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method=['NuthKaab', 'Tilt'], 
#                                         coreg_stable_only=False,
#                                         ss_method=ss_method)

In [None]:
# # -----Adjust vertically using SNOTEL snow depth
# # Load dDEM
# ddem_fn = os.path.join(out_dir, 'MCS_20240420-1_DEM_masked-NMAD2m-refl0_ss-roads_NuthKaab-Tilt-all-surfaces_differences.tif')
# ddem = rxr.open_rasterio(ddem_fn)
# ddem_crs = ddem.rio.crs
# ddem = xr.where(ddem < -1e10, np.nan, ddem)

# # Load SNOTEL data vars
# snotel_fn = os.path.join(data_dir, 'snotel', 'MCS_2020-01-01_2024-06-07_adj.csv')
# snotel = pd.read_csv(snotel_fn)
# # subset to DEM date
# snotel['datetime'] = pd.DatetimeIndex(snotel['datetime'])
# sourcedem_dt = pd.Timestamp(np.datetime64(f'{sourcedem_date[0:4]}-{sourcedem_date[4:6]}-{sourcedem_date[6:]}'))
# Idate = np.argmin([abs(x - sourcedem_dt) for x in snotel['datetime'].values])
# snotel_date = snotel.iloc[Idate]
# snotel_sd = snotel_date['SNWD_m']
# print(f'SNOTEL snow depth = {snotel_sd} m')

# # Load SNOTEL site info
# from shapely import wkt
# snotel_info_fn = os.path.join(data_dir, 'snotel', 'MCS_SNOTEL_site_info.csv')
# snotel_info = pd.read_csv(snotel_info_fn, index_col='Unnamed: 0')
# snotel_info['geometry'] = snotel_info['geometry'].apply(wkt.loads)
# snotel_info = gpd.GeoDataFrame(snotel_info, geometry='geometry', crs='EPSG:4326')
# # reproject to DEM CRS
# snotel_info.to_crs(ddem_crs, inplace=True)

# # Sample dDEM at SNOTEL location
# ddem_sample = ddem.sel(x=snotel_info.geometry[0].coords.xy[0][0], y=snotel_info.geometry[0].coords.xy[1][0], method='nearest').data[0]
# print(f'Elevation diff. at SNOTEL location = {ddem_sample} m')

# # Add the difference from the dDEM
# diff = snotel_sd - ddem_sample
# ddem_adj = ddem + diff

# # Plot results
# fig, ax = plt.subplots(2, 2, figsize=(14,8), gridspec_kw={'height_ratios':[2,1]})
# ax = ax.flatten()
# # Diff before SNOTEL adjustment
# im1 = ax[0].imshow(ddem.data[0], cmap="coolwarm_r", vmin=-5, vmax=5,
#                    extent=(np.min(ddem.x.data)/1e3, np.max(ddem.x.data)/1e3, 
#                            np.min(ddem.y.data)/1e3, np.max(ddem.y.data)/1e3))
# ax[0].set_ylabel('Northing [m]')
# ax[0].set_xlabel('Easting [m]')
# ax[0].set_title('dDEM before SNOTEL adjustment')
# ax[2].hist(np.ravel(ddem.data), bins=100)
# ax[2].set_xlim(-10,10)
# # Diff after SNOTEL adjustment
# im2 = ax[1].imshow(ddem_adj.data[0], cmap="coolwarm_r", vmin=-5, vmax=5,
#                    extent=(np.min(ddem.x.data)/1e3, np.max(ddem.x.data)/1e3, 
#                            np.min(ddem.y.data)/1e3, np.max(ddem.y.data)/1e3))
# ax[1].set_xlabel('Easting [m]')
# ax[1].set_title('dDEM after SNOTEL adjustment')
# ax[3].hist(np.ravel(ddem_adj.data), bins=100)
# ax[3].set_xlim(-10,10)
# for axis in [ax[0], ax[1]]:
#     axis.plot(snotel_info.geometry[0].coords.xy[0][0]/1e3, snotel_info.geometry[0].coords.xy[1][0]/1e3, 
#               '*k', markersize=10, label='SNOTEL site')
#     axis.legend(loc='upper right')
# for im in [im1, im2]:
#     fig.colorbar(im, orientation='vertical', label='Elevation difference [m]', shrink=0.8)
# for axis in [ax[2], ax[3]]:
#     axis.set_xlabel('Elevation differences [m]')
#     axis.vlines(0, 0, 1.7e7, color='k')
# ax[2].set_ylabel('Counts')
# plt.show()

# fig_fn = ddem_fn.replace('.tif', '_SNOTEL-adj.png')
# fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
# print('Figure saved to file:', fig_fn)

### Bias Corr + ICP + Nuth and Kaab - coregister all surfaces

For large shifts, rotations and high amounts of noise

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method=['BiasCorr', 'ICP', 'NuthKaab'], 
#                                         coreg_stable_only=False,
#                                         ss_method=ss_method)

### Nuth and Kaab - coregister all surfaces

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method='NuthKaab', 
#                                         coreg_stable_only=False,
#                                         ss_method=ss_method)

### Nuth and Kaab - coregister stable surfaces only

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method='NuthKaab', 
#                                         coreg_stable_only=True,
#                                         ss_method=ss_method)

### Gradient Descending - coregister all surfaces

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir,
#                                         coreg_method='GradientDescending', 
#                                         coreg_stable_only=False,
#                                         ss_method=ss_method)

### Gradient Descending - coregister stable surfaces only

In [None]:
# diff_after = coregister_difference_dems(ref_dem_fn=refdem_fn, 
#                                         source_dem_fn=sourcedem_fn, 
#                                         ss_mask_fn=ss_fn, 
#                                         out_dir=out_dir
#                                         coreg_method='GradientDescending', 
#                                         coreg_stable_only=True,
#                                         ss_method=ss_method)