# Predicting Diatom concentration with functional regression based on the oceanographic boxes (grid points)

## Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xskillscore as xs

from sklearn.compose import make_column_transformer
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import r_regression

from skfda.representation.grid import FDataGrid
from skfda.ml.clustering import KMeans

from skfda.ml.regression import HistoricalLinearRegression

from skfda.misc.hat_matrix import NadarayaWatsonHatMatrix, LocalLinearRegressionHatMatrix, KNeighborsHatMatrix
from skfda.preprocessing.smoothing import KernelSmoother

from sklearn.metrics import root_mean_squared_error as rmse

import os
import lzma
import dill

from tqdm import tqdm

import random

import cmocean.cm as cm
import salishsea_tools.viz_tools as sa_vi


## Datasets Preparation (Function)

In [None]:
# Creation of the training - testing datasets
def datasets_preparation(dataset, dataset2, regions0, name, inputs_names):
    
    indx = np.where((dataset.time_counter.dt.month==2) & (dataset.time_counter.dt.day==29))
    
    targets = dataset[name].to_numpy().reshape(*dataset[name].to_numpy().shape[:1],-1)
    
    inputs = []
    
    for i in inputs_names:
        inputs.append(dataset2[i].to_numpy().reshape(*dataset2[i].to_numpy().shape[:1],-1))

    inputs = np.array(inputs)

    # Deleting 29 of February
    inputs = np.delete(inputs,indx,axis=1)
    targets = np.delete(targets,indx,axis=0)

    # Splitting in years
    inputs = np.split(inputs,len(np.unique(dataset.time_counter.dt.year)),axis=1)
    targets = np.split(targets,len(np.unique(dataset.time_counter.dt.year)),axis=0)

    # Grouping all the years (amount of days for one year * amount of grid boxes)
    inputs = np.concatenate(inputs,axis=2)
    targets = np.concatenate(targets,axis=1)

    x = np.tile(dataset.x, len(np.unique(dataset.time_counter.dt.year))*len(dataset.y))
    y = np.tile(np.repeat(dataset.y, len(dataset.x)), len(np.unique(dataset.time_counter.dt.year)))

    indx = np.where((~np.isnan(targets).any(axis=0))& (x>10) & ((x>100) | (y<880)))
    inputs = inputs[:,:,indx[0]]
    targets = targets[:,indx[0]]

    regions = np.tile(np.ravel(regions0), len(dataset.time_counter))
    regions = regions[indx[0]]

    return(inputs, targets, indx, regions)


## Datasets Preparation 2 (Function)

In [None]:
# Creation of the data arrays
def datasets_preparation2(variable, name, units, dataset, indx):

    dates = pd.DatetimeIndex(dataset['time_counter'].values)
    indx2 = ~((dataset.time_counter.dt.month==2) & (dataset.time_counter.dt.day==29))
    dates = dates[indx2]
    
    # Creating the maps size (with nans)
    variable_all = np.full((len(variable), len(dataset.y) * len(dataset.x)),np.nan)
    variable_all[:,indx[0]] = variable
    variable_all = np.reshape(variable_all,(len(dates),len(dataset.y),len(dataset.x)))

    # Preparation of the dataarray 
    array = xr.DataArray(variable_all,
        coords = {'time_counter': dates,'y': dataset.y, 'x': dataset.x},
        dims = ['time_counter','y','x'],
        attrs=dict(description= name,
        units=units))
        
    return (array)


## File Creation (Function)

In [None]:
def file_creation(path, variable, name):

    temp = variable.to_dataset(name=name)
    temp.to_netcdf(path = path + 'targets_predictions.nc', mode='a', encoding={name:{"zlib": True, "complevel": 9}})
    

## Regressor (Function)

