# Make figures

In [None]:
import os,glob
import matplotlib.pyplot as plt
import pandas as pd
import geopandas as gpd
import contextily as ctx
from shapely import wkt
from shapely.ops import unary_union
from shapely.geometry import Polygon, Point
import geoutils as gu
import numpy as np
from geopy.geocoders import Nominatim
from matplotlib.colors import LightSource
import string
import xdem
import seaborn as sns
import sys
import json
import rioxarray as rxr
import xarray as xr
import matplotlib
from tqdm.auto import tqdm
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm
import rasterio as rio

# Import processing functions
code_dir = '/Users/raineyaberle/Research/PhD/SnowDEMs/skysat_stereo_snow/skysat_stereo_snow'
sys.path.append(code_dir)
import post_process_utils as pprocess

plt.rcParams.update({'font.size':12, 'font.sans-serif': 'Arial'})

In [None]:
# define some paths in directory for convenience
data_path = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites'
figures_path = '/Users/raineyaberle/Research/PhD/SnowDEMs/skysat_stereo_snow/figures'
  
# Define site-specific details
info_dict = {
    "MCS": {
        "refdem_fn": os.path.join(data_path, "MCS", 'refdem', "MCS_REFDEM_WGS84.tif"),
        "SNOTEL_fn": os.path.join(data_path, "MCS", 'snotel', 'MCS_2020-01-01_2024-06-07_adj.csv'),
        "SNOTEL_site_fn": os.path.join(data_path, "MCS", 'snotel', 'MCS_SNOTEL_site_info.csv'),
        "DEM_fns": [
            os.path.join(data_path, "MCS", "20240420", "post_process", "coregAll_ba-u5m_MCS_20240420_DEM_GCPshift.tif"),
            os.path.join(data_path, "MCS", "20241003", "post_process", "coregAll_ba-u5m_MCS_20241003_DEM_GCPshift_slope-corrected.tif")
        ],
        "orthomosaic_fns": [
            os.path.join(data_path, "MCS", "20240420", "coregAll_ba-u5m_MCS_20240420_orthomosaic.tif"),
            os.path.join(data_path, "MCS", "20241003", "coregAll_ba-u5m_MCS_20241003_orthomosaic.tif")
            ],
        "satellites": ["s116", "s109"],
        "site_name_display": "Mores Creek Summit"
    },
    "JacksonPeak": {
        "refdem_fn": os.path.join(data_path, "JacksonPeak", 'refdem', "USGS_LPC_ID_FEMAHQ_2018_D18_merged_filtered_UTM11_filled.tif"),
        "SNOTEL_fn": os.path.join(data_path, "JacksonPeak", 'snotel', "JacksonPeak_2023-01-01_2024-06-07_adj.csv"),
        "SNOTEL_site_fn": os.path.join(data_path, "JacksonPeak", 'snotel', 'JacksonPeak_SNOTEL_site_info.csv'),
        "DEM_fns": [
            os.path.join(data_path, "JacksonPeak", "20240420", "post_process", "JacksonPeak_20240420_DEM_GCPshift.tif")
        ],
        "orthomosaic_fns": [
            os.path.join(data_path, "JacksonPeak", "20240420", "JacksonPeak_20240420_orthomosaic.tif")
        ],
        "satellites": ["s113"],
        "site_name_display": "Jackson Peak"
    },
    "Banner": {
        "refdem_fn": os.path.join(data_path, "Banner", 'refdem', "Banner_REFDEM_WGS84.tif"),
        "SNOTEL_fn": os.path.join(data_path, "Banner", 'snotel', "Banner_2020-01-01_2024-06-07_adj.csv"),
        "SNOTEL_site_fn": os.path.join(data_path, "Banner", 'snotel', 'Banner_SNOTEL_site_info.csv'),
        "DEM_fns": [
            os.path.join(data_path, "Banner", "20240419-1", "post_process", "Banner_20240419-1_DEM_GCPshift.tif"),
            os.path.join(data_path, "Banner", "20240419-2", "post_process", "Banner_20240419-2_DEM_GCPshift.tif")
        ],
        "orthomosaic_fns": [
            os.path.join(data_path, "Banner", "20240419-1", "Banner_20240419-1_orthomosaic.tif"),
            os.path.join(data_path, "Banner", "20240419-2", "Banner_20240419-2_orthomosaic.tif")
            ],
        "satellites": ["s112", "s108"],
        "site_name_display": "Banner Summit"
    },
}

# Function for clipping DEM rows and columns with no data 
def clip_nodata(dem):
    # Get the data array
    dem_data = dem.data
    # Find valid rows and columns
    valid_rows = np.any(np.isfinite(dem_data), axis=1)
    valid_cols = np.any(np.isfinite(dem_data), axis=0)
    # Get index range of valid data
    row_min, row_max = np.where(valid_rows)[0][[0, -1]]
    col_min, col_max = np.where(valid_cols)[0][[0, -1]]
    # Compute new height (accounting for clipping at the bottom)
    height = row_max - row_min + 1
    width = col_max - col_min + 1
    # Clip the DEM data
    clipped_data = dem_data[row_min:row_max+1, col_min:col_max+1]
    # Define a rasterio window for spatial reference
    window = rio.windows.Window(col_min, row_min, width, height)
    # Correctly adjust the transform
    clipped_transform = rio.windows.transform(window, dem.transform)
    # Create a new xdem.DEM with the clipped data
    clipped_dem = xdem.DEM.from_array(clipped_data, transform=clipped_transform, crs=dem.crs, nodata=dem.nodata)

    return clipped_dem

## Banner: coregistering to stable surfaces, no bundle adjust

In [None]:
# Load inputs
dem_fn = os.path.join(data_path, 'Banner', '20240419-2', 'post_process', 'Banner_20240419-2_DEM_GCPshift.tif')
dem_mod_fn = os.path.join(data_path, 'Banner', '20240419-2', 'post_process', 'coreg-ss-individual_Banner_20240419-2_DEM_GCPshift.tif')
refdem_fn = os.path.join(data_path, 'Banner', 'refdem', 'Banner_REFDEM_WGS84.tif')
ss_mask_fn = os.path.join(data_path, 'Banner', '20240419-2', 'land_cover_masks', 'stable_surfaces_mask.tif')
snow_mask_fn = os.path.join(data_path, 'Banner', '20240419-2', 'land_cover_masks', 'snow_mask.tif')

dem = xdem.DEM(dem_fn)
dem_mod = xdem.DEM(dem_mod_fn).reproject(dem)
refdem = xdem.DEM(refdem_fn).reproject(dem)
ss_mask = gu.Raster(ss_mask_fn).reproject(dem)
ss_mask = (ss_mask==1)
snow_mask = gu.Raster(snow_mask_fn).reproject(dem)
snow_mask = (snow_mask==1)

# Calculate dDEMs
ddem = dem - refdem
ddem_mod = dem_mod - refdem

# Mask dDEMs to stable surfaces
ddem_ss = ddem.copy()
ddem_ss.set_mask(~ss_mask)
ddem_ss_median, ddem_ss_nmad = np.ma.median(ddem_ss.data), xdem.spatialstats.nmad(ddem_ss)
ddem_mod_ss = ddem_mod.copy()
ddem_mod_ss.set_mask(~ss_mask)
ddem_mod_ss_median, ddem_mod_ss_nmad = np.ma.median(ddem_mod_ss.data), xdem.spatialstats.nmad(ddem_mod_ss)

# Create snow depth map
sd = ddem.copy()
sd.set_mask(~snow_mask)
sd_mod = ddem_mod.copy()
sd_mod.set_mask(~snow_mask)

# Plot
fig, ax = plt.subplots(4, 2, figsize=(8, 12))
ddem.plot(cmap='coolwarm_r', vmin=-5, vmax=5, ax=ax[0,0])
ddem_mod.plot(cmap='coolwarm_r', vmin=-5, vmax=5, ax=ax[0,1])
ddem_ss.plot(cmap='coolwarm_r', vmin=-5, vmax=5, ax=ax[1,0])
ddem_mod_ss.plot(cmap='coolwarm_r', vmin=-5, vmax=5, ax=ax[1,1])
ax[2,0].hist(ddem_ss.data.ravel(), bins=np.linspace(-5,5,50), facecolor='gray', edgecolor='k', linewidth=0.5)
ax[2,0].text(0.05, 0.8, f"Median = {np.round(float(ddem_ss_median),3)} m\nNMAD={np.round(float(ddem_ss_nmad),3)} m",
             transform=ax[2,0].transAxes, ha='left')
ax[2,1].hist(ddem_ss.data.ravel(), bins=np.linspace(-5,5,50), facecolor='gray', edgecolor='k', linewidth=0.5)
ax[2,1].text(0.05, 0.8, f"Median = {np.round(float(ddem_mod_ss_median),3)} m\nNMAD={np.round(float(ddem_mod_ss_nmad),3)} m",
             transform=ax[2,1].transAxes, ha='left')
sd.plot(cmap='YlGnBu', vmin=0, vmax=5, ax=ax[3,0])
sd_mod.plot(cmap='YlGnBu', vmin=0, vmax=5, ax=ax[3,1])

ax[0,0].set_title('Full workflow')
ax[0,1].set_title('Coregister to SS only')
ax[0,0].set_ylabel('dDEM', rotation=0, ha='right')
ax[1,0].set_ylabel('dDEM \nstable surfaces', rotation=0, ha='right')
ax[2,0].set_ylabel('dDEM \nstable surfaces', rotation=0, ha='right')
ax[3,0].set_ylabel('Snow depth map', rotation=0, ha='right')

for axis in list(ax.ravel()[0:4]) + list(ax.ravel()[6:]):
    axis.set_xticks([])
    axis.set_yticks([])

plt.show()

