# Script to classify years based on a model trained with random points

## Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import BaggingClassifier

from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay

from tqdm import tqdm

import random

import salishsea_tools.viz_tools as sa_vi


## Datasets Preparation

In [None]:
def datasets_preparation(dataset, dataset2, dataset3):
    
    drivers = np.stack([np.ravel(dataset['Temperature_(0m-15m)']),
        np.ravel(dataset['Temperature_(15m-100m)']), 
        np.ravel(dataset['Salinity_(0m-15m)']),
        np.ravel(dataset['Salinity_(15m-100m)']),
        np.ravel(dataset2['Summation_of_solar_radiation']),
        np.ravel(dataset2['Mean_wind_speed']),
        np.ravel(dataset2['Mean_air_temperature'])
        ])
    indx = np.where(~np.isnan(drivers).any(axis=0))
    drivers = drivers[:,indx[0]]

    diat = np.ravel(dataset3['Clusters_Diatom_Sorted'])
    diat = diat[indx[0]]

    return(drivers, diat, indx)

## Classifier

In [None]:
def classifier (inputs, targets):
    
    inputs = inputs.transpose()
    
    model = DecisionTreeClassifier()
    model = make_pipeline(StandardScaler(), model)
    
    clsf = BaggingClassifier(model, n_estimators=10, n_jobs=10).fit(inputs, targets)

    return (clsf)

## Classifier 2

In [None]:
def classifier2 (inputs, targets, variable_name):

    inputs2 = inputs.transpose()
    
    outputs = clsf.predict(inputs2)
    cm = confusion_matrix(targets,outputs)

    # Post processing
    indx2 = np.full((len(diat_i.y)*len(diat_i.x)),np.nan)
    indx2[indx[0]] = outputs
    model = np.reshape(indx2,(len(diat_i.y),len(diat_i.x)))

    # Preparation of the dataarray 
    model = xr.DataArray(model,
        coords = {'y': diat_i.y, 'x': diat_i.x},
        dims = ['y','x'],
        attrs=dict( long_name = variable_name,
        units="mmol m-2"),)
                        
    plotting3(targets, model, diat_i, variable_name, cm)

## Plotting

In [None]:
def plotting(variable, name):

    plt.plot(years,variable, marker = '.', linestyle = '')
    plt.xlabel('Years')
    plt.ylabel(name)
    plt.show()

## Plotting 2

In [None]:
def plotting2(variable,title):
    
    fig, ax = plt.subplots()

    scatter= ax.scatter(dates,variable, marker='.', c=pd.DatetimeIndex(dates).month)

    ax.legend(handles=scatter.legend_elements()[0], labels=['February','March','April'])
    fig.suptitle('Daily ' + title + ' (15 Feb - 30 Apr)')
    
    fig.show()

## Plotting 3

In [None]:
def plotting3(targets, model, variable, variable_name, cm):

    disp = ConfusionMatrixDisplay(confusion_matrix=cm)

    fig, ax = plt.subplots(2,2, figsize = (10,15))
    cmap = plt.get_cmap('Paired', np.max(np.arange(0,5))+1)
    cmap.set_bad('gray')

    variable.plot(ax=ax[0,0], cmap=cmap, vmax=np.max(np.arange(0,5))+1, cbar_kwargs={'label': 'Cluster [count]', 'ticks': np.arange(0,5)})

    model.plot(ax=ax[0,1], cmap=cmap, vmax=np.max(np.arange(0,5))+1, cbar_kwargs={'label': 'Cluster [count]', 'ticks': np.arange(0,5)})

    cmap = plt.get_cmap('cubehelix')
    cmap.set_bad('gray')

    diat_i2.plot(ax=ax[1,0], cmap=cmap, cbar_kwargs={'label': variable_name + ' Concentration  [mmol m-2]'})
    disp.plot(ax=ax[1,1])

    plt.subplots_adjust(left=0.1,
        bottom=0.1, 
        right=0.95, 
        top=0.95, 
        wspace=0.35, 
        hspace=0.35)

    sa_vi.set_aspect(ax[0,0])
    sa_vi.set_aspect(ax[0,1])
    sa_vi.set_aspect(ax[1,0])

    ax[0,0].title.set_text(variable_name + 'Clustering')
    ax[0,1].title.set_text(variable_name + 'Classification')
    ax[1,0].title.set_text('Diatom Concentration')
    ax[1,1].title.set_text('Confusion Matrix')

    fig.suptitle(str(dates[i].date()))

    plt.show()


## Training (Random Points)

In [None]:
ds = xr.open_dataset('/data/ibougoudis/MOAD/files/integrated_model_var_old.nc')
ds2 = xr.open_dataset('/data/ibougoudis/MOAD/files/external_inputs.nc')
ds3 = xr.open_dataset('/data/ibougoudis/MOAD/files/clustering_diatom_sorted.nc')

# ds = ds.isel(time_counter = (np.arange(0, len(ds.time_counter),2)), 
#     y=(np.arange(ds.y[0], ds.y[-1], 5)), 
#     x=(np.arange(ds.x[0], ds.x[-1], 5)))

# ds2 = ds2.isel(time_counter = (np.arange(0, len(ds2.time_counter),2)), 
#     y=(np.arange(ds2.y[0], ds2.y[-1], 5)), 
#     x=(np.arange(ds2.x[0], ds2.x[-1], 5)))

# ds3 = ds3.isel(time_counter = (np.arange(0, len(ds3.time_counter),2)), 
#     y=(np.arange(ds3.y[0], ds3.y[-1], 5)), 
#     x=(np.arange(ds3.x[0], ds3.x[-1], 5)))

dataset = ds.sel(time_counter = slice('2007', '2020'))
dataset2 = ds2.sel(time_counter = slice('2007', '2020'))
dataset3 = ds3.sel(time_counter = slice('2007', '2020'))

drivers, diat, _ = datasets_preparation(dataset, dataset2, dataset3)

clsf = classifier(drivers, diat)

## Daily Maps

In [None]:
ds = ds.sel(time_counter = slice('2021', '2023'))
ds2 = ds2.sel(time_counter = slice('2021', '2023'))
ds3 = ds3.sel(time_counter = slice('2021', '2023'))
dates = pd.DatetimeIndex(ds['time_counter'].values)

maps = random.sample(range(0,len(ds.time_counter)),10)

for i in tqdm(maps):

    dataset = ds.isel(time_counter=i)
    dataset2 = ds2.isel(time_counter=i)
    dataset3 = ds3.isel(time_counter=i)

    drivers, diat, indx = datasets_preparation(dataset, dataset2, dataset3)

    diat_i = dataset3['Clusters_Diatom_Sorted']
    diat_i2 = dataset['Diatom']

    classifier2(drivers, diat, 'Diatom ')
