In [None]:
"""
Created on Mon Mar 27 17:37 2023

Apply the ensemble of NN to Smith data and only take ensemble mean

Author: @claraburgard

"""

In [None]:
import numpy as np
import xarray as xr
from tqdm.notebook import trange, tqdm
#from tqdm import tqdm
import glob
import matplotlib as mpl
import seaborn as sns
import datetime
import time
import os,sys

import tensorflow as tf
from tensorflow import keras
from contextlib import redirect_stdout

from multimelt.constants import *
import multimelt.useful_functions as uf
import summer_paper.data_formatting_NN as dfmt
import summer_paper.postprocessing_functions_NN as pp


DEFINE OPTIONS

In [None]:
#mod_size =  'xsmall96' #'mini', 'small', 'medium', 'large', 'extra_large'
TS_opt = 'extrap' # extrap, whole, thermocline
norm_method =  'std' # std,interquart, minmax
exp_name = 'newbasic2'#'allbutconstants' #'onlyTSdraftandslope' #'TSdraftbotandiceddandwcd' #'onlyTSisfdraft' #'TSdraftbotandiceddandwcdreldGL' #TSdraftslopereldGL

In [None]:
mod = 'CESM2-WACCM'
scenario = 'historical'

to2300 = False

if scenario == 'historical':
    yystart = 1980
    yyend = 2014
else:
    if to2300:
        yystart = 2015
        yyend = 2300
    else:
        yystart = 2015
        yyend = 2100   

READ IN DATA

In [None]:
home_path = '/bettik/burgardc/'

inputpath_data_nn = '/bettik/burgardc/DATA/NN_PARAM/interim/INPUT_DATA/'
inputpath_csv = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/INPUT_DATA/CMIP/'+mod+'/'
inputpath_mask = home_path+'/DATA/SUMMER_PAPER/interim/ANTARCTICA_IS_MASKS/BedMachine_4km/'
inputpath_data=home_path+'/DATA/SUMMER_PAPER/interim/'
inputpath_CVinput = home_path+'/DATA/NN_PARAM/interim/INPUT_DATA/EXTRAPOLATED_ISFDRAFT_CHUNKS/'
inputpath_plumes = home_path+'/DATA/SUMMER_PAPER/interim/PLUMES/BedMachine_4km/'
inputpath_boxes = home_path+'/DATA/SUMMER_PAPER/interim/BOXES/BedMachine_4km/'

# make the domain a little smaller to make the computation even more efficient - file isf has already been made smaller at its creation
map_lim = [-3000000,3000000]

file_isf_orig = xr.open_dataset(inputpath_mask+'BedMachinev2_4km_isf_masks_and_info_and_distance_oneFRIS.nc')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
rignot_isf = file_isf_nonnan.Nisf.where(np.isfinite(file_isf_nonnan['isf_area_rignot']), drop=True)
file_isf = file_isf_nonnan.sel(Nisf=rignot_isf)
file_isf['isf_name'] = file_isf['isf_name'].astype(str)
file_isf['region'] = file_isf['region'].astype(str)

BedMachine_orig = xr.open_dataset(inputpath_data+'BedMachine_v2_aggregated4km_allvars.nc')
BedMachine_orig_cut = dfmt.cut_domain_stereo(BedMachine_orig, map_lim, map_lim)
file_bed_goodGL = -1*BedMachine_orig_cut['bed']
file_draft = (BedMachine_orig_cut['thickness'] - BedMachine_orig_cut['surface']).where(file_isf['ISF_mask'] > 1)
file_isf_conc = BedMachine_orig_cut['isf_conc']

grid_cell_area_file = xr.open_dataset(inputpath_data+'gridarea_ISMIP6_AIS_4000m_grid.nc').sel(x=file_isf.x,y=file_isf.y)
true_grid_cell_area = grid_cell_area_file['cell_area'].drop('lon').drop('lat')
cell_area_weight = true_grid_cell_area/(4000 * 4000)

lon = file_isf.longitude
lat = file_isf.latitude

xx = file_isf.x
yy = file_isf.y
dx = (xx[2] - xx[1]).values
dy = (yy[2] - yy[1]).values
grid_cell_area_const = abs(dx*dy)  
grid_cell_area_weighted = file_isf_conc * grid_cell_area_const * cell_area_weight

ice_draft_pos = file_draft
ice_draft_neg = -ice_draft_pos

isf_stack_mask = uf.create_stacked_mask(file_isf['ISF_mask'], file_isf.Nisf, ['y','x'], 'mask_coord')

box_charac_all_2D = xr.open_dataset(inputpath_boxes + 'BedMachine_4km_boxes_2D_oneFRIS.nc')
box_charac_all_1D = xr.open_dataset(inputpath_boxes + 'BedMachine_4km_boxes_1D_oneFRIS.nc')
plume_charac = xr.open_dataset(inputpath_plumes+'BedMachine_4km_plume_characteristics.nc')

param_var_of_int_2D = file_isf[['ISF_mask', 'latitude', 'longitude', 'dGL']]
param_var_of_int_1D = file_isf[['front_bot_depth_avg', 'front_bot_depth_max','isf_name']]

geometry_info_2D = plume_charac.merge(param_var_of_int_2D).merge(ice_draft_pos.rename('ice_draft_pos')).merge(grid_cell_area_weighted.rename('grid_cell_area_weighted')).merge(file_isf_conc.rename('isfdraft_conc'))
geometry_info_1D = param_var_of_int_1D

