# Bootstrap (Diatom Production Rate)

## Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xskillscore as xs

from sklearn.pipeline import make_pipeline
from sklearn.compose import TransformedTargetRegressor
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.preprocessing import StandardScaler

from sklearn.ensemble import HistGradientBoostingRegressor

from sklearn.metrics import root_mean_squared_error as rmse

import os
import dill

from tqdm import tqdm

import cmocean.cm as cm
import salishsea_tools.viz_tools as sa_vi

from sklearn.utils import resample


## Datasets Preparation

In [None]:
# Creation of the training - testing datasets
def datasets_preparation(dataset, name, inputs_names):

    x = np.tile(dataset.x, len(dataset.time_counter)*len(dataset.y))
    y = np.tile(np.repeat(dataset.y, len(dataset.x)), len(dataset.time_counter))

    inputs = []
    
    if inputs_names[-1] =='Day_of_year':   
        for i in inputs_names[0:-1]:
            inputs.append(dataset[i].to_numpy().flatten())       
        inputs.append(np.repeat(dataset.time_counter.dt.dayofyear, len(dataset.x)*len(dataset.y)))

    else:        
        for i in inputs_names:
            inputs.append(dataset[i].to_numpy().flatten())

    inputs = np.array(inputs)

    targets = np.ravel(dataset[name])
    
    indx = np.where(np.isfinite(targets) & (x>10) & ((x>100) | (y<880)))
    inputs = inputs[:,indx[0]]
    targets = targets[indx[0]]

    inputs = inputs.transpose()

    return(inputs, targets, indx)


## Datasets Preparation 2

In [None]:
def datasets_preparation2(dataset,variable,indx):

    variable_all = np.full((len(dataset.time_counter) * len(dataset.y) * len(dataset.x)),np.nan)
    variable_all[indx[0]] = variable
    variable_all = np.reshape(variable_all,(len(dataset.time_counter),len(dataset.y),len(dataset.x)))

    # Preparation of the dataarray 
    array = xr.DataArray(variable_all,
        coords = {'time_counter': dataset.time_counter,'y': dataset.y, 'x': dataset.x},
        dims = ['time_counter','y','x'])
        
    return (array)


## Regressor

In [None]:
def regressor (inputs, targets, n_bins, drivers, spatial, inputs_names):

    if spatial == []:
        model = TransformedTargetRegressor(regressor=make_pipeline(ColumnTransformer(
            transformers=[('drivers', StandardScaler(), np.arange(0,len(drivers)))], remainder='passthrough'),
            HistGradientBoostingRegressor(categorical_features=[len(drivers)])),
            transformer=StandardScaler())

    else:
        model = TransformedTargetRegressor(regressor=make_pipeline(ColumnTransformer(
        transformers=[('drivers', StandardScaler(), np.arange(0,len(drivers))), 
            ('spatial', KBinsDiscretizer(n_bins=n_bins,encode='ordinal',strategy='quantile'), np.arange(inputs_names.index(spatial[0]),inputs_names.index(spatial[-1])+1))],
            remainder='passthrough'),
        HistGradientBoostingRegressor(categorical_features=np.arange(inputs_names.index(spatial[0]),len(inputs_names)))),
        transformer=StandardScaler())
    
    regr = model.fit(inputs,targets)

    return(regr)


## Plotting (regions)

In [None]:
def plot_box(ax, corn, colour):

    ax.plot([corn[2], corn[3], corn[3], corn[2], corn[2]], 
    [corn[0], corn[0], corn[1], corn[1], corn[0]], '-', color=colour)
    

## Plotting (histograms)

In [None]:
def plot_hist (variable,name,boxnames):

    fig, axs = plt.subplots(1,len(boxnames), figsize = (20,6), layout='constrained')

    for j in range(0,len(boxnames)):

        h = axs[j].hist(variable[:,j])
        axs[j].set_title(boxnames[j])

        axs[j].set_ylabel('Frequency')
        fig.suptitle(name)
        

## Training - Testing

