# Coregister DEMs to a reference DEM using the method by Nuth and Kaab (2011) and stable surfaces correction

In [1]:
import xdem
import matplotlib.pyplot as plt
import numpy as np
import os
import glob
import rasterio as rio
import geoutils as gu
import pyproj
from scipy.stats import median_abs_deviation as MAD
import warnings
warnings.filterwarnings('ignore')

## Load DEMs

In [3]:
# Load DEM file names
data_path = '/Users/rdcrlrka/Research/PhD/study_sites/MCS/'
refdem_fn = os.path.join(data_path, 'refdem', 'MCS_REFDEM_WGS84.tif')
sourcedem_fn = os.path.join(data_path, '20240420', 'MCS_20240420_DEM.tif')

In [None]:
# Load reference DEM
# Note: XDEM can only read a file with one band, so we first read the DEMs using geoutils to select the first band
ref_dem = xdem.DEM(gu.Raster(ref_dem_fn, load_data=True, bands=1))
ref_dem.set_nodata(0, update_array=False)
ref_dem.plot()
plt.title(os.path.basename(ref_dem_fn))
plt.show()

In [None]:
def coregister(coreg_obj, ref_dem, source_dem, ss_mask, vmin=-1, vmax=1):
    # Set up figure
    fig, ax = plt.subplots(2, 3, figsize=(12,8))
    ax = ax.flatten()
    
    # Calculate difference before registration
    diff_before = source_dem - ref_dem
    diff_before.set_nodata(0)
    # calculate differences over stable surfaces
    ss_mask = ss_mask.reproject(diff_before)
    diff_before_ss = diff_before[ss_mask]
    med_before_ss, nmad_before_ss = np.median(diff_before_ss), xdem.spatialstats.nmad(diff_before_ss)
    # plot
    diff_before.plot(cmap="coolwarm_r", vmin=vmin, vmax=vmax, ax=ax[0])
    ax[0].set_title(f'Difference before coreg. \nSS median = {np.round(med_before_ss,3)}, SS NMAD = {np.round(nmad_before_ss,3)}')
        
    # Apply coregistration
    coreg_obj.fit(ref_dem, source_dem)
    aligned_dem = coreg_obj.apply(source_dem)
    diff_after = aligned_dem - ref_dem
    # calculate differences over stable surfaces
    ss_mask = ss_mask.reproject(diff_after)
    diff_after_ss = diff_after[ss_mask]
    med_after_ss, nmad_after_ss = np.median(diff_after_ss), xdem.spatialstats.nmad(diff_after_ss)
    # plot
    diff_after.plot(cmap="coolwarm_r", vmin=vmin, vmax=vmax, ax=ax[1])
    ax[1].set_title(f'Difference after coreg. \nSS median = {np.round(med_after_ss,3)}, SS NMAD = {np.round(nmad_after_ss,3)}')
    ax[1].set_yticklabels([])
    ax[4].hist(np.ravel(diff_after.data), bins=50)
    ax[4].set_xlabel('Difference [m]')
    ax[4].set_ylabel('Counts')

    # Subtract the median difference over stable surfaces
    diff_after_adj = diff_after - med_after_ss
    # calculate differences over stable surfaces
    ss_mask = ss_mask.reproject(diff_after_adj)
    diff_after_adj_ss = diff_after_adj[ss_mask]
    med_after_adj_ss, nmad_after_adj_ss = np.median(diff_after_adj_ss), xdem.spatialstats.nmad(diff_after_adj_ss)
    # plot
    diff_after_adj.plot(cmap="coolwarm_r", vmin=vmin, vmax=vmax, ax=ax[2])
    ax[2].set_title(f'Difference after coreg. and \nremoving median SS difference'
                    f'\nSS median = {np.round(med_after_adj_ss,3)}, SS NMAD = {np.round(med_after_adj_ss, 3)}')
    ax[2].set_yticklabels([])
    ax[5].hist(np.ravel(diff_after_adj.data), bins=50)
    ax[5].set_xlabel('Difference [m]')
    ax[5].set_ylabel('Counts')

    ax[3].remove()
    fig.subplots_adjust(wspace=0.2)
    plt.show()

    return diff_after, diff_after_adj

In [None]:
# ------Iterate over source DEMs
for fn in [tba_dem_fn]:
    print(f'\n{os.path.basename(fn)}')
    
    # Load DEM
    tba_dem = xdem.DEM(gu.Raster(fn, load_data=True, bands=1))
    tba_dem.set_nodata(0, update_mask=False, update_array=False)
    
    # Reproject to the reference DEM coordinates, etc.
    tba_dem = tba_dem.reproject(ref_dem)

    # Load stable surface mask
    tba_ss_fn = [x for x in ss_fns if os.path.basename(fn)[0:8] in os.path.basename(x)][0]
    tba_ss = gu.Raster(tba_ss_fn, load_data=True)
    
    # Reproject source DEM to reference DEM coords
    tba_ss = tba_ss.reproject(ref_ss, nodata=0)
    
    # Combine masks
    combined_mask_array = (ref_ss.data == 1) & (tba_ss.data == 1)
    combined_mask = gu.Mask.from_array(combined_mask_array, 
                                       transform=ref_ss.transform, 
                                       crs=ref_ss.crs,
                                       nodata=0)
    combined_mask.set_nodata(0)

    # -----Apply Nuth and Kaab registration
    print('\nNuth and Kaab coregistration...')
    coreg_obj = xdem.coreg.NuthKaab()
    diff_nk, diff_adj_nk = coregister(coreg_obj, ref_dem, tba_dem, combined_mask)

    # -----Apply gradient descent coregistration
    # print('\nGradient descent coregistration...')
    # coreg_obj = xdem.coreg.GradientDescending()
    # diff_gd, diff_adj_gd = coregister(coreg_obj, ref_dem, tba_dem, combined_mask)

In [None]:
plt.figure(figsize=(10,10))
diff_adj_nk.plot(cmap='coolwarm_r', vmin=-0.5, vmax=0.5)
plt.show()

In [None]:
sorted(glob.glob(os.path.join(data_path, '*S2_SR*.tif')))

In [None]:
tba_dem_fn

In [None]:
import xarray as xr

In [None]:
im_fn = os.path.join(data_path, '20240318_S2_SR_HARMONIZED.tif')
im_xr = xr.open_dataset(im_fn)
im_xr = ref_im_xr / 1e4

plt.imshow(np.dstack([im_xr.band_data.data[3], im_xr.band_data.data[2], im_xr.band_data.data[1]]),
           extent=(np.min(im_xr.x.data), np.max(im_xr.x.data), np.min(im_xr.y.data), np.max(im_xr.y.data)))
plt.show()

In [None]:
ref_dem_fn