# Notebook to develop supervised classification algorithm for identifying snow in PlanetScope 4-band, Landsat 8, Sentinel-2, and MODIS imagery
Rainey Aberle

Adapted from the [SciKit Learn Classifier comparison tutorial](https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html)

## Outline:
1. Set up training data
- PlanetScope
- Landsat 8/9
- Sentinel-2 SR
- Sentinel-2 TOA
    
2. Develop supervised classifiers for EACH site and ALL sites

3. *Optional*
- Test how the number of training points and samples impact model accuracies (calculate learning curves)
- MODIS (not recommended)
- Sentinel-1 (not recommended)

## 0. Initial Setup

Import packages, define paths in directory, authenticate Google Earth Engine (GEE), define classification settings.

In [None]:
# -----Import packages
import os
import glob
import numpy as np
import rasterio as rio
import geopandas as gpd
import pandas as pd
import ee
import scipy
import wxee as wx
import xarray as xr
import rioxarray as rxr
from shapely.geometry import Polygon
from sklearn.tree import DecisionTreeClassifier 
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression
from sklearn.cluster import KMeans
from sklearn import metrics 
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import matplotlib
import sys
import pickle
import seaborn as sns
from shapely.geometry import Point, MultiPoint
from joblib import dump, load
import json
import datetime
from shapely import wkt

In [None]:
# -----Determine whether to save outputs to file
save_outputs = False # = True to save output figures and best classifier 

# -----Define paths in directory
# base directory (path to glacier-snow-cover-mapping/)
base_path = '/Users/raineyaberle/Research/PhD/snow_cover_mapping/snow-cover-mapping/'
# output folder for best classifier
out_path = os.path.join(base_path, 'inputs-outputs')
# output folder for figures
figures_out_path = os.path.join(base_path, 'figures')
# path to classified points used to train and test classifiers
data_pts_path = '/Users/raineyaberle/Research/snow_cover_mapping/classified-points/'

# -----Determine settings
terrain_parameters = False # whether to use terrain parameters (elevation, slope, aspect) in classification
save_figures = True # whether to save output figures

# -----Add path to functions
sys.path.insert(1, os.path.join(base_path, 'functions'))
import pipeline_utils as f

### Create dictionary of dataset-specific parameters

In [None]:
dataset_dict = {
    'PlanetScope':{
        'image_scalar': 1e4,
        'refl_bands': {
            'Blue': '0',
            'Green': '1',
            'Red': '2',
            'NIR': '3'
        },
        'no_data_value': -9999,
        'resolution_m': 3,
        'NDSI_bands': ['Green', 'NIR'],
        'RGB_bands': ['Red', 'Green', 'Blue']
    },
    
    'Landsat': {
        'image_scalar': 1/2.75e-05,
        'refl_bands': {
            'SR_B1': 'Ultra blue',
            'SR_B2': 'Blue',
            'SR_B3': 'Green',
            'SR_B4': 'Red',
            'SR_B5': 'NIR',
            'SR_B6': 'SWIR1',
            'SR_B7': 'SWIR2',
            'ST_B10': 'Bit mask'
        },
        'no_data_value': 0,
        'resolution_m': 30,
        'NDSI_bands': ['SR_B3', 'SR_B6'],
        'RGB_bands': ['SR_B4', 'SR_B3', 'SR_B2']
    },
    
    'Sentinel-2_SR':{
        'image_scalar': 1e4,
        'refl_bands': {
             'B1': 'Aerosols',
             'B2': 'Blue',
             'B3': 'Green',
             'B4': 'Red',
             'B5': 'Red Edge 1',
             'B6': 'Red Edge 2',
             'B7': 'Red Edge 3',
             'B8': 'NIR',
             'B8A': 'Red Edge 4',
             'B9': 'Water vapor',
             'B11': 'SWIR1',
             'B12': 'SWIR2'
        },
        'no_data_value': 0,
        'resolution_m': 10,
        'NDSI_bands': ['B3', 'B11'],
        'RGB_bands': ['B4', 'B3', 'B2']
    },
    
    'Sentinel-2_TOA':{
        'image_scalar': 1e4,
        'refl_bands': {
             'B1': 'Aerosols',
             'B2': 'Blue',
             'B3': 'Green',
             'B4': 'Red',
             'B5': 'Red Edge 1',
             'B6': 'Red Edge 2',
             'B7': 'Red Edge 3',
             'B8': 'NIR',
             'B8A': 'Red Edge 4',
             'B9': 'Water vapor',
             'B11': 'SWIR1',
             'B12': 'SWIR2'
        },
        'no_data_value': 0,
        'resolution_m': 10,
        'NDSI_bands': ['B3', 'B11'],
        'RGB_bands': ['B4', 'B3', 'B2']
    },
    
    'classified_image': {
        'image_scalar': 0,
        'bands':{
            'classified': 'classified image',
            'elevation': 'interpolated elevation'
        },
        'no_data_value': -9999,
        'classification_values':{
            '1': 'Snow',
            '2': 'Shadowed snow',
            '3': 'Ice',
            '4': 'Rock',
            '5': 'Water'
        },
        'class_colors':{
            'Snow': '#4eb3d3',
            'Shadowed_snow': '#6a51a3',
            'Ice': '#084081',
            'Rock': '#fe9929',
            'Water': '#252525'
        }
    }
        
}

# save dictionary as json file
# dataset_dict_fn = 'datasets_characteristics.json'
# json.dump(dataset_dict, open(os.path.join(out_path, dataset_dict_fn), 'w'))
# print('dictionary saved to file: ' + os.path.join(out_path, dataset_dict_fn))
# # open dictionary and display
# dataset_dict = json.load(open(os.path.join(out_path, dataset_dict_fn)))
# dataset_dict

## 1. Construct or load training data

In [None]:
# -----Define datasets
datasets = ['PlanetScope', 'Landsat', 'Sentinel-2_SR', 'Sentinel-2_TOA']

# -----Define site names
site_names = ['Gulkana', 'SouthCascade', 'Sperry', 'Wolverine']

# -----Grab colormap for classified image
colors = list(dataset_dict['classified_image']['class_colors'].values())
    
