# Classification of Zooplankton Type (Boxes)

## Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.pipeline import make_pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler

from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.metrics import classification_report, precision_recall_fscore_support, confusion_matrix, ConfusionMatrixDisplay
from sklearn.inspection import permutation_importance

from tqdm import tqdm
import cmocean.cm as cm
import salishsea_tools.viz_tools as sa_vi


## Datasets Preparation

In [None]:
# Creation of the training - testing datasets
def datasets_preparation(dataset, classes, inputs_names, regions):

    x = np.tile(dataset.x, len(dataset.time_counter)*len(dataset.y))
    y = np.tile(np.repeat(dataset.y, len(dataset.x)), len(dataset.time_counter))
    
    inputs = []
    
    for i in inputs_names:
        inputs.append(dataset[i].to_numpy().flatten())  

    inputs = np.array(inputs)

    targets = np.ravel(classes)

    regions = np.tile(np.ravel(regions), len(np.unique(dataset.time_counter)))   
     
    indx = np.where(np.isfinite(targets) & (x>10) & ((x>100) | (y<880)))
    inputs = inputs[:,indx[0]]
    targets = targets[indx[0]]
    regions = regions[indx[0]]

    inputs = inputs.transpose()

    return(inputs, targets, indx, regions)


## Plotting (Regions)

In [None]:
def plot_box(ax, corn, colour):

    ax.plot([corn[2], corn[3], corn[3], corn[2], corn[2]], 
    [corn[0], corn[0], corn[1], corn[1], corn[0]], '-', color=colour)
    

## New Classes

In [None]:
ds = xr.open_dataset('/data/ibougoudis/MOAD/files/clustering/feb_apr_yearly_clustering.nc')

z1 = xr.where(ds['Z1'] <5, 0, ds['Z1'])
z2 = xr.where(ds['Z2'] <5, 0, ds['Z2'])

classes = z1
classes = xr.where((z1==5) & (z2==0), 1, classes) # z1 high
classes = xr.where((z1==5) & (z2==5), 3, classes) # z1 and z2 high
classes = xr.where((z2==5) & (z1==0), 2, classes) # z2 high

# Low resolution

# classes = classes.isel(y=(np.arange(classes.y[0], classes.y[-1], 5)), 
#     x=(np.arange(classes.x[0], classes.x[-1], 5)))


## Initiation

In [None]:
filename = '/data/ibougoudis/MOAD/files/inputs/feb_apr.nc'

# drivers =  ['Summation_of_solar_radiation', 'Summation_of_longwave_radiation', 'Mean_precipitation', 'Mean_pressure', 'Mean_air_temperature', 'Mean_specific_humidity', 'Mean_wind_speed']

drivers =  ['Summation_of_solar_radiation', 'Summation_of_longwave_radiation', 'Mean_precipitation', 'Mean_pressure', 'Mean_air_temperature', 'Mean_specific_humidity', 'Mean_wind_speed']
spatial = ['Latitude', 'Longitude']
day_input = ['Day_of_year']

inputs_names = drivers + spatial + day_input

n_bins = 255

if filename[35:42] == 'jan_mar': # 75 days, 1st period
    period = '(16 Jan - 31 Mar)'
    id = '1'
    months = ['January', 'February', 'March']

elif filename[35:42] == 'jan_apr': # 120 days, 2nd period
    period = '(01 Jan - 30 Apr)'
    id = '2'
    months = ['January', 'February', 'March', 'April']

elif filename[35:42] == 'feb_apr': # 75 days, 3rd period
    period = '(15 Feb - 30 Apr)'
    id = '3'
    months = ['February', 'March', 'April']

elif filename[35:42] == 'apr_jun': # 76 days, 4th period
    period = '(16 Apr - 30 Jun)'
    id = '4'
    months = ['April', 'May', 'June']

elif filename[35:42] == 'may_sep': # 153 days, 5th period
    period = '(01 May - 30 Sep)'
    id = '5'
    months = ['May', 'June', 'July', 'August', 'September']
   
ds = xr.open_dataset(filename)
ds0 = ds # For the regional plot

# Low resolution

# ds = ds.isel(y=(np.arange(ds.y[0], ds.y[-1], 5)), 
#     x=(np.arange(ds.x[0], ds.x[-1], 5)))


## Regions

In [None]:
bathy = xr.open_dataset('/home/sallen/MEOPAR/grid/bathymetry_202108.nc')

fig, ax = plt.subplots(1, 1, figsize=(5, 9))
mycmap = cm.deep
mycmap.set_bad('grey')
ax.pcolormesh(bathy.Bathymetry, cmap=mycmap)
sa_vi.set_aspect(ax)