fig_fn = os.path.join(figures_path, 'Banner_20240419-2_tests.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

## Figure 1. Study sites map

In [None]:
# Set up figure
crs = "EPSG:32611"
fontsize = 12
plt.rcParams.update({'font.size':fontsize, 'font.sans-serif': "Arial"})
fig, ax = plt.subplots(2,2, figsize=(10,10))
ax = ax.ravel()

# Load and plot HWY21
hwy21_fn = os.path.join(data_path, '..', 'ITD_Functional_Class', 'ITD_HWY_21.shp')
hwy21 = gpd.read_file(hwy21_fn).to_crs(crs)
for i, axis in enumerate(ax):
    for j, geom in enumerate(hwy21.geometry[0].geoms):
        if (i==0) & (j==0):
            label='ID-21'
        else:
            label='_nolegend'
        axis.plot(*geom.coords.xy, color='w', linewidth=2)
        axis.plot(*geom.coords.xy, color='k', linewidth=1, label=label)
        
def bounds_to_polygon(bounds):
    return Polygon([[bounds.left, bounds.bottom],
                    [bounds.right, bounds.bottom],
                    [bounds.right, bounds.top],
                    [bounds.left, bounds.top],
                    [bounds.left, bounds.bottom]])
    
# iterate over site names
labels = ['a', 'b', 'c', 'd']
for i, site_name in enumerate(list(info_dict.keys())):
    # basemap
    ctx.add_basemap(ax=ax[i+1], source=ctx.providers.USGS.USImagery, attribution=False, crs=crs, zoom=12)
    
    # Reference DEM shaded relief
    refdem_fn = info_dict[site_name]['refdem_fn']
    refdem = gu.Raster(refdem_fn, load_data=True).reproject(crs=crs, res=10)
    refdem_bounds = refdem.bounds
    # initialize axes bounds for later
    xmin, ymin, xmax, ymax = refdem_bounds.left, refdem_bounds.bottom, refdem_bounds.right, refdem.bounds.top
    # Shaded relief
    ls = LightSource(azdeg=315, altdeg=45)
    hs = ls.hillshade(refdem.data, vert_exag=5)
    ax[i+1].imshow(hs, cmap='Greys_r',
                 extent=(refdem.bounds.left, refdem.bounds.right, refdem.bounds.bottom, refdem.bounds.top))
    im = ax[i+1].imshow(refdem.data, cmap='terrain', alpha=0.7, clim=(1500, 3000),
                      extent=(refdem.bounds.left, refdem.bounds.right, refdem.bounds.bottom, refdem.bounds.top))
    # outline on map
    ax[0].plot([refdem.bounds.left-2e3, refdem.bounds.right+2e3, refdem.bounds.right+2e3, refdem.bounds.left-2e3, refdem.bounds.left-2e3],
               [refdem.bounds.bottom-2e3, refdem.bounds.bottom-2e3, refdem.bounds.top+2e3, refdem.bounds.top+2e3, refdem.bounds.bottom-2e3], 
               '-w', linewidth=1, label='_nolegend')
    ax[0].text((refdem.bounds.left + refdem.bounds.right)/2, (refdem.bounds.bottom + refdem.bounds.top)/2, 
               labels[i+1], fontweight='bold', color='w', fontsize=fontsize+2, ha='center', va='center')
    
    # Image footprints
    dem_fns = info_dict[site_name]['DEM_fns']
    for dem_fn in dem_fns:
        im_dir = os.path.join(dem_fn.split('post_process')[0], 'SkySatScene')
        meta_fns = sorted(glob.glob(os.path.join(im_dir, '*_metadata.json')))
        polys = []
        for meta_fn in meta_fns:
            with open(meta_fn, 'r') as f:
                meta = json.load(f)
                polys.append([Polygon(meta['geometry']['coordinates'][0])])
        polys_merged = unary_union(polys)
        polys_gdf = gpd.GeoDataFrame(geometry=[polys_merged], crs='EPSG:4326')
        polys_gdf = polys_gdf.to_crs(crs)
        ax[i+1].plot(*polys_gdf.geometry[0].exterior.coords.xy, color='#fe9929', linewidth=2, label='SkySat images footprint')
        # adjust axes bounds
        polys_gdf_bounds = polys_gdf.geometry[0].bounds
        if polys_gdf_bounds[0] < xmin:
            xmin = polys_gdf_bounds[0]
        if polys_gdf_bounds[1] < ymin:
            ymin = polys_gdf_bounds[1]
        if polys_gdf_bounds[2] > xmax:
            xmax = polys_gdf_bounds[2]
        if polys_gdf_bounds[3] > ymax:
            ymax = polys_gdf_bounds[3]
    
    # SNOTEL site info
    snotel_fn = info_dict[site_name]['SNOTEL_site_fn']
    snotel = pd.read_csv(snotel_fn)
    snotel['geometry'] = snotel['geometry'].apply(wkt.loads)
    snotel_gdf = gpd.GeoDataFrame(snotel, crs='EPSG:4326')
    snotel_gdf = snotel_gdf.to_crs(refdem.crs)
    ax[i+1].plot(*snotel_gdf['geometry'].values[0].coords.xy, '*', markersize=15, 
                 markerfacecolor='b', markeredgecolor='w', linewidth=0.3, label="SNOTEL site")
    
    # scalebar
    scalebar = AnchoredSizeBar(ax[i+1].transData,
                               2e3, '2 km', 'lower right', 
                               pad=0.2,
                               color='white',
                               sep=7,
                               frameon=False,
                               fill_bar=True,
                               fontproperties=fm.FontProperties(size=14, weight='bold'))
    ax[i+1].add_artist(scalebar)
    
    # title
    ax[i+1].set_title(f"{labels[i+1]}) {info_dict[site_name]['site_name_display']}")
    # adjust axes
    ax[i+1].set_xticks([])
    ax[i+1].set_yticks([])
    ax[i+1].set_xlim(xmin-1e3, xmax+1e3)
    ax[i+1].set_ylim(ymin-1e3, ymax+1e3)
    # adjust line width of scale bar
    yrange = ax[i+1].get_ylim()[1] - ax[i+1].get_ylim()[0]
    size_vertical = 0.0075 * yrange
    scalebar.size_bar.get_children()[0].set_height(size_vertical)
    
# legend
handles0, labels0 = ax[0].get_legend_handles_labels()
handles2, labels2 = ax[2].get_legend_handles_labels()
fig.legend(handles0+handles2, labels0+labels2, loc='upper center', 
           ncols=len(labels0+labels2), frameon=False, bbox_to_anchor=[0.4, 0.75, 0.2, 0.2])

# colorbar
fig.subplots_adjust(right=0.87, hspace=0.25)
cbar_ax = fig.add_axes([0.9, 0.25, 0.03, 0.5])
fig.colorbar(im, cax=cbar_ax, shrink=0.8, label='Elevation [m]')

# ax[0] adjustments
ax[0].set_title('a) All sites')
ax[0].set_xticks(ax[0].get_xticks())
ax[0].set_xticklabels(np.divide(ax[0].get_xticks(), 1e3).astype(int).astype(str))
ax[0].set_yticks(ax[0].get_yticks())
ax[0].set_yticklabels(np.divide(ax[0].get_yticks(), 1e3).astype(int).astype(str))
ax[0].set_xlabel('Easting [km]')
ax[0].set_ylabel('Northing [km]')
xmin, xmax = 555e3, 675e3
ymin, ymax = 4818e3, 4920e3
ax[0].set_xlim(xmin, xmax)
ax[0].set_ylim(ymin, ymax)

# Plot Idaho locator map
idaho_fn = '/Users/raineyaberle/Research/PhD/GIS_data/US_State_Boundaries/Idaho_State_Boundaries.gpkg'
idaho = gpd.read_file(idaho_fn)
idaho = idaho.to_crs(crs)
ax2 = fig.add_axes([0.11, 0.73, 0.15, 0.15])
ax2.set_axis_off()
idaho.plot(facecolor='#bdbdbd', edgecolor='k', linewidth=0.5, ax=ax2)
xmin, xmax = ax[0].get_xlim()
ymin, ymax = ax[0].get_ylim()
ax2.fill_between([xmin, xmax], [ymin, ymin], [ymax, ymax], color='g')

# Add city names
geolocator = Nominatim(user_agent="My-Map")
cities = ["Boise", "Idaho City", "Stanley, Custer County, Idaho"]
df = pd.DataFrame(columns=['city', 'geometry'])
for city in cities:
    location = geolocator.geocode(city, addressdetails=True)
    df = pd.concat([df, pd.DataFrame({'city': [city],
                                      'geometry': Point(location.longitude, location.latitude)})])
gdf = gpd.GeoDataFrame(df, crs='EPSG:4326')
gdf = gdf.to_crs((crs)).reset_index(drop=True)
for city in cities:
    ax[0].text(gdf.loc[gdf['city']==city].geometry.values[0].coords.xy[0][0], 
               gdf.loc[gdf['city']==city].geometry.values[0].coords.xy[1][0],
               city.split(',')[0], fontsize=fontsize, color='w', fontweight='bold', style='italic', ha='center')

# basemap
ctx.add_basemap(ax=ax[0], source=ctx.providers.USGS.USImagery, attribution=False, crs=refdem.crs)

# fig.tight_layout(pad=3)
plt.show()

# Save figure
fig_fn = os.path.join(figures_path, "study_sites.png")
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)


In [None]:
# Plot elevation, slope, and aspect histograms
for site_name in list(info_dict.keys()):
    refdem_fn = info_dict[site_name]['refdem_fn']
    refdem = xdem.DEM(refdem_fn).reproject(res=2)
    slope = xdem.terrain.slope(refdem)
    aspect = xdem.terrain.aspect(refdem)
    
    fig, ax = plt.subplots(1, 3, figsize=(10,5))
    ax[0].hist(refdem.data.ravel(), bins=30)
    ax[0].set_xlabel('Elevation [m]')
    ax[1].hist(slope.data.ravel(), bins=np.arange(0,61,2))
    ax[1].set_xlabel('Slope [degrees]')
    ax[2].hist(aspect.data.ravel(), bins=np.arange(0, 361, 30))
    ax[2].set_xlabel('Aspect [degrees]')
    fig.suptitle(site_name)
    fig.tight_layout()
    plt.show()

## Figure 2. Simple errors distribution plot for workflow

In [None]:
# Generate a synthetic distribution of errors
np.random.seed(42)
errors = np.random.normal(loc=0, scale=1, size=1000)

# Calculate statistics
median = np.median(errors)
nmad = 1.4826 * np.median(np.abs(errors - median))  # Normalized Median Absolute Deviation
quantile_25 = np.percentile(errors, 25)
quantile_75 = np.percentile(errors, 75)

# Plot histogram
plt.rcParams.update({'font.size': 20, 'font.sans-serif': "Arial"})
fig, ax = plt.subplots(figsize=(7, 5))
ax.hist(errors, bins=30, alpha=0.7, color='gray')
# Add vertical lines for median and quantiles
lw = 3
ax.axvline(median, color='k', linestyle='-', linewidth=lw, label='Median')
ax.axvline(quantile_25, color='k', linestyle='--', linewidth=lw, label='_nolegend')
ax.axvline(quantile_75, color='k', linestyle='--', linewidth=lw, label=f'Quantiles')
ax.axvline(median, color='w', linestyle='-', linewidth=0, label='NMAD')
# Add text for NMAD
ax.set_xlabel('Error')
ax.set_ylabel('Frequency')
ax.legend(loc='best', frameon=False)
ax.spines[['top', 'right']].set_visible(False)
ax.spines[['left', 'bottom']].set_linewidth(lw)
ax.set_xticks([0])
ax.set_yticks([])
fig.tight_layout()
plt.show()

# Save figure
fig_fn = os.path.join(figures_path, "error_distribution_example.png")
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

## Figure 3. Land cover masks

In [None]:
# Define colors for land cover types
colors_dict = {'stable_surfaces': '#C3C3C3',
               'trees': '#167700', 
               'snow': '#55F5FF'}

# Set up figure
fig, ax = plt.subplots(5,2, figsize=(8,14))