# -----Loop through datasets
for dataset in datasets:
    
    print('----------')
    print(dataset)
    print('----------')

    # Check if training data exist in file
    training_data_fn = dataset + '_training_data.csv'
    feature_cols_fn = dataset + '_feature_columns.json'
    if (os.path.exists(os.path.join(out_path, training_data_fn))) and (os.path.exists(os.path.join(out_path, feature_cols_fn))):

        data_pts_full = pd.read_csv(os.path.join(out_path, training_data_fn))
        feature_cols = json.load(open(os.path.join(out_path, feature_cols_fn)))
        print('Training data and feature columns already exist... loaded from file.')

    else: 
        
        # Define band names feature columns used to classify
        band_names = list(dataset_dict[dataset]['refl_bands'].keys())
        if dataset=='PlanetScope':
            feature_cols = band_names + ['NDSI']
        else:
            feature_cols = band_names[0:-1] + ['NDSI']

        # Load data points file names
        if 'Sentinel-2' in dataset:
            data_pts_path_dataset = os.path.join(data_pts_path, 'Sentinel-2')
        else:
            data_pts_path_dataset = data_pts_path + dataset + '/'
        os.chdir(data_pts_path_dataset)
        data_pts_fns = sorted(glob.glob('*.shp'))

        # Load image file names
        im_fns = sorted(glob.glob('*.tif'))
        if dataset=='Sentinel-2_SR':
            im_fns = [im_fn for im_fn in im_fns if 'SR' in im_fn]
        elif dataset=='Sentinel-2_TOA':
            im_fns = [im_fn for im_fn in im_fns if 'TOA' in im_fn]
        
        # Initialize full data points dataframe (for use in next step)
        data_pts_full = gpd.GeoDataFrame()

        # Loop through sites
        for i, site_name in enumerate(site_names):
            
            # Grab image file names, dates, and data point file names
            im_fns_site = [im_fn for im_fn in im_fns if site_name in im_fn]
            print('image file names: ', im_fns_site)
            if dataset=='Landsat':
                im_dates = im_dates = [x[-12:-4] for x in im_fns_site]
            else:
                im_dates = [x.split(site_name+'_')[1][0:8] for x in im_fns_site]
            data_pts_site_fns = [data_pts_fn for data_pts_fn in data_pts_fns if site_name in data_pts_fn]

            # loop through image dates
            for j, im_date in enumerate(im_dates):

                # compile classified points
                data_pts = pd.DataFrame() # dataframe to hold applicable data classes
                # snow
                if len([s for s in data_pts_site_fns if ('snow.shp' in s) and (im_date in s)])>0: # check if class exists for site and date
                    data_pts_snow_fn = [s for s in data_pts_fns if ('snow.shp' in s) and (im_date in s)][0]
                    data_pts_snow = gpd.read_file(data_pts_path_dataset + data_pts_snow_fn) # read file
                    data_pts_snow['Class'] = 1 # determine class ID
                    data_pts = pd.concat([data_pts, data_pts_snow], ignore_index=True) # concatenate to full data points df
                    print(data_pts_snow_fn)
                # shadowed snow
                if len([s for s in data_pts_site_fns if ('snow-shadowed.shp' in s)  and (im_date in s)])>0: # check if class exists for site and date
                    data_pts_snow_sh_fn = [s for s in data_pts_site_fns if ('snow-shadowed.shp' in s) and (im_date in s)][0]
                    data_pts_snow_sh = gpd.read_file(data_pts_path_dataset + data_pts_snow_sh_fn) # read file
                    data_pts_snow_sh['Class'] = 2 # determine class ID
                    data_pts = pd.concat([data_pts, data_pts_snow_sh], ignore_index=True) # concatenate to full data points df
                    print(data_pts_snow_sh_fn)
                # ice
                if len([s for s in data_pts_site_fns if ('ice.shp' in s) and (im_date in s)])>0: # check if class exists for site and date
                    data_pts_ice_fn = [s for s in data_pts_site_fns if ('ice.shp' in s)  and (im_date in s)][0]
                    data_pts_ice = gpd.read_file(data_pts_path_dataset + data_pts_ice_fn)  # read file
                    data_pts_ice['Class'] = 3 # determine class ID
                    data_pts = pd.concat([data_pts, data_pts_ice], ignore_index=True) # concatenate to full data points df
                    print(data_pts_ice_fn)
                # rock
                if len([s for s in data_pts_site_fns if ('rock.shp' in s) and (im_date in s)])>0: # check if class exists for site and date
                    data_pts_rock_fn = [s for s in data_pts_site_fns if ('rock.shp' in s)  and (im_date in s)][0]
                    data_pts_rock = gpd.read_file(data_pts_path_dataset + data_pts_rock_fn) # read file
                    data_pts_rock['Class'] = 4 # determine class ID
                    data_pts = pd.concat([data_pts, data_pts_rock], ignore_index=True) # concatenate to full data points df
                    print(data_pts_rock_fn)
                # water
                if len([s for s in data_pts_site_fns if ('water.shp' in s)  and (im_date in s)])>0: # check if class exists for site and date
                    data_pts_water_fn = [s for s in data_pts_site_fns if ('water.shp' in s) and (im_date in s)][0]
                    data_pts_water = gpd.read_file(data_pts_path_dataset + data_pts_water_fn) # read file
                    data_pts_water['Class'] = 5 # determine class ID
                    data_pts = pd.concat([data_pts, data_pts_water], ignore_index=True) # concatenate to full data points df
                    print(data_pts_water_fn)
                # remove 'id' column
                data_pts.drop('id', axis=1, inplace=True)
                # remove rows with empty geometries
                data_pts.dropna(inplace=True)
                # reformat MultiPoint objects as Point objects
                data_pts['geometry'] = [x.geoms[0] for x in data_pts['geometry']]

                # Load AOI
                AOI_path = base_path + '../study-sites/' + site_name + '/AOIs/' + site_name + '_USGS_*.shp'
                AOI_fn = glob.glob(AOI_path)[0]
                AOI = gpd.read_file(AOI_fn)
                # reproject AOI to WGS 84 for compatibility with images
                AOI_WGS = AOI.to_crs('EPSG:4326')
                # Determine optimal UTM zone EPSG code
                epsg_UTM = f.convert_wgs_to_utm((AOI_WGS.geometry.bounds.maxx[0] - AOI_WGS.geometry.bounds.minx[0]) + AOI_WGS.geometry.bounds.minx[0],
                                                (AOI_WGS.geometry.bounds.maxy[0] - AOI_WGS.geometry.bounds.miny[0]) + AOI_WGS.geometry.bounds.miny[0])
                AOI_UTM = AOI.to_crs('EPSG:'+str(epsg_UTM))
                
                # Load image
                im_fn_date = [im_fn for im_fn in im_fns if im_date in im_fn][0]
                # Adjust image radiometry for PlanetScope images
                if dataset=='PlanetScope':
                    # load image
                    im_ds = xr.open_dataset(data_pts_path_dataset + im_fn_date)
                    # load DEM
                    DEM_path = base_path + '../study-sites/' + site_name + '/DEMs/'
                    DEM_fn = glob.glob(DEM_path+site_name+'*_DEM*.tif')[0]
                    DEM = xr.open_dataset(DEM_fn)
                    DEM = DEM.rename({'band_data':'elevation'})
                    # remove unnecessary data (possible extra bands from ArcticDEM or other DEM)
                    if len(np.shape(DEM.elevation.data))>2:
                        DEM['elevation'] = DEM.elevation[0]
                    DEM.rio.write_crs('EPSG:'+str(rxr.open_rasterio(DEM_fn).rio.crs.to_epsg()), inplace=True)
                    DEM = DEM.rio.reproject('EPSG:'+epsg_UTM)
                    DEM.rio.write_crs('EPSG:'+epsg_UTM, inplace=True)
                    # create polygons of top and bottom 20th percentile elevations
                    polygon_top, polygon_bottom = f.create_aoi_elev_polys(AOI_UTM, DEM)
                    # adjust image radiometry
                    im_ds = f.planetscope_adjust_image_radiometry(im_ds, np.datetime64(im_date[0:4]+'-'+im_date[4:6]+'-'+im_date[6:8]),
                                                                  polygon_top, polygon_bottom, dataset_dict, skip_clipped=False)[0]
                else:
                    # read in xarray.DataArray
                    im_da = rxr.open_rasterio(data_pts_path_dataset + im_fn_date)
                    # reproject to optimal UTM zone (if necessary)
                    im_da = im_da.rio.reproject('EPSG:'+str(epsg_UTM))
                    # convert to xarray.DataSet
                    im_ds = im_da.to_dataset('band')
                    # account for image scalar and no data values
                    im_ds = xr.where(im_ds != dataset_dict[dataset]['no_data_value'],
                                     im_ds / dataset_dict[dataset]['image_scalar'], np.nan)
                    # rename bands
                    im_ds = im_ds.rename({i + 1: name for i, name in enumerate(band_names)})
                    # set CRS
                    im_ds.rio.write_crs('EPSG:'+str(im_da.rio.crs.to_epsg()), inplace=True)

                # reproject data points to optimal UTM zone
                data_pts = data_pts.to_crs('EPSG:'+epsg_UTM)

                # sample band values at data points
                band_names = [band_name for band_name in band_names if 'QA' not in band_name] # remove QA bands from band names
                for band_name in band_names:
                    data_pts[band_name] = [im_ds.sel(x=x.x, y=x.y, method="nearest")[band_name].data 
                                           for x in data_pts['geometry'].values]
                
                # plot images and data points
                fig1, ax1 = f.plot_xr_rgb_image(im_ds, dataset_dict[dataset]['RGB_bands'])
                ax1.scatter([x.x/1e3 for x in data_pts['geometry'].loc[data_pts['Class']==1].values], 
                            [x.y/1e3 for x in data_pts['geometry'].loc[data_pts['Class']==1].values], c=colors[0], s=1, label='Snow')
                ax1.scatter([x.x/1e3 for x in data_pts['geometry'].loc[data_pts['Class']==2].values], 
                            [x.y/1e3 for x in data_pts['geometry'].loc[data_pts['Class']==2].values], c=colors[1], s=1, label='Shadowed snow')
                ax1.scatter([x.x/1e3 for x in data_pts['geometry'].loc[data_pts['Class']==3].values], 
                            [x.y/1e3 for x in data_pts['geometry'].loc[data_pts['Class']==3].values], c=colors[2], s=1, label='Ice')
                ax1.scatter([x.x/1e3 for x in data_pts['geometry'].loc[data_pts['Class']==4].values], 
                            [x.y/1e3 for x in data_pts['geometry'].loc[data_pts['Class']==4].values], c=colors[3], s=1, label='Rock')
                ax1.scatter([x.x/1e3 for x in data_pts['geometry'].loc[data_pts['Class']==5].values], 
                            [x.y/1e3 for x in data_pts['geometry'].loc[data_pts['Class']==5].values], c=colors[4], s=1, label='Water')
                ax1.set_xlabel('Easting [km]')
                ax1.set_ylabel('Northing [km]')
                ax1.legend(loc='best')
                ax1.set_title(im_date)
                plt.show() 

                # add data points to full data points data frame
                data_pts['NDSI'] = ((data_pts[dataset_dict[dataset]['NDSI_bands'][0]] - data_pts[dataset_dict[dataset]['NDSI_bands'][1]]) 
                                    / (data_pts[dataset_dict[dataset]['NDSI_bands'][0]] + data_pts[dataset_dict[dataset]['NDSI_bands'][1]])) # add NDSI column
                data_pts['image_date'] = im_date # add image date column
                data_pts['site_name'] = site_name # add site name column
                data_pts = data_pts.to_crs('EPSG:4326') # reproject back to WGS84 for compatibility
                data_pts_full = pd.concat([data_pts_full, data_pts])
    
        # Reset dataframe indices
        data_pts_full.reset_index(drop=True, inplace=True)

        # Plot spectral pairplot for training data
        df = data_pts_full
        df['Class'] = df['Class'].astype(int)
        df = df.sort_values(by='Class')
        # Assign labels to each class
        df.loc[df['Class']==1, 'Class'] = 'Snow'
        df.loc[df['Class']==2, 'Class'] = 'Shadowed snow'
        df.loc[df['Class']==3, 'Class'] = 'Ice'
        df.loc[df['Class']==4, 'Class'] = 'Rock'
        df.loc[df['Class']==5, 'Class'] = 'Water'
        df[feature_cols] = df[feature_cols].astype(float)
        # plot
        fig = sns.pairplot(df, vars=feature_cols, corner=True, diag_kind='kde', hue='Class', 
                           palette=colors)
        plt.show()
    
        # Save figure to file
        fig_fn = figures_out_path + 'spectral_pairplot_' + dataset + '_training_data.png'
        fig.savefig(fig_fn, facecolor='w', dpi=300)
        print('Spectral pairplot figure saved to file: '+fig_fn)

        # Save dataframe to file
        data_pts_full = f.reduce_memory_usage(data_pts_full) # Reduce memory usage in df
        data_pts_full.to_csv(out_path + training_data_fn, index=False)
        print(dataset + ' training data saved to file:' + out_path + training_data_fn)

        # Save feature columns
        feature_cols_fn = out_path + dataset + '_feature_columns.json'
        json.dump(feature_cols, open(feature_cols_fn, "w"))
        print('Feature columns saved to file: ', feature_cols_fn)
    
    # Rename data points dataframe according to dataset
    if dataset=='PlanetScope':
        data_pts_full_PS = data_pts_full
        feature_cols_PS = feature_cols
    elif dataset=='Landsat':
        data_pts_full_L = data_pts_full
        feature_cols_L = feature_cols
    elif dataset=='Sentinel-2_SR':
        data_pts_full_S2_SR = data_pts_full
        feature_cols_S2_SR = feature_cols
    elif dataset=='Sentinel-2_TOA':
        data_pts_full_S2_TOA = data_pts_full
        feature_cols_S2_TOA = feature_cols


