# Pipeline for generating snow depth maps from a source DEM and a reference DEM

In [None]:
import os, glob
import geedim as gd
import ee
import matplotlib.pyplot as plt
import xarray as xr
import rioxarray as rxr
import numpy as np
from tqdm.auto import tqdm
import xdem
import geoutils as gu
import pyproj
from scipy.stats import median_abs_deviation as MAD
import math

## Define paths to DEMs

In [None]:
# Define site name and source DEM date for convenience
site_name = 'JacksonCreek'
sourcedem_date = '20240420'
data_path = f'/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/{site_name}/'
# refdem_fn = os.path.join(data_path, 'refdem', 'Banner_NASADEM_clip_buffer_2km_EPSG32611+5773.tif')
sourcedem_fn = glob.glob(os.path.join(data_path, sourcedem_date, '*DEM.tif'))[0]

# Define path for output snow depth images
out_dir = f'/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/snow_depth_maps/'

# Check that files exist
# if not os.path.exists(refdem_fn):
#     print('Reference DEM not found, check path to file before continuing.')
if not os.path.exists(sourcedem_fn):
    print('Source DEM not found, check path to file before continuing.')
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    print('Created directory for output files:', out_dir)

## Construct stable surface mask for source DEM

In [None]:
# Authenticate and initialize GEE
try:
    ee.Initialize()
except:
    ee.Authenticate()
    ee.Initialize()

In [None]:
def convert_wgs_to_utm(lon: float, lat: float):
    """
    Return best UTM EPSG code based on WGS84 lat and lon coordinate pair. 

    Parameters
    ----------
    lon: float
        longitude coordinate
    lat: float
        latitude coordinate

    Returns
    ----------
    epsg_code: str
        optimal UTM zone, e.g. "EPSG:32606"
    """
    utm_band = str((math.floor((lon + 180) / 6) % 60) + 1)
    if len(utm_band) == 1:
        utm_band = '0' + utm_band
    if lat >= 0:
        epsg_code = 'EPSG:326' + utm_band
        return epsg_code
    epsg_code = 'EPSG:327' + utm_band
    
    return epsg_code

def query_gee_for_image(dem_fn, dem_date, site_name, out_dir):
    """
    Query Google Earth Engine for Landsat 8 and 9 surface reflectance (SR), Sentinel-2 top of atmosphere (TOA) or SR imagery.
    Images captured within the hour will be mosaicked. For each image, run the classification and snowline detection workflow.

    Parameters
    __________
    dem_bounds: list, numpy.array
        bounds of the DEM used for querying and clipping imagery (format = [xmin, ymin, xmax, ymax])
    date_date: str
        date of the DEM (format = 'YYYYMMDD')
    out_dir: str
        path in directory where image will be saved

    Returns
    __________
    im_xr: xarray.Dataset
        resulting image
    """

    # -----Load DEM
    dem = xr.open_dataset(dem_fn)
    # Grab lat lon image bounds
    dem_bounds = dem.rio.reproject('EPSG:4326').rio.bounds()
    
    # -----Estimate best UTM zone
    centroid_lon = (dem_bounds[0] + dem_bounds[2]) / 2
    centroid_lat = (dem_bounds[1] + dem_bounds[3]) / 2
    crs = convert_wgs_to_utm(centroid_lon, centroid_lat)
    print(f'Best UTM zone = {crs}')
    
    # -----Reformat image bounds for image querying and clipping
    region = {'type': 'Polygon',
              'coordinates': [[[dem_bounds[0], dem_bounds[1]],
                              [dem_bounds[2], dem_bounds[1]],
                              [dem_bounds[2], dem_bounds[3]],
                              [dem_bounds[0], dem_bounds[3]],
                              [dem_bounds[0], dem_bounds[1]]
                             ]]
             }

    # -----Define the start and end dates for query (within two weeks of DEM date)
    if '-' in dem_date:
        dem_date = dem_date.replace('-','')
    dem_dt = np.datetime64(f'{dem_date[0:4]}-{dem_date[4:6]}-{dem_date[6:8]}')
    start_date = str(dem_dt - np.timedelta64(2, 'W'))
    end_date = str(dem_dt + np.timedelta64(2, 'W'))

    # -----Query GEE for imagery
    im_col = gd.MaskedCollection.from_name('COPERNICUS/S2_SR_HARMONIZED').search(start_date=start_date,
                                                                                 end_date=end_date,
                                                                                 region=region,
                                                                                 mask=True,
                                                                                 fill_portion=70)
    im_col_ids = np.array(im_col.ee_collection.aggregate_array('system:id').getInfo())
    def sts_to_date(sts):
        return ee.Date(sts).format('yyyy-MM-dd')
    im_col_dts = np.array(im_col.ee_collection.aggregate_array('system:time_start').map(sts_to_date).getInfo(), dtype='datetime64[D]')

    # -----Download the closest image in time
    # Identify ID of the closest image in time
    dt_diffs = dem_dt - im_col_dts
    Iclosest = np.ravel(np.argwhere(dt_diffs==np.min(dt_diffs)))[0]
    im_id = im_col_ids[Iclosest] 
    im_dt = im_col_dts[Iclosest] 
    print(f'Closest image date = {im_dt}')
    # Create new masked image from ID
    im = gd.MaskedImage.from_id(im_id, mask=True, region=region)
    # Download to file
    out_fn = os.path.join(os.path.dirname(out_dir), f"{site_name}_{str(im_dt).replace('-','')}_S2_SR_HARMONIZED.tif")
    refl_bands = im.refl_bands
    if not os.path.exists(out_fn):
        im.download(out_fn, region=region, scale=10, crs=crs, bands=refl_bands, dtype='int16')
        print('Sentinel-2 image saved to file:', out_fn)
    else:
        print('Sentinel-2 image file already exists in directory, loading...')

    # -----Open image and restructure data variables
    im_xr = xr.open_dataset(out_fn)
    band_data = im_xr['band_data']
    im_xr_adj = xr.Dataset()
    for i, band_name in enumerate(refl_bands):
        im_xr_adj[band_name] = band_data.isel(band=i)
    im_xr_adj.attrs = im_xr.attrs
    for coord in im_xr.coords:
        im_xr_adj[coord] = im_xr[coord]
    im_xr_adj = im_xr_adj / 1e4 # account for reflectance scalar
    
    return im_xr_adj, crs