SoG_north = [650, 730, 100, 200]
plot_box(ax, SoG_north, 'g')
SoG_center = [450, 550, 200, 300]
plot_box(ax, SoG_center, 'b')
Fraser_plume = [380, 460, 260, 330]
plot_box(ax, Fraser_plume, 'm')
SoG_south = [320, 380, 280, 350]
plot_box(ax, SoG_south, 'k')
Haro_Boundary = [290, 350, 210, 280]
plot_box(ax, Haro_Boundary, 'm')
JdF_west = [250, 425, 25, 125]
plot_box(ax, JdF_west, 'c')
JdF_east = [200, 290, 150, 260]
plot_box(ax, JdF_east, 'w')
PS_all = [0, 200, 80, 320]
plot_box(ax, PS_all, 'm')
PS_main = [20, 150, 200, 280]
plot_box(ax, PS_main, 'r')

boxnames = ['GN','GC','FP','GS', 'HB', 'JdFW', 'JdFE', 'PSA', 'PSM']
fig.legend(boxnames)

boxes = [SoG_north,SoG_center,Fraser_plume,SoG_south,Haro_Boundary,JdF_west,JdF_east,PS_all,PS_main]

regions0 = np.full((len(ds0.y),len(ds0.x)),np.nan)
for i in range (0, len(boxes)):
    regions0[boxes[i][0]:boxes[i][1], boxes[i][2]:boxes[i][3]] = i

regions0 = xr.DataArray(regions0,dims = ['y','x'])

# Low resolution

# temp = []
# for i in boxes:

#     temp.append([x//5 for x in i])

# boxes = temp

# regions0 = regions0.isel(y=(np.arange(regions0.y[0], regions0.y[-1], 5)), 
#     x=(np.arange(regions0.x[0], regions0.x[-1], 5)))


## Training

In [None]:
dataset = ds.sel(time_counter = slice('2007', '2020'))
classes2 = classes.sel(time_counter = slice('2007', '2020'))

labels = np.unique(dataset.time_counter.dt.strftime('%d %b'))
indx_labels = np.argsort(pd.to_datetime(labels, format='%d %b'))
labels = labels[indx_labels]

inputs, targets, indx, regions = datasets_preparation(dataset, classes2, inputs_names, regions0)

clf_all = []
predictions = np.full(targets.shape,np.nan) # size of targets without nans

if spatial == []:
    model = make_pipeline(ColumnTransformer(
        transformers=[('drivers', StandardScaler(), np.arange(0,len(drivers)))], remainder='passthrough'),
        HistGradientBoostingClassifier(categorical_features=[len(drivers)]))

else:
    model = make_pipeline(ColumnTransformer(
    transformers=[('drivers', StandardScaler(), np.arange(0,len(drivers))), 
        ('spatial', KBinsDiscretizer(n_bins=n_bins,encode='ordinal',strategy='quantile'), np.arange(inputs_names.index(spatial[0]),inputs_names.index(spatial[-1])+1))],
        remainder='passthrough'),
    HistGradientBoostingClassifier(categorical_features=np.arange(inputs_names.index(spatial[0]),len(inputs_names))))

for i in tqdm(range (0, len(boxes))):

    indx2 = np.where(regions==i)
    inputs2 = inputs[indx2[0],:]
    targets2 = targets[indx2[0]]

    clf = model.fit(inputs2,targets2)
    predictions[indx2[0]] = clf.predict(inputs2)

    clf_all.append(clf)


## Training Years

In [None]:
targets_names = ['no zooplankton', 'high Z1', 'high Z2', 'high Z1 & Z2']
print('Classification Report (Training)')
print('\n')

metrics_all = []
cm_all = []

for i in range (0, len(boxnames)):

    print (boxnames[i])

    indx2 = np.where(regions==i)
    predictions2 = predictions[indx2[0]]
    targets2 = targets[indx2[0]]

    print(classification_report(targets2, predictions2, target_names=targets_names))
    metrics = precision_recall_fscore_support(targets2, predictions2, labels = [0,1,2,3])
    metrics_all.append(np.array(metrics))

    cm = confusion_matrix(targets2, predictions2, labels = [0,1,2,3])
    cm_all.append(cm)
    ConfusionMatrixDisplay.from_predictions(targets2, predictions2, display_labels=targets_names, colorbar=False)
    plt.title('Confusion Matrix for ' + boxnames[i])
    plt.show()

metrics_all = np.array(metrics_all)
cm_all = np.array(cm_all)


## Feature Importance (Training)

