In [None]:
"""
Created on Thu Sep 15 10:55 2022

Evaluating results of the cross-validation like the scatter plot in paper

Author: @claraburgard

"""

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import glob

READ IN DATA

In [None]:
plot_path = '/bettik/burgardc/PLOTS/NN_plots/1D_eval/'


In [None]:
inputpath_mask = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_OPM006/'
file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_new_oneFRIS.nc')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
file_isf = file_isf_nonnan.sel(Nisf=large_isf)
file_isf['isf_name'] = file_isf['isf_name'].where(file_isf['isf_name'] != 'Ekstrom', np.array('Ekström', dtype=object))
isf_names = file_isf['isf_name']

In [None]:
region_list = []
for kisf in file_isf.Nisf:
    if file_isf['isf_name'].sel(Nisf=kisf).values in ['Ross','Nickerson','Sulzberger', 'Cook']:
        region_list.append('East and West Ross')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Filchner','Ronne','Filchner-Ronne']:
        region_list.append('Weddell')        
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Ekström','Nivl','Prince Harald','Riiser-Larsen','Fimbul','Roi Baudouin','Lazarev','Stancomb Brunt','Jelbart','Borchgrevink']:
        region_list.append('Dronning Maud Land')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Getz','Thwaites','Crosson','Dotson','Cosgrove','Pine Island']:
        region_list.append('Amundsen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Venable','George VI','Abbot','Stange','Larsen C','Bach','Larsen D','Wilkins']:
        region_list.append('Peninsula and Bellinghausen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Amery','Moscow Univ.','Tracy Tremenchus','Totten','West','Shackleton']:
        region_list.append('East Antarctica')
    else:
        print('Argh, help me, '+file_isf['isf_name'].sel(Nisf=kisf).values+' has no region assigned!')
# 6 regions
# regions = ['East and West Ross','Weddell','Dronning Maud Land','Amundsen','Peninsula and Bellinghausen','East Antarctica']
file_isf['region'] = xr.DataArray(data=region_list,dims='Nisf')

In [None]:
regions = ['Weddell','Peninsula and Bellinghausen','Amundsen','East and West Ross','East Antarctica','Dronning Maud Land']
nisf_by_reg_list = []
for rr, reg in enumerate(regions):
    subset_isf = file_isf.Nisf.where(file_isf['region']==reg,drop=True)
    nisf_by_reg_list.append(subset_isf.values)
nisf_by_reg_list = np.concatenate(nisf_by_reg_list)

In [None]:
run_list = ['OPM006','OPM016','OPM018','OPM021']

diff_Gt_CVtime_list = []
diff_box1_CVtime_list = []
diff_Gt_CVisf_list = []
diff_box1_CVisf_list = []

ref_Gt_list = []
ref_box1_list = []
Gt_CVtime_list = []
Gt_CVisf_list = []

outputpath_melt = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/'

for n,nemo_run in enumerate(run_list):
    
    ### CV TIME
    melt_param_files_CVtime = list(sorted(glob.glob(outputpath_melt+'CV_TBLOCKS/evalmetrics_1D_*_TSextrap_normstd_'+nemo_run+'.nc')))
    
    param_list = []
    for mfilename in melt_param_files_CVtime:
        paramname = mfilename[77:157].split('.')[0].split('_')[:-3]
        mparam = '_'.join(paramname)
        param_list.append(mparam)
        
    ds_melt_param_CVtime  = xr.open_mfdataset(melt_param_files_CVtime, concat_dim='param', combine='nested', coords='minimal',compat='override')#, chunks={'x': chunksize, 'y': chunksize})
    ds_melt_param_CVtime = ds_melt_param_CVtime.assign_coords(param=param_list)
    Gt_CVtime_list.append(ds_melt_param_CVtime['predicted_melt'].sel(metrics='Gt'))
    diff_Gt_CVtime = (ds_melt_param_CVtime['predicted_melt'] - ds_melt_param_CVtime['reference_melt']).sel(metrics='Gt')
    diff_Gt_CVtime_list.append(diff_Gt_CVtime)
    diff_box1_CVtime = (ds_melt_param_CVtime['predicted_melt'].mean('time') - ds_melt_param_CVtime['reference_melt'].mean('time')).sel(metrics='box1')
    diff_box1_CVtime_list.append(diff_box1_CVtime)
    ref_Gt_list.append(ds_melt_param_CVtime['reference_melt'].isel(param=0).sel(metrics='Gt'))
    ref_box1_list.append(ds_melt_param_CVtime['reference_melt'].isel(param=0).sel(metrics='box1'))

    ### CV TIME
    melt_param_files_CVisf = list(sorted(glob.glob(outputpath_melt+'CV_ISF/evalmetrics_1D_*_TSextrap_normstd_'+nemo_run+'.nc')))
    
    param_list = []
    for mfilename in melt_param_files_CVisf:
        paramname = mfilename[73:157].split('.')[0].split('_')[:-3]
        mparam = '_'.join(paramname)
        param_list.append(mparam)
        
    ds_melt_param_CVisf  = xr.open_mfdataset(melt_param_files_CVisf, concat_dim='param', combine='nested', coords='minimal',compat='override')#, chunks={'x': chunksize, 'y': chunksize})
    ds_melt_param_CVisf = ds_melt_param_CVisf.assign_coords(param=param_list)
    Gt_CVisf_list.append(ds_melt_param_CVisf['predicted_melt'].sel(metrics='Gt'))
    diff_Gt_CVisf = (ds_melt_param_CVisf['predicted_melt'] - ds_melt_param_CVisf['reference_melt']).sel(metrics='Gt')
    diff_Gt_CVisf_list.append(diff_Gt_CVisf)
    diff_box1_CVisf = (ds_melt_param_CVisf['predicted_melt'].mean('time') - ds_melt_param_CVisf['reference_melt'].mean('time')).sel(metrics='box1')
    diff_box1_CVisf_list.append(diff_box1_CVisf)

