# Classify snow-covered area (SCA) in Landsat, Sentinel-2, and MODIS surface reflectance imagery: full pipeline

Rainey Aberle

Department of Geosciences, Boise State University

2022

### Requirements:
- Area of Interest (AOI) shapefile: where snow will be classified in all available images. 
- Google Earth Engine (GEE) account: used to pull DEM over the AOI. Sign up for a free account [here](https://earthengine.google.com/new_signup/). 

### Outline:
__0. Setup__ paths in directory, AOI file location - _modify this section!_

__1. Load images__ over the AOI since 2016. 

__2. Prepare image collections__ for classification: reprojection and image quality masking. 

__3. Classify SCA__ and use the snow elevations distribution to estimate the seasonal snowline

-------


### 0. Setup

#### Define paths in directory and desired settings. 
Modify lines located within the following:

`#### MODIFY HERE ####`  

`#####################`

In [None]:
##### MODIFY HERE #####

# -----Paths in directory
site_name = 'Wolverine'
# path to snow-cover-mapping/
base_path = '/Users/raineyaberle/Research/PhD/snow_cover_mapping/snow-cover-mapping/'
# path to AOI including the name of the shapefile
AOI_fn = base_path + '../study-sites/' + site_name + '/glacier_outlines/' + site_name + '_USGS_*.shp'
# path to DEM including the name of the tif file
# Note: set DEM_fn=None if you want to use the ASTER GDEM on Google Earth Engine
DEM_fn = base_path + '../study-sites/' + site_name + '/DEMs/' + site_name + '*_DEM_filled.tif'
# path for output images
out_path = base_path + '../study-sites/' + site_name + '/imagery/'
# path for output figures
figures_out_path = base_path + '../study-sites/' + site_name + '/figures/'

# -----Determine settings
plot_results = True # = True to plot figures of results for each image where applicable
skip_clipped = False # = True to skip images where bands appear "clipped", i.e. max blue SR < 0.8
crop_to_AOI = True # = True to crop images to AOI before calculating SCA
save_outputs = True # = True to save SCA images to file
save_figures = True # = True to save SCA output figures to file

#######################

# -----Import packages
import xarray as xr
import rioxarray
import wxee as wx
import os
import numpy as np
import glob
from osgeo import gdal
import matplotlib.dates as mdates
from matplotlib.dates import DateFormatter
from matplotlib.patches import Rectangle
from matplotlib import pyplot as plt, dates
import rasterio as rio
import rasterio.features
from rasterio.mask import mask
from rasterio.plot import show
from shapely.geometry import Polygon, shape
import shapely.geometry
from scipy.interpolate import interp2d
from scipy import stats
import pandas as pd
import geopandas as gpd
import geemap
import math
import sys
import ee
import fiona
import pickle
import wxee as wx
import time

# -----Add path to functions
sys.path.insert(1, base_path+'functions/')
import ps_pipeline_utils as f

# -----Load dataset dictionary
with open(base_path + 'inputs-outputs/datasets_characteristics.pkl', 'rb') as fn:
    dataset_dict = pickle.load(fn)

#### Authenticate and initialize Google Earth Engine (GEE). 

__Note:__ The first time you run the following cell, you will be asked to authenticate your GEE account for use in this notebook. This will send you to an external web page, where you will walk through the GEE authentication workflow and copy an authentication code back in this notebook when prompted. 

In [None]:
try:
    ee.Initialize()
except: 
    ee.Authenticate()
    ee.Initialize()

#### Load AOI and DEM

In [None]:
# -----Load AOI as gpd.GeoDataFrame
AOI_fn = glob.glob(AOI_fn)[0]
AOI = gpd.read_file(AOI_fn)
# reproject the AOI to WGS to solve for the optimal UTM zone
AOI_WGS = AOI.to_crs(4326)
AOI_WGS_centroid = [AOI_WGS.geometry[0].centroid.xy[0][0],
                    AOI_WGS.geometry[0].centroid.xy[1][0]]
epsg_UTM = f.convert_wgs_to_utm(AOI_WGS_centroid[0], AOI_WGS_centroid[1])
    
# -----Load DEM as Xarray DataSet
if DEM_fn==None:
    
    # query GEE for DEM
    DEM, AOI_UTM = f.query_GEE_for_DEM(AOI)
    
else:
    
    # reproject AOI to UTM
    AOI_UTM = AOI.to_crs(str(epsg_UTM))
    # load DEM as xarray DataSet
    DEM_fn = glob.glob(DEM_fn)[0]
    DEM_rio = rio.open(DEM_fn) # open using rasterio to access the transform
    DEM = xr.open_dataset(DEM_fn)
    DEM = DEM.rename({'band_data': 'elevation'})
    # reproject the DEM to the optimal UTM zone
    DEM = DEM.rio.reproject(str('EPSG:'+epsg_UTM))

### 1. Landsat


In [None]:
dataset = 'Landsat'

# -----Load images over the AOI
# reformat AOI for clipping images
AOI_WGS_bb_ee = ee.Geometry.Polygon(
                        [[[AOI_WGS.geometry.bounds.minx[0], AOI_WGS.geometry.bounds.miny[0]],
                          [AOI_WGS.geometry.bounds.maxx[0], AOI_WGS.geometry.bounds.miny[0]],
                          [AOI_WGS.geometry.bounds.maxx[0], AOI_WGS.geometry.bounds.maxy[0]],
                          [AOI_WGS.geometry.bounds.minx[0], AOI_WGS.geometry.bounds.maxy[0]],
                          [AOI_WGS.geometry.bounds.minx[0], AOI_WGS.geometry.bounds.miny[0]]]
                        ])

# define search filters
date_range_start = '2017-01-01'
date_range_end = '2022-12-01'
month_start = 5
month_end = 10
cloud_cover_max = 100

In [None]:
def clip_image(im):
    return im.clip(AOI_WGS_bb_ee.buffer(2000))

print('Landsat image collection:')
# Query GEE for imagery
L = (ee.ImageCollection("LANDSAT/LC08/C02/T1_L2")
         .filter(ee.Filter.lt("CLOUD_COVER", cloud_cover_max))
         .filterDate(ee.Date(date_range_start), ee.Date(date_range_end))
         .filter(ee.Filter.calendarRange(month_start, month_end, 'month'))
         .filterBounds(AOI_WGS_bb_ee))
# define band names
L_band_names = ['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7', 'QA_PIXEL']
#  clip images to AOI and select bands
L_clip = L.map(clip_image).select(L_band_names)
# convert image collection to xarray Dataset
L_xr = L_clip.wx.to_xarray(scale=30, crs='EPSG:4326')
# define RGB bands
L_RGB_bands = ['SR_B4', 'SR_B3', 'SR_B2']

In [None]:
# -----Prepare image collection for classification

# Reproject image to UTM using rasterio.reproject
L_xr_UTM = L_xr.rio.reproject("EPSG:"+epsg_UTM)
# replace no data values with NaN
for band in L_band_names:
    L_xr_UTM[band] = L_xr_UTM[band].where(L_xr_UTM[band] != L_xr_UTM[band]._FillValue)
# account for image scalar
for band in L_band_names[0:-1]:
    L_xr_UTM[band] = L_xr_UTM[band] * dataset_dict[dataset]['SR_scalar']
    
# Mask cloud-covered pixels
L_xr_UTM_mask = f.Landsat_mask_clouds(L_xr_UTM, L_band_names, plot_results=True)


In [None]:
# -----Classify images
import matplotlib
# load image classifier and feature columns
clf_fn = base_path+'inputs-outputs/L_classifier_all_sites.sav'
clf = pickle.load(open(clf_fn, 'rb'))
feature_cols_fn = base_path+'inputs-outputs/L_feature_cols.pkl'
feature_cols = pickle.load(open(feature_cols_fn,'rb'))

# calculate NDSI using red and NIR bands
NDSI_bands = dataset_dict[dataset]['NDSI']
L_xr_UTM_mask['NDSI'] = ((L_xr_UTM_mask[NDSI_bands[0]] - L_xr_UTM_mask[NDSI_bands[1]]) 
                         / (L_xr_UTM_mask[NDSI_bands[0]] + L_xr_UTM_mask[NDSI_bands[1]]))
# loop through images
for i, t in enumerate(L_xr_UTM_mask.time):
    
    # subset image collection to time
    im = L_xr_UTM_mask.sel(time=t)
    
    # find indices of real numbers (no NaNs allowed in classification)
    I_real = np.where((~np.isnan(im.to_array().data[0])) & (~np.isnan(im.to_array().data[1])) 
                      & (~np.isnan(im.to_array().data[2])) & (~np.isnan(im.to_array().data[3])))
        
    # create df of image band values
    df = pd.DataFrame(columns=feature_cols)
    for col in feature_cols:
        df[col] = np.ravel(im[col].data[I_real])
    df = df.reset_index(drop=True)
    
    # -----Classify image
    if len(df)>1:
        array_classified = clf.predict(df[feature_cols])
    else:
        print("No real values found to classify, skipping...")
        continue

    # reshape from flat array to original shape
    im_classified = np.zeros(np.shape(im.to_array().data[0]))
    im_classified[:] = np.nan
    im_classified[I_real] = array_classified
    
    # -----Determine snow-covered elevations
    # mask the DEM using the AOI
#     mask = rio.features.geometry_mask(AOI_UTM.geometry,
#                                       out_shape=(len(DEM.y), len(DEM.x)),
#                                       transform=DEM_rio.transform,
#                                       invert=True)
#     mask = xr.DataArray(mask , dims=("y", "x"))
#     # mask DEM values outside the AOI
#     DEM_AOI = DEM.where(mask == True)
#     # interpolate DEM to the image coordinates
#     # im_classified = im_classified.squeeze(drop=True) # drop uneccesary dimensions
#     x, y = im.indexes.values() # grab indices of image
#     DEM_AOI_interp = DEM_AOI.interp(x=x, y=y, method="nearest") # interpolate DEM to image coordinates
#     # determine snow covered elevations
#     DEM_AOI_interp_snow = DEM_AOI_interp.where(im_classified<=2) # mask pixels not classified as snow
#     snow_est_elev = DEM_AOI_interp_snow.elevation.data.flatten() # create array of snow-covered pixel elevations
#     snow_est_elev = snow_est_elev[~np.isnan(snow_est_elev)] # remove NaN values

#     # -----Determine bins to use in histogram
#     elev_min = np.fix(np.nanmin(DEM_AOI_interp.elevation.data.flatten())/10)*10
#     elev_max = np.round(np.nanmax(DEM_AOI_interp.elevation.data.flatten())/10)*10
#     bin_edges = np.linspace(elev_min, elev_max, num=int((elev_max-elev_min)/10 + 1))
#     bin_centers = (bin_edges[1:] + bin_edges[0:-1]) / 2

#     # -----Calculate elevation histograms
#     H_DEM = np.histogram(DEM_AOI_interp.elevation.data.flatten(), bins=bin_edges)[0]
#     H_snow_est_elev = np.histogram(snow_est_elev, bins=bin_edges)[0]
#     H_snow_est_elev_norm = H_snow_est_elev / H_DEM

    # -----Plot
    fig, ax = plt.subplots(1, 3, figsize=(16,8))#, gridspec_kw={'height_ratios': [3, 1]})
    ax = ax.flatten()
    # define x and y limits
    xmin, xmax = np.min(im.x.data)/1e3, np.max(im.x.data)/1e3
    ymin, ymax = np.min(im.y.data)/1e3, np.max(im.y.data)/1e3
    # define colors for plotting
    color_snow = '#4eb3d3'
    color_ice = '#084081'
    color_rock = '#fdbb84'
    color_water = '#bdbdbd'
    color_contour = '#f768a1'
    # create colormap
    colors = [color_snow, color_snow, color_ice, color_rock, color_water]
    cmp = matplotlib.colors.ListedColormap(colors)
    # RGB image
    ax[0].imshow(np.dstack([im['SR_B4'].data, im['SR_B3'].data, im['SR_B2'].data]),
               extent=(xmin, xmax, ymin, ymax))
    ax[0].set_xlabel("Easting [km]")
    ax[0].set_ylabel("Northing [km]")
    ax[0].set_title('RGB image')
    # classified image
    ax[1].imshow(im_classified, cmap=cmp, vmin=1, vmax=5,
                 extent=(xmin, xmax, ymin, ymax))
    # plot dummy points for legend
    ax[1].scatter(0, 0, color=color_snow, s=50, label='snow')
    ax[1].scatter(0, 0, color=color_ice, s=50, label='ice')
    ax[1].scatter(0, 0, color=color_rock, s=50, label='rock')
    ax[1].scatter(0, 0, color=color_water, s=50, label='water')
    ax[1].set_title('Classified image')
    ax[1].set_xlabel('Easting [km]')
    ax[1].legend(loc='best')
    # NDSI threshold
    im_ndsi_threshold = np.where(im['NDSI'].data >= 0.4, 1, 0)
    cmp_snow = matplotlib.colors.ListedColormap(['white', color_snow])
    im_ndsi = ax[2].imshow(im_ndsi_threshold, cmap=cmp_snow, clim=(0,1), extent=(xmin, xmax, ymin, ymax))
    ax[2].set_xlabel("Easting [km]")
    ax[2].set_title('NDSI')
    # fig.colorbar(im_ndsi, ax=ax[2], shrink=0.5)
    # AOI
    if AOI.geometry[0].geom_type=='MultiPolygon': # loop through geoms if AOI = MultiPolygon
        for poly in AOI.geometry[0].geoms:
            ax[0].plot([x/1e3 for x in poly.exterior.coords.xy[0]], [y/1e3 for y in poly.exterior.coords.xy[1]], '-k', linewidth=1, label='AOI')
            ax[1].plot([x/1e3 for x in poly.exterior.coords.xy[0]], [y/1e3 for y in poly.exterior.coords.xy[1]], '-k', linewidth=1, label='_nolegend_')
            ax[2].plot([x/1e3 for x in poly.exterior.coords.xy[0]], [y/1e3 for y in poly.exterior.coords.xy[1]], '-k', linewidth=1, label='_nolegend_')
    else:
        ax[0].plot([x/1e3 for x in AOI.geometry[0].exterior.coords.xy[0]], [y/1e3 for y in AOI.geometry[0].exterior.coords.xy[1]], '-k', linewidth=1, label='AOI')
        ax[1].plot([x/1e3 for x in AOI.geometry[0].exterior.coords.xy[0]], [y/1e3 for y in AOI.geometry[0].exterior.coords.xy[1]], '-k', linewidth=1, label='_nolegend_')
        ax[2].plot([x/1e3 for x in AOI.geometry[0].exterior.coords.xy[0]], [y/1e3 for y in AOI.geometry[0].exterior.coords.xy[1]], '-k', linewidth=1, label='_nolegend_')
    # reset x and y limits
    ax[0].set_xlim(xmin, xmax)
    ax[0].set_ylim(ymin, ymax)
    ax[1].set_xlim(xmin, xmax)
    ax[1].set_ylim(ymin, ymax)
    ax[2].set_xlim(xmin, xmax)
    ax[2].set_ylim(ymin, ymax)
    # image bands histogram
    # h_b = ax[2].hist(im['SR_B2'].data.flatten(), color='blue', histtype='step', linewidth=2, bins=100, label="blue")
    # h_g = ax[2].hist(im['SR_B3'].data.flatten(), color='green', histtype='step', linewidth=2, bins=100, label="green")
    # h_r = ax[2].hist(im['SR_B4'].data.flatten(), color='red', histtype='step', linewidth=2, bins=100, label="red")
    # h_nir = ax[2].hist(im['SR_B5'].data.flatten(), color='brown', histtype='step', linewidth=2, bins=100, label="NIR")
    # ax[2].set_xlabel("Surface reflectance")
    # ax[2].set_ylabel("Pixel counts")
    # ax[2].legend(loc='best')
    # ax[2].grid()
    # # normalized snow elevations histogram
    # ax[3].bar(bin_centers, H_snow_est_elev_norm, width=(bin_centers[1]-bin_centers[0]), color=color_snow, align='center')
    # ax[3].set_xlabel("Elevation [m]")
    # ax[3].set_ylabel("% snow-covered")
    # ax[3].grid()
    # ax[3].set_xlim(elev_min-10, elev_max+10)
    # ax[3].set_ylim(0,1)
    fig.suptitle(str(t.data)[0:10])
    fig.tight_layout()
    plt.show()
    
    # save figure
    fig_fn = figures_out_path + 'L_'+str(t.data)[0:10]+'_SCA.png'
    fig.savefig(fig_fn, dpi=300, facecolor='w')
    print('figure saved to file: ' + fig_fn)


In [None]:
### Test NDSI method
for i, t in enumerate(L_xr_UTM_mask.time):
    
    # subset image collection to time
    im = L_xr_UTM_mask.sel(time=t)
    
    # threshold NDSI and red bands
    ndsi_thresh = np.where(im['NDSI'].data >=0.4, 1, np.nan)
    red_thresh = np.where(im['SR_B4'].data >=0.9, 1, np.nan)
    im_classified = np.where((ndsi_thresh==1) & (red_thresh==1), 1, np.nan)

    # plot RGB, NDSI threshold, and red band threshold
    fig, ax = plt.subplots(3, 2, figsize=(12,16))
    ax = ax.flatten()
    # fig.delaxes(ax[3])
    # define x and y limits
    xmin, xmax = np.min(im.x.data)/1e3, np.max(im.x.data)/1e3
    ymin, ymax = np.min(im.y.data)/1e3, np.max(im.y.data)/1e3
    # define colors for plotting
    color_snow = '#4eb3d3'
    color_ice = '#084081'
    color_rock = '#fdbb84'
    color_water = '#bdbdbd'
    color_contour = '#f768a1'
    # create colormap
    cmp_snow = matplotlib.colors.ListedColormap([color_rock, color_snow])
    cmp_red = matplotlib.colors.ListedColormap(['w', 'm'])
    # RGB image
    ax[0].imshow(np.dstack([im['SR_B4'].data, im['SR_B3'].data, im['SR_B2'].data]),
               extent=(xmin, xmax, ymin, ymax))
    ax[0].set_xlabel("Easting [km]")
    ax[0].set_ylabel("Northing [km]")
    ax[0].set_title('RGB image')
    # classified image
    ax[1].imshow(im_classified, cmap='Blues', clim=(0,1),
               extent=(xmin, xmax, ymin, ymax))
    ax[1].set_title('Classified image')
    # NDSI
    im_ndsi = ax[2].imshow(im['NDSI'].data, cmap='RdBu', clim=(-1,1),
               extent=(xmin, xmax, ymin, ymax))
    ax[2].set_title('NDSI')
    fig.colorbar(im_ndsi, ax=ax[2], shrink=0.5)
    # NDSI threshold
    im_ndsi_thresh = ax[3].imshow(ndsi_thresh, cmap='Blues', clim=(0,1),
               extent=(xmin, xmax, ymin, ymax))
    ax[3].set_title('NDSI threshold')
    fig.colorbar(im_ndsi_thresh, ax=ax[3], shrink=0.5)
    # red band
    im_red = ax[4].imshow(im['SR_B4'].data, cmap='Reds', clim=(0,1),
               extent=(xmin, xmax, ymin, ymax))
    ax[4].set_xlabel("Easting [km]")
    ax[4].set_ylabel("Northing [km]")
    ax[4].set_title('Red band')
    fig.colorbar(im_red, ax=ax[4], shrink=0.5)
    # red band threshold
    im_ndsi_thresh = ax[5].imshow(red_thresh, cmap='Reds', clim=(0,1),
               extent=(xmin, xmax, ymin, ymax))
    ax[5].set_title('Red threshold')
    fig.colorbar(im_ndsi_thresh, ax=ax[5], shrink=0.5)
    
    fig.suptitle(str(t.data)[0:10])
    fig.tight_layout()
    plt.show()

## 2. Sentinel

In [None]:
# # -----Query GEE for Landsat, Sentinel, and MODIS imagery
# S_col = (ee.ImageCollection("COPERNICUS/S2_SR")
#          .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', cloud_cover_max))
#          .filterDate(ee.Date(date_range_start), ee.Date(date_range_end))
#          .filter(ee.Filter.calendarRange(month_start, month_end, 'month'))
#          .filterBounds(AOI_WGS_bb_ee))
# M_col = (ee.ImageCollection("MODIS/061/MOD09GA").merge(ee.ImageCollection("MODIS/061/MYD09GA"))
#          .filterDate(ee.Date(date_range_start), ee.Date(date_range_end))
#          .filter(ee.Filter.calendarRange(month_start, month_end, 'month'))
#          .filterBounds(AOI_WGS_bb_ee))

# # -----Clip images to AOI and select bands
# S_band_names = ['B2', 'B3', 'B4', 'B5', 'B6', 'B8', 'B11', 'B12', 'QA60']
# M_band_names = ['sur_refl_'+x for x in ['b01', 'b04', 'b03', 'b05', 'b06', 'b07', 'state_1km']]

# S_col_clip = S_col.map(clip_image).select(S_band_names)
# M_col_clip = M_col.map(clip_image).select(M_band_names)

# # -----Convert image collections to xarray Datasets
# print('Sentinel-2 image collection:')
# S_col_xr = S_col_clip.wx.to_xarray(scale=20, crs='EPSG:4326')
# print('MODIS image collection:')
# M_col_xr = M_col_clip.wx.to_xarray(scale=500, crs='EPSG:4326')

### 2. Prepare image collections for classification: reproject to UTM and apply quality masks

In [None]:
### Classify images

# -----Load image classifier and feature columns
clf_fn = base_path+'inputs-outputs/L_classifier_all_sites.sav'
clf = pickle.load(open(clf_fn, 'rb'))
feature_cols_fn = base_path+'inputs-outputs/L_feature_cols.pkl'
feature_cols = pickle.load(open(feature_cols_fn,'rb'))

# calculate NDSI using red and NIR bands
L_xr_UTM_mask['NDSI'] = (L_xr_UTM_mask['SR_B3'] - L_xr_UTM_mask['SR_B6']) / (L_xr_UTM_mask['SR_B3'] + L_xr_UTM_mask['SR_B6'])
# loop through images
for i, t in enumerate(L_xr_UTM_mask.time):
    
    # subset image collection to time
    im = L_xr_UTM_mask.sel(time=t)
    
    # find indices of real numbers (no NaNs allowed in classification)
    I_real = np.where(~np.isnan(im.to_array().data[0]))
        
    # create df of image band values
    df = pd.DataFrame(columns=feature_cols)
    for col in feature_cols:
        df[col] = np.ravel(im[col].data[I_real])
    df = df.reset_index(drop=True)
    
    # -----Classify image
    if len(df)>1:
        array_classified = clf.predict(df[feature_cols])
    else:
        print("No real values found to classify, skipping...")

    # reshape from flat array to original shape
    im_classified = np.zeros(np.shape(im.to_array().data[0]))
    im_classified[:] = np.nan
    im_classified[I_real] = array_classified
    
    fig, ax = plt.subplots(1, 2, figsize=(14, 6))
    ax[0].imshow(im.to_array().data[0])
    ax[1].imshow(im_classified)
    plt.show()

In [None]:
# -----Reproject image collections to UTM
# Reproject using rasterio.reproject
# S_col_xr_UTM = S_col_xr.rio.reproject("EPSG:"+epsg_code)
# M_col_xr_UTM = M_col_xr.rio.reproject("EPSG:"+epsg_code)

In [None]:
# -----Apply image quality masking to Landsat and Sentinel imagery
# Landsat
# Pixel quality band = "QA_PIXEL"
# Bit 3 = cloud shadow, Bit 5 = cloud
# def L8_QA_mask(im):
#     cloudShadowBitMask = 1 << 3
#     cloudsBitMask = 1 << 5;
#     # Get the pixel QA band.
#     qa = im.select('QA_PIXEL')
#     # Both flags should be set to zero, indicating clear conditions.
#     mask = (qa.bitwiseAnd(cloudShadowBitMask).eq(0) & (qa.bitwiseAnd(cloudsBitMask).eq(0)))
#     # Return the masked image without the QA bands.
#     return (
#         image.updateMask(mask)
#         .select("SR_B[0-9]*")
#         .copyProperties(image, ["system:time_start"])
#     )

# L_col_clip_mask = L_col_clip.map(L8_QA_mask)


#   // Sentinel-2 image quality masking
#   function S2QAMask(image){
#     var QA60 = image.select(['QA60']);
#     return image.updateMask(QA60.lt(1));
#   }

### 3. Classify SCA

In [None]:
# start timer
# t1 = time.time()

# -----Load image classifiers and feature columns
# Landsat
L_clf_fn = base_path+'inputs-outputs/L_classifier_all_sites.sav'
L_clf = pickle.load(open(L_clf_fn, 'rb'))
L_feature_cols_fn = base_path+'inputs-outputs/L_feature_cols.pkl'
L_feature_cols = pickle.load(open(L_feature_cols_fn,'rb'))
# Sentinel-2
S_clf_fn = base_path+'inputs-outputs/S2_classifier_all_sites.sav'
S_clf = pickle.load(open(S_clf_fn, 'rb'))
S_feature_cols_fn = base_path+'inputs-outputs/S2_feature_cols.pkl'
S_feature_cols = pickle.load(open(S_feature_cols_fn,'rb'))
# MODIS
M_clf_fn = base_path+'inputs-outputs/M_classifier_all_sites.sav'
M_clf = pickle.load(open(M_clf_fn, 'rb'))
M_feature_cols_fn = base_path+'inputs-outputs/M_feature_cols.pkl'
M_feature_cols = pickle.load(open(M_feature_cols_fn,'rb'))  

In [None]:
# -----Classify snow
# Landsat
for date in L_col_xr.time.data:
    # create data frame to store pixel values
    L_df = pd.DataFrame(columns=L_band_names+['NDSI'])
    # extract pixel values
    for band in L_band_names:
        L_df[band] = L_col_xr.sel(time=date)[band].data.flatten()
    # find indices of rows without NaN
    I_real = np.where(pd.isnull(L_df).any(1)==False)[0]
    # drop rows with NaN
    L_df = L_df.dropna().reset_index(drop=True)
    # calculate NDSI (G-SWIR)/(G+SWIR)
    L_df['NDSI'] = (L_df['SR_B3'] -  L_df['SR_B6']) / (L_df['SR_B3'] +  L_df['SR_B6'])
    
    # classify SCA
    array_classified = L_clf.predict(L_df[L_feature_cols])
    
    # reshape from flat array to original shape
    im_classified = np.zeros((np.shape(L_col_xr['SR_B3'].sel(time=date).data)[0], np.shape(L_col_xr['SR_B3'].sel(time=date).data)[1]))
    im_classified[:] = np.nan
    im_classified[I_real] = array_classified
    
    # plot classified image
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,6))
    plt.rcParams.update({'font.size': 14, 'font.sans-serif': 'Arial'})
    # define x and y limits
    xmin, xmax = np.min(L_col_xr.sel(time=date).x.data)/1000, np.max(L_col_xr.sel(time=date).x.data)/1000
    ymin, ymax = np.min(L_col_xr.sel(time=date).y.data)/1000, np.max(L_col_xr.sel(time=date).y.data)/1000
    # RGB image
    ax1.imshow(np.dstack([L_col_xr['SR_B4'].sel(time=date), L_col_xr['SR_B3'].sel(time=date), L_col_xr['SR_B2'].sel(time=date)]), 
               extent=(xmin, xmax, ymin, ymax))
    ax1.set_xlabel("Easting [km]")
    ax1.set_ylabel("Northing [km]")
    # define colors for plotting
    color_snow = '#4eb3d3'
    color_ice = '#084081'
    color_rock = '#fdbb84'
    color_water = '#bdbdbd'
    # snow
    if any(im_classified.flatten()==1):
        ax2.imshow(np.where(im_classified == 1, 1, np.nan), cmap=matplotlib.colors.ListedColormap([color_snow, 'white']),
                    extent=(xmin, xmax, ymin, ymax))
        ax2.scatter(0, 0, color=color_snow, s=50, label='snow') # plot dummy point for legend
    if any(im_classified.flatten()==2):
        ax2.imshow(np.where(im_classified == 2, 4, np.nan), cmap=matplotlib.colors.ListedColormap([color_snow, 'white']),
                    extent=(xmin, xmax, ymin, ymax))
    # ice
    if any(im_classified.flatten()==3):
        ax2.imshow(np.where(im_classified == 3, 1, np.nan), cmap=matplotlib.colors.ListedColormap([color_ice, 'white']),
                    extent=(xmin, xmax, ymin, ymax))
        ax2.scatter(0, 0, color=color_ice, s=50, label='ice') # plot dummy point for legend
    # rock/debris
    if any(im_classified.flatten()==4):
        ax2.imshow(np.where(im_classified == 4, 1, np.nan), cmap=matplotlib.colors.ListedColormap([color_rock, 'white']),
                    extent=(xmin, xmax, ymin, ymax))
        ax2.scatter(0, 0, color=color_rock, s=50, label='rock') # plot dummy point for legend
    # water
    if any(im_classified.flatten()==5):
        ax2.imshow(np.where(im_classified == 5, 10, np.nan), cmap=matplotlib.colors.ListedColormap([color_water, 'white']),
                    extent=(xmin, xmax, ymin, ymax))
        ax2.scatter(0, 0, color=color_water, s=50, label='water') # plot
    ax2.set_xlim(xmin, xmax)
    ax2.set_ylim(ymin, ymax)
    plt.show()

In [None]:
# create data frame to store pixel values
L_df = pd.DataFrame(columns=L_band_names+['NDSI'])
# extract pixel values
for band in L_band_names:
    L_df[band] = L_col_xr.sel(time=date)[band].data.flatten()
L_df