# Pipeline for generating snow depth maps from a source DEM and a reference DEM

In [None]:
import os, glob
import geedim as gd
import ee
import matplotlib.pyplot as plt
import xarray as xr
import rioxarray as rxr
import numpy as np
from tqdm.auto import tqdm
import xdem
import geoutils as gu
import pyproj
from scipy.stats import median_abs_deviation as MAD
import math

## Define paths to DEMs

In [None]:
# Define site name and source DEM date for convenience
site_name = 'MCS'
sourcedem_date = '20240420'
data_path = f'/Users/rdcrlrka/Research/PhD/SkySat-Stereo/study-sites/{site_name}/'
refdem_fn = os.path.join(data_path, 'refdem', 'MCS_REFDEM_WGS84.tif')
sourcedem_fn = glob.glob(os.path.join(data_path, sourcedem_date, '*DEM_masked.tif'))[0]

# Define path for output snow depth images
out_dir = f'/Users/rdcrlrka/Research/PhD/SkySat-Stereo/snow_depth_maps/'

# Check that files exist
# if not os.path.exists(refdem_fn):
#     print('Reference DEM not found, check path to file before continuing.')
if not os.path.exists(sourcedem_fn):
    print('Source DEM not found, check path to file before continuing.')
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    print('Created directory for output files:', out_dir)

## Construct stable surface mask for source DEM

In [None]:
def convert_wgs_to_utm(lon: float, lat: float):
    """
    Return best UTM EPSG code based on WGS84 lat and lon coordinate pair. 

    Parameters
    ----------
    lon: float
        longitude coordinate
    lat: float
        latitude coordinate

    Returns
    ----------
    epsg_code: str
        optimal UTM zone, e.g. "EPSG:32606"
    """
    utm_band = str((math.floor((lon + 180) / 6) % 60) + 1)
    if len(utm_band) == 1:
        utm_band = '0' + utm_band
    if lat >= 0:
        epsg_code = 'EPSG:326' + utm_band
        return epsg_code
    epsg_code = 'EPSG:327' + utm_band
    
    return epsg_code

def query_gee_for_image(dem_fn, dem_date, site_name, out_dir):
    """
    Query Google Earth Engine for Landsat 8 and 9 surface reflectance (SR), Sentinel-2 top of atmosphere (TOA) or SR imagery.
    Images captured within the hour will be mosaicked. For each image, run the classification and snowline detection workflow.

    Parameters
    __________
    dem_bounds: list, numpy.array
        bounds of the DEM used for querying and clipping imagery (format = [xmin, ymin, xmax, ymax])
    date_date: str
        date of the DEM (format = 'YYYYMMDD')
    out_dir: str
        path in directory where image will be saved

    Returns
    __________
    im_xr: xarray.Dataset
        resulting image
    """

    # -----Authenticate and initialize GEE
    try:
        ee.Initialize()
    except:
        ee.Authenticate()
        ee.Initialize()
    
    # -----Load DEM
    dem = xr.open_dataset(dem_fn)
    # Grab lat lon image bounds
    dem_bounds = dem.rio.reproject('EPSG:4326').rio.bounds()
    
    # -----Estimate best UTM zone
    centroid_lon = (dem_bounds[0] + dem_bounds[2]) / 2
    centroid_lat = (dem_bounds[1] + dem_bounds[3]) / 2
    crs = convert_wgs_to_utm(centroid_lon, centroid_lat)
    print(f'Best UTM zone = {crs}')
    
    # -----Reformat image bounds for image querying and clipping
    region = {'type': 'Polygon',
              'coordinates': [[[dem_bounds[0], dem_bounds[1]],
                              [dem_bounds[2], dem_bounds[1]],
                              [dem_bounds[2], dem_bounds[3]],
                              [dem_bounds[0], dem_bounds[3]],
                              [dem_bounds[0], dem_bounds[1]]
                             ]]
             }

    # -----Define the start and end dates for query (within two weeks of DEM date)
    if '-' in dem_date:
        dem_date = dem_date.replace('-','')
    dem_dt = np.datetime64(f'{dem_date[0:4]}-{dem_date[4:6]}-{dem_date[6:8]}')
    start_date = str(dem_dt - np.timedelta64(2, 'W'))
    end_date = str(dem_dt + np.timedelta64(2, 'W'))

    # -----Query GEE for imagery
    im_col = gd.MaskedCollection.from_name('COPERNICUS/S2_SR_HARMONIZED').search(start_date=start_date,
                                                                                 end_date=end_date,
                                                                                 region=region,
                                                                                 mask=True,
                                                                                 fill_portion=70)
    im_col_ids = np.array(im_col.ee_collection.aggregate_array('system:id').getInfo())
    def sts_to_date(sts):
        return ee.Date(sts).format('yyyy-MM-dd')
    im_col_dts = np.array(im_col.ee_collection.aggregate_array('system:time_start').map(sts_to_date).getInfo(), dtype='datetime64[D]')

    # -----Download the closest image in time
    # Identify ID of the closest image in time
    dt_diffs = dem_dt - im_col_dts
    Iclosest = np.ravel(np.argwhere(dt_diffs==np.min(dt_diffs)))[0]
    im_id = im_col_ids[Iclosest] 
    im_dt = im_col_dts[Iclosest] 
    print(f'Closest image date = {im_dt}')
    # Create new masked image from ID
    im = gd.MaskedImage.from_id(im_id, mask=True, region=region)
    # Download to file
    out_fn = os.path.join(os.path.dirname(out_dir), f"{site_name}_{str(im_dt).replace('-','')}_S2_SR_HARMONIZED.tif")
    refl_bands = im.refl_bands
    if not os.path.exists(out_fn):
        im.download(out_fn, region=region, scale=10, crs=crs, bands=refl_bands, dtype='int16')
        print('Sentinel-2 image saved to file:', out_fn)
    else:
        print('Sentinel-2 image file already exists in directory, loading...')

    # -----Open image and restructure data variables
    im_xr = xr.open_dataset(out_fn)
    band_data = im_xr['band_data']
    im_xr_adj = xr.Dataset()
    for i, band_name in enumerate(refl_bands):
        im_xr_adj[band_name] = band_data.isel(band=i)
    im_xr_adj.attrs = im_xr.attrs
    for coord in im_xr.coords:
        im_xr_adj[coord] = im_xr[coord]
    im_xr_adj = im_xr_adj / 1e4 # account for reflectance scalar
    
    return im_xr_adj, crs

