In [None]:
"""
Created on Fri Sep 09 14:24 2022

Convert "raw output" from the model to 2D maps

Author: @claraburgard

"""

In [None]:
import numpy as np
import xarray as xr
from tqdm.notebook import trange, tqdm
import glob
import matplotlib as mpl
import seaborn as sns
import datetime
import time
import os,sys

import tensorflow as tf
from tensorflow import keras
from contextlib import redirect_stdout

from basal_melt_neural_networks.constants import *
import basal_melt_neural_networks.diagnostic_functions as diag
import basal_melt_neural_networks.data_formatting as dfmt
import basal_melt_neural_networks.postprocessing_functions as pp
from basal_melt_param.constants import *

DEFINE OPTIONS

In [None]:
mod_size =  'xsmall96' #'mini', 'small', 'medium', 'large', 'extra_large'
TS_opt = 'extrap' # extrap, whole, thermocline
norm_method =  'std' # std, interquart, minmax
exp_name = 'newbasic2' #'onlyTSdraftandslope' #'onlyTSdraftandslope' #'TSdraftbotandiceddandwcd' #'onlyTSisfdraft' #'TSdraftbotandiceddandwcdreldGL' #TSdraftslopereldGL

In [None]:
tblock_dim = range(1,14)
isf_dim = [10,11,12,13,18,22,23,24,25,30,31,33,38,39,40,42,43,44,45,47,48,51,52,53,54,55,58,61,65,66,69,70,71,73,75]

READ IN DATA

In [None]:
inputpath_data_nn = '/bettik/burgardc/DATA/NN_PARAM/interim/INPUT_DATA/'


In [None]:
if TS_opt == 'extrap':
    inputpath_CVinput = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS_CV/'
    path_orig_data = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS/'
elif TS_opt == 'whole':
    inputpath_CVinput = inputpath_data_nn+'WHOLE_PROF_CHUNKS_CV/'
    path_orig_data = inputpath_data_nn+'WHOLE_PROF_CHUNKS/'
elif TS_opt == 'thermocline':
    inputpath_CVinput = inputpath_data_nn+'THERMOCLINE_CHUNKS_CV/'

APPLY MODEL

In [None]:
input_list = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat','theta_in','salinity_in',
                  'T_mean', 'S_mean', 'T_std', 'S_std']

In [None]:
#### CV over time
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/CV_TBLOCKS/'

