In [None]:
"""
Created on Fri Sep 30 10:31 2022

Plot scatter to compare classic parameterisations - with evalmetrics1D

Author: Clara Burgard
"""

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import glob

In [None]:
%matplotlib qt5

READ IN DATA

In [None]:
plot_path = '/bettik/burgardc/PLOTS/NN_plots/1D_eval/'
home_path = '/bettik/burgardc/'

In [None]:
inputpath_mask = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_OPM006/'
file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_new_oneFRIS.nc')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
file_isf = file_isf_nonnan.sel(Nisf=large_isf)
file_isf['isf_name'] = file_isf['isf_name'].where(file_isf['isf_name'] != 'Ekstrom', np.array('Ekström', dtype=object))
isf_names = file_isf['isf_name']

In [None]:
region_list = []
for kisf in file_isf.Nisf:
    if file_isf['isf_name'].sel(Nisf=kisf).values in ['Ross','Nickerson','Sulzberger', 'Cook']:
        region_list.append('East and West Ross')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Filchner','Ronne','Filchner-Ronne']:
        region_list.append('Weddell')        
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Ekström','Nivl','Prince Harald','Riiser-Larsen','Fimbul','Roi Baudouin','Lazarev','Stancomb Brunt','Jelbart','Borchgrevink']:
        region_list.append('Dronning Maud Land')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Getz','Thwaites','Crosson','Dotson','Cosgrove','Pine Island']:
        region_list.append('Amundsen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Venable','George VI','Abbot','Stange','Larsen C','Bach','Larsen D','Wilkins']:
        region_list.append('Peninsula and Bellinghausen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Amery','Moscow Univ.','Tracy Tremenchus','Totten','West','Shackleton']:
        region_list.append('East Antarctica')
    else:
        print('Argh, help me, '+file_isf['isf_name'].sel(Nisf=kisf).values+' has no region assigned!')
# 6 regions
# regions = ['East and West Ross','Weddell','Dronning Maud Land','Amundsen','Peninsula and Bellinghausen','East Antarctica']
file_isf['region'] = xr.DataArray(data=region_list,dims='Nisf')

In [None]:
regions = ['Weddell','Peninsula and Bellinghausen','Amundsen','East and West Ross','East Antarctica','Dronning Maud Land']
nisf_by_reg_list = []
for rr, reg in enumerate(regions):
    subset_isf = file_isf.Nisf.where(file_isf['region']==reg,drop=True)
    nisf_by_reg_list.append(subset_isf.values)
nisf_by_reg_list = np.concatenate(nisf_by_reg_list)

NN results

In [None]:
TS_opt = 'extrap'
#TS_opt = 'whole'

In [None]:
run_list = ['OPM006','OPM016','OPM018','OPM021']

diff_Gt_CVtime_list = []
diff_box1_CVtime_list = []
diff_Gt_CVisf_list = []
diff_box1_CVisf_list = []

ref_Gt_list = []
ref_box1_list = []
Gt_CVtime_list = []
Gt_CVisf_list = []

outputpath_melt = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/'

