# Make figures

In [None]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import geopandas as gpd
import contextily as ctx
from shapely import wkt
from shapely.geometry import Polygon, Point
import geoutils as gu
import numpy as np
from geopy.geocoders import Nominatim
from matplotlib.colors import LightSource
from matplotlib.gridspec import GridSpec
import string

In [None]:
# define some paths in directory for convenience
data_path = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites'
figures_path = '/Users/raineyaberle/Research/PhD/SnowDEMs/skysat-snow/figures'
  
# Define site-specific details
info_dict = {
    "MCS": {
        "refdem_fn": os.path.join(data_path, "MCS", 'refdem', "MCS_REFDEM_WGS84.tif"),
        "SNOTEL_fn": os.path.join(data_path, "MCS", 'snotel', 'MCS_SNOTEL_site_info.csv'),
        "DEM_fns": [
            os.path.join(data_path, "MCS", "20240420", "MCS_20240420_DEM.tif"),
            os.path.join(data_path, "MCS", "20241003", "MCS_20241003_DEM.tif")
        ],
        "satellites": ["s116", "s109"],
        "site_name_display": "Mores Creek Summit"
    },
    "JacksonCreek": {
        "refdem_fn": os.path.join(data_path, "JacksonCreek", 'refdem', "USGS_LPC_ID_FEMAHQ_2018_D18_merged_filtered.tif"),
        "SNOTEL_fn": os.path.join(data_path, "JacksonCreek", 'snotel', 'JacksonCreek_SNOTEL_site_info.csv'),
        "DEM_fns": [
            os.path.join(data_path, "JacksonCreek", "20240420", "JacksonCreek_20240420_DEM.tif")
        ],
        "satellites": ["s113"],
        "site_name_display": "Jackson Peak"
    },
    "Banner": {
        "refdem_fn": os.path.join(data_path, "Banner", 'refdem', "Banner_REFDEM_WGS84.tif"),
        "SNOTEL_fn": os.path.join(data_path, "Banner", 'snotel', 'Banner_SNOTEL_site_info.csv'),
        "DEM_fns": [
            os.path.join(data_path, "Banner", "20240419-1", "Banner_20240419-1_DEM.tif"),
            os.path.join(data_path, "Banner", "20240419-2", "Banner_20240419-2_DEM.tif")
        ],
        "satellites": ["s112", "s108"],
        "site_name_display": "Banner Summit"
    },
}

## Study sites map

In [None]:
# Set up figure
crs = "EPSG:32611"
fontsize = 12
plt.rcParams.update({'font.size':fontsize, 'font.sans-serif': "Arial"})
fig, ax = plt.subplots(2,2, figsize=(10,10))
ax = ax.ravel()

# Load and plot HWY21
hwy21_fn = os.path.join(data_path, '..', 'ITD_Functional_Class', 'ITD_HWY_21.shp')
hwy21 = gpd.read_file(hwy21_fn).to_crs(crs)
for i, axis in enumerate(ax):
    for j, geom in enumerate(hwy21.geometry[0].geoms):
        if (i==0) & (j==0):
            label='HWY 21'
        else:
            label='_nolegend'
        axis.plot(*geom.coords.xy, color='w', linewidth=2)
        axis.plot(*geom.coords.xy, color='k', linewidth=1, label=label)
        
def bounds_to_polygon(bounds):
    return Polygon([[bounds.left, bounds.bottom],
                    [bounds.right, bounds.bottom],
                    [bounds.right, bounds.top],
                    [bounds.left, bounds.top],
                    [bounds.left, bounds.bottom]])
    
