In [None]:
"""
Created on Mon Mar 27 17:37 2023

Apply the ensemble of NN to Smith data and only take ensemble mean

Author: @claraburgard

"""

In [None]:
import numpy as np
import xarray as xr
from tqdm.notebook import trange, tqdm
#from tqdm import tqdm
import glob
import matplotlib as mpl
import seaborn as sns
import datetime
import time
import os,sys

import tensorflow as tf
from tensorflow import keras
from contextlib import redirect_stdout

from multimelt.constants import *
import summer_paper.data_formatting_NN as dfmt
import summer_paper.postprocessing_functions_NN as pp
import summer_paper.useful_functions as uf

DEFINE OPTIONS

In [None]:
mod_size =  'small' #'mini', 'small', 'medium', 'large', 'extra_large'
TS_opt = 'extrap' # extrap, whole, thermocline
norm_method =  'std' # std, interquart, minmax
exp_name = 'newbasic2'#'allbutconstants' #'onlyTSdraftandslope' #'TSdraftbotandiceddandwcd' #'onlyTSisfdraft' #'TSdraftbotandiceddandwcdreldGL' #TSdraftslopereldGL

In [None]:
outputpath_info = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/'

In [None]:
nemo_run = 'OPM031' # 'EPM031', 'EPM034'

In [None]:
home_path = '/bettik/burgardc/'

if nemo_run in ['ctrl94','isf94','isfru94']:
    inputpath_data='/bettik/burgardc/DATA/SUMMER_PAPER/interim/NEMO_'+nemo_run+'_ANT_STEREO/'
    inputpath_mask = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/ANTARCTICA_IS_MASKS/nemo_5km_'+nemo_run+'/'
    inputpath_plumes = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/PLUMES/nemo_5km_'+nemo_run+'/'
    inputpath_boxes = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/BOXES/nemo_5km_'+nemo_run+'/'
else:
    inputpath_data='/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/NEMO_eORCA025.L121_'+nemo_run+'_ANT_STEREO/'
    inputpath_mask = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_'+nemo_run+'/'
    inputpath_profiles = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/T_S_PROF/nemo_5km_'+nemo_run+'/'
    inputpath_plumes = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/PLUMES/nemo_5km_'+nemo_run+'/'
    inputpath_boxes = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/BOXES/nemo_5km_'+nemo_run+'/'
    
inputpath_CVinput = home_path+'/DATA/NN_PARAM/interim/INPUT_DATA/EXTRAPOLATED_ISFDRAFT_CHUNKS/'


In [None]:
# make the domain a little smaller to make the computation even more efficient - file isf has already been made smaller at its creation
map_lim = [-3000000,3000000]
file_mask_orig = xr.open_dataset(inputpath_data+'other_mask_vars_Ant_stereo.nc')
file_mask_orig_cut = dfmt.cut_domain_stereo(file_mask_orig, map_lim, map_lim)
file_other = xr.open_dataset(inputpath_data+'corrected_draft_bathy_isf.nc')#, chunks={'x': chunk_size, 'y': chunk_size})
file_other_cut = dfmt.cut_domain_stereo(file_other, map_lim, map_lim)
file_conc = xr.open_dataset(inputpath_data+'isfdraft_conc_Ant_stereo.nc')
file_conc_cut = dfmt.cut_domain_stereo(file_conc, map_lim, map_lim)

if nemo_run in ['ctrl94','isf94','isfru94']:
    file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_oneFRIS.nc')
else:
    file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_new_oneFRIS.nc')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
file_isf = file_isf_nonnan.sel(Nisf=large_isf)

# bathymetry, ice draft, concentration
file_bed_orig = file_mask_orig_cut['bathy_metry']
file_bed_corr = file_other_cut['corrected_isf_bathy']
file_draft = file_other_cut['corrected_isfdraft'] 
file_bed_goodGL = file_bed_orig.where(file_draft < file_bed_orig,file_bed_corr)
file_isf_conc = file_conc_cut['isfdraft_conc']

lon = file_isf.longitude
lat = file_isf.latitude

xx = file_isf.x
yy = file_isf.y
dx = (xx[2] - xx[1]).values
dy = (yy[2] - yy[1]).values
grid_cell_area_const = abs(dx*dy)  
grid_cell_area_weighted = file_isf_conc * grid_cell_area_const

