# Test the pipeline for an off-glacier AOI, classify snow using NDSI

In [None]:
# wxee is not included in the default environment. Install by uncommenting the line below.
# !micromamba install -c conda-forge wxee -y

import os
from glob import glob
import ee
import sys
import wxee as wx
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from tqdm import tqdm
import xarray as xr
import rioxarray as rxr

# -----Define local folder for exports
out_folder = '/Users/rdcrlrka/Research/glacier_snow_mapping/LemonCreek_watershed_AOI'

# -----Import pipeline utilities
# Assumes pipeline_utils.py is located one folder above this notebook
script_path = os.getcwd()
sys.path.append(os.path.join(script_path, '..'))
import glasee_pipeline_utils as utils

# -----Define image search settings
# Date and month ranges (inclusive)
date_start = '2019-04-01' 
date_end = '2019-10-31' 
month_start = 4 # April = 4
month_end = 10 # Oct = 10
# Minimum fill portion percentage of the AOI (0â€“100), used to remove images after mosaicking by day
min_aoi_coverage = 70
# Whether to mask clouds using the respective cloud mask via the geedim package
mask_clouds = True

## Authenticate GEE

In [None]:
project_id = "ee-raineyaberle"

try:
    ee.Initialize(project=project_id)
except:
    ee.Authenticate()
    ee.Initialize(project=project_id)

## Define AOI and query GEE for DEM

In [None]:
# -----Manual AOI
# save just the coordinates first for later use
aoi_coords = [
[-134.379943, 58.417939],
[-134.362549, 58.388341],
[-134.361008, 58.3727101],
[-134.369808, 58.359844],
[-134.470524, 58.3375671],
[-134.558482, 58.381999],
[-134.379943, 58.41793911]
]
# convert to ee.Geometry
aoi = ee.Geometry.Polygon(aoi_coords)

aoi_area = aoi.area().getInfo()
print(f"AOI = {int(aoi_area/1e6)} km2")

# -----Query GEE for DEM
dem = utils.query_gee_for_dem(aoi)
# Save the DEM to file
dem_file = os.path.join(out_folder, "ArcticDEM_clipped.tif")
if not os.path.exists(dem_file):
    dem = dem.set('system:time_start', 0)
    dem_xr = dem.wx.to_xarray(region=aoi, scale=10)
    dem_xr.isel(time=0).rio.to_raster(dem_file)
    print("Clipped DEM saved to file:", dem_file)
else:
    print("Clipped DEM already exists in file, skipping.")

# -----Identify the best UTM zone for outputs
def convert_wgs_to_utm(lon: float, lat: float):
    utm_band = str(int((np.floor((lon + 180) / 6) % 60) + 1))
    if len(utm_band) == 1:
        utm_band = '0' + utm_band
    if lat >= 0:
        epsg_code = 'EPSG:326' + utm_band
        return epsg_code
    epsg_code = 'EPSG:327' + utm_band
    return epsg_code

aoi_cen_lon = float(np.nanmean(np.array(aoi_coords)[:,0]))
aoi_cen_lat = float(np.nanmean(np.array(aoi_coords)[:,1]))
utm_crs = convert_wgs_to_utm(aoi_cen_lon, aoi_cen_lat)
print("Optimal UTM zone =", utm_crs)

## Query GEE for imagery, classify, save outputs

In [None]:
# Test with one dataset
dataset = "Sentinel-2_SR"
scale = 30

# Make sure out_folder exists
os.makedirs(out_folder, exist_ok=True)

# Create a colormap for plotting classified images
colors = ["#4eb3d3", "#6a51a3", "#084081", "#fe9929", "#252525"] 
bounds = [1, 2, 3, 4, 5, 6] 
cmap = matplotlib.colors.ListedColormap(colors)
norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)

# Create arrays of date start and date ends to iterate over
date_start_arr = np.arange(np.datetime64(date_start), np.datetime64(date_end))
date_end_arr = date_start_arr + np.timedelta64(1, 'D')

