### Mapping invasive species using supervised machine learning and AVIRIS-NG 
----

### Overview 
----

In this notebook, we will use existing data of verified land cover and alien species locations to extract spectra from AVIRIS NG surface reflectance data.

- 
### Learning Objectives
1. Understand how to inspect and prepare data for machine learning models
2. Train and interpret a machine learning model
3. Apply a trained model to AVIRIS imagery to create alien species maps

### Requirements
This tutorial requires the following Python modules installed `s3fs`, `rioxarray`, ...

----

### Load python modules

In [None]:
import geopandas as gpd
import xarray as xr
from shapely.geometry import box, mapping
import rioxarray as riox
import numpy as np
import hvplot.xarray
import holoviews as hv
import xvec
import matplotlib.pyplot as plt
import pandas as pd
from dask.diagnostics import ProgressBar

custom functions

In [None]:
#processes a multi-file datase
#to extract the first valid reflectance measurement for each geometry. 
#iterates through the data, identifying the first non-null entry for each geometry across all files, 
#and then selects those specific reflectance values. 

def get_first_xr(ds_in):
    #array with geomtery x files
    arr = ds_in['index'].data
    
    #we want to find the first file where the geom is not null
    # Initialize an array to store the indices of the first non-null entry for each column
    file_ind = np.full(arr.shape[1], -1)  # Using -1 as a placeholder for no non-null values
    
    # Iterate over each column
    for col_idx in range(arr.shape[1]):
        # Get the column
        col = arr[:, col_idx]
        
        # Find the index of the first non-null entry
        non_null_indices = np.where(~np.isnan(col))[0]
        
        if non_null_indices.size > 0:
            file_ind[col_idx] = non_null_indices[0]
    
    #create a list form 0 to len first_non_null_indices
    geom_ind = list(range(len(file_ind)))
    
    file_ind = xr.DataArray(file_ind, dims=["index"])
    geom_ind = xr.DataArray(geom_ind, dims=["index"])
    
    ds = ds_in['reflectance'][file_ind, geom_ind, :]
    ds['index'] = ds['index'].astype(int)
    raw_data_utm
    #convert to dataset
    ds = ds.to_dataset(name='reflectance')
    
    return ds

In [None]:
import warnings
warnings.filterwarnings('ignore')
hvplot.extension('bokeh')
#hvplot.extension('matplotlib')
%matplotlib inline

### 1. Open and explore land cover labels 

In [None]:
text_lab = ['Bare ground/Rock','Mature Fynbos','Recently burnt Fynbos','Wetland','Forest','Pine','Eucalyptus','Wattle','Water']
label = ['0','1','2','3','4','5','6','7','8']

lab_df = pd.DataFrame({
    'class': label,
    'text_lab': text_lab
})
lab_df

In [None]:
raw_data = gpd.read_file('/home/gmoncrieff/ct_invasive.gpkg')
raw_data_utm = (raw_data
                .to_crs("EPSG:32734")
                .merge(lab_df, on='class', how='left')
               )
raw_data_utm.head()

In [None]:
#explore data in interactive map. color by class. use google sattelite basemap
(raw_data_utm[['text_lab','geometry']]
 .explore('text_lab',tiles='https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}', attr='Google'))

### 2. Extract AVIRIS data at label locations

In [None]:
AVNG_Coverage = gpd.read_file('ANG_Coverage.geojson')
#filter dates to between midnight on 2023-11-09 and 23:59:59 on 2023-11-09
AVNG_CP = AVNG_Coverage[(AVNG_Coverage['end_time'] >= '2023-11-09 00:00:00') & (AVNG_Coverage['end_time'] <= '2023-11-09 23:59:59')]
#keep only AVNG_CP that intersects with raw_data
AVNG_CP = AVNG_CP[AVNG_CP.intersects(raw_data.union_all())]

#make a list of filenames
files = AVNG_CP['RFL s3'].tolist()
files.pop(70)

#filter to start time between
(AVNG_CP[['fid','geometry']]
 .explore('fid'))



In [None]:
test = xr.open_dataset(files[30], engine='kerchunk', chunks='auto')
test = test.where(test>0)
test

In [None]:
test.sel(wavelength=[660, 570, 480], method="nearest").hvplot.rgb('x', 'y',
                                                                  rasterize=True,data_aspect=1,robust=True,
                                                                  bands='wavelength',frame_width=400
)