ice_draft_pos = file_draft
ice_draft_neg = -ice_draft_pos

isf_stack_mask = dfmt.create_stacked_mask(file_isf['ISF_mask'], file_isf.Nisf, ['y','x'], 'mask_coord')

box_charac_all_2D = xr.open_dataset(inputpath_boxes + 'nemo_5km_boxes_2D_oneFRIS.nc')
box_charac_all_1D = xr.open_dataset(inputpath_boxes + 'nemo_5km_boxes_1D_oneFRIS.nc')
plume_charac = xr.open_dataset(inputpath_plumes+'nemo_5km_plume_characteristics_oneFRIS_corrected.nc').squeeze().drop('Nisf')

param_var_of_int_2D = file_isf[['ISF_mask', 'latitude', 'longitude', 'dGL']]
param_var_of_int_1D = file_isf[['front_bot_depth_avg', 'front_bot_depth_max','isf_name']]

geometry_info_2D = plume_charac.merge(param_var_of_int_2D).merge(ice_draft_pos.rename('ice_draft_pos')).merge(grid_cell_area_weighted.rename('grid_cell_area_weighted')).merge(file_isf_conc.rename('isfdraft_conc'))
geometry_info_1D = param_var_of_int_1D

norm_metrics_file = xr.open_dataset(inputpath_CVinput + 'metrics_norm_wholedataset_origexcept26_christoph_new.nc')
norm_metrics = norm_metrics_file.to_dataframe()

box_loc_config2 = box_charac_all_2D['box_location'].sel(box_nb_tot=box_charac_all_1D['nD_config'].sel(config=2))
box1 = box_loc_config2.where(box_loc_config2==1).isel(Nisf=1).drop('Nisf')

file_slope = xr.open_dataset(inputpath_mask+'nemo_5km_slope_info_bedrock_draft_latlon_oneFRIS.nc').squeeze().drop('Nisf')

In [None]:
### APPLY MODEL

input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

tuning_sort = 'new' #'old'

### use any model from CV over time
if tuning_sort == 'new':
    path_model = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/NN_MODELS/'
elif tuning_sort == 'old':
    path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/WHOLE/'


### APPLY MODEL

def apply_NN_results_2D_allisf_allyears(file_isf, norm_metrics, all_info_all, model, input_vars=[]):
    
    all_info_df = all_info_all.to_dataframe()
    val_norm = pp.normalise_vars(all_info_df[input_vars],
                                norm_metrics[input_vars].loc['mean_vars'],
                                norm_metrics[input_vars].loc['range_vars'])

    x_val_norm = val_norm
    
    batch_size=4096
    y_out_norm = model.predict(x_val_norm.values.astype('float64'),batch_size=batch_size,verbose = 1)

    y_out_norm_xr = xr.DataArray(data=y_out_norm.squeeze()).rename({'dim_0': 'index'})
    y_out_norm_xr = y_out_norm_xr.assign_coords({'index': x_val_norm.index})

    y_out_norm_xr_2D = uf.bring_back_to_2D(y_out_norm_xr)

    # denormalise the output
    y_out_xr = pp.denormalise_vars(y_out_norm_xr_2D, 
                             norm_metrics['melt_m_ice_per_y'].loc['mean_vars'],
                             norm_metrics['melt_m_ice_per_y'].loc['range_vars'])

    y_whole_grid = y_out_xr.reindex_like(file_isf['ISF_mask'])
    return y_whole_grid

if nemo_run in ['ctrl94','isf94','isfru94']:
    inputpath_profiles = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/T_S_PROF/nemo_5km_'+nemo_run+'/'
    outputpath_melt = '/bettik/burgardc/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_TUNING/nemo_5km_'+nemo_run+'/'
else:
    inputpath_profiles = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/T_S_PROF/nemo_5km_'+nemo_run+'/'
    outputpath_melt = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/processed/MELT_RATE/nemo_5km_'+nemo_run+'/'


if nemo_run in ['ctrl94','isf94','isfru94']:
    file_TS_orig = xr.open_dataset(inputpath_profiles+'T_S_mean_prof_corrected_km_contshelf_allyy.nc')
else:
    file_TS_orig = xr.open_dataset(inputpath_profiles+'T_S_mean_prof_corrected_km_contshelf_and_offshore_1980-2018_oneFRIS.nc')
    file_TS_orig = file_TS_orig.sel(Nisf=large_isf).sel(profile_domain=50).squeeze().drop('profile_domain')#.isel(time=range(40,70))

