# Classify snow for an off-glacier site using the NDSI

## Define settings

In [None]:
# wxee (used to convert an ee.Image to xarray.Dataset) is not included in the default environment. 
# Install by uncommenting the line below.
# !micromamba install -c conda-forge wxee -y

import os
import ee
import sys
import wxee as wx
import numpy as np
import datetime
from tqdm import tqdm
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd

# --- Define options for downloading outputs ---
out_folder = '/Users/rdcrlrka/Research/glacier_snow_mapping/LemonCreek_watershed_AOI'
download_images = True
plot_results = True

# --- Define image search and classification settings ---
# Date and month ranges (inclusive)
date_start = '2017-04-01' 
date_end = '2017-11-01' 
month_start = 4 # April = 4
month_end = 10 # Oct = 10
# Minimum fill portion percentage of the AOI (0â€“100), used to remove images after mosaicking by day
min_aoi_coverage = 70
# Whether to mask clouds using the respective cloud mask via the geedim package
mask_clouds = True
# NDSI threshold used to classify snow (set to None to skip)
ndsi_threshold = 0.6
# Blue band threshold used to classify snow (set to None to skip)
blue_threshold = 0.4
# SLA percentile: the percentile of snow-covered elevations to sample, from 0 to 100
sla_percentile = 5

# --- Import pipeline utilities ---
# Assumes glasee_pipeline_utils.py is one folder above this notebook
script_path = os.getcwd()
sys.path.append(os.path.join(script_path, '..'))
import glasee_pipeline_utils as utils

## Authenticate and/or Initialize Google Earth Engine (GEE)

Replace the project ID with your GEE project. Default = `ee-[GEE-username]`

In [None]:
project_id = "ee-raineyaberle"

try:
    ee.Initialize(project=project_id)
except:
    ee.Authenticate()
    ee.Initialize(project=project_id)

## Define AOI and query GEE for DEM

In [None]:
# -----Manual AOI
# save just the coordinates first for later use
aoi_coords = [
[-134.379943, 58.417939],
[-134.362549, 58.388341],
[-134.361008, 58.3727101],
[-134.369808, 58.359844],
[-134.470524, 58.3375671],
[-134.558482, 58.381999],
[-134.379943, 58.41793911]
]
# convert to ee.Geometry
aoi = ee.Geometry.Polygon(aoi_coords)
aoi_area = aoi.area().getInfo()
print(f"AOI area = {int(aoi_area/1e6)} km2")

# -----Identify the best UTM zone for outputs
def convert_wgs_to_utm(lon: float, lat: float):
    utm_band = str(int((np.floor((lon + 180) / 6) % 60) + 1))
    if len(utm_band) == 1:
        utm_band = '0' + utm_band
    if lat >= 0:
        epsg_code = 'EPSG:326' + utm_band
        return epsg_code
    epsg_code = 'EPSG:327' + utm_band
    return epsg_code

aoi_cen_lon = float(np.nanmean(np.array(aoi_coords)[:,0]))
aoi_cen_lat = float(np.nanmean(np.array(aoi_coords)[:,1]))
utm_crs = convert_wgs_to_utm(aoi_cen_lon, aoi_cen_lat)
print("Optimal UTM zone =", utm_crs)

# -----Query GEE for DEM
dem = utils.query_gee_for_dem(aoi)
# set an arbitrary time for DEM, otherwise wxee will get angry
print("Downloading DEM as xarray.Dataset")
dem = dem.set('system:time_start', 0)
dem_xr = dem.wx.to_xarray(region=aoi, scale=10, crs=utm_crs)
# remove unnecessary dimensions
dem_xr = dem_xr.squeeze(drop=True)
if 'elevation' in dem_xr.data_vars:
    dem_xr = dem_xr['elevation']

# -----Plot
fig, ax = plt.subplots(figsize=(6,6))
dem_xr.plot(ax=ax, cmap='terrain')
plt.show()