norm_metrics_file = xr.open_dataset(inputpath_CVinput + 'metrics_norm_wholedataset_origexcept26_christoph_new.nc')
norm_metrics = norm_metrics_file.to_dataframe()

box_loc_config2 = box_charac_all_2D['box_location'].sel(box_nb_tot=box_charac_all_1D['nD_config'].sel(config=2))
box1 = box_loc_config2.where(box_loc_config2==1).isel(Nisf=1).drop('Nisf')

In [None]:
### APPLY MODEL

input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

tuning_sort = 'new' #'old'

### use any model from CV over time
outputpath_melt = home_path+'/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_CMIP/'+mod+'/'
if tuning_sort == 'new':
    path_model = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/NN_MODELS/'
elif tuning_sort == 'old':
    path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/WHOLE/'

def apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_nrun, model, input_vars=[]):
    """
    Compute 2D melt based on a given NN model
    
    """

    val_norm = pp.normalise_vars(df_nrun[input_vars],
                                norm_metrics[input_vars].loc['mean_vars'],
                                norm_metrics[input_vars].loc['range_vars'])

    x_val_norm = val_norm

    y_out_norm = model.predict(x_val_norm.values.astype('float64'),verbose = 0)

    y_out_norm_xr = xr.DataArray(data=y_out_norm.squeeze()).rename({'dim_0': 'index'})
    y_out_norm_xr = y_out_norm_xr.assign_coords({'index': x_val_norm.index})

    # denormalise the output
    y_out = pp.denormalise_vars(y_out_norm_xr, 
                             norm_metrics['melt_m_ice_per_y'].loc['mean_vars'],
                             norm_metrics['melt_m_ice_per_y'].loc['range_vars'])

    y_out_pd_s = pd.Series(y_out.values,index=df_nrun.index,name='predicted_melt') 

    # put some order in the file
    y_out_xr = y_out_pd_s.to_xarray().sortby('y')

    y_whole_grid = y_out_xr.reindex_like(file_isf['ISF_mask'])
    return y_whole_grid

### If no ensemble of models

for mod_size in ['xsmall96','small','large']:
    
    print('NN size',mod_size)
    seed_nb = 1
    model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_wholedataset_'+str(seed_nb).zfill(2)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')

    print('prepare 2D input')
    li = []
    for tt in tqdm(range(yystart,yyend+1)):    
        #print('Combining all dataframes')

        df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt)+'.csv',index_col=[0,1,2])
        li.append(df_nrun)

    print('Concatenating input')
    df_allyy = pd.concat(li)

    print('Computing 2D patterns')
    xr_ensmean_res2D = apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_allyy, model, input_vars)

    print('Compute the 1D evalmetrics')
    res_1D_list = []
    for kisf in tqdm(file_isf.Nisf.values): 

        geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
        melt_rate_2D_isf_m_per_y = dfmt.choose_isf(xr_ensmean_res2D,isf_stack_mask, kisf)

        melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12

        box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
        param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

        melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['grid_cell_area_weighted'])     

        out_1D = xr.concat([melt_rate_1D_isf_Gt_per_y, melt_rate_1D_isf_myr_box1_mean], dim='metrics').assign_coords({'metrics': ['Gt','box1']})
        res_1D_list.append(out_1D) 

    res_1D_all = xr.concat(res_1D_list, dim='Nisf')
    res_1D_all.rename('predicted_melt').to_netcdf(outputpath_melt + 'evalmetrics_1D_NN'+mod_size+'_'+scenario+'.nc')

###############

In [None]:

res_1D_all_list = []
for tt in tqdm(range(yystart,yyend+1)):
#for tt in tqdm(range(2015,2017)):
    
    #print('Combining all dataframes')
    li = []
    for kisf in file_isf.Nisf.values: 
        
        df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_isf'+
                              str(kisf).zfill(3)+'_'+mod+'_'+scenario+'_'+
                              str(tt)+'.csv',index_col=[0,1])
        #df_nrun['bathy_metry'] = -1*df_nrun['bathy_metry']
        li.append(df_nrun)

    df_allyy = pd.concat(li)
    
    ens_res2D_list = []
    #for seed_nb in range(1,11):
    for seed_nb in range(1,2):
        model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_wholedataset_'+str(seed_nb).zfill(2)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')

        res_2D = apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_allyy, model, input_vars)

        ens_res2D_list.append(res_2D.assign_coords({'seed_nb': seed_nb}))

    xr_ens_res2D = xr.concat(ens_res2D_list, dim='seed_nb')
    xr_ensmean_res2D = xr_ens_res2D.mean('seed_nb')
    
    #print('Compute the 1D evalmetrics')
    res_1D_list = []
    for kisf in file_isf.Nisf.values: 
        
        geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
        melt_rate_2D_isf_m_per_y = dfmt.choose_isf(xr_ensmean_res2D,isf_stack_mask, kisf)

        melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12

        box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
        param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

        melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['isfdraft_conc'])     

        out_1D = xr.concat([melt_rate_1D_isf_Gt_per_y, melt_rate_1D_isf_myr_box1_mean], dim='metrics').assign_coords({'metrics': ['Gt','box1']})
        res_1D_list.append(out_1D) 
    
    res_1D_all = xr.concat(res_1D_list, dim='Nisf')
    res_1D_all_list.append(res_1D_all.assign_coords({'time': tt}))

res_1D_all.to_netcdf(outputpath_melt + 'evalmetrics_1D_NN'+mod_size+'_'+scenario+'.nc')