In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import xdem
import geoutils as gu
import skimage
from matplotlib.colors import ListedColormap
import shutil
import subprocess
import multiprocessing
import sys

## Define paths in directory

In [None]:
data_path = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/ID-MCS'
asp_path = '/Users/raineyaberle/Research/PhD/SnowDEMs/StereoPipeline-3.5.0-alpha-2024-10-05-x86_64-OSX/bin'
out_path = os.path.join(data_path, '20240420', 'proc_out')

# snow-covered DEM
dem_fn = os.path.join(out_path, 'MCS_20240420_TOAR_DEM.tif')
ortho_fn = os.path.join(out_path, 'MCS_20240420_TOAR_orthomosaic.tif')

# reference DEM
refdem_fn = os.path.join(data_path, 'refdem', 'MCS_REFDEM_WGS84_filled.tif')

# Create output directory if it doesn't exist
if not os.path.exists(out_path):
    os.mkdir(out_path)

## Define utility functions

In [None]:
def identify_optimal_utm(lon: float, lat: float):
    utm_band = str(int((np.floor((lon + 180) / 6) % 60) + 1))
    if len(utm_band) == 1:
        utm_band = '0' + utm_band
    if lat >= 0:
        epsg_code = '326' + utm_band
        return 'EPSG:' + epsg_code
    epsg_code = '327' + utm_band
    return 'EPSG:' + epsg_code

def run_cmd(bin, args, **kw):
    # locate execution file
    binpath = shutil.which(bin)
    if binpath is None:
        msg = ("Unable to find executable %s\n"
        "Install ASP and ensure it is in your PATH env variable\n"
       "https://ti.arc.nasa.gov/tech/asr/intelligent-robotics/ngt/stereo/")
        sys.exit(msg)

    # construct command
    call = [binpath] + args
    print(call)
    try:
        code = subprocess.call(call, shell=False)
    except OSError as e:
        raise Exception('%s: %s' % (binpath, e))
    if code != 0:
        raise Exception('ASP step ' + kw['msg'] + ' failed')


def save_raster(raster, raster_fn, out_dtype='float32', out_nodata=-9999):
    raster = raster.astype(out_dtype)
    raster.set_nodata(out_nodata, update_array=True)
    raster.save(raster_fn, dtype=out_dtype)
    print('Raster saved to file:', raster_fn)
    

def reproject_to_utm(raster_fn, out_path):
    # Load raster and reference raster
    raster = gu.Raster(raster_fn)
    raster_wgs = raster.reproject(crs="EPSG:4326")

    # Determine optimal UTM zone
    cen_lon, cen_lat = raster_wgs.transform[2], raster_wgs.transform[5]
    epsg_utm = identify_optimal_utm(cen_lon, cen_lat)
    print(f"Optimal UTM CRS = {epsg_utm}")

    # Check if raster already in optimal UTM zone
    current_epsg = f"EPSG:{raster.crs.to_epsg()}"
    if current_epsg==epsg_utm:
        print('Raster already in optimal UTM zone, skipping reprojection.')
        return raster_fn
    
    # Define output file name
    raster_out_fn = os.path.join(out_path, 
                                 os.path.splitext(os.path.basename(raster_fn))[0] +
                                 f'_{epsg_utm.replace(':','')}.tif')
    
    # Check if file already exists
    if os.path.exists(raster_out_fn):
        print('Raster in optimal UTM zone already exists in file.')
        return raster_out_fn
    
    # Reproject
    raster_utm = raster.reproject(crs=epsg_utm)

    # Save
    save_raster(raster_utm, raster_out_fn)
    
    return raster_out_fn