## Define some helper functions

In [None]:
def split_date_range_daily(
        dataset: str = None, 
        date_start: str = None, 
        date_end: str = None, 
        month_start: int = None, 
        month_end: int = None
        ):
    # Convert string inputs to datetime objects
    date_start = datetime.datetime.strptime(date_start, "%Y-%m-%d").date()
    date_end = datetime.datetime.strptime(date_end, "%Y-%m-%d").date()

    # Enforce dataset availability
    dataset_start_years = {
        "Sentinel-2_TOA": 2016,
        "Sentinel-2_SR": 2019,
        "Landsat": 2013,
    }

    if dataset not in dataset_start_years:
        raise ValueError(f"Unsupported dataset: {dataset}")

    min_year = dataset_start_years[dataset]
    date_start = max(date_start, datetime.date(min_year, 1, 1))  # Clamp to dataset availability

    # List to hold date range tuples
    date_ranges = []
    current = max(date_start, datetime.date(date_start.year, month_start, 1))
    while current < date_end:
        if month_start <= current.month <= month_end:
            date_ranges.append((current.isoformat(), (current + datetime.timedelta(days=1)).isoformat()))
        current += datetime.timedelta(days=1)

    print(f"Number of dates to query = {len(date_ranges)}")

    return date_ranges


def classify_snow(
        image_xr: xr.Dataset = None, 
        ndsi_threshold: float = None, 
        blue_threshold: float = None,
        out_file: str = None
        ):
    
    # Get the name of the blue band
    if "B2" in image_xr.data_vars:
        blue_band = "B2"
    elif "SR_B2" in image_xr.data_vars:
        blue_band = "SR_B2"
    else:
        raise ValueError(f"Cannot determine blue band for the image.")

    # Apply the thresholds specified
    if (type(ndsi_threshold)==float) & (type(blue_threshold)==float):
        image_classified = xr.where(
            (image_xr.NDSI >= ndsi_threshold) 
            & (image_xr[blue_band] >= blue_threshold), 1, 0
            )
    elif (type(ndsi_threshold)==float):
        image_classified = xr.where(image_xr.NDSI >= ndsi_threshold, 1, 0)
    elif (type(blue_threshold)==float):
        image_classified = xr.where(image_xr[blue_band] >= blue_threshold, 1, 0)
    else:
        raise ValueError("No thresholds provided to classify snow, cannot proceed.")

    # Put nodata values back in
    image_classified = xr.where(np.isnan(image_xr.NDSI), np.nan, image_classified)

    # Set raster attributes
    image_classified = (
        image_classified
            .rio.write_crs(image_xr.rio.crs)
            .rio.write_nodata(np.nan)
            )
    
    # Save to file
    if type(out_file)==str:
        # convert to int datatype
        image_classified_int = xr.where(np.isnan(image_classified), 9999, image_classified).astype(int)
        image_classified_int = (
        image_classified_int
            .rio.write_crs(image_xr.rio.crs)
            .rio.write_nodata(9999)
            )
        image_classified_int.rio.to_raster(out_file, dtype=np.uint16)
        print(f"Classified image saved to: {out_file}")

    return image_classified