In [None]:
test.sel({'wavelength': 660},method='nearest').hvplot('x', 'y',
                                                      rasterize=True, data_aspect=1,robust=True,
                                                      cmap='magma',frame_width=400
)

In [None]:
def extract_points(file,points):
    
    ds = xr.open_dataset(file, engine='kerchunk', chunks='auto')
    
    # Get the bounding box coordinates
    left, bottom, right, top = ds.rio.bounds()
    
    # Create a Shapely box geometry
    bbox_shapely = box(left, bottom, right, top)
    
    # Clip the raw data to the bounding box
    points = points.clip(bbox_shapely)
    print(f'got {points.shape[0]} point from {file}')
    
    # Extract points
    extracted = ds.xvec.extract_points(points['geometry'], x_coords="x", y_coords="y",index=True)
    
    return extracted

In [None]:
ds_all = [extract_points(file,raw_data_utm) for file in files]

In [None]:
ds_all  = xr.concat(ds_all, dim='file')
ds_all

extract the first valid reflectance measurement for each geometry

In [None]:

ds = get_first_xr(ds_all)
ds

merge with point data to add labels

In [None]:
class_xr =raw_data_utm[['class','group']].to_xarray()
ds = ds.merge(class_xr.astype(int),join='left')
ds

In [None]:
with ProgressBar():
 dsp = ds.persist()

In [None]:
dsp

### 3. Inspect AVIRIS spectra

In [None]:
dsp_plot = dsp.where(dsp['class']==0, drop=True)
dsp_plot['reflectance'].hvplot.line(x='wavelength',by='index',
                                    color='green',ylim=(0,0.5),alpha=0.5,legend=False)

> edit data

### 4. Prep data for ML model

In [None]:
wavelengths_to_drop = ds.wavelength.where(
    (ds.wavelength < 420) |
    (ds.wavelength >= 1340) & (ds.wavelength <= 1450) |
    (ds.wavelength >= 1800) & (ds.wavelength <= 1980) |
    (ds.wavelength > 2400), drop=True
)

# Use drop_sel() to remove those specific wavelength ranges
dsp = dsp.drop_sel(wavelength=wavelengths_to_drop)

In [None]:
# Calculate the L2 norm along the 'wavelength' dimension in a Dask-aware way
l2_norm = np.sqrt((dsp['reflectance'] ** 2).sum(dim='wavelength'))

# Normalize the reflectance by dividing by the L2 norm
dsp['reflectance'] = dsp['reflectance'] / l2_norm

In [None]:
dsp_norm_plot = dsp.where(dsp['class']==4, drop=True)
dsp_norm_plot['reflectance'].hvplot.line(x='wavelength',by='index',
                                         color='green',ylim=(0,0.2),alpha=0.5,legend=False)

### 5. Train and evaluate ML model

In [None]:
import xgboost as xgb
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, ConfusionMatrixDisplay

In [None]:
dtrain = dsp.where(dsp['group']==1,drop=True)
dtest = dsp.where(dsp['group']==2,drop=True)

y_train = dtrain['class'].values.astype(int)
y_test = dtest['class'].values.astype(int)
X_train = dtrain['reflectance'].values
X_test = dtest['reflectance'].values

In [None]:
dtrain = dsp.where(dsp['group']==1,drop=True)
dtest = dsp.where(dsp['group']==2,drop=True)
# Label encode the class variable
le = LabelEncoder()

y_train = le.fit_transform(dtrain['class'].values)
y_test = le.transform(dtest['class'].values)
X_train = dtrain['reflectance'].values
X_test = dtest['reflectance'].values


In [None]:
# Define the hyperparameter grid
param_grid = {
    'n_estimators': [50, 250, 500],
    'max_depth': [3, 5, 9],
    'learning_rate': [0.1, 0.1, 0.001],
    'subsample': [0.5, 1],
    'min_child_weight': [1, 3, 5],
    'gamma': [0, 0.1, 0.3]
}

param_grid = {
    'max_depth': [5],
    'learning_rate': [0.1],
    'subsample': [0.5, 1],
    'n_estimators' : [50,100]
}

# Create the XGBoost model object
xgb_model = xgb.XGBClassifier(tree_method='hist')

# Create the GridSearchCV object
grid_search = GridSearchCV(xgb_model, param_grid, cv=5, scoring='accuracy')

# Fit the GridSearchCV object to the training data
grid_search.fit(X_train, y_train)

# Print the best set of hyperparameters and the corresponding score
print("Best set of hyperparameters: ", grid_search.best_params_)
print("Best score: ", grid_search.best_score_)
best_model = grid_search.best_estimator_