def construct_land_cover_masks(dem_fn, ortho_fn, out_path):
    # Define output file names
    veg_mask_fn = os.path.join(out_path, 'mask_vegetation.tif')
    snow_mask_fn = os.path.join(out_path, 'mask_snow.tif')
    ss_mask_fn = os.path.join(out_path, 'mask_stable_surfaces.tif')
    
    # Check if masks already exists
    if not os.path.exists(ss_mask_fn):
        # Load DEM and orthomosaic
        dem = xdem.DEM(dem_fn)
        ortho = gu.Raster(ortho_fn, load_data=True).reproject(dem)
        ortho[ortho==65534] = -9999
        ortho.set_nodata(0, update_array=True)
        ortho = ortho / 1e5
        blue, green, red, nir = [gu.Raster.from_array(ortho.data[i], 
                                                      transform=ortho.transform, 
                                                      crs=ortho.crs,
                                                      nodata=ortho.nodata) for i in range(len(ortho.bands)-1)]

        # vegetation mask
        ndvi = (nir - red) / (nir + red)
        ndvi_threshold = skimage.filters.threshold_otsu(ndvi.data.data)
        print('NDVI threshold for vegetation mask = ', ndvi_threshold)
        veg_mask = (ndvi > ndvi_threshold)

        # snow mask
        ndsi = (green - nir) / (green + nir)
        ndsi_threshold = 0.1
        snow_mask = (ndsi > ndsi_threshold) & (green > 0.1)

        # stable surfaces mask
        ss_mask = (~snow_mask & ~veg_mask)

        # save to file
        for raster, raster_fn in [[veg_mask, veg_mask_fn], [snow_mask, snow_mask_fn], [ss_mask, ss_mask_fn]]:
            save_raster(raster, raster_fn, out_dtype='int8', out_nodata=0)

    else:
        veg_mask = gu.Raster(veg_mask_fn)
        snow_mask = gu.Raster(snow_mask_fn)
        ss_mask = gu.Raster(ss_mask_fn)

    # Plot and save figure
    fig_fn = os.path.join(out_path, 'masks.png')
    if not os.path.exists(fig_fn):
        fig, ax = plt.subplots(1, 2, figsize=(10,5))
        # RGB
        ax[0].imshow(np.dstack([red.data, green.data, blue.data]), clim=(0, 0.5), extent=red.bounds)
        # Land cover masks
        snow_color = (65/255, 182/255, 196/255, 0.7)
        veg_color = (35/255, 132/255, 67/255, 0.7)
        ss_color = (115/255, 115/255, 115/255, 0.7)
        snow_mask.plot(cmap=ListedColormap([(0,0,0,0), snow_color]), 
                    vmin=0, vmax=1, add_cbar=False, ax=ax[1])
        veg_mask.plot(cmap=ListedColormap([(0,0,0,0), veg_color]), 
                    vmin=0, vmax=1, add_cbar=False, ax=ax[1])
        ss_mask.plot(cmap=ListedColormap([(0,0,0,0), ss_color]), 
                    vmin=0, vmax=2, add_cbar=False, ax=ax[1])
        # dummy points for legend
        xmin, xmax = ax[1].get_xlim()
        ymin, ymax = ax[1].get_ylim()
        ax[1].plot(0, 0, 's', markerfacecolor=snow_color, markeredgecolor='gray', label='snow')
        ax[1].plot(0, 0, 's', markerfacecolor=veg_color, markeredgecolor='gray', label='vegetation')
        ax[1].plot(0, 0, 's', markerfacecolor=ss_color, markeredgecolor='gray', label='stable surfaces')
        ax[1].set_xlim(xmin, xmax)
        ax[1].set_ylim(ymin, ymax)
        ax[1].legend(loc='upper left', markerscale=2)
        plt.show()

        fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
        print('Figure saved to file:', fig_fn)


