# Make figures

In [None]:
import os,glob
import matplotlib.pyplot as plt
import pandas as pd
import geopandas as gpd
import contextily as ctx
from shapely import wkt
from shapely.ops import unary_union
from shapely.geometry import Polygon, Point, MultiPolygon
import geoutils as gu
import numpy as np
from geopy.geocoders import Nominatim
from matplotlib.colors import LightSource
from matplotlib.gridspec import GridSpec
import string
import xdem
import seaborn as sns
import sys
import json
import rioxarray as rxr
import xarray as xr
import matplotlib
from tqdm.auto import tqdm

# Import processing functions
code_dir = '/Users/raineyaberle/Research/PhD/SnowDEMs/skysat_stereo_snow/skysat_stereo_snow'
sys.path.append(code_dir)
import post_process_utils as pprocess

plt.rcParams.update({'font.size':12, 'font.sans-serif': 'Arial'})

In [9]:
# define some paths in directory for convenience
data_path = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites'
figures_path = '/Users/raineyaberle/Research/PhD/SnowDEMs/skysat_stereo_snow/figures'
  
# Define site-specific details
info_dict = {
    "MCS": {
        "refdem_fn": os.path.join(data_path, "MCS", 'refdem', "MCS_REFDEM_WGS84.tif"),
        "SNOTEL_fn": os.path.join(data_path, "MCS", 'snotel', 'MCS_2020-01-01_2024-06-07_adj.csv'),
        "SNOTEL_site_fn": os.path.join(data_path, "MCS", 'snotel', 'MCS_SNOTEL_site_info.csv'),
        "DEM_fns": [
            os.path.join(data_path, "MCS", "20240420", "post_process", "ba+DEMuncertainty1mAll_MCS_20240420_DEM_GCPshift.tif"),
            os.path.join(data_path, "MCS", "20241003", "post_process", "ba+DEMuncertainty1mAll_MCS_20241003_DEM_GCPshift.tif")
        ],
        "satellites": ["s116", "s109"],
        "site_name_display": "Mores Creek Summit"
    },
    "JacksonPeak": {
        "refdem_fn": os.path.join(data_path, "JacksonPeak", 'refdem', "USGS_LPC_ID_FEMAHQ_2018_D18_merged_filtered_filled.tif"),
        "SNOTEL_fn": os.path.join(data_path, "JacksonPeak", 'snotel', "JacksonPeak_2023-01-01_2024-06-07_adj.csv"),
        "SNOTEL_site_fn": os.path.join(data_path, "JacksonPeak", 'snotel', 'JacksonPeak_SNOTEL_site_info.csv'),
        "DEM_fns": [
            os.path.join(data_path, "JacksonPeak", "20240420", "post_process", "ba+DEMuncertainty1mAll_JacksonPeak_20240420_DEM_GCPshift.tif")
        ],
        "satellites": ["s113"],
        "site_name_display": "Jackson Peak"
    },
    "Banner": {
        "refdem_fn": os.path.join(data_path, "Banner", 'refdem', "Banner_REFDEM_WGS84.tif"),
        "SNOTEL_fn": os.path.join(data_path, "Banner", 'snotel', "Banner_2020-01-01_2024-06-07_adj.csv"),
        "SNOTEL_site_fn": os.path.join(data_path, "Banner", 'snotel', 'Banner_SNOTEL_site_info.csv'),
        "DEM_fns": [
            os.path.join(data_path, "Banner", "20240419-1", "post_process", "ba+DEMuncertainty1mAll_Banner_20240419-1_DEM_GCPshift.tif"),
            os.path.join(data_path, "Banner", "20240419-2", "post_process", "ba+DEMuncertainty1mAll_Banner_20240419-2_DEM_GCPshift.tif")
        ],
        "satellites": ["s112", "s108"],
        "site_name_display": "Banner Summit"
    },
}

## Study sites map

In [None]:
# Set up figure
crs = "EPSG:32611"
fontsize = 12
plt.rcParams.update({'font.size':fontsize, 'font.sans-serif': "Arial"})
fig, ax = plt.subplots(2,2, figsize=(10,10))
ax = ax.ravel()

# Load and plot HWY21
hwy21_fn = os.path.join(data_path, '..', 'ITD_Functional_Class', 'ITD_HWY_21.shp')
hwy21 = gpd.read_file(hwy21_fn).to_crs(crs)
for i, axis in enumerate(ax):
    for j, geom in enumerate(hwy21.geometry[0].geoms):
        if (i==0) & (j==0):
            label='ID-21'
        else:
            label='_nolegend'
        axis.plot(*geom.coords.xy, color='w', linewidth=2)
        axis.plot(*geom.coords.xy, color='k', linewidth=1, label=label)
        
def bounds_to_polygon(bounds):
    return Polygon([[bounds.left, bounds.bottom],
                    [bounds.right, bounds.bottom],
                    [bounds.right, bounds.top],
                    [bounds.left, bounds.top],
                    [bounds.left, bounds.bottom]])
    
