# DEM post-processing: vertical adjustment with GCPs and error assessment

In [2]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import xdem
import geoutils as gu
import json
import seaborn as sns

# Define paths in directory
data_dir = f"/Volumes/LaCie/raineyaberle/Research/PhD/Skysat-Stereo/study-sites/"
site_name = "JacksonPeak"
date = "20240420"
dem_fn = os.path.join(data_dir, site_name, date, f"ba+DEMuncertainty1mAll_{site_name}_{date}_DEM.tif")
# refdem_fn = os.path.join(data_dir, site_name, 'refdem', f'{site_name}_REFDEM_WGS84.tif')
refdem_fn = os.path.join(data_dir, site_name, "refdem", "USGS_LPC_ID_FEMAHQ_2018_D18_merged_filtered.tif")
ss_mask_fn = os.path.join(data_dir, site_name, date, "stable_surfaces", "stable_surfaces_mask.tif")
snow_mask_fn = os.path.join(data_dir, site_name, date, "stable_surfaces", "snow_mask.tif")
# gcp_fn = "/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/ITD_Functional_Class/ITD_HWY_21.shp"
# gcp_elev = 0
gcp_fn = os.path.join(data_dir, site_name, "snotel", "JacksonPeak_snotel_site_info.gpkg")
gcp_elev = 1.45
out_dir = os.path.join(data_dir, site_name, date, "post_process")

# Check that input files exist
for file, name in [[dem_fn, 'DEM'], [refdem_fn, 'Reference DEM'], 
                   [ss_mask_fn, 'Stable surfaces'], [snow_mask_fn, 'Snow mask'], [gcp_fn, "GCP"]]:
    if not os.path.exists(file):
        print(f"{name} not found, please fix before continuing")

# Make output directory
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

## Vertical adjustment

In [None]:
dem_zshift_fn = os.path.join(out_dir, os.path.splitext(os.path.basename(dem_fn))[0] + '_GCPshift.tif')
if not os.path.exists(dem_zshift_fn):
    # Load input files
    dem = xdem.DEM(dem_fn)
    refdem = xdem.DEM(refdem_fn).reproject(dem)
    gcp = gu.Vector(gcp_fn).reproject(dem)
    gcp = gcp.create_mask(dem)
    ss_mask = gu.Raster(ss_mask_fn).reproject(dem)
    ss_mask = (ss_mask == 1)
    
    # Calculate dDEM
    ddem = dem - refdem
    
    dem.set_nodata(np.nan, update_array=True)
    ddem.set_nodata(np.nan, update_array=True)
    
    # Sample dDEM at GCP points
    ddem_gcp = ddem[gcp]
    ddem_gcp_median = np.nanmedian(ddem_gcp.data)
    zshift = -ddem_gcp_median + gcp_elev
    print(f"Vertical adjustment from GCP = {np.round(float(zshift), 2)} m")
    
    # Apply vertical adjustment to DEM
    dem_zshift = dem + zshift
    ddem_zshift = ddem + zshift
    # Save to file
    dem_zshift.save(dem_zshift_fn)
    print('Shifted DEM saved to file:', dem_zshift_fn)
    
    ddem_zshift.plot(cmap='coolwarm_r', vmin=-5, vmax=5)
    plt.show()
    
    # Make serializable dictionary of results
    zshift_dict = {'original_GCP_dDEM_values_m': json.dumps([float(x) for x in ddem_gcp.data]),
                   'GCP_dDEM_median_m': json.dumps(float(ddem_gcp_median)),
                   'DEM_vertical_shift_m': json.dumps(float(zshift))}  

    # Save to file
    zshift_dict_fn = os.path.splitext(dem_zshift_fn)[0] + '.json'
    with open(zshift_dict_fn, "w") as f:
        json.dump(zshift_dict, f)
    print('Vertical shift dictionary saved to file:', zshift_dict_fn)

    

## dDEM stats

