# Coregister SkySat DEMs to the a reference DEM

In [1]:
import xdem
import geoutils as gu
import os
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Define paths to input files
date = '20240403'
dem_fn = f'/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/SitKusa/{date}/SitKusa_{date}_DEM.tif'
refdem_fn = '/Volumes/LaCie/raineyaberle/Research/PhD/Skysat-Stereo/study-sites/SitKusa/refdem/SitKusa_ArcticDEM_buffer30km_filled.tif'
glacier_outlines_fn = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/SitKusa/glacier_boundaries/SitKusa_boundaries_RGIv6.shp'   
out_fn = os.path.join(dem_fn.replace('.tif', '_corrected_coregistered.tif'))

# Load input files
refdem = xdem.DEM(refdem_fn)
dem = xdem.DEM(dem_fn)
# Clip reference DEM to tba_dem bounds + 1 km
clip_geom = [dem.bounds.left-1e3, dem.bounds.bottom-1e3, dem.bounds.right+1e3, dem.bounds.top+1e3]
refdem = refdem.crop(clip_geom)
# Load glacier outlines to use as mask
glacier_outlines = gu.Vector(glacier_outlines_fn)
# Reproject to reference DEM grid
dem = dem.reproject(refdem)
# mask anomalous points
dem.data[dem.data < 0] = np.nan
glacier_outlines = glacier_outlines.reproject(refdem)
ss_mask = ~glacier_outlines.create_mask(refdem) # stable (non-glacier) surfaces mask

# Plot
fig, ax = plt.subplots()
dem.plot(ax=ax, cmap='terrain')
glacier_outlines.plot(ax=ax, facecolor='None', edgecolor='k')
ax.set_xlim(clip_geom[0], clip_geom[2])
ax.set_ylim(clip_geom[1], clip_geom[3])
plt.show()


In [None]:
print('Calculating dDEM before adjustments...')
ddem_before = dem - refdem
ddem_before_ss = ddem_before[ss_mask]
ddem_before_ss_med = np.nanmedian(ddem_before_ss.data[ddem_before_ss.data!=-9999])
ddem_before_ss_nmad = xdem.spatialstats.nmad(ddem_before_ss)

def correct_dem(r, s, ss_mask):
    print('ICP')
    icp = xdem.coreg.ICP().fit(r, s, ss_mask)
    s_icp = icp.apply(s)
    print('Deramp')
    deramp = xdem.coreg.Deramp().fit(r, s_icp, ss_mask)
    s_icp_deramp = deramp.apply(s_icp)
    print(deramp.meta)
    print('Nuth and Kaab coregistration')
    nk = xdem.coreg.NuthKaab().fit(r, s_icp_deramp, ss_mask)
    s_icp_deramp_nk = nk.apply(s_icp_deramp)
    print(nk.meta)
    return s_icp_deramp_nk

print('Correcting DEM...')
dem_corr = correct_dem(refdem, dem, ss_mask)
# mask large anomalies to improve second round of corrections
dem_corr.data[np.abs(dem_corr.data-refdem.data) > 100] = np.nan 
dem_corr = correct_dem(refdem, dem_corr, ss_mask)

print('Calculating difference after bias correction...')
ddem_after = dem_corr - refdem
ddem_after_ss = ddem_after[ss_mask]
ddem_after_ss_med = np.nanmedian(ddem_after_ss.data[ddem_after_ss.data!=-9999])
ddem_after_ss_nmad = xdem.spatialstats.nmad(ddem_after_ss)

# print('Applying vertical correction using median dDEM value at stable surfaces...')
ddem_after = ddem_after - ddem_after_ss_med
final_dem = dem_corr - ddem_after_ss_med
ddem_after_ss_med = 0

# Save results to file
dem_corr.save(out_fn)
print('Corrected, coregistered DEM saved to file:', out_fn)

# Plot results
fig_fn = out_fn.replace('.tif', '.png')
plt.rcParams.update({'font.sans-serif':'Arial', 'font.size':12})
vmin, vmax = -100, 100
fig, ax = plt.subplots(1, 3, figsize=(14,5))
ddem_before.plot(ax=ax[0], cmap='coolwarm_r', vmin=vmin, vmax=vmax)
ax[0].set_title('dDEM')
ddem_after.plot(ax=ax[1], cmap='coolwarm_r', vmin=vmin, vmax=vmax)
ax[1].set_title('Corrected, coregistered dDEM')
bins = np.linspace(vmin, vmax, 100)
ax[2].hist(ddem_after.data.ravel(), bins=bins, color='gray', alpha=0.8, label='All surfaces')
ax2 = ax[2].twinx()
hist = ax2.hist(ddem_after_ss.data.ravel(), bins=bins, color='m', alpha=0.8, label='Stable surfaces')
ax2.set_ylim(0, np.nanmax(hist[0])*1.4)
ax2.spines['right'].set_color('m')
ax2.yaxis.label.set_color('m')
ax2.tick_params(colors='m', which='both')
ax2.set_ylim(0, np.nanmax(hist[0])*1.4)
ax[2].set_xlim(vmin, vmax)
handles1, labels1 = ax[2].get_legend_handles_labels()
handles2, labels2 = ax2.get_legend_handles_labels()
handles, labels = handles1+handles2, labels1+labels2
ax[2].legend(handles, labels, loc='best')
ax[2].set_title(f'SS median = {np.round(ddem_after_ss_med, 3)} m\nSS NMAD = {np.round(ddem_after_ss_nmad, 3)} m')
fig.tight_layout()
plt.show()
# Save figure to file
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)