# iterate over site names
labels = ['a', 'b', 'c', 'd']
for i, site_name in enumerate(list(info_dict.keys())):
    # basemap
    ctx.add_basemap(ax=ax[i+1], source=ctx.providers.USGS.USImagery, attribution=False, crs=crs, zoom=12)
    
    # Reference DEM shaded relief
    refdem_fn = info_dict[site_name]['refdem_fn']
    refdem = gu.Raster(refdem_fn, load_data=True).reproject(crs=crs, res=10)
    # Shaded relief
    ls = LightSource(azdeg=315, altdeg=45)
    hs = ls.hillshade(refdem.data, vert_exag=5)
    ax[i+1].imshow(hs, cmap='Greys_r',
                 extent=(refdem.bounds.left, refdem.bounds.right, refdem.bounds.bottom, refdem.bounds.top))
    im = ax[i+1].imshow(refdem.data, cmap='terrain', alpha=0.7, clim=(1500, 3000),
                      extent=(refdem.bounds.left, refdem.bounds.right, refdem.bounds.bottom, refdem.bounds.top))
    # outline on map
    ax[0].plot([refdem.bounds.left-2e3, refdem.bounds.right+2e3, refdem.bounds.right+2e3, refdem.bounds.left-2e3, refdem.bounds.left-2e3],
               [refdem.bounds.bottom-2e3, refdem.bounds.bottom-2e3, refdem.bounds.top+2e3, refdem.bounds.top+2e3, refdem.bounds.bottom-2e3], 
               '-w', linewidth=1, label='_nolegend')
    ax[0].text((refdem.bounds.left + refdem.bounds.right)/2, (refdem.bounds.bottom + refdem.bounds.top)/2, 
               labels[i+1], fontweight='bold', color='w', fontsize=fontsize+2, ha='center', va='center')
    
    # DEMs
    for dem_fn in list(info_dict[site_name]['DEM_fns']):
        dem = gu.Raster(dem_fn)
        dem_poly = bounds_to_polygon(dem.bounds)
        ax[i+1].plot(*dem_poly.exterior.coords.xy, color='#fe9929', linewidth=2, label='SkySat DEM')
    
    # SNOTEL site info
    snotel_fn = info_dict[site_name]['SNOTEL_fn']
    snotel = pd.read_csv(snotel_fn)
    snotel['geometry'] = snotel['geometry'].apply(wkt.loads)
    snotel_gdf = gpd.GeoDataFrame(snotel, crs='EPSG:4326')
    snotel_gdf = snotel_gdf.to_crs(dem.crs)
    ax[i+1].plot(*snotel_gdf['geometry'].values[0].coords.xy, '*', markersize=15, 
                 markerfacecolor='b', markeredgecolor='w', linewidth=0.3, label="SNOTEL site")
    
    # title
    ax[i+1].set_title(f"{labels[i+1]}) {info_dict[site_name]['site_name_display']}")
    # adjust axes
    ax[i+1].set_xticks([])
    ax[i+1].set_yticks([])
    ax[i+1].set_xlim(refdem.bounds.left-2e3, refdem.bounds.right+2e3)
    ax[i+1].set_ylim(refdem.bounds.bottom-2e3, refdem.bounds.top+2e3)
    
# legend
handles0, labels0 = ax[0].get_legend_handles_labels()
handles2, labels2 = ax[2].get_legend_handles_labels()
fig.legend(handles0+handles2, labels0+labels2, loc='upper center', 
           ncols=len(labels0+labels2), frameon=False, bbox_to_anchor=[0.4, 0.75, 0.2, 0.2])

# colorbar
fig.subplots_adjust(right=0.87, hspace=0.25)
cbar_ax = fig.add_axes([0.9, 0.25, 0.03, 0.5])
fig.colorbar(im, cax=cbar_ax, shrink=0.8, label='Elevation [m]')

# ax[0] adjustments
ax[0].set_title('a) All sites')
ax[0].set_xticks(ax[0].get_xticks())
ax[0].set_xticklabels(np.divide(ax[0].get_xticks(), 1e3).astype(int).astype(str))
ax[0].set_yticks(ax[0].get_yticks())
ax[0].set_yticklabels(np.divide(ax[0].get_yticks(), 1e3).astype(int).astype(str))
ax[0].set_xlabel('Easting [km]')
ax[0].set_ylabel('Northing [km]')
xmin, xmax = 555e3, 675e3
ymin, ymax = 4818e3, 4920e3
ax[0].set_xlim(xmin, xmax)
ax[0].set_ylim(ymin, ymax)