In [None]:
def regressor (inputs, targets, j, r_inputs):

    # Printing of the correlation coefficients
    temp_inputs = np.reshape(inputs,(len(inputs),inputs.shape[1]*inputs.shape[2]))
    temp_inputs = temp_inputs.transpose()
    temp_targets = np.ravel(targets)

    r_inputs[j] = np.round(r_regression(temp_inputs,temp_targets),2)

    # Scaling the inputs
    scaler_inputs = make_column_transformer((StandardScaler(), np.arange(0,len(inputs))))
    temp_inputs = scaler_inputs.fit_transform(temp_inputs)
    temp_inputs = temp_inputs.transpose()
    inputs = np.reshape(temp_inputs,(len(inputs),inputs.shape[1],inputs.shape[2]))   
    inputs = np.transpose(inputs,axes=(2,1,0))
    
    # Scaling the targets
    scaler_targets = StandardScaler()
    temp_targets = np.expand_dims(temp_targets,-1)
    temp_targets = scaler_targets.fit_transform(temp_targets)
    targets = temp_targets.reshape(targets.shape)

    # Final transformations
    targets = targets.transpose()
    inputs = FDataGrid(data_matrix=inputs, grid_points=np.arange(0,len(targets[0])))
    targets = FDataGrid(data_matrix=targets, grid_points=np.arange(0,len(targets[0])))

    ## Smoothing
    # targets = targets.to_basis(FourierBasis(n_basis=10)
    kernel_estimator = LocalLinearRegressionHatMatrix(bandwidth=1)
    smoother = KernelSmoother(kernel_estimator=kernel_estimator)
    inputs = smoother.fit_transform(inputs)

    model = HistoricalLinearRegression(n_intervals=3,lag=74)
    regr = model.fit(inputs,targets)

    return(regr,scaler_inputs,scaler_targets,smoother,r_inputs)


## Scaling (Function)

In [None]:
def scaling(regr,inputs,scaler_inputs,targets,scaler_targets,smoother):

    # Scaling the inputs
    temp = np.reshape(inputs,(len(inputs),inputs.shape[1]*inputs.shape[2]))
    temp = temp.transpose()
    temp = scaler_inputs.transform(temp)
    temp = temp.transpose()        
    inputs = np.reshape(temp,(len(inputs),inputs.shape[1],inputs.shape[2]))
        
    inputs = np.transpose(inputs,axes=(2,1,0))
    inputs = FDataGrid(data_matrix=inputs, grid_points=np.arange(0,len(targets)))

    inputs = smoother.transform(inputs)

    predictions = regr.predict(inputs)

    # Post-processing of predictions
    predictions = np.array(predictions.to_grid(np.arange(0,len(targets))).data_matrix)
    predictions = np.squeeze(predictions,2)

    # Scaling the predictions
    temp = np.ravel(predictions)
    temp = np.expand_dims(temp,axis=-1)
    temp = scaler_targets.inverse_transform(temp)
    predictions = temp.reshape(predictions.shape)
    predictions = predictions.transpose()

    return(predictions)


## Scatter Plot (Function)

In [None]:
def scatter_plot(targets, predictions, name):

    # compute slope m and intercept b
    m, b = np.polyfit(targets, predictions, deg=1)

    fig, ax = plt.subplots(2, figsize=(5,10), layout='constrained')

    ax[0].scatter(targets,predictions, alpha = 0.2, s = 10)

    lims = [np.min([ax[0].get_xlim(), ax[0].get_ylim()]),
        np.max([ax[0].get_xlim(), ax[0].get_ylim()])]

    # plot fitted y = m*x + b
    ax[0].axline(xy1=(0, b), slope=m, color='r')

    ax[0].set_xlabel('targets')
    ax[0].set_ylabel('predictions')
    ax[0].set_xlim(lims)
    ax[0].set_ylim(lims)
    ax[0].set_aspect('equal')

    ax[0].plot(lims, lims,linestyle = '--',color = 'k')

    h = ax[1].hist2d(targets,predictions, bins=100, cmap='jet', 
        range=[lims,lims], cmin=0.1, norm='log')
    
    ax[1].plot(lims, lims,linestyle = '--',color = 'k')

    # plot fitted y = m*x + b
    ax[1].axline(xy1=(0, b), slope=m, color='r')

    ax[1].set_xlabel('targets')
    ax[1].set_ylabel('predictions')
    ax[1].set_aspect('equal')

    fig.colorbar(h[3],ax=ax[1], location='bottom')

    fig.suptitle(name)

    plt.show()

    return(m)


## Plotting (Regions)

