# Overall comparison of the models' performance (Diatom Production Rate)

## Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xskillscore as xs
from sklearn.metrics import root_mean_squared_error as rmse
import dill


## Function for 5% mean

In [None]:
def low_mean(a):
    
    return a[0:4].mean(dim="time_counter")


## Annual calculations

In [None]:
def annual (targets,predictions):
    
    r = np.round(np.corrcoef(np.ravel(targets),np.ravel(predictions))[0][1],3)
    rms = rmse(np.ravel(targets),np.ravel(predictions))
    m,_ = np.polyfit(np.ravel(targets), np.ravel(predictions), deg=1)
    slope = np.round(m,3)

    return (r,rms,slope)


## Daily calculations

In [None]:
def daily(targets_all,predictions_all):

    r_days = xr.corr(targets_all,predictions_all, dim=['x','y']).groupby('time_counter.year').mean('time_counter')
    rms_days = xs.rmse(targets_all,predictions_all, dim=['x','y'], skipna=True).groupby('time_counter.year').mean('time_counter')
    slope_days = xs.linslope(targets_all,predictions_all, dim=['x','y'], skipna=True).groupby('time_counter.year').mean('time_counter')

    r = np.zeros(4)
    rms = np.zeros(4)
    slope = np.zeros(4)

    for j in range(0,4):

        r[j] = np.round(r_days[j].values,3)
        rms[j] = rms_days[j].values
        slope[j] =  np.round(slope_days[j].values,3)
    
    return(r,rms,slope)


## 5% calculations

In [None]:
def daily2(targets_all,predictions_all):

    r_days = xr.corr(targets_all,predictions_all, dim=['x','y']).sortby(xr.corr(targets_all,predictions_all, dim=['x','y'])).groupby('time_counter.year').map(low_mean)
    rms_days = xs.rmse(targets_all,predictions_all, dim=['x','y'], skipna=True).sortby(xs.rmse(targets_all,predictions_all, dim=['x','y'], skipna=True),ascending=False).groupby('time_counter.year').map(low_mean)
    slope_days = xs.linslope(targets_all,predictions_all, dim=['x','y'], skipna=True).sortby(xs.linslope(targets_all,predictions_all, dim=['x','y'], skipna=True)).groupby('time_counter.year').map(low_mean)

    r = np.zeros(4)
    rms = np.zeros(4)
    slope = np.zeros(4)

    for j in range(0,4):

        r[j] = np.round(r_days[j].values,3)
        rms[j] = rms_days[j].values
        slope[j] =  np.round(slope_days[j].values,3)
    
    return(r,rms,slope)


## All calculations

In [None]:
def calculations(species,category):

    r = np.zeros((len(category),4,5)) # First axis is the model categories, second the years, third the calculated metrics
    rms = np.zeros((len(category),4,5))
    slope = np.zeros((len(category),4,5))

    r_train = np.zeros(len(category))
    rms_train =  np.zeros(len(category))
    slope_train = np.zeros(len(category))

    for i in range (0, len(category)):

        ds = xr.open_dataset(species+category[i]+'/targets_predictions.nc')
        targets_all = ds.Targets
        predictions_all = ds.Predictions
        mean_targets = targets_all.mean(dim=['x','y'], skipna=True)
        mean_predictions = predictions_all.mean(dim=['x','y'], skipna=True)

        with open(species+category[i]+ '/metrics.pkl', 'rb') as f:

            metrics = dill.load(f)
        
        r_train[i], rms_train[i], slope_train[i] = metrics[0:3]

        # Annual
        targets_annual = np.ravel(targets_all.groupby('time_counter.year'))
        predictions_annual = np.ravel(predictions_all.groupby('time_counter.year'))
        targets_mean_annual = np.ravel(mean_targets.groupby('time_counter.year'))
        predictions_mean_annual = np.ravel(mean_predictions.groupby('time_counter.year'))
        targets_mean_annual_season = np.ravel((mean_targets-metrics[4][np.where(metrics[3].year==2017)[0][0]:]).groupby('time_counter.year'))
        predictions_mean_annual_season = np.ravel((mean_predictions-metrics[4][np.where(metrics[3].year==2017)[0][0]:]).groupby('time_counter.year'))

        years = []
        for j in range(1,8,2):

            years.append(targets_annual[j-1])

            idx = np.isfinite(np.ravel(targets_annual[j]))
            r[i,len(years)-1,0],rms[i,len(years)-1,0],slope[i,len(years)-1,0] = annual(np.ravel(targets_annual[j])[idx],np.ravel(predictions_annual[j])[idx])
            r[i,len(years)-1,3],rms[i,len(years)-1,3],slope[i,len(years)-1,3] = annual(targets_mean_annual[j],predictions_mean_annual[j])
            r[i,len(years)-1,4],rms[i,len(years)-1,4],slope[i,len(years)-1,4] = annual(targets_mean_annual_season[j],predictions_mean_annual_season[j])

        # Daily means
        r[i,:,1],rms[i,:,1],slope[i,:,1] = daily(targets_all,predictions_all)
    
        # 5% low
        r[i,:,2],rms[i,:,2],slope[i,:,2] = daily2(targets_all,predictions_all)

    return(years,r,rms,slope,r_train,rms_train,slope_train)


## Plotting

In [None]:
def plotting (j,metric,categories,years,quantity,name):

    fig,ax = plt.subplots()

    for i in range (0,len(categories)):
        ax.plot(years,quantity[i,:,j],marker= '*', label=categories[i])
    plt.xlabel('Years')
    plt.suptitle(metric+ ' ' +name+ ' (15 Feb - 30 Apr)')
    plt.legend()
    plt.show()


## Printing

In [None]:
def printing(metric,years,data,categories,criteria):

     temp = pd.DataFrame(data,columns=years,index=categories)
     print(metric+ ' ' +criteria)
     display(temp)
     

## Summary

In [None]:
def summary(quantity,categories,metrics):

    years,r,rms,slope,r_train,rms_train,slope_train = calculations(quantity,categories)

    print (quantity[36:-1])
    print ('\n')

    temp = np.concatenate((r_train,rms_train,slope_train))
    temp = temp.reshape(3,len(categories))
    
    temp = pd.DataFrame(temp.transpose(),columns=['r', 'rms', 'slope'],index=categories)
    print('Training')
    display(temp)
    print ('\n')

    for i in range (0, len(metrics)):

        printing(metrics[i],years,r[:,:,i],categories,'Correlation coefficient')
        plotting(i,metrics[i],categories,years,r,'correlation coefficient')

        printing(metrics[i],years,rms[:,:,i],categories, 'Root mean square error')
        plotting(i,metrics[i],categories,years,rms,'root mean square error')

        printing(metrics[i],years,slope[:,:,i],categories, 'Slope of the best fitting line')
        plotting(i,metrics[i],categories,years,slope,'slope of the best fitting line')
    

## Main Body

In [None]:
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)

# Paths
diat = '/data/ibougoudis/MOAD/files/results/Diatom/'
diat_pr = '/data/ibougoudis/MOAD/files/results/Diatom_Production_Rate/'
flag = '/data/ibougoudis/MOAD/files/results/Flagellate/'
flag_pr = '/data/ibougoudis/MOAD/files/results/Flagellate_Production_Rate/'

categories = ['hist_xy_ext','func_cl_target_ext','func_cl_drivers_ext']

metrics = ['Annual', 'Daily mean', '5% low mean', 'Temporal','Temporal (removed seasonality)']

summary(flag,categories,metrics)
