# Predicting Flagellate production rate with Histogram-based Gradient Boosting Regression Tree

## Importing

In [21]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xskillscore as xs

from sklearn.pipeline import make_pipeline
from sklearn.compose import TransformedTargetRegressor
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler

from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.ensemble import BaggingRegressor
from sklearn.feature_selection import r_regression

from sklearn.metrics import root_mean_squared_error as rmse

import os
import lzma
import dill

import random

import cmocean.cm as cm
import salishsea_tools.viz_tools as sa_vi

np.warnings.filterwarnings('ignore') # For the nan mean warning


## Datasets Preparation

In [22]:
# Creation of the training - testing datasets
def datasets_preparation(dataset, dataset2, name):
    
    x = np.tile(dataset.x, len(dataset.time_counter)*len(dataset.y))
    y = np.tile(np.repeat(dataset.y, len(dataset.x)), len(dataset.time_counter))
   
    inputs = np.stack([
        np.ravel(dataset2['Summation_of_solar_radiation']),
        np.ravel(dataset2['Mean_wind_speed']),
        np.ravel(dataset2['Mean_air_temperature']),
        np.ravel(dataset2['Mean_precipitation']),
        np.ravel(dataset2['Latitude']),
        np.ravel(dataset2['Longitude']),
        np.repeat(dataset.time_counter.dt.dayofyear, len(dataset.x)*len(dataset.y)),
        ])

    targets = np.ravel(dataset[name])
    
    indx = np.where(np.isfinite(targets) & (x>10) & ((x>100) | (y<880)))
    inputs = inputs[:,indx[0]]
    targets = targets[indx[0]]

    inputs = inputs.transpose()

    return(inputs, targets, indx)


## Datasets Preparation 2

In [23]:
# Creation of the data arrays
def datasets_preparation2(variable, name, units, dataset):

    # Obtaining the daily indexes
    temp = np.reshape(np.ravel(dataset['Temperature_(15m-100m)']), (len(dataset.time_counter), len(dataset.y) * len(dataset.x)))
    x =  np.tile(dataset.x, len(dataset.y))
    y =  np.tile(np.repeat(dataset.y, len(dataset.x)),1)

    indx = np.where((~np.isnan(temp).any(axis=0)) & (x>10) & ((x>100) | (y<880)))

    variable_all = np.full((len(dataset.time_counter), len(dataset.y) * len(dataset.x)),np.nan)
    variable_all[:,indx[0]] = variable
    variable_all = np.reshape(variable_all,(len(dataset.time_counter),len(dataset.y),len(dataset.x)))

    # Preparation of the dataarray 
    array = xr.DataArray(variable_all,
        coords = {'time_counter': dataset.time_counter,'y': dataset.y, 'x': dataset.x},
        dims = ['time_counter','y','x'],
        attrs=dict(description= name,
        units=units))
        
    return (array)


## File Creation

In [24]:
def file_creation(path, variable, name):

    temp = variable.to_dataset(name=name)
    temp.to_netcdf(path = path + 'targets_predictions.nc', mode='a', encoding={name:{"zlib": True, "complevel": 9}})
    

## Regressor

In [25]:
def regressor (inputs, targets, name):

    model = TransformedTargetRegressor(regressor=make_pipeline(ColumnTransformer(
        transformers=[('drivers', StandardScaler(), [0,1,2,3]), ('spatial', KBinsDiscretizer(n_bins=255,encode='ordinal',strategy='quantile'), [4,5])],remainder='passthrough'),
        HistGradientBoostingRegressor(categorical_features=[4,5,6])),
        transformer=StandardScaler())
    regr = BaggingRegressor(model, n_estimators=12, n_jobs=4).fit(inputs,targets)

    # Printing of the correlation coefficients
    r = np.round(r_regression(inputs,targets),2)
    dict = {'Summation_of_solar_radiation': r[0], 'Mean_wind_speed':r[1], 'Mean_air_temperature': r[2],
        'Mean_precipitation':r[3], 'Latitude':r[4], 'Longitude': r[5], 'Day_of_the_year': r[6]}

    print('The correlation coefficients between each input and ' + name +  ' are: ' +str(dict))

    return(regr)