for d_start, d_end in tqdm(list(zip(date_start_arr, date_end_arr))):
    # Define outputs
    image_file = os.path.join(out_folder, f"{d_start}_{dataset}.tif")
    classified_image_file = os.path.join(out_folder, f"{d_start}_{dataset}_classified.tif")
    figure_file = os.path.join(out_folder, f"{d_start}_{dataset}_results.png")

    # Continue if outputs exist
    if os.path.exists(image_file) & os.path.exists(classified_image_file) & os.path.exists(figure_file):
        print(f"Outputs already exist, skipping {d_start} to {d_end}")
        continue
    
    # Query GEE for images
    image_col = utils.query_gee_for_imagery(
        dataset = dataset,
        aoi = aoi,
        date_start = str(d_start), 
        date_end = str(d_end), 
        fill_portion = min_aoi_coverage, 
        mask_clouds = mask_clouds, 
        scale = None, 
        verbose = False
    )

    # Check that there were any results
    if image_col.size().getInfo() < 1:
        print(f"No images found for {d_start} to {d_end}")
        continue

    classified_image_col = utils.classify_image_collection(
        collection = image_col,
        dataset = dataset,
        verbose = False
    )

    # Convert image "collections" to xarray.Datasets and save to file
    image_xr = image_col.first().wx.to_xarray(region=aoi, scale=scale, crs=utm_crs).isel(time=0)
    image_xr.rio.to_raster(image_file)
    print("Image saved to:", image_file)
    
    classified_image_xr = classified_image_col.first().wx.to_xarray(region=aoi, scale=scale, crs=utm_crs).isel(time=0)
    classified_image_xr.rio.to_raster(classified_image_file)
    print("Classified image saved to:", classified_image_file)    

    plt.rcParams.update({"font.size": 12, "font.sans-serif": "Verdana"})
    fig, ax = plt.subplots(2, 1, figsize=(10,10))
    fig.subplots_adjust(right=1.0)
    cax = fig.add_axes([0.83, 0.14, 0.02, 0.3])
    # RGB
    if dataset=="Landsat":
        rgb_bands = ["SR_B4", "SR_B3", "SR_B2"]
    else:
        rgb_bands = ["B4", "B3", "B2"]
    ax[0].imshow(
        np.dstack([image_xr[rgb_bands[0]], image_xr[rgb_bands[1]], image_xr[rgb_bands[2]]]),
        extent=(
            min(image_xr.x.data)/1e3, max(image_xr.x.data)/1e3, 
            min(image_xr.y.data)/1e3, max(image_xr.y.data)/1e3
        )
    )
    # classified image
    im = ax[1].imshow(
        classified_image_xr.classification,
        cmap=cmap,
        norm=norm,
        extent=(
            min(classified_image_xr.x.data)/1e3, max(classified_image_xr.x.data)/1e3, 
            min(classified_image_xr.y.data)/1e3, max(classified_image_xr.y.data)/1e3
        )
    )
    cbar = fig.colorbar(im, cax=cax)
    cbar.set_ticks([1.5, 2.5, 3.5, 4.5, 5.5])
    cbar.set_ticklabels([
        'Snow', 'Shadowed snow', 'Ice', 'Rock/debris', 'Water'
    ])
    cbar.ax.minorticks_off()
    ax[0].set_ylabel("Northing [km]")
    ax[1].set_ylabel("Northing [km]")
    ax[1].set_xlabel("Easting [km]")
    
    # Save figure
    fig.savefig(figure_file, dpi=300, bbox_inches='tight')
    print("Figure saved to:", figure_file)
    plt.close()
    

## Calculate snow cover statistics

Let's see how the original statistics play out for this case. Here is a Python version that we can use for proof-of-concept. The original function for reference is located in `glasee_pipeline_utils.py > calculate_snow_cover_statistics`