In [None]:
def plot_box(ax, corn, colour):

    ax.plot([corn[2], corn[3], corn[3], corn[2], corn[2]], 
    [corn[0], corn[0], corn[1], corn[1], corn[0]], '-', color=colour)

## Plotting Criteria (Function)

In [None]:
def plotting_criteria(dates, variable, year_variable, title):
    
    indx = pd.DatetimeIndex(dates)
    fig, ax = plt.subplots()

    scatter= ax.scatter(dates,variable, marker='.', c=indx.month)
    plt.xticks(rotation=70)
    ax.legend(handles=scatter.legend_elements()[0], labels=['February','March','April'])
    ax.plot(dates[(indx.month == 3) & (indx.day == 15)], year_variable,color='red',marker='*')
    fig.suptitle(title + ' (15 Feb - 30 Apr)')
    
    fig.show()


## Plotting Mean Values (Function)

In [None]:
def plotting_mean_values(dates, targets_mean, predictions_mean, i, units, category, region, boxnames):

    r_train = np.round(np.corrcoef(np.ravel(targets_mean), np.ravel(predictions_mean))[0][1],3)
    rms_train = rmse(np.ravel(targets_mean), np.ravel(predictions_mean))
    m,_ = np.polyfit(np.ravel(targets_mean), np.ravel(predictions_mean), deg=1)
    slope_train = np.round(m,3)

    temp = pd.DataFrame(np.vstack((r_train,rms_train,slope_train)).transpose(),index=[boxnames[i]],columns=['r','rms','slope'])
    display(temp)

    years = np.unique(dates.year)
    ticks = [0]
    
    fig, _ = plt.subplots(figsize=(19,5))
    
    targets_mean = np.ma.array(targets_mean)
    predictions_mean = np.ma.array(predictions_mean)

    for year in years[:-1]:
        ticks.append((np.where(dates.year==year)[0][-1]+1))
        targets_mean[(np.where(dates.year==year)[0][-1]+1)] = np.ma.masked
        predictions_mean[(np.where(dates.year==year)[0][-1]+1)] = np.ma.masked

    plt.plot(targets_mean, label = 'targets')
    plt.plot(predictions_mean, label = 'predictions')
    plt.xlabel('Years')
    plt.xticks(ticks,years)
    plt.suptitle('Mean '+category + ' ' +units + ' (15 Feb - 30 Apr) ' + region + ' ' + boxnames[i])
    plt.legend()
    plt.show()

    return(r_train,rms_train,slope_train)
    

## Plotting Maps (Function)

In [None]:
def plotting_maps(targets, predictions, name, units):

    fig, ax = plt.subplots(2,2, figsize = (10,15), layout='tight')

    cmap = plt.get_cmap('cubehelix')
    cmap.set_bad('gray')

    targets.plot(ax=ax[0,0], cmap=cmap, vmin = targets.min(), vmax = targets.max(), cbar_kwargs={'label': name + ' ' + units})
    predictions.plot(ax=ax[0,1], cmap=cmap, vmin = targets.min(), vmax = targets.max(), cbar_kwargs={'label': name + ' ' + units})
    (targets-predictions).plot(ax=ax[1,0], cmap=cmap, cbar_kwargs={'label': name + ' ' + units})

    plt.subplots_adjust(left=0.1,
        bottom=0.1, 
        right=0.95, 
        top=0.95, 
        wspace=0.35, 
        hspace=0.35)

    sa_vi.set_aspect(ax[0,0])
    sa_vi.set_aspect(ax[0,1])
    sa_vi.set_aspect(ax[1,0])

    ax[0,0].title.set_text('Targets')
    ax[0,1].title.set_text('Predictions')
    ax[1,0].title.set_text('Targets-Predictions')
    ax[1,1].axis('off')

    fig.suptitle(name + ' '+ str(targets.time_counter.dt.date.values))

    plt.show()
    

## Post Processing (Function)