## Scatter Plot

In [26]:
def scatter_plot(targets, predictions, name):

    # compute slope m and intercept b
    m, b = np.polyfit(targets, predictions, deg=1)

    fig, ax = plt.subplots(2, figsize=(5,10), layout='constrained')

    ax[0].scatter(targets,predictions, alpha = 0.2, s = 10)

    lims = [np.min([ax[0].get_xlim(), ax[0].get_ylim()]),
        np.max([ax[0].get_xlim(), ax[0].get_ylim()])]

    # plot fitted y = m*x + b
    ax[0].axline(xy1=(0, b), slope=m, color='r')

    ax[0].set_xlabel('targets')
    ax[0].set_ylabel('predictions')
    ax[0].set_xlim(lims)
    ax[0].set_ylim(lims)
    ax[0].set_aspect('equal')

    ax[0].plot(lims, lims,linestyle = '--',color = 'k')

    h = ax[1].hist2d(targets,predictions, bins=100, cmap='jet', 
        range=[lims,lims], cmin=0.1, norm='log')
    
    ax[1].plot(lims, lims,linestyle = '--',color = 'k')

    # plot fitted y = m*x + b
    ax[1].axline(xy1=(0, b), slope=m, color='r')

    ax[1].set_xlabel('targets')
    ax[1].set_ylabel('predictions')
    ax[1].set_aspect('equal')

    fig.colorbar(h[3],ax=ax[1], location='bottom')

    fig.suptitle(name)

    plt.show()

    return(m)


## Seasonality

In [27]:
def seasonality (dates,targets):

    # Preparation of the dataarray 
    targets2 = xr.DataArray(targets,
        coords = {'time_counter':dates},
        dims = 'time_counter')
    
    test = targets2.groupby('time_counter.dayofyear').mean('time_counter')
    test0 = test.drop_isel(dayofyear=14) # Removing 29 Feb
    test2 = np.tile(test0,len(np.unique(dates.year)))
    indx2 = np.where((dates.month==2) & (dates.day==29)) # Finding where  29 Feb exists

    test3 = np.insert(test2,indx2[0][0],test[14]) # 2008
    test3 = np.insert(test3,indx2[0][1],test[14]) # 2012
    test3 = np.insert(test3,indx2[0][2],test[14]) # 2016
    season = np.insert(test3,indx2[0][3],test[14]) # 2020

    return(season)


## Plotting (Criteria)

In [28]:
def plotting_criteria(dates, variable, year_variable, title):
    
    indx = pd.DatetimeIndex(dates)
    fig, ax = plt.subplots()

    scatter= ax.scatter(dates,variable, marker='.', c=indx.month)
    plt.xticks(rotation=70)
    ax.legend(handles=scatter.legend_elements()[0], labels=['February','March','April'])
    ax.plot(dates[(indx.month == 3) & (indx.day == 15)], year_variable,color='red',marker='*')
    fig.suptitle(title + ' (15 Feb - 30 Apr)')
    
    fig.show()


## Plotting (Mean Values)

In [29]:
def plotting_mean_values(dates, mean_targets, mean_predictions, units, category, region):

    years = np.unique(dates.year)
    ticks = [0]
    
    fig, _ = plt.subplots(figsize=(19,5))
    
    mean_targets = np.ma.array(mean_targets)
    mean_predictions = np.ma.array(mean_predictions)

    for year in years[:-1]:
        ticks.append((np.where(dates.year==year)[0][-1]+1))
        mean_targets[(np.where(dates.year==year)[0][-1]+1)] = np.ma.masked
        mean_predictions[(np.where(dates.year==year)[0][-1]+1)] = np.ma.masked

    plt.plot(mean_targets, label = 'targets')
    plt.plot(mean_predictions, label = 'predictions')
    plt.xlabel('Years')
    plt.xticks(ticks,years)
    plt.suptitle('Mean '+category + ' ' +units + ' (15 Feb - 30 Apr) ' + region)
    plt.legend()
    fig.show()
    

## Plotting (Maps)