# Iterate over sites
dem_count = 0
for site_name in list(info_dict.keys()):
    dem_fns = info_dict[site_name]['DEM_fns']
    site_name_display = info_dict[site_name]["site_name_display"]
    
    # Iterate over DEMs
    for dem_fn in dem_fns:
        # Grab date from file name
        date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        if date=='20240419-1':
            date_display = '2024-04-19a'
        elif date=='20240419-2':
            date_display = '2024-04-19b'
        else:
            date_display = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
                    
        # Load input files
        rgb_fn = glob.glob(os.path.join(data_path, site_name, date, '*4band*.tif'))[0]
        rgb = rxr.open_rasterio(rgb_fn)
        crs = rgb.rio.crs
        # rgb = rgb.rio.reproject(resolution=30, dst_crs=crs)
        rgb = xr.where(rgb <= 0, np.nan, rgb / 1e4)
        def load_raster(fn):
            raster = rxr.open_rasterio(fn).squeeze()
            raster = xr.where((raster==0) | (np.abs(raster) > 1e10), np.nan, raster)
            return raster
        snow_mask_fn = os.path.join(data_path, site_name, date, 'land_cover_masks', 'snow_mask.tif')
        snow_mask = load_raster(snow_mask_fn)
        trees_mask_fn = os.path.join(data_path, site_name, date, 'land_cover_masks', 'trees_mask.tif')
        trees_mask = load_raster(trees_mask_fn)
        ss_mask_fn = os.path.join(data_path, site_name, date, 'land_cover_masks', 'stable_surfaces_mask.tif')
        ss_mask = load_raster(ss_mask_fn)
    
        # Plot
        # RGB image
        ax[dem_count,0].imshow(np.dstack([rgb.data[2]*2.5, rgb.data[1]*2.5, rgb.data[0]*2.5]),
                     extent=(np.min(rgb.x), np.max(rgb.x), np.min(rgb.y), np.max(rgb.y)))
        ax[dem_count,0].set_ylabel(f"{site_name_display}\n{date_display}", rotation=0, ha='right')
        ax[dem_count,0].set_xticks([])
        ax[dem_count,0].set_yticks([])
        # Land cover masks
        for mask, label in zip([ss_mask, trees_mask, snow_mask], list(colors_dict.keys())):
            cmap = matplotlib.colors.ListedColormap([(0,0,0,0), matplotlib.colors.to_rgb(colors_dict[label])])
            ax[dem_count,1].imshow(mask, cmap=cmap, clim=(0,1),
                           extent=(np.min(mask.x), np.max(mask.x), np.min(mask.y), np.max(mask.y)))
            # dummy point for legend
            ax[dem_count,1].plot(0, 0, 's', color=colors_dict[label], markersize=10, label=label) 
        ax[dem_count,1].set_xlim(ax[dem_count,0].get_xlim())
        ax[dem_count,1].set_ylim(ax[dem_count,0].get_ylim())
        ax[dem_count,1].set_yticks([])
        ax[dem_count,1].set_xticks([])

        dem_count += 1

# Add titles
ax[0,0].set_title('RGB mosaic')
ax[0,1].set_title('Land cover masks')
handles, labels = ax[-1,1].get_legend_handles_labels()
labels = ["Stable surfaces", "Trees", "Snow"]

# Add scalebars
for axis in ax.ravel():
    scalebar = AnchoredSizeBar(axis.transData,
                                2e3, '2 km', 'lower right', 
                                pad=0.2,
                                color='k',
                                sep=7,
                                frameon=True,
                                fill_bar=True,
                                fontproperties=fm.FontProperties(size=14, weight='bold'))
    axis.add_artist(scalebar)
    # adjust line width of scale bar
    yrange = axis.get_ylim()[1] - axis.get_ylim()[0]
    size_vertical = 0.0075 * yrange
    scalebar.size_bar.get_children()[0].set_height(size_vertical)

# Add legend
ax[0,1].legend(handles, labels, loc='upper center', frameon=False, ncols=3, 
               bbox_to_anchor=[0.35, 1.1, 0.2, 0.2], handletextpad=0.3, columnspacing=1)

fig.tight_layout()

# Save figure without panel labels
# fig_fn = os.path.join(figures_path, 'land_cover_masks_no-panel-labels.png')
# fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
# print('Figure saved to file:', fig_fn)

# Save figure with panel labels
text_labels = list(string.ascii_lowercase)
for i, axis in enumerate(ax.ravel()):
    axis.text(0.05, 0.85, text_labels[i], transform=axis.transAxes, 
              fontweight='bold', bbox=dict(facecolor='white', edgecolor='None'))
