# Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import salishsea_tools.viz_tools as sa_vi

from sklearn.neural_network import MLPRegressor
from sklearn.model_selection import train_test_split
from sklearn import preprocessing

from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error as mse


# Datasets Preparation

In [None]:
def datasets_preparation():

   # Driver variables
    temp = ds.votemper.where(mask==1)
    saline = ds.vosaline.where(mask==1)

    # Biological variables
    sil = ds_bio.silicon.where(mask==1)
    nitr = ds_bio.nitrate.where(mask==1)
    ammo = ds_bio.ammonium.where(mask==1)

    # Phytoplankton variables
    diat = ds_bio.diatoms.where(mask==1)
    flag = ds_bio.flagellates.where(mask==1)
    
    # Zooplankton variables
    micro = ds_bio.microzooplankton.where(mask==1)
    meso = ds_bio.mesozooplankton.where(mask==1)
    
    return (temp, saline, sil, nitr, ammo, diat, flag, micro, meso)


# MLPRegressor

In [None]:
def regressor (inputs, targets, variable_name):

    # Pre processing 
    indx = np.where(~np.isnan(inputs).any(axis=0))
    inputs2 = inputs[:,indx[0]]
    inputs2 = inputs2.transpose()

    targets2 = np.ravel(targets)
    targets2 = targets2[indx[0]]
    
    # Regressor
    scale = preprocessing.StandardScaler()
    inputs2 = scale.fit_transform(inputs2)
    X_train, X_test, y_train, y_test = train_test_split(inputs2, targets2, random_state=1)
    regr = MLPRegressor(hidden_layer_sizes = 100, activation = 'relu', solver = 'adam', max_iter=1500, random_state=1, learning_rate_init=0.01).fit(X_train, y_train)
    outputs_test = regr.predict(X_test)
    outputs = regr.predict(inputs2)

    scatter_plot(y_test, outputs_test, variable_name + ' (Testing dataset)')
    scatter_plot(targets2,outputs,variable_name)

    # Post processing
    indx2 = np.full(inputs[0,:].size,np.nan)
    indx2[indx[0]] = outputs
    model = np.reshape(indx2,(898,398)) 

    # Preparation of the dataarray 
    model = xr.DataArray(model,
                    coords = {'y': temp.y, 'x': temp.x},
                    dims = ['y','x'],
                    attrs=dict( long_name = variable_name + " Concentration",
                                units="mmol m-2"),
                        )

    plotting(targets2, model, targets, variable_name)

    return (regr)

# Printing

In [None]:
def printing (targets,outputs,m):

    print ('The slope of the best fitting line is ', np.round(m,3))
    print ('The correlation coefficient is:', np.round(r2_score(targets, outputs),3))
    print (' The mean square error is:', np.round(mse(targets,outputs),5))


## Scatter Plot

In [None]:
def scatter_plot(targets, outputs, variable_name):

    # compute slope m and intercept b
    m, b = np.polyfit(targets, outputs, deg=1)

    printing (targets, outputs, m)
    
    fig, ax = plt.subplots()

    plt.scatter(targets,outputs, alpha = 0.2, s = 10)
    plt.xlabel('targets')
    plt.ylabel('outputs')

    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]

    # plot fitted y = m*x + b
    plt.axline(xy1=(0, b), slope=m, color='r')

    ax.set_aspect('equal')
    ax.set_xlim(lims)
    ax.set_ylim(lims)

    ax.plot(lims, lims,linestyle = '--',color = 'k')

    fig.suptitle(str(date.date[0]) + ', Depth: ' + str(np.round(ds['deptht'][depth].values,2)) + ' meters' + ', ' + variable_name)

    plt.show()


# Plotting

In [None]:
def plotting (targets, model, variable, variable_name):

    fig, ax = plt.subplots(2,2, figsize = (10,15))

    cmap = plt.get_cmap('cubehelix')
    cmap.set_bad('gray')

    variable.plot(ax=ax[0,0], cmap=cmap, vmin = targets.min(), vmax =targets.max(), cbar_kwargs={'label': variable_name + ' Concentration  [mmol m-3]'})
    model.plot(ax=ax[0,1], cmap=cmap, vmin = targets.min(), vmax = targets.max(), cbar_kwargs={'label': variable_name + ' Concentration  [mmol m-3]'})
    ((variable-model)/variable*100).plot(ax=ax[1,0], cmap=cmap, cbar_kwargs={'label': variable_name + ' Concentration  [percentage]'})

    plt.subplots_adjust(left=0.1,
        bottom=0.1, 
        right=0.95, 
        top=0.95, 
        wspace=0.35, 
        hspace=0.35)

    sa_vi.set_aspect(ax[0,0])
    sa_vi.set_aspect(ax[0,1])
    sa_vi.set_aspect(ax[1,0])

    ax[0,0].title.set_text(variable_name + ' (targets)')
    ax[0,1].title.set_text(variable_name + ' (outputs)')
    ax[1,0].title.set_text('targets - outputs')
    ax[1,1].axis('off')

    fig.suptitle(str(date.date[0]) + ', Depth: ' + str(np.round(ds['deptht'][depth].values,2)) + ' meters')

    plt.show()
    

