In [None]:
"""
Created on Mon Mar 27 17:37 2023

Apply the ensemble of NN to Smith data and only take ensemble mean

Author: @claraburgard

"""

In [None]:
import numpy as np
import xarray as xr
#from tqdm.notebook import trange, tqdm
from tqdm import tqdm
import glob
import matplotlib as mpl
import seaborn as sns
import datetime
import time
import os,sys

import tensorflow as tf
from tensorflow import keras
from contextlib import redirect_stdout

from multimelt.constants import *
import summer_paper.useful_functions as uf
import summer_paper.data_formatting_NN as dfmt
import summer_paper.postprocessing_functions_NN as pp

import os.path


DEFINE OPTIONS

In [None]:
#mod_size =  'xsmall96' #'mini', 'small', 'medium', 'large', 'extra_large'
TS_opt = 'extrap' # extrap, whole, thermocline
norm_method =  'std' # std,interquart, minmax
exp_name = 'newbasic2'#'allbutconstants' #'onlyTSdraftandslope' #'TSdraftbotandiceddandwcd' #'onlyTSisfdraft' #'TSdraftbotandiceddandwcdreldGL' #TSdraftslopereldGL

In [None]:
home_path = '/bettik/burgardc/'

In [None]:
mod = 'GFDL-CM4'
scenario = 'historical'

if mod in ['CNRM-CM6-1','CNRM-ESM2-1']:
    to2300 = False
elif mod in ['GISS-E2-1-H']:
    to2300 = True
elif mod in ['ACCESS-CM2','ACCESS-ESM1-5','CESM2-WACCM','CanESM5','IPSL-CM6A-LR','MRI-ESM2-0']:
    to2300 = True
elif mod in ['MPI-ESM1-2-HR','GFDL-CM4','GFDL-ESM4']:
    to2300 = False
elif mod == 'UKESM1-0-LL':
    to2300 = True     
elif mod == 'CESM2':
    to2300 = False 

if scenario == 'historical':
    yystart = 1980
    yyend = 2014
    #yyend = 1981
elif scenario == 'ssp245':
    yystart = 2015
    yyend = 2100  
else:
    if to2300:
        yystart = 2015
        yyend = 2300
    else:
        yystart = 2015
        yyend = 2100   

READ IN DATA

In [None]:
geoyear = 2150

In [None]:
home_path = '/bettik/burgardc/'

inputpath_mask = home_path+'/DATA/SUMMER_PAPER/interim/ANTARCTICA_IS_MASKS/ElmerIce_'+str(geoyear)+'/'
inputpath_data=home_path+'/DATA/SUMMER_PAPER/interim/'
inputpath_CVinput = home_path+'/DATA/NN_PARAM/interim/INPUT_DATA/EXTRAPOLATED_ISFDRAFT_CHUNKS/'
inputpath_plumes = home_path+'/DATA/SUMMER_PAPER/interim/PLUMES/ElmerIce_'+str(geoyear)+'/'
inputpath_boxes = home_path+'/DATA/SUMMER_PAPER/interim/BOXES/ElmerIce_'+str(geoyear)+'/'

# make the domain a little smaller to make the computation even more efficient - file isf has already been made smaller at its creation
map_lim = [-3000000,3000000]

file_isf_orig = xr.open_dataset(inputpath_mask+'ElmerIce_4km_'+str(geoyear)+'isf_masks_and_info_and_distance_oneFRIS.nc')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
sorted_isf_rignot = [11,69,43,28,12,57,
                     70,44,29,13,58,71,45,30,14,
                     59,72,46,
                     31,
                     15,61,73,47,32,16,48,33,17,62,49,34,18,63,74,
                     50,35,19,64,
                     10,
                     36,20,65,51,37,
                     22,38,52,23,66,53,39,24,
                     67,40,54,75,25,41,
                     26,42,55,68,60,27]
file_isf = file_isf_nonnan.sel(Nisf=sorted_isf_rignot)
file_isf['isf_name'] = file_isf['isf_name'].astype(str)

rignot_isf = sorted_isf_rignot

inputpath_ElmerIce='/bettik/burgardc/DATA/SUMMER_PAPER/interim/ELMERICE_NEWGEO/'
BedMachine_orig = xr.open_dataset(inputpath_ElmerIce+'ElmerIce_4km_allvars_'+str(geoyear)+'.nc')
file_BedMachine = dfmt.cut_domain_stereo(BedMachine_orig, map_lim, map_lim)
file_bed_goodGL = -1*file_BedMachine['bed']
file_draft = (file_BedMachine['thickness'] - file_BedMachine['surface']).where(file_isf['ISF_mask'] > 1)
file_isf_conc = file_BedMachine['isf_conc']

