In [None]:
"""
Created on Mon Mar 27 17:37 2023

Apply the ensemble of NN to Smith data and only take ensemble mean

Author: @claraburgard

"""

In [None]:
import numpy as np
import xarray as xr
from tqdm.notebook import trange, tqdm
#from tqdm import tqdm
import glob
import matplotlib as mpl
import seaborn as sns
import datetime
import time
import os,sys

import tensorflow as tf
from tensorflow import keras
from contextlib import redirect_stdout

from multimelt.constants import *
import summer_paper.data_formatting_NN as dfmt
import summer_paper.postprocessing_functions_NN as pp


DEFINE OPTIONS

In [None]:
mod_size =  'large' #'mini', 'small', 'medium', 'large', 'extra_large'
TS_opt = 'extrap' # extrap, whole, thermocline
norm_method =  'std' # std, interquart, minmax
exp_name = 'newbasic2'#'allbutconstants' #'onlyTSdraftandslope' #'TSdraftbotandiceddandwcd' #'onlyTSisfdraft' #'TSdraftbotandiceddandwcdreldGL' #TSdraftslopereldGL

In [None]:
outputpath_info = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/'

In [None]:
nemo_run = 'ctrl94' # 'EPM031', 'EPM034'

In [None]:
file_info = pd.read_csv(outputpath_info+'info_chunks.txt', delimiter=',', header=None)
file_info = file_info.set_index(file_info[0])

for chunk_nb in file_info[file_info[1]==nemo_run][0].values:
    start_yy = file_info[file_info[1]==nemo_run][2].loc[chunk_nb]
    end_yy = file_info[file_info[1]==nemo_run][3].loc[chunk_nb]
    trange = range(start_yy,end_yy+1)
    print(chunk_nb,start_yy,end_yy)

In [None]:
if (end_yy - start_yy) == 9:
    tblock_dim = [chunk_nb]
else:
    tblock_dim = [chunk_nb-1,chunk_nb]
print(tblock_dim)

READ IN DATA

In [None]:
inputpath_data_nn = '/bettik/burgardc/DATA/NN_PARAM/interim/INPUT_DATA/'

In [None]:
if TS_opt == 'extrap':
    inputpath_CVinput = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS/'
    inputpath_csv = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS/'

APPLY MODEL

In [None]:
input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

In [None]:
tuning_sort = 'new'

In [None]:
### use any model from CV over time
outputpath_melt = '/bettik/burgardc/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_CHECK_TUNING/nemo_5km_'+nemo_run+'/'
if tuning_sort == 'new':
    path_model = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/NN_MODELS/'
elif tuning_sort == 'old':
    path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/WHOLE/'

file_isf, geometry_info_2D, box_charac_2D, box_charac_1D, isf_stack_mask = pp.read_input_evalmetrics_NN(nemo_run)

norm_metrics_file = xr.open_dataset(inputpath_CVinput + 'metrics_norm_wholedataset_origexcept26_christoph_new.nc')
norm_metrics = norm_metrics_file.to_dataframe()

box_loc_config2 = box_charac_2D['box_location'].sel(box_nb_tot=box_charac_1D['nD_config'].sel(config=2))
box1 = box_loc_config2.where(box_loc_config2==1).isel(Nisf=1).drop('Nisf')

In [None]:
for tblock in tblock_dim:
    
    res_1D_list = []
    for kisf in tqdm(file_isf.Nisf.values): 
        
        df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_isf'+str(kisf).zfill(3)+'_'+str(tblock).zfill(3)+'_new.csv',index_col=[0,1,2])

        ens_res2D_list = []
        #for seed_nb in range(1,11):
        for seed_nb in range(1,2):
            model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_wholedataset_'+str(seed_nb).zfill(2)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')

            res_2D = pp.apply_NN_results_2D_1isf_1tblock(file_isf, norm_metrics, df_nrun, model, input_vars)

            ens_res2D_list.append(res_2D.assign_coords({'seed_nb': seed_nb}))

        xr_ens_res2D = xr.concat(ens_res2D_list, dim='seed_nb')
        xr_ensmean_res2D = xr_ens_res2D.mean('seed_nb')

        geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
        melt_rate_2D_isf_m_per_y = dfmt.choose_isf(xr_ensmean_res2D,isf_stack_mask, kisf)

        melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12

        box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
        param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

        melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['isfdraft_conc'])     

        out_1D = xr.concat([melt_rate_1D_isf_Gt_per_y, melt_rate_1D_isf_myr_box1_mean], dim='metrics').assign_coords({'metrics': ['Gt','box1']})
        res_1D_list.append(out_1D) 
    
    res_1D_all = xr.concat(res_1D_list, dim='Nisf')
    res_1D_all.to_netcdf(outputpath_melt + 'evalmetrics_1D_'+mod_size+'_'+exp_name+'_ensmean_'+TS_opt+'_norm'+norm_method+'_'+str(tblock).zfill(3)+'_'+tuning_sort+'tuning.nc')