# Plot Idaho locator map
idaho_fn = '/Users/raineyaberle/Research/PhD/GIS_data/US_State_Boundaries/Idaho_State_Boundaries.gpkg'
idaho = gpd.read_file(idaho_fn)
idaho = idaho.to_crs(crs)
ax2 = fig.add_axes([0.11, 0.73, 0.15, 0.15])
ax2.set_axis_off()
idaho.plot(facecolor='#bdbdbd', edgecolor='k', linewidth=0.5, ax=ax2)
xmin, xmax = ax[0].get_xlim()
ymin, ymax = ax[0].get_ylim()
ax2.fill_between([xmin, xmax], [ymin, ymin], [ymax, ymax], color='g')

# Add city names
geolocator = Nominatim(user_agent="My-Map")
cities = ["Boise", "Idaho City", "Stanley, Custer County, Idaho"]
df = pd.DataFrame(columns=['city', 'geometry'])
for city in cities:
    location = geolocator.geocode(city, addressdetails=True)
    df = pd.concat([df, pd.DataFrame({'city': [city],
                                      'geometry': Point(location.longitude, location.latitude)})])
gdf = gpd.GeoDataFrame(df, crs='EPSG:4326')
gdf = gdf.to_crs((crs)).reset_index(drop=True)
for city in cities:
    ax[0].text(gdf.loc[gdf['city']==city].geometry.values[0].coords.xy[0][0], 
               gdf.loc[gdf['city']==city].geometry.values[0].coords.xy[1][0],
               city.split(',')[0], fontsize=fontsize, color='w', fontweight='bold', style='italic', ha='center')

# basemap
ctx.add_basemap(ax=ax[0], source=ctx.providers.USGS.USImagery, attribution=False, crs=dem.crs)

# fig.tight_layout(pad=3)
plt.show()
    
# Save figure
fig_fn = os.path.join(figures_path, "study_sites.png")
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)
    

## Stable surface errors

In [None]:
# # Stable surfaces and GCP differences
        # # Load files
        # ddem_fn = os.path.join(data_path, site_name, date, 'skysat_snow', 'corr_coreg_diff', 'final_ddem.tif')
        # ddem = gu.Raster(ddem_fn)
        # ss_mask_fn = os.path.join(data_path, site_name, date, 'skysat_snow', 'land_cover_masks', 'stable_surfaces_mask.tif')
        # ss_mask = gu.Raster(ss_mask_fn).reproject(ddem)
        # ss_mask = (ss_mask==1)
        # gcp_mask_fn = os.path.join(data_path, site_name, date, 'skysat_snow', 'land_cover_masks', 'gcp_mask.tif')
        # gcp_mask = gu.Raster(gcp_mask_fn).reproject(ddem)
        # gcp_mask = (gcp_mask==1)
        # # Mask dDEM to stable surfaces and GCP
        # ddem_ss = ddem[ss_mask]
        # ddem_gcp = ddem[gcp_mask]
        # # Plot histograms
        # ax[2,i].hist(ddem_ss.data.ravel(), bins=bins, color=ss_color, alpha=0.8, label='Stable surfaces')
        # ax2 = ax[2,i].twinx()
        # hist = ax2.hist(ddem_gcp.data.ravel(), bins=bins, color=gcp_color, alpha=0.8, label='GCP')
        # # adjust axes
        # ax2.set_ylim(0, 1.4*np.nanmax(hist[0]))
        # ax[2,i].set_xlim(vmin, vmax)
        # ax[2,i].set_xlabel('dDEM [m]')
        # ax[2,i].axvline(0, color='k', linewidth=1)
        # ax[2,i].tick_params(axis='y', colors=ss_color, labelsize=fontsize-2)
        # ax[2,i].spines['left'].set_color(ss_color)
        # ax2.tick_params(axis='y', colors=gcp_color, labelsize=fontsize-2)
        # ax2.spines['right'].set_color(gcp_color)