In [None]:
# accuracy_importance_all = []
# recall_importance_all = []

# inputs_names2 = ['SWR', 'LWR', 'TP', 'AP', 'AT', 'SH', 'WS', 'Lat', 'Lon', 'Day']

# for i in tqdm(range (0, len(boxes))):

#     indx2 = np.where(regions==i)
#     inputs2 = inputs[indx2[0],:]
#     targets2 = targets[indx2[0]]

#     importances_all = permutation_importance(clf_all[i], inputs2, targets2, n_repeats=5, scoring=('accuracy','recall_macro'))

#     accuracy_importance = pd.Series(importances_all['accuracy'].importances_mean)
#     accuracy_importance_all.append(accuracy_importance)

#     recall_importance = pd.Series(importances_all['recall_macro'].importances_mean)
#     recall_importance_all.append(recall_importance)

# accuracy_importance_all = np.array(accuracy_importance_all)

# plt.pcolormesh(np.transpose(accuracy_importance_all), cmap='cividis')
# plt.yticks(ticks=np.arange(0.5,len(inputs_names2)), labels=inputs_names2)
# plt.xticks(ticks=np.arange(0.5,len(boxnames)), labels=boxnames)
# plt.show()

# recall_importance_all = np.array(recall_importance_all)
# plt.pcolormesh(np.transpose(recall_importance_all), cmap='cividis')
# plt.yticks(ticks=np.arange(0.5,len(inputs_names2)), labels=inputs_names2)
# plt.xticks(ticks=np.arange(0.5,len(boxnames)), labels=boxnames)
# plt.show()


## Testing Years

In [None]:
dataset = ds.sel(time_counter = slice('2021', '2024'))
classes2 = classes.sel(time_counter = slice('2021', '2024'))

inputs_test, targets_test, indx_test, regions = datasets_preparation(dataset, classes2, inputs_names, regions0)

predictions_test = np.full(targets_test.shape,np.nan) # size of targets without nans

for i in range(0, len(boxes)):

    indx2 = np.where(regions==i)
    inputs2 = inputs_test[indx2[0],:]
    targets2 = targets_test[indx2[0]]

    predictions2 = clf_all[i].predict(inputs2)
    predictions_test[indx2[0]] = predictions2

    print('Classification Report (Testing)')
    print('\n')

    metrics_all_test = []
    cm_all_test = []

    print (boxnames[i])

    print(classification_report(targets2, predictions2, target_names=targets_names))
    metrics = precision_recall_fscore_support(targets2, predictions2, labels = [0,1,2,3])
    metrics_all_test.append(np.array(metrics))

    cm = confusion_matrix(targets2, predictions2, labels = [0,1,2,3])
    cm_all_test.append(cm)
    ConfusionMatrixDisplay.from_predictions(targets2, predictions2, display_labels=targets_names, colorbar=False)
    plt.title('Confusion Matrix for ' + boxnames[i])
    plt.show()

    metrics_all_test = np.array(metrics_all_test)
    cm_all_test = np.array(cm_all_test)


## Feature Importance (Testing)

In [None]:
accuracy_importance_all = []
recall_importance_all = []

inputs_names2 = ['SWR', 'LWR', 'TP', 'AP', 'AT', 'SH', 'WS', 'Lat', 'Lon', 'Day']

for i in tqdm(range (0, len(boxes))):

    indx2 = np.where(regions==i)
    inputs2 = inputs_test[indx2[0],:]
    targets2 = targets_test[indx2[0]]

    importances_all = permutation_importance(clf_all[i], inputs2, targets2, n_repeats=5, scoring=('accuracy','recall_macro'))

    accuracy_importance = pd.Series(importances_all['accuracy'].importances_mean)
    accuracy_importance_all.append(accuracy_importance)

    recall_importance = pd.Series(importances_all['recall_macro'].importances_mean)
    recall_importance_all.append(recall_importance)

accuracy_importance_all = np.array(accuracy_importance_all)

plt.pcolormesh(np.transpose(accuracy_importance_all), cmap='cividis')
plt.yticks(ticks=np.arange(0.5,len(inputs_names2)), labels=inputs_names2)
plt.xticks(ticks=np.arange(0.5,len(boxnames)), labels=boxnames)
plt.show()

recall_importance_all = np.array(recall_importance_all)
plt.pcolormesh(np.transpose(recall_importance_all), cmap='cividis')
plt.yticks(ticks=np.arange(0.5,len(inputs_names2)), labels=inputs_names2)
plt.xticks(ticks=np.arange(0.5,len(boxnames)), labels=boxnames)
plt.show()
