# Create stable surface masks for DEMs using near-coincident Sentinel-2 imagery and NDSI thresholding

In [None]:
import xdem
import os
import glob
import matplotlib.pyplot as plt
import geopandas as gpd
import ee
import geedim as gd
import xarray as xr
import numpy as np
from tqdm.autonotebook import tqdm
import math

## Define path to data, grab DEM file names

In [None]:
data_path = '/Users/raineyaberle/Research/PhD/SnowDEMs/MCS/lidar/'
dem_fns = sorted(glob.glob(os.path.join(data_path, '*RF*.tif')))
dem_fns

## Initialize Google Earth Engine

In [None]:
try:
    ee.Initialize()
except:
    ee.Authenticate()
    ee.Initialize()

## Create stable surface masks for each DEM

In [None]:
def convert_wgs_to_utm(lon: float, lat: float):
    """
    Return best UTM EPSG code based on WGS84 lat and lon coordinate pair

    Parameters
    ----------
    lon: float
        longitude coordinate
    lat: float
        latitude coordinate

    Returns
    ----------
    epsg_code: str
        optimal UTM zone, e.g. "EPSG:32606"
    """
    utm_band = str((math.floor((lon + 180) / 6) % 60) + 1)
    if len(utm_band) == 1:
        utm_band = '0' + utm_band
    if lat >= 0:
        epsg_code = 'EPSG:326' + utm_band
        return epsg_code
    epsg_code = 'EPSG:327' + utm_band
    
    return epsg_code

def query_gee_for_image(dem_bounds, dem_date, out_path):
    """
    Query Google Earth Engine for Landsat 8 and 9 surface reflectance (SR), Sentinel-2 top of atmosphere (TOA) or SR imagery.
    Images captured within the hour will be mosaicked. For each image, run the classification and snowline detection workflow.

    Parameters
    __________
    dem_bounds: list, numpy.array
        bounds of the DEM used for querying and clipping imagery (format = [xmin, ymin, xmax, ymax])
    date_date: str
        date of the DEM (format = 'YYYYMMDD')
    out_path: str
        path in directory where image will be saved

    Returns
    __________
    im_xr: xarray.Dataset
        resulting image
    """

    # -----Reformat image bounds for image querying and clipping
    region = {'type': 'Polygon',
              'coordinates': [[[dem_bounds[0], dem_bounds[1]],
                              [dem_bounds[2], dem_bounds[1]],
                              [dem_bounds[2], dem_bounds[3]],
                              [dem_bounds[0], dem_bounds[3]],
                              [dem_bounds[0], dem_bounds[1]]
                             ]]
             }

    # -----Define the start and end dates for query (within one month of DEM date)
    dem_dt = np.datetime64(f'{dem_date[0:4]}-{dem_date[4:6]}-{dem_date[6:8]}')
    start_date = str(dem_dt - np.timedelta64(2, 'W'))
    end_date = str(dem_dt + np.timedelta64(2, 'W'))

    # -----Query GEE for imagery
    im_col = gd.MaskedCollection.from_name('COPERNICUS/S2_SR_HARMONIZED').search(start_date=start_date,
                                                                                 end_date=end_date,
                                                                                 region=region,
                                                                                 mask=True,
                                                                                 fill_portion=70)
    im_col_ids = np.array(im_col.ee_collection.aggregate_array('system:id').getInfo())
    def sts_to_date(sts):
        return ee.Date(sts).format('yyyy-MM-dd')
    im_col_dts = np.array(im_col.ee_collection.aggregate_array('system:time_start').map(sts_to_date).getInfo(), dtype='datetime64[D]')

    # -----Download the closest image in time
    # Identify ID of the closest image in time
    dt_diffs = dem_dt - im_col_dts
    Iclosest = np.ravel(np.argwhere(dt_diffs==np.min(dt_diffs)))[0]
    im_id = im_col_ids[Iclosest] 
    im_dt = im_col_dts[Iclosest] 
    print(f'Closest image date = {im_dt}')
    # Create new masked image from ID
    im = gd.MaskedImage.from_id(im_id, mask=True, region=region)
    # Download to file
    im_fn = os.path.join(data_path, str(im_dt).replace('-','') + '_S2_SR_HARMONIZED.tif')
    refl_bands = im.refl_bands
    if not os.path.exists(im_fn):
        im.download(im_fn, region=region, scale=10, crs='EPSG:4326', bands=refl_bands, dtype='int16')
    print('Sentinel-2 image saved to file:', im_fn)

    # -----Open image and restructure data variables
    im_xr = xr.open_dataset(im_fn)
    band_data = im_xr['band_data']
    im_xr_adj = xr.Dataset()
    for i, band_name in enumerate(refl_bands):
        im_xr_adj[band_name] = band_data.isel(band=i)
    im_xr_adj.attrs = im_xr.attrs
    for coord in im_xr.coords:
        im_xr_adj[coord] = im_xr[coord]
    im_xr_adj = im_xr_adj / 1e4 # account for reflectance scalar
    
    return im_xr_adj