def create_stable_surface_mask(im_xr, dem_date, out_fn, crs, plot=True):
    """
    Create stable surface mask by applying an NDSI threshold of 0.35 to the input Sentinel-2 SR image

    Parameters
    ----------
    im_xr: xarray.Dataset
        input Sentinel-2 SR image
    dem_date: str
        observation date of DEM
    out_fn: str
        file name of output stable surfaces file
    crs: str
        coordinate reference system of output file (e.g., "EPSG:4326")
    plot: bool
        whether to plot results
    
    Returns
    ----------
    stable_surfaces: xarray.DataArray
        resulting stable surfaces mask
    
    """
    # Add NDSI band
    im_xr['NDSI'] = (im_xr['B3'] - im_xr['B11']) / (im_xr['B3'] + im_xr['B11'])
    
    # Threshold
    ss_xr = xr.where(im_xr['NDSI'] <= 0.4, 1, 0)

    # Calculate statistics
    num_pixels = len(np.ravel(ss_xr.data))
    num_pixels_stable = len(np.argwhere(np.ravel(ss_xr.data)==1))
    res_m2 = (ss_xr.x.data[1] - ss_xr.x.data[0]) **2
    perc_stable = num_pixels_stable / num_pixels * 100
    area_stable_km2 = num_pixels_stable * res_m2 / 1e6
    print(f'Stable surfaces = {np.round(perc_stable,2)} % of image, {np.round(area_stable_m2, 2)} km2')

    # Write CRS to image
    ss_xr = ss_xr.rio.write_crs(crs)
    
    # Save to file
    ss_xr.rio.to_raster(out_fn)
    print('Stable surfaces mask saved to file:', out_fn)
    
    # Plot
    if plot:
        fig, ax = plt.subplots(1, 3, figsize=(12,6))
        ax[0].imshow(np.dstack([im_xr.B4.data, im_xr.B3.data, im_xr.B2.data]),
                     extent=(np.min(im_xr.x.data), np.max(im_xr.x.data), np.min(im_xr.y.data), np.max(im_xr.y.data)))
        ax[0].set_title('Raw image')
        ax[1].imshow(im_xr.NDSI.data, clim=(-1,1), cmap='Blues',
                     extent=(np.min(im_xr.x.data), np.max(im_xr.x.data), np.min(im_xr.y.data), np.max(im_xr.y.data)))
        ax[1].set_title('NDSI')
        ax[2].imshow(ss_xr.data, clim=(0,1), cmap='Greys',
                     extent=(np.min(ss_xr.x.data), np.max(ss_xr.x.data), np.min(ss_xr.y.data), np.max(ss_xr.y.data)))
        ax[2].set_title('Stable surfaces')
        plt.show()
        # save figure
        fig_fn = os.path.splitext(out_fn)[0] + '.png'
        fig.savefig(fig_fn, dpi=250, bbox_inches='tight')
        print('Figure saved to file:', fig_fn)

    return ss_xr
    

In [None]:
# Check if stable surfaces mask already exists in file
ss_fn = os.path.join(out_dir, os.path.splitext(os.path.basename(sourcedem_fn))[0] + '_stable_surfaces_mask.tif')
if os.path.exists(ss_fn):
    print('Stable surfaces mask already exists for DEM, skipping...')