In [None]:
stats_fig_fn = os.path.splitext(dem_zshift_fn)[0] + '_dDEM_stats.png'
if not os.path.exists(stats_fig_fn):
    # Load shifted DEM and reference DEM
    dem = xdem.DEM(dem_zshift_fn)
    refdem = xdem.DEM(refdem_fn).reproject(dem)
    ddem = dem - refdem
    
    # Mask unstable surfaces
    ss_mask = gu.Raster(ss_mask_fn).reproject(dem)
    ss_mask = (ss_mask==1)
    ddem_ss = ddem[ss_mask]
    
    # Calculate median, NMAD, and quantiles over stable surfaces
    ddem_ss_median, ddem_ss_nmad = np.ma.median(ddem_ss), xdem.spatialstats.nmad(ddem_ss)
    ddem_ss_p25, ddem_ss_p75 = np.nanpercentile(ddem_ss.data.data, 25), np.nanpercentile(ddem_ss.data.data, 75)
    
    # Plot
    plt.rcParams.update({'font.size': 12, 'font.sans-serif': "Arial"})
    fig, ax = plt.subplots(1, 2, figsize=(10,5), gridspec_kw={'width_ratios': [1,1.5]})
    # dDEM map plot
    ddem.plot(ax=ax[0], cmap='coolwarm_r', vmin=-5, vmax=5, cbar_title='meters')
    ax[0].set_title('dDEM')
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    # Histograms of elevation residuals
    bins = np.arange(-5, 5, step=0.2)
    ax[1].hist(ddem.data.ravel(), bins=bins, 
               facecolor='gray', edgecolor='k', linewidth=0.5, label='All surfaces')
    ax1 = ax[1].twinx()
    hist = ax1.hist(ddem_ss.data.ravel(), bins=bins,
                    facecolor='m', alpha=0.5, edgecolor='k', linewidth=0.5, label='Stable surfaces')
    ax1.set_ylim(0, np.nanmax(hist[0]) * 1.5)
    ax1.spines['right'].set_color('m')
    ax1.tick_params(axis='y', color='m', labelcolor='m')
    # Lines for stats
    ax[1].axvline(ddem_ss_median, linestyle='-', color='k', linewidth=2, label=f"Median: {np.round(float(ddem_ss_median), 2)} m")
    ax[1].axvline(ddem_ss_nmad, linestyle='--', color='k', linewidth=1, label=f"NMAD: {np.round(float(ddem_ss_nmad), 2)} m")
    ax[1].axvline(ddem_ss_p25, linestyle='--', color='gray', linewidth=1, label="P$_{25}$: " + f"{np.round(float(ddem_ss_p25), 2)} m")
    ax[1].axvline(ddem_ss_p75, linestyle='--', color='gray', linewidth=1, label="P$_{75}$: " + f"{np.round(float(ddem_ss_p75), 2)} m")
    ax[1].set_xlabel('Elevation residual [m]')
    ax[1].set_ylabel('Frequency')
    ax[1].set_xlim(np.min(bins),np.max(bins))
    # add legend
    handles1, labels1 = ax1.get_legend_handles_labels()
    handles2, labels2 = ax[1].get_legend_handles_labels()
    handles = handles1 + handles2
    labels = labels1 + labels2
    ax[1].legend(handles, labels, loc='upper left')
    
    fig.tight_layout()
    plt.show()
    
    # Save figure
    fig.savefig(stats_fig_fn)
    print('dDEM statistics plot saved to file:', stats_fig_fn)
    

## Construct snow depth map

