# Temporary script for predicting Diatom concenctration

## Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.pipeline import make_pipeline
from sklearn.compose import TransformedTargetRegressor
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import MinMaxScaler

import xgboost as xgb

from sklearn.ensemble import BaggingRegressor

from sklearn.metrics import root_mean_squared_error as rmse

from tqdm import tqdm

import dill
import random

import salishsea_tools.viz_tools as sa_vi

## Datasets Preparation

In [None]:
def datasets_preparation(dataset, dataset2):
    
    drivers = np.stack([np.ravel(dataset['Temperature_(0m-15m)']),
        np.ravel(dataset['Temperature_(15m-100m)']), 
        np.ravel(dataset['Salinity_(0m-15m)']),
        np.ravel(dataset['Salinity_(15m-100m)']),
        np.ravel(dataset2['Summation_of_solar_radiation']),
        np.ravel(dataset2['Mean_wind_speed']),
        np.ravel(dataset2['Mean_air_temperature']),
        np.tile(np.repeat(dataset.y, len(dataset.x)), len(dataset.time_counter)),
        np.tile(dataset.x, len(dataset.time_counter)*len(dataset.y)),
        np.repeat(dataset.time_counter.dt.dayofyear, len(dataset.x)*len(dataset.y))
        ])

    indx = np.where(~np.isnan(drivers).any(axis=0) & (drivers[8]>10) & ((drivers[8]>100) | (drivers[7]<880)))
    drivers = drivers[:,indx[0]]

    diat = np.ravel(dataset['Diatom'])
    diat = diat[indx[0]]

    drivers = drivers.transpose()

    return(drivers, diat, indx)

## Regressor

In [None]:
def regressor (inputs, targets):

    model = TransformedTargetRegressor(regressor=make_pipeline(ColumnTransformer(
        transformers=[('spatial', OneHotEncoder(), [7,8]), ('temporal', OneHotEncoder(), [9])],remainder=MinMaxScaler()),
        xgb.XGBRegressor(n_estimators=1000, max_depth=7, eta=0.1, subsample=0.7, colsample_bytree=0.8)),
        transformer=MinMaxScaler())
    regr = BaggingRegressor(model, n_estimators=10, n_jobs=-1).fit(inputs,targets)

    return(regr)

## Scatter Plot

In [None]:
def scatter_plot(targets, outputs, variable_name):

    # compute slope m and intercept b
    m, b = np.polyfit(targets, outputs, deg=1)

    fig, ax = plt.subplots(2, figsize=(5,10), layout='constrained')

    ax[0].scatter(targets,outputs, alpha = 0.2, s = 10)

    lims = [np.min([ax[0].get_xlim(), ax[0].get_ylim()]),
        np.max([ax[0].get_xlim(), ax[0].get_ylim()])]

    # plot fitted y = m*x + b
    ax[0].axline(xy1=(0, b), slope=m, color='r')

    ax[0].set_xlabel('targets')
    ax[0].set_ylabel('outputs')
    ax[0].set_xlim(lims)
    ax[0].set_ylim(lims)
    ax[0].set_aspect('equal')

    ax[0].plot(lims, lims,linestyle = '--',color = 'k')

    h = ax[1].hist2d(targets,outputs, bins=100, cmap='jet', 
        range=[lims,lims], cmin=0.1, norm='log')
    
    ax[1].plot(lims, lims,linestyle = '--',color = 'k')

    # plot fitted y = m*x + b
    ax[1].axline(xy1=(0, b), slope=m, color='r')

    ax[1].set_xlabel('targets')
    ax[1].set_ylabel('outputs')
    ax[1].set_aspect('equal')

    fig.colorbar(h[3],ax=ax[1], location='bottom')

    fig.suptitle(variable_name)

    plt.show()

    return(m)

## Plotting (Years)

In [None]:
def plotting_years(years, variable, name):

    fig, ax = plt.subplots()
    
    ax.plot(years,variable, marker = '.', linestyle = '')
    plt.xlabel('Years')
    plt.ylabel(name)
    fig.show()

## Plotting (Days)

In [None]:
def plotting_days(variable,title):
    
    fig, ax = plt.subplots()

    scatter= ax.scatter(dates,variable, marker='.', c=pd.DatetimeIndex(dates).month)

    ax.legend(handles=scatter.legend_elements()[0], labels=['February','March','April'])
    fig.suptitle('Daily ' + title + ' (15 Feb - 30 Apr)')
    
    fig.show()


## Plotting (Mean Values)

In [None]:
def plotting_mean_values(mean_targets, mean_predictions):

    fig, ax = plt.subplots()
    
    plt.plot(dates,(mean_targets), marker = '.', linestyle = '', label = 'targets')
    plt.plot(dates,(mean_predictions), marker = '.', linestyle = '', label = 'predictions')
    plt.xlabel('Years')
    plt.ylabel('Diatom Concentration  [mmol m-2]')
    plt.suptitle('Daily Time-series')
    plt.legend()
    fig.show()

## Evaluation