# iterate over site names
labels = ['a', 'b', 'c', 'd']
for i, site_name in enumerate(list(info_dict.keys())):
    # basemap
    ctx.add_basemap(ax=ax[i+1], source=ctx.providers.USGS.USImagery, attribution=False, crs=crs, zoom=12)
    
    # Reference DEM shaded relief
    refdem_fn = info_dict[site_name]['refdem_fn']
    refdem = gu.Raster(refdem_fn, load_data=True).reproject(crs=crs, res=10)
    # Shaded relief
    ls = LightSource(azdeg=315, altdeg=45)
    hs = ls.hillshade(refdem.data, vert_exag=5)
    ax[i+1].imshow(hs, cmap='Greys_r',
                 extent=(refdem.bounds.left, refdem.bounds.right, refdem.bounds.bottom, refdem.bounds.top))
    im = ax[i+1].imshow(refdem.data, cmap='terrain', alpha=0.7, clim=(1500, 3000),
                      extent=(refdem.bounds.left, refdem.bounds.right, refdem.bounds.bottom, refdem.bounds.top))
    # outline on map
    ax[0].plot([refdem.bounds.left-2e3, refdem.bounds.right+2e3, refdem.bounds.right+2e3, refdem.bounds.left-2e3, refdem.bounds.left-2e3],
               [refdem.bounds.bottom-2e3, refdem.bounds.bottom-2e3, refdem.bounds.top+2e3, refdem.bounds.top+2e3, refdem.bounds.bottom-2e3], 
               '-w', linewidth=1, label='_nolegend')
    ax[0].text((refdem.bounds.left + refdem.bounds.right)/2, (refdem.bounds.bottom + refdem.bounds.top)/2, 
               labels[i+1], fontweight='bold', color='w', fontsize=fontsize+2, ha='center', va='center')
    
    # Image footprints
    dem_fns = info_dict[site_name]['DEM_fns']
    for dem_fn in dem_fns:
        im_dir = os.path.join(os.path.dirname(dem_fn), '..', 'SkySatScene')
        meta_fns = sorted(glob.glob(os.path.join(im_dir, '*_metadata.json')))
        polys = []
        for meta_fn in meta_fns:
            with open(meta_fn, 'r') as f:
                meta = json.load(f)
                polys.append([Polygon(meta['geometry']['coordinates'][0])])
        polys_merged = unary_union(polys)
        polys_gdf = gpd.GeoDataFrame(geometry=[polys_merged], crs='EPSG:4326')
        polys_gdf = polys_gdf.to_crs(crs)
        ax[i+1].plot(*polys_gdf.geometry[0].exterior.coords.xy, color='#fe9929', linewidth=2, label='SkySat images footprint')
    
    # SNOTEL site info
    snotel_fn = info_dict[site_name]['SNOTEL_site_fn']
    snotel = pd.read_csv(snotel_fn)
    snotel['geometry'] = snotel['geometry'].apply(wkt.loads)
    snotel_gdf = gpd.GeoDataFrame(snotel, crs='EPSG:4326')
    snotel_gdf = snotel_gdf.to_crs(dem.crs)
    ax[i+1].plot(*snotel_gdf['geometry'].values[0].coords.xy, '*', markersize=15, 
                 markerfacecolor='b', markeredgecolor='w', linewidth=0.3, label="SNOTEL site")
    
    # title
    ax[i+1].set_title(f"{labels[i+1]}) {info_dict[site_name]['site_name_display']}")
    # adjust axes
    ax[i+1].set_xticks([])
    ax[i+1].set_yticks([])
    ax[i+1].set_xlim(refdem.bounds.left-2e3, refdem.bounds.right+2e3)
    ax[i+1].set_ylim(refdem.bounds.bottom-2e3, refdem.bounds.top+2e3)
    
# legend
handles0, labels0 = ax[0].get_legend_handles_labels()
handles2, labels2 = ax[2].get_legend_handles_labels()
fig.legend(handles0+handles2, labels0+labels2, loc='upper center', 
           ncols=len(labels0+labels2), frameon=False, bbox_to_anchor=[0.4, 0.75, 0.2, 0.2])

# colorbar
fig.subplots_adjust(right=0.87, hspace=0.25)
cbar_ax = fig.add_axes([0.9, 0.25, 0.03, 0.5])
fig.colorbar(im, cax=cbar_ax, shrink=0.8, label='Elevation [m]')

# ax[0] adjustments
ax[0].set_title('a) All sites')
ax[0].set_xticks(ax[0].get_xticks())
ax[0].set_xticklabels(np.divide(ax[0].get_xticks(), 1e3).astype(int).astype(str))
ax[0].set_yticks(ax[0].get_yticks())
ax[0].set_yticklabels(np.divide(ax[0].get_yticks(), 1e3).astype(int).astype(str))
ax[0].set_xlabel('Easting [km]')
ax[0].set_ylabel('Northing [km]')
xmin, xmax = 555e3, 675e3
ymin, ymax = 4818e3, 4920e3
ax[0].set_xlim(xmin, xmax)
ax[0].set_ylim(ymin, ymax)