In [None]:
def train_test(dataset,inputs,targets,indx, dataset_test,inputs_test,indx_test, n_bins, drivers, spatial, inputs_names):

    regr = regressor(inputs, targets, n_bins, drivers, spatial, inputs_names)

    predictions = regr.predict(inputs)
    predictions_all = datasets_preparation2(dataset,predictions,indx)
    
    predictions_test = regr.predict(inputs_test)
    predictions_tests_all = datasets_preparation2(dataset_test,predictions_test,indx_test)

    targets_mean = np.reshape(targets,(len(dataset.time_counter), len(indx[0]) // len(dataset.time_counter)))
    targets_mean = np.mean(targets_mean,axis=1)

    season = np.array(np.split(targets_mean,len(np.unique(dataset.time_counter.dt.year)),axis=0))
    season = np.mean(season, axis=0)
    
    return(predictions,predictions_all,predictions_tests_all,season)


## Metrics (Training)

In [None]:
def metrics_train (dataset,targets_all,predictions_all,boxes,regions0,season):

    r_train = np.full(len(boxes),np.nan)
    rms_train = np.full(len(boxes),np.nan)
    slope_train = np.full(len(boxes),np.nan)

    r_train_season = np.full(len(boxes),np.nan)
    slope_train_season = np.full(len(boxes),np.nan)

    for i in range (0,len(boxes)):

        targets = targets_all.where(regions0==i).mean(['y','x'])
        predictions = predictions_all.where(regions0==i).mean(['y','x'])

        r_train[i] = xr.corr(targets,predictions)
        rms_train[i] = xs.rmse(targets,predictions,skipna=True) / np.mean(targets) * 100
        slope_train[i] = xs.linslope(targets,predictions,skipna=True)

        climatology = targets_all[:,boxes[i][0]:boxes[i][1], boxes[i][2]:boxes[i][3]]
        season_train = climatology.to_numpy()
        season_train = np.reshape(season_train,(len(season),len(np.unique(dataset.time_counter.dt.year)),climatology.shape[1]*climatology.shape[2]),order='F')
        season_train = np.nanmean(season_train,axis=(1,2))
        season_train = np.tile(season_train,len(np.unique(dataset.time_counter.dt.year))) # Broadcasting season to all training years

        r_train_season[i] = np.round(np.corrcoef(targets-season_train,predictions-season_train)[0][1],3)
        m,_ = np.polyfit(targets-season_train,predictions-season_train, deg=1)
        slope_train_season[i] = np.round(m,3)

    return (r_train,rms_train,slope_train,r_train_season,slope_train_season)


## Metrics (Testing)

In [None]:
def metrics_test (dataset_test,targets_test_all,predictions_test_all, boxes,regions0,season, dataset,targets_all):

    r_test = np.full(len(boxes),np.nan)
    rms_test = np.full(len(boxes),np.nan)
    slope_test = np.full(len(boxes),np.nan)

    r_test_season = np.full(len(boxes),np.nan)
    slope_test_season = np.full(len(boxes),np.nan)

    targets_sum = np.full((len(boxes), len(np.unique(dataset_test.time_counter.dt.year))), np.nan)
    predictions_sum = np.full((len(boxes), len(np.unique(dataset_test.time_counter.dt.year))), np.nan)

    rms_test_s = np.full(len(boxes),np.nan)

    for i in range (0,len(boxes)):

        targets = targets_test_all.where(regions0==i).mean(['y','x'])
        predictions = predictions_test_all.where(regions0==i).mean(['y','x'])

        r_test[i] = xr.corr(targets,predictions)
        rms_test[i] = xs.rmse(targets,predictions,skipna=True) / np.mean(targets) * 100
        slope_test[i] = xs.linslope(targets,predictions,skipna=True)

        climatology = targets_all[:,boxes[i][0]:boxes[i][1], boxes[i][2]:boxes[i][3]]
        season_test = climatology.to_numpy()
        season_test = np.reshape(season_test,(len(season),len(np.unique(dataset.time_counter.dt.year)),climatology.shape[1]*climatology.shape[2]),order='F')
        season_test = np.nanmean(season_test,axis=(1,2))
        season_test = np.tile(season_test,len(np.unique(dataset_test.time_counter.dt.year))) # Broadcasting season to all testing years

        r_test_season[i] = np.round(np.corrcoef(targets-season_test,predictions-season_test)[0][1],3)
        m,_ = np.polyfit(targets-season_test,predictions-season_test, deg=1)
        slope_test_season[i] = np.round(m,3)

        targets_sum[i] = (targets-season_test).groupby(targets.time_counter.dt.year).sum().values
        predictions_sum[i] =  (predictions-season_test).groupby(predictions.time_counter.dt.year).sum().values

        rms_test_s[i] = 0

        for j in range (0, len(np.unique(dataset_test.time_counter.dt.year))):

            rms_test_s[i] = rms_test_s[i] + (targets_sum[i,j] - predictions_sum[i,j])**2

        rms_test_s[i] = np.sqrt(rms_test_s[i])/2

    return (r_test,rms_test,slope_test,r_test_season,slope_test_season,rms_test_s)


## Initiation

In [None]:
name = 'Diatom_Production_Rate'
units = '[mmol N m-2 s-1]'
category = 'Production rates'

filename = '/data/ibougoudis/MOAD/files/inputs/jan_mar.nc'

drivers = ['Summation_of_solar_radiation', 'Mean_air_temperature', 'Mean_wind_speed', 'Mean_precipitation']
spatial = ['Latitude', 'Longitude']

inputs_names = drivers + spatial

n_bins=255

if filename[35:42] == 'jan_mar': # 75 days, 1st period
    period = '(16 Jan - 31 Mar)'
    id = '1'
    months = ['January', 'February', 'March']

elif filename[35:42] == 'jan_apr': # 120 days, 2nd period
    period = '(01 Jan - 30 Apr)'
    id = '2'
    months = ['January', 'February', 'March', 'April']

elif filename[35:42] == 'feb_apr': # 75 days, 3rd period
    period = '(15 Feb - 30 Apr)'
    id = '3'
    months = ['February', 'March', 'April']

elif filename[35:42] == 'apr_jun': # 76 days, 4th period
    period = '(16 Apr - 30 Jun)'
    id = '4'
    months = ['April', 'May', 'June']

elif filename[35:42] == 'may_sep': # 153 days, 5th period
    period = '(01 May - 30 Sep)'
    id = '5'
    months = ['May', 'June', 'July', 'August', 'September']

ds = xr.open_dataset(filename)


## Regions

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 9))
mycmap = cm.deep
mycmap.set_bad('grey')
ax.pcolormesh(ds[name][0], cmap=mycmap)
sa_vi.set_aspect(ax)