In [None]:
def post_processing(dates,regions,boxes,targets,predictions,units,category,boxnames):

    targets_mean = np.zeros((len(boxes),len(dates)))
    predictions_mean = np.zeros((len(boxes),len(dates)))
    
    r_train = np.zeros(len(boxes))
    rms_train = np.zeros(len(boxes))
    slope_train = np.zeros(len(boxes))

    for i in range (0,len(boxes)):

        indx2 = np.where(regions==i) # indexes of the j cluster
        targets2 = targets[:,indx2[0]] # inputs of the j cluster
        predictions2 = predictions[:,indx2[0]] # predictions of the j cluster

        # for the daily mean plot
        targets_mean_temp = np.split(targets2,len(np.unique(dates.year)),axis=1)
        targets_mean_temp = np.ravel(targets_mean_temp)
        targets_mean_temp = np.reshape(targets_mean_temp,(len(dates),int(len(indx2[0])/len(np.unique(dates.year)))))
        targets_mean[i] = np.mean(targets_mean_temp,axis=1)

        predictions_mean_temp = np.split(predictions2,len(np.unique(dates.year)),axis=1)
        predictions_mean_temp = np.ravel(predictions_mean_temp)
        predictions_mean_temp = np.reshape(predictions_mean_temp,(len(dates),int(len(indx2[0])/len(np.unique(dates.year)))))
        predictions_mean[i] = np.mean(predictions_mean_temp,axis=1)

        r_train[i], rms_train[i], slope_train[i] = plotting_mean_values(dates, targets_mean[i], predictions_mean[i], i, units, category, '', boxnames)

    return(r_train,rms_train,slope_train,targets_mean,predictions_mean)


## Evaluation (Function)

In [None]:
def evaluation (years, regr_all, boxes, regions0, dataset, dataset2, name, units, scaler_inputs_all, scaler_targets_all, smoother_all,inputs_names):

    # For every year
    r_years,rms_years,slope_years = np.array([]),np.array([]),np.array([])

    # For the metrics 
    targets_temp,predictions_temp = [],[]

    # The data arrays 
    targets_all,predictions_all = [],[]

    for year in (years):

        dataset_temp = dataset.sel(time_counter=str(year))
        dataset2_temp = dataset2.sel(time_counter=str(year))

        inputs, targets, indx, regions  = datasets_preparation(dataset_temp, dataset2_temp, regions0, name, inputs_names)

        # Predictions for each regressor
        predictions = np.full(targets.shape,np.nan) # size of targets without nans

        for i in range (0,len(boxes)):
            
            indx2 = np.where(regions==i) # indexes of the i cluster
            inputs2 = inputs[:,:,indx2[0]] # inputs of the i cluster
            targets2 = targets[:,indx2[0]] # targets of the i cluster
            predictions2 = scaling(regr_all[i],inputs2,scaler_inputs_all[i],targets,scaler_targets_all[i],smoother_all[i])
            predictions[:,indx2[0]] = predictions2 # putting them in the right place

            targets_temp.extend(np.ravel(targets2))
            predictions_temp.extend(np.ravel(predictions2))

        # Calculating the annual time-series
        m_year = scatter_plot(np.ravel(targets), np.ravel(predictions), name + ' for '+ str(year)) 
        r_year = np.corrcoef(np.ravel(targets), np.ravel(predictions))[0][1]
        rms_year = rmse(np.ravel(targets), np.ravel(predictions))
        
        r_years = np.append(r_years,r_year)
        rms_years = np.append(rms_years,rms_year)
        slope_years = np.append(slope_years,m_year)

        # Daily arrays
        targets_all.append (datasets_preparation2(targets, name + ' _targets', units, dataset_temp, indx))
        predictions_all.append(datasets_preparation2(predictions, name + ' _predictions', units, dataset_temp, indx))   

    # Daily arrays
    targets_all = xr.concat(targets_all, dim='time_counter')
    predictions_all = xr.concat(predictions_all, dim='time_counter')
        
    return(r_years, rms_years, slope_years, targets_all, predictions_all)


## Initiation

In [None]:
name = 'Diatom'
units = '[mmol m-2]'
category = 'Concentrations'

if name == 'Diatom':
    inputs_names = ['Summation_of_solar_radiation','Mean_wind_speed','Mean_air_temperature']
else:
    inputs_names = ['Summation_of_solar_radiation','Mean_air_temperature','Mean_pressure', 'Mean_precipitation', 'Mean_specific_humidity']

