In [None]:
"""
Created on Wed Apr 06 13:54 2022

Evaluate model coming out of 17 inputs on one other run

Author: @claraburgard

"""

In [None]:
import numpy as np
import xarray as xr
from tqdm.notebook import trange, tqdm
import glob
import matplotlib as mpl
import seaborn as sns
import time

import tensorflow as tf
from tensorflow import keras

from basal_melt_neural_networks.constants import *
import basal_melt_neural_networks.diagnostic_functions as diag
import basal_melt_neural_networks.data_formatting as dfmt

import cartopy
import cartopy.crs as ccrs

In [None]:
%matplotlib qt5

READ IN DATA

In [None]:
nemo_run = 'OPM021'
inputpath_data='/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/NEMO_eORCA025.L121_'+nemo_run+'_ANT_STEREO/'
inputpath_mask = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_'+nemo_run+'/'
inputpath_profiles = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/T_S_PROF/nemo_5km_'+nemo_run+'/'
inputpath_plumes = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/PLUMES/nemo_5km_'+nemo_run+'/'
inputpath_boxes = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/BOXES/nemo_5km_'+nemo_run+'/'
outputpath_melt = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/processed/MELT_RATE/nemo_5km_'+nemo_run+'/'
outputpath_nn = '/bettik/burgardc/SCRIPTS/basal_melt_neural_networks/data/interim/'
outputpath_doc = '/bettik/burgardc/SCRIPTS/basal_melt_neural_networks/custom_doc/'
inputpath_tides = '/bettik/burgardc/DATA/BASAL_MELT_PARAM/interim/TIDES/'

In [None]:
T_S_2D_isfdraft = xr.open_mfdataset(inputpath_profiles+'T_S_2D_fields_isf_draft.nc').sel(profile_domain=50)

In [None]:
# dIF, dGL, longitude, latitude
file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_new.nc')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
file_isf = file_isf_nonnan.sel(Nisf=large_isf)

In [None]:
# T and S profiles
file_TS_orig = xr.open_dataset(inputpath_profiles+'T_S_mean_prof_corrected_km_contshelf_and_offshore_1980-2018.nc')
file_TS = file_TS_orig.sel(Nisf=file_isf.Nisf)
file_TS_dom = file_TS.sel(profile_domain=50)

In [None]:
plume_charac = xr.open_dataset(inputpath_plumes+'nemo_5km_plume_characteristics.nc')

In [None]:
map_lim = [-3000000,3000000]
file_mask_orig = xr.open_dataset(inputpath_data+'other_mask_vars_Ant_stereo.nc')
file_mask_orig_cut = dfmt.cut_domain_stereo(file_mask_orig, map_lim, map_lim)
file_other = xr.open_dataset(inputpath_data+'corrected_draft_bathy_isf.nc')#, chunks={'x': chunk_size, 'y': chunk_size})
file_other_cut = dfmt.cut_domain_stereo(file_other, map_lim, map_lim)
file_conc = xr.open_dataset(inputpath_data+'isfdraft_conc_Ant_stereo.nc')
file_conc_cut = dfmt.cut_domain_stereo(file_conc, map_lim, map_lim)

In [None]:
# bathymetry, ice draft, concentration
file_bed_orig = file_mask_orig_cut['bathy_metry']
file_draft = file_other_cut['corrected_isfdraft'] 
file_isf_conc = file_conc_cut['isfdraft_conc']

In [None]:
file_slope = xr.open_dataset(inputpath_mask+'nemo_5km_slope_info_bedrock_draft.nc')
file_orientation = xr.open_dataset(inputpath_mask+'nemo_5km_orientation_info.nc')

In [None]:
utide_file = xr.open_dataset(inputpath_tides + 'tidal_velocity_nemo_Ant_stereo.nc')
u_tide = dfmt.cut_domain_stereo(utide_file['ttv'], map_lim, map_lim)

In [None]:
NEMO_melt_rates_2D = xr.open_mfdataset(outputpath_melt+'melt_rates_2D_NEMO.nc')
melt_rate = NEMO_melt_rates_2D['melt_m_ice_per_y']

MERGE ALL INFO

In [None]:
geometry_2D = file_isf[['dGL', 'dIF']].merge(file_draft).merge(file_bed_orig).merge(file_slope).merge(file_orientation).merge(u_tide)
geometry_2D['dIF'] = geometry_2D['dIF'].where(np.isfinite(geometry_2D['dIF']), np.nan)
time_dpdt_in = T_S_2D_isfdraft[['theta_in','salinity_in']].merge(melt_rate)

In [None]:
geometry_2D_br, time_dpdt_in_br = xr.broadcast(geometry_2D,time_dpdt_in)

In [None]:
final_input_xr = xr.merge([geometry_2D_br, time_dpdt_in_br]).transpose('y','x','time').drop('profile_domain')

In [None]:
final_input_xr

In [None]:
merged_df = final_input_xr.drop('longitude').drop('latitude').to_dataframe()

In [None]:
# remove rows where there are nans
clean_df_yy = merged_df.dropna()