## Ramps and Tilts

In [None]:
# Set up figure
fontsize=12
plt.rcParams.update({'font.size':fontsize, 'font.sans-serif': "Arial"})
fig, ax = plt.subplots(2, 6, figsize=(12, 6), constrained_layout=True, gridspec_kw=dict(width_ratios=[1,1,1,1,1,0.1]))

# Grab list of satellites
satellites = []
for site_name in info_dict.keys():
    satellites += info_dict[site_name]['satellites']
    
# Define histogram settings
vmin, vmax = -5, 5
bins = np.linspace(vmin, vmax, 50)
ss_color = 'gray'
gcp_color = 'b'

# Iterate over sites
i = 0 # DEM count
for site_name in info_dict.keys():
    # Iterate over DEMs
    for dem_fn in info_dict[site_name]['DEM_fns']:

        # Load input files
        dem = gu.Raster(dem_fn)
        date = dem_fn.split('_')[1]
        
        # Deramp surface
        deramp_diff_fn = os.path.join(data_path, site_name, date, 'skysat_snow', 'corr_coreg_diff', 'fit1_deramp_difference.tif')
        deramp_diff = gu.Raster(deramp_diff_fn)
        # Plot the surface
        deramp_im = ax[0,i].imshow(deramp_diff.data, cmap='PuOr', vmin=-15, vmax=15)
        ax[0,i].set_xticks([])
        ax[0,i].set_yticks([])
        
        # Tilt fit
        tilt_diff_fn = os.path.join(data_path, site_name, date, 'skysat_snow', 'corr_coreg_diff', 'fit2_tilt_difference.tif')
        tilt_diff = gu.Raster(tilt_diff_fn)
        # Plot the surface
        tilt_im = ax[1,i].imshow(tilt_diff.data, cmap='copper_r', vmin=-0.01, vmax=0.01)
        ax[1,i].set_xticks([])
        ax[1,i].set_yticks([])
        
        # ylabels and titles
        if i==0:
            ax[0,0].set_ylabel('Ramp', fontsize=fontsize)
            ax[1,0].set_ylabel('Tilt', fontsize=fontsize)
            # ax[2,0].set_ylabel('Pixel counts', fontsize=fontsize)
        ax[0,i].set_title(f"{info_dict[site_name]['site_name_display']}\n{date}\n{satellites[i]}", fontsize=fontsize)
        
        i+=1
    
# colorbars
fig.colorbar(deramp_im, cax=ax[0,-1])
fig.colorbar(tilt_im, cax=ax[1,-1])

# legend
# ax[-1,-1].set_axis_off()
# handles1, labels1 = ax[2,4].get_legend_handles_labels()
# handles2, labels2 = ax2.get_legend_handles_labels()
# fig.legend(handles1+handles2, labels1+labels2, loc='lower right', frameon=False, bbox_to_anchor=[0.9, 0.15, 0.2, 0.2])

# panel labels
labels = list(string.ascii_lowercase)
count=0
for i in range(ax.shape[0]):
    for j in range(ax.shape[1]-1):
        ax[i,j].text(0.1, 0.9, labels[count], transform=ax[i,j].transAxes, 
                     ha='center', va='center', fontsize=fontsize+2, fontweight='bold', color='k',
                     bbox=dict(facecolor='w', edgecolor='None', alpha=0.6))
        count+=1

plt.show()

# Save figure
fig_fn = os.path.join(figures_path, 'surface_corrections.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)


## Canopy height comparison

## Snow depth comparison with SNOTEL and Lidar

## Snow depth prediction