## 2. Develop supervised classifiers for each dataset

### Define supervised classification algorithms to test 

In [None]:
# -----Classifier names
names = [
    "Decision Tree",
    "Random Forest",
    "Neural Network",
    "Ada Boost",
    "Naive Bayes",
    "QDA",
    "Logistic Regression",
    "Nearest Neighbors",
    "Support Vector Machine",

]

# -----Classifiers
classifiers = [
    DecisionTreeClassifier(max_depth=5),
    RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
    MLPClassifier(alpha=1, max_iter=1000),
    AdaBoostClassifier(),
    GaussianNB(),
    QuadraticDiscriminantAnalysis(),
    LogisticRegression(random_state = 0, max_iter = 1000),
    KNeighborsClassifier(3),
    SVC(gamma=2, C=1),
]

# -----Define number of folds to use in K-folds cross-validation
num_folds = 10

### Use classified points at each site to determine the best classifiers for EACH site and for ALL sites using K-folds cross-validation. 

In [None]:
save_outputs = False

# -----Loop through datasets
for dataset in datasets:
    
    print('----------')
    print(dataset)
    print('----------')
    
    # Define variables and dataset prefix to use in file names based on dataset
    if dataset=='PlanetScope':
        data_pts_full = data_pts_full_PS.dropna().reset_index(drop=True)
        feature_cols = feature_cols_PS
    elif dataset=='Landsat':
        data_pts_full = data_pts_full_L.dropna().reset_index(drop=True)
        feature_cols = feature_cols_L
    elif dataset=='Sentinel-2_SR':
        data_pts_full = data_pts_full_S2_SR.dropna().reset_index(drop=True)
        feature_cols = feature_cols_S2_SR
    elif dataset=='Sentinel-2_TOA':
        data_pts_full = data_pts_full_S2_TOA.dropna().reset_index(drop=True)
        feature_cols = feature_cols_S2_TOA
        
    # Reformat string datatypes in dataframes
    data_pts_full['geometry'] = data_pts_full['geometry'].apply(wkt.loads)
        
    # -----Test one classifier for all sites
    print('Testing one classifier for ALL sites...')
    # data_pts_full[feature_cols] = data_pts_full[feature_cols].astype(float)
    data_pts_full = data_pts_full.dropna().reset_index(drop=True)
    X = data_pts_full[feature_cols].astype(float) # features
    y = data_pts_full['Class'].astype(int) # labels
    
    # Iterate over classifiers
    num_folds = 10
    accuracy = np.zeros(len(classifiers)) # mean accuracy
    K = np.zeros(len(classifiers)) # mean Kappa score
    CM = np.zeros((4, 4, len(classifiers))) # confusion matrix
    j=0
    for name, clf in zip(names, classifiers):

        print(name)

        # Conduct K-Fold cross-validation
        kfold = KFold(n_splits=num_folds, shuffle=True, random_state=1)
        accuracy_folds = np.zeros(num_folds) # accuracy for all simulations
        K_folds = np.zeros(num_folds) # kappa score for all MC simulations
        # CM_folds = np.zeros((4, 4, num_folds)) # confusion matrix for all folds
        # enumerate the splits and summarize the distributions
        k=0
        for train_ix, test_ix in kfold.split(X):

            # select rows
            X_train, X_test = X.loc[train_ix], X.loc[test_ix]
            y_train, y_test = y[train_ix], y[test_ix]

            # Train classifier
            clf.fit(X_train, y_train)

            # Predict class values using trained classifier
            y_pred = clf.predict(X_test)

            # Calculate overall accuracy
            accuracy_folds[k] = metrics.accuracy_score(y_test, y_pred)
            # Calculate Kappa score
            K_folds[k] = metrics.cohen_kappa_score(y_test, y_pred)
            
            k+=1

        # Calculate mean accuracy and Kappa score
        accuracy[j] = np.nanmean(accuracy_folds)
        K[j] = np.nanmean(K_folds)
        
        # Determine feature importance using Random Forest model
        if name=="Random Forest":
            importances = clf.feature_importances_
            # print('    Feature importances:')
            # [print('   ', x, y) for x, y in zip(feature_cols, importances)]
            fig, ax = plt.subplots(1,1, figsize=(8,6))
            bars = ax.bar(np.arange(len(feature_cols)), importances)
            ax.bar_label(bars)
            ax.set_xticks(np.arange(len(feature_cols)))
            ax.set_xticklabels(feature_cols)
            ax.set_xlabel('Features')
            ax.set_ylabel('Importance')
            ax.set_title(dataset+' Random Forest model feature importances')
            plt.show()
            fig.savefig(figures_out_path + 'RandomForest_feature_importances_'+dataset+'.png', 
                     dpi=300, facecolor='w')
            print('    figure saved to file')
        j+=1

    # -----Determine best classifier based on Kappa score
    results = pd.DataFrame()
    results['Classifier'], results['Accuracy'], results['Kappa_score'] = names, accuracy, K
    clf_best_name = names[np.where(K==np.max(K))[0][0]]
    clf_best = classifiers[np.where(K==np.max(K))[0][0]]
    print(results)
    print('')
    print('Best accuracy classifier: ' + clf_best_name)

    # -----Save most accurate classifier
    if save_outputs:
        clf_fn = out_path + dataset + '_classifier_all_sites.joblib'
        dump(clf_best, clf_fn)
        print('Most accurate classifier saved to file: ',clf_fn)
    
    # -----Print confusion matrix
    # CM_copy = pd.DataFrame(columns=['Snow', 'Shadowed_snow', 'Ice', 'Bare_ground', 'Water', 'All_snow'])
    # CM_copy['Snow'] = CM[:,0,0]
    # CM_copy['Shadowed_snow'] = CM[:,1,0]
    # CM_copy['Ice'] = CM[:,2,0]
    # CM_copy['Bare_ground'] = CM[:,3,0]
    # CM_copy['Water'] = CM[:,4,0]
    # CM_copy['All_snow'] = CM_copy['Snow'] + CM_copy['Shadowed_snow']
    # CM_copy = CM_copy.drop(columns=['Snow', 'Shadowed_snow'])
    # print(CM_copy)