Gt_all_CVtime = xr.concat(Gt_CVtime_list, dim='nemo_run')
Gt_all_CVisf = xr.concat(Gt_CVisf_list, dim='nemo_run')
diff_Gt_all_CVtime = xr.concat(diff_Gt_CVtime_list, dim='time')
diff_box1_all_CVtime = xr.concat(diff_box1_CVtime_list, dim='nemo_run')   
diff_Gt_all_CVisf = xr.concat(diff_Gt_CVisf_list, dim='time')
diff_box1_all_CVisf = xr.concat(diff_box1_CVisf_list, dim='nemo_run')   
ref_Gt_all = xr.concat(ref_Gt_list, dim='nemo_run')
ref_box1_all = xr.concat(ref_box1_list, dim='nemo_run')

In [None]:
mean_box1_all = ref_box1_all.mean()
mean_box1_all.load()

In [None]:
tot_Gt = ref_Gt_all.mean('time').mean('nemo_run').mean('Nisf')
tot_Gt.load()

In [None]:
param_list_of_int = ['mini', 'small', 'medium', 'large', 'extra_large']


In [None]:
RMSE_Gt_all_CVtime = np.sqrt((diff_Gt_all_CVtime**2).mean(['time','Nisf']))
RMSE_box1_all_CVtime = np.sqrt((diff_box1_all_CVtime**2).mean(['nemo_run','Nisf']))
RMSE_Gt_all_CVisf = np.sqrt((diff_Gt_all_CVisf**2).mean(['time','Nisf']))
RMSE_box1_all_CVisf = np.sqrt((diff_box1_all_CVisf**2).mean(['nemo_run','Nisf']))

In [None]:
RMSE_Gt_all_CVisf.load()

In [None]:
fig, axs = plt.subplots(1, 2,figsize=(8.24*1.25/1.5,8.24/2),sharey=True)

ccolors = ['aqua', 'aquamarine', 'deepskyblue', 'blue', 'darkblue']

#plt.figure()
for k,mparam in enumerate(param_list_of_int[::-1]):
    
    ccolor = ccolors[k]
        
    axs[0].scatter(RMSE_Gt_all_CVtime.sel(param=mparam),mparam,marker='x',c='k', s=50)
    axs[1].scatter(RMSE_box1_all_CVtime.sel(param=mparam),mparam,marker='x',c='k', s=50)
    axs[0].scatter(RMSE_Gt_all_CVtime.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)
    axs[1].scatter(RMSE_box1_all_CVtime.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)

    axs[0].scatter(RMSE_Gt_all_CVisf.sel(param=mparam),mparam,marker='+',c='k', s=120)
    axs[1].scatter(RMSE_box1_all_CVisf.sel(param=mparam),mparam,marker='+',c='k', s=120)
    axs[0].scatter(RMSE_Gt_all_CVisf.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)
    axs[1].scatter(RMSE_box1_all_CVisf.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)

axs[0].set_xlim(0,105)
axs[1].set_xlim(0,1)

sns.despine()
plt.savefig(plot_path+'RMSE_scatter_box1_CV_NN.pdf')

In [None]:
mean_diff_Gt_CVisf = diff_Gt_all_CVisf.mean(['time'])

In [None]:
plotted_var = mean_diff_Gt_CVisf.sel(param=param_list_of_int,Nisf=nisf_by_reg_list)


x = np.arange(len(plotted_var.param))
y = np.arange(len(plotted_var.Nisf))

fig, axs = plt.subplots(1, 1,figsize=((len(y)+1)/4,(len(x)+2)/2.75))
ax0 = axs.imshow(plotted_var.values, cmap=plt.cm.coolwarm, vmin=-20, vmax=20)
axs.set_yticklabels(labels=np.round(RMSE_Gt_all_CVisf.sel(param=param_list_of_int).values,2))

for j, dom in enumerate([50]):
    axs.yaxis.tick_right()
    axs.set_yticks(x)
    axs.yaxis.tick_right()

    axs.xaxis.tick_top()    
    axs.set_xticks(y)
    axs.set_xticklabels(labels=isf_names.sel(Nisf=plotted_var.Nisf).values, rotation=90)


cb_ax = fig.add_axes([0.15, 0.08, 0.7, 0.02])
#cb_ax = fig.add_axes([0.01, 0.2, 0.02, 0.7])
cbar = fig.colorbar(ax0, cax=cb_ax, extend='both',orientation='horizontal')
plt.tight_layout()
#fig.savefig(plot_path+'heatmap_RMSE_box1_CVisf_allruns.pdf')