for n,nemo_run in enumerate(run_list):

    ### CV TIME
    melt_param_files_CVtime = list(sorted(glob.glob(outputpath_melt+'CV_TBLOCKS/evalmetrics_1D_CV_*newbasic*_normstd_'+nemo_run+'.nc')))
    
    param_list = []
    for mfilename in melt_param_files_CVtime:
        paramname = mfilename[80:157].split('.')[0].split('_')[:-2]
        #paramname = mfilename[80:157].split('.')[0].split('_')[:-3]
        #paramname = mfilename[79:157].split('.')[0].split('_')[:-3]
        mparam = '_'.join(paramname)
        param_list.append(mparam)
        
    ds_melt_param_CVtime  = xr.open_mfdataset(melt_param_files_CVtime, concat_dim='param', combine='nested', coords='minimal',compat='override')#, chunks={'x': chunksize, 'y': chunksize})
    ds_melt_param_CVtime = ds_melt_param_CVtime.assign_coords(param=param_list)
    Gt_CVtime_list.append(ds_melt_param_CVtime['predicted_melt'].sel(metrics='Gt'))
    diff_Gt_CVtime = (ds_melt_param_CVtime['predicted_melt'] - ds_melt_param_CVtime['reference_melt']).sel(metrics='Gt')
    diff_Gt_CVtime = diff_Gt_CVtime.assign_coords({'time': np.arange(1,len(diff_Gt_CVtime.time)+1)+n*50})
    diff_Gt_CVtime_list.append(diff_Gt_CVtime)
    diff_box1_CVtime = (ds_melt_param_CVtime['predicted_melt'].mean('time') - ds_melt_param_CVtime['reference_melt'].mean('time')).sel(metrics='box1')
    diff_box1_CVtime_list.append(diff_box1_CVtime)
    ref_Gt_list.append(ds_melt_param_CVtime['reference_melt'].isel(param=0).sel(metrics='Gt'))
    ref_box1_list.append(ds_melt_param_CVtime['reference_melt'].isel(param=0).sel(metrics='box1'))

    ### CV ISF
    melt_param_files_CVisf = list(sorted(glob.glob(outputpath_melt+'CV_ISF/evalmetrics_1D_CV_*newbasic*_normstd_'+nemo_run+'.nc')))
        
    ds_melt_param_CVisf  = xr.open_mfdataset(melt_param_files_CVisf, concat_dim='param', combine='nested', coords='minimal',compat='override')#, chunks={'x': chunksize, 'y': chunksize})
    ds_melt_param_CVisf = ds_melt_param_CVisf.assign_coords(param=param_list)
    Gt_CVisf_list.append(ds_melt_param_CVisf['predicted_melt'].sel(metrics='Gt'))
    diff_Gt_CVisf = (ds_melt_param_CVisf['predicted_melt'] - ds_melt_param_CVisf['reference_melt']).sel(metrics='Gt')
    diff_Gt_CVisf = diff_Gt_CVisf.assign_coords({'time': np.arange(1,len(diff_Gt_CVisf.time)+1)+n*50})
    diff_Gt_CVisf_list.append(diff_Gt_CVisf)
    diff_box1_CVisf = (ds_melt_param_CVisf['predicted_melt'].mean('time') - ds_melt_param_CVisf['reference_melt'].mean('time')).sel(metrics='box1')
    diff_box1_CVisf_list.append(diff_box1_CVisf)

Gt_all_CVtime_NN = xr.concat(Gt_CVtime_list, dim='nemo_run')
Gt_all_CVisf_NN = xr.concat(Gt_CVisf_list, dim='nemo_run')
diff_Gt_all_CVtime_NN = xr.concat(diff_Gt_CVtime_list, dim='time')
diff_box1_all_CVtime_NN = xr.concat(diff_box1_CVtime_list, dim='nemo_run')   
diff_Gt_all_CVisf_NN = xr.concat(diff_Gt_CVisf_list, dim='time')
diff_box1_all_CVisf_NN = xr.concat(diff_box1_CVisf_list, dim='nemo_run')   
ref_Gt_all = xr.concat(ref_Gt_list, dim='nemo_run')
ref_box1_all = xr.concat(ref_box1_list, dim='nemo_run')

In [None]:
melt_param_files_CVisf

In [None]:
Gt_all_CVisf_NN

Classic params

In [None]:
run_list = ['OPM006','OPM016','OPM018','OPM021']

diff_Gt_CVtime_list = []
diff_box1_CVtime_list = []
diff_Gt_CVisf_list = []
diff_box1_CVisf_list = []

ref_Gt_list = []
ref_box1_list = []
Gt_CVtime_list = []
Gt_CVisf_list = []


