# Overall comparison of the models' performance (Diatom Production Rate)

## Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xskillscore as xs
from sklearn.metrics import root_mean_squared_error as rmse


## Function for 5% mean

In [None]:
def low_mean(a):
    
    return a[0:4].mean(dim="time_counter")


## Daily calculations

In [None]:
def daily(r_days,rms_days,slope_days):

    r = np.zeros(4)
    rms = np.zeros(4)
    slope = np.zeros(4)

    for j in range(0,4):

        r[j] = np.round(r_days[j].values,3)
        rms[j] = rms_days[j].values
        slope[j] =  np.round(slope_days[j].values,3)
    
    return(r,rms,slope)


## Printing (Function)

In [None]:
def printing(metric,years,r,rms,m):

    for j in range (0,len(years)):
        print('The '+metric+ ' correlation coefficient for year '+str(years[j])+' is '+str(r[j]))
        print('The '+metric+ ' root mean square error for year '+str(years[j])+' is '+str(rms[j]))
        print('The '+metric+ ' slope of the best fitting line for year '+str(years[j])+' is '+str(m[j]))
        print('\n')


## Plotting (Function)

In [None]:
def plotting (j,metric,categories,years,quantity,name):

    fig,ax = plt.subplots()

    for i in range (0,len(categories)):
        ax.plot(years,quantity[i,:,j],marker= '*', label=categories[i])
    plt.xlabel('Years')
    plt.suptitle(metric+ ' ' +name+ ' (15 Feb - 30 Apr)')
    plt.legend()
    fig.show()


## Calculations

In [None]:
def calculations(species,category):

    r = np.zeros((len(category),4,3)) # First axis is the model categories, second the years, third the calculated metrics
    rms = np.zeros((len(category),4,3))
    slope = np.zeros((len(category),4,3))

    for i in range (0, len(category)):

        ds = xr.open_dataset(species+category[i]+'/targets_predictions.nc')
        targets_all = ds.Targets
        predictions_all = ds.Predictions

        # Annual
        targets_annual = np.ravel(targets_all.groupby('time_counter.year'))
        predictions_annual = np.ravel(predictions_all.groupby('time_counter.year'))

        years = []
        for j in range(1,8,2):

            years.append(targets_annual[j-1])

            idx = np.isfinite(np.ravel(targets_annual[j]))

            r[i,len(years)-1,0] = np.round(np.corrcoef(np.ravel(targets_annual[j])[idx],np.ravel(predictions_annual[j])[idx])[0][1],3)
            rms[i,len(years)-1,0] = rmse(np.ravel(targets_annual[j])[idx],np.ravel(predictions_annual[j])[idx])
            m,_ = np.polyfit(np.ravel(targets_annual[j])[idx], np.ravel(predictions_annual[j])[idx], deg=1)
            slope[i,len(years)-1,0] = np.round(m,3)

        # Daily means
        r_days = xr.corr(targets_all,predictions_all, dim=['x','y']).groupby('time_counter.year').mean('time_counter')
        rms_days = xs.rmse(targets_all,predictions_all, dim=['x','y'], skipna=True).groupby('time_counter.year').mean('time_counter')
        slope_days = xs.linslope(targets_all,predictions_all, dim=['x','y'], skipna=True).groupby('time_counter.year').mean('time_counter')

        r[i,:,1],rms[i,:,1],slope[i,:,1] = daily(r_days,rms_days,slope_days)
            
        # 5% low
        r_days = xr.corr(targets_all,predictions_all, dim=['x','y']).sortby(xr.corr(targets_all,predictions_all, dim=['x','y'])).groupby('time_counter.year').map(low_mean)
        rms_days = xs.rmse(targets_all,predictions_all, dim=['x','y'], skipna=True).sortby(xs.rmse(targets_all,predictions_all, dim=['x','y'], skipna=True),ascending=False).groupby('time_counter.year').map(low_mean)
        slope_days = xs.linslope(targets_all,predictions_all, dim=['x','y'], skipna=True).sortby(xs.linslope(targets_all,predictions_all, dim=['x','y'], skipna=True)).groupby('time_counter.year').map(low_mean)

        r[i,:,2],rms[i,:,2],slope[i,:,2] = daily(r_days,rms_days,slope_days)

    return(years,r,rms,slope)