grid_cell_area_file = xr.open_dataset(inputpath_data+'gridarea_ISMIP6_AIS_4000m_grid.nc').sel(x=file_isf.x,y=file_isf.y)
true_grid_cell_area = grid_cell_area_file['cell_area'].drop_vars('lon').drop_vars('lat')
cell_area_weight = true_grid_cell_area/(4000 * 4000)

lon = file_isf.longitude
lat = file_isf.latitude

xx = file_isf.x
yy = file_isf.y
dx = (xx[2] - xx[1]).values
dy = (yy[2] - yy[1]).values
grid_cell_area_const = abs(dx*dy)  
grid_cell_area_weighted = file_isf_conc * grid_cell_area_const * cell_area_weight

ice_draft_pos = file_draft
ice_draft_neg = -ice_draft_pos

isf_stack_mask = uf.create_stacked_mask(file_isf['ISF_mask'], file_isf.Nisf, ['y','x'], 'mask_coord')

box_charac_all_2D = xr.open_dataset(inputpath_boxes + 'ElmerIce_4km_'+str(geoyear)+'_boxes_2D_oneFRIS.nc')
box_charac_all_1D = xr.open_dataset(inputpath_boxes + 'ElmerIce_4km_'+str(geoyear)+'_boxes_1D_oneFRIS.nc')
plume_charac_old = xr.open_dataset(inputpath_plumes+'ElmerIce_'+str(geoyear)+'_plume_characteristics.nc')
plume_charac_new = xr.open_dataset(inputpath_plumes+'ElmerIce_'+str(geoyear)+'_plume_characteristics_lazero_comparison_mixedshift.nc')
plume_charac = xr.concat([plume_charac_old.drop_sel(option='lazero'),plume_charac_new.sel(option='new_lazero')], dim='option').assign_coords({'option': ['cavity','local','lazero']})

param_var_of_int_2D = file_isf[['ISF_mask', 'latitude', 'longitude', 'dGL']]
param_var_of_int_1D = file_isf[['front_bot_depth_avg', 'front_bot_depth_max','isf_name']]

geometry_info_2D = plume_charac.merge(param_var_of_int_2D).merge(ice_draft_pos.rename('ice_draft_pos')).merge(grid_cell_area_weighted.rename('grid_cell_area_weighted')).merge(file_isf_conc.rename('isfdraft_conc'))
geometry_info_1D = param_var_of_int_1D

norm_metrics_file = xr.open_dataset(inputpath_CVinput + 'metrics_norm_wholedataset_origexcept26_christoph_new.nc')
norm_metrics = norm_metrics_file.to_dataframe()

box_loc_config2 = box_charac_all_2D['box_location'].sel(box_nb_tot=box_charac_all_1D['nD_config'].sel(config=2))
box1 = box_loc_config2.where(box_loc_config2==1).isel(Nisf=1).drop('Nisf')

file_slope = xr.open_dataset(inputpath_mask+'ElmerIce_4km_'+str(geoyear)+'_slope_info_bedrock_draft_latlon_oneFRIS.nc')

ONLY ONE ENSEMBLE MEMBER OF THE NN DEEP ENSEMBLE

In [None]:
### APPLY MODEL

input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

tuning_sort = 'new' #'old'

### use any model from CV over time
if tuning_sort == 'new':
    path_model = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/NN_MODELS/'
elif tuning_sort == 'old':
    path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/WHOLE/'
    
scenario = 'ssp245'

### APPLY MODEL

def apply_NN_results_2D_allisf_allyears(file_isf, norm_metrics, all_info_all, model, input_vars=[]):
    
    all_info_df = all_info_all.to_dataframe()
    val_norm = pp.normalise_vars(all_info_df[input_vars],
                                norm_metrics[input_vars].loc['mean_vars'],
                                norm_metrics[input_vars].loc['range_vars'])

    x_val_norm = val_norm

    y_out_norm = model.predict(x_val_norm.values.astype('float64'),verbose = 1)

    y_out_norm_xr = xr.DataArray(data=y_out_norm.squeeze()).rename({'dim_0': 'index'})
    y_out_norm_xr = y_out_norm_xr.assign_coords({'index': x_val_norm.index})

    y_out_norm_xr_2D = uf.bring_back_to_2D(y_out_norm_xr)

    # denormalise the output
    y_out_xr = pp.denormalise_vars(y_out_norm_xr_2D, 
                             norm_metrics['melt_m_ice_per_y'].loc['mean_vars'],
                             norm_metrics['melt_m_ice_per_y'].loc['range_vars'])

    y_whole_grid = y_out_xr.reindex_like(file_isf['ISF_mask'])
    return y_whole_grid