fig_fn = os.path.join(figures_path, 'land_cover_masks.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

plt.show()


## Figure 4. Slope bias correction

In [None]:
site_name = "MCS"
date = "20241003"

# Load input files
refdem_fn = info_dict['MCS']['refdem_fn']
dem_slopecorr_fn = info_dict['MCS']['DEM_fns'][1]
dem_fn = dem_slopecorr_fn.replace('_slope-corrected','')
ss_mask_fn = os.path.join(data_path, site_name, date, 'land_cover_masks', 'stable_surfaces_mask.tif')

slope = xdem.DEM(refdem_fn.replace('.tif', '_SLOPE.tif'))
refdem = xdem.DEM(refdem_fn).reproject(slope)
dem = xdem.DEM(dem_fn).reproject(slope)
dem_slopecorr = xdem.DEM(dem_slopecorr_fn).reproject(slope)
ss_mask = gu.Raster(ss_mask_fn, load_data=True).reproject(slope)
ss_mask = (ss_mask==1)
        
# Calculate dDEMs
ddem = dem - refdem
ddem_slopecorr = dem_slopecorr - refdem

# Calculate stable surface errors
ddem_ss = ddem.copy()
ddem_ss.set_mask(~ss_mask)
ddem_slopecorr_ss = ddem_slopecorr.copy()
ddem_slopecorr_ss.set_mask(~ss_mask)

# Save in dataframe
df = pd.DataFrame({'slope': slope.data.ravel(),
                   'dDEM': ddem_ss.data.ravel(),
                   'dDEM_corr': ddem_slopecorr_ss.data.ravel()})
bin_edges = np.linspace(0, 40, 21)
df['slope_bin'] = pd.cut(df['slope'], bins=bin_edges)
df.dropna(inplace=True)
df.reset_index(drop=True, inplace=True)

# Plot results
fig, ax = plt.subplots(2,2,figsize=(8,8), gridspec_kw={'height_ratios':[2,1]})
ax = ax.flatten()
ddem.plot(cmap='coolwarm_r', vmin=-5, vmax=5, add_cbar=False, ax=ax[0])
ddem_slopecorr.plot(cmap='coolwarm_r', vmin=-5, vmax=5, cbar_title='Elevation residual [m]', ax=ax[1])
for axis in [ax[0], ax[1]]:
    axis.set_xticks([])
    axis.set_yticks([])
# histgrams of slope counts
bin_counts = df['slope_bin'].value_counts(sort=False)
for axis, col in zip([ax[2], ax[3]], ['dDEM', 'dDEM_corr']):
    ax2 = axis.twinx()
    ax2.bar(list(bin_counts.index.astype(str)), bin_counts.values, color='k', alpha=0.5, width=1.0)
    ax2.set_ylim(0, np.nanmax(bin_counts.values)*4)
    ax2.set_yticks([])
    # boxplots
    sns.boxplot(data=df, x='slope_bin', y=col, color='gray', showfliers=False, 
                boxprops={'edgecolor':'k'}, medianprops={'color':'k', 'linewidth':1}, 
                whiskerprops={'color': 'k', 'linewidth':1}, ax=axis)    
    axis.set_ylim(-7,7)
    axis.axhline(0, color='k', linewidth=1)
    axis.set_xlabel('Slope [degrees]')
    axis.set_xticks(axis.get_xticks())
    axis.set_xticklabels(['0'] + ['']*4 + ['10'] + ['']*4 + ['20'] + ['']*4 + ['30'] + ['']*3 + ['  40'])

# add panel labels
text_labels = list(string.ascii_lowercase)
for i, axis in enumerate(ax.ravel()):
    axis.text(0.05, 0.9, text_labels[i], transform=axis.transAxes, 
              fontweight='bold', bbox=dict(facecolor='white', edgecolor='None'))

# add titles
ax[0].set_title('Before slope bias correction')
ax[1].set_title('After slope bias correction')
ax[2].set_ylabel('Elevation residual [m]')
ax[3].set_ylabel('')

fig.tight_layout()
plt.show()

# save figure
fig_fn = os.path.join(figures_path, 'slope_bias_correction.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)


## Figure 5. Stable surface errors: maps and histograms

In [None]:
# Set up figure
plt.rcParams.update({'font.size': 12, 'font.sans-serif':'Arial'})
fig, ax = plt.subplots(5, 3, figsize=(10,12), gridspec_kw={'width_ratios': [1, 1, 1.5]})
fig.subplots_adjust(bottom=0.1)
bins = np.arange(-5, 5.1, step=0.2)

# Iterate over sites
dem_count = 0
for site_name in list(info_dict.keys()):
    site_name_display = info_dict[site_name]['site_name_display']
    print(site_name_display)
    dem_fns = info_dict[site_name]['DEM_fns']
    # Load reference DEM
    refdem_fn = info_dict[site_name]['refdem_fn']
    refdem = xdem.DEM(refdem_fn)
    
    # Iterate over DEMs
    for dem_fn in dem_fns:
        # Grab date from file name
        date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        if date=='20240419-1':
            date_disp = '2024-04-19a'
        elif date=='20240419-2':
            date_disp = '2024-04-19b'
        else:
            date_disp = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
                    
        ### Load input files
        # Load DEM
        dem = xdem.DEM(dem_fn)
        refdem_reproj = refdem.reproject(dem)
        ddem = dem - refdem_reproj
        
        # Load orthomosaic
        ortho_fn = glob.glob(os.path.join(os.path.dirname(dem_fn), '..', '*orthomosaic.tif'))[0]
        ortho = gu.Raster(ortho_fn)
        
        # Load stable surfaces mask
        ss_mask_fn = os.path.join(os.path.dirname(dem_fn), '..', 'land_cover_masks', 'stable_surfaces_mask.tif')
        ss_mask = gu.Raster(ss_mask_fn).reproject(dem)
        ss_mask = (ss_mask == 1)
        # Mask unstable surfaces
        ddem_ss = ddem[ss_mask]
        
        # Calculate stable surface stats
        ddem_ss_median, ddem_ss_nmad = np.nanmedian(ddem_ss.data), xdem.spatialstats.nmad(ddem_ss)
        ddem_ss_p25, ddem_ss_p75 = np.nanpercentile(ddem_ss.data, 25), np.nanpercentile(ddem_ss.data, 75)
        
        ### Plot
        # Shaded relief
        ls = LightSource(azdeg=315, altdeg=45)
        hs = ls.hillshade(dem.data, vert_exag=5)
        ax[dem_count,0].imshow(hs, cmap='Greys_r',
                            extent=(dem.bounds.left, dem.bounds.right, dem.bounds.bottom, dem.bounds.top))
        hs_im = ax[dem_count,0].imshow(dem.data, cmap='terrain', alpha=0.7, clim=(1500, 3000),
                                       extent=(dem.bounds.left, dem.bounds.right, dem.bounds.bottom, dem.bounds.top))
        ax[dem_count,0].set_ylabel(f"{site_name_display}\n{date_disp}", rotation=0, ha='right')
        # dDEM
        ddem_im = ax[dem_count,1].imshow(ddem.data, cmap='coolwarm_r', vmin=-5, vmax=5,
                                         extent=(ddem.bounds.left, ddem.bounds.right, ddem.bounds.bottom, ddem.bounds.top))        
        # Elevation residuals
        ax[dem_count,2].hist(ddem_ss.ravel(), color='gray', alpha=0.5, edgecolor='k', linewidth=0.5, bins=bins)
        ax[dem_count,2].axvline(0, color='k', linewidth=1)
        ax[dem_count,2].set_xlim(np.min(bins),np.max(bins))
        ax[dem_count,2].set_yticks([])
        ax[dem_count,2].set_ylabel('Frequency')
        # ax[dem_count,2].text(f"Median: {np.round(float(ddem_ss_median),2)} m")
        # ax[dem_count,2].text(f"NMAD: {np.round(float(ddem_ss_nmad),2)} m")
        
        dem_count += 1

# add titles
ax[0,0].set_title('DEM')
ax[0,1].set_title('dDEM')
ax[0,2].set_title('Stable surface errors')
ax[-1,2].set_xlabel('Elevation residual [m]')
# add colorbars
cax = fig.add_axes([ax[-1,0].get_position().x0, 0.06, ax[-1,0].get_position().width, 0.02])
fig.colorbar(hs_im, cax=cax, orientation='horizontal', label='Elevation [m]')
cax = fig.add_axes([ax[-1,1].get_position().x0, 0.06, ax[-1,1].get_position().width, 0.02])
fig.colorbar(ddem_im, cax=cax, orientation='horizontal', label='Elevation residual [m]')
# remove coordinates from maps
for axis in ax[:, 0:2].ravel():
    axis.set_xticks([])
    axis.set_yticks([])
    
# Save figure without panel labels
fig_fn = os.path.join(figures_path, 'dDEMs_no-panel-labels.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

# Save figure with panel labels
text_labels = list(string.ascii_lowercase)
for i, axis in enumerate(ax.ravel()):
    axis.text(0.05, 0.85, text_labels[i], transform=axis.transAxes, 
              fontweight='bold', bbox=dict(facecolor='white', edgecolor='None'))
fig_fn = os.path.join(figures_path, 'dDEMs.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

plt.show()
    

## Figure 6. Stable surface errors vs. terrain parameters

In [None]:
# Set up figure
fig, ax = plt.subplots(5, 3, figsize=(12,12))

# Iterate over sites
dem_count = 0
for site_name in list(info_dict.keys()):
    print(site_name)
    refdem_fn = info_dict[site_name]['refdem_fn']
    refdem = xdem.DEM(refdem_fn)
    
    # Load terrain parameters
    elev = xdem.DEM(refdem_fn.replace('.tif', '_ELEVATION.tif'))
    slope = xdem.DEM(refdem_fn.replace('.tif', '_SLOPE.tif'))
    aspect = xdem.DEM(refdem_fn.replace('.tif', '_ASPECT.tif'))
        
    dem_fns = info_dict[site_name]['DEM_fns']
    # Iterate over DEMs
    for dem_fn in dem_fns:
        print(os.path.basename(dem_fn))
        date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        if date=='20240419-1':
            date_disp = '2024-04-19a'
        elif date=='20240419-2':
            date_disp = '2024-04-19b'
        else:
            date_disp = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
        # Load DEMs and stable surface masks
        dem = xdem.DEM(dem_fn)
        refdem_reproj = refdem.reproject(dem)
        ss_mask_fn = os.path.join(data_path, site_name, date, 'land_cover_masks', 'stable_surfaces_mask.tif')
        ss_mask = gu.Raster(ss_mask_fn, load_data=True).reproject(dem)
        ss_mask = (ss_mask==1)
        
        # Calculate dDEM
        ddem = dem - refdem_reproj
        
        # Mask stable surfaces
        ddem_masked = ddem.copy()
        ddem_masked.set_mask(~ss_mask)
        
        # Regrid terrain parameters
        elev_reproj = elev.reproject(dem)
        slope_reproj = slope.reproject(dem)
        aspect_reproj = aspect.reproject(dem)
        
        # Compile in dataframe
        df = pd.DataFrame({
                'elevation': elev_reproj.data.ravel(),
                'slope': slope_reproj.data.ravel(),
                'aspect': aspect_reproj.data.ravel(),
                'dDEM': ddem_masked.data.ravel()
                })
        df.dropna(inplace=True)
        
        # Create bins for boxplots
        nbins = 20
        elev_min = np.floor(np.nanmin(df['elevation']) / 100) * 100
        elev_max = np.ceil(np.nanmax(df['elevation']) / 100) * 100
        df['elevation_bin'] = pd.cut(df['elevation'], bins=np.linspace(elev_min, elev_max, nbins + 1))
        df['slope_bin'] = pd.cut(df['slope'], bins=np.linspace(0, 40, 21))
        df['aspect_bin'] = pd.cut(df['aspect'], bins=np.linspace(0, 360, nbins + 1))        

        # Plot
        for i, column in enumerate(['elevation', 'slope', 'aspect']):
            # histogram
            ax2 = ax[dem_count, i].twinx()
            bin_counts = df[column+'_bin'].value_counts(sort=False)
            ax2.bar(bin_counts.index.astype(str), bin_counts.values, color='k', alpha=0.5, width=1.0)
            # Set y-axis label and color for counts
            ax2.set_yticks([])
            ax2.set_ylim(0, np.nanmax(bin_counts.values)*3)
            # boxplot
            sns.boxplot(data=df, x=column+'_bin', y='dDEM', color='#fee8c8', showfliers=False, 
                        boxprops={'edgecolor':'k'}, medianprops={'color':'k', 'linewidth':1}, whiskerprops={'color': 'k', 'linewidth':1}, ax=ax[dem_count, i])
            ax[dem_count, i].set_xlabel('')
            ax[dem_count, i].set_xticks(ax[dem_count, i].get_xticks())
            if column=='elevation':
                ax[dem_count,i].set_xticklabels([str(int(elev_min))] + ['']*9 + [str(int((elev_max+elev_min)/2))] + ['']*8 + [str(int(elev_max))])
                ax[dem_count,i].set_xlabel('meters')
            elif column=='slope':
                ax[dem_count,i].set_xticklabels(['0 '] + ['']*4 + ['10'] + ['']*4 + ['20 '] + ['']*4 + ['30 '] + ['']*3 + ['  40'])
                ax[dem_count,i].set_xlabel('degrees')
            elif column=='aspect':
                ax[dem_count, i].set_xticklabels(['N'] + ['']*4 + ['E'] + ['']*4 + ['S'] + ['']*4 + ['W'] + ['']*4)
            
            ax[dem_count, i].axhline(y=0, xmin=ax[dem_count, i].get_xlim()[0], xmax=ax[dem_count, i].get_xlim()[1], 
                                     color='k', linewidth=1)
            ax[dem_count,i].set_ylabel('')
            ax[dem_count,i].set_ylim(-10,10)
        ax[dem_count,0].set_ylabel('Elevation residual [m]')
        ax[dem_count, 0].text(-0.3, 0.5, f"{info_dict[site_name]['site_name_display']}\n{date_disp}", transform=ax[dem_count,0].transAxes, rotation=0, ha='right')

        dem_count+=1

# add panel labels
text_labels = list(string.ascii_lowercase)
for i, axis in enumerate(ax.ravel()):
    axis.text(0.03, 0.87, text_labels[i], transform=axis.transAxes, 
              fontweight='bold', bbox=dict(facecolor='None', edgecolor='None'))

ax[0,0].set_title('Elevation')
ax[0,1].set_title('Slope')
ax[0,2].set_title('Aspect')
fig.tight_layout()
plt.show()

# Save figure
fig_fn = os.path.join(figures_path, 'stable_surface_errors_vs_terrain.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)
                

## Figure 7. Snow depth observations comparison

In [None]:
skysat_color = '#33a02c'
lidar_color = '#1f78b4'
date_format = matplotlib.dates.DateFormatter("%b %-d")

# function to plot range of snow depths and snow depth at SNOTEL
def plot_sd_snotel(axis, sd_map, sd_map_date, sd_map_color, sd_map_label, snotel_vector, linestyle='-'):
    sd_map_min, sd_map_max = np.nanpercentile(sd_map.data.data.ravel(), 1), np.nanpercentile(sd_map.data.data.ravel(), 99)
    sd_map_mid = (sd_map_max + sd_map_min)/2
    snotel_vector_reproj = snotel_vector.reproject(sd_map).buffer(4)
    snotel_vector_reproj = snotel_vector_reproj.create_mask(sd_map)
    sd_map_snotel = sd_map[snotel_vector_reproj]
    # Plot
    axis.plot(sd_map_date, sd_map_snotel.mean(), 'o', markersize=10, color=sd_map_color, label=f'{sd_map_label} at SNOTEL')
    eb = axis.errorbar(x=sd_map_date, y=sd_map_mid, yerr=(sd_map_max-sd_map_min)/2, color=sd_map_color, 
                       linewidth=2, capsize=5, label=f"{sd_map_label} range")
    eb[-1][0].set_linestyle(linestyle)
    
fig, ax = plt.subplots(3, 4, figsize=(12, 10))#, gridspec_kw={'width_ratios': [1, 1, 1, 1.5]})
dem_count=0
for site_name in ['MCS', 'Banner']:
    site_name_display = info_dict[site_name]['site_name_display']
    print('\n', site_name_display)
    
    # Load lidar snow depth file name
    if site_name=='MCS':
        sd_lidar_fn = os.path.join(data_path, site_name, 'SNEX_QSI_SD', 'SNEX_MCS_Lidar_20240418_SD_V01.0.tif')
    elif site_name=='Banner':
        sd_lidar_fn = os.path.join(data_path, site_name, 'SNEX_QSI_SD', 'SNEX20_QSI_SD_0.5M_USIDBS_20210315_20210315.tif')
    
    # Load SNOTEL
    # snow depth time series
    snotel_fn = info_dict[site_name]['SNOTEL_fn']
    snotel = pd.read_csv(snotel_fn)
    snotel['datetime'] = pd.to_datetime(snotel['datetime'])
    snotel = snotel.loc[snotel['datetime'].dt.year==2024]
    snotel['SNWD_90%_m'] = snotel['SNWD_m'] * 0.9
    snotel['SNWD_110%_m'] = snotel['SNWD_m'] * 1.1
    ax[dem_count,3].fill_between(snotel['datetime'], snotel['SNWD_90%_m'], snotel['SNWD_110%_m'], color='k', alpha=0.2)
    ax[dem_count,3].plot(snotel['datetime'], snotel['SNWD_m'], '-', color='k', label='SNOTEL')
    # site location
    snotel_site_fn = info_dict[site_name]['SNOTEL_site_fn']
    snotel_site = pd.read_csv(snotel_site_fn)
    snotel_site['geometry'] = wkt.loads(snotel_site['geometry'])
    snotel_site = gpd.GeoDataFrame(snotel_site, crs="EPSG:4326")
    snotel_site = gu.Vector(snotel_site)    
    
    # Iterate over DEMs
    for dem_fn in info_dict[site_name]['DEM_fns']:
        date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        if (site_name=='MCS') & (date=='20241003'):
            continue
        elif date=='20240419-1':
            date_display = '2024-04-19a'
        elif date=='20240419-2':
            date_display = '2024-04-19b'
        else:
            date_display = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
        print(date_display)
        
        # Load SkySat snow depth
        sd_skysat_fn = glob.glob(os.path.join(data_path, site_name, date, 'post_process', '*snow_depth.tif'))[0]
        sd_skysat = xdem.DEM(sd_skysat_fn)
        sd_skysat.set_nodata(np.nan, update_array=True)
        sd_skysat_date = np.datetime64(date_display[0:10])
        # Load and reproject lidar snow depth
        sd_lidar = xdem.DEM(sd_lidar_fn).reproject(sd_skysat)
        sd_lidar.set_nodata(np.nan, update_array=True)
        sd_lidar_date = [x for x in os.path.basename(sd_lidar_fn).split('_') if '202' in x][0]
        # Combine snow depths into pandas dataframe
        sd_df = pd.DataFrame({'sd_skysat': sd_skysat.data.ravel(),
                              'sd_lidar': sd_lidar.data.ravel()})
        # Plot
        # snow depth maps
        sd_skysat.plot(cmap='YlGnBu', vmin=0, vmax=4, ax=ax[dem_count,0], add_cbar=False)
        sd_lidar.plot(cmap='YlGnBu', vmin=0, vmax=4, ax=ax[dem_count,1], add_cbar=False)
        # histogram heat map
        hist_data, x_edges, y_edges = np.histogram2d(sd_df['sd_skysat'], sd_df['sd_lidar'], bins=np.linspace(-1,5,40))
        vmin = np.nanmin(hist_data)
        vmax = np.nanmax(hist_data) * 0.9
        ax[dem_count,2].imshow(hist_data.T, origin='lower', cmap='YlOrBr', extent=[-1,5,-1,5], clim=(vmin, vmax))
        ax[dem_count,2].plot([-1, 5], [-1, 5], '-k', linewidth=2, label="1:1") 
        ax[dem_count,2].set_xlabel('SkySat snow depth [m]')
        ax[dem_count,2].set_ylabel('Lidar snow depth [m]')
        ax[dem_count,2].set_xlim(-1,3)
        ax[dem_count,2].set_ylim(-1,3)
        # add titles
        ax[dem_count,0].set_title(f"SkySat\n{site_name_display}\n{date_display}", fontsize=12)
        ax[dem_count,1].set_title(f"Lidar\n{site_name_display}\n{sd_lidar_date[0:4]}-{sd_lidar_date[4:6]}-{sd_lidar_date[6:]}", rotation=0, fontsize=12)

        # Calculate lidar range in snow depths and sample at SNOTEL
        if site_name=='MCS':
            lidar_count=0
            sd_lidar_fns = sorted(glob.glob(os.path.join(data_path, site_name, 'SNEX_QSI_SD', '*_SD_*.tif')))
            for sd_lidar_fn in sd_lidar_fns:
                # Load and reproject lidar snow depth
                sd_lidar = xdem.DEM(sd_lidar_fn).reproject(sd_skysat)
                sd_lidar.set_nodata(np.nan, update_array=True)
                sd_lidar_date = [x for x in os.path.basename(sd_lidar_fn).split('_') if '202' in x][0]
                sd_lidar_date = np.datetime64(f"{sd_lidar_date[0:4]}-{sd_lidar_date[4:6]}-{sd_lidar_date[6:]}")
                # Plot snow depth range and sample at SNOTEL
                if lidar_count==0:
                    label='Lidar'
                else:
                    label='__nolegend'
                plot_sd_snotel(ax[dem_count,3], sd_lidar, sd_lidar_date, lidar_color, label, snotel_site)
                lidar_count += 1
        elif (site_name=='Banner') & (date=='20240419-1'):
            sd_lidar_date = np.datetime64(f"2024-{sd_lidar_date[4:6]}-{sd_lidar_date[6:]}")
            plot_sd_snotel(ax[dem_count,3], sd_lidar, sd_lidar_date, lidar_color, label, snotel_site, linestyle='--')
        
        # Calculate SkySat range in snow depths and sample at SNOTEL
        if date!='20240419-2':
            plot_sd_snotel(ax[dem_count,3], sd_skysat, sd_skysat_date, skysat_color, 'SkySat', snotel_site)
            # adjust axes
            ax[dem_count,3].set_ylabel('Snow depth [m]')
            ax[dem_count,3].set_xlim(np.datetime64('2024-01-01'), np.datetime64('2024-06-01'))
            ax[dem_count,3].xaxis.set_major_formatter(date_format)
            ax[dem_count,3].set_xticks(ax[dem_count,3].get_xticks()[::2])
            ax[dem_count,3].set_yticks([0,1,2])
            ax[dem_count,3].grid()
        else:
            ax[dem_count,3].remove()
        
        dem_count += 1

# adjust axes
for axis in ax[:,0:2].ravel():
    axis.set_xticks([])
    axis.set_yticks([])
handles, labels = ax[0,-1].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower right', frameon=True, bbox_to_anchor=(0.97, 0.15))

# add colorbars
# snow depth
cax = fig.add_axes([ax[-1,0].get_position().x0+0.07, -0.02, ax[-1,0].get_position().width*1.5, 0.02])
fig.colorbar(matplotlib.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=0, vmax=4), cmap='YlGnBu'),
             cax=cax, orientation='horizontal', label='Snow depth [m]')
# histogram heatmap
cax = fig.add_axes([ax[0,2].get_position().x0, -0.02, ax[0,2].get_position().width, 0.02])
cbar = fig.colorbar(matplotlib.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=0, vmax=vmax), cmap='YlOrBr'),
             cax=cax, orientation='horizontal', label='Counts')

# add scalebars to maps
for axis in ax[:,0:2].ravel():
    scalebar = AnchoredSizeBar(axis.transData,
                                2e3, '2 km', 'lower left', 
                                pad=0.2,
                                color='k',
                                sep=7,
                                frameon=False,
                                fill_bar=True,
                                bbox_transform=axis.transAxes,
                                fontproperties=fm.FontProperties(size=12, weight='bold'))
    axis.add_artist(scalebar)
    # adjust line width of scale bar
    yrange = axis.get_ylim()[1] - axis.get_ylim()[0]
    size_vertical = 0.0075 * yrange
    scalebar.size_bar.get_children()[0].set_height(size_vertical)

# add panels labels
text_labels = list(string.ascii_lowercase)
for i, axis in enumerate(ax.ravel()):
    axis.text(0.05, 0.9, text_labels[i], transform=axis.transAxes, 
              fontweight='bold', fontsize=14, bbox=dict(facecolor='None', edgecolor='None'))

plt.show()

# Save figure
fig_fn = os.path.join(figures_path, 'snow_depth_comparison.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)


## Figure 8. Snow depth models at Mores Creek

In [None]:
# Load data
site_name = "MCS"
date = '20240420'
dem_fn = info_dict[site_name]['DEM_fns'][0]
site_name_display = info_dict[site_name]['site_name_display']

# Get date from file name
date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
date_display = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
print(date_display)

# Load SkySat snow depth
sd_skysat_fn = glob.glob(os.path.join(data_path, site_name, date, 'post_process', '*snow_depth*.tif'))[0]
sd_skysat = xdem.DEM(sd_skysat_fn)
sd_skysat.set_nodata(np.nan, update_array=True)

# Load SkySat modeled snow depth
sd_skysat_mod_fn = glob.glob(os.path.join(data_path, site_name, date, 'snow_depth_modeling', f'modeled_snow_depth*.tif'))[0]
sd_skysat_mod = xdem.DEM(sd_skysat_mod_fn).reproject(sd_skysat)
sd_skysat_mod.set_nodata(np.nan, update_array=True)
# Clip no data rows and columns
sd_skysat_mod = clip_nodata(sd_skysat_mod)

# Load lidar snow depth
sd_lidar_fn = os.path.join(data_path, site_name, 'SNEX_QSI_SD', 'SNEX_MCS_Lidar_20240418_SD_V01.0.tif')
sd_lidar = xdem.DEM(sd_lidar_fn).reproject(sd_skysat)
sd_lidar.set_nodata(np.nan, update_array=True)

# Load lidar modeled snow depth
sd_lidar_mod_fn = glob.glob(os.path.join(os.path.dirname(sd_lidar_fn), 'snow_depth_modeling', '*snow_depth*.tif'))[0]
sd_lidar_mod = xdem.DEM(sd_lidar_mod_fn).reproject(sd_skysat)
sd_lidar_mod.set_nodata(np.nan, update_array=True)

# Calculate modeled - lidar
sd_skysat_mod_minus_lidar = sd_skysat_mod - sd_lidar
sd_lidar_mod_minus_lidar = sd_lidar_mod - sd_lidar

        

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.projections import PolarAxes
import mpl_toolkits.axisartist.floating_axes as FA
import mpl_toolkits.axisartist.grid_finder as GF
import itertools

class TaylorDiagram:
    """
    Taylor diagram adapted from Yannick Copin: <yannick.copin@laposte.net>
    Source: https://gist.github.com/ycopin/3342888
    """
    def __init__(self, refstd, fig=None, rect=111, label='_', srange=(0, 1.5), crange=(0, 1)):
        """
        Initialize the Taylor diagram.
        
        Parameters:
        * refstd: Reference standard deviation
        * fig: Matplotlib figure object
        * rect: Subplot definition
        * label: Reference label
        * srange: Standard deviation axis extension (relative to refstd)
        * crange: Correlation range (e.g., (-0.4,1) to slightly extend to negative correlations)
        """
        self.refstd = refstd
        tr = PolarAxes.PolarTransform()
        
        # Define correlation limits
        cmin, cmax = crange
        rlocs = [-0.2, 0, 0.2, 0.4, 0.6, 0.8, 0.9]
        tlocs = np.arccos(rlocs)  # Convert correlations to polar angles
        gl1 = GF.FixedLocator(tlocs)
        tf1 = GF.DictFormatter(dict(zip(tlocs, map(lambda x: f"{round(x, 2):.1f}", rlocs))))
        
        # Define std limits
        self.smin = srange[0] * self.refstd
        self.smax = srange[1] * self.refstd
        
        self.tmax = np.arccos(cmin)  # Adjust angular range based on crange
        
        ghelper = FA.GridHelperCurveLinear(
            tr,
            extremes=(0, self.tmax, self.smin, self.smax),
            grid_locator1=gl1, tick_formatter1=tf1)
        
        if fig is None:
            fig = plt.figure()
        
        ax = FA.FloatingSubplot(fig, rect, grid_helper=ghelper)
        fig.add_subplot(ax)
        
        ax.axis["top"].set_axis_direction("bottom")
        ax.axis["top"].toggle(ticklabels=True, label=True)
        ax.axis["top"].major_ticklabels.set_axis_direction("top")
        ax.axis["top"].label.set_axis_direction("top")
        ax.axis["top"].label.set_text("Correlation")
        
        ax.axis["left"].set_axis_direction("bottom")
        ax.axis["left"].label.set_text("Standard deviation")
        
        ax.axis["right"].set_axis_direction("top")
        ax.axis["right"].toggle(ticklabels=True)
        ax.axis["right"].major_ticklabels.set_axis_direction("bottom")
        
        if self.smin:
            ax.axis["bottom"].toggle(ticklabels=False, label=False)
        else:
            ax.axis["bottom"].set_visible(False)
        
        self._ax = ax  # Save main axes
        self.ax = ax.get_aux_axes(tr)  # Get polar coordinate system
        
        l, = self.ax.plot([0], self.refstd, 'k*', ls='', ms=10, label=label)
        
        t = np.linspace(0, self.tmax)
        r = np.zeros_like(t) + self.refstd
        self.ax.plot(t, r, 'k--', label='_')
        
        self.samplePoints = [l]
    
    def add_sample(self, stddev, corrcoef, *args, **kwargs):
        """
        Add a sample point (stddev, corrcoef) to the Taylor diagram.
        """
        l, = self.ax.plot(np.arccos(corrcoef), stddev, *args, **kwargs)
        self.samplePoints.append(l)
        return l
    
    def add_grid(self, *args, **kwargs):
        """Add a grid."""
        self._ax.grid(*args, **kwargs)
    
    def add_contours(self, levels=5, **kwargs):
        """
        Add constant centered RMS difference contours.
        """
        rs, ts = np.meshgrid(np.linspace(self.smin, self.smax),
                             np.linspace(0, self.tmax))
        # Compute centered RMS difference
        rms = np.sqrt(self.refstd**2 + rs**2 - 2*self.refstd*rs*np.cos(ts))

        contours = self.ax.contour(ts, rs, rms, levels, **kwargs)
        
        return contours


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.projections import PolarAxes
import mpl_toolkits.axisartist.floating_axes as FA
import mpl_toolkits.axisartist.grid_finder as GF
import itertools

class TaylorDiagram:
    def __init__(self, refstd, fig=None, rect=111, label='_', srange=(0, 1.5), crange=(0, 1)):
        self.refstd = refstd
        tr = PolarAxes.PolarTransform()

        # Define correlation limits
        cmin, cmax = crange
        rlocs = [-0.2, 0, 0.2, 0.4, 0.6, 0.8, 0.9]
        tlocs = np.arccos(rlocs)  # Convert correlations to polar angles
        gl1 = GF.FixedLocator(tlocs)
        tf1 = GF.DictFormatter(dict(zip(tlocs, map(lambda x: f"{round(x, 2):.1f}", rlocs))))

        # Define std limits
        self.smin = srange[0] * self.refstd
        self.smax = srange[1] * self.refstd

        self.tmax = np.arccos(cmin)  # Adjust angular range based on crange

        # Custom standard deviation ticks (example)
        std_ticks = np.arange(self.refstd*0.1, self.refstd, self.refstd*0.2)
        std_tick_labels = [f"{tick/self.refstd:.1f}" for tick in std_ticks]  
        gl2 = GF.FixedLocator(std_ticks)
        tf2 = GF.DictFormatter(dict(zip(std_ticks, std_tick_labels)))

        ghelper = FA.GridHelperCurveLinear(
            tr,
            extremes=(0, self.tmax, self.smin, self.smax),
            grid_locator1=gl1, tick_formatter1=tf1,
            grid_locator2=gl2, tick_formatter2=tf2)  # Add second locator and formatter for std

        if fig is None:
            fig = plt.figure()

        ax = FA.FloatingSubplot(fig, rect, grid_helper=ghelper)
        fig.add_subplot(ax)

        ax.axis["top"].set_axis_direction("bottom")
        ax.axis["top"].toggle(ticklabels=True, label=True)
        ax.axis["top"].major_ticklabels.set_axis_direction("top")
        ax.axis["top"].label.set_axis_direction("top")
        ax.axis["top"].label.set_text("Correlation")

        ax.axis["left"].set_axis_direction("bottom")
        ax.axis["left"].label.set_text("Standard deviation")

        ax.axis["right"].set_axis_direction("top")
        ax.axis["right"].toggle(ticklabels=True)
        ax.axis["right"].major_ticklabels.set_axis_direction("bottom")

        if self.smin:
            ax.axis["bottom"].toggle(ticklabels=False, label=False)
        else:
            ax.axis["bottom"].set_visible(False)

        self._ax = ax
        self.ax = ax.get_aux_axes(tr)

        l, = self.ax.plot([0], self.refstd, 'k*', ls='', ms=10, label=label)

        t = np.linspace(0, self.tmax)
        r = np.zeros_like(t) + self.refstd
        self.ax.plot(t, r, 'k--', label='_')

        self.samplePoints = [l]

    def add_sample(self, stddev, corrcoef, *args, **kwargs):
        l, = self.ax.plot(np.arccos(corrcoef), stddev, *args, **kwargs)
        self.samplePoints.append(l)
        return l

    def add_grid(self, *args, **kwargs):
        self._ax.grid(*args, **kwargs)

    def add_contours(self, levels=5, **kwargs):
        rs, ts = np.meshgrid(np.linspace(self.smin, self.smax),
                             np.linspace(0, self.tmax))
        rms = np.sqrt(self.refstd**2 + rs**2 - 2*self.refstd*rs*np.cos(ts))
        contours = self.ax.contour(ts, rs, rms, levels, **kwargs)
        return contours


In [None]:
# Set up figures
gs = matplotlib.gridspec.GridSpec(2, 3)
fig = plt.figure(figsize=(10,6))
ax = [[fig.add_subplot(gs[0,0]), fig.add_subplot(gs[0,1]), fig.add_subplot(133)],
      [fig.add_subplot(gs[1,0]), fig.add_subplot(gs[1,1])],]

# Plot
sd_skysat_mod.plot(cmap='YlGnBu', vmin=0, vmax=4, ax=ax[0][0], add_cbar=False)
ax[0][0].set_title('SkySat modeled')
sd_skysat_mod_minus_lidar.plot(cmap='coolwarm_r', vmin=-4, vmax=4, ax=ax[0][1], add_cbar=False)
ax[0][1].set_title('SkySat modeled $-$ Lidar')
sd_lidar_mod.plot(cmap='YlGnBu', vmin=0, vmax=4, ax=ax[1][0], add_cbar=False)
ax[1][0].set_title('Lidar modeled')
sd_lidar_mod_minus_lidar.plot(cmap='coolwarm_r', vmin=-4, vmax=4, ax=ax[1][1], add_cbar=False)
ax[1][1].set_title('Lidar modeled $-$ Lidar')

# Plot Taylor diagram
refstd = sd_lidar.data.std()
dia = TaylorDiagram(refstd, fig=fig, rect=133, label="Lidar",
                    srange=(0, 1.15), crange=(-0.2,1))
# Add the models to Taylor diagram
colors = [lidar_color, skysat_color, skysat_color]
markers = ['^', 'o', '^']
labels = ['Lidar modeled', 'SkySat', 'SkySat modeled']
for i, sample in enumerate([sd_lidar_mod, sd_skysat, sd_skysat_mod]):
    stddev = sample.data.std()
    sample_data = sample.data.data.ravel()
    lidar_data = sd_lidar.data.data.ravel()
    sample_data_clean = sample_data[(~np.isnan(sample_data)) & (~np.isnan(lidar_data))]
    lidar_data_clean = lidar_data[(~np.isnan(sample_data)) & (~np.isnan(lidar_data))]
    corrcoef = np.corrcoef(sample_data_clean, lidar_data_clean)[0][1]
    dia.add_sample(stddev, corrcoef,
                   marker=markers[i], ms=10, ls='',
                   mfc=colors[i], mec=colors[i], label=labels[i])
dia.add_grid()
# Add RMS contours, and label them
contours = dia.add_contours(colors='0.5')
plt.clabel(contours, inline=1, fontsize=10, fmt='%.2f')
fig.legend(dia.samplePoints,
           [p.get_label() for p in dia.samplePoints],
           numpoints=1, ncols=2, columnspacing=0.8, handletextpad=0.5, 
           loc='lower right', bbox_to_anchor=[0.8, 0.05, 0.2, 0.2])

# add scalebars to maps
for axis in [ax[0][0], ax[0][1], ax[1][0], ax[1][1]]:
    scalebar = AnchoredSizeBar(axis.transData,
                                2e3, '2 km', 'lower left', 
                                pad=0.2,
                                color='k',
                                sep=7,
                                frameon=False,
                                fill_bar=True,
                                bbox_transform=axis.transAxes,
                                fontproperties=fm.FontProperties(size=12, weight='bold'))
    axis.add_artist(scalebar)
    # adjust line width of scale bar
    yrange = axis.get_ylim()[1] - axis.get_ylim()[0]
    size_vertical = 0.0075 * yrange
    scalebar.size_bar.get_children()[0].set_height(size_vertical)

# add colorbars
cax = fig.add_axes([ax[1][0].get_position().x0, 0.0, ax[1][0].get_position().width, 0.02])
fig.colorbar(matplotlib.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=0, vmax=4), cmap='YlGnBu'),
             cax=cax, orientation='horizontal', label='Snow depth [m]')