def create_stable_surface_mask(im_xr, dem_date, out_fn, crs, plot=True):
    """
    Create stable surface mask by applying an NDSI threshold of 0.35 to the input Sentinel-2 SR image

    Parameters
    ----------
    im_xr: xarray.Dataset
        input Sentinel-2 SR image
    dem_date: str
        observation date of DEM
    out_fn: str
        file name of output stable surfaces file
    crs: str
        coordinate reference system of output file (e.g., "EPSG:4326")
    plot: bool
        whether to plot results
    
    Returns
    ----------
    stable_surfaces: xarray.DataArray
        resulting stable surfaces mask
    
    """
    # Add NDSI band
    im_xr['NDSI'] = (im_xr['B3'] - im_xr['B11']) / (im_xr['B3'] + im_xr['B11'])
    
    # Threshold
    ss_xr = xr.where(im_xr['NDSI'] <= 0.4, 1, 0)

    # Calculate statistics
    num_pixels = len(np.ravel(ss_xr.data))
    num_pixels_stable = len(np.argwhere(np.ravel(ss_xr.data)==1))
    res_m2 = (ss_xr.x.data[1] - ss_xr.x.data[0]) **2
    perc_stable = num_pixels_stable / num_pixels * 100
    area_stable_km2 = num_pixels_stable * res_m2 / 1e6
    print(f'Stable surfaces = {np.round(perc_stable,2)} % of image, {np.round(area_stable_m2, 2)} km2')

    # Write CRS to image
    ss_xr = ss_xr.rio.write_crs(crs)
    
    # Save to file
    ss_xr.rio.to_raster(out_fn)
    print('Stable surfaces mask saved to file:', out_fn)

    ss_xr = rxr.open_rasterio(out_fn)
    
    # Plot
    if plot:
        fig, ax = plt.subplots(1, 3, figsize=(12,6))
        ax[0].imshow(np.dstack([im_xr.B4.data, im_xr.B3.data, im_xr.B2.data]),
                     extent=(np.min(im_xr.x.data), np.max(im_xr.x.data), np.min(im_xr.y.data), np.max(im_xr.y.data)))
        ax[0].set_title('Raw image')
        ax[1].imshow(im_xr.NDSI.data, clim=(-1,1), cmap='Blues',
                     extent=(np.min(im_xr.x.data), np.max(im_xr.x.data), np.min(im_xr.y.data), np.max(im_xr.y.data)))
        ax[1].set_title('NDSI')
        ax[2].imshow(ss_xr.data[0], clim=(0,1), cmap='Greys',
                     extent=(np.min(ss_xr.x.data), np.max(ss_xr.x.data), np.min(ss_xr.y.data), np.max(ss_xr.y.data)))
        ax[2].set_title('Stable surfaces')
        plt.show()
        # save figure
        fig_fn = os.path.splitext(out_fn)[0] + '.png'
        fig.savefig(fig_fn, dpi=250, bbox_inches='tight')
        print('Figure saved to file:', fig_fn)

    return ss_xr
    