def calculate_snowline_altitude(
        date: str = None,
        dataset: str = None,
        classified_image: xr.Dataset = None, 
        dem: xr.DataArray = None, 
        sla_percentile: float = 5,
        out_file: str = None
    ):

    # --- Calculate pixel area ---
    scale = np.nanmean(classified_image.x.data[1:] - classified_image.x.data[0:-1])
    pixel_area = float(scale) ** 2

    # --- Calculate snow, no snow, and masked areas ---
    snow_mask = (classified_image == 1)
    snow_covered_area = np.sum(snow_mask).data * pixel_area

    snow_free_mask = (classified_image == 0)
    snow_free_area = np.sum(snow_free_mask).data * pixel_area

    nodata_mask = np.isnan(classified_image)
    nodata_area = np.sum(nodata_mask).data * pixel_area

    # --- Mask DEM to where there are observations ---
    # make sure DEM is on the same grid first
    dem = dem.rio.reproject_match(classified_image)
    dem_masked = dem.where(~np.isnan(classified_image))

    # --- Estimate Snowline Altitude (SLA) ---
    snow_dem = dem_masked.where(snow_mask)
    sla = float(np.nanpercentile(snow_dem, sla_percentile))

    # --- Estimate SLA upper and lower bounds ---
    # "Reference system switch": find the DEM percentile corresponding to the SLA
    below_sla_mask = dem < sla
    below_sla_mask_area = float(below_sla_mask.sum()) * pixel_area
    sla_percentile_dem = (below_sla_mask_area / aoi_area) * 100

    # Upper bound
    above_sla_mask = dem > sla
    sla_upper_mask = snow_free_mask & above_sla_mask
    sla_upper_mask_area = float(sla_upper_mask.sum()) * pixel_area
    if sla_upper_mask_area > 0:
        # DEM percentile to sample = (SLA percentile) + (Area snow-free above SLA / Total AOI Area)
        sla_upper_percentile = sla_percentile_dem + (sla_upper_mask_area / aoi_area) * 100
        # make sure it's between 0 and 1
        sla_upper_percentile = np.clip(sla_upper_percentile/100, 0, 1)
        sla_upper = float(dem.quantile(sla_upper_percentile, skipna=True))
    else:
        sla_upper = sla

    # Lower bound
    sla_lower_mask = (classified_image == 1) & below_sla_mask
    sla_lower_mask_area = float(sla_lower_mask.sum()) * pixel_area
    if sla_lower_mask_area > 0:
        # DEM percentile to sample = (SLA percentile) - (Area snow-covered below SLA / Total AOI Area)
        sla_lower_percentile = sla_percentile_dem - (sla_lower_mask_area / aoi_area) * 100
        # make sure it's between 0 and 1
        sla_lower_percentile = np.clip(sla_lower_percentile / 100, 0, 1)
        sla_lower = float(dem.quantile(sla_lower_percentile, skipna=True))
    else:
        sla_lower = sla

    # Compile results in a pd.DataFrame
    # rounding floats because we're not that precise
    df = pd.DataFrame({
        "date": [date],
        "dataset": [dataset],
        "snow_covered_area_m2": [int(snow_covered_area)],
        "snow_free_area_m2": [int(snow_free_area)],
        "masked_area_m2": [int(nodata_area)],
        "SLA_m": [int(sla)],
        "SLA_lower_bound_m": [int(sla_lower)],
        "SLA_upper_bound_m": [int(sla_upper)]
    }, index=[0])

    # Save to file
    if type(out_file)==str:
        df.to_csv(out_file, index=False)
        print(f"Snow cover stats saved to: {out_file}")

    return df, dem_masked