for n,nemo_run in enumerate(run_list):

    outputpath_melt = home_path+'DATA/BASAL_MELT_PARAM/processed/MELT_RATE/nemo_5km_'+nemo_run+'/'

    ### READ IN THE REFERENCE
    NEMO_melt_rates_1D = xr.open_dataset(outputpath_melt+'melt_rates_1D_NEMO_oneFRIS.nc')
    ref_Gt = NEMO_melt_rates_1D['melt_Gt_per_y_tot']
    ref_Gt_list.append(ref_Gt)
    NEMO_box1_myr = xr.open_dataset(outputpath_melt+'melt_rates_box1_NEMO_oneFRIS.nc')
    ref_box1 = NEMO_box1_myr['mean_melt_box1_myr']
    ref_box1_list.append(ref_box1)
    
    ### READ IN THE PARAM FILES - CV TIME
    
    # Param files
    outputpath_melt = home_path+'DATA/BASAL_MELT_PARAM/processed/MELT_RATE/nemo_5km_'+nemo_run+'/'
    melt_param_files_CVtime = list(sorted(glob.glob(outputpath_melt+'eval_metrics_1D_*_CVtime.nc')))
    
    param_list = []
    for mfilename in melt_param_files_CVtime:
        paramname = mfilename[91:157].split('.')[0].split('_')[:-2]
        #paramname = mfilename[91:157].split('.')[0]
        #paramname = mfilename[105:157].split('.')[0].split('_')[:-1]
        mparam = '_'.join(paramname)
        param_list.append(mparam)
    
    ds_melt_param_CVtime  = xr.open_mfdataset(melt_param_files_CVtime, concat_dim='new_param', combine='nested', coords='minimal',compat='override')#, chunks={'x': chunksize, 'y': chunksize})
    if 'option' in ds_melt_param_CVtime.coords:        
        ds_melt_param_CVtime = ds_melt_param_CVtime.drop('param').drop('option')
    else:
        ds_melt_param_CVtime = ds_melt_param_CVtime.drop('param')
    ds_melt_param_CVtime = ds_melt_param_CVtime.rename({'new_param': 'param'})
    ds_melt_param_CVtime = ds_melt_param_CVtime.assign_coords(param=param_list)
    Gt_CVtime_list.append(ds_melt_param_CVtime['melt_1D_Gt_per_y'])
    diff_Gt_CVtime = ds_melt_param_CVtime['melt_1D_Gt_per_y'] - ref_Gt
    diff_Gt_CVtime = diff_Gt_CVtime.assign_coords({'time': np.arange(1,len(diff_Gt_CVtime.time)+1)+n*50})
    diff_Gt_CVtime_list.append(diff_Gt_CVtime)
    diff_box1_CVtime = ds_melt_param_CVtime['melt_1D_mean_myr_box1'].mean('time') - ref_box1.mean('time')
    diff_box1_CVtime_list.append(diff_box1_CVtime)
    
    melt_param_files_CVisf = list(sorted(glob.glob(outputpath_melt+'eval_metrics_1D_*_CVshelves.nc')))
    ds_melt_param_CVisf  = xr.open_mfdataset(melt_param_files_CVisf, concat_dim='new_param', combine='nested', coords='minimal',compat='override')#, chunks={'x': chunksize, 'y': chunksize})
    if 'option' in ds_melt_param_CVisf.coords:        
        ds_melt_param_CVisf  = ds_melt_param_CVisf.drop('param').drop('option')
    else:
        ds_melt_param_CVisf  = ds_melt_param_CVisf.drop('param')
    ds_melt_param_CVisf = ds_melt_param_CVisf.rename({'new_param': 'param'})
    ds_melt_param_CVisf = ds_melt_param_CVisf.assign_coords(param=param_list)
    Gt_CVisf_list.append(ds_melt_param_CVisf['melt_1D_Gt_per_y'])
    diff_Gt_CVisf = ds_melt_param_CVisf['melt_1D_Gt_per_y'] - ref_Gt
    diff_Gt_CVisf = diff_Gt_CVisf.assign_coords({'time': np.arange(1,len(diff_Gt_CVisf.time)+1)+n*50})
    diff_Gt_CVisf_list.append(diff_Gt_CVisf)
    diff_box1_CVisf = ds_melt_param_CVisf['melt_1D_mean_myr_box1'].mean('time') - ref_box1.mean('time')
    diff_box1_CVisf_list.append(diff_box1_CVisf)
        

Gt_all_CVtime = xr.concat(Gt_CVtime_list, dim='nemo_run')
Gt_all_CVisf = xr.concat(Gt_CVisf_list, dim='nemo_run')
diff_Gt_all_CVtime = xr.concat(diff_Gt_CVtime_list, dim='time')
diff_box1_all_CVtime = xr.concat(diff_box1_CVtime_list, dim='nemo_run')   
diff_Gt_all_CVisf = xr.concat(diff_Gt_CVisf_list, dim='time')
diff_box1_all_CVisf = xr.concat(diff_box1_CVisf_list, dim='nemo_run')   
ref_Gt_all = xr.concat(ref_Gt_list, dim='nemo_run')
ref_box1_all = xr.concat(ref_box1_list, dim='nemo_run')

In [None]:
sorted_isf_all = [11,69,43,12,70,44,29,13,58,71,45,30,31,61,73,47,32,48,33,17,62,49,34,18,10,65,51,22,38,52,23,66,53,39,24,40,54,75,25,26,42,55]
special_nisf_list = []
for kisf in sorted_isf_all:
    if kisf in nisf_by_reg_list:
        special_nisf_list.append(kisf)

In [None]:
special_nisf_list

In [None]:
diff_Gt_all_param_together_CVtime = xr.concat([diff_Gt_all_CVtime.sel(profile_domain=50),diff_Gt_all_CVtime_NN], dim='param')
diff_Gt_all_param_together_CVisf = xr.concat([diff_Gt_all_CVisf.sel(profile_domain=50),diff_Gt_all_CVisf_NN], dim='param')