### If no ensemble of models
for mod in ['ACCESS-CM2','ACCESS-ESM1-5','CESM2','CESM2-WACCM','CNRM-CM6-1','CNRM-ESM2-1',
           'CanESM5','GFDL-CM4','GFDL-ESM4','GISS-E2-1-H','IPSL-CM6A-LR','MPI-ESM1-2-HR',
           'MRI-ESM2-0','UKESM1-0-LL']: #,,
    print(mod)
    
    if mod in ['CNRM-CM6-1','CNRM-ESM2-1']:
        to2300 = False
    elif mod in ['GISS-E2-1-H']:
        to2300 = True
    elif mod in ['ACCESS-CM2','ACCESS-ESM1-5','CESM2-WACCM','CanESM5','IPSL-CM6A-LR','MRI-ESM2-0']:
        to2300 = True
    elif mod in ['MPI-ESM1-2-HR','GFDL-CM4','GFDL-ESM4']:
        to2300 = False
    elif mod == 'UKESM1-0-LL':
        to2300 = True     
    elif mod == 'CESM2':
        to2300 = False        

    if scenario == 'historical':
        yystart = 1980 #1850
        yyend = 2014
    elif scenario == 'ssp245':
        yystart = 2015
        yyend = 2100  
    else:
        if to2300:
            yystart = 2015
            yyend = 2300
        else:
            yystart = 2015
            yyend = 2100 

    inputpath_profiles='/bettik/burgardc/DATA/SUMMER_PAPER/interim/T_S_PROF/CMIP/'+mod+'/'
    outputpath_melt = home_path+'/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_CMIP/'+mod+'/'


            
    # PREPARE VARIABLES
    file_TS_list = []
    for tt in range(yystart,yyend+1):
        file_T_orig = xr.open_dataset(inputpath_profiles+'T_mean_prof_50km_contshelf_'+mod+'_'+scenario+'_'+str(tt)+'.nc')
        file_S_orig = xr.open_dataset(inputpath_profiles+'S_mean_prof_50km_contshelf_'+mod+'_'+scenario+'_'+str(tt)+'.nc')
        file_TS_orig = xr.merge([file_T_orig.rename({'thetao':'theta_ocean'}), file_S_orig.rename({'so':'salinity_ocean'})]).sel(Nisf=rignot_isf).assign_coords({'time': tt})
        file_TS_list.append(file_TS_orig)
    file_TS = xr.concat(file_TS_list, dim='time').rename({'z':'depth'})
    file_TS['depth'] = -1*file_TS['depth']
    depth_axis_old = file_TS.depth.values
    depth_axis_new = np.concatenate((np.zeros(1),depth_axis_old))
    file_TS_with_shallow = file_TS.interp({'depth': depth_axis_new})
    filled_TS = file_TS_with_shallow.ffill(dim='depth').bfill(dim='depth').sel(Nisf=file_isf.Nisf) #, 'profile_domain': 1})

    print('Prepare input variables')
    n = 0
    for kisf in tqdm(file_isf.Nisf):
        depth_kisf = uf.choose_isf(ice_draft_pos,isf_stack_mask, kisf)
        depth_of_int0 = depth_kisf.where(depth_kisf < file_isf['front_bot_depth_max'].sel(Nisf=kisf), 
                                       file_isf['front_bot_depth_max'].sel(Nisf=kisf))
        depth_of_int = depth_of_int0.where(depth_kisf > file_isf['front_ice_depth_min'].sel(Nisf=kisf), 
                                       file_isf['front_ice_depth_min'].sel(Nisf=kisf))

        T_isf = filled_TS['theta_ocean'].sel(Nisf=kisf).interp({'depth': depth_of_int}).drop('depth')
        S_isf = filled_TS['salinity_ocean'].sel(Nisf=kisf).interp({'depth': depth_of_int}).drop('depth')

        cell_area_kisf = uf.choose_isf(cell_area_weight ,isf_stack_mask, kisf) 
        isf_conc_kisf = uf.choose_isf(file_isf_conc,isf_stack_mask, kisf) 
        weight_kisf = cell_area_kisf * isf_conc_kisf

        T_mean_cav = uf.weighted_mean(T_isf, 'mask_coord', weight_kisf).to_dataset(name='T_mean')
        S_mean_cav = uf.weighted_mean(S_isf, 'mask_coord', weight_kisf).to_dataset(name='S_mean')
        T_std_cav = uf.weighted_std(T_isf, 'mask_coord', weight_kisf).to_dataset(name='T_std')
        S_std_cav = uf.weighted_std(S_isf, 'mask_coord', weight_kisf).to_dataset(name='S_std')
        T_S_2D_meanstd_kisf = xr.merge([T_mean_cav,S_mean_cav,T_std_cav,S_std_cav])

        T_S_info_br, TSmean_br = xr.broadcast(xr.merge([T_isf.rename('theta_in'),S_isf.rename('salinity_in')]),T_S_2D_meanstd_kisf)
        TS_info_all = xr.merge([T_S_info_br, TSmean_br])

        file_isf_kisf = uf.choose_isf(file_isf[['dGL', 'dIF']], isf_stack_mask, kisf)
        bathy_kisf = uf.choose_isf(file_bed_goodGL, isf_stack_mask, kisf)
        slope_kisf = uf.choose_isf(file_slope, isf_stack_mask, kisf)
        geometry_kisf = xr.merge([file_isf_kisf,
                             depth_kisf.rename('corrected_isfdraft'),
                             bathy_kisf.rename('bathy_metry'),
                             slope_kisf])


        geometry_2D_br, time_dpdt_in_br = xr.broadcast(geometry_kisf,TS_info_all)
        all_info = xr.merge([geometry_2D_br, time_dpdt_in_br])

        if n == 0:
            all_info_all = all_info.squeeze().drop('Nisf')
        else:
            all_info_all =  all_info_all.combine_first(all_info).squeeze().drop('Nisf')
        n = n+1        

    for mod_size in ['xsmall96','small','large']: #,

        print('NN size',mod_size)
        seed_nb = 1
        model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_wholedataset_'+str(seed_nb).zfill(2)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')

        print('Computing 2D patterns')
        xr_ensmean_res2D = apply_NN_results_2D_allisf_allyears(file_isf, norm_metrics, all_info_all, model, input_vars)

        print('Compute the 1D evalmetrics')
        res_1D_list = []
        for kisf in tqdm(file_isf.Nisf.values): 

            geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
            melt_rate_2D_isf_m_per_y = dfmt.choose_isf(xr_ensmean_res2D,isf_stack_mask, kisf)

            melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12

            box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
            param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

            melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['grid_cell_area_weighted'])     

            out_1D = xr.concat([melt_rate_1D_isf_Gt_per_y, melt_rate_1D_isf_myr_box1_mean], dim='metrics').assign_coords({'metrics': ['Gt','box1']})
            res_1D_list.append(out_1D) 

        res_1D_tt = xr.concat(res_1D_list, dim='Nisf')

        res_1D_tt.rename('predicted_melt').to_netcdf(outputpath_melt + 'evalmetrics_1D_NN'+mod_size+'_'+scenario+'.nc')