In [30]:
def plotting_maps(targets, predictions, name, units):

    fig, ax = plt.subplots(2,2, figsize = (10,15), layout='tight')

    cmap = plt.get_cmap('cubehelix')
    cmap.set_bad('gray')

    targets.plot(ax=ax[0,0], cmap=cmap, vmin = targets.min(), vmax = targets.max(), cbar_kwargs={'label': name + ' ' + units})
    predictions.plot(ax=ax[0,1], cmap=cmap, vmin = targets.min(), vmax = targets.max(), cbar_kwargs={'label': name + ' ' + units})
    (targets-predictions).plot(ax=ax[1,0], cmap=cmap, cbar_kwargs={'label': name + ' ' + units})

    plt.subplots_adjust(left=0.1,
        bottom=0.1, 
        right=0.95, 
        top=0.95, 
        wspace=0.35, 
        hspace=0.35)

    sa_vi.set_aspect(ax[0,0])
    sa_vi.set_aspect(ax[0,1])
    sa_vi.set_aspect(ax[1,0])

    ax[0,0].title.set_text('Targets')
    ax[0,1].title.set_text('Predictions')
    ax[1,0].title.set_text('Targets-Predictions')
    ax[1,1].axis('off')

    fig.suptitle(name + ' '+ str(targets.time_counter.dt.date.values))

    plt.show()
    

## Plotting (Regions)

In [31]:
def plot_box(ax, corn, colour):

    ax.plot([corn[2], corn[3], corn[3], corn[2], corn[2]], 
    [corn[0], corn[0], corn[1], corn[1], corn[0]], '-', color=colour)
    

## Plotting (Regional analysis)

In [32]:
def plotting_regional(metric,box,years,category):

    fig,ax = plt.subplots()

    for i in range (0,len(box)):
        ax.plot(years,metric[:,i],marker= '*', label=box[i])
    plt.suptitle(category+ ' (Regional analysis)')
    plt.legend()
    fig.show()


## Evaluation

In [33]:
def evaluation (regr, ds, ds2, name, units):

    years = np.unique(ds.time_counter.dt.year)

    # For every year
    r_years = np.array([])
    rms_years = np.array([])
    slope_years = np.array([])

    # The data arrays 
    targets_all = []
    predictions_all = []

    for year in (years):

        dataset = ds.sel(time_counter=str(year))
        dataset2 = ds2.sel(time_counter=str(year))

        inputs, targets, indx = datasets_preparation(dataset, dataset2, name)

        predictions = regr.predict(inputs)

        # Calculating the annual time-series
        m_year = scatter_plot(targets, predictions, name + ' for '+ str(year)) 
        r_year = np.corrcoef(targets, predictions)[0][1]
        rms_year = rmse(targets, predictions)

        r_years = np.append(r_years,r_year)
        rms_years = np.append(rms_years,rms_year)
        slope_years = np.append(slope_years,m_year)

        # Daily arrays
        targets = np.reshape(targets,(len(dataset.time_counter), int(len(indx[0]) / len(dataset.time_counter))))
        predictions = np.reshape(predictions,(len(dataset.time_counter), int(len(indx[0]) / len(dataset.time_counter))))
        targets_all.append (datasets_preparation2(targets, name + ' _targets', units, dataset))
        predictions_all.append(datasets_preparation2(predictions, name + ' _predictions', units, dataset))   

    # Daily arrays
    targets_all = xr.concat(targets_all, dim='time_counter')
    predictions_all = xr.concat(predictions_all, dim='time_counter')
    
    return(r_years, rms_years, slope_years, targets_all, predictions_all)


## Training

In [None]:
name = 'Diatom_Production_Rate'
units = '[mmol N m-2 s-1]'
category = 'Production rates'

ds = xr.open_dataset('/data/ibougoudis/MOAD/files/integrated_original.nc')
ds2 = xr.open_dataset('/data/ibougoudis/MOAD/files/external_inputs.nc')

# ds = ds.isel(
#     y=(np.arange(ds.y[0], ds.y[-1], 5)), 
#     x=(np.arange(ds.x[0], ds.x[-1], 5)))

# ds2 = ds2.isel(
#     y=(np.arange(ds2.y[0], ds2.y[-1], 5)), 
#     x=(np.arange(ds2.x[0], ds2.x[-1], 5)))