# Plot Idaho locator map
idaho_fn = '/Users/raineyaberle/Research/PhD/GIS_data/US_State_Boundaries/Idaho_State_Boundaries.gpkg'
idaho = gpd.read_file(idaho_fn)
idaho = idaho.to_crs(crs)
ax2 = fig.add_axes([0.11, 0.73, 0.15, 0.15])
ax2.set_axis_off()
idaho.plot(facecolor='#bdbdbd', edgecolor='k', linewidth=0.5, ax=ax2)
xmin, xmax = ax[0].get_xlim()
ymin, ymax = ax[0].get_ylim()
ax2.fill_between([xmin, xmax], [ymin, ymin], [ymax, ymax], color='g')

# Add city names
geolocator = Nominatim(user_agent="My-Map")
cities = ["Boise", "Idaho City", "Stanley, Custer County, Idaho"]
df = pd.DataFrame(columns=['city', 'geometry'])
for city in cities:
    location = geolocator.geocode(city, addressdetails=True)
    df = pd.concat([df, pd.DataFrame({'city': [city],
                                      'geometry': Point(location.longitude, location.latitude)})])
gdf = gpd.GeoDataFrame(df, crs='EPSG:4326')
gdf = gdf.to_crs((crs)).reset_index(drop=True)
for city in cities:
    ax[0].text(gdf.loc[gdf['city']==city].geometry.values[0].coords.xy[0][0], 
               gdf.loc[gdf['city']==city].geometry.values[0].coords.xy[1][0],
               city.split(',')[0], fontsize=fontsize, color='w', fontweight='bold', style='italic', ha='center')

# basemap
ctx.add_basemap(ax=ax[0], source=ctx.providers.USGS.USImagery, attribution=False, crs=dem.crs)

# fig.tight_layout(pad=3)
plt.show()
    
# Save figure
fig_fn = os.path.join(figures_path, "study_sites.png")
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)
    

In [None]:
# Plot elevation, slope, and aspect histograms
for site_name in list(info_dict.keys()):
    refdem_fn = info_dict[site_name]['refdem_fn']
    refdem = xdem.DEM(refdem_fn).reproject(res=2)
    slope = xdem.terrain.slope(refdem)
    aspect = xdem.terrain.aspect(refdem)
    
    fig, ax = plt.subplots(1, 3, figsize=(10,5))
    ax[0].hist(refdem.data.ravel(), bins=30)
    ax[0].set_xlabel('Elevation [m]')
    ax[1].hist(slope.data.ravel(), bins=np.arange(0,61,2))
    ax[1].set_xlabel('Slope [degrees]')
    ax[2].hist(aspect.data.ravel(), bins=np.arange(0, 361, 30))
    ax[2].set_xlabel('Aspect [degrees]')
    fig.suptitle(site_name)
    fig.tight_layout()
    plt.show()

## Simple errors distribution plot for workflow

In [None]:
# Generate a synthetic distribution of errors
np.random.seed(42)
errors = np.random.normal(loc=0, scale=1, size=1000)

# Calculate statistics
median = np.median(errors)
nmad = 1.4826 * np.median(np.abs(errors - median))  # Normalized Median Absolute Deviation
quantile_25 = np.percentile(errors, 25)
quantile_75 = np.percentile(errors, 75)

# Plot histogram
plt.rcParams.update({'font.size': 20, 'font.sans-serif': "Arial"})
fig, ax = plt.subplots(figsize=(7, 5))
ax.hist(errors, bins=30, alpha=0.7, color='gray')
# Add vertical lines for median and quantiles
lw = 3
ax.axvline(median, color='k', linestyle='-', linewidth=lw, label='Median')
ax.axvline(quantile_25, color='k', linestyle='--', linewidth=lw, label='_nolegend')
ax.axvline(quantile_75, color='k', linestyle='--', linewidth=lw, label=f'Quantiles')
ax.axvline(median, color='w', linestyle='-', linewidth=0, label='NMAD')
# Add text for NMAD
ax.set_xlabel('Error')
ax.set_ylabel('Frequency')
ax.legend(loc='best', frameon=False)
ax.spines[['top', 'right']].set_visible(False)
ax.spines[['left', 'bottom']].set_linewidth(lw)
ax.set_xticks([0])
ax.set_yticks([])
fig.tight_layout()
plt.show()

# Save figure
fig_fn = os.path.join(figures_path, "error_distribution_example.png")
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

## Land cover masks

In [None]:
# Define colors for land cover types
colors_dict = {'stable_surfaces': '#d9d9d9',
               'trees': '#006d2c', 
               'snow': '#4292c6'}

# Set up figure
fig, ax = plt.subplots(5,2, figsize=(8,14))