RUN DEEP ENSEMBLE

In [None]:
### APPLY MODEL

input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

tuning_sort = 'new' #'old'

### use any model from CV over time
if tuning_sort == 'new':
    path_model = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/NN_MODELS/'
elif tuning_sort == 'old':
    path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/WHOLE/'
    
scenario = 'ssp585'
ano_ISMIP = True
ano_NEMO = True

### APPLY MODEL

def apply_NN_results_2D_allisf_allyears(file_isf, norm_metrics, all_info_all, model, input_vars=[]):
    
    all_info_df = all_info_all.to_dataframe()
    val_norm = pp.normalise_vars(all_info_df[input_vars],
                                norm_metrics[input_vars].loc['mean_vars'],
                                norm_metrics[input_vars].loc['range_vars'])

    x_val_norm = val_norm
    
    batch_size=4096
    y_out_norm = model.predict(x_val_norm.values.astype('float64'),batch_size=batch_size,verbose = 1)

    y_out_norm_xr = xr.DataArray(data=y_out_norm.squeeze()).rename({'dim_0': 'index'})
    y_out_norm_xr = y_out_norm_xr.assign_coords({'index': x_val_norm.index})

    y_out_norm_xr_2D = uf.bring_back_to_2D(y_out_norm_xr)

    # denormalise the output
    y_out_xr = pp.denormalise_vars(y_out_norm_xr_2D, 
                             norm_metrics['melt_m_ice_per_y'].loc['mean_vars'],
                             norm_metrics['melt_m_ice_per_y'].loc['range_vars'])

    y_whole_grid = y_out_xr.reindex_like(file_isf['ISF_mask'])
    return y_whole_grid