In [None]:
diff_box1_all_param_together_CVtime = xr.concat([diff_box1_all_CVtime.sel(profile_domain=50),diff_box1_all_CVtime_NN], dim='param')
diff_box1_all_param_together_CVisf = xr.concat([diff_box1_all_CVisf.sel(profile_domain=50),diff_box1_all_CVisf_NN], dim='param')

In [None]:
RMSE_Gt_all_CVtime_Nisf = np.sqrt((diff_Gt_all_param_together_CVtime**2).sel(Nisf=special_nisf_list).mean(['time']))
RMSE_box1_all_CVtime_Nisf = np.sqrt((diff_box1_all_param_together_CVtime**2).sel(Nisf=special_nisf_list).mean(['nemo_run']))
RMSE_Gt_all_CVisf_Nisf = np.sqrt((diff_Gt_all_param_together_CVisf**2).sel(Nisf=special_nisf_list).mean(['time']))
RMSE_box1_all_CVisf_Nisf = np.sqrt((diff_box1_all_param_together_CVisf**2).sel(Nisf=special_nisf_list).mean(['nemo_run']))

In [None]:
RMSE_Gt_all_CVtime_Nisf

In [None]:
plotted_var = RMSE_Gt_all_CVisf_Nisf.sel(param=param_list_of_int)

x = np.arange(len(plotted_var.param))
y = np.arange(len(plotted_var.Nisf))

fig, ax = plt.subplots(1, 1,figsize=((len(y)+1)/4,(len(x)+2)/2.75))
ax.imshow(plotted_var.values, cmap=plt.cm.coolwarm, vmin=-100, vmax=100)
ax.yaxis.tick_right()
ax.set_yticks(x)
ax.yaxis.tick_right()
ax.xaxis.tick_top()    
ax.set_xticks(y)
ax.set_xticklabels(labels=isf_names.sel(Nisf=plotted_var.Nisf).values, rotation=90)


In [None]:
RMSE_Gt_all_CVtime = np.sqrt((diff_Gt_all_param_together_CVtime**2).sel(Nisf=special_nisf_list).mean(['time','Nisf']))
RMSE_box1_all_CVtime = np.sqrt((diff_box1_all_param_together_CVtime**2).sel(Nisf=special_nisf_list).mean(['nemo_run','Nisf']))
RMSE_Gt_all_CVisf = np.sqrt((diff_Gt_all_param_together_CVisf**2).sel(Nisf=special_nisf_list).mean(['time','Nisf']))
RMSE_box1_all_CVisf = np.sqrt((diff_box1_all_param_together_CVisf**2).sel(Nisf=special_nisf_list).mean(['nemo_run','Nisf']))

In [None]:
#param_list_of_int = ['linear_local', 'quadratic_local', 'quadratic_local_cavslope',
#                    'quadratic_local_locslope', 'quadratic_mixed_mean','quadratic_mixed_cavslope',
#                    'quadratic_mixed_locslope', 
#                     'lazero19_2','lazero19_modif2',
#                     'boxes_1_pismyes_picopno','boxes_2_pismyes_picopno','boxes_3_pismyes_picopno','boxes_4_pismyes_picopno',
#                     'boxes_3_pismyes_picopyes', 'boxes_4_pismno_picopyes']

#param_list_of_int = ['quadratic_local', 
#                    'quadratic_local_locslope', 
#                    'lazero19_2',
#                    'boxes_3_pismyes_picopno',
#                    'boxes_4_pismno_picopyes',
#                    'mini', 'small', 'medium','large','extra_large'
#                    ]

#param_list_of_int = ['quadratic_local', 
#                    'quadratic_local_locslope', 
#                    'lazero19_2',
#                    'boxes_3_pismyes_picopno',
#                    'boxes_4_pismno_picopyes',
#                    'mediumtestwoconstants_TSextrap',
#                    'medium_onlyTSisfdraft_extrap',
#                    'medium_onlyTSdraftandslope_extrap',
#                    'medium_onlyTSdraftandslope_whole',
#                    'medium_onlyTSdraftandslopeandconc_extrap',
#                    'extra_large_onlyTSdraftandslope_extrap',
#                    'medium_TSTfdGLdIFwcd_extrap',
#                    'medium_TSdraftbotandiceddandwcd_extrap',
#                    'medium_TSdraftbotandiceddandwcdreldGL_extrap',
#                    'medium_TSdraftslopereldGL_extrap',
#                    'extra_large_TSdraftslopereldGL_extrap'
#                    ]