In [None]:
def calculate_snow_cover_statistics_python(
        classified_image_file: str = None, 
        dem_file: str = None, 
        scale: int = None,
        ):
    """
    Calculate snow cover statistics the input image. The function will calculate the following
    statistics for each image: snow area, ice area, rock area, water area, glacier area, SLA,
    SLA upper bound, and SLA lower bound. 

    Parameters
    ----------
    classified_image_file : str
        File path to the classified image (e.g., from the classify_image_collection function).
    dem_file : str
        File path to the DEM (e.g., clipped ArcticDEM).
    scale : int
        Spatial scale in meters to use for area calculations and DEM sampling. If None, defaults to 30 m for Landsat and 10 m for Sentinel-2.
    
    Returns
    ----------
    None
    """
    # --- Check if output figure already exists ---
    fig_file = classified_image_file.replace('_classified.tif', '_SLA.png')
    if os.path.exists(fig_file):
        return

    # --- Load inputs ---
    date = os.path.basename(classified_image_file).split("_")[0]
    dataset = os.path.basename(classified_image_file).split("_")[1]
    print(date, dataset)

    if not scale:
        scale = 30 if (dataset == 'Landsat') else 10

    classified_image = rxr.open_rasterio(classified_image_file).squeeze()
    classified_image = classified_image.where(classified_image != classified_image.attrs.get("_FillValue"))
    
    dem = rxr.open_rasterio(dem_file).squeeze()
    dem = dem.rio.reproject_match(classified_image)
    dem = dem.where(dem != dem.attrs.get("_FillValue"))

    # --- Calculate areas for each class ---
    pixel_area = float(scale) ** 2
    aoi_area = float(np.count_nonzero(~np.isnan(dem.data))) * pixel_area

    snow_mask = (classified_image == 1) | (classified_image == 2)
    ice_mask = classified_image == 3
    rock_mask = classified_image == 4
    water_mask = classified_image == 5

    snow_area = float(snow_mask.sum()) * pixel_area
    ice_area = float(ice_mask.sum()) * pixel_area
    rock_area = float(rock_mask.sum()) * pixel_area
    water_area = float(water_mask.sum()) * pixel_area
    glacier_area = snow_area + ice_area

    # --- Estimate Snowline Altitude (SLA) ---
    snow_dem = dem.where(snow_mask)
    sla = float(snow_dem.quantile(0.05, skipna=True))

    # --- Estimate SLA upper and lower bounds ---
    # "Reference system switch": find the DEM percentile corresponding to the SLA
    below_sla_mask = dem < sla
    below_sla_mask_area = float(below_sla_mask.sum()) * pixel_area
    sla_percentile_dem = (below_sla_mask_area / aoi_area) * 100

    # Upper bound
    snow_free_mask = (classified_image >= 3)
    above_sla_mask = dem > sla
    sla_upper_mask = snow_free_mask & above_sla_mask
    sla_upper_mask_area = float(sla_upper_mask.sum()) * pixel_area
    # DEM percentile to sample = (SLA percentile) + (Area snow-free above SLA / Total AOI Area)
    sla_upper_percentile = sla_percentile_dem + (sla_upper_mask_area / aoi_area) * 100
    sla_upper = float(dem.quantile(sla_upper_percentile / 100, skipna=True))

    # Lower bound
    sla_lower_mask = snow_mask & below_sla_mask
    sla_lower_mask_area = float(sla_lower_mask.sum()) * pixel_area
    # DEM percentile to sample = (SLA percentile) - (Area snow-covered below SLA / Total AOI Area)
    sla_lower_percentile = sla_percentile_dem - (sla_lower_mask_area / aoi_area) * 100
    sla_lower = float(dem.quantile(sla_lower_percentile / 100, skipna=True))

    # --- Plot results ---
    plt.rcParams.update({"font.size": 12, "font.sans-serif": "Verdana"})
    fig, ax = plt.subplots(1, 2, figsize=(12,6), gridspec_kw=dict(width_ratios=[1, 1.5]))
    
    # Classified image + SLA contours
    colors = ["#4eb3d3", "#6a51a3", "#084081", "#fe9929", "#252525"] 
    bounds = [1, 2, 3, 4, 5, 6] 
    cmap = matplotlib.colors.ListedColormap(colors)
    norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)
    im = ax[0].imshow(
        classified_image,
        cmap=cmap,
        norm=norm,
        extent=(
            min(classified_image.x.data)/1e3, max(classified_image.x.data)/1e3, 
            min(classified_image.y.data)/1e3, max(classified_image.y.data)/1e3
        )
    )
    cbar = fig.colorbar(im, ax=ax[0], orientation='horizontal', shrink=0.8)
    cbar.set_ticks([1.5, 2.5, 3.5, 4.5, 5.5])
    cbar.set_ticklabels([
        'Snow', 'Shadowed\nsnow', 'Ice', 'Rock/debris', 'Water'
    ], fontsize=8)
    cbar.ax.minorticks_off()
    ax[0].set_ylabel("Northing [km]")
    ax[0].set_xlabel("Easting [km]")

    X,Y = np.meshgrid(dem.x.data, dem.y.data)
    ax[0].contour(np.divide(X, 1e3), np.divide(Y, 1e3), dem.data, levels=[sla], colors='k', linestyles='solid')
    ax[0].contour(np.divide(X, 1e3), np.divide(Y, 1e3), dem.data, levels=[sla_upper], colors='k', linestyles='dashed')
    ax[0].contour(np.divide(X, 1e3), np.divide(Y, 1e3), dem.data, levels=[sla_lower], colors='k', linestyles='dotted')

    # Histograms of all elevations and snow elevations with lines for SLA metrics
    bins = np.linspace(np.nanmin(dem.data), np.nanmax(dem.data), num=100)
    ax[1].hist(dem.data.ravel(), bins=bins, color='gray', alpha=1, label="All elevations")
    ax[1].hist(snow_dem.data.ravel(), bins=bins, color=colors[0], alpha=1, label="Snow elevations")
    ax[1].axvline(sla_upper, color='k', linestyle='dashed', label="SLA$_{upper}$ = "+str(int(sla_upper))+" m")
    ax[1].axvline(sla, color='k', linestyle='solid', label="SLA = "+str(int(sla))+" m")
    ax[1].axvline(sla_lower, color='k', linestyle='dotted', label="SLA$_{lower}$ = "+str(int(sla_lower))+" m")
    ax[1].legend(loc='upper right')
    ax[1].set_xlabel("Elevation [m]")
    ax[1].set_ylabel("Counts")

    fig.suptitle(f"{date} {dataset}")

    # Save to file
    fig.savefig(fig_file, dpi=300, bbox_inches='tight')
    print("Figure saved to:", fig_file)

    plt.close()

    return

# Iterate over the classified images
classified_image_files = sorted(glob(os.path.join(out_folder, "*_classified.tif")))
print(f"Located {len(classified_image_files)} classified image files.")

for file in tqdm(classified_image_files):
    calculate_snow_cover_statistics_python(
        file,
        dem_file,
    )