# Bootstrap (Flagellate Production Rate, boxes, spatial means)

## Importing

In [None]:
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import xskillscore as xs

from sklearn.pipeline import make_pipeline
from sklearn.compose import TransformedTargetRegressor
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.preprocessing import StandardScaler

from sklearn.ensemble import HistGradientBoostingRegressor

from sklearn.metrics import root_mean_squared_error as rmse

import os
import dill

from tqdm import tqdm

import cmocean.cm as cm
import salishsea_tools.viz_tools as sa_vi

from sklearn.utils import resample


## Datasets Preparation

In [None]:
# Creation of the training - testing datasets
def datasets_preparation(dataset, name, inputs_names, regions, boxes):

    dayofyear = np.tile(np.arange(0,len(dataset.time_counter)//len(np.unique(dataset.time_counter.dt.year))), len(np.unique(dataset.time_counter.dt.year)))

    inputs = []
    
    if inputs_names[-1] =='Day_of_year':   
        for i in inputs_names[0:-1]:
            inputs.append(dataset[i].to_numpy().flatten())       
        inputs.append(np.repeat(dayofyear, len(dataset.x)*len(dataset.y)))

    else:        
        for i in inputs_names:
            inputs.append(dataset[i].to_numpy().flatten())

    inputs = np.array(inputs)
    targets = np.ravel(dataset[name])

    regions = np.tile(np.ravel(regions), len(np.unique(dataset.time_counter)))

    indx = np.where(np.isfinite(targets) & ~np.isnan(regions))
    inputs = inputs[:,indx[0]]
    targets = targets[indx[0]]

    regions = regions[indx[0]]

    inputs = inputs.transpose()

    targets_mean = np.zeros((len(dataset.time_counter),len(boxes)))
    inputs_mean = np.zeros((len(dataset.time_counter),len(boxes),(len(inputs_names))))

    for i in range (0, len(boxes)):

        indx2 = np.where(regions==i)
        inputs2 = inputs[indx2[0],:]
        targets2 = targets[indx2[0]]

        for j in range(0, len(inputs_names)):

            temp = np.reshape(inputs2[:,j], (len(dataset.time_counter), len(inputs2) // len(dataset.time_counter)))
            inputs_mean[:,i,j] = temp.mean(axis=1)

        temp = np.reshape(targets2, (len(dataset.time_counter), len(targets2) // len(dataset.time_counter)))
        targets_mean[:,i] = temp.mean(axis=1)

    return(inputs_mean, targets_mean, indx, regions)


## Regressor

In [None]:
def regressor (inputs, targets, drivers, day_input):

    if day_input == []:
        model = TransformedTargetRegressor(regressor=make_pipeline(ColumnTransformer(
        transformers=[('drivers', StandardScaler(), np.arange(0,len(drivers)))], remainder='passthrough'),
        HistGradientBoostingRegressor()),
        transformer=StandardScaler())

    else:
        model = TransformedTargetRegressor(regressor=make_pipeline(ColumnTransformer(
            transformers=[('drivers', StandardScaler(), np.arange(0,len(drivers)))], remainder='passthrough'),
            HistGradientBoostingRegressor(categorical_features=[len(drivers)])),
            transformer=StandardScaler())
        
    regr = model.fit(inputs,targets)

    return(regr)

## Plotting (regions)

In [None]:
def plot_box(ax, corn, colour):

    ax.plot([corn[2], corn[3], corn[3], corn[2], corn[2]], 
    [corn[0], corn[0], corn[1], corn[1], corn[0]], '-', color=colour)
    

## Plotting (histograms)

In [None]:
def plot_hist (variable,name,boxnames):

    fig, axs = plt.subplots(1,len(boxnames), figsize = (20,6), layout='constrained')

    for j in range(0,len(boxnames)):

        h = axs[j].hist(variable[:,j])
        axs[j].set_title(boxnames[j])

        axs[j].set_ylabel('Frequency')
        fig.suptitle(name)
        

## Training - Testing

In [None]:
def train_test(dataset,inputs,targets, inputs_test, targets_test, drivers, day_input, boxes, labels):

    predictions = np.full(targets.shape,np.nan)
    predictions_test = np.full(targets_test.shape,np.nan)

    season = np.zeros((len(boxes),len(labels)))

    for i in range(0, len(boxes)):

        # Training
        regr = regressor(inputs[:,i,:], targets[:,i], drivers, day_input)
        predictions[:,i] = regr.predict(inputs[:,i,:])

        # Testing
        predictions_test[:,i] = regr.predict(inputs_test[:,i,:])

        temp = np.array(np.split(targets[:,i],len(np.unique(dataset.time_counter.dt.year)),axis=0))
        season[i] = np.mean(temp, axis=0)
    
    return(predictions, season, predictions_test)


## Metrics (Training)

In [None]:
def metrics_train (dataset, targets, predictions, boxes, season):

    r_train = np.full(len(boxes),np.nan)
    rms_train = np.full(len(boxes),np.nan)
    slope_train = np.full(len(boxes),np.nan)

    r_train_season = np.full(len(boxes),np.nan)
    slope_train_season = np.full(len(boxes),np.nan)

    for i in range (0,len(boxes)):

        r_train[i] = np.round(np.corrcoef(predictions[:,i],targets[:,i])[0][1],3)
        rms_train[i] = rmse(predictions[:,i],targets[:,i]) / np.mean(targets[:,i]) * 100
        m,_ = np.polyfit(predictions[:,i],targets[:,i], deg=1)
        slope_train[i] = np.round(m,3)

        season_train = np.tile(season, len(np.unique(dataset.time_counter.dt.year))) # Broadcasting season to all training years

        r_train_season[i] = np.round(np.corrcoef(targets[:,i]-season_train[i],predictions[:,i]-season_train[i])[0][1],3)
        m,_ = np.polyfit(targets[:,i]-season_train[i],predictions[:,i]-season_train[i], deg=1)
        slope_train_season[i] = np.round(m,3)

    return (r_train,rms_train,slope_train,r_train_season,slope_train_season)


## Metrics (Testing)

In [None]:
def metrics_test (dataset_test, targets_test, predictions_test, boxes, season):

    r_test = np.full(len(boxes),np.nan)
    rms_test = np.full(len(boxes),np.nan)
    slope_test = np.full(len(boxes),np.nan)

    r_test_season = np.full(len(boxes),np.nan)
    slope_test_season = np.full(len(boxes),np.nan)

    targets_sum = np.full((len(boxes), len(np.unique(dataset_test.time_counter.dt.year))), np.nan)
    predictions_sum = np.full((len(boxes), len(np.unique(dataset_test.time_counter.dt.year))), np.nan)

    rms_test_s = np.full(len(boxes),np.nan)

    for i in range (0,len(boxes)):

        r_test[i] = np.round(np.corrcoef(predictions_test[:,i],targets_test[:,i])[0][1],3)
        rms_test[i] = rmse(predictions_test[:,i],targets_test[:,i]) / np.mean(targets_test[:,i]) * 100
        m,_ = np.polyfit(predictions_test[:,i],targets_test[:,i], deg=1)
        slope_test[i] = np.round(m,3)

        season_test = np.tile(season,len(np.unique(dataset_test.time_counter.dt.year))) # Broadcasting season to all testing years

        r_test_season[i] = np.round(np.corrcoef(targets_test[:,i]-season_test[i],predictions_test[:,i]-season_test[i])[0][1],3)
        m,_ = np.polyfit(targets_test[:,i]-season_test[i],predictions_test[:,i]-season_test[i], deg=1)
        slope_test_season[i] = np.round(m,3)

        targets_an = np.reshape(targets_test[:,i]-season_test[i],(len(np.unique(dataset_test.time_counter.dt.year)),len(season[0]))) 
        predictions_an = np.reshape(predictions_test[:,i]-season_test[i],(len(np.unique(dataset_test.time_counter.dt.year)),len(season[0]))) 

        rms_test_s[i] = 0

        for j in range (0, len(np.unique(dataset_test.time_counter.dt.year))):

            targets_sum[i,j] = np.sum(targets_an[j])
            predictions_sum[i,j] = np.sum(predictions_an[j])

            rms_test_s[i] = rms_test_s[i] + (targets_sum[i,j] - predictions_sum[i,j])**2

        rms_test_s[i] = np.sqrt(rms_test_s[i])/2

    return (r_test,rms_test,slope_test,r_test_season,slope_test_season,rms_test_s)


## Initiation

In [None]:
name = 'Flagellate_Production_Rate'
units = '[mmol N m-2 s-1]'
category = 'Production rates'

filename = '/data/ibougoudis/MOAD/files/inputs/jan_mar.nc'

drivers = ['Summation_of_solar_radiation', 'Mean_pressure', 'Mean_precipitation']
day_input = []

inputs_names = drivers + day_input

n_bins=255

if filename[35:42] == 'jan_mar': # 75 days, 1st period
    period = '(16 Jan - 31 Mar)'
    id = '1'
    months = ['January', 'February', 'March']

elif filename[35:42] == 'jan_apr': # 120 days, 2nd period
    period = '(01 Jan - 30 Apr)'
    id = '2'
    months = ['January', 'February', 'March', 'April']

elif filename[35:42] == 'feb_apr': # 75 days, 3rd period
    period = '(15 Feb - 30 Apr)'
    id = '3'
    months = ['February', 'March', 'April']

elif filename[35:42] == 'apr_jun': # 76 days, 4th period
    period = '(16 Apr - 30 Jun)'
    id = '4'
    months = ['April', 'May', 'June']

elif filename[35:42] == 'may_sep': # 153 days, 5th period
    period = '(01 May - 30 Sep)'
    id = '5'
    months = ['May', 'June', 'July', 'August', 'September']

ds = xr.open_dataset(filename)


## Regions

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 9))
mycmap = cm.deep
mycmap.set_bad('grey')
ax.pcolormesh(ds[name][0], cmap=mycmap)
sa_vi.set_aspect(ax)

SoG_north = [650, 730, 100, 200]
plot_box(ax, SoG_north, 'g')
SoG_center = [450, 550, 200, 300]
plot_box(ax, SoG_center, 'b')
Fraser_plume = [380, 460, 260, 330]
plot_box(ax, Fraser_plume, 'm')
SoG_south = [320, 380, 280, 350]
plot_box(ax, SoG_south, 'k')
Haro_Boundary = [290, 350, 210, 280]
plot_box(ax, Haro_Boundary, 'm')
JdF_west = [250, 425, 25, 125]
plot_box(ax, JdF_west, 'c')
JdF_east = [200, 290, 150, 260]
plot_box(ax, JdF_east, 'w')
PS_all = [0, 200, 80, 320]
plot_box(ax, PS_all, 'm')
PS_main = [20, 150, 200, 280]
plot_box(ax, PS_main, 'r')

boxnames = ['SoG_north','SoG_center','Fraser_plume','SoG_south', 'Haro_Boundary', 'JdF_west', 'JdF_east', 'PS_all', 'PS_main']
fig.legend(boxnames)

boxes = [SoG_north,SoG_center,Fraser_plume,SoG_south,Haro_Boundary,JdF_west,JdF_east,PS_all,PS_main]

regions0 = np.full((len(ds.y),len(ds.x)),np.nan)

for i in range (0, len(boxes)):
    regions0[boxes[i][0]:boxes[i][1], boxes[i][2]:boxes[i][3]] = i

regions0 = xr.DataArray(regions0,dims = ['y','x'])

# # Low resolution

temp = []

for i in boxes:
    temp.append([x//5 for x in i])

boxes = temp

## Datasets

In [None]:
# Low resolution

ds = ds.isel(y=(np.arange(ds.y[0], ds.y[-1], 5)), 
    x=(np.arange(ds.x[0], ds.x[-1], 5)))

regions0 = regions0.isel(y=(np.arange(regions0.y[0], regions0.y[-1], 5)), 
    x=(np.arange(regions0.x[0], regions0.x[-1], 5)))

dataset = ds.sel(time_counter = slice('2007', '2020'))
inputs,targets,indx,regions = datasets_preparation(dataset, name, inputs_names, regions0, boxes)

labels = np.unique(dataset.time_counter.dt.strftime('%d %b'))

dataset_test = ds.sel(time_counter = slice('2021', '2024'))
inputs_test,targets_test,indx_test,regions_test = datasets_preparation(dataset_test, name, inputs_names, regions0, boxes)


## Bootstrap

In [None]:
n_resamples = 100

r_train = np.full((n_resamples+1, len(boxes)), np.nan)
rms_train = np.full_like(r_train, np.nan)
slope_train = np.full_like(r_train, np.nan)

r_train_season = np.full_like(r_train, np.nan)
slope_train_season = np.full_like(r_train, np.nan)

r_test = np.full_like(r_train, np.nan)
rms_test = np.full_like(r_train, np.nan)
slope_test = np.full_like(r_train, np.nan)

r_test_season = np.full_like(r_train, np.nan)
slope_test_season = np.full_like(r_train, np.nan)

rms_test_s = np.full_like(r_train, np.nan)

# For the first (original) training session
predictions, season, predictions_test = train_test(dataset, inputs, targets, inputs_test, targets_test, drivers, day_input, boxes, labels)

r_train[0],rms_train[0],slope_train[0],r_train_season[0],slope_train_season[0] = metrics_train(dataset, targets, predictions, boxes, season)

r_test[0],rms_test[0],slope_test[0],r_test_season[0],slope_test_season[0],rms_test_s[0] = metrics_test(dataset_test, targets_test, predictions_test, boxes, season)

targets0 = np.array(np.split(targets,len(np.unique(dataset.time_counter.dt.year)),axis=0))
targets0 = np.transpose(targets0, (1,0,2))

predictions0 = np.array(np.split(predictions,len(np.unique(dataset.time_counter.dt.year)),axis=0))
predictions0 = np.transpose(predictions0, (1,0,2))

errors = targets0 - predictions0

for j in tqdm(range (0, n_resamples)):

    temp = resample(errors.transpose(1,0,2))
    errors_new = temp.transpose(1,0,2)
    targets_new = errors_new + predictions0

    targets_new = np.reshape(targets_new,targets.shape, order='F')
    
    _,season, predictions_test = train_test(dataset, inputs, targets_new, inputs_test, targets_test, drivers,day_input, boxes, labels)
    r_train[j+1],rms_train[j+1],slope_train[j+1],r_train_season[j+1],slope_train_season[j+1] = metrics_train(dataset,targets_new, predictions, boxes ,season)
    r_test[j+1],rms_test[j+1],slope_test[j+1],r_test_season[j+1],slope_test_season[j+1],rms_test_s[j+1] = metrics_test(dataset_test,targets_test, predictions_test, boxes, season) 
    

## Histograms

In [None]:
plot_hist(r_train,'Correlation Coefficient (Training, no seasonality)', boxnames)
plot_hist(r_test_season,'Correlation Coefficient (Testing, no seasonality)', boxnames)
plot_hist(rms_test_s,'Error (Testing)', boxnames)


## Saving

In [None]:
path = '/data/ibougoudis/MOAD/files/results/' + name + '/bootstraps/' + name[0:4].lower() +  '_pr_hist' + id + '_boxes_s4_boot_100/'

os.makedirs(path, exist_ok=True)

with open(path + 'train_metrics.pkl', 'wb') as f:
    dill.dump([r_train,rms_train,slope_train,r_train_season,slope_train_season], f)

with open(path + 'test_metrics.pkl', 'wb') as f:
    dill.dump([r_test,rms_test,slope_test,r_test_season,slope_test_season,rms_test_s], f)

with open(path + 'readme.txt', 'w') as f:
    f.write ('name: ' + name)
    f.write('\n')
    f.write('period: ' + filename[35:42])
    f.write ('\n')
    f.write ('input_features: ')
    f.write (str([i for i in inputs_names]))
    f.write ('\n')
    f.write('n_bins: ' + str(n_bins))
    f.write ('\n')
    f.write('n_resamples: ')
    f.write(str(n_resamples))
    f.write ('\n')
