In [None]:
"""
Created on Mon Mar 27 17:37 2023

Apply the ensemble of NN to Smith data and only take ensemble mean

Author: @claraburgard

"""

In [None]:
import numpy as np
import xarray as xr
from tqdm.notebook import trange, tqdm
#from tqdm import tqdm
import glob
import matplotlib as mpl
import seaborn as sns
import datetime
import time
import os,sys

import tensorflow as tf
from tensorflow import keras
from contextlib import redirect_stdout

from multimelt.constants import *
import multimelt.useful_functions as uf
import summer_paper.data_formatting_NN as dfmt
import summer_paper.postprocessing_functions_NN as pp

import os.path


DEFINE OPTIONS

In [None]:
#mod_size =  'xsmall96' #'mini', 'small', 'medium', 'large', 'extra_large'
TS_opt = 'extrap' # extrap, whole, thermocline
norm_method =  'std' # std,interquart, minmax
exp_name = 'newbasic2'#'allbutconstants' #'onlyTSdraftandslope' #'TSdraftbotandiceddandwcd' #'onlyTSisfdraft' #'TSdraftbotandiceddandwcdreldGL' #TSdraftslopereldGL

In [None]:
outputpath_info = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/'

nemo_run = 'ctrl94' # 'EPM031', 'EPM034'

file_info = pd.read_csv(outputpath_info+'info_chunks.txt', delimiter=',', header=None)
file_info = file_info.set_index(file_info[0])

for chunk_nb in file_info[file_info[1]==nemo_run][0].values:
    start_yy = file_info[file_info[1]==nemo_run][2].loc[chunk_nb]
    end_yy = file_info[file_info[1]==nemo_run][3].loc[chunk_nb]
    trange = range(start_yy,end_yy+1)
    print(chunk_nb,start_yy,end_yy)

In [None]:
tblock_dim = [30]

READ IN DATA

In [None]:
home_path = '/bettik/burgardc/'

inputpath_data_nn = '/bettik/burgardc/DATA/NN_PARAM/interim/INPUT_DATA/'
inputpath_csv = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS/'
inputpath_mask = home_path+'/DATA/SUMMER_PAPER/interim/ANTARCTICA_IS_MASKS/nemo_5km_'+nemo_run+'/'
inputpath_data=home_path+'/DATA/SUMMER_PAPER/interim/'
inputpath_CVinput = home_path+'/DATA/NN_PARAM/interim/INPUT_DATA/EXTRAPOLATED_ISFDRAFT_CHUNKS/'
inputpath_plumes = home_path+'/DATA/SUMMER_PAPER/interim/PLUMES/nemo_5km_'+nemo_run+'/'
inputpath_boxes = home_path+'/DATA/SUMMER_PAPER/interim/BOXES/nemo_5km_'+nemo_run+'/'

# make the domain a little smaller to make the computation even more efficient - file isf has already been made smaller at its creation
map_lim = [-3000000,3000000]

file_isf, geometry_info_2D, box_charac_2D, box_charac_1D, isf_stack_mask = pp.read_input_evalmetrics_NN(nemo_run)


norm_metrics_file = xr.open_dataset(inputpath_CVinput + 'metrics_norm_wholedataset_origexcept26_christoph_new.nc')
norm_metrics = norm_metrics_file.to_dataframe()

box_loc_config2 = box_charac_2D['box_location'].sel(box_nb_tot=box_charac_1D['nD_config'].sel(config=2))
box1 = box_loc_config2.where(box_loc_config2==1).isel(Nisf=1).drop('Nisf')

In [None]:
mod_size = 'xsmall96'

### APPLY MODEL

input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

tuning_sort = 'new' #'old'

### use any model from CV over time
#outputpath_melt = home_path+'/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_CMIP/'+mod+'/'
if tuning_sort == 'new':
    path_model = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/NN_MODELS/'
elif tuning_sort == 'old':
    path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/WHOLE/'

def apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_nrun, model, input_vars=[]):
    """
    Compute 2D melt based on a given NN model
    
    """

    val_norm = pp.normalise_vars(df_nrun[input_vars],
                                norm_metrics[input_vars].loc['mean_vars'],
                                norm_metrics[input_vars].loc['range_vars'])

    x_val_norm = val_norm

    y_out_norm = model.predict(x_val_norm.values.astype('float64'),verbose = 0)

    y_out_norm_xr = xr.DataArray(data=y_out_norm.squeeze()).rename({'dim_0': 'index'})
    y_out_norm_xr = y_out_norm_xr.assign_coords({'index': x_val_norm.index})

    # denormalise the output
    y_out = pp.denormalise_vars(y_out_norm_xr, 
                             norm_metrics['melt_m_ice_per_y'].loc['mean_vars'],
                             norm_metrics['melt_m_ice_per_y'].loc['range_vars'])

    y_out_pd_s = pd.Series(y_out.values,index=df_nrun.index,name='predicted_melt') 

    # put some order in the file
    y_out_xr = y_out_pd_s.to_xarray().sortby('y')

    y_whole_grid = y_out_xr.reindex_like(file_isf['ISF_mask'])
    return y_whole_grid


In [None]:
    print('NN size',mod_size)
    seed_nb = 1
    model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_wholedataset_'+str(seed_nb).zfill(2)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')
    
    tblock=30
    li = []
    for kisf in tqdm(file_isf.Nisf): 
        df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_isf'+str(kisf.values).zfill(3)+'_'+str(tblock).zfill(3)+'_new.csv',index_col=[0,1,2])
        li.append(df_nrun)
    df_allyy = pd.concat(li)
    
    print('Computing 2D patterns')
    xr_ensmean_res2D = apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_allyy, model, input_vars)
    

In [None]:
df_allyy.max()

In [None]:
df_allyy.min()

In [None]:
df_LarsenB = df_nrun.where(df_allyy['Nisf'] == 66).dropna(how='all')

In [None]:
xr_ensmean_res2D = apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_LarsenB, model, input_vars)

In [None]:
xr_ensmean_res2D.where(np.isfinite(xr_ensmean_res2D), drop=True).plot()

In [None]:
input_var = df_LarsenB['theta_in'].to_xarray().sortby('y').reindex_like(file_isf['ISF_mask'])

In [None]:
input_var.where(file_isf['ISF_mask'] == 66, drop=True).plot()

In [None]:
inputpath_profiles='/bettik/burgardc/DATA/SUMMER_PAPER/interim/T_S_PROF/CMIP/'

S_prof = xr.open_dataset(inputpath_profiles + mod + '/S_mean_prof_50km_contshelf_'+mod+'_'+scenario+'_1980.nc')

In [None]:
S_prof['so'].sel(Nisf=66).plot()

In [None]:
inputpath_data_orig='/bettik/jourdai1/OCEAN_DATA_CMIP6_STEREO/'+mod+'/'
T_ocean_files = xr.open_mfdataset(inputpath_data_orig+'so_Oyr_'+mod+'_'+scenario+'_r1i1p1f2_1950*.nc')


In [None]:
T_ocean_files['so'].sel(time='1980-07-01T14:00:00.000000000').isel(z=6).plot()

In [None]:
### APPLY MODEL

input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat',
                'theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

tuning_sort = 'new' #'old'

### use any model from CV over time
outputpath_melt = home_path+'/DATA/SUMMER_PAPER/processed/OCEAN_MELT_RATE_CMIP/'+mod+'/'
if tuning_sort == 'new':
    path_model = '/bettik/burgardc/DATA/SUMMER_PAPER/interim/NN_MODELS/'
elif tuning_sort == 'old':
    path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/WHOLE/'

def apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_nrun, model, input_vars=[]):
    """
    Compute 2D melt based on a given NN model
    
    """

    val_norm = pp.normalise_vars(df_nrun[input_vars],
                                norm_metrics[input_vars].loc['mean_vars'],
                                norm_metrics[input_vars].loc['range_vars'])

    x_val_norm = val_norm

    y_out_norm = model.predict(x_val_norm.values.astype('float64'),verbose = 0)

    y_out_norm_xr = xr.DataArray(data=y_out_norm.squeeze()).rename({'dim_0': 'index'})
    y_out_norm_xr = y_out_norm_xr.assign_coords({'index': x_val_norm.index})

    # denormalise the output
    y_out = pp.denormalise_vars(y_out_norm_xr, 
                             norm_metrics['melt_m_ice_per_y'].loc['mean_vars'],
                             norm_metrics['melt_m_ice_per_y'].loc['range_vars'])

    y_out_pd_s = pd.Series(y_out.values,index=df_nrun.index,name='predicted_melt') 

    # put some order in the file
    y_out_xr = y_out_pd_s.to_xarray().sortby('y')

    y_whole_grid = y_out_xr.reindex_like(file_isf['ISF_mask'])
    return y_whole_grid