In [None]:
clean_df_yy.shape

In [None]:
timetag = '20220407-1601'
normalisation_coeff = pd.read_csv(outputpath_nn + 'dataframe_norm_training_data_'+timetag+'.csv').set_index('Unnamed: 0')#.drop('Unnamed: 0', 1)
normalisation_coeff.index.name = None
normalisation_coeff_input = normalisation_coeff.drop(['melt_m_ice_per_y'], axis=1)

In [None]:
input_var = clean_df_yy.drop(['melt_m_ice_per_y'], axis=1)
ref_melt = clean_df_yy['melt_m_ice_per_y']

In [None]:
normalisation_coeff_input.loc['x_mean']

In [None]:
normalised_input_var = (input_var - normalisation_coeff_input.loc['x_mean'])/normalisation_coeff_input.loc['x_range']

In [None]:
x_val_arr = np.array(normalised_input_var)
y_val_arr = np.array(ref_melt)

In [None]:
model = keras.models.load_model(outputpath_nn + 'model_nn_'+timetag+'.h5')

In [None]:
y_out_norm = model.predict(x_val_arr)

In [None]:
y_out = (y_out_norm * normalisation_coeff['melt_m_ice_per_y'].loc['x_range']) + normalisation_coeff['melt_m_ice_per_y'].loc['x_mean']

In [None]:
y_out_pd_s = pd.Series(y_out[:,0],index=clean_df_yy.index,name='predicted_melt') 
y_target_pd_s = pd.Series(y_val_arr,index=clean_df_yy.index,name='reference_melt') 

In [None]:
y_out_xr = y_out_pd_s.to_xarray()
y_target_xr = y_target_pd_s.to_xarray()
y_to_compare = xr.merge([y_out_xr.T, y_target_xr.T]).sortby('y')

In [None]:
xx = range(0,80)
plt.figure()
plt.scatter(y_to_compare['predicted_melt'].values.flatten(),y_to_compare['reference_melt'].values.flatten(), s=10, edgecolors='None',alpha=0.01)
plt.plot(xx,xx,'k')

In [None]:
computed_melt = y_to_compare['predicted_melt']#.isel(time=0)
ref_melt = y_to_compare['reference_melt']#.isel(time=0)

In [None]:
min_m = min(computed_melt.min(), ref_melt.min())
max_m = max(computed_melt.max(), ref_melt.max())
lim = max(abs(min_m),abs(max_m))

if min_m < 0:
    cmap = mpl.cm.coolwarm
    minlim = -lim
else:
    cmap = mpl.cm.viridis
    minlim = 0

f = plt.figure(figsize=(15, 5))

ax1 = plt.subplot(1, 3, 1)
computed_melt.plot(ax=ax1, vmin=minlim,vmax=lim, cmap=cmap)
ax1.set_title('Neural Network [m ice/y]')

ax2 = plt.subplot(1, 3, 2, sharex = ax1, sharey = ax1)
ref_melt.plot(ax=ax2, vmin=minlim,vmax=lim, cmap=cmap)
ax2.set_title('Reference [m ice/y]')

ax3 = plt.subplot(1, 3, 3, sharex = ax1, sharey = ax1)
(computed_melt - ref_melt).plot(ax=ax3)
ax3.set_xticklabels('')
ax3.set_yticklabels('')
ax3.set_title('NN - Ref [m ice/y]')

f.tight_layout()

In [None]:
isf_stack_mask = dfmt.create_stacked_mask(file_isf['ISF_mask'], file_isf.Nisf, ['y','x'], 'mask_coord')

In [None]:
verbose=True
nisf_list = file_isf.Nisf

xx = file_isf.x
yy = file_isf.y
dx = (xx[2] - xx[1]).values
dy = (yy[2] - yy[1]).values
grid_cell_area = abs(dx*dy)  
grid_cell_area_weighted = file_isf_conc * grid_cell_area

In [None]:
def compute_Gt_per_y_from_m_per_y(nisf_list, melt_m_per_y, isf_stack_mask, 
                                  grid_cell_area_weighted, verbose):
    
    if verbose:
        time_start = time.time()
        print('WELCOME! AS YOU WISH, I WILL CONVERT M ICE PER Y TO GT PER Y')

    if verbose:
        list_loop = tqdm(nisf_list)
    else:
        list_loop = nisf_list

    melt1D_Gt_per_yr_list = []

    for kisf in list_loop:
        #print(kisf, n)

        grid_cell_area_weighted_isf = dfmt.choose_isf(grid_cell_area_weighted,isf_stack_mask, kisf)
        melt_rate_2D_isf_m_per_y = dfmt.choose_isf(melt_m_per_y,isf_stack_mask, kisf)
        #print('LINE1671', melt_rate_2D_isf_m_per_y.coords)

        melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * grid_cell_area_weighted_isf).sum(dim=['mask_coord']) * rho_i / 10**12
        melt1D_Gt_per_yr_list.append(melt_rate_1D_isf_Gt_per_y)

    melt1D_Gt_per_yr = xr.concat(melt1D_Gt_per_yr_list, dim='Nisf')

    melt1D_Gt_per_yr_ds = melt1D_Gt_per_yr.to_dataset(name='melt_1D_Gt_per_y')
    out_1D = melt1D_Gt_per_yr_ds

    if verbose:
        timelength = time.time() - time_start
        print("I AM DONE! IT TOOK: "+str(round(timelength,2))+" seconds.")
    
    return out_1D