dataset = ds.sel(time_counter = slice('2007', '2020'))
dataset2 = ds2.sel(time_counter = slice('2007', '2020'))

inputs, targets, indx = datasets_preparation(dataset, dataset2, name)

regr = regressor(inputs, targets, name)

predictions = regr.predict(inputs)

r_train = np.round(np.corrcoef(predictions,targets)[0][1],3)
rms_train = rmse(predictions,targets)
m,_ = np.polyfit(targets, predictions, deg=1)
slope_train = np.round(m,3)

print ('The correlation coefficient during training is: ' + str(r_train))
print ('The rmse during training is: ' + str(rms_train))
print('The slope of the best fitting line during training is: '+str(slope_train))

# Daily mean plot
targets = np.reshape(targets,(len(dataset.time_counter), int(len(indx[0]) / len(dataset.time_counter))))
predictions = np.reshape(predictions,(len(dataset.time_counter), int(len(indx[0]) / len(dataset.time_counter))))
targets = np.mean(targets,axis=1)
predictions = np.mean(predictions,axis=1)

dates = pd.DatetimeIndex(dataset['time_counter'].values)

plotting_mean_values(dates, targets, predictions, units, category, 'Salish Sea')


In [None]:
season = seasonality(dates,targets)

plt.plot(season[75:150])
plt.suptitle('Long-term seasonality (2007-2020)')

plotting_mean_values(dates, targets-season, predictions-season, units, category, 'Salish Sea (removed seasonality)')

quant_train = dataset[name] # Keeping it for the regional seasonalities
dates_season = pd.DatetimeIndex(quant_train['time_counter'].values)


## Other Years

In [None]:
ds = ds.sel(time_counter = slice('2021', '2024'))
ds2 = ds2.sel(time_counter = slice('2021', '2024'))

dates = pd.DatetimeIndex(ds['time_counter'].values)

r_years, rms_years, slope_years, targets_all, predictions_all = evaluation(regr, ds, ds2, name, units)

r_days = xr.corr(targets_all,predictions_all, dim=['x','y'])
rms_days = xs.rmse(targets_all,predictions_all, dim=['x','y'], skipna=True)
slope_days = xs.linslope(targets_all,predictions_all, dim=['x','y'], skipna=True)


## Plotting (Results)

In [None]:
plotting_criteria(dates, r_days, r_years, 'Correlation Coefficients')
plotting_criteria(dates, rms_days, rms_years, 'Root Mean Square Errors')
plotting_criteria(dates, slope_days, slope_years, 'Slopes of the best fitting line')

# Daily maps
maps = random.sample(sorted(np.arange(0,len(targets_all.time_counter))),10)
for i in maps:

    idx = np.isfinite(np.ravel(targets_all[i]))
    scatter_plot(np.ravel(targets_all[i])[idx], np.ravel(predictions_all[i])[idx], name + ' '+ str(targets_all[i].time_counter.dt.date.values))

    plotting_maps(targets_all[i], predictions_all[i], name, units)
    

## Regional analysis

In [None]:
bathy = xr.open_dataset('/home/sallen/MEOPAR/grid/bathymetry_202108.nc')

fig, ax = plt.subplots(1, 1, figsize=(5, 9))
mycmap = cm.deep
mycmap.set_bad('grey')
ax.pcolormesh(bathy['Bathymetry'], cmap=mycmap)
sa_vi.set_aspect(ax)

SoG_north = [650, 730, 100, 200]
plot_box(ax, SoG_north, 'g')
SoG_center = [450, 550, 200, 300]
plot_box(ax, SoG_center, 'b')
Fraser_plume = [380, 460, 260, 330]
plot_box(ax, Fraser_plume, 'm')
SoG_south = [320, 380, 280, 350]
plot_box(ax, SoG_south, 'k')
Haro_Boundary = [290, 350, 210, 280]
plot_box(ax, Haro_Boundary, 'm')
JdF_west = [250, 425, 25, 125]
plot_box(ax, JdF_west, 'c')
JdF_east = [200, 290, 150, 260]
plot_box(ax, JdF_east, 'w')
PS_main = [20, 150, 200, 280]
plot_box(ax, PS_main, 'r')
PS_all = [0, 200, 80, 320]
plot_box(ax, PS_all, 'm')