else:
    
    # Query GEE for closest Sentinel-2 image in time
    im_xr, crs = query_gee_for_image(sourcedem_fn, sourcedem_date, site_name, out_dir)
    
    # Create stable surfaces mask
    ss_xr = create_stable_surface_mask(im_xr, sourcedem_date, ss_fn, crs=crs, plot=True)

## Coregister source DEM to reference DEM grid

In [None]:
def coregister(coreg_obj, ref_dem, source_dem, ss_mask):
    # Set up figure
    fig, ax = plt.subplots(2, 3, figsize=(12,8))
    ax = ax.flatten()
    
    # Calculate difference before registration
    diff_before = source_dem - ref_dem
    diff_before.set_nodata(0)
    # calculate differences over stable surfaces
    ss_mask = ss_mask.reproject(diff_before)
    diff_before_ss = diff_before[ss_mask]
    med_before_ss, nmad_before_ss = np.median(diff_before_ss), xdem.spatialstats.nmad(diff_before_ss)
    # plot
    diff_before.plot(cmap="coolwarm_r", ax=ax[0])
    ax[0].set_title(f'Difference before coreg. \nSS median = {np.round(med_before_ss,3)}, SS NMAD = {np.round(nmad_before_ss,3)}')
        
    # Apply coregistration
    coreg_obj.fit(ref_dem, source_dem)
    aligned_dem = coreg_obj.apply(source_dem)
    diff_after = aligned_dem - ref_dem
    # calculate differences over stable surfaces
    ss_mask = ss_mask.reproject(diff_after)
    diff_after_ss = diff_after[ss_mask]
    med_after_ss, nmad_after_ss = np.median(diff_after_ss), xdem.spatialstats.nmad(diff_after_ss)
    # plot
    diff_after.plot(cmap="coolwarm_r", vmin=vmin, vmax=vmax, ax=ax[1])
    ax[1].set_title(f'Difference after coreg. \nSS median = {np.round(med_after_ss,3)}, SS NMAD = {np.round(nmad_after_ss,3)}')
    ax[1].set_yticklabels([])
    ax[4].hist(np.ravel(diff_after.data), bins=50)
    ax[4].set_xlabel('Difference [m]')
    ax[4].set_ylabel('Counts')

    # Subtract the median difference over stable surfaces
    diff_after_adj = diff_after - med_after_ss
    # calculate differences over stable surfaces
    ss_mask = ss_mask.reproject(diff_after_adj)
    diff_after_adj_ss = diff_after_adj[ss_mask]
    med_after_adj_ss, nmad_after_adj_ss = np.median(diff_after_adj_ss), xdem.spatialstats.nmad(diff_after_adj_ss)
    # plot
    diff_after_adj.plot(cmap="coolwarm_r", vmin=vmin, vmax=vmax, ax=ax[2])
    ax[2].set_title(f'Difference after coreg. and \nremoving median SS difference'
                    f'\nSS median = {np.round(med_after_adj_ss,3)}, SS NMAD = {np.round(med_after_adj_ss, 3)}')
    ax[2].set_yticklabels([])
    ax[5].hist(np.ravel(diff_after_adj.data), bins=50)
    ax[5].set_xlabel('Difference [m]')
    ax[5].set_ylabel('Counts')

    ax[3].remove()
    fig.subplots_adjust(wspace=0.2)
    plt.show()

    return diff_after, diff_after_adj

In [None]:
# Load reference DEM
ref_dem = xdem.DEM(gu.Raster(refdem_fn, load_data=True, bands=1))
# refdem.set_nodata(0, update_mask=False, update_array=False)

# Load source DEM
tba_dem = xdem.DEM(gu.Raster(sourcedem_fn, load_data=True, bands=1))
# tba_dem.set_nodata(0, update_mask=False, update_array=False)

# Reproject to the reference DEM coordinates, etc.
tba_dem = tba_dem.reproject(ref_dem)

# Load stable surface mask
tba_ss_fn = [x for x in ss_fns if os.path.basename(fn)[0:8] in os.path.basename(x)][0]
tba_ss = gu.Raster(tba_ss_fn, load_data=True)

# Reproject source DEM to reference DEM coords
tba_ss = tba_ss.reproject(ref_ss, nodata=0)

# Combine masks
combined_mask_array = (ref_ss.data == 1) & (tba_ss.data == 1)
combined_mask = gu.Mask.from_array(combined_mask_array, 
                                   transform=ref_ss.transform, 
                                   crs=ref_ss.crs,
                                   nodata=0)
combined_mask.set_nodata(0)

# -----Apply Nuth and Kaab registration
print('\nNuth and Kaab coregistration...')
coreg_obj = xdem.coreg.NuthKaab()
diff_nk, diff_adj_nk = coregister(coreg_obj, ref_dem, tba_dem, combined_mask)


## Calculate snow depth: $h_{snow} = h_{source} - h_{ref}$