In [None]:
computed_melt_Gt_per_y = compute_Gt_per_y_from_m_per_y(nisf_list, computed_melt, isf_stack_mask, 
                                  grid_cell_area_weighted, verbose)

In [None]:
ref_melt_Gt_per_y = compute_Gt_per_y_from_m_per_y(nisf_list, ref_melt, isf_stack_mask, 
                                  grid_cell_area_weighted, verbose)

In [None]:
computed_melt_Gt_per_y['melt_1D_Gt_per_y'].plot()

In [None]:
ref_melt_Gt_per_y['melt_1D_Gt_per_y'].plot()

In [None]:
region_list = []
for kisf in file_isf.Nisf:
    if file_isf['isf_name'].sel(Nisf=kisf).values in ['Ross','Nickerson','Sulzberger', 'Cook']:
        region_list.append('East and West Ross')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Filchner','Ronne']:
        region_list.append('Weddell')        
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Ekstrom','Nivl','Prince Harald','Riiser-Larsen','Fimbul','Roi Baudouin','Lazarev','Stancomb Brunt','Jelbart','Borchgrevink']:
        region_list.append('Dronning Maud Land')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Getz','Thwaites','Crosson','Dotson','Cosgrove','Pine Island']:
        region_list.append('Amundsen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Venable','George VI','Abbot','Stange','Larsen C','Bach','Larsen D','Wilkins']:
        region_list.append('Peninsula and Bellinghausen')
    elif file_isf['isf_name'].sel(Nisf=kisf).values in ['Amery','Moscow Univ.','Tracy Tremenchus','Totten','West','Shackleton']:
        region_list.append('East Antarctica')
    else:
        print('Argh, help me, '+file_isf['isf_name'].sel(Nisf=kisf).values+' has no region assigned!')
# 6 regions
# regions = ['East and West Ross','Weddell','Dronning Maud Land','Amundsen','Peninsula and Bellinghausen','East Antarctica']
file_isf['region'] = xr.DataArray(data=region_list,dims='Nisf')

In [None]:
def plot_scatter_all_isf(param_melt_tuned, target_Gt_yr, file_isf):
    
    regions = ['Weddell','Peninsula and Bellinghausen','Amundsen','East and West Ross','East Antarctica','Dronning Maud Land']
    colors = ['deepskyblue','brown','red','orange','limegreen','seagreen']
    symbol = ['o','v','>','p','*','s','<','^','X','d']
    
    f = plt.figure()
    f.set_size_inches(8.25*1.5, 8.25*1.5)

    ax={}

    i = 0
    nn=0
    ii=0
    marker_symbol = np.zeros((len(regions))).astype(int)
    for rr,reg in enumerate(regions):
        k = 0
        marker_color = colors[rr]
        subset_isf = file_isf.Nisf.where(file_isf['region']==reg,drop=True)
        for kisf in tqdm(subset_isf.Nisf):
            marker_type = symbol[k]

        
            #print(i)

            x_axis = param_melt_tuned.sel(Nisf=kisf)
            y_axis = target_Gt_yr.sel(Nisf=kisf)


            ax[i] = f.add_subplot(6,6,i+1)

            #for bb in x_axis_unc.bootstrap:
            #    ax[i].scatter(x_axis_unc.sel(bootstrap=bb), 
            #                y_axis,
            #                s=5, c='lightgrey', alpha=0.05, edgecolors='None',
            #                rasterized=True)
            ax[i].scatter(x_axis, 
                        y_axis,
                        s=10, c=marker_color, marker=marker_type, edgecolors='None',
                        rasterized=True)


            
            min_xy = min(x_axis.min(),y_axis.min())
            max_xy = max(x_axis.max(),y_axis.max())
            onetooneline = np.arange(min_xy,max_xy)
            ax[i].set_xlim(min_xy-5,max_xy+5)
            ax[i].set_ylim(min_xy-5,max_xy+5)
            ax[i].plot(onetooneline,onetooneline,'k-')
            #if i == 0 or i==3:
            #    ax[i].set_ylabel('Target melt [Gt/yr]')
            #if i == 3 or i==4:
            #    ax[i].set_xlabel('Thermal forcing term [(Gt*s)/(m*year)]')



            ax[i].set_title(str(file_isf['isf_name'].sel(Nisf=kisf).values))

            i = i+1
            k = k+1


    f.tight_layout()
    sns.despine()

    return f

In [None]:
plot_scatter_all_isf(computed_melt_Gt_per_y['melt_1D_Gt_per_y'], ref_melt_Gt_per_y['melt_1D_Gt_per_y'], file_isf)