# Iterate over sites
dem_count = 0
for site_name in list(info_dict.keys()):
    dem_fns = info_dict[site_name]['DEM_fns']
    site_name_display = info_dict[site_name]["site_name_display"]
    
    # Iterate over DEMs
    for dem_fn in dem_fns:
        # Grab date from file name
        date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        if date=='20240419-1':
            date_display = '2024-04-19a'
        elif date=='20240419-2':
            date_display = '2024-04-19b'
        else:
            date_display = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
            
        if (site_name=='MCS') & (date=='20240420'):
            clim = (0,0.5)
        else:
            clim = (0,0.9)
                    
        # Load input files
        rgb_fn = glob.glob(os.path.join(os.path.dirname(dem_fn), '..', '*4band*.tif'))[0]
        rgb = rxr.open_rasterio(rgb_fn)
        crs = rgb.rio.crs
        # rgb = rgb.rio.reproject(resolution=30, dst_crs=crs)
        rgb = xr.where(rgb <= 0, np.nan, rgb / 1e4)
        def load_raster(fn):
            raster = rxr.open_rasterio(fn).squeeze()
            raster = xr.where((raster==0) | (np.abs(raster) > 1e10), np.nan, raster)
            return raster
        snow_mask_fn = os.path.join(os.path.dirname(dem_fn), '..', 'stable_surfaces', 'snow_mask.tif')
        snow_mask = load_raster(snow_mask_fn)
        trees_mask_fn = os.path.join(os.path.dirname(dem_fn), '..', 'stable_surfaces', 'trees_mask.tif')
        trees_mask = load_raster(trees_mask_fn)
        ss_mask_fn = os.path.join(os.path.dirname(dem_fn), '..', 'stable_surfaces', 'stable_surfaces_mask.tif')
        ss_mask = load_raster(ss_mask_fn)
    
        # Plot
        # RGB image
        ax[dem_count,0].imshow(np.dstack([rgb.data[2], rgb.data[1], rgb.data[0]]), clim=clim,
                     extent=(np.min(rgb.x)/1e3, np.max(rgb.x)/1e3, np.min(rgb.y)/1e3, np.max(rgb.y)/1e3))
        ax[dem_count,0].set_ylabel(f"{site_name_display}\n{date_display}", rotation=0, ha='right')
        ax[dem_count,0].set_xticks([])
        ax[dem_count,0].set_yticks([])
        # Land cover masks
        for mask, label in zip([ss_mask, trees_mask, snow_mask], list(colors_dict.keys())):
            cmap = matplotlib.colors.ListedColormap([(0,0,0,0), matplotlib.colors.to_rgb(colors_dict[label])])
            ax[dem_count,1].imshow(mask, cmap=cmap, clim=(0,1),
                           extent=(np.min(mask.x)/1e3, np.max(mask.x)/1e3, np.min(mask.y)/1e3, np.max(mask.y)/1e3))
            # dummy point for legend
            ax[dem_count,1].plot(0, 0, 's', color=colors_dict[label], markersize=10, label=label) 
        ax[dem_count,1].set_xlim(ax[dem_count,0].get_xlim())
        ax[dem_count,1].set_ylim(ax[dem_count,0].get_ylim())
        ax[dem_count,1].set_yticks([])
        ax[dem_count,1].set_xticks([])

        dem_count += 1
        
# Add titles
ax[0,0].set_title('RGB mosaic')
ax[0,1].set_title('Land cover masks')
handles, labels = ax[-1,1].get_legend_handles_labels()
labels = ["Stable surfaces", "Trees", "Snow"]

# Add panel labels
text_labels = list(string.ascii_lowercase)
for i, axis in enumerate(ax.ravel()):
    axis.text(0.05, 0.85, text_labels[i], transform=axis.transAxes, 
              fontweight='bold', bbox=dict(facecolor='white', edgecolor='None'))
fig.tight_layout()

# Add legend
ax[0,1].legend(handles, labels, loc='upper center', frameon=False, ncols=3, 
               bbox_to_anchor=[0.35, 1.1, 0.2, 0.2], handletextpad=0.3, columnspacing=1)

plt.close()