In [None]:
# Check if stable surfaces mask already exists in file
ss_fn = os.path.join(out_dir, os.path.splitext(os.path.basename(sourcedem_fn))[0] + '_stable_surfaces_mask.tif')
if os.path.exists(ss_fn):
    print('Stable surfaces mask already exists for DEM')
    ss_xr = rxr.open_rasterio(ss_fn)

else:
    # Query GEE for closest Sentinel-2 image in time
    im_xr, crs = query_gee_for_image(sourcedem_fn, sourcedem_date, site_name, out_dir)
    
    # Create stable surfaces mask
    ss_xr = create_stable_surface_mask(im_xr, sourcedem_date, ss_fn, crs=crs, plot=True)

## Coregister source DEM to reference DEM grid

In [None]:
def create_coreg_object(coreg_name):
    if type(coreg_name) == list:
        try:
            coreg_class = getattr(xdem.coreg.CoregPipeline, coreg_name)
            return coreg_class()
        except AttributeError:
            raise ValueError(f"Coregistration method '{coreg_name}' not found.")
    elif type(coreg_name) == str:
        try:
            coreg_class = getattr(xdem.coreg, coreg_name)
            return coreg_class()
        except AttributeError:
            raise ValueError(f"Coregistration method '{coreg_name}' not found.")
    else:
        print('coreg_method format not recognized, exiting...')
        return None