### if we want the "real" profiles and not ISMIP + anomalies
inputpath_profiles='/bettik/burgardc/DATA/SUMMER_PAPER/interim/T_S_PROF/CMIP/'
T_ISMIP = xr.open_dataset(inputpath_profiles + 'T_mean_prof_50km_contshelf_ISMIP.nc')
S_ISMIP = xr.open_dataset(inputpath_profiles + 'S_mean_prof_50km_contshelf_ISMIP.nc')
TS_ISMIP = xr.merge([T_ISMIP.rename({'thetao':'theta_ocean'}),S_ISMIP.rename({'so':'salinity_ocean'})])
    
inputpath_profiles_NEMO = '/bettik/burgardc/DATA/BASAL_MELT_PARAM/interim/T_S_PROF/nemo_5km_OPM026/'
file_TS_orig = xr.open_dataset(inputpath_profiles_NEMO+'T_S_mean_prof_corrected_km_contshelf_1980-2018_rignotisf.nc')
TS_NEMO = file_TS_orig.sel(profile_domain=[50]).squeeze().drop_vars('profile_domain').mean('time')
TS_NEMO['depth'] = -1*TS_NEMO['depth']
TS_NEMO_rightaxis = TS_NEMO.rename({'depth':'z'}).interp({'z': TS_ISMIP.z})
TS_NEMO_withISMIP = xr.concat([TS_NEMO_rightaxis,TS_ISMIP.sel(Nisf=[36,62])], dim='Nisf')

### If no ensemble of models
for mod in ['ACCESS-CM2','ACCESS-ESM1-5','CESM2','CESM2-WACCM','CNRM-CM6-1','CNRM-ESM2-1',
           'CanESM5','GFDL-CM4','GFDL-ESM4','GISS-E2-1-H','IPSL-CM6A-LR','MPI-ESM1-2-HR','MRI-ESM2-0',
           'UKESM1-0-LL']: #,,