ds = xr.open_dataset('/data/ibougoudis/MOAD/files/integrated_original.nc')
ds2 = xr.open_dataset('/data/ibougoudis/MOAD/files/external_inputs.nc')


## Regions

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 9))
mycmap = cm.deep
mycmap.set_bad('grey')
ax.pcolormesh(ds['Diatom'][0], cmap=mycmap)
sa_vi.set_aspect(ax)

SoG_north = [650, 730, 100, 200]
plot_box(ax, SoG_north, 'g')
SoG_center = [450, 550, 200, 300]
plot_box(ax, SoG_center, 'b')
Fraser_plume = [380, 460, 260, 330]
plot_box(ax, Fraser_plume, 'm')
SoG_south = [320, 380, 280, 350]
plot_box(ax, SoG_south, 'k')
Haro_Boundary = [290, 350, 210, 280]
plot_box(ax, Haro_Boundary, 'm')
JdF_west = [250, 425, 25, 125]
plot_box(ax, JdF_west, 'c')
JdF_east = [200, 290, 150, 260]
plot_box(ax, JdF_east, 'w')
PS_all = [0, 200, 80, 320]
plot_box(ax, PS_all, 'm')
PS_main = [20, 150, 200, 280]
plot_box(ax, PS_main, 'r')

boxnames = ['SoG_north','SoG_center','Fraser_plume','SoG_south', 'Haro_Boundary', 'JdF_west', 'JdF_east', 'PS_all', 'PS_main']
fig.legend(boxnames)

boxes = [SoG_north,SoG_center,Fraser_plume,SoG_south,Haro_Boundary,JdF_west,JdF_east,PS_all,PS_main]

regions0 = np.full((len(ds.y),len(ds.x)),np.nan)

for i in range (0, len(boxes)):
    regions0[boxes[i][0]:boxes[i][1], boxes[i][2]:boxes[i][3]] = i

regions0 = xr.DataArray(regions0,dims = ['y','x'])

# Low resolution
temp = []