SoG_north = [650, 730, 100, 200]
plot_box(ax, SoG_north, 'g')
SoG_center = [450, 550, 200, 300]
plot_box(ax, SoG_center, 'b')
Fraser_plume = [380, 460, 260, 330]
plot_box(ax, Fraser_plume, 'm')
SoG_south = [320, 380, 280, 350]
plot_box(ax, SoG_south, 'k')
Haro_Boundary = [290, 350, 210, 280]
plot_box(ax, Haro_Boundary, 'm')
JdF_west = [250, 425, 25, 125]
plot_box(ax, JdF_west, 'c')
JdF_east = [200, 290, 150, 260]
plot_box(ax, JdF_east, 'w')
PS_all = [0, 200, 80, 320]
plot_box(ax, PS_all, 'm')
PS_main = [20, 150, 200, 280]
plot_box(ax, PS_main, 'r')

boxnames = ['SoG_north','SoG_center','Fraser_plume','SoG_south', 'Haro_Boundary', 'JdF_west', 'JdF_east', 'PS_all', 'PS_main']
fig.legend(boxnames)

boxes = [SoG_north,SoG_center,Fraser_plume,SoG_south,Haro_Boundary,JdF_west,JdF_east,PS_all,PS_main]

regions0 = np.full((len(ds.y),len(ds.x)),np.nan)

for i in range (0, len(boxes)):
    regions0[boxes[i][0]:boxes[i][1], boxes[i][2]:boxes[i][3]] = i

regions0 = xr.DataArray(regions0,dims = ['y','x'])

# # Low resolution

temp = []