## 3. *Optional*

### a. Test how the number of training points impacts model accuracies (i.e., calculate learning curves)

In [None]:
from sklearn.model_selection import learning_curve, ShuffleSplit
from sklearn.metrics import make_scorer

# Define scoring function
scorer = make_scorer(metrics.accuracy_score)  

# Create a function to plot learning curves
def plot_learning_curve(estimator, estimator_name, X, y, cv=None, axis=None, color='k',
                        label='_nolegend_', n_jobs=None, train_sizes=np.linspace(500, 6.5e3, 7).astype(int)):

    train_sizes, train_scores, test_scores = learning_curve(
        estimator, X, y, cv=cv, scoring=scorer, n_jobs=n_jobs, train_sizes=train_sizes)
    
    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std = np.std(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)
    test_scores_std = np.std(test_scores, axis=1)
    axis.plot(train_sizes, test_scores_mean, 'o-', color=color, label=label)
    axis.grid()

    return plt

# -----Set up figure
plt.rcParams.update({'font.size':12, 'font.sans-serif':'Arial'})
fig, ax = plt.subplots(2, 2, figsize=(12,8))
plt.subplots_adjust(right=0.78)
ax = ax.flatten()
# define colors for plotting results for each model
colors = ['#a6cee3', '#1f78b4', '#b2df8a', '#33a02c', '#fb9a99', 
          '#e31a1c', '#fdbf6f', '#ff7f00', '#6a3d9a']
# set up titles
text_labels = ['(a) ', '(b) ', '(c) ', '(d) ']

# -----Iterate over datasets
for i, dataset in enumerate(datasets):
    
    print(dataset)
    print('----------')
    
    # Define variables and dataset prefix to use in file names based on dataset
    if dataset=='PlanetScope':
        data_pts_full = data_pts_full_PS.dropna().reset_index(drop=True)
        feature_cols = feature_cols_PS
        train_sizes = np.linspace(1e3, 7e3, 4).astype(int)
    elif dataset=='Landsat':
        data_pts_full = data_pts_full_L.dropna().reset_index(drop=True)
        feature_cols = feature_cols_L
        train_sizes = np.linspace(1e3, 6e3, 6).astype(int)
    elif dataset=='Sentinel-2_SR':
        data_pts_full = data_pts_full_S2_SR.dropna().reset_index(drop=True)
        feature_cols = feature_cols_S2_SR
        train_sizes = np.linspace(1e3, 9e3, 5).astype(int)
    elif dataset=='Sentinel-2_TOA':
        data_pts_full = data_pts_full_S2_TOA.dropna().reset_index(drop=True)
        feature_cols = feature_cols_S2_TOA
        train_sizes = np.linspace(1e3, 9e3, 5).astype(int)

    # Iterate over models
    for name, clf, color in zip(names, classifiers, colors):

        # Split training data into X and y
        X = data_pts_full[feature_cols].astype(float) # features
        y = data_pts_full['Class'].astype(int) # labels        
        
        # Plot learning curves
        cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)  
        plot_learning_curve(clf, name, X, y, cv=cv, axis=ax[i], color=color, label=name, n_jobs=-1)
        
    # adjust axes
    if i >=2:
        ax[i].set_xlabel('Number of training samples')
    if (i==0) or (i==2):
        ax[i].set_ylabel('Accuracy score')
    ax[i].set_title(text_labels[i] + dataset.replace('_',' '))
    ax[i].set_yticks([0.6, 0.7, 0.8, 0.9, 1.0])
    ax[i].set_ylim(0.6, 1)

    print(' ')

