In [None]:
"""
Created on Wed Apr 20 10:58 2022

Evaluating results computed with NN model (1D)

Author: @claraburgard

"""

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

READ IN DATA

In [None]:
home_path = '/bettik/burgardc/'

In [None]:
inputpath_mask = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_OPM006/'
file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_new.nc')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
file_isf = file_isf_nonnan.sel(Nisf=large_isf)
file_isf['isf_name'] = file_isf['isf_name'].where(file_isf['isf_name'] != 'Ekstrom', np.array('Ekström', dtype=object))
isf_names = file_isf['isf_name']

In [None]:
region_list = []
for kisf in file_isf.Nisf:
    if file_isf['isf_name'].sel(Nisf=kisf).values in ['Ross','Nickerson','Sulzberger', 'Cook']:
        region_list.append('East and West Ross')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Filchner','Ronne']:
        region_list.append('Weddell')        
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Ekström','Nivl','Prince Harald','Riiser-Larsen','Fimbul','Roi Baudouin','Lazarev','Stancomb Brunt','Jelbart','Borchgrevink']:
        region_list.append('Dronning Maud Land')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Getz','Thwaites','Crosson','Dotson','Cosgrove','Pine Island']:
        region_list.append('Amundsen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Venable','George VI','Abbot','Stange','Larsen C','Bach','Larsen D','Wilkins']:
        region_list.append('Peninsula and Bellinghausen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Amery','Moscow Univ.','Tracy Tremenchus','Totten','West','Shackleton']:
        region_list.append('East Antarctica')
    else:
        print('Argh, help me, '+file_isf['isf_name'].sel(Nisf=kisf).values+' has no region assigned!')
# 6 regions
# regions = ['East and West Ross','Weddell','Dronning Maud Land','Amundsen','Peninsula and Bellinghausen','East Antarctica']
file_isf['region'] = xr.DataArray(data=region_list,dims='Nisf')

regions = ['Weddell','Peninsula and Bellinghausen','Amundsen','East and West Ross','East Antarctica','Dronning Maud Land']

In [None]:
nisf_by_reg_list = []
for rr, reg in enumerate(regions):
    subset_isf = file_isf.Nisf.where(file_isf['region']==reg,drop=True)
    nisf_by_reg_list.append(subset_isf.values)
nisf_by_reg_list = np.concatenate(nisf_by_reg_list)

In [None]:
#run_list = ['OPM006','OPM016','OPM018','OPM021','OPM026','OPM027','OPM031']
run_list = ['OPM006','OPM016','OPM018','OPM021','OPM026','OPM027','OPM031-2'] #'OPM031-1',
#run_list = ['OPM031'] 
timetag_list = ['20220427-0957','20220427-1002',
                '20220427-1052','20220427-1021',
                '20220427-1058','20220427-1042',
                '20220427-1059','20220427-1051']

diff_Gt_list = []
diff_box1_list = []

ref_Gt_list = []
ref_box1_list = []


    
for n,nemo_run0 in enumerate(run_list):

    if nemo_run0 in ['OPM031-1','OPM031-2']:
        nemo_run = 'OPM031'
    else:
        nemo_run = nemo_run0

    outputpath_melt = home_path+'DATA/BASAL_MELT_PARAM/processed/MELT_RATE/nemo_5km_'+nemo_run+'/'

    ### READ IN THE REFERENCE
    NEMO_melt_rates_1D = xr.open_dataset(outputpath_melt+'melt_rates_1D_NEMO.nc')
    ref_Gt = NEMO_melt_rates_1D['melt_Gt_per_y_tot']
    ref_Gt_list.append(ref_Gt)
    NEMO_box1_myr = xr.open_dataset(outputpath_melt+'melt_rates_box1_NEMO.nc')
    ref_box1 = NEMO_box1_myr['mean_melt_box1_myr']
    ref_box1_list.append(ref_box1)

    ### READ IN THE PARAM FILES - NON BOOTSTRAP

    # Param files
    

    diff_Gt_sub_list = []
    diff_box1_sub_list = []

    for timetag in timetag_list:

        outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/'
        new_path_output = outputpath_melt_nn+timetag+'/'

        ds_melt_param = xr.open_dataset(new_path_output+'eval_metrics_'+nemo_run0+'.nc')

        diff_Gt = ds_melt_param['melt_1D_Gt_per_y'] - ref_Gt
        diff_Gt_sub_list.append(diff_Gt)

        diff_box1 = ds_melt_param['melt_1D_mean_myr_box1'].mean('time') - ref_box1.mean('time')
        diff_box1_sub_list.append(diff_box1)    

    
    diff_Gt_sub = xr.concat(diff_Gt_sub_list, dim='nn_model')
    diff_Gt_sub = diff_Gt_sub.assign_coords(nn_model=timetag_list)
    diff_box1_sub = xr.concat(diff_box1_sub_list, dim='nn_model') 
    diff_box1_sub = diff_box1_sub.assign_coords(nn_model=timetag_list)
    
    diff_Gt_list.append(diff_Gt_sub)    
    diff_box1_list.append(diff_box1_sub)
    
diff_Gt_all = xr.concat(diff_Gt_list, dim='nemo_run')
diff_Gt_all = diff_Gt_all.assign_coords(nemo_run=run_list)
diff_box1_all = xr.concat(diff_box1_list, dim='nemo_run') 
diff_box1_all = diff_box1_all.assign_coords(nemo_run=run_list)
ref_Gt_all = xr.concat(ref_Gt_list, dim='nemo_run')
ref_Gt_all = ref_Gt_all.assign_coords(nemo_run=run_list)
ref_box1_all = xr.concat(ref_box1_list, dim='nemo_run')
ref_box1_all = ref_box1_all.assign_coords(nemo_run=run_list)