#for mod in ['UKESM1-0-LL']: # continue here
            #,,'GFDL-CM4','ACCESS-CM2','ACCESS-ESM1-5','CESM2','CESM2-WACCM','CNRM-CM6-1','CNRM-ESM2-1',
          # 'CanESM5','GFDL-ESM4','GISS-E2-1-H','IPSL-CM6A-LR','MPI-ESM1-2-HR',
    print(mod)
    
    if mod in ['CNRM-CM6-1','CNRM-ESM2-1']:
        to2300 = False
    elif mod in ['GISS-E2-1-H']:
        to2300 = True
    elif mod in ['ACCESS-CM2','ACCESS-ESM1-5','CESM2-WACCM','CanESM5','IPSL-CM6A-LR','MRI-ESM2-0']:
        to2300 = True
    elif mod in ['MPI-ESM1-2-HR','GFDL-CM4','GFDL-ESM4']:
        to2300 = False
    elif mod == 'UKESM1-0-LL':
        to2300 = True     
    elif mod == 'CESM2':
        to2300 = False        

    if (geoyear > 2100) and (not to2300):
        continue
    
    if scenario == 'historical':
        yystart = 1850 #1980 #1850
        yyend = 2014
    elif scenario == 'ssp245':
        yystart = 2015
        yyend = 2100  
    else:
        if to2300:
            yystart = 2015
            yyend = 2300
            if (geoyear > 2100):
                yystart = geoyear
        else:
            yystart = 2015
            yyend = 2100 

    inputpath_profiles='/bettik/burgardc/DATA/SUMMER_PAPER/interim/T_S_PROF/CMIP/'+mod+'/'
    outputpath_melt = home_path+'/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_CMIP/'+mod+'/'

    T_Clim = xr.open_dataset(inputpath_profiles + 'T_mean_prof_50km_contshelf_'+mod+'_clim.nc')
    S_Clim = xr.open_dataset(inputpath_profiles + 'S_mean_prof_50km_contshelf_'+mod+'_clim.nc')
    TS_Clim = xr.merge([T_Clim.rename({'thetao':'theta_ocean'}),S_Clim.rename({'so':'salinity_ocean'})])
            
    # PREPARE VARIABLES
    file_TS_list = []
    for tt in range(yystart,yyend+1):
        file_T_orig = xr.open_dataset(inputpath_profiles+'T_mean_prof_50km_contshelf_'+mod+'_'+scenario+'_'+str(tt)+'.nc')
        file_S_orig = xr.open_dataset(inputpath_profiles+'S_mean_prof_50km_contshelf_'+mod+'_'+scenario+'_'+str(tt)+'.nc')
        file_TS_orig = xr.merge([file_T_orig.rename({'thetao':'theta_ocean'}), file_S_orig.rename({'so':'salinity_ocean'})]).sel(Nisf=rignot_isf).assign_coords({'time': tt})
        file_TS_list.append(file_TS_orig)

    file_TS_0 = xr.concat(file_TS_list, dim='time')
    if ano_ISMIP:
        file_TS = (file_TS_0 - TS_ISMIP*0 + TS_Clim*0).rename({'z':'depth'})
    elif ano_NEMO:
        file_TS = (file_TS_0 - TS_ISMIP + TS_NEMO_withISMIP).rename({'z':'depth'})
    else:
        ### if we want the "real" profiles and not ISMIP + anomalies
        file_TS = (file_TS_0 - TS_ISMIP + TS_Clim).rename({'z':'depth'})
        ###
    
    file_TS['depth'] = -1*file_TS['depth']
    depth_axis_old = file_TS.depth.values
    depth_axis_new = np.concatenate((np.zeros(1),depth_axis_old))
    file_TS_with_shallow = file_TS.interp({'depth': depth_axis_new})
    filled_TS = file_TS_with_shallow.ffill(dim='depth').bfill(dim='depth').sel(Nisf=file_isf.Nisf) #, 'profile_domain': 1})
    

    

    print('Prepare input variables')
    n = 0
    for kisf in tqdm(file_isf.Nisf):
        depth_kisf = uf.choose_isf(ice_draft_pos,isf_stack_mask, kisf)
        depth_of_int0 = depth_kisf.where(depth_kisf < file_isf['front_bot_depth_max'].sel(Nisf=kisf), 
                                       file_isf['front_bot_depth_max'].sel(Nisf=kisf))
        depth_of_int = depth_of_int0.where(depth_kisf > file_isf['front_ice_depth_min'].sel(Nisf=kisf), 
                                       file_isf['front_ice_depth_min'].sel(Nisf=kisf))

        T_isf = filled_TS['theta_ocean'].sel(Nisf=kisf).interp({'depth': depth_of_int}).drop_vars('depth')
        S_isf = filled_TS['salinity_ocean'].sel(Nisf=kisf).interp({'depth': depth_of_int}).drop_vars('depth')

        cell_area_kisf = uf.choose_isf(cell_area_weight ,isf_stack_mask, kisf) 
        isf_conc_kisf = uf.choose_isf(file_isf_conc,isf_stack_mask, kisf) 
        weight_kisf = cell_area_kisf * isf_conc_kisf

        T_mean_cav = uf.weighted_mean(T_isf, 'mask_coord', weight_kisf).to_dataset(name='T_mean')
        S_mean_cav = uf.weighted_mean(S_isf, 'mask_coord', weight_kisf).to_dataset(name='S_mean')
        T_std_cav = uf.weighted_std(T_isf, 'mask_coord', weight_kisf).to_dataset(name='T_std')
        S_std_cav = uf.weighted_std(S_isf, 'mask_coord', weight_kisf).to_dataset(name='S_std')
        T_S_2D_meanstd_kisf = xr.merge([T_mean_cav,S_mean_cav,T_std_cav,S_std_cav])

        T_S_info_br, TSmean_br = xr.broadcast(xr.merge([T_isf.rename('theta_in'),S_isf.rename('salinity_in')]),T_S_2D_meanstd_kisf)
        TS_info_all = xr.merge([T_S_info_br, TSmean_br])

        file_isf_kisf = uf.choose_isf(file_isf[['dGL', 'dIF']], isf_stack_mask, kisf)
        bathy_kisf = uf.choose_isf(file_bed_goodGL, isf_stack_mask, kisf)
        slope_kisf = uf.choose_isf(file_slope, isf_stack_mask, kisf)
        geometry_kisf = xr.merge([file_isf_kisf,
                             depth_kisf.rename('corrected_isfdraft'),
                             bathy_kisf.rename('bathy_metry'),
                             slope_kisf])


        geometry_2D_br, time_dpdt_in_br = xr.broadcast(geometry_kisf,TS_info_all)
        all_info = xr.merge([geometry_2D_br, time_dpdt_in_br])

        if n == 0:
            all_info_all = all_info.squeeze().drop_vars('Nisf')
        else:
            all_info_all =  all_info_all.combine_first(all_info).squeeze().drop_vars('Nisf')
        n = n+1        

    mod_size = 'small'
    
    res_2D_list = []
    res2D_sum = 0
    for seed_nb in range(1,11):
        print('seed_nb',seed_nb)

        model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_wholedataset_'+str(seed_nb).zfill(2)+'_TS'+TS_opt+'_norm'+norm_method+'.h5', 
                                        compile=False)
        model.compile(optimizer="adam",loss=tf.keras.losses.MeanSquaredError(),metrics=[tf.keras.metrics.MeanSquaredError()])
        
        print('Computing 2D patterns')
        res2D = apply_NN_results_2D_allisf_allyears(file_isf, norm_metrics, all_info_all, model, input_vars)
        res2D_sum = res2D_sum + res2D
        #res_2D_list.append(res2D.assign_coords({'seed_nb': seed_nb}).chunk({'time': 20}))
        
    #res2D_all = xr.concat(res_2D_list, dim='seed_nb')
    #xr_ensmean_res2D = res2D_all.mean('seed_nb').load()
    xr_ensmean_res2D = res2D_sum / 10
    #del res2D_all
    
    print('Compute the 1D evalmetrics')
    res_1D_list = []
    for kisf in tqdm(file_isf.Nisf.values): 

        geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
        melt_rate_2D_isf_m_per_y = dfmt.choose_isf(xr_ensmean_res2D,isf_stack_mask, kisf)

        melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12

        box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
        param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

        melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['grid_cell_area_weighted'])     

        out_1D = xr.concat([melt_rate_1D_isf_Gt_per_y, melt_rate_1D_isf_myr_box1_mean], dim='metrics').assign_coords({'metrics': ['Gt','box1']})
        res_1D_list.append(out_1D) 

    res_1D_tt = xr.concat(res_1D_list, dim='Nisf')

    if ano_ISMIP:
        res_1D_tt.rename('predicted_melt').to_netcdf(outputpath_melt + 'evalmetrics_1D_NN'+mod_size+'_'+scenario+'_anoISMIP_ElmerIcegeo'+str(geoyear)+'.nc')
    elif ano_NEMO:
        res_1D_tt.rename('predicted_melt').to_netcdf(outputpath_melt + 'evalmetrics_1D_NN'+mod_size+'_'+scenario+'_anoNEMO_ElmerIcegeo'+str(geoyear)+'.nc')
    else:
        res_1D_tt.rename('predicted_melt').to_netcdf(outputpath_melt + 'evalmetrics_1D_NN'+mod_size+'_'+scenario+'_notISMIP_ElmerIcegeo'+str(geoyear)+'.nc')