def coregister_difference(ref_dem_fn=None, source_dem_fn=None, ss_mask_fn=None, coreg_method='NuthKaab'):
    # Load DEMs and stable surface mask
    ref_dem = xdem.DEM(gu.Raster(ref_dem_fn, load_data=True, bands=1))
    tba_dem = xdem.DEM(gu.Raster(source_dem_fn, load_data=True, bands=1))
    ss_mask = gu.Raster(ss_mask_fn, load_data=True)
    ss_mask = (ss_mask == 0)  # Convert to boolean mask

    # Reproject source DEM and stable surface mask to reference DEM grid
    tba_dem = tba_dem.reproject(ref_dem)
    ss_mask = ss_mask.reproject(ref_dem, nodata=0)

    # Set up the coregistration object
    coreg_obj = create_coreg_object(coreg_method)
    if not coreg_obj:
        return None
    
    # Calculate differences before coregistration
    diff_before = tba_dem - ref_dem
    # Stable surface stats
    diff_before_ss = diff_before.copy()
    diff_before_ss.mask = np.logical_and(diff_before.data.mask, ss_mask.data.mask)
    diff_before_ss_median = np.nanmedian(diff_before_ss)
    diff_before_ss_nmad = xdem.spatialstats.nmad(diff_before_ss)
    
    # Fit the coregistration object
    coreg_obj.fit(ref_dem, tba_dem)    
    aligned_dem = coreg_obj.apply(tba_dem)
    
    # Calculate differences after coregistration
    diff_after = aligned_dem - ref_dem
    # Stable surface stats
    diff_after_ss = diff_after.copy()
    diff_after_ss.mask = np.logical_and(diff_after.data.mask, ss_mask.data.mask)
    diff_after_ss_median = np.nanmedian(diff_after_ss)
    diff_after_ss_nmad = xdem.spatialstats.nmad(diff_after_ss)

    # Subtract the median difference over stable surfaces
    diff_after_adj = diff_after - ref_dem
    # Stable surface stats
    diff_after_adj_ss = diff_after_adj.copy()
    diff_after_adj_ss.mask = np.logical_and(diff_after_adj.data.mask, ss_mask.data.mask)
    diff_after_adj_ss_median = np.nanmedian(diff_after_adj_ss)
    diff_after_adj_ss_nmad = xdem.spatialstats.nmad(diff_after_adj_ss)
    
    # Plotting
    fig, ax = plt.subplots(3, 2, figsize=(12,16))
    ax = ax.flatten()
    # Determine color limits
    vmin = np.nanmin([diff_before.data.min(), diff_after.data.min(), diff_after_adj.data.min()])
    vmax = np.nanmax([diff_before.data.max(), diff_after.data.max(), diff_after_adj.data.max()])
    vmax_abs = np.nanmax([np.abs(vmin), np.abs(vmax)])  # Determine max absolute value to center color at 0
    vmin, vmax = -vmax_abs, vmax_abs
    # Differences before coregistration
    diff_before.plot(cmap="coolwarm_r", vmin=vmin, vmax=vmax, ax=ax[0])
    ax[0].set_title(f'Difference before coreg. \nSS median = {np.round(diff_before_ss_median, 3)}, SS NMAD = {np.round(diff_before_ss_nmad, 3)}')
    ax[1].hist(diff_before.data, bins=50)
    # Differences after coregistration
    diff_after.plot(cmap="coolwarm_r", vmin=vmin, vmax=vmax, ax=ax[2])
    ax[2].set_title(f'Difference after coreg. \nSS median = {np.round(diff_after_ss_median, 3)}, SS NMAD = {np.round(diff_after_ss_nmad, 3)}')
    ax[3].hist(diff_after.data, bins=50)
    ax[3].set_xlabel('Difference [m]')
    ax[3].set_ylabel('Counts')
    # Differences after coregistration - median SS difference
    diff_after_adj.plot(cmap="coolwarm_r", vmin=vmin, vmax=vmax, ax=ax[4])
    ax[4].set_title(f'Difference after coreg. and \nremoving median SS difference\nSS median = {np.round(diff_after_adj_ss_median, 3)}, SS NMAD = {np.round(diff_after_adj_ss_nmad, 3)}')
    ax[5].hist(diff_after_adj.data, bins=50)
    ax[5].set_xlabel('Difference [m]')
    ax[5].set_ylabel('Counts')
    # Adjust map units to km
    for axis in [ax[0], ax[2], ax[4]]:
        axis.set_xticks(axis.get_xticks())
        axis.set_xticklabels(np.divide(axis.get_xticks(), 1e3).astype(str))
        axis.set_yticks(axis.get_yticks())
        axis.set_yticklabels(np.divide(axis.get_yticks(), 1e3).astype(str))
        axis.set_xlabel('Easting [km]')
        axis.set_ylabel('Northing [km]')
    
    fig.tight_layout()
    plt.show()

    return diff_after, diff_after_adj


In [None]:
# -----Coregister source DEM and reference DEM
diff, diff_adj = coregister_difference(ref_dem_fn=refdem_fn, 
                                       source_dem_fn=sourcedem_fn, 
                                       ss_mask_fn=ss_fn, 
                                       coreg_method='NuthKaab')


In [None]:
diff.plot(vmin=-10, vmax=10, cmap='coolwarm_r')

In [None]:
diff_adj.plot(vmin=-10, vmax=10, cmap='coolwarm_r')

In [None]:
ref_dem = xdem.DEM(gu.Raster(refdem_fn, load_data=True, bands=1))
tba_dem = xdem.DEM(gu.Raster(sourcedem_fn, load_data=True, bands=1))
ss_mask = gu.Raster(ss_fn, load_data=True)
ss_mask = (ss_mask == 0)  # Convert to boolean mask

# Reproject source DEM and stable surface mask to reference DEM grid
tba_dem = tba_dem.reproject(ref_dem)
ss_mask = ss_mask.reproject(ref_dem, nodata=0)

# Set up the coregistration object
coreg_obj = create_coreg_object("NuthKaab")

# Calculate differences before coregistration
# diff_before = tba_dem - ref_dem
# # Stable surface stats
# diff_before_ss = diff_before.copy()
# diff_before_ss.mask = np.logical_and(diff_before.data.mask, ss_mask.data.mask)
# diff_before_ss_median = np.nanmedian(diff_before_ss)
# diff_before_ss_nmad = xdem.spatialstats.nmad(diff_before_ss)

# Apply coregistration
coreg_obj.fit(ref_dem, tba_dem)

In [None]:
coreg_obj.meta

In [None]:
tba_dem.transform