cax = fig.add_axes([ax[1][1].get_position().x0, 0.0, ax[1][1].get_position().width, 0.02])
fig.colorbar(matplotlib.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=-4, vmax=4), cmap='coolwarm_r'),
             cax=cax, orientation='horizontal', label='Difference [m]')
# add panels labels and # remove axes ticks from maps
text_labels = ['a', 'b', 'c', 'd', 'e']    
for i, axis in enumerate(list(itertools.chain.from_iterable(ax))):
    if i==2:
        xscale, yscale = 0.05, 0.85
    else:
        xscale, yscale = 0.05, 0.9
    axis.text(xscale, yscale, text_labels[i], transform=axis.transAxes, 
              fontweight='bold', fontsize=14, bbox=dict(facecolor='None', edgecolor='None'))
    axis.set_xticks([])
    axis.set_yticks([])
ax[0][2].spines[['top', 'right', 'bottom', 'left']].set_visible(False)

plt.show()

# Save figure
fig_fn = os.path.join(figures_path, 'snow_depth_models_comparison.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)


## Default vs. updated DEMs

In [None]:
# Set up figure
fig, ax = plt.subplots(5, 3, figsize=(10,16))

# Iterate over site names
dem_count=0
for site_name in list(info_dict.keys()):
    print(site_name)
    dem_fns = info_dict[site_name]['DEM_fns']
    site_name_display = info_dict[site_name]['site_name_display']
    
    # Load reference elevations
    refdem_fn = info_dict[site_name]['refdem_fn']
    refdem = xdem.DEM(refdem_fn)
    
    # Iterate over DEMs
    for dem_fn in dem_fns:
        print(os.path.basename(dem_fn))
        # Get date from file name
        date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        if date=='20240419-1':
            date_display = '2024-04-19a'
        elif date=='20240419-2':
            date_display = '2024-04-19b'
        else:
            date_display = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
    
        # Load DEMs
        dem = xdem.DEM(dem_fn)
        dem_default_fn = os.path.join(data_path, site_name, date, f"default_{site_name}_{date}_DEM.tif")
        dem_default = xdem.DEM(dem_default_fn).reproject(dem)
        refdem_reproj = refdem.reproject(dem)
        
        # Load orthomosaic
        ortho_fns = info_dict[site_name]['orthomosaic_fns']
        ortho_fn = [x for x in ortho_fns if date in x][0]
        ortho = gu.Raster(ortho_fn)
        
        # Calculate dDEM
        ddem = dem - refdem_reproj
        ddem_default = dem_default - refdem_reproj
        
        # Plot
        ortho.plot(cmap='Grays_r', add_cbar=False, ax=ax[dem_count,0])
        ddem_default.plot(cmap='coolwarm_r', vmin=-5, vmax=5, ax=ax[dem_count,1])
        ddem.plot(cmap='coolwarm_r', vmin=-5, vmax=5, ax=ax[dem_count,2])
        
        # adjust axes
        ax[dem_count,0].set_ylabel(f"{site_name_display}\n{date_display}", rotation=0, ha='right')
        for axis in ax[dem_count,:]:
            axis.set_xticks([])
            axis.set_yticks([])
        
        dem_count += 1

ax[0,0].set_title('Default pipeline')
ax[0,1].set_title('Modified pipeline')
fig.suptitle('SkySat - Reference DEM')

fig.tight_layout()
plt.show()

# Save figure
fig_fn = os.path.join(figures_path, 'default_vs_modified_pipeline_dDEMs.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

## Figure S1. reference DEMs for coreg and bundle adjust at MCS only

In [None]:
# Load reference DEM
refdem_fn = os.path.join(data_path, 'MCS', 'refdem', 'MCS_REFDEM_WGS84.tif')
refdem = xdem.DEM(refdem_fn, load_data=True)

# Load GCP
gcp_fn = os.path.join(data_path, '..', 'ITD_Functional_Class', 'ITD_HWY_21.shp')
gcp_elev = 0
gcp = gu.Vector(gcp_fn)

# Iterate over dates
dates = ['20241003', '20240420']
for date in dates:
    print('\n', date)
    
    # Load DEM file names
    dem_fns = sorted(glob.glob(os.path.join(data_path, 'MCS', date, '*_DEM.tif')))
    dem_fns = [x for x in dem_fns if 'camweight' not in x]

    # Load stable surfaces mask
    ss_mask_fn = os.path.join(data_path, 'MCS', date, 'land_cover_masks', 'stable_surfaces_mask.tif')
    ss_mask = gu.Raster(ss_mask_fn, load_data=True)

    # Define all options tested
    refdem_opts = ['coregAll', 'coregStable', 'coregRoads']
    refdem_opts_display = ['All surfaces', 'Stable surfaces', 'Roads']
    ba_opts = ['noDEM', 'u10m', 'u5m', 'u1m']
    ba_opts_display = ['No DEM', 'DEM $\sigma$ = 10 m', 'DEM $\sigma$ = 5 m', 'DEM $\sigma$ = 1 m']
    text_labels = list(string.ascii_lowercase)

    # Set up figures
    # Map view
    fig1, ax1 = plt.subplots(len(ba_opts), len(refdem_opts), figsize=(len(refdem_opts)*4, len(ba_opts)*4))
    # Histograms
    fig2, ax2 = plt.subplots(len(ba_opts), len(refdem_opts), figsize=(len(refdem_opts)*4, len(ba_opts)*4))
    
    # Iterate over options
    dem_count = 0
    nmads = 1e4 * np.ones(np.shape(ax1))
    medians = 1e4 * np.ones(np.shape(ax1))
    for i, refdem_opt in enumerate(refdem_opts):
        dem_refdem_fns = [x for x in dem_fns if refdem_opt in os.path.basename(x)]
        if i==2:
            vmin, vmax = -50,50
        else:
            vmin, vmax=-5,5
        for j, ba_opt in enumerate(ba_opts):
            dem_refdem_ba_fns = [x for x in dem_refdem_fns if ba_opt in os.path.basename(x)]
            if len(dem_refdem_ba_fns) < 1:
                dem_count += 1
                continue
            else:
                # Add row and column titles
                if j==0:
                    ax1[j,i].set_title(refdem_opts_display[i])
                    ax2[j,i].set_title(refdem_opts_display[i])
                if i==0:
                    ax1[j,i].set_ylabel(ba_opts_display[j], rotation=0, ha='right', fontsize=14)
                    ax2[j,i].set_ylabel(ba_opts_display[j], rotation=0, ha='right', fontsize=14)
                
                dem_fn = dem_refdem_ba_fns[0]
                print(os.path.basename(dem_fn))
                
                # Load DEMs
                dem = xdem.DEM(dem_fn, load_data=True)
                refdem_reproj = refdem.reproject(dem)            
                
                # Load stable surfaces mask
                ss_mask_reproj = ss_mask.reproject(dem)
                ss_mask_reproj = (ss_mask_reproj == 1)
                            
                # Calculate dDEM
                ddem = dem - refdem_reproj
                
                # Apply vertical adjustment with GCP
                gcp_reproj = gcp.reproject(dem)
                gcp_reproj = gcp_reproj.create_mask(dem)
                ddem_gcp = ddem[gcp_reproj]
                ddem_gcp_median = np.ma.median(ddem_gcp)
                print(ddem_gcp_median)
                dem -= ddem_gcp_median
                ddem -= ddem_gcp_median
                
                if date=='20241003':
                    # Apply slope correction
                    corr = xdem.coreg.TerrainBias(terrain_attribute='slope').fit(refdem_reproj, dem, ss_mask_reproj)
                    dem_corr = corr.apply(dem)
                
                    # Re-apply vertical adjustment with GCP
                    ddem = dem_corr - refdem_reproj
                    ddem_gcp = ddem[gcp_reproj]
                    ddem_gcp_median = np.ma.median(ddem_gcp)
                    print(ddem_gcp_median)
                    dem_corr -= ddem_gcp_median
                    ddem -= ddem_gcp_median
                
                # Calculate stable surface errors
                ddem_ss = ddem[ss_mask_reproj]
                ddem_ss_median = np.ma.median(ddem_ss)
                medians[j,i] = ddem_ss_median
                nmads[j,i] = xdem.spatialstats.nmad(ddem_ss)
                
                # Plot
                im = ax1[j,i].imshow(ddem.data, cmap='coolwarm_r', vmin=vmin, vmax=vmax, 
                                    extent=(ddem.bounds.left, ddem.bounds.right, ddem.bounds.bottom, ddem.bounds.top))
                ax1[j,i].set_xticks([])
                ax1[j,i].set_yticks([])
                ax1[j,i].set_xlim(602000, 609650)
                ax1[j,i].set_ylim(4863200, 4871300)
                ax2[j,i].hist(ddem_ss.data.ravel(), bins=np.linspace(vmin, vmax, 20), 
                            color='gray', edgecolor='k', linewidth=0.5)
                ax2[j,i].set_yticks([])
                ax2[j,i].set_xlim(vmin,vmax)
                if i==2:
                    ylim = ax2[j,i].get_ylim()
                    ax2[j,i].set_ylim(0, ylim[1]*1.2)
                # add colorbar on last row
                if j==len(ba_opts)-1:
                    cbar_ax = fig1.add_axes([ax1[j,i].get_position().x0, 0.07, 
                                            ax1[j,i].get_position().width, 0.015])
                    fig1.colorbar(im, cax=cbar_ax, orientation='horizontal', label='Elevation residual [m]')
                # add stats
                for axis in [ax1[j,i], ax2[j,i]]:
                    axis.text(0.01, 0.93, f'Median: {np.round(float(ddem_ss_median),2)} m',
                                transform=axis.transAxes)
                    axis.text(0.01, 0.85, f'NMAD: {np.round(float(nmads[j,i]),2)} m',
                                transform=axis.transAxes)
                dem_count += 1

    # thicker frame for the lowest NMAD combo
    ibest = np.argwhere(np.abs(nmads)==np.min(np.abs(nmads)))[0]
    ax1[ibest[0]][ibest[1]].spines[['top', 'bottom', 'right', 'left']].set_linewidth(3)
    ax2[ibest[0]][ibest[1]].spines[['top', 'bottom', 'right', 'left']].set_linewidth(3)

    # panel labels 
    for i, axis in enumerate(ax1.ravel()):
        axis.text(0.9, 0.9, text_labels[i], transform=axis.transAxes, 
                fontweight='bold', fontsize=14, bbox=dict(facecolor='white', edgecolor='None'))
    for i, axis in enumerate(ax2.ravel()):
        axis.text(0.9, 0.9, text_labels[i], transform=axis.transAxes, 
                fontweight='bold', fontsize=14, bbox=dict(facecolor='white', edgecolor='None'))

    plt.show()
                
    # Save to file
    fig1_fn = os.path.join(figures_path, f'param_tests_maps_MCS_{date}.png')
    fig1.savefig(fig1_fn, dpi=300, bbox_inches='tight')
    print('Figure 1 saved to file:', fig1_fn)
    fig2_fn = os.path.join(figures_path, f'param_tests_histgrams_MCS_{date}.png')
    fig2.savefig(fig2_fn, dpi=300, bbox_inches='tight')
    print('Figure 2 saved to file:', fig2_fn)

## Figure S3. Model feature importances

In [None]:
# Set up figure
fig, ax = plt.subplots(4, 2, figsize=(10,10))

# Iterate over snow depth models
dem_count = 0
for site_name in list(info_dict.keys()):
    site_name_display = info_dict[site_name]['site_name_display']
    print(site_name)
    # Get DEM file names
    dem_fns = info_dict[site_name]['DEM_fns']
    # Iterate over DEM file names
    for dem_fn in dem_fns:
        print(os.path.basename(dem_fn))
        # Get date from file name
        date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        if (site_name=='MCS') & (date=='20241003'):
            continue
        if date=='20240419-1':
            date_display = '2024-04-19a'
        elif date=='20240419-2':
            date_display = '2024-04-19b'
        else:
            date_display = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
        
        # Load feature importances
        fi_fn = os.path.join(os.path.dirname(dem_fn), '..', 'snow_depth_modeling', 'feature_importances.csv')
        fi = pd.read_csv(fi_fn)
        fi.rename(columns={'Unnamed: 0': 'feature_column'}, inplace=True)
        fi['feature_column_display'] = [x.replace('topographic_position_index', 'TPI').replace('Sx', '$S_x$') for x in fi['feature_column']]
        
        # Plot
        ax[dem_count,0].bar(fi['feature_column_display'], fi['MDI'], yerr=fi['MDI_std'], color='#bf812d')
        ax[dem_count,1].bar(fi['feature_column_display'], fi['permutation'], yerr=fi['permutation_std'], color='#35978f')
        
        ax[dem_count,0].set_ylabel(f"{site_name_display}\n{date_display}", rotation=0, ha='right')
        dem_count += 1
        
# add titles
ax[0,0].set_title('Mean decrease in impurity')
ax[0,1].set_title('Permutation')

# add panel labels
text_labels = list(string.ascii_lowercase)
for i, axis in enumerate(ax.ravel()):
    axis.text(0.9, 0.9, text_labels[i], transform=axis.transAxes, 
              fontweight='bold', bbox=dict(facecolor='None', edgecolor='None'))

fig.tight_layout()
plt.show()

# Save figure
fig_fn = os.path.join(figures_path, 'figS3_model_feature_importances.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

## Presentation figures

### Parallax example

In [None]:
out_folder = os.path.join(figures_path, 'multiview_gif')
if not os.path.exists(out_folder):
    os.mkdir(out_folder)

fns = sorted(glob.glob(os.path.join(data_path, 'MCS', '20240420', '*first_orthomosaic.tif')))
titles = ['Left view', 'Right view']
for i, fn in enumerate(fns[0:2]):
    fig = plt.figure(figsize=(6,6))
    raster = gu.Raster(fn)
    raster.plot(cmap='Grays_r', add_cbar=False)
    plt.title(titles[i], fontsize=18)
    plt.xlim(607.25e3, 607.5e3)
    plt.ylim(4866.05e3, 4866.3e3)
    plt.xticks([])
    plt.yticks([])
    plt.show()
    fig.savefig(os.path.join(out_folder, titles[i].lower().replace(' ', '_') + '.png'), dpi=300, bbox_inches='tight')

### Image footprints to DEM

In [None]:
image_fns = sorted(glob.glob(os.path.join(data_path, 'MCS', '20240420', 'SkySatScene', '20*_panchromatic.tif')))
meta_fns = sorted(glob.glob(os.path.join(data_path, 'MCS', '20240420', 'SkySatScene', '*_metadata.json')))
dem_fn = os.path.join(data_path, 'MCS', '20240420', 'MCS_20240420_DEM.tif')

# Image footprints
fig, ax = plt.subplots(1, 2, figsize=(10,5))
for meta_fn in meta_fns:
    meta = json.load(open(meta_fn))
    bounds = np.array(meta['geometry']['coordinates'])[0]
    bounds_poly = Polygon(bounds)
    bounds_gdf = gpd.GeoDataFrame(geometry=[bounds_poly], crs="EPSG:4326")
    bounds_gdf = bounds_gdf.to_crs("EPSG:32611")
    bounds_gdf.plot(facecolor='gray', alpha=0.3, ax=ax[0])

# DEM
dem = gu.Raster(dem_fn, load_data=True).reproject(res=5)
dem_bounds = dem.bounds
# initialize axes bounds for later
xmin, ymin, xmax, ymax = dem_bounds.left, dem_bounds.bottom, dem_bounds.right, dem_bounds.top
# Shaded relief
ls = LightSource(azdeg=315, altdeg=45)
hs = ls.hillshade(dem.data, vert_exag=5)
ax[1].imshow(hs, cmap='Greys_r',
             extent=(xmin, xmax, ymin, ymax))
im = ax[1].imshow(dem.data, cmap='terrain', alpha=0.7, clim=(1500, 3000),
                  extent=(xmin, xmax, ymin, ymax))

for axis in ax:
    axis.set_xticks([])
    axis.set_yticks([])
    axis.spines[['top', 'right', 'left', 'bottom']].set_visible(False)

fig.tight_layout()
plt.show()

# Save figure
fig_fn = os.path.join(figures_path, 'frames_to_DEM.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)
    

### Snow depth maps

In [None]:
# Set up figure
fig, ax = plt.subplots(2,4, figsize=(12,6), gridspec_kw={'height_ratios': [2,1]})

# Iterate over sites
dem_count = 0
for site_name in list(info_dict.keys()):
    dem_fns = info_dict[site_name]['DEM_fns']
    site_name_display = info_dict[site_name]["site_name_display"]
    
    # Load reference elevation
    refdem_fn = info_dict[site_name]['refdem_fn']
    elev_fn = refdem_fn.replace('.tif', '_ELEVATION.tif')
    # elev = xdem.DEM(elev_fn)
    
    # Iterate over DEMs
    for dem_fn in dem_fns:
        # Grab date from file name
        date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        if (site_name=='MCS') & (date=='20241003'):
            continue
        if date=='20240419-1':
            date_display = '2024-04-19a'
        elif date=='20240419-2':
            date_display = '2024-04-19b'
        else:
            date_display = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
        
        # Load orthomosaic
        ortho_fn = os.path.join(data_path, site_name, date, os.path.basename(dem_fn).split('_DEM')[0] + '_orthomosaic.tif')
        ortho = gu.Raster(ortho_fn)
        ortho = clip_nodata(ortho)
        
        # Load snow depth map
        sd_fn = glob.glob(os.path.join(data_path, site_name, date, 'post_process', '*_snow_depth.tif'))[0]
        sd = xdem.DEM(sd_fn).reproject(ortho)

        # Reproject elevations
        # elev_reproj = elev.reproject(ortho)            
        
        # # Create dataframe of snow depths and elevations
        # nbins = 10
        # df = pd.DataFrame({'elevation': elev_reproj.data.ravel(),
        #                    'snow_depth': sd.data.ravel()})
        # elev_min = np.floor(np.nanmin(df['elevation']) / 100) * 100
        # elev_max = np.ceil(np.nanmax(df['elevation']) / 100) * 100
        # df['elevation_bin'] = pd.cut(df['elevation'], bins=np.linspace(elev_min, elev_max, nbins + 1))
        
        # Plot
        # snow depth maps
        ortho.plot(ax=ax[0,dem_count], cmap='Grays_r', add_cbar=False)
        if dem_count==3:
            sd.plot(ax=ax[0,dem_count], cmap='YlGnBu', vmin=0, vmax=4, alpha=0.5, add_cbar=True, cbar_title='Snow depth [m]')
        else:
            sd.plot(ax=ax[0,dem_count], cmap='YlGnBu', vmin=0, vmax=4, alpha=0.5, add_cbar=False)
        ax[0,dem_count].set_xticks([])
        ax[0,dem_count].set_yticks([])
        ax[0,dem_count].set_title(f"{site_name_display}\n{date_display}")
        # scalebar
        scalebar = AnchoredSizeBar(ax[0,dem_count].transData,
                                1e3, '1 km', 'lower right', 
                                pad=0.2,
                                color='k',
                                sep=5,
                                frameon=True,
                                fill_bar=True,
                                fontproperties=fm.FontProperties(size=10, weight='bold'))
        ax[0,dem_count].add_artist(scalebar)
        # adjust line width of scale bar
        yrange = ax[0,dem_count].get_ylim()[1] - ax[0,dem_count].get_ylim()[0]
        size_vertical = 0.0075 * yrange
        scalebar.size_bar.get_children()[0].set_height(size_vertical)
        
        # snow depth histograms
        ax[1,dem_count].hist(sd.data.ravel(), bins=np.linspace(-1,5,40), edgecolor='k', linewidth=0.5,
                                facecolor='#9ecae1')
        ax[1,dem_count].set_xlabel('Snow depth [m]')
        ax[1,dem_count].axhline(0, color='k', linewidth=1)
        ax[1,dem_count].set_yticks([])
        ax[1,dem_count].spines[['top', 'right', 'left']].set_visible(False)
            
        # elevation vs. snow depth boxplots
        # sns.boxplot(df, x='elevation_bin', y='snow_depth', legend=False, color='#9ecae1', 
        #             showfliers=False, ax=ax[1,dem_count])
        # ax[1,dem_count].set_ylim(-2,7)
        # ax[1,dem_count].set_ylim(-2,5)
        # ax[1,dem_count].axhline(0, color='k', linewidth=1)
        # ax[1,dem_count].set_xlabel('Elevation [m]')
        # ax[1,dem_count].set_ylabel('')
        # ax[1,dem_count].set_xticks(np.arange(0,10))
        # ax[1,dem_count].set_xticklabels([str(int(elev_min))] + ['']*3 
        #                                 + [str(int((elev_min+elev_max)/2))] + ['']*4 
        #                                 + [str(int(elev_max))])
        
        
        dem_count += 1

# Save figure
fig_fn = os.path.join(figures_path, 'snow_depth_maps.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

plt.close()



### Lidar vs. SkySat snow depth

In [None]:
site_name = "MCS"
date = "20240420"

lidar_sd_fn = os.path.join(data_path, site_name, 'SNEX_MCS_Lidar', 'SNEX_MCS_Lidar_20240418_SD_V01.0.tif')
lidar_sd = xdem.DEM(lidar_sd_fn)

skysat_sd_fn = glob.glob(os.path.join(data_path, site_name, date, 'post_process', '*_snow_depth.tif'))[0]
skysat_sd = xdem.DEM(skysat_sd_fn).reproject(lidar_sd)

fig, ax = plt.subplots(1, 2, figsize=(10,5))

lidar_sd.plot(cmap='YlGnBu', vmin=0, vmax=4, ax=ax[0], add_cbar=False)
ax[0].set_title('Lidar 2024-04-18')
skysat_sd.plot(cmap='YlGnBu', vmin=0, vmax=4, ax=ax[1], add_cbar=True, cbar_title='Snow depth [m]')
ax[1].set_title('SkySat 2024-04-20')
for axis in ax:
    axis.set_xticks([])
    axis.set_yticks([])

fig.tight_layout()

fig_fn = os.path.join(figures_path, 'snow_depth_maps_skysat_vs_lidar_MCS.png')
fig.savefig(fig_fn)
print('Figure saved to file:', fig_fn)

plt.close()

In [None]:
# for site_name in list(info_dict.keys()):
#     dem_fns = info_dict[site_name]['DEM_fns']
#     site_name_display = info_dict[site_name]["site_name_display"]
    
#     # Load reference elevation
#     refdem_fn = info_dict[site_name]['refdem_fn']
#     elev_fn = refdem_fn.replace('.tif', '_ELEVATION.tif')
#     # elev = xdem.DEM(elev_fn)
    
#     # Iterate over DEMs
#     for dem_fn in dem_fns:
#         # Grab date from file name
#         date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        
#         if (site_name=='MCS') & (date=='20241003'):
#             continue
    
#         # Load snow depth map
#         sd_fn = glob.glob(os.path.join(data_path, site_name, date, 'post_process', '*_snow_depth.tif'))[0]
#         sd = xdem.DEM(sd_fn)
#         print(site_name, date, np.ma.median(sd.data.ravel()))


### Model input examples

In [None]:
plt.rcParams.update({'font.size': 24})

site_name = "Banner"
date = "20240419-1"

# Load inputs
refdem_fn = info_dict[site_name]['refdem_fn']
elev_fn = refdem_fn.replace('.tif', '_ELEVATION.tif')
elev = xdem.DEM(elev_fn)
sd_fn = glob.glob(os.path.join(data_path, site_name, date, 'post_process', '*_snow_depth.tif'))[0]
sd = xdem.DEM(sd_fn)
sd_mod_fn = glob.glob(os.path.join(data_path, site_name, date, 'snow_depth_modeling', 'modeled_snow_depth*.tif'))[0]
sd_mod = xdem.DEM(sd_mod_fn)

fig, ax = plt.subplots(3, 1, figsize=(6,18))

# snow depth
sd.plot(cmap='YlGnBu', vmin=0, vmax=4, ax=ax[0], add_cbar=False)
xmin, xmax, ymin, ymax = sd.bounds.left, sd.bounds.right, sd.bounds.bottom, sd.bounds.top
# ax[0].set_title('SkySat snow depths')
ax[0].set_ylabel('~7 km')

# shaded relief
ls = LightSource(azdeg=315, altdeg=45)
hs = ls.hillshade(elev.data, vert_exag=5)
ax[1].imshow(hs, cmap='Greys_r',
          extent=(elev.bounds.left, elev.bounds.right, elev.bounds.bottom, elev.bounds.top))
hs_im = ax[1].imshow(elev.data, cmap='terrain', alpha=0.7, clim=(1500, 3000),
                extent=(elev.bounds.left, elev.bounds.right, elev.bounds.bottom, elev.bounds.top))
ax[1].plot([xmin, xmax, xmax, xmin, xmin], [ymin, ymin, ymax, ymax, ymin], '-k', linewidth=2)
ax[1].set_ylabel('16.5 km')
# ax[1].set_title('Reference terrain characteristics')
ax[1].spines[['bottom', 'right', 'top']].set_visible(False)

# modeled snow depth
sd_mod.plot(cmap='YlGnBu', vmin=0, vmax=4, ax=ax[2], add_cbar=False)
ax[2].set_ylabel('16.5 km')
# ax[2].set_title('Modeled snow depths')
ax[2].spines[['bottom', 'right', 'top']].set_visible(False)

for axis in ax:
    axis.set_xticks([])
    axis.set_yticks([])
    
    
plt.show()

fig_fn = os.path.join(figures_path, 'model_inputs.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)