def align_dem(refdem_fn=None, dem_fn=None, out_dir=None, asp_dir=os.getcwd(),
              stable_only=False, ss_mask_fn=None, max_displacement=100, tr=2, threads=0):
    """
    Align a DEM to a reference DEM using ASP's pc_align function in two steps. 
    First, use the Iterative Closest Point method for translational and rotational alignment. 
    Then, use the Nuth and Kaab method for subpixel translational alignment. 

    Parameters
    ----------
    refdem_fn: str | Path
        file name of the reference DEM
    dem_fn: str | Path
        file name of the DEM to be aligned
    out_dir: str | Path
        path to the folder where outputs will be saved
    asp_dir: str | Path
        path to Ames Stereo Pipeline "bin" folder
    stable_only: bool
        whether to coregister only over stable surfaces
    ss_mask_fn: str | Path
        file name of the stable surfaces mask, required if stable_only=True
    max_displacement: float | int
        maximum displacement of the alignment, passed to pc_align
    tr: float | int
        target resolution for output aligned DEM
    threads: int
        number of threads to run pc_align (default: total CPU * 0.75, rounded to nearest integer)
    
    Returns
    ----------
    dem_nk_out_fn: str
        file name of the aligned DEM
    """
    os.makedirs(out_dir, exist_ok=True)

    def get_stable_surface_dem():
        dem_ss_fn = dem_fn.replace('.tif', '_ss.tif')
        if not os.path.exists(dem_ss_fn):
            print('Masking DEM to stable surfaces for alignment...')
            dem = xdem.DEM(dem_fn, load_data=True)
            ss_mask = gu.Raster(ss_mask_fn, load_data=True)
            dem.set_mask(~(ss_mask == 1))
            save_raster(dem, dem_ss_fn)
        return dem_ss_fn

    def apply_transform(input_dem, transform_txt, output_prefix):
        pc_out = output_prefix + '-trans_source.tif'
        grid_out = output_prefix + '-trans_source-DEM.tif'

        if not os.path.exists(grid_out):
            align_args = [
                '--max-displacement', '-1',
                '--num-iterations', '0',
                '--threads', str(threads),
                '--initial-transform', transform_txt,
                '--save-transformed-source-points',
                '-o', output_prefix,
                refdem_fn, input_dem
            ]
            out = run_cmd(os.path.join(asp_dir, 'pc_align'), align_args)
            print(out)

            grid_args = ['--tr', str(tr),
                         '--threads', str(threads),
                         pc_out,
                         '-o', os.path.splitext(pc_out)[0]]
            out = run_cmd(os.path.join(asp_dir, 'point2dem'), grid_args)
            print(out)
        else:
            print(f'Output already exists for {output_prefix}, skipping transform.')
        return grid_out

    # Select DEM to use for initial alignment
    dem_coreg_fn = get_stable_surface_dem() if stable_only else dem_fn

    print('Aligning the DEMs in two rounds:\n'
          '1) ICP for translation and rotation\n'
          '2) Nuth and Kaab for subpixel translation\n')

    # Round 1: ICP
    out_prefix_icp = os.path.join(out_dir, 'run-icp')
    out_transform_icp_fn = out_prefix_icp + '-transform.txt'
    out_dem_icp_fn = out_prefix_icp + '-trans_source-DEM.tif'

    if threads == 0:
        ncpu = multiprocessing.cpu_count()
        threads = int(ncpu * 0.75)
        print(f"Detected {ncpu} CPUs, using {threads} threads")

    if not os.path.exists(out_transform_icp_fn):
        print('Running ICP alignment...')
        align_args = [
            '--max-displacement', str(max_displacement),
            '--threads', str(threads),
            '--highest-accuracy',
            '--alignment-method', 'point-to-plane',
            '--save-transformed-source-points',
            '-o', out_prefix_icp,
            refdem_fn, dem_coreg_fn
        ]
        out = run_cmd(os.path.join(asp_dir, 'pc_align'), align_args)
        print(out)

        # Grid ICP output
        pc_out = out_prefix_icp + '-trans_source.tif'
        grid_args = ['--tr', str(tr), '--threads', str(threads),
                     pc_out, '-o', os.path.splitext(pc_out)[0]]
        out = run_cmd(os.path.join(asp_dir, 'point2dem'), grid_args)
        print(out)
    else:
        print('ICP output already exists, skipping.')

    # Round 2: Nuth and Kaab
    print('Running Nuth and Kaab alignment...')
    out_dem_nk_fn = os.path.join(out_dir, 'run-nk-DEM.tif')
    out_transform_nk_fn = os.path.join(out_dir, 'run-nk-transform.txt')

    if not os.path.exists(out_transform_nk_fn):
        dem_icp = xdem.DEM(out_dem_icp_fn, load_data=True)
        refdem = xdem.DEM(refdem_fn, load_data=True).reproject(dem_icp)

        nk = xdem.coreg.NuthKaab().fit(refdem, dem_icp)
        dem_nk = nk.apply(dem_icp)

        if not stable_only:
            save_raster(dem_nk, out_dem_nk_fn)

        metadata = nk._meta
        tx = metadata['outputs']['affine']['shift_x']
        ty = metadata['outputs']['affine']['shift_y']
        tz = metadata['outputs']['affine']['shift_z']
        transform_matrix = np.array([
            [1, 0, 0, tx],
            [0, 1, 0, ty],
            [0, 0, 1, tz],
            [0, 0, 0, 1]
        ])
        np.savetxt(out_transform_nk_fn, transform_matrix, fmt="%.6f")
        print("Nuth and Kaab transform matrix saved:", out_transform_nk_fn)
        print(transform_matrix)

    # Apply both transforms to full DEM if stable_only was used
    if stable_only:
        print('Applying ICP and NK transforms to full DEM...')
        # Step 1: Apply ICP to full DEM
        out_prefix_full_icp = os.path.join(out_dir, 'run-full-icp')
        full_icp_out = apply_transform(dem_fn, out_transform_icp_fn, out_prefix_full_icp)

        # Step 2: Apply NK to ICP-transformed DEM
        out_prefix_final = os.path.join(out_dir, 'run-full-final')
        final_aligned_dem = apply_transform(full_icp_out, out_transform_nk_fn, out_prefix_final)
        return final_aligned_dem

    return out_dem_nk_fn


## Reproject rasters to optimal UTM zone

In [None]:
dem_fn = reproject_to_utm(dem_fn, out_path)
ortho_fn = reproject_to_utm(ortho_fn, out_path)
refdem_fn = reproject_to_utm(refdem_fn, out_path)

In [None]:
dem = xdem.DEM(dem_fn)
refdem = xdem.DEM(refdem_fn).reproject(dem)

coreg = xdem.coreg.CoregPipeline([xdem.coreg.ICP(), xdem.coreg.NuthKaab()]).fit(refdem, dem)
dem_coreg = coreg.apply(dem)