# Regressor for Other Depths

In [None]:
def regressor2 (inputs,targets, variable_name, reg_variable):
    
    # Pre processing 
    indx = np.where(~np.isnan(inputs).any(axis=0))
    inputs2 = inputs[:,indx[0]]
    inputs2 = inputs2.transpose()

    targets2 = np.ravel(targets)
    targets2 = targets2[indx[0]]

    # Regressor
    scale = preprocessing.StandardScaler()
    inputs2 = scale.fit_transform(inputs2)
    outputs = reg_variable.predict(inputs2)

    # Post processing
    indx2 = np.full(inputs[0,:].size,np.nan)
    indx2[indx[0]] = outputs
    model = np.reshape(indx2,(898,398)) 
    
    scatter_plot(targets2,outputs,variable_name)

    # Preparation of the dataarray 
    model = xr.DataArray(model,
                    coords = {'y': temp.y, 'x': temp.x},
                    dims = ['y','x'],
                    attrs=dict( long_name = variable_name + " Concentration",
                                units="mmol m-2"),
                        )

    plotting(targets2, model, targets, variable_name)
                  
    return ()

# Main Body

In [None]:
# Dataset and date
ds = xr.open_dataset ('/results2/SalishSea/nowcast-green.202111/20mar22/SalishSea_1d_20220320_20220320_grid_T.nc')
ds_bio = xr. open_dataset ('/results2/SalishSea/nowcast-green.202111/20mar22/SalishSea_1d_20220320_20220320_biol_T.nc')
date = pd.DatetimeIndex(ds['time_counter'].values)

# Open the mesh mask
mesh = xr.open_dataset('/home/sallen/MEOPAR/grid/mesh_mask202108.nc')
mask = mesh.tmask.to_numpy()

temp, saline, sil, nitr, ammo, diat, flag, micro, meso = datasets_preparation()


for depth in range (11, 12):

    # Potential input variables
    drivers = np.stack([np.ravel(temp[0,depth]), np.ravel(saline[0,depth])])
    nutrients = np.stack([np.ravel(sil[0,depth]), np.ravel(nitr[0,depth]), np.ravel(ammo[0,depth])])
    phyto = np.stack([np.ravel(diat[0,depth]), np.ravel(flag[0,depth])])
    zoo = np.stack([np.ravel(micro[0,depth]), np.ravel(meso[0,depth])])

    reg_sil = regressor (drivers,sil[0,depth], 'Silicon')
    reg_nitr = regressor (drivers,nitr[0,depth], 'Nitrate')
    reg_ammo = regressor (drivers, ammo[0,depth], 'Ammonium')

    reg_diat = regressor(drivers, diat[0,depth], 'Diatom')
    reg_flag = regressor(drivers, flag[0,depth], 'Flagellate')

    reg_micro = regressor(drivers, micro[0,depth], 'Microzooplankton')
    reg_meso = regressor(drivers, meso[0,depth], 'Mesozooplankton')


# Next Depth

In [None]:
for depth in range (13, 16):

    # Potential input variables
    drivers = np.stack([np.ravel(temp[0,depth]), np.ravel(saline[0,depth])])
    nutrients = np.stack([np.ravel(sil[0,depth]), np.ravel(nitr[0,depth]), np.ravel(ammo[0,depth])])
    phyto = np.stack([np.ravel(diat[0,depth]), np.ravel(flag[0,depth])])
    zoo = np.stack([np.ravel(micro[0,depth]), np.ravel(meso[0,depth])])

    regressor2 (drivers,sil[0,depth], 'Silicon', reg_sil)
    regressor2 (drivers,nitr[0,depth], 'Nitrate', reg_nitr)
    regressor2 (drivers, ammo[0,depth], 'Ammonium', reg_ammo)

    regressor2 (drivers, diat[0,depth], 'Diatom', reg_diat)
    regressor2 (drivers, flag[0,depth], 'Flagellate', reg_flag)

    regressor2 (drivers, micro[0,depth], 'Microzooplankton', reg_micro)
    regressor2 (drivers, meso[0,depth], 'Mesozooplankton', reg_meso)