# Coregister SkySat DEMs to the a reference DEM

In [1]:
import xdem
import geoutils as gu
import os
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Define paths to input files
date = '20211010'
tba_dem_fn = f'/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/SitKusa/{date}/SitKusa_{date}_DEM.tif'
ref_dem_fn = '/Volumes/LaCie/raineyaberle/Research/PhD/Skysat-Stereo/study-sites/SitKusa/refdem/SitKusa_ArcticDEM_buffer30km_filled.tif'
glacier_outlines_fn = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/SitKusa/glacier_boundaries/SitKusa_boundaries_RGIv6.shp'   
out_fn = os.path.join(tba_dem_fn.replace('.tif', '_deramped_coregistered.tif'))

# Load input files
ref_dem = xdem.DEM(ref_dem_fn)
tba_dem = xdem.DEM(tba_dem_fn)
# Clip reference DEM to tba_dem bounds + 1 km
clip_geom = [tba_dem.bounds.left-1e3, tba_dem.bounds.bottom-1e3, tba_dem.bounds.right+1e3, tba_dem.bounds.top+1e3]
ref_dem = ref_dem.crop(clip_geom)
# Load glacier outlines to use as mask
glacier_outlines = gu.Vector(glacier_outlines_fn)
# Reproject to reference DEM grid
tba_dem = tba_dem.reproject(ref_dem)
glacier_outlines = glacier_outlines.reproject(ref_dem)
ss_mask = ~glacier_outlines.create_mask(ref_dem) # stable (non-glacier) surfaces mask

# Plot
fig, ax = plt.subplots()
tba_dem.plot(ax=ax, cmap='terrain')
glacier_outlines.plot(ax=ax, facecolor='None', edgecolor='k')
ax.set_xlim(clip_geom[0], clip_geom[2])
ax.set_ylim(clip_geom[1], clip_geom[3])
plt.show()


In [None]:
# Calculate difference before
diff_before = tba_dem - ref_dem

# Deramp
print('Deramping...')
deramp = xdem.coreg.Deramp(poly_order=2).fit(ref_dem, tba_dem, ss_mask)
dem_deramp = deramp.apply(tba_dem)

# Coregister
print('Coregistering...')
coreg = xdem.coreg.NuthKaab().fit(ref_dem, dem_deramp, ss_mask)
dem_deramp_coreg = coreg.apply(dem_deramp)

# Calculate difference after
diff_after = dem_deramp_coreg - ref_dem

# Save results
dem_deramp_coreg.save(out_fn)
print('Deramped, coregistered DEM saved to file:', out_fn)

# Plot results
fig, ax = plt.subplots(1, 2, figsize=(12,5))
diff_before.plot(ax=ax[0], cmap='coolwarm_r', vmin=-100, vmax=100)
ax[0].set_title('dDEM')
diff_after.plot(ax=ax[1], cmap='coolwarm_r', vmin=-100, vmax=100)
ax[1].set_title('Deramped, coregistered dDEM')
plt.show()