In [None]:
def evaluation (regr, ds, ds2):

    years = np.arange (2021,2025)

    # For every day of each year
    r_days = np.array([])
    rms_days = np.array([])
    slope_days = np.array([])
    mean_targets = np.array([])
    mean_predictions = np.array([])

    # For every year
    r_years = np.array([])
    rms_years = np.array([])
    slope_years = np.array([])

    for year in tqdm(years):

        dataset = ds.sel(time_counter=str(year))
        dataset2 = ds2.sel(time_counter=str(year))

        inputs, targets, _ = datasets_preparation(dataset, dataset2)

        predictions = regr.predict(inputs)

        m_year = scatter_plot(targets, predictions, 'Diatom for '+ str(year)) 
        r_year = np.round(np.corrcoef(targets, predictions)[0][1],3)
        rms_year = rmse(targets, predictions)

        targets = np.reshape(targets,(len(dataset.time_counter), 1838))
        predictions = np.reshape(predictions,(len(dataset.time_counter), 1838))

        for i in np.arange(0, len(dataset.time_counter)):
            r_days = np.append(r_days, np.round(np.corrcoef(targets[i,:], predictions[i,:])[0][1],3))
            rms_days = np.append(rms_days, rmse(targets[i,:], predictions[i,:]))
            m_day, _ = np.polyfit(targets[i,:], predictions[i,:], deg=1)
            slope_days = np.append(slope_days, m_day)

            mean_targets = np.append(mean_targets,np.mean(targets[i,:]))
            mean_predictions = np.append(mean_predictions,np.mean(predictions[i,:]))

        r_years = np.append(r_year,r_years)
        rms_years = np.append(rms_year,rms_years)
        slope_years = np.append(m_year, slope_years)

    plotting_days(r_days, 'Correlation Coefficients')
    plotting_days(rms_days, 'Root Mean Square Errors')
    plotting_days(slope_days, 'Slopes of the best fitting line')

    plotting_years(years,r_years, 'Correlation Coefficients')
    plotting_years(years,rms_years, 'Root Mean Square Errors')
    plotting_years(years,slope_years, 'Slopes of the best fitting line')

    plotting_mean_values(mean_targets, mean_predictions)
    

## Regressor 4

In [None]:
def regressor4 (inputs, targets, variable_name):

    inputs2 = inputs.transpose()
    
    outputs = regr.predict(inputs2)

    # Post processing
    indx2 = np.full((len(diat_i.y)*len(diat_i.x)),np.nan)
    indx2[indx[0]] = outputs
    model = np.reshape(indx2,(len(diat_i.y),len(diat_i.x)))

    m = scatter_plot(targets, outputs, variable_name + str(i.dt.date.values)) 

    # Preparation of the dataarray 
    model = xr.DataArray(model,
        coords = {'y': diat_i.y, 'x': diat_i.x},
        dims = ['y','x'],
        attrs=dict( long_name = variable_name + "Concentration",
        units="mmol m-2"),)
                        
    plotting3(targets, model, diat_i, variable_name)

## Plotting (Maps)

In [None]:
def plotting_maps(targets, model, variable, variable_name):

    fig, ax = plt.subplots(2,2, figsize = (10,15))

    cmap = plt.get_cmap('cubehelix')
    cmap.set_bad('gray')

    variable.plot(ax=ax[0,0], cmap=cmap, vmin = targets.min(), vmax =targets.max(), cbar_kwargs={'label': variable_name + ' Concentration  [mmol m-2]'})
    model.plot(ax=ax[0,1], cmap=cmap, vmin = targets.min(), vmax = targets.max(), cbar_kwargs={'label': variable_name + ' Concentration  [mmol m-2]'})
    ((variable-model) / variable * 100).plot(ax=ax[1,0], cmap=cmap, cbar_kwargs={'label': variable_name + ' Concentration  [percentage]'})

    plt.subplots_adjust(left=0.1,
        bottom=0.1, 
        right=0.95, 
        top=0.95, 
        wspace=0.35, 
        hspace=0.35)

    sa_vi.set_aspect(ax[0,0])
    sa_vi.set_aspect(ax[0,1])
    sa_vi.set_aspect(ax[1,0])


    ax[0,0].title.set_text(variable_name + ' (targets)')
    ax[0,1].title.set_text(variable_name + ' (outputs)')
    ax[1,0].title.set_text('targets - outputs')
    ax[1,1].axis('off')

    fig.suptitle(str(i.dt.date.values))

    plt.show()
    

## Training (Random Points)

In [None]:
ds = xr.open_dataset('/data/ibougoudis/MOAD/files/integrated_original.nc')
ds2 = xr.open_dataset('/data/ibougoudis/MOAD/files/external_inputs.nc')

ds = ds.isel(time_counter = (np.arange(0, len(ds.time_counter),2)), 
    y=(np.arange(ds.y[0], ds.y[-1], 5)), 
    x=(np.arange(ds.x[0], ds.x[-1], 5)))

ds2 = ds2.isel(time_counter = (np.arange(0, len(ds2.time_counter),2)), 
    y=(np.arange(ds2.y[0], ds2.y[-1], 5)), 
    x=(np.arange(ds2.x[0], ds2.x[-1], 5)))

dataset = ds.sel(time_counter = slice('2007', '2010'))
dataset2 = ds2.sel(time_counter = slice('2007', '2010'))

drivers, diat, _ = datasets_preparation(dataset, dataset2)

regr = regressor(drivers, diat)


## Other Years

In [None]:

ds = ds.sel(time_counter = slice('2021', '2024'))
ds2 = ds2.sel(time_counter = slice('2021', '2024'))
dates = pd.DatetimeIndex(ds['time_counter'].values)

evaluation(regr,ds,ds2)


## Daily Maps

In [None]:


maps = random.sample(sorted(ds.time_counter),10)

for i in tqdm(maps):

    dataset = ds.sel(time_counter=slice(i,i))
    dataset2 = ds2.sel(time_counter=slice(i,i))
    drivers, diat, indx = datasets_preparation(dataset, dataset2)

    diat_i = dataset['Diatom']

    regressor4(drivers, diat, 'Diatom ')