res_all_list = []
nemo_run_old = 'OPM021'
#for tblock_out in tqdm(tblock_dim):
#for tblock_out in tqdm([1,2]):
for tblock_out in tqdm([11,12,13]):

    isf_out = 0

    nemo_run = pp.identify_nemo_run_from_tblock(tblock_out)

    if (nemo_run_old != nemo_run):
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
        res_all_CV = xr.concat(res_all_list, dim='time')
        res_all_CV.to_netcdf(outputpath_melt_nn + 'evalmetrics_2D_CV_'+mod_size+'_'+exp_name+'_'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')
        res_all_list = []
    
    nemo_run_old = nemo_run
        
    res_all = pp.compute_crossval_metric_2D_for_1CV(tblock_out,isf_out,tblock_dim,isf_dim,inputpath_CVinput,path_orig_data,norm_method,TS_opt,mod_size+'_'+exp_name,'experiments/',input_vars=input_list,verbose=False)
    res_all_list.append(res_all)
    
    if (tblock_out == max(tblock_dim)):
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
        res_all_CV = xr.concat(res_all_list, dim='time')
        res_all_CV.to_netcdf(outputpath_melt_nn + 'evalmetrics_2D_CV_'+mod_size+'_'+exp_name+'_'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')

In [None]:
#### CV over shelves
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/CV_ISF/'

t_list = []
res_all_list = None
#nemo_run_old = 'OPM006'
nemo_run_old = 'OPM021'

#for tt in tblock_dim:
for tt in [11,12,13]:
    nemo_run = pp.identify_nemo_run_from_tblock(tt)
    print(nemo_run)
    
    if (nemo_run_old != nemo_run):
        print(t_list)
        
        for isf_out in tqdm(isf_dim):
        #for isf_out in tqdm([44,66]):
            
            tblock_out = 0
        
            res_all = pp.compute_crossval_metric_2D_for_1CV(tblock_out,isf_out,t_list,isf_dim,inputpath_CVinput,path_orig_data,norm_method,TS_opt,mod_size+'_'+exp_name,'experiments/',input_vars=input_list,verbose=False)
            if res_all_list is None:
                res_all_list = res_all
            else:
                res_all_list = res_all_list.combine_first(res_all)
            
        res_all_CV = res_all_list
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
        res_all_CV.to_netcdf(outputpath_melt_nn + 'evalmetrics_2D_CV_'+mod_size+'_'+exp_name+'_'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')
        res_all_list = None
        
        t_list = []
        
    t_list.append(tt)
    nemo_run_old = nemo_run
    
    if (tt == max(tblock_dim)):
        print(t_list)
        
        for isf_out in tqdm(isf_dim):
            
            tblock_out = 0
        
            res_all = pp.compute_crossval_metric_2D_for_1CV(tblock_out,isf_out,t_list,isf_dim,inputpath_CVinput,path_orig_data,norm_method,TS_opt,mod_size+'_'+exp_name,'experiments/',input_vars=input_list,verbose=False)
            if res_all_list is None:
                res_all_list = res_all
            else:
                res_all_list = res_all_list.combine_first(res_all)
            
        res_all_CV = res_all_list
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
        res_all_CV.to_netcdf(outputpath_melt_nn + 'evalmetrics_2D_CV_'+mod_size+'_'+exp_name+'_'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')

            

In [None]:
res_all_CV['reference_melt'].mean('time').plot()

In [None]:
CV_type = 'shelves'
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/CV_ISF/'
path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/CV_ISF/'


y_all_isf = None
nemo_run_old = 'OPM006'

tblock_out = 0

for tblock in tqdm(tblock_dim):

    nemo_run = pp.identify_nemo_run_from_tblock(tblock)

    if (nemo_run_old != nemo_run):
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
        y_all_isf.to_netcdf(outputpath_melt_nn + 'melt_2D_'+mod_size+'_TS'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')
        y_all_isf = None

    nemo_run_old = nemo_run
    inputpath_mask = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_'+nemo_run+'/'

    for isf_out in isf_dim:

        file_isf = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_new_oneFRIS.nc')

        # original index
        df_nrun = pd.read_csv(path_orig_data + 'dataframe_input_isf'+str(isf_out).zfill(3)+'_'+str(tblock).zfill(3)+'.csv',index_col=[0,1,2])

        norm_metrics_file = xr.open_dataset(inputpath_CVinput + 'metrics_norm_CV_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'.nc')
        norm_metrics = norm_metrics_file.sel(norm_method=norm_method).drop('norm_method').to_dataframe()
        val_norm = pp.normalise_vars(df_nrun,
                                    norm_metrics.loc['mean_vars'],
                                    norm_metrics.loc['range_vars'])

        x_val_norm = val_norm.drop(['melt_m_ice_per_y'], axis=1)
        y_val_norm = val_norm['melt_m_ice_per_y']

        model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')

        y_out_norm = model.predict(x_val_norm.values)

        y_out_norm_xr = xr.DataArray(data=y_out_norm.squeeze()).rename({'dim_0': 'index'})
        y_out_norm_xr = y_out_norm_xr.assign_coords({'index': x_val_norm.index})

        # denormalise the output
        y_out = pp.denormalise_vars(y_out_norm_xr, 
                                 norm_metrics['melt_m_ice_per_y'].loc['mean_vars'],
                                 norm_metrics['melt_m_ice_per_y'].loc['range_vars'])

        y_out_pd_s = pd.Series(y_out.values,index=df_nrun.index,name='predicted_melt') 
        y_target_pd_s = pd.Series(df_nrun['melt_m_ice_per_y'].values,index=df_nrun.index,name='reference_melt') 

        # put some order in the file
        y_out_xr = y_out_pd_s.to_xarray()
        y_target_xr = y_target_pd_s.to_xarray()
        y_to_compare = xr.merge([y_out_xr.T, y_target_xr.T]).sortby('y')

        y_whole_grid = y_to_compare.reindex_like(file_isf['ISF_mask'])
        if y_all_isf is None:
            y_all_isf = y_whole_grid
        else:
            y_all_isf = y_all_isf.combine_first(y_whole_grid)


    if (tblock == max(tblock_dim)):
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run)
        y_all_isf.to_netcdf(outputpath_melt_nn + 'melt_2D_'+mod_size+'_TS'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')


INTEGRATED METRICS

In [None]:
#CV_type = 'shelves'
#outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/CV_ISF/'

CV_type = 'time'
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/CV_TBLOCKS/'

for nemo_run in ['OPM006', 'OPM016', 'OPM018', 'OPM021']:  #
#for nemo_run in ['OPM021']:
    
    verbose=True
    tuning_mode = False
    file_isf, geometry_info_2D, box_charac_2D, box_charac_1D, isf_stack_mask = pp.read_input_evalmetrics_NN(nemo_run)
    nisf_list = file_isf.Nisf

    melt2D = xr.open_dataset(outputpath_melt_nn + 'melt_2D_'+mod_size+'_TS'+TS_opt+'_norm'+norm_method+'_'+nemo_run+'.nc')
    
    if verbose:
        time_start = time.time()
        print('WELCOME! AS YOU WISH, I WILL COMPUTE THE EVALUATION METRICS FOR '+str(len(nisf_list))+' ICE SHELVES')

    if verbose:
        list_loop = tqdm(nisf_list)
    else:
        list_loop = nisf_list

    if box_charac_2D and box_charac_1D:
        box_loc_config2 = box_charac_2D['box_location'].sel(box_nb_tot=box_charac_1D['nD_config'].sel(config=2))
        box1 = box_loc_config2.where(box_loc_config2==1).isel(Nisf=0).drop('Nisf')

    melt1D_Gt_per_yr_list = []
    if not tuning_mode:
        melt1D_myr_box1_list = []

    for kisf in list_loop:
        #print(kisf, n)
        geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
        melt_rate_2D_isf_m_per_y = dfmt.choose_isf(melt2D,isf_stack_mask, kisf)

        melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12
        melt1D_Gt_per_yr_list.append(melt_rate_1D_isf_Gt_per_y)

        box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
        param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

        melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['isfdraft_conc'])     
        melt1D_myr_box1_list.append(melt_rate_1D_isf_myr_box1_mean)

    melt1D_Gt_per_yr = xr.concat(melt1D_Gt_per_yr_list, dim='Nisf')
    melt1D_myr_box1 = xr.concat(melt1D_myr_box1_list, dim='Nisf')

    #melt1D_Gt_per_yr_ds = melt1D_Gt_per_yr.to_dataset(name='melt_1D_Gt_per_y')
    #melt1D_myr_box1_ds = melt1D_myr_box1.to_dataset(name='melt_1D_mean_myr_box1')
    out_1D = xr.concat([melt1D_Gt_per_yr, melt1D_myr_box1], dim='metrics').assign_coords({'metrics': ['Gt','box1']})

    if verbose:
        timelength = time.time() - time_start
        print("I AM DONE! IT TOOK: "+str(round(timelength,2))+" seconds.")
        
    out_1D.drop('config').drop('box_nb_tot').to_netcdf(outputpath_melt_nn + 'evalmetrics_1D_'+mod_size+'_TS'+TS_opt+'_norm'+norm_method+'_'+nemo_run+'.nc')

In [None]:
np.sqrt((((out_1D['predicted_melt'] - out_1D['reference_melt']).sel(metrics='box1'))**2).mean())

In [None]:
y_all_isf['predicted_melt'].mean('time').plot()