In [None]:
y_pred = best_model.predict(X_test)

# Step 2: Calculate F1 score for the entire dataset
f1 = f1_score(y_test, y_pred, average='weighted')  # 'weighted' accounts for class imbalance
print(f"F1 Score (weighted): {f1}")

# Step 3: Calculate precision and recall for class 3
precision_class_3 = precision_score(y_test, y_pred, labels=[3], average='macro', zero_division=0)
recall_class_3 = recall_score(y_test, y_pred, labels=[3], average='macro', zero_division=0)

print(f"Precision for Class 3: {precision_class_3}")
print(f"Recall for Class 3: {recall_class_3}")

# Step 4: Plot the confusion matrix
conf_matrix = confusion_matrix(y_test, y_pred)

ConfusionMatrixDisplay(confusion_matrix=conf_matrix).plot()

> conformal

### 6. Interpret and understand ML model

SHAP scores

In [None]:
import shap
#shap.initjs()

# Initialize the SHAP Tree Explainer for XGBoost
explainer = shap.TreeExplainer(best_model,feature_names=list(map(str, dsp.wavelength.values.astype(int))))
shap_values = explainer(X_test)

In [None]:
sel_in = 5
shap.plots.waterfall(shap_values[sel_in,:,y_test[sel_in]])


#lables for wavelength in plot
#feats =  dsp.wavelength.values.astype(int)
#feats = map(str,feats)
#shap.force_plot(explainer.expected_value[y_test[sel_in]], shap_values.values[sel_ind,:,y_test[sel_in]], pd.DataFrame(X_test).iloc[sel_ind, :],link='logit',feature_names=list(feats))

In [None]:
shap.plots.beeswarm(shap_values[:,:,y_test[sel_in]].abs, color="shap_red")

In [None]:
#plot spectra and importance
importance = np.abs(shap_values[sel_in,:,y_test[sel_in]].values)
# Create the base plot
fig, ax = plt.subplots(figsize=(12, 6))

# Get the wavelength and importance data
wavelength = dsp['wavelength'].values
importance = importance  # Make sure this aligns with your wavelength data

# Create a colormap
cmap = plt.get_cmap('hot').reversed()  # You can choose a different colormap if you prefer

# Normalize importance values to [0, 1] for colormap
norm = plt.Normalize(importance.min(), importance.max())

# Add shading
for i in range(len(wavelength) - 1):
    ax.fill_between([wavelength[i], wavelength[i+1]], 0, 1, 
                    color=cmap(norm(importance[i])), alpha=0.3)

# Add a colorbar to show the importance scale
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, label='Importance')

# Add white blocks to obscure specified regions
ax.fill_between([0,420], 0, 1, color='white')
ax.fill_between([1340,1458], 0, 1, color='white')
ax.fill_between([1800,1980], 0, 1, color='white')
ax.fill_between([2400,2500], 0, 1, color='white')
ax.set_xlim(420,2400)

plot_xr = xr.DataArray(X_test[sel_in], coords=[wavelength], dims=["wavelength"])
plot_xr.plot.line(x='wavelength', color='green', ylim=(0, 0.2), ax=ax,zorder=0)

plt.title('Reflectance with Importance Shading')
plt.xlabel('Wavelength')
plt.ylabel('Reflectance (normalized)')

plt.tight_layout()
plt.show()

### 7. Prep AVIRIS scenes or prediction

Let me explain the key components of this solution:

predict_on_chunk function:

This function is designed to work on individual chunks of your data.
It ensures the input is a numpy array and reshapes it if necessary to match your model's input requirements.
It applies the model's predict function to the chunk.


predict_xarray function:

This function uses xarray's apply_ufunc to apply the predict_on_chunk function across your entire dataset.
It respects the chunking of your xarray DataArray and processes data chunk by chunk.
The input_core_dims=[['wavelength']] specifies that the 'wavelength' dimension should be passed to the function.
output_core_dims=[[]] assumes that the prediction output is a single value per sample. Adjust this if your model outputs multiple values per sample.
vectorize=True allows the function to work on arrays.
dask='allowed' enables parallel processing with dask.


Usage:

You can apply this function to your xarray DataArray.
The compute() call at the end triggers the actual computation and returns the results.

In [None]:
SAPAD = (gpd.read_file('SAPAD_2024.gpkg')
         .query("SITE_TYPE!='Marine Protected Area'")
        )