In [None]:
geoyear

################## DO NOT USE ANYMORE!!!!

In [None]:
mod_size = 'xsmall96'

### APPLY MODEL

input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

tuning_sort = 'new' #'old'

### use any model from CV over time
outputpath_melt = home_path+'/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_CMIP/'+mod+'/'
if tuning_sort == 'new':
    path_model = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/NN_MODELS/'
elif tuning_sort == 'old':
    path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/WHOLE/'

def apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_nrun, model, input_vars=[]):
    """
    Compute 2D melt based on a given NN model
    
    """

    val_norm = pp.normalise_vars(df_nrun[input_vars],
                                norm_metrics[input_vars].loc['mean_vars'],
                                norm_metrics[input_vars].loc['range_vars'])

    x_val_norm = val_norm

    y_out_norm = model.predict(x_val_norm.values.astype('float64'),verbose = 0)

    y_out_norm_xr = xr.DataArray(data=y_out_norm.squeeze()).rename({'dim_0': 'index'})
    y_out_norm_xr = y_out_norm_xr.assign_coords({'index': x_val_norm.index})

    # denormalise the output
    y_out = pp.denormalise_vars(y_out_norm_xr, 
                             norm_metrics['melt_m_ice_per_y'].loc['mean_vars'],
                             norm_metrics['melt_m_ice_per_y'].loc['range_vars'])

    y_out_pd_s = pd.Series(y_out.values,index=df_nrun.index,name='predicted_melt') 

    # put some order in the file
    y_out_xr = y_out_pd_s.to_xarray().sortby('y')

    y_whole_grid = y_out_xr.reindex_like(file_isf['ISF_mask'])
    return y_whole_grid


In [None]:
    print('NN size',mod_size)
    seed_nb = 1
    model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_wholedataset_'+str(seed_nb).zfill(2)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')
    
    if scenario == 'historical':
        print('prepare 2D input')
        li = []
        for tt in tqdm(range(yystart,yyend+1)):    
            #print('Combining all dataframes')
        
            df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt)+'.csv',index_col=[0,1,2])
            li.append(df_nrun)

        print('Concatenating input')
        df_allyy = pd.concat(li)
        
        print('Computing 2D patterns')
        xr_ensmean_res2D = apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_allyy, model, input_vars)

In [None]:
### APPLY MODEL

input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

tuning_sort = 'new' #'old'

### use any model from CV over time
outputpath_melt = home_path+'/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_CMIP/'+mod+'/'
if tuning_sort == 'new':
    path_model = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/NN_MODELS/'
elif tuning_sort == 'old':
    path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/WHOLE/'

def apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_nrun, model, input_vars=[]):
    """
    Compute 2D melt based on a given NN model
    
    """

    val_norm = pp.normalise_vars(df_nrun[input_vars],
                                norm_metrics[input_vars].loc['mean_vars'],
                                norm_metrics[input_vars].loc['range_vars'])

    x_val_norm = val_norm

    y_out_norm = model.predict(x_val_norm.values.astype('float64'),verbose = 0)

    y_out_norm_xr = xr.DataArray(data=y_out_norm.squeeze()).rename({'dim_0': 'index'})
    y_out_norm_xr = y_out_norm_xr.assign_coords({'index': x_val_norm.index})

    # denormalise the output
    y_out = pp.denormalise_vars(y_out_norm_xr, 
                             norm_metrics['melt_m_ice_per_y'].loc['mean_vars'],
                             norm_metrics['melt_m_ice_per_y'].loc['range_vars'])

    y_out_pd_s = pd.Series(y_out.values,index=df_nrun.index,name='predicted_melt') 

    # put some order in the file
    y_out_xr = y_out_pd_s.to_xarray().sortby('y')

    y_whole_grid = y_out_xr.reindex_like(file_isf['ISF_mask'])
    return y_whole_grid

### If no ensemble of models

for mod_size in ['xsmall96','small','large']: #
    
    print('NN size',mod_size)
    seed_nb = 1
    model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_wholedataset_'+str(seed_nb).zfill(2)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')
    
    if scenario == 'historical':
        print('prepare 2D input')
        li = []
        for tt in tqdm(range(yystart,yyend+1)):    
            #print('Combining all dataframes')
        
            df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt)+'.csv',index_col=[0,1,2])
            li.append(df_nrun)

        print('Concatenating input')
        df_allyy = pd.concat(li)
        
        print('Computing 2D patterns')
        xr_ensmean_res2D = apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_allyy, model, input_vars)

        print('Compute the 1D evalmetrics')
        res_1D_list = []
        for kisf in tqdm(file_isf.Nisf.values): 

            geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
            melt_rate_2D_isf_m_per_y = dfmt.choose_isf(xr_ensmean_res2D,isf_stack_mask, kisf)

            melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12

            box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
            param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

            melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['grid_cell_area_weighted'])     

            out_1D = xr.concat([melt_rate_1D_isf_Gt_per_y, melt_rate_1D_isf_myr_box1_mean], dim='metrics').assign_coords({'metrics': ['Gt','box1']})
            res_1D_list.append(out_1D) 

        res_1D_all = xr.concat(res_1D_list, dim='Nisf')
        res_1D_all.rename('predicted_melt').to_netcdf(outputpath_melt + 'evalmetrics_1D_NN'+mod_size+'_'+scenario+'.nc')
    
    else:
        
        res_1Dtt_list = []
        
        for tt in tqdm(range(yystart,yyend+1,50)):    
            #print('Combining all dataframes')
            trange = range(tt,tt+50)
            
            li = []
            for tt0 in trange: 
                
                if os.path.exists(inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt0)+'.csv'):
                    df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt0)+'.csv',index_col=[0,1,2])
                    li.append(df_nrun)  
                    ttend = tt0
                else:
                    print(str(tt0)+' file does not exist')
                
            trange = range(tt,ttend+1)
            df_allyy = pd.concat(li)

            #print('Computing 2D patterns')
            xr_ensmean_res2D = apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_allyy, model, input_vars)

            #print('Compute the 1D evalmetrics')
            res_1D_list = []
            for kisf in file_isf.Nisf.values: 

                geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
                melt_rate_2D_isf_m_per_y = dfmt.choose_isf(xr_ensmean_res2D,isf_stack_mask, kisf)

                melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12

                box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
                param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

                melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['grid_cell_area_weighted'])     

                out_1D = xr.concat([melt_rate_1D_isf_Gt_per_y, melt_rate_1D_isf_myr_box1_mean], dim='metrics').assign_coords({'metrics': ['Gt','box1']})
                res_1D_list.append(out_1D) 

            res_1D_tt = xr.concat(res_1D_list, dim='Nisf').assign_coords({'time': trange})
            res_1Dtt_list.append(res_1D_tt)
            
        res_1D_all = xr.concat(res_1Dtt_list, dim='time')
        res_1D_all.rename('predicted_melt').to_netcdf(outputpath_melt + 'evalmetrics_1D_NN'+mod_size+'_'+scenario+'.nc')

In [None]:
inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt)+'.csv'

In [None]:
if os.path.exists(inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt0)+'.csv'):
    print('True')
else:
    print('not')

In [None]:
inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt0)+'.csv'

###############