# Classification of Zooplankton Type

## Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.pipeline import make_pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler

from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.metrics import classification_report, precision_recall_fscore_support, confusion_matrix, ConfusionMatrixDisplay
from sklearn.inspection import permutation_importance

from tqdm import tqdm
import cmocean.cm as cm
import salishsea_tools.viz_tools as sa_vi


## Datasets Preparation

In [None]:
# Creation of the training - testing datasets
def datasets_preparation(dataset, classes, inputs_names):

    x = np.tile(dataset.x, len(dataset.time_counter)*len(dataset.y))
    y = np.tile(np.repeat(dataset.y, len(dataset.x)), len(dataset.time_counter))
    
    inputs = []
    
    for i in inputs_names:
        inputs.append(dataset[i].to_numpy().flatten())  

    inputs = np.array(inputs)

    targets = np.ravel(classes)
    
    indx = np.where(np.isfinite(targets) & (x>10) & ((x>100) | (y<880)))
    inputs = inputs[:,indx[0]]
    targets = targets[indx[0]]

    inputs = inputs.transpose()

    return(inputs, targets, indx)


## Plotting (Regions)

In [None]:
def plot_box(ax, corn, colour):

    ax.plot([corn[2], corn[3], corn[3], corn[2], corn[2]], 
    [corn[0], corn[0], corn[1], corn[1], corn[0]], '-', color=colour)
    

## New Classes

In [None]:
ds = xr.open_dataset('/data/ibougoudis/MOAD/files/clustering/feb_apr_yearly_clustering.nc')

z1 = xr.where(ds['Z1'] <5, 0, ds['Z1'])
z2 = xr.where(ds['Z2'] <5, 0, ds['Z2'])

classes = z1
classes = xr.where((z1==5) & (z2==0), 1, classes) # z1 high
classes = xr.where((z1==5) & (z2==5), 3, classes) # z1 and z2 high
classes = xr.where((z2==5) & (z1==0), 2, classes) # z2 high
# classes = xr.where(classes==0, np.nan, classes)

# Low resolution

# classes = classes.isel(y=(np.arange(classes.y[0], classes.y[-1], 5)), 
#     x=(np.arange(classes.x[0], classes.x[-1], 5)))


## Initiation

In [None]:
filename = '/data/ibougoudis/MOAD/files/inputs/feb_apr.nc'

# drivers =  ['Summation_of_solar_radiation', 'Summation_of_longwave_radiation', 'Mean_precipitation', 'Mean_pressure', 'Mean_air_temperature', 'Mean_specific_humidity', 'Mean_wind_speed']

drivers =  ['Summation_of_solar_radiation', 'Summation_of_longwave_radiation', 'Mean_precipitation', 'Mean_air_temperature']
spatial = ['Latitude', 'Longitude']
day_input = ['Day_of_year']

inputs_names = drivers + spatial + day_input

n_bins = 255

if filename[35:42] == 'jan_mar': # 75 days, 1st period
    period = '(16 Jan - 31 Mar)'
    id = '1'
    months = ['January', 'February', 'March']

elif filename[35:42] == 'jan_apr': # 120 days, 2nd period
    period = '(01 Jan - 30 Apr)'
    id = '2'
    months = ['January', 'February', 'March', 'April']

elif filename[35:42] == 'feb_apr': # 75 days, 3rd period
    period = '(15 Feb - 30 Apr)'
    id = '3'
    months = ['February', 'March', 'April']

elif filename[35:42] == 'apr_jun': # 76 days, 4th period
    period = '(16 Apr - 30 Jun)'
    id = '4'
    months = ['April', 'May', 'June']

elif filename[35:42] == 'may_sep': # 153 days, 5th period
    period = '(01 May - 30 Sep)'
    id = '5'
    months = ['May', 'June', 'July', 'August', 'September']
   
ds = xr.open_dataset(filename)
ds0 = ds # For the regional plot

# Low resolution

# ds = ds.isel(y=(np.arange(ds.y[0], ds.y[-1], 5)), 
#     x=(np.arange(ds.x[0], ds.x[-1], 5)))


## Regions

In [None]:
bathy = xr.open_dataset('/home/sallen/MEOPAR/grid/bathymetry_202108.nc')

fig, ax = plt.subplots(1, 1, figsize=(5, 9))
mycmap = cm.deep
mycmap.set_bad('grey')
ax.pcolormesh(bathy.Bathymetry, cmap=mycmap)
sa_vi.set_aspect(ax)

SoG_north = [650, 730, 100, 200]
plot_box(ax, SoG_north, 'g')
SoG_center = [450, 550, 200, 300]
plot_box(ax, SoG_center, 'b')
Fraser_plume = [380, 460, 260, 330]
plot_box(ax, Fraser_plume, 'm')
SoG_south = [320, 380, 280, 350]
plot_box(ax, SoG_south, 'k')
Haro_Boundary = [290, 350, 210, 280]
plot_box(ax, Haro_Boundary, 'm')
JdF_west = [250, 425, 25, 125]
plot_box(ax, JdF_west, 'c')
JdF_east = [200, 290, 150, 260]
plot_box(ax, JdF_east, 'w')
PS_all = [0, 200, 80, 320]
plot_box(ax, PS_all, 'm')
PS_main = [20, 150, 200, 280]
plot_box(ax, PS_main, 'r')

boxnames = ['GN','GC','FP','GS', 'HB', 'JdFW', 'JdFE', 'PSA', 'PSM']
fig.legend(boxnames)

boxes = [SoG_north,SoG_center,Fraser_plume,SoG_south,Haro_Boundary,JdF_west,JdF_east,PS_all,PS_main]