## Main Body

In [None]:
# Paths
diat = '/data/ibougoudis/MOAD/files/results/Diatom/'
diat_pr = '/data/ibougoudis/MOAD/files/results/Diatom_Production_Rate/'
flag = '/data/ibougoudis/MOAD/files/results/Flagellate/'
flag_pr = '/data/ibougoudis/MOAD/files/results/Flagellate_Production_Rate/'

# categories = ['hist','hist_xy','hist_nitr','hist_xy_nitr','hist_func_cl','hist_func_cl_nitr','func_reg']
categories = ['hist','hist_xy','hist_nitr','hist_xy_nitr','hist_ext', 'hist_xy_ext']

metrics = ['Annual', 'Daily mean', '5% low mean']

species = diat_pr
years,r,rms,slope = calculations(diat_pr,categories)

## Plotting

In [None]:
for i in range (0, len(metrics)):

    plotting(i,metrics[i],categories,years,r,'correlation coefficient')
    plotting(i,metrics[i],categories,years,rms,'root mean square error')
    plotting(i,metrics[i],categories,years,slope,'slope of the best fitting line')


## Printing

In [None]:
print(species[36:-1])
print('\n')

for i in range (0,len(categories)):

    print(categories[i])
    printing(metrics[0],years,r[i,:,0],rms[i,:,0],slope[i,:,0])
    printing(metrics[1],years,r[i,:,1],rms[i,:,1],slope[i,:,1])
    printing(metrics[2],years,r[i,:,2],rms[i,:,2],slope[i,:,2])


## Tables

### Correlation coefficients

| Diatom Production Rate | Training | 2021 (annual) | 2021 (mean daily) | 2021 (mean 5% low) | 2022 (annual) | 2022 (mean daily) | 2022 (mean 5% low) | 2023 (annual) | 2023 (mean daily) | 2023 (mean 5% low) | 2024 (annual) | 2024 (mean daily) | 2024 (mean 5% low) |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| hist | .9 |  .8 |  .843 | | .7 | 0.8 | | | | | | | | 
| hist + xy | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + nitr | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + xy + nitr | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + fun_cl | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + fun_cl + nitr | .9 | .8 | .843 | | .7 | 0.8 | | | | | | | | 
| func_reg | .9 |.8| .843 | | .7 | 0.8  | | | | | | | |

### Root mean square errors

| Diatom Production Rate | Training | 2021 (annual) | 2021 (mean daily) | 2021 (mean 5% low) | 2022 (annual) | 2022 (mean daily) | 2022 (mean 5% low) | 2023 (annual) | 2023 (mean daily) | 2023 (mean 5% low) | 2024 (annual) | 2024 (mean daily) | 2024 (mean 5% low) |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| hist | .9 |  .8 |  .843 | | .7 | 0.8 | | | | | | | | 
| hist + xy | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + nitr | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + xy + nitr | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + fun_cl | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + fun_cl + nitr | .9 | .8 | .843 | | .7 | 0.8 | | | | | | | | 
| func_reg | .9 |.8| .843 | | .7 | 0.8  | | | | | | | |

### Slope of the best fitting line

| Diatom Production Rate | Training | 2021 (annual) | 2021 (mean daily) | 2021 (mean 5% low) | 2022 (annual) | 2022 (mean daily) | 2022 (mean 5% low) | 2023 (annual) | 2023 (mean daily) | 2023 (mean 5% low) | 2024 (annual) | 2024 (mean daily) | 2024 (mean 5% low) |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| hist | .9 |  .8 |  .843 | | .7 | 0.8 | | | | | | | | 
| hist + xy | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + nitr | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + xy + nitr | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + fun_cl | .9 |  .8 | .843 | | .7 | 0.8 | | | | | | | | 
| hist + fun_cl + nitr | .9 | .8 | .843 | | .7 | 0.8 | | | | | | | | 
| func_reg | .9 |.8| .843 | | .7 | 0.8  | | | | | | | |