In [None]:
sd_fn = os.path.splitext(dem_zshift_fn)[0] + '_snow_depth.tif'
if not os.path.exists(sd_fn):
    # Load input files
    dem = xdem.DEM(dem_zshift_fn)
    refdem = xdem.DEM(refdem_fn).reproject(dem)
    snow_mask = gu.Raster(snow_mask_fn).reproject(dem)
    snow_mask = (snow_mask == 1)
    
    # Calculate dDEM
    ddem = dem - refdem
    
    # Mask snow-free surfaces
    ddem.set_mask(~snow_mask)
    
    # Save to file
    ddem.save(sd_fn)
    print('Snow depth map saved to file:', sd_fn)
    
    # Plot
    fig, ax = plt.subplots(1, 2, figsize=(10,5))
    ddem.plot(cmap='Blues', vmin=0, vmax=6, ax=ax[0], cbar_title='Snow depth [m]')
    ax[1].hist(ddem.data.ravel(), bins=np.arange(-1,6.1, step=0.2), facecolor='skyblue', edgecolor='k', linewidth=0.5)
    ax[1].set_yticks([])
    ax[1].set_xlabel('Snow depth [m]')
    ax[1].set_xlim(-1,6)
    fig.tight_layout()
    plt.show()
    
else:
    print('Snow depth map already exists in file, skipping.')
    

## Estimate resolution at which terrain parameters should be sampled by modeling the semivariogram

In [None]:
res_dict_fn = os.path.join(data_dir, site_name, 'refdem', 'refdem_terrain_sampling_resolutions.json')
if not os.path.exists(res_dict_fn):

    def calculate_variogram_range(raster):
        # Calculate the empirical variogram
        print('Calculating the empirical variogram...')
        var = xdem.spatialstats.sample_empirical_variogram(raster)
        # Model the spherical variogram using a double-range model
        print('Modeling the variogram...')
        func_sum_vgm, params_vgm = xdem.spatialstats.fit_sum_model_variogram(
            list_models=["Spherical"], empirical_variogram=var
            )
        # Plot
        xdem.spatialstats.plot_variogram(var, 
                                        list_fit_fun=[func_sum_vgm],
                                        xscale_range_split=[10, 100, 1000, 10000])
        plt.show()
        # Estimate correlation length using the range
        range = np.round(float(params_vgm['range'].values[0]))
        print('Modeled range = ', range)
        return range

    def calculate_terrain_parameters(refelev, elev):
        # Calculate terrain params
        slope = refelev.slope()
        aspect = refelev.aspect()

        # Calculate best sampling resolution from modeled variogram range
        res0 = refelev.res
        refelev_res = calculate_variogram_range(refelev)
        slope_res = calculate_variogram_range(slope)
        aspect_res = calculate_variogram_range(aspect)
        
        # Resample using new res
        elev = elev.reproject(res=(refelev_res, refelev_res))
        slope = slope.reproject(res=(slope_res, slope_res))
        aspect = aspect.reproject(res=(aspect_res, aspect_res))
        
        # Reproject back to DEM grid for later calculations
        elev = elev.reproject(elev)
        slope = slope.reproject(elev)
        aspect = aspect.reproject(elev)
        
        # Set no-data values to NaN
        elev.set_nodata(np.nan)
        slope.set_nodata(np.nan)
        aspect.set_nodata(np.nan)
    
        # Save results in dictionary
        res_dict = {'original_sill_m': json.dumps(res0[0]),
                    'elevation_sill_m': json.dumps(refelev_res),
                    'slope_sill_m': json.dumps(slope_res),
                    'aspect_sill_m': json.dumps(aspect_res)}
        
        return elev, slope, aspect, res_dict

    # Load input files
    dem = xdem.DEM(dem_zshift_fn)
    refdem = xdem.DEM(refdem_fn).reproject(dem)

    # Calculate terrain parameters
    elev, slope, aspect, res_dict = calculate_terrain_parameters(refdem, dem)
    with open(res_dict_fn, 'w') as f:
        json.dump(res_dict, f)
    print('Terrain sampling resolutions saved to file:', res_dict_fn)

else:
    print('Estimated terrain sampling already exists, loaded from file.')
    with open(res_dict_fn, 'r') as f:
        res_dict = json.load(f)

res_dict



## Estimate spatial correlation in errors