def plot_snow_cover_stats(
        date: str = None, 
        dataset: str = None,
        image: xr.Dataset = None, 
        classified_image: xr.DataArray = None,
        dem: xr.DataArray = None,
        snow_stats_df: pd.DataFrame = None,
        out_file: str = None
    ):
    plt.rcParams.update({"font.size": 12, "font.sans-serif": "Verdana"})

    # Grab info from df
    sla = snow_stats_df['SLA_m'].values[0]
    sla_lower = snow_stats_df['SLA_lower_bound_m'].values[0]
    sla_upper = snow_stats_df['SLA_upper_bound_m'].values[0]

    # Set up figure
    gs = matplotlib.gridspec.GridSpec(2, 2, height_ratios=[2,1])
    fig = plt.figure(figsize=(12,8))
    ax = [
        fig.add_subplot(gs[0,0]),
        fig.add_subplot(gs[0,1]),
        fig.add_subplot(gs[1,:])
    ]
    # RGB image
    if "Sentinel-2" in dataset:
        rgb_bands = ['B4', 'B3', 'B2']
    else:
        rgb_bands = ['SR_B4', 'SR_B3', 'SR_B2']
    ax[0].imshow(
        np.dstack([image[rgb_bands[0]], image[rgb_bands[1]], image[rgb_bands[2]]]),
        extent=(min(image.x)/1e3, max(image.x)/1e3, min(image.y)/1e3, max(image.y)/1e3)
    )

    ax[0].set_xlabel("Easting [km]")
    ax[0].set_ylabel("Northing [km]")

    # Classified image + SLA contour
    colors = ["w", "#2C98CA"] 
    bounds = [0,1] 
    cmap = matplotlib.colors.ListedColormap(colors)
    norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)
    im = ax[1].imshow(
        classified_image,
        cmap=cmap,
        norm=norm,
        extent=(
            min(classified_image.x)/1e3, max(classified_image.x)/1e3, 
            min(classified_image.y)/1e3, max(classified_image.y)/1e3
        )
    )
    X,Y = np.meshgrid(dem.x.data, dem.y.data)
    ax[1].contour(np.divide(X, 1e3), np.divide(Y, 1e3), dem.data, levels=[sla], colors='k', linestyles='solid')
    # ax[1].contour(np.divide(X, 1e3), np.divide(Y, 1e3), dem.data, levels=[sla_upper], colors='k', linestyles='dashed')
    # ax[1].contour(np.divide(X, 1e3), np.divide(Y, 1e3), dem.data, levels=[sla_lower], colors='k', linestyles='dotted')

    # dummy points for legend
    xmin, xmax = ax[1].get_xlim()
    ymin, ymax = ax[1].get_ylim()
    ax[1].plot(0, 0, 's', markersize=10, markerfacecolor=colors[0], markeredgecolor='k', markeredgewidth=0.5, label="No snow")
    ax[1].plot(0, 0, 's', markersize=10, markerfacecolor=colors[1], markeredgecolor='k', markeredgewidth=0.5, label="Snow")
    ax[1].plot([0,1], [0,1], '-k', linewidth=1.5, label="SLA")
    ax[1].set_xlim(xmin, xmax)
    ax[1].set_ylim(ymin, ymax)
    ax[1].legend(loc='upper left')

    ax[1].set_xlabel("Easting [km]")


    # Histograms of all elevations and snow elevations with lines for SLA metrics
    snow_dem = dem.where(classified_image==1)
    bins = np.linspace(np.nanmin(dem.data), np.nanmax(dem.data), num=100)
    ax[2].hist(dem.data.ravel(), bins=bins, color='gray', alpha=1, label="All elevations")
    ax[2].hist(snow_dem.data.ravel(), bins=bins, color=colors[1], alpha=1, label="Snow elevations")
    ax[2].axvline(sla_upper, color='k', linestyle='dashed', label="SLA$_{upper}$ = "+str(int(sla_upper))+" m")
    ax[2].axvline(sla, color='k', linestyle='solid', label="SLA = "+str(int(sla))+" m")
    ax[2].axvline(sla_lower, color='k', linestyle='dotted', label="SLA$_{lower}$ = "+str(int(sla_lower))+" m")
    ax[2].legend(loc='upper right')
    ax[2].set_xlabel("Elevation [m]")
    ax[2].set_ylabel("Counts")

    fig.suptitle(f"{date} {dataset}")
    fig.tight_layout()

    # Save to file
    fig.savefig(out_file, dpi=300, bbox_inches='tight')
    print("Figure saved to:", out_file)
    plt.close()

## Run the workflow: query GEE for imagery, classify snow, calculate stats, plot

In [None]:
os.makedirs(out_folder, exist_ok=True)