param_list_of_int = ['quadratic_local', 
                     'quadratic_local_locslope', 
                    'lazero19_2',
                    'boxes_4_pismyes_picopno',
                    'boxes_4_pismno_picopyes',
                    'mini_newbasic2_extrap',
                    #'xsmall64_newbasic2_extrap',
                    'xsmall96_newbasic2_extrap',
                    #'mini_newbasic_extrap',
                    'small_newbasic2_extrap',
                    #'small_newbasic_extrap',
                    #'small64_newbasic2_extrap',
                    'medium_newbasic2_extrap',
                    #'medium_newbasic_extrap',
                    'large_newbasic2_extrap',
                    #'large_newbasic_extrap',
                    'extra_large_newbasic2_extrap',
                    #'extra_large_newbasic_extrap'
                    ]

param_list_of_int = [#'linear_local',
                    'quadratic_local', 
                     'quadratic_local_locslope', 
                    'lazero19_2',
                    'boxes_4_pismyes_picopno',
                    'boxes_4_pismno_picopyes',
                    'mini_newbasic2_extrap',
                    #'xsmall64_newbasic2_extrap',
                    'xsmall96_newbasic2_extrap',
                    #'mini_newbasic_extrap',
                    'small_newbasic2_extrap',
                    #'small_newbasic_extrap',
                    #'small64_newbasic2_extrap',
                    'medium_newbasic2_extrap',
                    #'medium_newbasic_extrap',
                    'large_newbasic2_extrap',
                    #'large_newbasic_extrap',
                    'extra_large_newbasic2_extrap',
                    #'extra_large_newbasic_extrap'
                    ]


#param_list_of_int = ['quadratic_local', 
#                    'lazero19_2',
#                    'boxes_4_pismyes_picopno',
#                    'boxes_4_pismno_picopyes',
#                    'medium_newbasic2_extrap',
#                    'large_newbasic2_extrap',
#                    ]

#'quadratic_local_locslope', 

#param_list_of_int = ['quadratic_local', 
#                    'quadratic_local_locslope', 
#                    'lazero19_2',
#                    'boxes_3_pismyes_picopno',
#                    'boxes_4_pismno_picopyes',
#                    'mediumtestwoconstants',
#                    'medium_onlyTSisfdraft',
#                    'medium_onlyTSdraftandslope',
#                    'extra_large_onlyTSdraftandslope',
#                    'medium_TSTfdGLdIFwcd',
#                    'medium_TSdraftbotandiceddandwcd',
#                    'medium_TSdraftbotandiceddandwcdreldGL',
#                    'medium_TSdraftslopereldGL',
#                    'extra_large_TSdraftslopereldGL'
#                    ]

#param_list_of_int = ['mediumtestwoconstants']

#param_list_of_nn = ['mini', 'small', 'medium','large','extra_large']
param_list_of_nn = ['medium_newbasic_extrap',
                    'medium_onlyTSisfdraft_extrap',
                    'medium_onlyTSdraftandslope_extrap',
                    'medium_TSTfdGLdIFwcd_extrap',
                    'medium_TSdraftbotandiceddandwcd_extrap',
                    'medium_TSdraftbotandiceddandwcdreldGL_extrap',
                    'medium_TSdraftslopereldGL_extrap',
                    'extra_large_TSdraftslopereldGL_extrap'
                    ]

                    #'extra_large_onlyTSdraftandslope_extrap',
                    #'medium_onlyTSdraftandslopeandconc_extrap',




param_simple_list = ['linear_local', 'quadratic_local', 'quadratic_local_cavslope',
                   'quadratic_local_locslope', 'quadratic_mixed_cavslope',
                   'quadratic_mixed_locslope', 'quadratic_mixed_mean']
param_plume_list = ['lazero19_2','lazero19_modif2']
param_box_list_hetero =  ['boxes_1_pismyes_picopno','boxes_2_pismyes_picopno','boxes_3_pismyes_picopno','boxes_4_pismyes_picopno'] #'boxes_1_pismno_picopno','boxes_2_pismno_picopno','boxes_3_pismno_picopno','boxes_4_pismno_picopno',
param_picop_list = ['boxes_3_pismyes_picopyes', 'boxes_4_pismno_picopyes']

param_list_box_whiskers = ['linear_local', 'quadratic_local', 'quadratic_local_cavslope',
                   'quadratic_local_locslope', 'quadratic_mixed_mean','quadratic_mixed_cavslope',
                   'quadratic_mixed_locslope', 
                    'lazero19_2',
                     'boxes_3_pismyes_picopno','boxes_4_pismyes_picopno',
                     'boxes_3_pismyes_picopyes', 'boxes_4_pismno_picopyes']