handles, labels = ax[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='center right')        
plt.show()

# save figure
fig_fn = os.path.join(base_path, 'figures', 'figS2_classifiers_learning_curves.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('figure saved to file: ' + fig_fn)

### b. MODIS (not recommended)

In [None]:
dataset='MODIS'
ds_dict = dataset_dict[dataset]

# -----Define bands and feature columns (predictors used in classification)
# MODIS bands: 
# sur_refl_b01=Red, sur_refl_b04=Green, sur_refl_b03=Blue, sur_refl_b05=SWIR1, 
# sur_refl_b06=SWIR2, sur_refl_b07=SWIR3
band_names = [band for band in ds_dict['bands'] if 'qc' not in band]
feature_cols_M = [band for band in band_names] + ['NDSI']
# buffer used for clipping images
buffer = 3000 # [m]

# -----Check if training data exist in file
M_training_data_fn = 'MODIS_training_data.pkl'
if os.path.exists(out_path + M_training_data_fn):
    
    data_pts_full_M = pd.read_pickle(out_path + M_training_data_fn)
    print('MODIS training data already exist... loaded from file.')
    
else: 
    
    # Initialize full data points dataframe (for use in next step)
    data_pts_full_M = data_pts_full_PS.copy(deep=True)
    data_pts_full_M = data_pts_full_M.drop(columns=['blue', 'red', 'green', 'NIR', 'NDSI'])
    data_pts_full_M[band_names] = " " # initialize band columns

    # Loop through sites
    for i, site_name in enumerate(site_names):
    
        print('----------')
        print(site_name)
        print('----------')

        # Extract image dates from data point file names
        im_dates = [s[len(site_name)+1:len(site_name)+9] for s in data_pts_fns if (site_name in s) and ('snow.shp' in s)]

        # Load AOI
        AOI_fn = AOI_path + site_name + '/glacier_outlines/' + site_name + '_USGS_*.shp'
        AOI_fn = glob.glob(AOI_fn)[0]
        AOI = gpd.read_file(AOI_fn)
        # reproject AOI to WGS 84 for compatibility with images
        AOI_WGS = AOI.to_crs(4326)
        # reformat AOI_WGS bounding box as ee.Geometry for clipping images
        AOI_WGS_bb_ee = ee.Geometry.Polygon(
                                [[[AOI_WGS.geometry.bounds.minx[0], AOI_WGS.geometry.bounds.miny[0]],
                                  [AOI_WGS.geometry.bounds.maxx[0], AOI_WGS.geometry.bounds.miny[0]],
                                  [AOI_WGS.geometry.bounds.maxx[0], AOI_WGS.geometry.bounds.maxy[0]],
                                  [AOI_WGS.geometry.bounds.minx[0], AOI_WGS.geometry.bounds.maxy[0]],
                                  [AOI_WGS.geometry.bounds.minx[0], AOI_WGS.geometry.bounds.miny[0]]]
                                ])

        # Load images from Earth Engine
        if site_name=='Gulkana':
            im1_fn = '2021_06_15'
            im2_fn = '2021_08_06'
        elif site_name=='SouthCascade':
            im1_fn= '2021_07_03'
            im2_fn = '2021_08_28'
        elif site_name=='Sperry':
            im1_fn = '2021_06_27'
            im2_fn = '2021_08_01'
        elif site_name=='Wolverine':
            im1_fn = '2021_06_15'
            im2_fn = '2021_08_15'
        im1, im2 = ee.Image('MODIS/061/MOD09GA/'+im1_fn), ee.Image('MODIS/061/MOD09GA/'+im2_fn)
    
        # Clip images and select bands
        im1_clip = im1.clip(AOI_WGS_bb_ee.buffer(buffer)).select(band_names)
        im2_clip = im2.clip(AOI_WGS_bb_ee.buffer(buffer)).select(band_names)
        # Convert images to xarray Datasets
        im1_xr = im1_clip.wx.to_xarray(scale=ds_dict['resolution_m'], crs='EPSG:4326')
        im2_xr = im2_clip.wx.to_xarray(scale=ds_dict['resolution_m'], crs='EPSG:4326')
        # Determine optimal UTM zone EPSG code
        epsg_UTM = f.convert_wgs_to_utm((AOI_WGS.geometry.bounds.maxx[0] - AOI_WGS.geometry.bounds.minx[0]) + AOI_WGS.geometry.bounds.minx[0],
                                         (AOI_WGS.geometry.bounds.maxy[0] - AOI_WGS.geometry.bounds.miny[0]) + AOI_WGS.geometry.bounds.miny[0])
        # Reproject to UTM
        im1_xr = im1_xr.rio.reproject('EPSG:'+epsg_UTM)
        im2_xr = im2_xr.rio.reproject('EPSG:'+epsg_UTM)
        # Replace no data values with NaN, account for image scalar
        im1_xr = xr.where((im1_xr!=ds_dict['no_data_value']) &  (im1_xr > 0), 
                          im1_xr / ds_dict['SR_scalar'], 
                          np.nan)
        im2_xr = xr.where((im2_xr!=ds_dict['no_data_value']) & (im2_xr > 0), 
                          im2_xr / ds_dict['SR_scalar'], 
                          np.nan)
        # Create list of images
        im_list = [im1_xr, im2_xr]
        
        # Loop through image dates
        for j, im_date in enumerate(im_dates):

            im = im_list[j]
            
            # select df columns for study site and image date
            data_pts = data_pts_full_M.loc[(data_pts_full_M['site_name']==site_name) 
                                            & (data_pts_full_M['PS_im_date']==im_date[0:4]+'-'+im_date[4:6]+'-'+im_date[6:8])]
            # reproject to UTM
            data_pts = data_pts.to_crs(epsg_UTM)
            
            # grab x and y coordinates for data points at the site
            data_pts_x = [data_pts['geometry'].reset_index(drop=True)[i].geoms[0].x
                          for i in np.arange(0,len(data_pts))]
            data_pts_y = [data_pts['geometry'].reset_index(drop=True)[i].geoms[0].y
                          for i in np.arange(0,len(data_pts))]
            
            # extract band values at data points 
            for band_name in band_names:
                data_pts[band_name] = [im.sel(x=x, y=y, method="nearest")[band_name].data[0] 
                                       for x, y in list(zip(data_pts_x, data_pts_y))]

            # plot images and data points
            fig1, ax1 = plt.subplots(1, 1, figsize=(6,6))
            ax1.imshow(np.dstack([im[ds_dict['RGB_bands'][0]].data[0],
                                  im[ds_dict['RGB_bands'][1]].data[0],
                                  im[ds_dict['RGB_bands'][2]].data[0]]),
                      extent=(np.min(im.x.data)/1e3, np.max(im.x.data)/1e3, 
                              np.min(im.y.data)/1e3, np.max(im.y.data)/1e3))
            # ax1.scatter([x.geoms[0].x/1e3 for x in data_pts['geometry'].loc[data_pts['class']==1]], 
            #             [x.geoms[0].y/1e3 for x in data_pts['geometry'].loc[data_pts['class']==1]], c='cyan', s=2)
            ax1.plot([x/1e3 for x in data_pts_x], [y/1e3 for y in data_pts_y], '.m', markersize=5)
            ax1.set_xlabel('Easting [km]')
            ax1.set_ylabel('Northing [km]')
            plt.show()

            # add data_pts back to full df
            data_pts_full_M.loc[(data_pts_full_M['site_name']==site_name) 
                                 & (data_pts_full_M['PS_im_date']==im_date[0:4]+'-'+im_date[4:6]+'-'+im_date[6:8])] = data_pts

    # Add NDSI column
    data_pts_full_M['NDSI'] = ((data_pts_full_M[ds_dict['NDSI'][0]] - data_pts_full_M[ds_dict['NDSI'][1]]) / 
                                (data_pts_full_M[ds_dict['NDSI'][0]] + data_pts_full_M[ds_dict['NDSI'][1]]))
    
    # Remove no data points
    data_pts_full_M = data_pts_full_M.dropna().reset_index(drop=True)
    
    # don't use shadowed snow (where class==shadowed snow, make it snow)
    data_pts_full_M.loc[data_pts_full_M['class']==2, 'class'] = 1
    
    # Reduce memory usage in df
    data_pts_full_M = f.reduce_memory_usage(data_pts_full_M)
    
    # Save training data to file
    data_pts_full_M.to_pickle(out_path + M_training_data_fn)
    print('MODIS training data saved to file:' + out_path + M_training_data_fn)
    
    # Save feature columns
    feature_cols_fn = out_path + 'MODIS_feature_columns.pkl'
    pickle.dump(feature_cols_M, open(feature_cols_fn, 'wb'))
    print('Feature columns saved to file: ', feature_cols_fn)
    
    # Plot spectral pairplot for training data
    df = data_pts_full_M
    df = df.sort_values(by='class')
    df['Class'] = df['class'].astype(object)
    # Assign labels to each class
    df.loc[df['Class']==1, 'Class'] = 'Snow'
    df.loc[df['Class']==2, 'Class'] = 'Snow'
    df.loc[df['Class']==3, 'Class'] = 'Ice'
    df.loc[df['Class']==4, 'Class'] = 'Bare rock'
    df.loc[df['Class']==5, 'Class'] = 'Water'
    # Create colormap
    color_snow = '#4eb3d3'
    color_ice = '#084081'
    color_rock = '#fdbb84'
    color_water = '#bdbdbd'
    color_contour = '#f768a1'
    colors = [color_snow, color_ice, color_rock, color_water]
    # plot
    fig = sns.pairplot(df[['Class'] + feature_cols_M], corner=True, diag_kind='kde', hue='Class', palette=colors)
    plt.show()
    # save figure
    if save_figures:
        fig_fn = base_path + 'figures/spectral_pairplot_M_training_data.png'
        fig.savefig(fig_fn, facecolor='w', dpi=300)
        print('figure saved to file:' + fig_fn)

### c. Sentinel-1 (not recommended)

#### Set-up training data

In [None]:
dataset = 'Sentinel-1'

# -----Add path to preprocessing functions
# path to gee_s1_ard/python-api/
ard_path = base_path + '../gee_s1_ard/python-api/'
# add functions to path
sys.path.insert(1, ard_path)
import wrapper as w

# -----Define bands and feature columns (predictors used in classification)
band_names = ['VV', 'VH', 'angle']
feature_cols_S1 = band_names + ['VHVV', 'VV-VH']
# buffer used for clipping images
buffer = 2000 # [m]

# -----Load classified points
os.chdir(data_pts_path)
data_pts_fns = glob.glob('*.shp')
data_pts_fns.sort()

# -----Check if training data exist in file
S1_training_data_fn = 'S1_training_data.pkl'
if os.path.exists(out_path + S1_training_data_fn):
    
    data_pts_full_L = pd.read_pickle(out_path + S1_training_data_fn)
    print('Sentinel-1 training data already exist... loaded from file.')
    
else: 
    
    # Initialize full data points dataframe (for use in next step)
    data_pts_full_S1 = data_pts_full_PS.copy(deep=True)
    # remove PS bands
    data_pts_full_S1 = data_pts_full_S1.drop(columns=['blue', 'green', 'red', 'NIR', 'NDSI'])
    # initialize band columns
    data_pts_full_S1['S1_im_date'] = ' '
    data_pts_full_S1[feature_cols_S1] = 0

    # Loop through sites
    # for i, site_name in enumerate(site_names):
    site_name = 'SouthCascade'
    
    print('----------')
    print(site_name)
    print('----------')

    # Extract image dates from data point file names
    im_dates = [s[len(site_name)+1:len(site_name)+9] for s in data_pts_fns if (site_name in s) and ('snow.shp' in s)]

    # Load AOI
    AOI_fn = AOI_path + site_name + '/glacier_outlines/' + site_name + '_USGS_*.shp'
    AOI_fn = glob.glob(AOI_fn)[0]
    AOI = gpd.read_file(AOI_fn)
    # reproject AOI to WGS 84 for compatibility with images
    AOI_WGS = AOI.to_crs(4326)
    # reformat AOI_WGS bounding box as ee.Geometry for clipping images
    AOI_WGS_bb_ee = ee.Geometry.Polygon(
                            [[[AOI_WGS.geometry.bounds.minx[0], AOI_WGS.geometry.bounds.miny[0]],
                              [AOI_WGS.geometry.bounds.maxx[0], AOI_WGS.geometry.bounds.miny[0]],
                              [AOI_WGS.geometry.bounds.maxx[0], AOI_WGS.geometry.bounds.maxy[0]],
                              [AOI_WGS.geometry.bounds.minx[0], AOI_WGS.geometry.bounds.maxy[0]],
                              [AOI_WGS.geometry.bounds.minx[0], AOI_WGS.geometry.bounds.miny[0]]]
                            ]).buffer(3000)
    # solve for optimal UTM zone for reprojection
    AOI_WGS_centroid = [AOI_WGS.geometry[0].centroid.xy[0][0],
                        AOI_WGS.geometry[0].centroid.xy[1][0]]
    epsg_UTM = f.convert_wgs_to_utm(AOI_WGS_centroid[0], AOI_WGS_centroid[1])
        
        # # Load images from Earth Engine
        # if site_name=='Gulkana':
        #     im1_fn = 'LANDSAT/LC08/C02/T1_L2/LC08_067016_20210610'
        #     im2_fn = 'LANDSAT/LC08/C02/T1_L2/LC08_068016_20210804'
    if site_name=='SouthCascade':
        im1_date_range = ['2021-07-01', '2021-07-03']
        im2_date_range = ['2021-08-27', '2021-08-29']
        # elif site_name=='Sperry':
            # im1_fn = 'LANDSAT/LC08/C02/T1_L2/LC08_041026_20210706'
            # im2_fn = 'LANDSAT/LC08/C02/T1_L2/LC08_041026_20210722'
        # elif site_name=='Wolverine':
            # im1_fn = 'LANDSAT/LC08/C02/T1_L2/LC08_067018_20220629'
            # im2_fn = 'LANDSAT/LC08/C02/T1_L2/LC08_067018_20210829'        
    date_ranges = [im1_date_range, im2_date_range]

    # loop through images
    for j in [0, 1]:

        # Define dictionary of parameters
        params = {'APPLY_BORDER_NOISE_CORRECTION': True,
                  'APPLY_TERRAIN_FLATTENING': True,
                  'APPLY_SPECKLE_FILTERING': True, 
                  'POLARIZATION': 'VVVH',
                  'PLATFORM_NUMBER': None,
                  'ORBIT': None, 
                  'ORBIT_NUM': None, 
                  'SPECKLE_FILTER_FRAMEWORK': 'MULTI',
                  'SPECKLE_FILTER': 'LEE',
                  'SPECKLE_FILTER_KERNEL_SIZE': 9,
                  'SPECKLE_FILTER_NR_OF_IMAGES': 10,
                  'APPLY_TERRAIN_FLATTENING': True,
                  'DEM': ee.Image("NASA/ASTER_GED/AG100_003"),
                  'TERRAIN_FLATTENING_MODEL': 'VOLUME',
                  'TERRAIN_FLATTENING_ADDITIONAL_LAYOVER_SHADOW_BUFFER': 0,
                  'FORMAT' : 'DB',
                  'CLIP_TO_ROI': True,
                  'SAVE_ASSET': False,
                  'ASSET_ID': None,
                  'START_DATE': date_ranges[j][0],
                  'STOP_DATE': date_ranges[j][1],
                  'ROI': AOI_WGS_bb_ee
                 }

        # run the gee_s1_ard wrapper
        im = w.s1_preproc(params)
        
        im_date = im_dates[j]
            
        # Convert ee.ImageCollection to xarray.Dataset
        im_ds = im.wx.to_xarray(scale=10, crs='EPSG:4326')
        # reproject to UTM
        im_ds = im_ds.rio.reproject('EPSG:'+epsg_UTM)
        # replace no data values with NaN
        im_ds = im_ds.where(im_ds!=-32768)
        
        # select df columns for study site and image date
        data_pts = data_pts_full_S1.loc[(data_pts_full_S1['site_name']==site_name) 
                                        & (data_pts_full_S1['PS_im_date']==im_date[0:4]+'-'+im_date[4:6]+'-'+im_date[6:8])]
        # add image date
        data_pts['S1_im_date'] = str(im_ds.time.data[0])[0:10]
        # reproject to UTM
        data_pts = data_pts.to_crs(epsg_UTM)
            
        # grab x and y coordinates for data points at the site
        data_pts_x = [data_pts['geometry'].reset_index(drop=True)[i].geoms[0].x
                      for i in np.arange(0,len(data_pts))]
        data_pts_y = [data_pts['geometry'].reset_index(drop=True)[i].geoms[0].y
                      for i in np.arange(0,len(data_pts))]
        # extract band values at data points 
        for band_name in band_names:
            data_pts[band_name] = [im_ds.sel(x=x, y=y, method="nearest")[band_name].data[0] 
                                   for x, y in list(zip(data_pts_x, data_pts_y))]

        # add data_pts back to full df
        data_pts_full_S1.loc[(data_pts_full_S1['site_name']==site_name) 
                             & (data_pts_full_S1['PS_im_date']==im_date[0:4]+'-'+im_date[4:6]+'-'+im_date[6:8])] = data_pts

        # plot images and data points
        fig1, ax1 = plt.subplots(1, 1, figsize=(10,10))
        ax1.imshow(im_ds['VV'].data[0], cmap='Greys',
                   extent=(np.min(im_ds.x.data)/1e3, np.max(im_ds.x.data)/1e3, 
                           np.min(im_ds.y.data)/1e3, np.max(im_ds.y.data)/1e3))
        ax1.scatter([x.geoms[0].x/1e3 for x in data_pts['geometry'].loc[data_pts['class']==1]], 
                    [x.geoms[0].y/1e3 for x in data_pts['geometry'].loc[data_pts['class']==1]], c='cyan', s=1)
        ax1.scatter([x.geoms[0].x/1e3 for x in data_pts['geometry'].loc[data_pts['class']==3]], 
                    [x.geoms[0].y/1e3 for x in data_pts['geometry'].loc[data_pts['class']==3]], c='blue', s=1)
        ax1.scatter([x.geoms[0].x/1e3 for x in data_pts['geometry'].loc[data_pts['class']==4]], 
                    [x.geoms[0].y/1e3 for x in data_pts['geometry'].loc[data_pts['class']==4]], c='orange', s=1)
        ax1.scatter([x.geoms[0].x/1e3 for x in data_pts['geometry'].loc[data_pts['class']==5]], 
                    [x.geoms[0].y/1e3 for x in data_pts['geometry'].loc[data_pts['class']==5]], c='grey', s=1)
        ax1.set_xlabel('Easting [km]')
        ax1.set_ylabel('Northing [km]')
        plt.show()

    # Remove no data points
    data_pts_full_S1 = data_pts_full_S1.dropna().reset_index(drop=True)
    
    # Add VH/VV and VV-VH columns
    data_pts_full_S1['VHVV'] = data_pts_full_S1['VH'] / data_pts_full_S1['VV']
    data_pts_full_S1['VV-VH'] = data_pts_full_S1['VV'] - data_pts_full_S1['VH']

    # Reduce memory usage in df
    data_pts_full_S1 = f.reduce_memory_usage(data_pts_full_S1)

    # Save training data to file
    data_pts_full_S1.to_pickle(out_path + S1_training_data_fn)
    print('Landsat training data saved to file:' + out_path + S1_training_data_fn)

    # Save feature columns
    feature_cols_fn = out_path + 'Sentinel-1_feature_columns.pkl'
    pickle.dump(feature_cols_S1, open(feature_cols_fn, 'wb'))
    print('Feature columns saved to file: '+ feature_cols_fn)
    

#### Plot pairplot of training data spectral characteristics

In [None]:
data_pts_full_S1['class'] = data_pts_full_S1['class'].astype(object)
fig = sns.pairplot(data_pts_full_S1[['class'] + feature_cols_S1], markers='.',  
             corner=True, diag_kind='kde', hue='class', palette="colorblind");
# save figure
if save_figures:
    fig_fn = base_path + 'figures/spectral_pairplot_Sentinel1.png'
    fig.savefig(fig_fn, facecolor='w', dpi=300)
    print('figure saved to file:' + fig_fn)

#### Test one classifier for _each_ site

In [None]:
# -----Test supervised classification algorithms
# Loop through sites
# for i, site_name in enumerate(site_names):
site_name = 'SouthCascade'

print('----------')
print(site_name)
print('----------')

# Select all columns in data_pts_full_PS for site
data_pts = data_pts_full_S1.loc[data_pts_full_S1['site_name']==site_name]
data_pts = data_pts.reset_index(drop=True)

# Split data points into features (band values / terrain parameters) and target variable (class)
X = data_pts[feature_cols_S1] # features
y = data_pts['class'] # target variable
y = y.astype(int)

# Iterate over classifiers
accuracy = np.zeros(len(classifiers)) # mean accuracy
K = np.zeros(len(classifiers)) # mean Kappa score
j=0
for name, clf in zip(names, classifiers):

    print(name)

    # Conduct K-Fold cross-validation
    kfold = KFold(n_splits=num_folds, shuffle=True, random_state=1)
    accuracy_folds = np.zeros(num_folds) # accuracy for all simulations
    K_folds = np.zeros(num_folds) # kappa score for all MC simulations
    k=0 # iteration counter
    # enumerate the splits and summarize the distributions
    for train_ix, test_ix in kfold.split(X):

        # select rows
        X_train, X_test = X.loc[train_ix], X.loc[test_ix]
        y_train, y_test = y[train_ix], y[test_ix]

        # Train classifier
        clf.fit(X_train, y_train)

        # Predict class values using trained classifier
        y_pred = clf.predict(X_test)

        # Calculate overall accuracy
        accuracy_folds[k] = metrics.accuracy_score(y_test, y_pred)
        # Calculate Kappa score
        K_folds[k] = metrics.cohen_kappa_score(y_test, y_pred)

        k+=1

    # Calculate mean accuracy and Kappa score
    accuracy[j] = np.nanmean(accuracy_folds)
    K[j] = np.nanmean(K_folds)

    j+=1

# Determine best classifier based on accuracy
results = pd.DataFrame()
results['Classifier'], results['Accuracy'], results['Kappa_score'] = names, accuracy, K
clf_best_name = names[np.where(accuracy==np.max(accuracy))[0][0]]
clf_best = classifiers[np.where(accuracy==np.max(accuracy))[0][0]]
print(results)
print('')
print('Best accuracy classifier: ' + clf_best_name)

# -----Save most accurate classifier
if save_outputs==True:
    clf_fn = out_path + 'S1_classifier_'+site_names[i]+'.sav'
    pickle.dump(clf_best, open(clf_fn, 'wb'))
    print('Most accurate classifier saved to file: ',clf_fn)   

#### Test one classifier for _all_ sites

In [None]:
# -----Define image bands
bands = [x for x in im_ds.data_vars]
# bands = [band for band in bands if (band != 'QA_PIXEL') and ('B' in band)]
        
im_date = str(im_ds.time.data[0])[0:10]
print(im_date)
        
im_AOI = im_ds

# add VHVV and VV-VH columns
im_ds['VHVV'] = im_ds['VH'] / im_ds['VV']
im_ds['VV-VH'] = im_ds['VV'] - im_ds['VH']

# find indices of real numbers (no NaNs allowed in classification)
ix = [np.where(np.isnan(im_AOI[band].data), False, True) for band in bands]
I_real = np.full(np.shape(im_AOI[bands[0]].data), True)
for ixx in ix:
    I_real = I_real & ixx
            
# create df of image band values
df = pd.DataFrame(columns=feature_cols_S1)
for col in feature_cols_S1:
    df[col] = np.ravel(im_AOI[col].data[I_real])
df = df.reset_index(drop=True)

# -----Classify image
if len(df)>1:
    array_classified = clf.predict(df[feature_cols_S1])
else:
    print("No real values found to classify, skipping...")
    # continue

# reshape from flat array to original shape
im_classified = np.zeros(im_AOI.to_array().data[0].shape)
im_classified[:] = np.nan
im_classified[I_real] = array_classified
            
# -----Plot results
fig, ax = plt.subplots(1, 2, figsize=(10,6))
ax = ax.flatten()
# define x and y limits
xmin, xmax = np.min(im_ds.x.data)/1e3, np.max(im_ds.x.data)/1e3
ymin, ymax = np.min(im_ds.y.data)/1e3, np.max(im_ds.y.data)/1e3
# define colors for plotting
color_snow = '#4eb3d3'
color_ice = '#084081'
color_rock = '#fdbb84'
color_water = '#bdbdbd'
color_contour = '#f768a1'
# create colormap
colors = [color_snow, color_snow, color_ice, color_rock, color_water]
cmp = matplotlib.colors.ListedColormap(colors)
# RGB image
ax[0].imshow(im_ds['VV'].data[0], cmap='Greys',
             extent=(xmin, xmax, ymin, ymax))
ax[0].set_xlabel("Easting [km]")
ax[0].set_ylabel("Northing [km]")
ax[0].set_title('RGB image')
# classified image
ax[1].imshow(im_classified[0], cmap=cmp, vmin=1, vmax=5,
             extent=(np.min(im_AOI.x.data)/1e3, np.max(im_AOI.x.data)/1e3,
                     np.min(im_AOI.y.data)/1e3, np.max(im_AOI.y.data)/1e3))
# plot dummy points for legend
ax[1].scatter(0, 0, color=color_snow, s=50, label='snow')
ax[1].scatter(0, 0, color=color_ice, s=50, label='ice')
ax[1].scatter(0, 0, color=color_rock, s=50, label='rock')
ax[1].scatter(0, 0, color=color_water, s=50, label='water')
ax[1].set_title('Classified image')
ax[1].set_xlabel('Easting [km]')
ax[1].legend(loc='best')
# AOI
if AOI.geometry[0].geom_type=='MultiPolygon': # loop through geoms if AOI = MultiPolygon
    for j, poly in enumerate(AOI.geometry[0].geoms):
        # only include legend label for first geom
        if j==0:
            ax[0].plot([x/1e3 for x in poly.exterior.coords.xy[0]], [y/1e3 for y in poly.exterior.coords.xy[1]], '-k', linewidth=1, label='AOI')
        else:
            ax[0].plot([x/1e3 for x in poly.exterior.coords.xy[0]], [y/1e3 for y in poly.exterior.coords.xy[1]], '-k', linewidth=1, label='_nolegend_')
        ax[1].plot([x/1e3 for x in poly.exterior.coords.xy[0]], [y/1e3 for y in poly.exterior.coords.xy[1]], '-k', linewidth=1, label='_nolegend_')
else:
    ax[0].plot([x/1e3 for x in AOI.geometry[0].exterior.coords.xy[0]], [y/1e3 for y in AOI.geometry[0].exterior.coords.xy[1]], '-k', linewidth=1, label='AOI')
    ax[1].plot([x/1e3 for x in AOI.geometry[0].exterior.coords.xy[0]], [y/1e3 for y in AOI.geometry[0].exterior.coords.xy[1]], '-k', linewidth=1, label='_nolegend_')
# reset x and y limits
ax[0].set_xlim(xmin, xmax)
ax[0].set_ylim(ymin, ymax)
ax[1].set_xlim(xmin, xmax)
ax[1].set_ylim(ymin, ymax)
fig.suptitle(im_date)
fig.tight_layout()
plt.show()