file_TS = file_TS_orig#.sel(time=file_TS_orig.time[-10::])
filled_TS = file_TS.ffill(dim='depth').sel(Nisf=file_isf.Nisf) #, 'profile_domain': 1})

print('Prepare input variables')
n = 0
for kisf in tqdm(file_isf.Nisf):
    depth_kisf = uf.choose_isf(ice_draft_pos,isf_stack_mask, kisf)
    depth_of_int0 = depth_kisf.where(depth_kisf < file_isf['front_bot_depth_max'].sel(Nisf=kisf), 
                                   file_isf['front_bot_depth_max'].sel(Nisf=kisf))
    depth_of_int = depth_of_int0.where(depth_kisf > file_isf['front_ice_depth_min'].sel(Nisf=kisf), 
                                   file_isf['front_ice_depth_min'].sel(Nisf=kisf))

    T_isf = filled_TS['theta_ocean'].sel(Nisf=kisf).interp({'depth': depth_of_int}).drop('depth')
    S_isf = filled_TS['salinity_ocean'].sel(Nisf=kisf).interp({'depth': depth_of_int}).drop('depth')

    isf_conc_kisf = uf.choose_isf(file_isf_conc,isf_stack_mask, kisf) 

    T_mean_cav = uf.weighted_mean(T_isf, 'mask_coord', isf_conc_kisf).to_dataset(name='T_mean')
    S_mean_cav = uf.weighted_mean(S_isf, 'mask_coord', isf_conc_kisf).to_dataset(name='S_mean')
    T_std_cav = uf.weighted_std(T_isf, 'mask_coord', isf_conc_kisf).to_dataset(name='T_std')
    S_std_cav = uf.weighted_std(S_isf, 'mask_coord', isf_conc_kisf).to_dataset(name='S_std')
    T_S_2D_meanstd_kisf = xr.merge([T_mean_cav,S_mean_cav,T_std_cav,S_std_cav])

    T_S_info_br, TSmean_br = xr.broadcast(xr.merge([T_isf.rename('theta_in'),S_isf.rename('salinity_in')]),T_S_2D_meanstd_kisf)
    TS_info_all = xr.merge([T_S_info_br, TSmean_br])

    file_isf_kisf = uf.choose_isf(file_isf[['dGL', 'dIF']], isf_stack_mask, kisf)
    bathy_kisf = uf.choose_isf(file_bed_goodGL, isf_stack_mask, kisf)
    slope_kisf = uf.choose_isf(file_slope, isf_stack_mask, kisf)
    geometry_kisf = xr.merge([file_isf_kisf,
                         depth_kisf.rename('corrected_isfdraft'),
                         bathy_kisf.rename('bathy_metry'),
                         slope_kisf])


    geometry_2D_br, time_dpdt_in_br = xr.broadcast(geometry_kisf,TS_info_all)
    all_info = xr.merge([geometry_2D_br, time_dpdt_in_br])

    if n == 0:
        all_info_all = all_info.squeeze().drop('Nisf')
    else:
        all_info_all =  all_info_all.combine_first(all_info).squeeze().drop('Nisf')
    n = n+1        

mod_size = 'small'

res_2D_list = []
res2D_sum = 0
for seed_nb in range(1,11):
    print('seed_nb',seed_nb)

    model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_wholedataset_'+str(seed_nb).zfill(2)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')

    print('Computing 2D patterns')
    res2D = apply_NN_results_2D_allisf_allyears(file_isf, norm_metrics, all_info_all, model, input_vars)
    res2D_sum = res2D_sum + res2D
    #res_2D_list.append(res2D.assign_coords({'seed_nb': seed_nb}).chunk({'time': 20}))

#res2D_all = xr.concat(res_2D_list, dim='seed_nb')
#xr_ensmean_res2D = res2D_all.mean('seed_nb').load()
xr_ensmean_res2D = res2D_sum / 10
#del res2D_all

print('Compute the 1D evalmetrics')
res_1D_list = []
for kisf in tqdm(file_isf.Nisf.values): 

    geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
    melt_rate_2D_isf_m_per_y = dfmt.choose_isf(xr_ensmean_res2D,isf_stack_mask, kisf)

    melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12

    box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
    param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

    melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['grid_cell_area_weighted'])     

    out_1D = xr.concat([melt_rate_1D_isf_Gt_per_y, melt_rate_1D_isf_myr_box1_mean], dim='metrics').assign_coords({'metrics': ['Gt','box1']})
    res_1D_list.append(out_1D) 