In [None]:
fig, axs = plt.subplots(1, 2,figsize=(8.24*1.25/1.5,8.24/2),sharey=True)

ccolors = ['steelblue','darkcyan','c','deepskyblue','skyblue','aquamarine','green','lightsteelblue','cyan','orange','magenta']

k = 0
#plt.figure()
for mparam in param_list_of_int[::-1]:
    
    if mparam in param_plume_list:
        ccolor = 'darkgrey'
    elif mparam in param_box_list_hetero:
        ccolor = 'darkgrey'
    elif mparam in param_picop_list:
        ccolor = 'darkgrey'
    elif mparam in param_simple_list:
        ccolor = 'darkgrey'
    else:
        ccolor = ccolors[k]
        k = k+1
        
    axs[0].scatter(RMSE_Gt_all_CVtime.sel(param=mparam),mparam,marker='x',c='k', s=50)
    axs[1].scatter(RMSE_box1_all_CVtime.sel(param=mparam),mparam,marker='x',c='k', s=50)
    axs[0].scatter(RMSE_Gt_all_CVtime.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)
    axs[1].scatter(RMSE_box1_all_CVtime.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)

    axs[0].scatter(RMSE_Gt_all_CVisf.sel(param=mparam),mparam,marker='+',c='k', s=120)
    axs[1].scatter(RMSE_box1_all_CVisf.sel(param=mparam),mparam,marker='+',c='k', s=120)
    axs[0].scatter(RMSE_Gt_all_CVisf.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)
    axs[1].scatter(RMSE_box1_all_CVisf.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)

axs[0].set_xlim(0,110)
axs[1].set_xlim(0,2.15)
#axs[1].set_xlim(0,0.9)

sns.despine()
#plt.savefig(plot_path+'RMSE_scatter_box1_CV_compare_classic_NN_newbasic.pdf')

PERCENTAGES IN CHANGES OF RMSE

In [None]:
rmse_of_int = RMSE_box1_all_CVisf
for mparam in [
                    'quadratic_local', 
                     'quadratic_local_locslope', 
                    'lazero19_2',
                    'boxes_4_pismyes_picopno',
                    'boxes_4_pismno_picopyes',
                    'small_newbasic2_extrap']:
    diff_to_eval = rmse_of_int.sel(param=mparam) - rmse_of_int.sel(param='small_newbasic2_extrap')
    perc = diff_to_eval/rmse_of_int.sel(param=mparam)*100
    print(diff_to_eval.load().values,np.round(perc.values,2))

In [None]:
diff_to_eval.load()

In [None]:
fig, axs = plt.subplots(1, 2,figsize=(8.24*1.25/1.5,8.24/2),sharey=True)

ccolors = ['steelblue','darkcyan','c','deepskyblue','skyblue','aquamarine','green','lightsteelblue','cyan','orange','magenta']

k = 0
#plt.figure()
for mparam in param_list_of_int[::-1]:
    
    if mparam in param_plume_list:
        ccolor = 'darkorange'
    elif mparam in param_box_list_hetero:
        ccolor = 'purple'
    elif mparam in param_picop_list:
        ccolor = 'maroon'
    elif mparam in param_simple_list:
        ccolor = 'gold'
    else:
        ccolor = ccolors[k]
        k = k+1
        
    axs[0].scatter(RMSE_Gt_all_CVtime.sel(param=mparam),mparam,marker='x',c='k', s=50)
    axs[1].scatter(RMSE_box1_all_CVtime.sel(param=mparam),mparam,marker='x',c='k', s=50)
    axs[0].scatter(RMSE_Gt_all_CVtime.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)
    axs[1].scatter(RMSE_box1_all_CVtime.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)

    axs[0].scatter(RMSE_Gt_all_CVisf.sel(param=mparam),mparam,marker='+',c='k', s=120)
    axs[1].scatter(RMSE_box1_all_CVisf.sel(param=mparam),mparam,marker='+',c='k', s=120)
    axs[0].scatter(RMSE_Gt_all_CVisf.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)
    axs[1].scatter(RMSE_box1_all_CVisf.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)

axs[0].set_xlim(0,70)
axs[1].set_xlim(0,2.15)
#axs[1].set_xlim(0,0.9)

sns.despine()
plt.savefig(plot_path+'RMSE_scatter_box1_CV_compare_classic_NN_newbasic_forOPENworkshop.pdf')