# Run each dataset separately
# NOTE: Could also run for Sentinel-2_SR, but only available after 2019
for dataset in ["Sentinel-2_TOA", "Landsat"]:
    print(f"\n\n{dataset}\n--------------------")

    # Define dataset-specific variables
    if "Sentinel" in dataset:
        download_bands = ["B4", "B3", "B2", "NDSI"]
        scale = 10
    else:
        download_bands = ["SR_B4", "SR_B3", "SR_B2", "NDSI"]
        scale = 30
    rgb_bands = download_bands[0:-1]

    # Get daily date ranges
    date_ranges = split_date_range_daily(dataset, date_start, date_end, month_start, month_end)

    # Iterate over date ranges
    for d_start, d_end in tqdm(date_ranges):
        # Define outputs
        settings_string = f"NDSI{ndsi_threshold}_Blue{blue_threshold if type(blue_threshold)==float else 0}"
        image_file = os.path.join(out_folder, f"{d_start}_{dataset}.tif")
        classified_image_file = os.path.join(out_folder, f"{d_start}_{dataset}_classified_{settings_string}.tif")
        snow_stats_file = os.path.join(out_folder, f"{d_start}_{dataset}_snow_stats_{settings_string}.csv")
        fig_file = os.path.join(out_folder, f"{d_start}_{dataset}_results_{settings_string}.png")

        # Check if results have already been saved
        if os.path.exists(fig_file) and os.path.exists(snow_stats_file):
            print(f"Files for {d_start} already exist. Skipping.")
            continue

        # Query GEE for imagery       
        image_col = utils.query_gee_for_imagery(
            dataset = dataset,
            aoi = aoi,
            date_start = str(d_start), 
            date_end = str(d_end), 
            fill_portion = min_aoi_coverage, 
            mask_clouds = mask_clouds, 
            scale = None, 
            verbose = False
        )

        # Check if any images were found
        if image_col.size().getInfo() < 1:
            continue

        print(f"\n{d_start}")

        # Convert image to xarray.Dataset
        image = image_col.first().select(download_bands)
        image_xr = image.wx.to_xarray(region=aoi, scale=scale, crs=utm_crs)
        image_xr = image_xr.squeeze(drop=True)
        if download_images:
            image_xr.rio.to_raster(image_file)
            print(f"Image saved to: {image_file}")

        # Classify snow
        classified_image_xr = classify_snow(
            image_xr = image_xr, 
            ndsi_threshold = ndsi_threshold,
            blue_threshold = blue_threshold,
            out_file = classified_image_file if download_images else None
            )

        # Estimate snowline elevation
        snow_stats_df, dem_masked_xr = calculate_snowline_altitude(
            date = d_start, 
            dataset = dataset, 
            classified_image = classified_image_xr, 
            dem = dem_xr, 
            sla_percentile = sla_percentile,
            out_file = snow_stats_file
        )
        
        # Plot results
        plot_snow_cover_stats(
            date = d_start,
            dataset = dataset,
            image = image_xr,
            classified_image = classified_image_xr,
            dem = dem_masked_xr,
            snow_stats_df = snow_stats_df,
            out_file = fig_file
        )

print("\nDone!")

## Optional: compile all snow cover stats into one CSV

In [None]:
from glob import glob

# get files
stats_files = sorted(glob(os.path.join(out_folder, "*.csv")))
print(f"Located {len(stats_files)} snow cover stats files")

# compile in dataframe
df_list = []
for file in stats_files:
    df_list += [pd.read_csv(file)]
compiled_stats_df = pd.concat(df_list, ignore_index=True)

# save to file
compiled_stats_file = os.path.join(out_folder, "compiled_snow_stats.csv")
compiled_stats_df.to_csv(compiled_stats_file, index=False)
print(f"Compiled snow cover stats saved to: {compiled_stats_file}")

In [None]:
# optional: remove individual files
# commented out to precent accidents!

# for file in stats_files:
#     os.remove(file)