boxnames = ['SoG_north','SoG_center','Fraser_plume','SoG_south', 'Haro_Boundary', 'JdF_west', 'JdF_east', 'PS_main', 'PS_all']
fig.legend(boxnames)

SS_all = [0, 898, 0, 398]
boxes = [SS_all,SoG_north,SoG_center,Fraser_plume,SoG_south,Haro_Boundary,JdF_west,JdF_east,PS_main, PS_all]
boxnames.insert(0,'SS_all')

# # Low resolution
# temp = []
# for i in boxes:

#     temp.append([x//5 for x in i])

# boxes = temp


In [None]:

r = np.zeros((4,len(boxnames)))
rms = np.zeros((4,len(boxnames)))

r_season = np.zeros((4,len(boxnames)))
rms_season = np.zeros((4,len(boxnames)))

for i in range (0, len(boxes)):

    targets=targets_all[:,boxes[i][0]:boxes[i][1], boxes[i][2]:boxes[i][3]]
    predictions=predictions_all[:,boxes[i][0]:boxes[i][1], boxes[i][2]:boxes[i][3]]
   
    mean_targets = targets.mean(dim=['x','y'], skipna=True)
    mean_predictions = predictions.mean(dim=['x','y'], skipna=True)
    plotting_mean_values(dates, mean_targets, mean_predictions, units, category, boxnames[i])

    climatology = quant_train[:,boxes[i][0]:boxes[i][1], boxes[i][2]:boxes[i][3]]
    mean_targets_clim = climatology.mean(dim=['x','y'], skipna=True)
    season = seasonality(dates_season,mean_targets_clim)

    plotting_mean_values(dates, mean_targets-season[np.where(dates_season.year==2017)[0][0]:], 
        mean_predictions-season[np.where(dates_season.year==2017)[0][0]:], units, category, boxnames[i]+' (removed seasonality)')

    targets_annual = np.ravel(mean_targets.groupby('time_counter.year'))
    predictions_annual = np.ravel(mean_predictions.groupby('time_counter.year'))

    targets_annual_season = np.ravel((mean_targets-season[np.where(dates_season.year==2017)[0][0]:]).groupby('time_counter.year'))
    predictions_annual_season = np.ravel((mean_predictions-season[np.where(dates_season.year==2017)[0][0]:]).groupby('time_counter.year'))
    
    years = []
    for j in range(1,8,2):

        years.append(targets_annual[j-1])

        r[len(years)-1,i] = np.round(np.corrcoef(np.ravel(targets_annual[j]),np.ravel(predictions_annual[j]))[0][1],3)
        rms[len(years)-1,i] = rmse(np.ravel(targets_annual[j]),np.ravel(predictions_annual[j]))

        r_season[len(years)-1,i] = np.round(np.corrcoef(np.ravel(targets_annual_season[j]),np.ravel(predictions_annual_season[j]))[0][1],3)
        rms_season[len(years)-1,i] = rmse(np.ravel(targets_annual_season[j]),np.ravel(predictions_annual_season[j]))

plotting_regional(r,boxnames,years, 'Correlation coefficients')
plotting_regional(rms,boxnames,years, 'Root mean square errors')

plotting_regional(r_season,boxnames,years, 'Correlation coefficients (removed seasonality)')
plotting_regional(rms_season,boxnames,years, 'Root mean square errors (removed seasonality)')


## Saving

In [40]:
path = '/data/ibougoudis/MOAD/files/results/' + name + '/hist_normal_res/'
os.makedirs(path, exist_ok=True)
with lzma.open(path + 'regr.xz', 'wb') as f:
    
    dill.dump(regr, f)

with open(path + 'metrics.pkl', 'wb') as f:
    dill.dump([r_train,rms_train,slope_train,dates_season,season], f)

file_creation(path, targets_all, 'Targets')
file_creation(path, predictions_all, 'Predictions')
file_creation(path, (targets_all-predictions_all), 'Targets - Predictions')