In [None]:
RMSE_Gt_all_CVtime_wolargeones = np.sqrt((diff_Gt_all_param_together_CVtime**2).sel(Nisf=special_nisf_list).drop_sel(Nisf=[10,11]).mean(['time','Nisf']))
RMSE_box1_all_CVtime_wolargeones = np.sqrt((diff_box1_all_param_together_CVtime**2).sel(Nisf=special_nisf_list).drop_sel(Nisf=[10,11]).mean(['nemo_run','Nisf']))
RMSE_Gt_all_CVisf_wolargeones = np.sqrt((diff_Gt_all_param_together_CVisf**2).sel(Nisf=special_nisf_list).drop_sel(Nisf=[10,11]).mean(['time','Nisf']))
RMSE_box1_all_CVisf_wolargeones = np.sqrt((diff_box1_all_param_together_CVisf**2).sel(Nisf=special_nisf_list).drop_sel(Nisf=[10,11]).mean(['nemo_run','Nisf']))

In [None]:
fig, axs = plt.subplots(1, 2,figsize=(8.24*1.25/1.5,8.24/2),sharey=True)

ccolors = ['steelblue','darkcyan','c','deepskyblue','skyblue','aquamarine','green','lightsteelblue','cyan','orange','magenta']

k = 0
#plt.figure()
for mparam in param_list_of_int[::-1]:
    
    if mparam in param_plume_list:
        ccolor = 'darkgrey'
    elif mparam in param_box_list_hetero:
        ccolor = 'darkgrey'
    elif mparam in param_picop_list:
        ccolor = 'darkgrey'
    elif mparam in param_simple_list:
        ccolor = 'darkgrey'
    else:
        ccolor = ccolors[k]
        k = k+1
        
    axs[0].scatter(RMSE_Gt_all_CVtime_wolargeones.sel(param=mparam),mparam,marker='x',c='k', s=50)
    axs[1].scatter(RMSE_box1_all_CVtime_wolargeones.sel(param=mparam),mparam,marker='x',c='k', s=50)
    axs[0].scatter(RMSE_Gt_all_CVtime_wolargeones.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)
    axs[1].scatter(RMSE_box1_all_CVtime_wolargeones.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)

    axs[0].scatter(RMSE_Gt_all_CVisf_wolargeones.sel(param=mparam),mparam,marker='+',c='k', s=120)
    axs[1].scatter(RMSE_box1_all_CVisf_wolargeones.sel(param=mparam),mparam,marker='+',c='k', s=120)
    axs[0].scatter(RMSE_Gt_all_CVisf_wolargeones.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)
    axs[1].scatter(RMSE_box1_all_CVisf_wolargeones.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)

axs[0].set_xlim(0,60)
axs[1].set_xlim(0,2.15)
#axs[1].set_xlim(0,0.9)

sns.despine()
plt.savefig(plot_path+'RMSE_scatter_box1_CV_compare_classic_NN_newbasic_wolargeones.pdf')

In [None]:
RMSE_Gt_all_CVtime_wolargeones.load()

In [None]:
for mparam in param_list_of_int:
    print(mparam, RMSE_Gt_all_CVtime.sel(param=mparam).values)

In [None]:
for mparam in param_list_of_int:
    print(mparam, RMSE_Gt_all_CVisf.sel(param=mparam).values)

In [None]:
for mparam in param_list_of_int:
    print(mparam, RMSE_box1_all_CVtime.sel(param=mparam).values)

In [None]:
for mparam in param_list_of_int:
    print(mparam, RMSE_box1_all_CVisf.sel(param=mparam).values)

In [None]:
fig, axs = plt.subplots(1, 2,figsize=(8.24*1.25/1.5,8.24/2),sharey=True)

ccolors = ['aqua', 'aquamarine', 'deepskyblue', 'blue', 'darkblue', 'royalblue','green','lightsteelblue','cyan','orange','magenta']

k = 0
#plt.figure()
for mparam in param_list_of_nn[::-1]:
    
    if mparam in param_plume_list:
        ccolor = 'darkorange'
    elif mparam in param_box_list_hetero:
        ccolor = 'purple'
    elif mparam in param_picop_list:
        ccolor = 'maroon'
    elif mparam in param_simple_list:
        ccolor = 'gold'
    else:
        ccolor = ccolors[k]
        k = k+1
        
    axs[0].scatter(RMSE_Gt_all_CVtime.sel(param=mparam),mparam,marker='x',c='k', s=50)
    axs[1].scatter(RMSE_box1_all_CVtime.sel(param=mparam),mparam,marker='x',c='k', s=50)
    axs[0].scatter(RMSE_Gt_all_CVtime.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)
    axs[1].scatter(RMSE_box1_all_CVtime.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)

    axs[0].scatter(RMSE_Gt_all_CVisf.sel(param=mparam),mparam,marker='+',c='k', s=120)
    axs[1].scatter(RMSE_box1_all_CVisf.sel(param=mparam),mparam,marker='+',c='k', s=120)
    axs[0].scatter(RMSE_Gt_all_CVisf.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)
    axs[1].scatter(RMSE_box1_all_CVisf.sel(param=mparam),mparam,marker='o',c=ccolor, s=20)