res_1D_tt = xr.concat(res_1D_list, dim='Nisf')

res_1D_tt.rename('predicted_melt').to_netcdf(outputpath_melt + 'evalmetrics_1D_'+mod_size+'_newbasic2_extrap_normstd_newtuning.nc')
                                             

In [None]:
###### DO NOT USE ANYMORE

In [None]:
file_info = pd.read_csv(outputpath_info+'info_chunks.txt', delimiter=',', header=None)
file_info = file_info.set_index(file_info[0])

for chunk_nb in file_info[file_info[1]==nemo_run][0].values:
    start_yy = file_info[file_info[1]==nemo_run][2].loc[chunk_nb]
    end_yy = file_info[file_info[1]==nemo_run][3].loc[chunk_nb]
    trange = range(start_yy,end_yy+1)
    print(chunk_nb,start_yy,end_yy)

In [None]:
if (end_yy - start_yy) == 9:
    tblock_dim = [chunk_nb]
else:
    tblock_dim = [chunk_nb-1,chunk_nb]
print(tblock_dim)

READ IN DATA

In [None]:
inputpath_data_nn = '/bettik/burgardc/DATA/NN_PARAM/interim/INPUT_DATA/'

In [None]:
if TS_opt == 'extrap':
    inputpath_CVinput = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS/'
    inputpath_csv = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS/'

APPLY MODEL

In [None]:
input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

In [None]:
tuning_sort = 'new'

In [None]:
### use any model from CV over time
outputpath_melt = '/bettik/burgardc/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_CHECK_TUNING/nemo_5km_'+nemo_run+'/'
if tuning_sort == 'new':
    path_model = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/NN_MODELS/'
elif tuning_sort == 'old':
    path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/WHOLE/'

file_isf, geometry_info_2D, box_charac_2D, box_charac_1D, isf_stack_mask = pp.read_input_evalmetrics_NN(nemo_run)

norm_metrics_file = xr.open_dataset(inputpath_CVinput + 'metrics_norm_wholedataset_origexcept26_christoph_new.nc')
norm_metrics = norm_metrics_file.to_dataframe()

box_loc_config2 = box_charac_2D['box_location'].sel(box_nb_tot=box_charac_1D['nD_config'].sel(config=2))
box1 = box_loc_config2.where(box_loc_config2==1).isel(Nisf=1).drop('Nisf')

In [None]:
for tblock in tblock_dim:
    
    res_1D_list = []
    for kisf in tqdm(file_isf.Nisf.values): 
        
        df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_isf'+str(kisf).zfill(3)+'_'+str(tblock).zfill(3)+'_new.csv',index_col=[0,1,2])

        ens_res2D_list = []
        #for seed_nb in range(1,11):
        for seed_nb in range(1,2):
            model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_wholedataset_'+str(seed_nb).zfill(2)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')

            res_2D = pp.apply_NN_results_2D_1isf_1tblock(file_isf, norm_metrics, df_nrun, model, input_vars)

            ens_res2D_list.append(res_2D.assign_coords({'seed_nb': seed_nb}))

        xr_ens_res2D = xr.concat(ens_res2D_list, dim='seed_nb')
        xr_ensmean_res2D = xr_ens_res2D.mean('seed_nb')

        geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
        melt_rate_2D_isf_m_per_y = dfmt.choose_isf(xr_ensmean_res2D,isf_stack_mask, kisf)

        melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12

        box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
        param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

        melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['isfdraft_conc'])     

        out_1D = xr.concat([melt_rate_1D_isf_Gt_per_y, melt_rate_1D_isf_myr_box1_mean], dim='metrics').assign_coords({'metrics': ['Gt','box1']})
        res_1D_list.append(out_1D) 
    
    res_1D_all = xr.concat(res_1D_list, dim='Nisf')
    res_1D_all.to_netcdf(outputpath_melt + 'evalmetrics_1D_'+mod_size+'_'+exp_name+'_ensmean_'+TS_opt+'_norm'+norm_method+'_'+str(tblock).zfill(3)+'_'+tuning_sort+'tuning.nc')