In [None]:
RMSE_Gt_all = np.sqrt((diff_Gt_all**2).mean(['time','Nisf']))
RMSE_box1_all = np.sqrt((diff_box1_all**2).mean(['Nisf']))

In [None]:
mean_Gt = ref_Gt_all.mean(['time','Nisf'])
mean_box1 = ref_box1_all.mean(['time','Nisf'])

FIGURE: SCATTER OF ALL RUNS UNDER ONE GIVEN MODEL

In [None]:
fig, axs = plt.subplots(1, 2,figsize=(8.24*1.25/1.5,8.24*1.25/2),sharey=True)
#plt.figure()

for mparam in timetag_list[::-1]:
    
    for nrun in run_list:

        if nrun == 'OPM006':
            ccolor= 'magenta'
        elif nrun == 'OPM016':
            ccolor= 'orange'
        elif nrun == 'OPM018':
            ccolor= 'brown'
        elif nrun == 'OPM021':
            ccolor = 'red'
        elif nrun == 'OPM026':
            ccolor = 'yellowgreen'
        elif nrun == 'OPM027':
            ccolor = 'deepskyblue'
        elif nrun == 'OPM031-1':
            ccolor = 'blue'
        elif nrun == 'OPM031-2':
            ccolor = 'purple'

        if 'profile_domain' in RMSE_Gt_all.coords:
            axs[0].scatter(RMSE_Gt_all.sel(profile_domain=50, nemo_run=nrun),mparam,marker='o',c=ccolor)
            axs[1].scatter(RMSE_box1_all.sel(profile_domain=50, nemo_run=nrun),mparam,marker='o',c=ccolor)
        else:
            axs[0].scatter(RMSE_Gt_all.sel(nn_model=mparam,nemo_run=nrun),mparam,marker='o',c=ccolor)
            axs[1].scatter(RMSE_box1_all.sel(nn_model=mparam,nemo_run=nrun),mparam,marker='o',c=ccolor)

axs[0].set_xlim(0,20)
sns.despine()


FIGURE: SCATTER COMPARE RMSE FOR TRAINING DATASET AND NON-TRAINING

In [None]:
training_runs = ['OPM006', 'OPM016', 'OPM018', 'OPM031-2']
RMSE_Gt_all_train = np.sqrt((diff_Gt_all**2).sel(nemo_run=training_runs).mean(['time','Nisf','nemo_run']))
RMSE_box1_all_train = np.sqrt((diff_box1_all**2).sel(nemo_run=training_runs).mean(['Nisf','nemo_run']))
other_runs = ['OPM021', 'OPM026', 'OPM027'] #,'OPM031-1'
RMSE_Gt_all_other = np.sqrt((diff_Gt_all**2).sel(nemo_run=other_runs).mean(['time','Nisf','nemo_run']))
RMSE_box1_all_other = np.sqrt((diff_box1_all**2).sel(nemo_run=other_runs).mean(['Nisf','nemo_run']))

In [None]:
RMSE_Gt_all_train

In [None]:
RMSE_Gt_all_other

In [None]:
fig, axs = plt.subplots(1, 2,figsize=(8.24*1.25/1.5,8.24*1.25/2),sharey=True)
#plt.figure()

for mparam in timetag_list[::-1]:

        axs[0].scatter(RMSE_Gt_all_train.sel(nn_model=mparam),mparam,marker='o',c='orange')
        axs[1].scatter(RMSE_box1_all_train.sel(nn_model=mparam),mparam,marker='o',c='orange')
        
        
        axs[0].scatter(RMSE_Gt_all_other.sel(nn_model=mparam),mparam,marker='o',c='deepskyblue')
        axs[1].scatter(RMSE_box1_all_other.sel(nn_model=mparam),mparam,marker='o',c='deepskyblue')
                
axs[0].set_xlim(0,80)
sns.despine()

Heatmap

In [None]:
RMSE_Gt_all_isf = np.sqrt((diff_Gt_all**2).sel(nemo_run=other_runs).mean(['time','nemo_run']))
RMSE_box1_all_isf = np.sqrt((diff_box1_all**2).sel(nemo_run=other_runs).mean(['nemo_run']))

In [None]:
plotted_var = RMSE_Gt_all_isf.sel(nn_model=timetag_list,Nisf=nisf_by_reg_list)


x = np.arange(len(plotted_var.nn_model))
y = np.arange(len(plotted_var.Nisf))

fig, axs = plt.subplots(1, 1,figsize=(8.25,8.25))
ax0 = axs.imshow(plotted_var.values, cmap=plt.cm.Reds, vmin=0, vmax=100)
axs.set_yticklabels(labels=np.round(RMSE_Gt_all_other.sel(nn_model=timetag_list).values,2))

for j, dom in enumerate([50]):
    axs.yaxis.tick_right()
    axs.set_yticks(x)
    axs.yaxis.tick_right()

    axs.xaxis.tick_top()    
    axs.set_xticks(y)
    axs.set_xticklabels(labels=isf_names.sel(Nisf=plotted_var.Nisf).values, rotation=90)


cb_ax = fig.add_axes([0.15, 0.35, 0.7, 0.02])
#cb_ax = fig.add_axes([0.01, 0.2, 0.02, 0.7])
cbar = fig.colorbar(ax0, cax=cb_ax, extend='max',orientation='horizontal')
plt.tight_layout()


In [None]:
timetag_list

In [None]:
training_runs

In [None]:
other_runs