# Test using the trees for horizontal coregistration

In [None]:
import os
import matplotlib.pyplot as plt
import xdem
import numpy as np
import xarray as xr
import rioxarray as rxr
import geoutils as gu

In [None]:
# Define inputs
data_dir = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/MCS'
refdem_fn = os.path.join(data_dir, 'refdem', 'MCS_REFDEM_WGS84_CHM.tif')
sourcedem_fn = os.path.join(data_dir, '20240420', 'MCS_20240420-1_DEM.tif')
ortho_fn = os.path.join(data_dir, '20240420', 'MCS_20240420-1_4band_orthomosaic.tif')

job_dir = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/snow_depth_maps/MCS_20240420-1_ref+chm_sourcetrees'
masks_dir = os.path.join(job_dir, 'land_cover_masks')
trees_mask_fn = os.path.join(masks_dir, 'trees_mask.tif')
roads_mask_fn = os.path.join(masks_dir, 'roads_mask.tif')
snow_mask_fn = os.path.join(masks_dir, 'snow_mask.tif')
ss_mask_fn = os.path.join(masks_dir, 'stable_surfaces_mask.tif')

out_dir = os.path.join(job_dir, 'testing_tree_coreg')
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    print('Made directory for outputs:', out_dir)

In [None]:
# Apply initial deramping on stable surfaces
vmin, vmax = -10, 10

# Define output file names
dem_deramped_fn = os.path.join(out_dir, os.path.basename(sourcedem_fn).replace('.tif', '_deramped.tif'))
fig_fn = os.path.join(out_dir, os.path.basename(sourcedem_fn).replace('.tif', '_deramp_correction.png'))

if not os.path.exists(dem_deramped_fn):
    # Load input files
    refdem = xdem.DEM(refdem_fn)
    sourcedem = xdem.DEM(sourcedem_fn)
    sourcedem = sourcedem.reproject(refdem)
    ss_mask = gu.Raster(ss_mask_fn, load_data=True)
    ss_mask = ss_mask.reproject(ref_dem)
    ss_mask = (ss_mask == 1) # convert to boolean mask

    # Calculate difference refdem
    diff_before = sourcedem - ref_dem

    # Fit and apply Deramp object
    print('Fitting deramper...')
    deramp = xdem.coreg.Deramp(poly_order=2)
    deramp.fit(refdem, sourcedem, inlier_mask=ss_mask)
    meta = deramp.meta
    print(meta)
    dem_deramped = deramp.apply(sourcedem)

    # Save corrected DEM
    dem_deramped.save(dem_deramped_fn)
    print('Deramped DEM saved to file:', dem_deramped_fn)

    # Calculate difference after
    diff_after = dem_deramped - refdem

    # Plot results
    print('Plotting deramp correction results...')
    bins = np.linspace(vmin, vmax, num=100)
    fig, ax = plt.subplots(2, 2, figsize=(10,10))
    ax = ax.flatten()
    diff_before.plot(cmap='coolwarm_r', vmin=vmin, vmax=vmax, ax=ax[0])
    ax[0].set_title('dDEM')
    diff_after.plot(cmap='coolwarm_r', vmin=vmin, vmax=vmax, ax=ax[1])
    ax[1].set_title('Deramped dDEM')
    ax[2].hist(np.ravel(diff_before.data), color='grey', bins=bins)
    ax[2].set_xlim(vmin,vmax)
    ax[2].set_xlabel('Elevation differences (all surfaces) [m]')
    ax[3].hist(np.ravel(diff_after.data), color='grey', bins=bins)
    ax[3].set_xlim(vmin, vmax)
    ax[3].set_xlabel('Elevation differences (all surfaces) [m]')
    fig.tight_layout()
    plt.show()

    # Save figure
    fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
    print('Figure saved to file:', fig_fn)
    
else:
    print('Deramped DEM already exists in file, skipping.')

In [None]:
# Solve for horizontal shift on snow-free surfaces with highest cross-correlation
# Define input file names
tba_dem_fn = dem_deramped_fn
ref_dem_fn = refdem_fn
snow_mask_fn = snow_mask_fn

# Define output file names
dem_xy_shift_fn = tba_dem_fn.replace('.tif', '_xy-shifted.tif')
xy_shifts_fn = os.path.join(out_dir, 'xy_shifts.json')

if not os.path.exists(dem_xy_shift_fn):
    # Load input files
    ref_dem = xdem.DEM(refdem_fn)
    tba_dem = xdem.DEM(dem_deramped_fn)
    tba_dem = tba_dem.reproject(ref_dem)
    
    # Mask snow-covered surfaces in tba_dem
    snow_mask = gu.Raster(snow_mask_fn, load_data=True)
    snow_mask = snow_mask.reproject(ref_dem)
    snow_mask = (snow_mask == 1) # Convert to boolean mask
    tba_dem.set_mask(tba_dem.data.mask | snow_mask.data.data==1)
    
else:
    print('xy-shifted DEM already exists in file, skipping.')

In [None]:
from scipy.signal import correlate2d
from scipy.ndimage import shift
from scipy.optimize import minimize

def subpixel_cross_correlation(image1, image2, max_lag=3.0):
    def correlation_value(lag):
        # Shift image2 by sub-pixel lags (lag[0] is vertical, lag[1] is horizontal)
        shifted_image2 = shift(image2, shift=lag, mode='nearest')
        # Compute negative cross-correlation (we want to maximize correlation)
        correlation = -
        return correlation

    # Initial guess for lag (0, 0)
    initial_lag = [0.0, 0.0]
    
    # Define bounds for the optimizer to limit the search space
    bounds = [(-max_lag, max_lag), (-max_lag, max_lag)]

    # Optimize the lag to minimize the negative correlation value
    result = minimize(correlation_value, initial_lag, bounds=bounds, method='Powell')

    best_lag = result.x
    return best_lag

image1 = ref_dem.data.data
image1[image1==ref_dem.nodata] = 0
image2 = tba_dem.data.data
image2[image2==tba_dem.nodata] = 0

best_lag = subpixel_cross_correlation(image1, image2, max_lag=3.0)

print(f"The best sub-pixel lag is: (Vertical: {best_lag[0]:.4f}, Horizontal: {best_lag[1]:.4f})")

In [None]:
# Plot results
fig, ax = plt.subplots(1, 2, figsize=(10,5))
# correlation before
ax[0].imshow(image1 * image2, 
# correlation after