# Predicting Flagellate production rate with a Histogram-based Gradient Boosting Regression Tree based on the oceanographic boxes (spatial means)

## Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xskillscore as xs

from sklearn.pipeline import make_pipeline
from sklearn.compose import TransformedTargetRegressor
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.preprocessing import StandardScaler

from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.feature_selection import r_regression

from sklearn.metrics import root_mean_squared_error as rmse

import os
import lzma
import dill

from tqdm import tqdm

import random

import cmocean.cm as cm
import salishsea_tools.viz_tools as sa_vi


## Datasets Preparation

In [None]:
# Creation of the training - testing datasets
def datasets_preparation(dataset, name, inputs_names, regions, boxes):

    inputs = []

    if inputs_names[-1] =='Day_of_year':   
        for i in inputs_names[0:-1]:
            inputs.append(dataset[i].to_numpy().flatten())       
        inputs.append(np.repeat(dataset.time_counter.dt.dayofyear, len(dataset.x)*len(dataset.y)))

    else:        
        for i in inputs_names:
            inputs.append(dataset[i].to_numpy().flatten())

    inputs = np.array(inputs)
    targets = np.ravel(dataset[name])

    regions = np.tile(np.ravel(regions), len(np.unique(dataset.time_counter)))

    indx = np.where(np.isfinite(targets) & ~np.isnan(regions))
    inputs = inputs[:,indx[0]]
    targets = targets[indx[0]]

    regions = regions[indx[0]]

    inputs = inputs.transpose()

    targets_mean = np.zeros((len(dataset.time_counter),len(boxes)))
    inputs_mean = np.zeros((len(dataset.time_counter),len(boxes),(len(inputs_names))))

    for i in range (0, len(boxes)):

        indx2 = np.where(regions==i)
        inputs2 = inputs[indx2[0],:]
        targets2 = targets[indx2[0]]

        for j in range(0, len(inputs_names)):

            temp = np.reshape(inputs2[:,j], (len(dataset.time_counter), len(inputs2) // len(dataset.time_counter)))
            inputs_mean[:,i,j] = temp.mean(axis=1)

        temp = np.reshape(targets2, (len(dataset.time_counter), len(targets2) // len(dataset.time_counter)))
        targets_mean[:,i] = temp.mean(axis=1)

    return(inputs_mean, targets_mean, indx, regions)


## File Creation

In [None]:
def file_creation(path, variable, name):

    temp = variable.to_dataset(name=name)
    temp.to_netcdf(path = path + 'targets_predictions.nc', mode='a', encoding={name:{"zlib": True, "complevel": 9}})
    temp.close()
    

## Regressor

In [None]:
def regressor (inputs, targets, drivers, day_input, i, r_inputs):

    if day_input == []:
        model = TransformedTargetRegressor(regressor=make_pipeline(ColumnTransformer(
        transformers=[('drivers', StandardScaler(), np.arange(0,len(drivers)))], remainder='passthrough'),
        HistGradientBoostingRegressor()),
        transformer=StandardScaler())

    else:
        model = TransformedTargetRegressor(regressor=make_pipeline(ColumnTransformer(
            transformers=[('drivers', StandardScaler(), np.arange(0,len(drivers)))], remainder='passthrough'),
            HistGradientBoostingRegressor(categorical_features=[len(drivers)])),
            transformer=StandardScaler())
        
    regr = model.fit(inputs,targets)

    r_inputs[i] = np.round(r_regression(inputs,targets),2)

    return(regr, r_inputs)


## Scatter Plot

In [None]:
def scatter_plot(targets, predictions, name):

    # compute slope m and intercept b
    m, b = np.polyfit(targets, predictions, deg=1)

    fig, ax = plt.subplots(2, figsize=(5,10), layout='constrained')

    ax[0].scatter(targets,predictions, alpha = 0.2, s = 10)

    lims = [np.min([ax[0].get_xlim(), ax[0].get_ylim()]),
        np.max([ax[0].get_xlim(), ax[0].get_ylim()])]

    # plot fitted y = m*x + b
    ax[0].axline(xy1=(0, b), slope=m, color='r')

    ax[0].set_xlabel('targets')
    ax[0].set_ylabel('predictions')
    ax[0].set_xlim(lims)
    ax[0].set_ylim(lims)
    ax[0].set_aspect('equal')

    ax[0].plot(lims, lims,linestyle = '--',color = 'k')

    h = ax[1].hist2d(targets,predictions, bins=100, cmap='jet', 
        range=[lims,lims], cmin=0.1, norm='log')
    
    ax[1].plot(lims, lims,linestyle = '--',color = 'k')

    # plot fitted y = m*x + b
    ax[1].axline(xy1=(0, b), slope=m, color='r')

    ax[1].set_xlabel('targets')
    ax[1].set_ylabel('predictions')
    ax[1].set_aspect('equal')

    fig.colorbar(h[3],ax=ax[1], location='bottom')

    fig.suptitle(name)

    plt.show()

    return(m)


## Plotting (Criteria)

In [None]:
def plotting_criteria(dates, variable, year_variable, months, period, title):
    
    indx = pd.DatetimeIndex(dates)
    fig, ax = plt.subplots()

    scatter= ax.scatter(dates,variable, marker='.', c=indx.month)
    plt.xticks(rotation=70)
    ax.legend(handles=scatter.legend_elements()[0], labels=months)
    ax.plot(dates[(indx.month == np.unique(indx.month)[1]) & (indx.day == len(np.unique(dates.day)) // 2)], year_variable,color='red',marker='*')
    fig.suptitle(title + ' ' + period)
    
    plt.show()


## Plotting (Mean Values)

In [None]:
def plotting_mean_values(dates, targets, predictions, mean, units, category, region, period, labels):

    r = np.round(np.corrcoef(predictions,targets)[0][1],3)
    rms = rmse(predictions,targets)  / mean * 100
    m,_ = np.polyfit(targets, predictions, deg=1)
    slope = np.round(m,3)

    temp = pd.DataFrame(np.vstack((r,rms,slope)).transpose(),columns=['r','rms [%]','slope'])
    display(temp)

    years = np.unique(dates.year)

    fig, ax = plt.subplots(figsize=(19,5))
    
    mean_targets = np.ma.array(targets)
    mean_predictions = np.ma.array(predictions)

    for year in years:
        mean_targets[(np.where(dates.year==year)[0][-1])] = np.ma.masked
        mean_predictions[(np.where(dates.year==year)[0][-1])] = np.ma.masked
        
    ax.plot(mean_targets, label = 'targets')
    ax.plot(mean_predictions, label = 'predictions')

    ticks = np.arange(0,len(years)*len(labels),len(labels)/2)
    ticks = np.int16(ticks)
    labels2=np.tile(labels,len(years))

    ax.set_xticks(ticks, labels2[ticks])

    ax2 = ax.secondary_xaxis('bottom')
    ax2.set_xticks(ticks=np.arange(0,len(years)*len(labels),len(labels)), labels=years)
    
    ax2.tick_params(length=0, pad=30)

    plt.suptitle('Mean '+category + ' ' +units + ' ' + period + ' ' + region)
    ax.legend()
    plt.show()

    return(r,rms,slope)


## Plotting (Maps)

In [None]:
def plotting_maps(targets, predictions, name, units):

    fig, ax = plt.subplots(2,2, figsize = (10,15), layout='tight')

    cmap = plt.get_cmap('cubehelix')
    cmap.set_bad('gray')

    targets.plot(ax=ax[0,0], cmap=cmap, vmin = targets.min(), vmax = targets.max(), cbar_kwargs={'label': name + ' ' + units})
    predictions.plot(ax=ax[0,1], cmap=cmap, vmin = targets.min(), vmax = targets.max(), cbar_kwargs={'label': name + ' ' + units})
    (targets-predictions).plot(ax=ax[1,0], cmap=cmap, cbar_kwargs={'label': name + ' ' + units})

    plt.subplots_adjust(left=0.1,
        bottom=0.1, 
        right=0.95, 
        top=0.95, 
        wspace=0.35, 
        hspace=0.35)

    sa_vi.set_aspect(ax[0,0])
    sa_vi.set_aspect(ax[0,1])
    sa_vi.set_aspect(ax[1,0])

    ax[0,0].title.set_text('Targets')
    ax[0,1].title.set_text('Predictions')
    ax[1,0].title.set_text('Targets-Predictions')
    ax[1,1].axis('off')

    fig.suptitle(name + ' '+ str(targets.time_counter.dt.date.values))

    plt.show()
    

## Plotting (Regions)

In [None]:
def plot_box(ax, corn, colour):

    ax.plot([corn[2], corn[3], corn[3], corn[2], corn[2]], 
    [corn[0], corn[0], corn[1], corn[1], corn[0]], '-', color=colour)
    

## Plotting (Mean Peaks)

In [None]:
def plotting_mean_peaks(dates,targets_mean,predictions_mean,category,units,region,boxname,labels):

    years = np.unique(dates.year)
    
    fig, ax = plt.subplots(figsize=(19,5))
    
    targets_mean = np.ma.array(targets_mean)
    predictions_mean = np.ma.array(predictions_mean)

    for year in years:
      
        targets_mean[(np.where(dates.year==year)[0][-1])] = np.ma.masked
        predictions_mean[(np.where(dates.year==year)[0][-1])] = np.ma.masked

    ax.plot(targets_mean, label = 'targets')
    ax.plot(predictions_mean, label = 'predictions')

    ax.set_xticks(ticks=np.arange(0,len(years)*len(labels),len(labels)//len(years)+1), labels=np.tile(labels[np.arange(0,len(labels),len(labels)//len(years)+1)],len(years)))
    
    ax2 = ax.secondary_xaxis('bottom')
    ax2.set_xticks(ticks=np.arange(0,len(years)*len(labels),len(labels)+1), labels=years)
    
    ax2.tick_params(length=0, pad=30)

    plt.suptitle('Mean '+category + ' ' +units + ' (15 Feb - 30 Apr) ' + region + ' ' + boxname)
    ax.legend()
    plt.show()
    

## Plotting (Regional analysis)

In [None]:
def plotting_regional(metric,box,years,category):

    fig,ax = plt.subplots()

    for i in range (0,len(box)):
        ax.plot(years,metric[:,i],marker= '*', label=box[i])
    plt.suptitle(category+ ' (Regional analysis)')
    plt.legend()
    fig.show()


## Post Processing

In [None]:
def post_processing(dates,dataset,targets,predictions,units,category,period,labels,boxname):

    targets_mean = np.reshape(targets,(len(dataset.time_counter), len(targets) // len(dataset.time_counter)))
    predictions_mean = np.reshape(predictions,(len(dataset.time_counter), len(targets) // len(dataset.time_counter)))

    targets_mean = np.mean(targets_mean,axis=1)
    predictions_mean = np.mean(predictions_mean,axis=1)

    r,rms,slope = plotting_mean_values(dates, targets_mean, predictions_mean, np.mean(targets_mean), units, category, period, boxname, labels)

    season = np.array(np.split(targets_mean,len(np.unique(dates.year)),axis=0))
    season = np.mean(season, axis=0)

    season_train = np.tile(season,len(np.unique(dates.year))) # Broadcasting season to all training years

    r_train_season,_,slope_train_season = plotting_mean_values(dates, targets_mean-season_train, predictions_mean-season_train, np.mean(targets_mean),
    units, category, period, boxname + ' (removed seasonality)', labels)

    return(r, rms, slope, season, r_train_season, slope_train_season)


## Initiation

In [None]:
name = 'Flagellate_Production_Rate'
units = '[mmol N m-2 s-1]'
category = 'Production rates'

filename = '/data/ibougoudis/MOAD/files/inputs/jan_mar.nc'

drivers = ['Summation_of_solar_radiation', 'Mean_precipitation']
day_input = []

inputs_names = drivers + day_input

n_bins = 255

if filename[35:42] == 'jan_mar': # 75 days, 1st period
    period = '(16 Jan - 31 Mar)'
    id = '1'
    months = ['January', 'February', 'March']

elif filename[35:42] == 'jan_apr': # 120 days, 2nd period
    period = '(01 Jan - 30 Apr)'
    id = '2'
    months = ['January', 'February', 'March', 'April']

elif filename[35:42] == 'feb_apr': # 75 days, 3rd period
    period = '(15 Feb - 30 Apr)'
    id = '3'
    months = ['February', 'March', 'April']

elif filename[35:42] == 'apr_jun': # 76 days, 4th period
    period = '(16 Apr - 30 Jun)'
    id = '4'
    months = ['April', 'May', 'June']

elif filename[35:42] == 'may_sep': # 153 days, 5th period
    period = '(01 May - 30 Sep)'
    id = '5'
    months = ['May', 'June', 'July', 'August', 'September']
   
ds = xr.open_dataset(filename)


## Regions

In [None]:
bathy = xr.open_dataset('/home/sallen/MEOPAR/grid/bathymetry_202108.nc')

fig, ax = plt.subplots(1, 1, figsize=(5, 9))
mycmap = cm.deep
mycmap.set_bad('grey')
ax.pcolormesh(ds[name][0], cmap=mycmap)
sa_vi.set_aspect(ax)

SoG_north = [650, 730, 100, 200]
plot_box(ax, SoG_north, 'g')
SoG_center = [450, 550, 200, 300]
plot_box(ax, SoG_center, 'b')
Fraser_plume = [380, 460, 260, 330]
plot_box(ax, Fraser_plume, 'm')
SoG_south = [320, 380, 280, 350]
plot_box(ax, SoG_south, 'k')
Haro_Boundary = [290, 350, 210, 280]
plot_box(ax, Haro_Boundary, 'm')
JdF_west = [250, 425, 25, 125]
plot_box(ax, JdF_west, 'c')
JdF_east = [200, 290, 150, 260]
plot_box(ax, JdF_east, 'w')
PS_all = [0, 200, 80, 320]
plot_box(ax, PS_all, 'm')
PS_main = [20, 150, 200, 280]
plot_box(ax, PS_main, 'r')

boxnames = ['SoG_north','SoG_center','Fraser_plume','SoG_south', 'Haro_Boundary', 'JdF_west', 'JdF_east', 'PS_all', 'PS_main']
fig.legend(boxnames)

boxes = [SoG_north,SoG_center,Fraser_plume,SoG_south,Haro_Boundary,JdF_west,JdF_east,PS_all,PS_main]

regions0 = np.full((len(ds.y),len(ds.x)),np.nan)

for i in range (0, len(boxes)):
    regions0[boxes[i][0]:boxes[i][1], boxes[i][2]:boxes[i][3]] = i

regions0 = xr.DataArray(regions0,dims = ['y','x'])

# Low resolution

temp = []

for i in boxes:
    temp.append([x//5 for x in i])

boxes = temp


## Training

In [None]:
# Low resolution

ds = ds.isel(y=(np.arange(ds.y[0], ds.y[-1], 5)), 
    x=(np.arange(ds.x[0], ds.x[-1], 5)))

regions0 = regions0.isel(y=(np.arange(regions0.y[0], regions0.y[-1], 5)), 
    x=(np.arange(regions0.x[0], regions0.x[-1], 5)))

dataset = ds.sel(time_counter = slice('2007', '2020'))

labels = np.unique(dataset.time_counter.dt.strftime('%d %b'))
indx_labels = np.argsort(pd.to_datetime(labels, format='%d %b'))
labels = labels[indx_labels]

inputs, targets, indx, regions = datasets_preparation(dataset, name, inputs_names, regions0, boxes)

r_inputs = np.zeros((len(boxnames), len(inputs_names)))
regr_all = []
predictions = np.full(targets.shape,np.nan) # size of targets without nans

for i in range(0, len(boxes)):

    regr, r_inputs = regressor(inputs[:,i,:], targets[:,i], drivers, day_input, i, r_inputs)

    regr_all.append(regr)
    
    predictions[:,i] = regr.predict(inputs[:,i,:])

print('Metrics between input features and '+name)
temp = pd.DataFrame(r_inputs, index=boxnames, columns=inputs_names)
display(temp)


## Time-series (Training)

In [None]:
dates = pd.DatetimeIndex(dataset['time_counter'].values)

r_train = np.zeros(len(boxes))
rms_train = np.zeros(len(boxes))
slope_train = np.zeros(len(boxes))

season = np.zeros((len(boxes),len(labels)))

r_train_season = np.zeros(len(boxes))
slope_train_season = np.zeros(len(boxes))

peak = np.zeros(len(boxes))
std_targets = np.zeros(len(boxes))
std_season = np.zeros(len(boxes))
std_predictions = np.zeros(len(boxes))

for i in range(0, len(boxes)):

    r_train[i], rms_train[i], slope_train[i], season[i], r_train_season[i], slope_train_season[i] = post_processing(
        dates,dataset,targets[:,i],predictions[:,i],units,category,period,labels,boxnames[i])

    mean = np.mean(targets[:,i])
    std_targets[i] = np.std(targets[:,i])
    peak[i] = mean + 0*std_targets[i]
    
    std_season[i] = np.std(season[i])
    std_predictions[i] = np.std(predictions[:,i])

plt.plot(season.transpose())
plt.legend(boxnames)
plt.suptitle('Long-term seasonalities (2007-2020)')
plt.show()


## Testing Years

In [None]:
dataset = ds.sel(time_counter = slice('2021', '2024'))

dates = pd.DatetimeIndex(dataset['time_counter'].values)
years = np.unique(dataset.time_counter.dt.year)

season_test = np.tile(season,len(np.unique(dates.year))) # Broadcasting season to all testing years

inputs, targets, indx, regions  = datasets_preparation(dataset, name, inputs_names, regions0, boxes)

predictions = np.full(targets.shape,np.nan) # size of targets without nans

std_targets_test = np.zeros(len(boxes))
std_predictions_test = np.zeros(len(boxes))

for i in range (0, len(boxes)):

    predictions[:,i] = regr_all[i].predict(inputs[:,i,:]) 

    std_targets_test[i] = np.nanstd(targets[:,i])
    std_predictions_test[i] = np.nanstd(predictions[:,i])


## Standard deviations

In [None]:
plt.plot(std_targets)
plt.plot(std_season)
plt.plot(std_targets_test)
plt.plot(std_targets-std_season)
plt.plot(std_targets_test-std_season)

plt.plot(std_predictions)
plt.plot(std_predictions_test)
plt.plot(std_predictions-std_season)
plt.plot(std_predictions_test-std_season)

plt.xticks(ticks=np.arange(0,len(boxes)), labels=boxnames)
plt.xticks(rotation=45)
plt.legend(('targets','season','targets_test','targets-season','targets_test-season','predictions','predictions_test','predictions-season','predictions_test-season'))
plt.suptitle('Standard Deviations for training-seasonanlity-testing')
plt.show()

plt.plot((std_targets_test-std_season)*100/std_targets_test)
plt.xticks(ticks=np.arange(0,len(boxes)), labels=boxnames)
plt.xticks(rotation=45)
plt.suptitle('Percentage of difference between testing')


## Time-series (Testing)

In [None]:
r_test, rms_test, slope_test = np.zeros(len(boxes)),  np.zeros(len(boxes)), np.zeros(len(boxes))

r_test_season, slope_test_season = np.zeros(len(boxes)), np.zeros(len(boxes))

targets_sum, predictions_sum  = np.zeros((len(boxes),len(np.unique(dates.year)))), np.zeros((len(boxes),len(np.unique(dates.year))))

targets_mean, predictions_mean = np.zeros((len(boxes),len(np.unique(dates.year)))), np.zeros((len(boxes),len(np.unique(dates.year))))

targets_diff, predictions_diff = np.zeros((len(boxes),len(season[0]),len(np.unique(dates.year)))), np.zeros((len(boxes),len(season[0]),len(np.unique(dates.year))))

rss = np.zeros(len(boxes))

for i in range (0,len(boxes)):

    r_test[i] = np.round(np.corrcoef(predictions[:,i],targets[:,i])[0][1],3)
    rms_test[i] = rmse(targets[:,i],predictions[:,i]) / np.mean(targets[:,i]) * 100
    m,_ = np.polyfit(targets[:,i], predictions[:,i], deg=1)
    slope_test[i]  = np.round(m,3)

    rss[i] = np.sum((np.ravel(targets[:,i])-np.ravel(predictions[:,i]))**2) / np.mean(np.ravel(targets[:,i])) *100 # Similar to rms, is not affected by the seasonality

    targets_an = np.reshape(targets[:,i]-season_test[i],(len(years),len(season[0]))) # For targets_sum, targets_mean
    predictions_an = np.reshape(predictions[:,i]-season_test[i],(len(years),len(season[0]))) # For predictions_sum, predictions_mean

    targets_an2 = np.reshape(targets[:,i],(len(years),len(season[0]))) # For targets_diff
    predictions_an2 = np.reshape(predictions[:,i],(len(years),len(season[0]))) # For predictions_diff

    for j in range (0, len(years)):

        targets_sum[i,j] = np.sum(targets_an[j])
        predictions_sum[i,j] = np.sum(predictions_an[j])

        targets_mean[i,j] = np.mean(targets_an[j])
        predictions_mean[i,j] = np.mean(predictions_an[j])

        targets_diff[i,:,j] = np.where(targets_an2[j]>peak[i],targets_an2[j], np.nan)
        predictions_diff[i,:,j] = np.where(predictions_an2[j]>peak[i],predictions_an2[j], np.nan)

    # r_test_season[i], _, slope_test_season[i] = plotting_mean_values(dates, targets[:,i]-season_test[i], predictions[:,i]-season_test[i], targets[:,i].mean(),
    #     units, category, boxnames[i] +' (removed seasonality)', period, labels)

    targets_diff2 = np.reshape(targets_diff[i],(len(season[0])*len(np.unique(dates.year))), order='F')
    predictions_diff2 = np.reshape(predictions_diff[i],(len(season[0])*len(np.unique(dates.year))), order='F')

    plotting_mean_peaks(dates,targets_diff2,predictions_diff2,category,units,'Peaks',boxnames[i],labels)
        

## Saving

In [None]:
# path = '/data/ibougoudis/MOAD/files/results/' + name + '/single_runs/' + name[0:4].lower() + '_pr_hist' + id + '_boxes_s4/'

# os.makedirs(path, exist_ok=True)
# with lzma.open(path + 'regr_all.xz', 'wb') as f:   
#     dill.dump(regr, f)

# with open(path + 'r_inputs.pkl', 'wb') as f:
#     dill.dump(r_inputs, f)

# with open(path + 'train_metrics.pkl', 'wb') as f:
#     dill.dump([r_train,rms_train,slope_train,r_train_season,slope_train_season,season.transpose()], f)

# with open(path + 'test_metrics.pkl', 'wb') as f:
#     dill.dump([r_test,rms_test,slope_test,r_test_season,slope_test_season,targets_sum,predictions_sum,targets_mean,predictions_mean,targets_diff,predictions_diff,rss], f)

# file_creation(path, targets_all, 'Targets')
# file_creation(path, predictions_all, 'Predictions')
# file_creation(path, (targets_all-predictions_all), 'Targets - Predictions')

# with open(path + 'readme.txt', 'w') as f:
#     f.write ('name: ' + name)
#     f.write('\n')
#     f.write('period: ' + filename[35:42])
#     f.write ('\n')
#     f.write ('input_features: ')
#     f.write (str([i for i in inputs_names]))
#     f.write ('\n')
#     f.write('n_bins: ' + str(n_bins))
#     f.write ('\n')