def create_stable_surface_mask(im_xr, dem_date, out_fn, crs, plot=True):
    """
    Create stable surface mask by applying an NDSI threshold of 0.35 to the input Sentinel-2 SR image

    Parameters
    ----------
    im_xr: xarray.Dataset
        input Sentinel-2 SR image
    dem_date: str
        observation date of DEM
    out_fn: str
        file name of output stable surfaces file
    crs: str
        coordinate reference system of output file (e.g., "EPSG:4326")
    plot: bool
        whether to plot results
    
    Returns
    ----------
    stable_surfaces: xarray.DataArray
        resulting stable surfaces mask
    
    """
    # Add NDSI band
    im_xr['NDSI'] = (im_xr['B3'] - im_xr['B11']) / (im_xr['B3'] + im_xr['B11'])
    
    # Threshold
    ss_xr = xr.where(im_xr['NDSI'] <= 0.4, 1, 0)
    
    # Plot
    if plot:
        fig, ax = plt.subplots(1, 3, figsize=(12,6))
        ax[0].imshow(np.dstack([im_xr.B4.data, im_xr.B3.data, im_xr.B2.data]),
                     extent=(np.min(im_xr.x.data), np.max(im_xr.x.data), np.min(im_xr.y.data), np.max(im_xr.y.data)))
        ax[0].set_title('Raw image')
        ax[1].imshow(im_xr.NDSI.data, clim=(-1,1), cmap='Blues',
                     extent=(np.min(im_xr.x.data), np.max(im_xr.x.data), np.min(im_xr.y.data), np.max(im_xr.y.data)))
        ax[1].set_title('NDSI')
        ax[2].imshow(ss_xr.data, clim=(0,1), cmap='Greys',
                     extent=(np.min(ss_xr.x.data), np.max(ss_xr.x.data), np.min(ss_xr.y.data), np.max(ss_xr.y.data)))
        ax[2].set_title('Stable surfaces (NDSI threshold)')
        plt.show()

    # Reproject
    ss_xr = ss_xr.rio.write_crs('EPSG:4326')
    ss_xr_reproj = ss_xr.rio.reproject(crs)
    
    # Save to file
    ss_xr.rio.to_raster(out_fn)
    print('Stable surfaces mask saved to file:', out_fn)

    return ss_xr
    

In [None]:
# Iterate over DEM file names
for dem_fn in tqdm(dem_fns):
    # Grab DEM date from file name
    dem_date = os.path.basename(dem_fn)[0:8]
    print(f'\n{dem_date}')
    
    # Check if stable surfaces mask already exists in file
    ss_fn = os.path.join(data_path, dem_date + '_stable_surfaces.tif')

    # Open DEM
    dem = xr.open_dataset(dem_fn)
    # Grab lat lon image bounds
    dem_bounds = dem.rio.reproject('EPSG:4326').rio.bounds()
    if os.path.exists(ss_fn):
        print('Stable surfaces mask already exists for DEM, skipping...')
        continue

    # Solve for best UTM zone
    centroid_lon = (dem_bounds[0] + dem_bounds[2]) / 2
    centroid_lat = (dem_bounds[1] + dem_bounds[3]) / 2
    crs = convert_wgs_to_utm(centroid_lon, centroid_lat)
    print(f'Best UTM zone = {crs}')

    # Query GEE for closest Sentinel-2 image in time
    im_xr = query_gee_for_image(dem_bounds, dem_date, data_path)

    # Create stable surfaces mask
    ss_xr = create_stable_surface_mask(im_xr, dem_date, ss_fn, crs=crs, plot=True)