# Get the bounding box of the first GeoDataFrame
bbox = raw_data.total_bounds  # (minx, miny, maxx, maxy)
gdf_bbox = gpd.GeoDataFrame({'geometry': [box(*bbox)]}, crs=raw_data.crs)  # Specify the CRS
gdf_bbox['geometry'] = gdf_bbox.buffer(0.02)

# Filter the second GeoDataFrame to keep only the rows that intersect with the buffered bbox
SAPAD_CT = SAPAD.overlay(gdf_bbox,how='intersection')

In [None]:
SAPAD_CT.explore()

In [None]:
#keep only AVNG_CP that intersects with raw_data
AVNG_sapad = AVNG_CP[AVNG_CP.intersects(SAPAD_CT.union_all())]
files_sapad = AVNG_sapad['RFL s3'].tolist()
geometries_sapad = SAPAD_CT.to_crs("EPSG:32734").geometry.apply(mapping)

### 8. Predict over multiple AVIRIS scenes

In [None]:
def predict_proba_on_chunk(chunk, model):
    probabilities = model.predict_proba(chunk)
    return probabilities

In [None]:
#files_s=files_sapad[62:70]
files_s=files_sapad
# Get the number of classes from a small prediction
n_classes = 9

In [None]:
files_s.pop(85)
files_s.pop(87)

In [None]:
def predict_xr(file,geometries):
    print(f'file: {file}')
    ds = xr.open_dataset(file, engine='kerchunk', chunks='auto')
    #condition to use for masking no data later
    
    condition = (ds['reflectance'] > 0).any(dim='wavelength')
    ds = ds.stack(sample=('x','y'))
    
    wavelengths_to_drop = ds.wavelength.where(
        (ds.wavelength < 420) |
        (ds.wavelength >= 1340) & (ds.wavelength <= 1450) |
        (ds.wavelength >= 1800) & (ds.wavelength <= 1980) |
        (ds.wavelength > 2400), drop=True
    )
    
    # Use drop_sel() to remove those specific wavelength ranges
    ds = ds.drop_sel(wavelength=wavelengths_to_drop)
    
    # Calculate the L2 norm along the 'wavelength' dimension in a Dask-aware way
    l2_norm = np.sqrt((ds['reflectance'] ** 2).sum(dim='wavelength'))
    
    # Normalize the reflectance by dividing by the L2 norm
    ds['reflectance'] = ds['reflectance'] / l2_norm
    
    
    # Use apply_ufunc to apply the prediction function over chunks
    result = xr.apply_ufunc(
        predict_proba_on_chunk,
        ds['reflectance'],
        input_core_dims=[['wavelength']],#input dim with features
        output_core_dims=[['class']],  # output dims for probabilities
        exclude_dims=set(('wavelength',)),  #dims to drop in result
        output_sizes={'class': n_classes},
        output_dtypes=[np.float32],
        dask="parallelized",
        kwargs={'model': best_model}
    )
    result = result.unstack('sample')
    result = result.rio.set_spatial_dims(x_dim='x',y_dim='y')
    result = result.rio.write_crs("EPSG:32734")
    result = result.rio.clip(geometries).where(condition)
    result = result.transpose('class', 'y', 'x')
    return result

In [None]:
test  = predict_xr(files_s[53],geometries_sapad)
test

In [None]:
test = test.rio.reproject("EPSG:4326")

In [None]:
test.isel({'class':0}).hvplot(tiles=hv.element.tiles.EsriImagery(), 
                              project=True,rasterize=True,robust=True,
                              cmap='magma',frame_width=400,data_aspect=1,alpha=0.5)

In [None]:
grid_pred = [predict_xr(fi,geometries_sapad) for fi in files_s]

## 9. Merge and mosaic results

In [None]:
from rioxarray.merge import merge_arrays

In [None]:
merged = merge_arrays(grid_pred)
merged = merged.rio.reproject("EPSG:4326")
merged.rio.to_raster('/home/gmoncrieff/ct_invasive.tiff',driver="COG")

In [None]:
merged = xr.open_dataset('/home/gmoncrieff/ct_invasive.tiff', engine='rasterio', chunks='auto')

In [None]:
merged

In [None]:
merged.isel({'band':6}).hvplot(x='x',y='y',tiles=hv.element.tiles.EsriImagery(),
                               geo=True,
                                project=True,rasterize=True,robust=True,
                                cmap='magma',clim=(0,1), frame_width=400,data_aspect=1,alpha=0.5)

> return to step 1