ddem = dem_coreg - refdem

ddem.plot(cmap='coolwarm_r', vmin=-5, vmax=5)
plt.show()

## Construct land cover masks

In [None]:
construct_land_cover_masks(dem_fn, ortho_fn, out_path)

## Coregister DEM to reference DEM

In [None]:
import geopandas as gpd

# Define outputs
dem_aligned_fn = os.path.join(out_path, 'DEM_aligned.tif')
ddem_fn = os.path.join(out_path, 'dDEM.tif')

if not os.path.exists(dem_aligned_fn):

    # Load inputs
    dem = xdem.DEM(dem_fn)
    refdem = xdem.DEM(os.path.join(data_path, 'refdem', 'MCS_REFDEM_WGS84_filled.tif')).reproject(dem)

    # Mask DEM to stable surfaces
    # ss_mask_fn = os.path.join(out_path, 'mask_stable_surfaces.tif')
    # ss_mask = gu.Raster(ss_mask_fn, load_data=True).reproject(dem)
    # ss_mask = (ss_mask == 1)
    # dem_ss = dem.copy()
    # dem_ss.set_mask(~ss_mask)

    # Coregister
    coreg = (xdem.coreg.CoregPipeline([xdem.coreg.Deramp(), 
                                       xdem.coreg.LZD(), 
                                       xdem.coreg.NuthKaab()])
             .fit(refdem, dem_ss))
    dem_aligned = coreg.apply(dem)

    # Calculate differential DEM (dDEM)
    ddem = dem_aligned - refdem

    # Vertical adjustment using SNOTEL
    snotel_gdf = gpd.GeoDataFrame(geometry=[shapely.geometry.Point([-115.66587829589844, 43.93199920654297])], 
                                  crs='EPSG:4326')
    snotel_gdf = snotel_gdf.to_crs(crs='EPSG:32611')
    sd_snotel = 1.32
    ddem_snotel = ddem.interp_points(points=(snotel_gdf.geometry[0].coords.xy[0][0], 
                                      snotel_gdf.geometry[0].coords.xy[1][0]))[0]
    zshift = sd_snotel - ddem_snotel
    dem += zshift
    ddem += zshift
    
    # Save aligned DEM and dDEM to file
    save_raster(ddem, ddem_fn)
    save_raster(dem_aligned, dem_aligned_fn)
    
    # plot results
    ddem.plot(cmap='coolwarm_r', vmin=-5, vmax=5)
    plt.show()



## Calculate snow depths and stable surface errors

In [None]:
# Define outputs
ddem_snow_fn = os.path.join(out_path, 'dDEM_snow.tif')
ddem_ss_fn = os.path.join(out_path, 'dDEM_stable.tif')
fig_fn = os.path.join(out_path, 'dDEM_results.png')

if not os.path.exists(ddem_snow_fn):

    # load inputs
    ddem = xdem.DEM(ddem_fn)
    snow_mask = gu.Raster(os.path.join(out_path, 'mask_snow.tif')).reproject(dem)
    snow_mask = (snow_mask==1)
    ss_mask = gu.Raster(os.path.join(out_path, 'mask_stable_surfaces.tif')).reproject(dem)
    ss_mask = (ss_mask==1)

    # calculate snow depths
    ddem_snow = ddem.copy()
    ddem_snow.set_mask((~snow_mask).data.data)

    # calculate stable surface errors
    ddem_ss = ddem.copy()
    ddem_ss.set_mask((~ss_mask).data.data)

    # save outputs to file
    for raster, raster_fn in [[ddem_snow, ddem_snow_fn], [ddem_ss, ddem_ss_fn]]:
        save_raster(raster, raster_fn)
    
    # plot results
    fig, ax = plt.subplots(2, 2, figsize=(12,5), gridspec_kw={'height_ratios': [2,1]})
    ddem_ss.plot(ax=ax[0,0], cmap='coolwarm_r', vmin=-5, vmax=5)
    ax[1,0].hist(ddem_ss.data.ravel(), bins=np.linspace(-5,5,100), facecolor='gray', edgecolor='gray', alpha=0.8)
    ax[1,0].axvline(np.ma.median(ddem_ss.data), color='k', label='Median')
    ax[1,0].set_xlim(-5,5)
    ddem_snow.plot(ax=ax[0,1], cmap='Blues', vmin=-1, vmax=5)
    ax[1,1].hist(ddem_snow.data.ravel(), bins=np.linspace(0, 5, 100), facecolor='blue', edgecolor='blue', alpha=0.8)
    ax[1,1].set_xlim(0,5)
    plt.show()
    
    fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
    print('Figure saved to file:', fig_fn)


In [None]:
np.ma.median(ddem_ss.data), gu.stats.nmad(ddem_ss.data)