In [None]:
"""
Created on Wed Apr 20 10:58 2022

Evaluating results computed with NN model (1D)

Author: @claraburgard

"""

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

READ IN DATA

In [None]:
home_path = '/bettik/burgardc/'

In [None]:
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/'
timetag = '20220414-1706'
new_path_output = outputpath_melt_nn+timetag+'/'

In [None]:
inputpath_mask = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_OPM006/'
file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_new.nc')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
file_isf = file_isf_nonnan.sel(Nisf=large_isf)
file_isf['isf_name'] = file_isf['isf_name'].where(file_isf['isf_name'] != 'Ekstrom', np.array('Ekström', dtype=object))
isf_names = file_isf['isf_name']

In [None]:
region_list = []
for kisf in file_isf.Nisf:
    if file_isf['isf_name'].sel(Nisf=kisf).values in ['Ross','Nickerson','Sulzberger', 'Cook']:
        region_list.append('East and West Ross')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Filchner','Ronne']:
        region_list.append('Weddell')        
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Ekström','Nivl','Prince Harald','Riiser-Larsen','Fimbul','Roi Baudouin','Lazarev','Stancomb Brunt','Jelbart','Borchgrevink']:
        region_list.append('Dronning Maud Land')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Getz','Thwaites','Crosson','Dotson','Cosgrove','Pine Island']:
        region_list.append('Amundsen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Venable','George VI','Abbot','Stange','Larsen C','Bach','Larsen D','Wilkins']:
        region_list.append('Peninsula and Bellinghausen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Amery','Moscow Univ.','Tracy Tremenchus','Totten','West','Shackleton']:
        region_list.append('East Antarctica')
    else:
        print('Argh, help me, '+file_isf['isf_name'].sel(Nisf=kisf).values+' has no region assigned!')
# 6 regions
# regions = ['East and West Ross','Weddell','Dronning Maud Land','Amundsen','Peninsula and Bellinghausen','East Antarctica']
file_isf['region'] = xr.DataArray(data=region_list,dims='Nisf')

regions = ['Weddell','Peninsula and Bellinghausen','Amundsen','East and West Ross','East Antarctica','Dronning Maud Land']

In [None]:
nisf_by_reg_list = []
for rr, reg in enumerate(regions):
    subset_isf = file_isf.Nisf.where(file_isf['region']==reg,drop=True)
    nisf_by_reg_list.append(subset_isf.values)
nisf_by_reg_list = np.concatenate(nisf_by_reg_list)

In [None]:
#run_list = ['OPM006','OPM016','OPM018','OPM021','OPM026','OPM027','OPM031']
run_list = ['OPM006','OPM016','OPM018','OPM021','OPM026','OPM027','OPM031']
#run_list = ['OPM031'] 


diff_Gt_list = []
diff_box1_list = []
diff_Gt_anom_list = []
diff_box1_anom_list = []

ref_Gt_list = []
ref_box1_list = []


for n,nemo_run in enumerate(run_list):
    
    outputpath_melt = home_path+'DATA/BASAL_MELT_PARAM/processed/MELT_RATE/nemo_5km_'+nemo_run+'/'

    ### READ IN THE REFERENCE
    NEMO_melt_rates_1D = xr.open_dataset(outputpath_melt+'melt_rates_1D_NEMO.nc')
    ref_Gt = NEMO_melt_rates_1D['melt_Gt_per_y_tot']
    ref_Gt_list.append(ref_Gt)
    NEMO_box1_myr = xr.open_dataset(outputpath_melt+'melt_rates_box1_NEMO.nc')
    ref_box1 = NEMO_box1_myr['mean_melt_box1_myr']
    ref_box1_list.append(ref_box1)
    
    ### READ IN THE PARAM FILES - NON BOOTSTRAP
    
    # Param files

    ds_melt_param = xr.open_dataset(new_path_output+'eval_metrics_'+nemo_run+'.nc')
    diff_Gt = ds_melt_param['melt_1D_Gt_per_y'] - ref_Gt
    diff_Gt_anom = (ds_melt_param['melt_1D_Gt_per_y'] - ds_melt_param['melt_1D_Gt_per_y'].mean('time'))  - (ref_Gt - ref_Gt.mean('time'))
    #diff_Gt = diff_Gt.assign_coords({'time': np.arange(1,len(diff_Gt.time)+1)+n*50})
    diff_Gt_list.append(diff_Gt)
    diff_Gt_anom_list.append(diff_Gt_anom)
    
    diff_box1 = ds_melt_param['melt_1D_mean_myr_box1'].mean('time') - ref_box1.mean('time')
    diff_box1_list.append(diff_box1)    
    
    