regions0 = np.full((len(ds0.y),len(ds0.x)),np.nan)
for i in range (0, len(boxes)):
    regions0[boxes[i][0]:boxes[i][1], boxes[i][2]:boxes[i][3]] = i

regions0 = xr.DataArray(regions0,dims = ['y','x'])

# Low resolution

# temp = []
# for i in boxes:

#     temp.append([x//5 for x in i])

# boxes = temp

# regions0 = regions0.isel(y=(np.arange(regions0.y[0], regions0.y[-1], 5)), 
#     x=(np.arange(regions0.x[0], regions0.x[-1], 5)))


## Training

In [None]:
dataset = ds.sel(time_counter = slice('2007', '2020'))
classes2 = classes.sel(time_counter = slice('2007', '2020'))

labels = np.unique(dataset.time_counter.dt.strftime('%d %b'))
indx_labels = np.argsort(pd.to_datetime(labels, format='%d %b'))
labels = labels[indx_labels]

inputs, targets, indx = datasets_preparation(dataset, classes2, inputs_names)

if spatial == []:
    model = make_pipeline(ColumnTransformer(
        transformers=[('drivers', StandardScaler(), np.arange(0,len(drivers)))], remainder='passthrough'),
        HistGradientBoostingClassifier(categorical_features=[len(drivers)]))

else:
    model = make_pipeline(ColumnTransformer(
    transformers=[('drivers', StandardScaler(), np.arange(0,len(drivers))), 
        ('spatial', KBinsDiscretizer(n_bins=n_bins,encode='ordinal',strategy='quantile'), np.arange(inputs_names.index(spatial[0]),inputs_names.index(spatial[-1])+1))],
        remainder='passthrough'),
    HistGradientBoostingClassifier(categorical_features=np.arange(inputs_names.index(spatial[0]),len(inputs_names))))
    
clf = model.fit(inputs,targets)
predictions = clf.predict(inputs)


## Training Years

In [None]:
targets_names = ['no zooplankton', 'high Z1', 'high Z2', 'high Z1 & Z2']

(print('Classification Report (Training)'))
print(classification_report(targets, predictions, target_names=targets_names))
metrics = precision_recall_fscore_support(targets, predictions, labels = [0,1,2,3])

cm = confusion_matrix(y_true=targets, y_pred=predictions, labels=[0,1,2,3])
ConfusionMatrixDisplay.from_predictions(targets, predictions, display_labels=targets_names, colorbar=False)
plt.title('Confusion Matrix for training')
plt.show()


## Feature Importance (Training)

In [None]:
# importances_all = permutation_importance(clf, inputs, targets, n_repeats=5, scoring=('accuracy','recall_macro'))

# inputs_names2 = ['SWR', 'LWR', 'TP', 'AP', 'AT', 'SH', 'WS', 'Lat', 'Lon', 'Day']

# sorted_importances_idx = importances_all['accuracy'].importances_mean.argsort()

# x = []
# for i in range (0, len(inputs_names2)):
#     x.append(inputs_names2[sorted_importances_idx[i]])

# accuracy_importance = pd.Series(importances_all['accuracy'].importances_mean[sorted_importances_idx], index=x)

# sorted_importances_idx = importances_all['recall_macro'].importances_mean.argsort()

# x = []
# for i in range (0, len(inputs_names2)):
#     x.append(inputs_names2[sorted_importances_idx[i]])

# recall_importance = pd.Series(importances_all['recall_macro'].importances_mean[sorted_importances_idx], index=x)

# accuracy_importance.plot.bar()
# plt.show()

# recall_importance.plot.bar()
# plt.show()


## Testing Years

In [None]:
dataset = ds.sel(time_counter = slice('2021', '2024'))
classes2 = classes.sel(time_counter = slice('2021', '2024'))

inputs_test, targets_test, indx_test = datasets_preparation(dataset, classes2, inputs_names)

predictions_test = clf.predict(inputs_test)

(print('Classification Report (Testing)'))
print(classification_report(targets_test, predictions_test, target_names=targets_names))
metrics = precision_recall_fscore_support(targets, predictions, labels = [0,1,2,3])

cm = confusion_matrix(y_true=targets_test, y_pred=predictions_test, labels=[0,1,2,3])
ConfusionMatrixDisplay.from_predictions(targets, predictions, display_labels=targets_names, colorbar=False)
plt.title('Confusion Matrix for testing')
plt.show()


## Feature Importance (Testing)

In [None]:
# importances_all = permutation_importance(clf, inputs_test, targets_test, n_repeats=5, scoring=('accuracy','recall_macro'))

# inputs_names2 = ['SWR', 'LWR', 'TP', 'AP', 'AT', 'SH', 'WS', 'Lat', 'Lon', 'Day']

# sorted_importances_idx = importances_all['accuracy'].importances_mean.argsort()

# x = []
# for i in range (0, len(inputs_names2)):
#     x.append(inputs_names2[sorted_importances_idx[i]])

# accuracy_importance = pd.Series(importances_all['accuracy'].importances_mean[sorted_importances_idx], index=x)

# sorted_importances_idx = importances_all['recall_macro'].importances_mean.argsort()

# x = []
# for i in range (0, len(inputs_names2)):
#     x.append(inputs_names2[sorted_importances_idx[i]])

# recall_importance = pd.Series(importances_all['recall_macro'].importances_mean[sorted_importances_idx], index=x)

# accuracy_importance.plot.bar()
# plt.show()

# recall_importance.plot.bar()
# plt.show()