### If no ensemble of models

for mod_size in ['xsmall96','small','large']: #
    
    print('NN size',mod_size)
    seed_nb = 1
    model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_wholedataset_'+str(seed_nb).zfill(2)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')
    
    if scenario == 'historical':
        print('prepare 2D input')
        li = []
        for tt in tqdm(range(yystart,yyend+1)):    
            #print('Combining all dataframes')
        
            df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt)+'.csv',index_col=[0,1,2])
            li.append(df_nrun)

        print('Concatenating input')
        df_allyy = pd.concat(li)
        
        print('Computing 2D patterns')
        xr_ensmean_res2D = apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_allyy, model, input_vars)

        print('Compute the 1D evalmetrics')
        res_1D_list = []
        for kisf in tqdm(file_isf.Nisf.values): 

            geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
            melt_rate_2D_isf_m_per_y = dfmt.choose_isf(xr_ensmean_res2D,isf_stack_mask, kisf)

            melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12

            box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
            param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

            melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['grid_cell_area_weighted'])     

            out_1D = xr.concat([melt_rate_1D_isf_Gt_per_y, melt_rate_1D_isf_myr_box1_mean], dim='metrics').assign_coords({'metrics': ['Gt','box1']})
            res_1D_list.append(out_1D) 

        res_1D_all = xr.concat(res_1D_list, dim='Nisf')
        res_1D_all.rename('predicted_melt').to_netcdf(outputpath_melt + 'evalmetrics_1D_NN'+mod_size+'_'+scenario+'.nc')
    
    else:
        
        res_1Dtt_list = []
        
        for tt in tqdm(range(yystart,yyend+1,50)):    
            #print('Combining all dataframes')
            trange = range(tt,tt+50)
            
            li = []
            for tt0 in trange: 
                
                if os.path.exists(inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt0)+'.csv'):
                    df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt0)+'.csv',index_col=[0,1,2])
                    li.append(df_nrun)  
                    ttend = tt0
                else:
                    print(str(tt0)+' file does not exist')
                
            trange = range(tt,ttend+1)
            df_allyy = pd.concat(li)

            #print('Computing 2D patterns')
            xr_ensmean_res2D = apply_NN_results_2D_1isf_1year(file_isf, norm_metrics, df_allyy, model, input_vars)

            #print('Compute the 1D evalmetrics')
            res_1D_list = []
            for kisf in file_isf.Nisf.values: 

                geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
                melt_rate_2D_isf_m_per_y = dfmt.choose_isf(xr_ensmean_res2D,isf_stack_mask, kisf)

                melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12

                box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
                param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

                melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['grid_cell_area_weighted'])     

                out_1D = xr.concat([melt_rate_1D_isf_Gt_per_y, melt_rate_1D_isf_myr_box1_mean], dim='metrics').assign_coords({'metrics': ['Gt','box1']})
                res_1D_list.append(out_1D) 

            res_1D_tt = xr.concat(res_1D_list, dim='Nisf').assign_coords({'time': trange})
            res_1Dtt_list.append(res_1D_tt)
            
        res_1D_all = xr.concat(res_1Dtt_list, dim='time')
        res_1D_all.rename('predicted_melt').to_netcdf(outputpath_melt + 'evalmetrics_1D_NN'+mod_size+'_'+scenario+'.nc')

In [None]:
inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt)+'.csv'

In [None]:
if os.path.exists(inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt0)+'.csv'):
    print('True')
else:
    print('not')

In [None]:
inputpath_csv + 'dataframe_input_allisf_'+mod+'_'+scenario+'_'+str(tt0)+'.csv'

###############