diff_Gt_all = xr.concat(diff_Gt_list, dim='nemo_run')
diff_Gt_all = diff_Gt_all.assign_coords(nemo_run=run_list)
diff_Gt_anom_all = xr.concat(diff_Gt_anom_list, dim='nemo_run')
diff_Gt_anom_all = diff_Gt_anom_all.assign_coords(nemo_run=run_list)
diff_box1_all = xr.concat(diff_box1_list, dim='nemo_run') 
diff_box1_all = diff_box1_all.assign_coords(nemo_run=run_list)
ref_Gt_all = xr.concat(ref_Gt_list, dim='nemo_run')
ref_Gt_all = ref_Gt_all.assign_coords(nemo_run=run_list)
ref_box1_all = xr.concat(ref_box1_list, dim='nemo_run')
ref_box1_all = ref_box1_all.assign_coords(nemo_run=run_list)

In [None]:
RMSE_Gt_all = np.sqrt((diff_Gt_all**2).mean(['time','Nisf']))
RMSE_box1_all = np.sqrt((diff_box1_all**2).mean(['Nisf']))

In [None]:
mean_Gt = ref_Gt_all.mean(['time','Nisf'])
mean_box1 = ref_box1_all.mean(['time','Nisf'])

In [None]:
fig, axs = plt.subplots(1, 2,figsize=(8.24*1.25/1.5,8.24*1.25/2),sharey=True)
#plt.figure()

for mparam in [timetag]:
    for nrun in run_list:

        if nrun == 'OPM006':
            ccolor= 'magenta'
        elif nrun == 'OPM016':
            ccolor= 'orange'
        elif nrun == 'OPM018':
            ccolor= 'brown'
        elif nrun == 'OPM021':
            ccolor = 'red'
        elif nrun == 'OPM026':
            ccolor = 'yellowgreen'
        elif nrun == 'OPM027':
            ccolor = 'deepskyblue'
        elif nrun == 'OPM031':
            ccolor = 'blue'

        if 'profile_domain' in RMSE_Gt_all.coords:
            axs[0].scatter(RMSE_Gt_all.sel(profile_domain=50, nemo_run=nrun),mparam,marker='o',c=ccolor)
            axs[1].scatter(RMSE_box1_all.sel(profile_domain=50, nemo_run=nrun),mparam,marker='o',c=ccolor)
        else:
            axs[0].scatter(RMSE_Gt_all.sel(nemo_run=nrun),mparam,marker='o',c=ccolor)
            axs[1].scatter(RMSE_box1_all.sel(nemo_run=nrun),mparam,marker='o',c=ccolor)


sns.despine()

In [None]:
training_runs = ['OPM006', 'OPM016', 'OPM018', 'OPM031']
RMSE_Gt_all_train = np.sqrt((diff_Gt_all**2).sel(nemo_run=training_runs).mean(['time','Nisf','nemo_run']))
RMSE_box1_all_train = np.sqrt((diff_box1_all**2).sel(nemo_run=training_runs).mean(['Nisf','nemo_run']))
other_runs = ['OPM021', 'OPM026', 'OPM027','OPM031']
RMSE_Gt_all_other = np.sqrt((diff_Gt_all**2).sel(nemo_run=other_runs).mean(['time','Nisf','nemo_run']))
RMSE_box1_all_other = np.sqrt((diff_box1_all**2).sel(nemo_run=other_runs).mean(['Nisf','nemo_run']))

In [None]:
fig, axs = plt.subplots(1, 2,figsize=(8.24*1.25/1.5,8.24*1.25/2),sharey=True)
#plt.figure()

for mparam in ['training','other']:
    if mparam == 'training':

        axs[0].scatter(RMSE_Gt_all_train,mparam,marker='o',c=ccolor)
        axs[1].scatter(RMSE_box1_all_train,mparam,marker='o',c=ccolor)
        
    else:
        
        axs[0].scatter(RMSE_Gt_all_other,mparam,marker='o',c=ccolor)
        axs[1].scatter(RMSE_box1_all_other,mparam,marker='o',c=ccolor)
                

sns.despine()