for i in boxes:
    temp.append([x//5 for x in i])

boxes = temp

## Datasets

In [None]:
# Low resolution

ds = ds.isel(y=(np.arange(ds.y[0], ds.y[-1], 5)), 
    x=(np.arange(ds.x[0], ds.x[-1], 5)))

regions0 = regions0.isel(y=(np.arange(regions0.y[0], regions0.y[-1], 5)), 
    x=(np.arange(regions0.x[0], regions0.x[-1], 5)))

dataset = ds.sel(time_counter = slice('2007', '2020'))
inputs,targets,indx = datasets_preparation(dataset, name, inputs_names)
targets_all = datasets_preparation2(dataset,targets,indx)

dataset_test = ds.sel(time_counter = slice('2021', '2024'))
inputs_test,targets_test,indx_test = datasets_preparation(dataset_test, name, inputs_names)
targets_test_all = datasets_preparation2(dataset_test,targets_test,indx_test)


## Bootstrap

In [None]:
n_resamples = 100

r_train = np.full((n_resamples+1, len(boxes)), np.nan)
rms_train = np.full_like(r_train, np.nan)
slope_train = np.full_like(r_train, np.nan)

r_train_season = np.full_like(r_train, np.nan)
slope_train_season = np.full_like(r_train, np.nan)

r_test = np.full_like(r_train, np.nan)
rms_test = np.full_like(r_train, np.nan)
slope_test = np.full_like(r_train, np.nan)

r_test_season = np.full_like(r_train, np.nan)
slope_test_season = np.full_like(r_train, np.nan)

rms_test_s = np.full_like(r_train, np.nan)

# For the first (original) training session

predictions,predictions_all,predictions_test_all,season = train_test(dataset,inputs,targets,indx, dataset_test,inputs_test,indx_test, n_bins,drivers,spatial,inputs_names)
r_train[0],rms_train[0],slope_train[0],r_train_season[0],slope_train_season[0] = metrics_train(dataset,targets_all,predictions_all, boxes,regions0,season)
r_test[0],rms_test[0],slope_test[0],r_test_season[0],slope_test_season[0],rms_test_s[0] = metrics_test(dataset_test,targets_test_all,predictions_test_all, 
    boxes,regions0,season, dataset,targets_all)

targets0 = np.reshape(targets,(len(dataset.time_counter), len(indx[0]) // len(dataset.time_counter)))
targets0 = np.array(np.split(targets0,len(np.unique(dataset.time_counter.dt.year)),axis=0))
targets0 = np.transpose(targets0, (1,0,2))

predictions0 = np.reshape(predictions,(len(dataset.time_counter), len(indx[0]) // len(dataset.time_counter)))
predictions0 = np.array(np.split(predictions0,len(np.unique(dataset.time_counter.dt.year)),axis=0))
predictions0 = np.transpose(predictions0, (1,0,2))

errors = targets0 - predictions0

for j in tqdm(range (0, n_resamples)):

    temp = resample(errors.transpose(1,0,2))
    errors_new = temp.transpose(1,0,2)
    targets_new = np.ravel(errors_new) + predictions

    targets_new_all = datasets_preparation2(dataset,targets_new,indx)
    
    _,predictions_all,predictions_test_all,season = train_test(dataset,inputs,targets_new,indx, dataset_test,inputs_test,indx_test, n_bins,drivers,spatial,inputs_names)
    r_train[j+1],rms_train[j+1],slope_train[j+1],r_train_season[j+1],slope_train_season[j+1] = metrics_train(dataset,targets_new_all,predictions_all, boxes,regions0,season)
    r_test[j+1],rms_test[j+1],slope_test[j+1],r_test_season[j+1],slope_test_season[j+1],rms_test_s[j+1] = metrics_test(dataset_test,targets_test_all,predictions_test_all, 
        boxes,regions0,season, dataset,targets_all)
    

## Histograms

In [None]:
plot_hist(r_train,'Correlation Coefficient (Training, no seasonality)', boxnames)
plot_hist(r_test_season,'Correlation Coefficient (Testing, no seasonality)', boxnames)
plot_hist(rms_test_s,'Error (Testing)', boxnames)


## Saving

In [None]:
path = '/data/ibougoudis/MOAD/files/results/' + name + '/bootstraps/' + name[0:4].lower() + '_pr_hist' + id + '_boxes_s10_boot_100/'

os.makedirs(path, exist_ok=True)

with open(path + 'train_metrics.pkl', 'wb') as f:
    dill.dump([r_train,rms_train,slope_train,r_train_season,slope_train_season], f)

with open(path + 'test_metrics.pkl', 'wb') as f:
    dill.dump([r_test,rms_test,slope_test,r_test_season,slope_test_season,rms_test_s], f)

with open(path + 'readme.txt', 'w') as f:
    f.write ('name: ' + name)
    f.write('\n')
    f.write('period: ' + filename[35:42])
    f.write ('\n')
    f.write ('input_features: ')
    f.write (str([i for i in inputs_names]))
    f.write ('\n')
    f.write('n_bins: ' + str(n_bins))
    f.write ('\n')
    f.write('n_resamples: ')
    f.write(str(n_resamples))
    f.write ('\n')