axs[0].set_xlim(0,100)
axs[1].set_xlim(0,2.05)

sns.despine()
plt.savefig(plot_path+'RMSE_scatter_box1_CV_compare_classic_NN_experiments.pdf')

In [None]:
for mparam in param_list_of_int[::-1]:
    print(mparam, RMSE_Gt_all_CVtime.sel(param=mparam).values)

In [None]:
RMSE_Gt_all_CVtime.param

In [None]:
plotted_var = diff_Gt_all_param_together_CVisf.mean('time').sel(param=param_list_of_int,Nisf=nisf_by_reg_list)

x = np.arange(len(plotted_var.param))
y = np.arange(len(plotted_var.Nisf))

fig, axs = plt.subplots(1, 1,figsize=((len(y)+1)/4,(len(x)+2)/2.75))
ax0 = axs.imshow(plotted_var.values, cmap=plt.cm.coolwarm, vmin=-100, vmax=100)
axs.set_yticklabels(labels=np.round(RMSE_Gt_all_CVisf.sel(param=param_list_of_int).values,2))

for j, dom in enumerate([50]):
    axs.yaxis.tick_right()
    axs.set_yticks(x)
    axs.yaxis.tick_right()

    axs.xaxis.tick_top()    
    axs.set_xticks(y)
    axs.set_xticklabels(labels=isf_names.sel(Nisf=plotted_var.Nisf).values, rotation=90)


cb_ax = fig.add_axes([0.15, 0.08, 0.7, 0.02])
#cb_ax = fig.add_axes([0.01, 0.2, 0.02, 0.7])
cbar = fig.colorbar(ax0, cax=cb_ax, extend='both',orientation='horizontal')
plt.tight_layout()


ONLY NNs

In [None]:
plotted_var = diff_Gt_all_param_together_CVisf.mean('time').sel(param=param_list_of_nn,Nisf=nisf_by_reg_list)


x = np.arange(len(plotted_var.param))
y = np.arange(len(plotted_var.Nisf))

fig, axs = plt.subplots(1, 1,figsize=((len(y)+1)/4,(len(x)+2)/2.75))
ax0 = axs.imshow(plotted_var.values, cmap=plt.cm.coolwarm, vmin=-100, vmax=100)
axs.set_yticklabels(labels=np.round(RMSE_Gt_all_CVisf.sel(param=param_list_of_int).values,2))

for j, dom in enumerate([50]):
    axs.yaxis.tick_right()
    axs.set_yticks(x)
    axs.yaxis.tick_right()

    axs.xaxis.tick_top()    
    axs.set_xticks(y)
    axs.set_xticklabels(labels=isf_names.sel(Nisf=plotted_var.Nisf).values, rotation=90)


cb_ax = fig.add_axes([0.15, 0.08, 0.7, 0.02])
#cb_ax = fig.add_axes([0.01, 0.2, 0.02, 0.7])
cbar = fig.colorbar(ax0, cax=cb_ax, extend='both',orientation='horizontal')
plt.tight_layout()

In [None]:
plotted_var = diff_box1_all_param_together_CVisf.mean('nemo_run').sel(param=param_list_of_int,Nisf=nisf_by_reg_list)


x = np.arange(len(plotted_var.param))
y = np.arange(len(plotted_var.Nisf))

fig, axs = plt.subplots(1, 1,figsize=((len(y)+1)/4,(len(x)+2)/2.75))
ax0 = axs.imshow(plotted_var.values, cmap=plt.cm.coolwarm, vmin=-3, vmax=3)
axs.set_yticklabels(labels=np.round(RMSE_box1_all_CVisf.sel(param=param_list_of_int).values,2))

for j, dom in enumerate([50]):
    axs.yaxis.tick_right()
    axs.set_yticks(x)
    axs.yaxis.tick_right()

    axs.xaxis.tick_top()    
    axs.set_xticks(y)
    axs.set_xticklabels(labels=isf_names.sel(Nisf=plotted_var.Nisf).values, rotation=90)


cb_ax = fig.add_axes([0.15, 0.08, 0.7, 0.02])
#cb_ax = fig.add_axes([0.01, 0.2, 0.02, 0.7])
cbar = fig.colorbar(ax0, cax=cb_ax, extend='both',orientation='horizontal')
plt.tight_layout()