for i in boxes:
    temp.append([x//5 for x in i])

boxes = temp


## Training

In [None]:
# Low resolution

ds = ds.isel(y=(np.arange(ds.y[0], ds.y[-1], 5)), 
    x=(np.arange(ds.x[0], ds.x[-1], 5)))

ds2 = ds2.isel(y=(np.arange(ds2.y[0], ds2.y[-1], 5)), 
    x=(np.arange(ds2.x[0], ds2.x[-1], 5)))

regions0 = regions0.isel(y=(np.arange(regions0.y[0], regions0.y[-1], 5)), 
    x=(np.arange(regions0.x[0], regions0.x[-1], 5)))

ds = ds.where(regions0>-1)
ds2 = ds2.where(regions0>-1)

r_inputs = np.zeros((len(boxnames), len(inputs_names)))

regr_all = []
scaler_inputs_all = []
scaler_targets_all = []
smoother_all = []

dataset = ds.sel(time_counter = slice('2007', '2020'))
dataset2 = ds2.sel(time_counter = slice('2007', '2020'))

inputs, targets, indx, regions = datasets_preparation(dataset, dataset2, regions0, name, inputs_names)

predictions = np.full(targets.shape,np.nan) # size of targets without nans

for i in tqdm(range (0, len(boxes))):

    indx2 = np.where(regions==i) # indexes of the i region
    inputs2 = inputs[:,:,indx2[0]] # inputs of the i region
    targets2 = targets[:,indx2[0]] # targets of the i region
    
    regr,scaler_inputs,scaler_targets,smoother,r_inputs = regressor(inputs2, targets2, i, r_inputs)

    scaler_inputs_all.append(scaler_inputs)
    scaler_targets_all.append(scaler_targets)
    smoother_all.append(smoother)

    regr_all.append(regr)
    
    predictions[:,indx2[0]] = scaling(regr_all[i],inputs2,scaler_inputs_all[i],targets,scaler_targets_all[i],smoother_all[i]) # putting them in the right place

print('Metrics between input features and '+name)
temp = pd.DataFrame(r_inputs,index=boxnames, columns=['Summation of solar radiation','Mean wind speed','Mean air temperature'])
display(temp)


## Heatmaps

In [None]:
for i in range(0,len(inputs_names)):

    fig, axs = plt.subplots(1,len(boxes), figsize = (28,6), layout='constrained')

    for j in range(0,len(boxes)):

        temp = regr_all[j].coef_
        coeff = temp.data_matrix
        coeff = np.where(coeff==0,np.nan,coeff)

        if j==0: #first time for this input feature

            vmin = np.nanmin(coeff[0,:,:,i])
            vmax = np.nanmax(coeff[0,:,:,i])

        h = axs[j].imshow(coeff[0,:,:,i], cmap='bwr',aspect='auto', vmin=-np.maximum(np.abs(vmin),vmax), vmax=np.maximum(np.abs(vmin),vmax))
        axs[j].set_ylim(axs[j].get_ylim()[::-1])
        cbar = fig.colorbar(h)
        axs[j].set_title(boxnames[j])
        axs[j].set_xlabel('Day')
        axs[j].set_ylabel('Day')
        fig.suptitle(inputs_names[i])


## Time-series (Training)

In [None]:
dates = pd.DatetimeIndex(dataset['time_counter'].values)
indx2 = ~((dataset.time_counter.dt.month==2) & (dataset.time_counter.dt.day==29))
dates = dates[indx2]

r_train,rms_train,slope_train,targets_mean,predictions_mean = post_processing(dates,regions,boxes,targets,predictions,units,category, boxnames)

season = np.reshape(targets_mean, ((len(boxes)), len(targets), len(np.unique(dates.year))), order = 'F')
season = np.mean(season,axis=2)

plt.plot(season.transpose())
plt.legend(boxnames)
plt.suptitle('Long-term seasonalities (2007-2020)')
plt.show()

season_train = np.tile(season,len(np.unique(dates.year))) # Broadcasting season to all training years

r_train_season = np.zeros(len(boxes))
slope_train_season = np.zeros(len(boxes))

for i in range (0,len(boxes)):

    r_train_season[i], _, slope_train_season[i] = plotting_mean_values(dates, targets_mean[i]-season_train[i], predictions_mean[i]-season_train[i], i, units, category, 
        '(removed seasonality) ', boxnames)


## Other Years

In [None]:
dataset = ds.sel(time_counter = slice('2021', '2024'))
dataset2 = ds2.sel(time_counter = slice('2021', '2024'))

dates = pd.DatetimeIndex(dataset['time_counter'].values)

indx = ~((dataset.time_counter.dt.month==2) & (dataset.time_counter.dt.day==29))
dates = dates[indx]

years = np.unique(dates.year)

r_years, rms_years, slope_years, targets_all, predictions_all= evaluation(years,regr_all,boxes,regions0,dataset,dataset2,name,units,scaler_inputs_all,scaler_targets_all,smoother_all,inputs_names)

r_days = xr.corr(targets_all,predictions_all, dim=['x','y'])
rms_days = xs.rmse(targets_all,predictions_all, dim=['x','y'], skipna=True)
slope_days = xs.linslope(targets_all,predictions_all, dim=['x','y'], skipna=True)

plotting_criteria(dates, r_days, r_years, 'Correlation Coefficients')
plotting_criteria(dates, rms_days, rms_years, 'Root Mean Square Errors')
plotting_criteria(dates, slope_days, slope_years, 'Slopes of the best fitting line')

# # Daily maps
# maps = random.sample(sorted(np.arange(0,len(targets_all.time_counter))),10)
# for i in maps:

#     idx = np.isfinite(np.ravel(targets_all[i]))
#     scatter_plot(np.ravel(targets_all[i])[idx], np.ravel(predictions_all[i])[idx], name + ' '+ str(targets_all[i].time_counter.dt.date.values))

#     plotting_maps(targets_all[i], predictions_all[i], name, units)


## Variable Initialization

In [None]:
r_test, rms_test, slope_test = np.zeros(len(boxes)),  np.zeros(len(boxes)), np.zeros(len(boxes))

r_test_season, slope_test_season = np.zeros(len(boxes)),  np.zeros(len(boxes))

targets_sum = np.zeros((len(boxes),len(np.unique(dates.year))))
predictions_sum = np.zeros((len(boxes),len(np.unique(dates.year))))

targets_mean = np.zeros((len(boxes),len(np.unique(dates.year))))
predictions_mean = np.zeros((len(boxes),len(np.unique(dates.year))))

targets_diff, predictions_diff = np.zeros((len(boxes),len(targets_all))), np.zeros((len(boxes),len(targets_all)))

rss = np.zeros(len(boxes))

targets_s = np.zeros((len(targets2),len(np.unique(dates.year)),len(boxes)))
predictions_s = np.zeros((len(targets2),len(np.unique(dates.year)),len(boxes)))

season_test = np.tile(season,len(np.unique(dates.year))) # Broadcasting season to all training years


## Time-series (Testing)

In [None]:
for i in range (0,len(boxes)):

    targets = targets_all.where(regions0==i).mean(['y','x'])
    predictions = predictions_all.where(regions0==i).mean(['y','x'])

    r_test[i] = xr.corr(targets,predictions)
    rms_test[i] = xs.rmse(targets,predictions,skipna=True)
    slope_test[i] = xs.linslope(targets,predictions,skipna=True)

    targets_sum[i] = (targets-season_test[i]).groupby(targets.time_counter.dt.year).sum().values
    predictions_sum[i] =  (predictions-season_test[i]).groupby(predictions.time_counter.dt.year).sum().values

    targets_mean[i] = (targets-season_test[i]).groupby(targets.time_counter.dt.year).mean().values
    predictions_mean[i] =  (predictions-season_test[i]).groupby(predictions.time_counter.dt.year).mean().values

    rss[i] = ((targets-predictions)**2).sum().values # Similar to rms, is not affected by the seasonality

    r_test_season[i], _, slope_test_season[i] = plotting_mean_values(dates, targets-season_test[i], predictions-season_test[i], i, 
        units, category, '(removed seasonality)', boxnames)

    mean = (targets).groupby(targets.time_counter.dt.year).mean().values
    std = (targets).groupby(targets.time_counter.dt.year).std().values

    diff = mean + 1*std
    diff = xr.DataArray(diff, coords = {'year': np.unique(targets.time_counter.dt.year)}, dims = ['year'])

    targets_diff[i] = (targets).groupby(targets.time_counter.dt.year).where((targets).groupby(targets.time_counter.dt.year)>diff)
    predictions_diff[i] = (predictions).groupby(predictions.time_counter.dt.year).where((predictions).groupby(predictions.time_counter.dt.year)>diff)

    # Saving them to the same format as the spatial means notebook
    targets, predictions = np.array(targets), np.array(predictions)
    targets_s[:,:,i] = np.reshape(targets, (len(targets2),len(np.unique(targets_all.time_counter.dt.year))), order='F')
    predictions_s[:,:,i] = np.reshape(predictions, (len(targets2),len(np.unique(targets_all.time_counter.dt.year))), order='F')

    plt.plot(targets_diff[i])
    plt.plot(predictions_diff[i])
    plt.show()


## Saving

In [None]:
# path = '/data/ibougoudis/MOAD/files/results/' + name + '/func_reg_box2/'

# os.makedirs(path, exist_ok=True)
# with lzma.open(path + 'regr_all.xz', 'wb') as f:
    
#     dill.dump(regr_all, f)

# with open(path + 'train_metrics.pkl', 'wb') as f:
#     dill.dump([r_train,rms_train,slope_train,r_train_season,slope_train_season,season], f)

# with open(path + 'test_metrics.pkl', 'wb') as f:
#     dill.dump([r_test,rms_test,slope_test,r_test_season,slope_test_season,targets_sum,predictions_sum,targets_mean,predictions_mean,targets_diff,predictions_diff,rss], f)

# with open(path + 'targets-predictions.pkl', 'wb') as f:
#     dill.dump([targets_s,predictions_s], f)

# file_creation(path, targets_all, 'Targets')
# file_creation(path, predictions_all, 'Predictions')
# file_creation(path, (targets_all-predictions_all), 'Targets - Predictions')