# Save figure
fig_fn = os.path.join(figures_path, 'land_cover_masks.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

## Parameter testing: reference DEMs for coreg and bundle adjust at MCS snow-free only

In [None]:
dem_fns = glob.glob(os.path.join(data_path, 'MCS', '20241003', '*_DEM.tif'))

# Load reference DEM
refdem_fn = os.path.join(data_path, 'MCS', 'refdem', 'MCS_REFDEM_WGS84.tif')
refdem = xdem.DEM(refdem_fn, load_data=True)

# Load stable surfaces mask
ss_mask_fn = os.path.join(data_path, 'MCS', '20241003', 'stable_surfaces', 'stable_surfaces_mask.tif')
ss_mask = gu.Raster(ss_mask_fn, load_data=True)

# Load GCP
gcp_fn = os.path.join(data_path, '..', 'ITD_Functional_Class', 'ITD_HWY_21.shp')
gcp_elev = 0
gcp = gu.Vector(gcp_fn)

# Define all options tested
refdem_opts = ['COPDEM+AllLidar', 'COPDEM+SSLidar', 'COPDEM+roads']
refdem_opts_display = ['All surfaces', 'Stable surfaces', 'Roads']
ba_opts = ['noRefDEM', 'u0.1m', 'u1m', 'u10m']
ba_opts_display = ['No DEM', 'DEM $\sigma$ = 0.1 m', 'DEM $\sigma$ = 1 m', 'DEM $\sigma$ = 10 m']
text_labels = list(string.ascii_lowercase)

# Set up figures
# Map view
fig1, ax1 = plt.subplots(len(ba_opts), len(refdem_opts), figsize=(len(refdem_opts)*4, len(ba_opts)*4))
# Histograms
fig2, ax2 = plt.subplots(len(ba_opts), len(refdem_opts), figsize=(len(refdem_opts)*4, len(ba_opts)*4))
    
# Iterate over options
dem_count = 0
nmads = 1e4 * np.ones(np.shape(ax))
for i, refdem_opt in enumerate(refdem_opts):
    dem_refdem_fns = [x for x in dem_fns if refdem_opt in os.path.basename(x)]
    if i==2:
        vmin, vmax = -200,200
    else:
        vmin, vmax=-5,5
    for j, ba_opt in enumerate(ba_opts):
        dem_refdem_ba_fns = [x for x in dem_refdem_fns if ba_opt in os.path.basename(x)]
        if len(dem_refdem_ba_fns) < 1:
            dem_count += 1
            continue
        else:
            # Add row and column titles
            if j==0:
                ax1[j,i].set_title(refdem_opts_display[i])
                ax2[j,i].set_title(refdem_opts_display[i])
            if i==0:
                ax1[j,i].set_ylabel(ba_opts_display[j], rotation=0, ha='right', fontsize=14)
                ax2[j,i].set_ylabel(ba_opts_display[j], rotation=0, ha='right', fontsize=14)
            
            dem_fn = dem_refdem_ba_fns[0]
            print(os.path.basename(dem_fn))
            
            # Load DEMs
            dem = xdem.DEM(dem_fn, load_data=True)
            refdem_reproj = refdem.reproject(dem)            
            
            # Load stable surfaces mask
            ss_mask_reproj = ss_mask.reproject(dem)
            ss_mask_reproj = (ss_mask_reproj == 1)
            
            # Load NMAD mosaic
            nmad_fn = dem_fn.replace('_DEM.tif', '_nmad_mos.tif')
            nmad = gu.Raster(nmad_fn, load_data=True).reproject(dem)
            nmad = (nmad >= 10)
            dem.set_mask(nmad)
                        
            # Calculate dDEM
            ddem = dem - refdem_reproj
            
            # Apply vertical adjustment with GCP
            gcp_reproj = gcp.reproject(dem)
            gcp_reproj = gcp_reproj.create_mask(dem)
            ddem_gcp = ddem[gcp_reproj]
            ddem_gcp_median = np.ma.median(ddem_gcp)
            print(ddem_gcp_median)
            dem -= ddem_gcp_median
            ddem -= ddem_gcp_median
            
            # Calculate stable surface errors
            ddem_ss = ddem[ss_mask_reproj]
            ddem_ss_median = np.ma.median(ddem_ss.data)
            nmads[j,i] = xdem.spatialstats.nmad(ddem_ss)
            
            # Plot
            im = ax1[j,i].imshow(ddem.data, cmap='coolwarm_r', vmin=vmin, vmax=vmax, 
                                 extent=(ddem.bounds.left, ddem.bounds.right, ddem.bounds.bottom, ddem.bounds.top))
            ax1[j,i].set_xticks([])
            ax1[j,i].set_yticks([])
            ax1[j,i].set_xlim(602000, 609650)
            ax1[j,i].set_ylim(4863200, 4871300)
            ax2[j,i].hist(ddem_ss.data.ravel(), bins=np.linspace(vmin, vmax, 20), 
                          color='gray', edgecolor='k', linewidth=0.5)
            ax2[j,i].set_yticks([])
            if i==2:
                ax2[j,i].set_xlim(-200,100)
                ylim = ax2[j,i].get_ylim()
                ax2[j,i].set_ylim(0, ylim[1]*1.2)
            # add colorbar on last row
            if j==len(ba_opts)-1:
                cbar_ax = fig1.add_axes([ax1[j,i].get_position().x0, 0.07, 
                                         ax1[j,i].get_position().width, 0.015])
                fig1.colorbar(im, cax=cbar_ax, orientation='horizontal', label='Elevation residual [m]')
            # add stats
            for axis in [ax1[j,i], ax2[j,i]]:
                axis.text(0.01, 0.93, f'Median: {np.round(float(ddem_ss_median),2)} m',
                            transform=axis.transAxes)
                axis.text(0.01, 0.85, f'NMAD: {np.round(float(nmads[j,i]),2)} m',
                            transform=axis.transAxes)
                axis.text(0.9, 0.9, text_labels[dem_count], transform=axis.transAxes, 
                          fontweight='bold', fontsize=14, bbox=dict(facecolor='white', edgecolor='None'))
            
            dem_count += 1

# thicker frame for the lowest NMAD combo
ibest = np.argwhere(nmads==np.min(nmads))[0]
ax1[ibest[0]][ibest[1]].spines[['top', 'bottom', 'right', 'left']].set_linewidth(3)
ax2[ibest[0]][ibest[1]].spines[['top', 'bottom', 'right', 'left']].set_linewidth(3)

plt.show()
            
# Save to file
fig1_fn = os.path.join(figures_path, 'param_tests_maps.png')
fig1.savefig(fig1_fn, dpi=300, bbox_inches='tight')
print('Figure 1 saved to file:', fig1_fn)
fig2_fn = os.path.join(figures_path, 'param_tests_histgrams.png')
fig2.savefig(fig2_fn, dpi=300, bbox_inches='tight')
print('Figure 2 saved to file:', fig2_fn)

## Stable surface errors: maps and histograms

In [None]:
# Set up figure
plt.rcParams.update({'font.size': 12, 'font.sans-serif':'Arial'})
fig, ax = plt.subplots(5, 3, figsize=(10,12), gridspec_kw={'width_ratios': [1, 1, 1.5]})
fig.subplots_adjust(bottom=0.1)
bins = np.arange(-5, 5.1, step=0.2)

# Iterate over sites
dem_count = 0
for site_name in list(info_dict.keys()):
    site_name_display = info_dict[site_name]['site_name_display']
    print(site_name_display)
    dem_fns = info_dict[site_name]['DEM_fns']
    # Load reference DEM
    refdem_fn = info_dict[site_name]['refdem_fn']
    refdem = xdem.DEM(refdem_fn)
    
    # Iterate over DEMs
    for dem_fn in dem_fns:
        # Grab date from file name
        date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        if date=='20240419-1':
            date_disp = '2024-04-19a'
        elif date=='20240419-2':
            date_disp = '2024-04-19b'
        else:
            date_disp = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
                    
        ### Load input files
        # Load DEM
        dem = xdem.DEM(dem_fn)
        refdem_reproj = refdem.reproject(dem)
        ddem = dem - refdem_reproj
        
        # Load SNOTEL site location and snow depth
        # snotel_site_fn = info_dict[site_name]["SNOTEL_site_fn"]
        # snotel_site = pd.read_csv(snotel_site_fn)
        # snotel_site['geometry'] = wkt.loads(snotel_site['geometry'])
        # snotel_gdf = gpd.GeoDataFrame(snotel_site, crs="EPSG:4326")
        # snotel_gdf = snotel_gdf.to_crs(dem.crs)
        # if (site_name=='MCS') & (date=='20241003'):
        #     sd_snotel = 0
        # else:
        #     snotel_fn = info_dict[site_name]["SNOTEL_fn"]
        #     snotel = pd.read_csv(snotel_fn)
        #     snotel['datetime'] = np.array(snotel['datetime']).astype('datetime64[D]')
        #     sd_snotel = snotel.loc[snotel['datetime']==np.datetime64(f"{date[0:4]}-{date[4:6]}-{date[6:8]}"), "SNWD_m"].values[0]
        
        # Load orthomosaic
        ortho_fn = glob.glob(os.path.join(os.path.dirname(dem_fn), '..', '*orthomosaic.tif'))[0]
        ortho = gu.Raster(ortho_fn)
        
        # Load stable surfaces mask
        ss_mask_fn = os.path.join(os.path.dirname(dem_fn), '..', 'stable_surfaces', 'stable_surfaces_mask.tif')
        ss_mask = gu.Raster(ss_mask_fn).reproject(dem)
        ss_mask = (ss_mask == 1)
        # Mask unstable surfaces
        ddem_ss = ddem[ss_mask]
        
        # Calculate stable surface stats
        ddem_ss_median, ddem_ss_nmad = np.nanmedian(ddem_ss.data), xdem.spatialstats.nmad(ddem_ss)
        ddem_ss_p25, ddem_ss_p75 = np.nanpercentile(ddem_ss.data, 25), np.nanpercentile(ddem_ss.data, 75)
        
        ### Plot
        # Shaded relief
        ls = LightSource(azdeg=315, altdeg=45)
        hs = ls.hillshade(dem.data, vert_exag=5)
        ax[dem_count,0].imshow(hs, cmap='Greys_r',
                            extent=(dem.bounds.left, dem.bounds.right, dem.bounds.bottom, dem.bounds.top))
        hs_im = ax[dem_count,0].imshow(dem.data, cmap='terrain', alpha=0.7, clim=(1500, 3000),
                                       extent=(dem.bounds.left, dem.bounds.right, dem.bounds.bottom, dem.bounds.top))
        ax[dem_count,0].set_ylabel(f"{site_name_display}\n{date_disp}", rotation=0, ha='right')
        # dDEM
        ddem_im = ax[dem_count,1].imshow(ddem.data, cmap='coolwarm_r', vmin=-5, vmax=5,
                                         extent=(ddem.bounds.left, ddem.bounds.right, ddem.bounds.bottom, ddem.bounds.top))        
        # Elevation residuals
        ax[dem_count,2].hist(ddem_ss.ravel(), color='gray', alpha=0.5, edgecolor='k', linewidth=0.5, bins=bins)
        ax[dem_count,2].axvline(0, color='k', linewidth=1)
        ax[dem_count,2].set_xlim(np.min(bins),np.max(bins))
        ax[dem_count,2].set_yticks([])
        ax[dem_count,2].set_ylabel('Frequency')
        
        dem_count += 1

# add titles
ax[0,0].set_title('DEM')
ax[0,1].set_title('dDEM')
ax[0,2].set_title('Stable surface errors')
ax[-1,2].set_xlabel('Elevation residual [m]')
# add colorbars
cax = fig.add_axes([ax[-1,0].get_position().x0, 0.06, ax[-1,0].get_position().width, 0.02])
fig.colorbar(hs_im, cax=cax, orientation='horizontal', label='Elevation [m]')
cax = fig.add_axes([ax[-1,1].get_position().x0, 0.06, ax[-1,1].get_position().width, 0.02])
fig.colorbar(ddem_im, cax=cax, orientation='horizontal', label='Elevation residual [m]')
# remove coordinates from maps
for axis in ax[:, 0:2].ravel():
    axis.set_xticks([])
    axis.set_yticks([])
# add panel labels
text_labels = list(string.ascii_lowercase)
for i, axis in enumerate(ax.ravel()):
    axis.text(0.05, 0.85, text_labels[i], transform=axis.transAxes, 
              fontweight='bold', bbox=dict(facecolor='white', edgecolor='None'))

plt.show()

# Save figure
fig_fn = os.path.join(figures_path, 'dDEMs.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)
    

## Stable surface errors vs. terrain parameters

In [None]:
# Set up figure
fig, ax = plt.subplots(5, 3, figsize=(12,12))

# Iterate over sites
dem_count = 0
for site_name in list(info_dict.keys()):
    print(site_name)
    refdem_fn = info_dict[site_name]['refdem_fn']
    refdem = xdem.DEM(refdem_fn)
    terrain_dict_fn = os.path.join(os.path.dirname(refdem_fn), 'refdem_terrain_sampling_resolutions.json')
    with open(terrain_dict_fn, 'r') as f:
        terrain_dict = json.load(f)
    for key, value in terrain_dict.items():
        terrain_dict[key] = float(value)
    
    # Calculate terrain parameters
    elev, slope, aspect, tpi, sx = pprocess.calculate_terrain_params(refdem, terrain_dict)
        
    dem_fns = info_dict[site_name]['DEM_fns']
    # Iterate over DEMs
    for dem_fn in dem_fns:
        print(os.path.basename(dem_fn))
        date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        if date=='20240419-1':
            date_disp = '2024-04-19a'
        elif date=='20240419-2':
            date_disp = '2024-04-19b'
        else:
            date_disp = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
        # Load DEMs and stable surface masks
        dem = xdem.DEM(dem_fn)
        refdem_reproj = refdem.reproject(dem)
        ss_mask_fn = os.path.join(data_path, site_name, date, 'stable_surfaces', 'stable_surfaces_mask.tif')
        ss_mask = gu.Raster(ss_mask_fn, load_data=True).reproject(dem)
        ss_mask = (ss_mask==1)
        
        # Calculate dDEM
        ddem = dem - refdem_reproj
        
        # Mask stable surfaces
        ddem_masked = ddem.copy()
        ddem_masked.set_mask(~ss_mask)
        
        # Regrid terrain parameters
        elev = elev.reproject(dem)
        slope = slope.reproject(dem)
        aspect = aspect.reproject(dem)
        tpi = tpi.reproject(dem)
        sx = sx.reproject(dem)
        
        # Compile in dataframe
        df = pd.DataFrame({
                'elevation': elev.data.ravel(),
                'slope': slope.data.ravel(),
                'aspect': aspect.data.ravel(),
                'dDEM': ddem_masked.data.ravel()
                })
        df.dropna(inplace=True)
        
        # Create bins for boxplots
        nbins = 20
        elev_min = np.floor(np.nanmin(df['elevation']) / 100) * 100
        elev_max = np.ceil(np.nanmax(df['elevation']) / 100) * 100
        df['elevation_bin'] = pd.cut(df['elevation'], bins=np.linspace(elev_min, elev_max, nbins + 1))
        df['slope_bin'] = pd.cut(df['slope'], bins=np.linspace(0, 90, 16))
        df['aspect_bin'] = pd.cut(df['aspect'], bins=np.linspace(0, 360, nbins + 1))        

        # Plot
        # map views
        # ddem.plot(ax=ax[dem_count, 0], cmap='coolwarm_r', vmin=-5, vmax=5, add_cbar=False)
        # ax[dem_count,0].set_xticks([])
        # ax[dem_count,0].set_yticks([])
        for i, column in enumerate(['elevation', 'slope', 'aspect']):
            # histogram
            ax2 = ax[dem_count, i].twinx()
            bin_counts = df[column+'_bin'].value_counts(sort=False)
            ax2.bar(bin_counts.index.astype(str), bin_counts.values, color='k', alpha=0.5, width=1.0)
            # Set y-axis label and color for counts
            ax2.set_yticks([])
            ax2.set_ylim(0, np.nanmax(bin_counts.values)*3)
            # boxplot
            sns.boxplot(data=df, x=column+'_bin', y='dDEM', color='#9ecae1', showfliers=False, ax=ax[dem_count, i])
            ax[dem_count, i].set_xlabel('')
            ax[dem_count, i].set_xticks(ax[dem_count, i].get_xticks())
            if column=='elevation':
                ax[dem_count,i].set_xticklabels([str(int(elev_min))] + ['']*9 + [str(int((elev_max+elev_min)/2))] + ['']*8 + [str(int(elev_max))])
                ax[dem_count,i].set_xlabel('meters')
            elif column=='slope':
                ax[dem_count,i].set_xticklabels(['0'] + ['']*4 + ['30'] + ['']*4 + ['60'] + ['']*4)
                ax[dem_count,i].set_xlabel('degrees')
            elif column=='aspect':
                ax[dem_count, i].set_xticklabels(['N'] + ['']*4 + ['E'] + ['']*4 + ['S'] + ['']*4 + ['W'] + ['']*4)
            
            ax[dem_count, i].axhline(y=0, xmin=ax[dem_count, i].get_xlim()[0], xmax=ax[dem_count, i].get_xlim()[1], 
                                     color='k', linewidth=1)
        ax[dem_count, 0].set_ylabel(f"{info_dict[site_name]['site_name_display']}\n{date_disp}", rotation=0, ha='right')

        dem_count+=1

# add panel labels
text_labels = list(string.ascii_lowercase)
for i, axis in enumerate(ax.ravel()):
    axis.text(0.05, 0.85, text_labels[i], transform=axis.transAxes, 
              fontweight='bold', bbox=dict(facecolor='None', edgecolor='None'))

ax[0,0].set_title('Elevation')
ax[0,1].set_title('Slope')
ax[0,2].set_title('Aspect')
plt.show()

# Save figure
fig_fn = os.path.join(figures_path, 'stable_surface_errors_vs_terrain.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)
                

In [None]:
# fig.tight_layout()
fig_fn = os.path.join(figures_path, 'stable_surface_errors_vs_terrain.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

## Snow depth estimates and models: Maps and scatterplots

In [None]:
# Map views
fig1, ax1 = plt.subplots(4, 3, figsize=(10,12))
# Scatterplots
fig2, ax2 = plt.subplots(1, 2, figsize=(8,6))

# Load lidar snow depths at MCS
sd_lidar_fn = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/MCS/SNEX_MCS_Lidar/SNEX_MCS_Lidar_20240418_SD_V01.0.tif'
sd_lidar = xdem.DEM(sd_lidar_fn).reproject(res=[1,1])

# Initialize data for scatter plots
diffs_site = []
diffs_mcs = []

# Iterate over site names
dem_count=0
for site_name in list(info_dict.keys()):
    print(site_name)
    dem_fns = info_dict[site_name]['DEM_fns']
    
    # Iterate over DEMs
    for dem_fn in dem_fns:
        print(os.path.basename(dem_fn))
        # Get date from file name
        date = os.path.basename(dem_fn).split(site_name)[1].split('_')[1]
        if (site_name=='MCS') & (date=='20240420'):
            continue
        if date=='20240419-1':
            date_disp = '2024-04-19a'
        elif date=='20240419-2':
            date_disp = '2024-04-19b'
        else:
            date_disp = f"{date[0:4]}-{date[4:6]}-{date[6:]}"
        
        # Load SkySat snow depth
        sd_skysat_fn = glob.glob(os.path.join(os.path.dirname(dem_fn), 'post_process', '*_snow_depth.tif'))[0]
        sd_skysat = xdem.DEM(sd_skysat_fn)
        
        # Load modeled snow depth at site
        sd_mod_site_fn = os.path.join(os.path.dirname(dem_fn), 'snow_depth_modeling', f'moded_snow_depth_{site_name}_{date}.tif')
        sd_mod_site = xdem.DEM(sd_mod_site_fn).reproject(sd_skysat)
        
        # Load modeled snow depth at MCS
        if site_name=='MCS':
            sd_mod_mcs = None
        else:
            sd_mod_mcs_fn = os.path.join(os.path.dirname(dem_fn), 'snow_depth_modeling', f'moded_snow_depth_extrapMCS.tif')
            sd_mod_mcs = xdem.DEM(sd_mod_mcs_fn).reproject(sd_skysat)
            # Calculate differences at MCS
            sd_lidar_reproj = sd_lidar.reproject(sd_mod_mcs)
            sd_diff_mcs = (sd_mod_mcs - sd_lidar_reproj).data.ravel()
            diffs_mcs.append([sd_diff_mcs])
      
        # Calculate differences at site
        sd_diff_site = (sd_mod_site - sd_skysat).data.ravel()
        diffs_site.append([sd_diff_site])
        
        # Plot
        # SkySat
        sd_skysat.plot(cmap='Blues', vmin=0, vmax=5, add_cbar=False, ax=ax[dem_count,0])
        ax[dem_count,0].set_title(f"{site_name_display}\n{date_display}", rotation=0, ha='right')
        # Modeled site
        sd_mod_site.plot(cmap='Blues', vmin=0, vmax=5, add_cbar=False, ax=ax[dem_count,0])
        # Modeled MCS
        if site_name=='MCS':
            sd_lidar.plot(cmap='Blues', vmin=0, vmax=5, add_cbar=False, ax=ax[dem_count,2])
        else:
            sd_mod_mcs.plot(cmap='Blues', vmin=0, vmax=5, add_cbar=False, ax=ax[dem_count,2])
        
        for axis in ax[dem_count,:]:
            axis.set_xticks([])
            axis.set_yticks([])
        
        dem_count+=1
            
# Add panel labels

# Add colorbar
    