In [None]:
"""
Created on Fri Sep 09 14:24 2022

Convert "raw output" from the model to 2D maps

Author: @claraburgard

"""

In [None]:
import numpy as np
import xarray as xr
from tqdm.notebook import trange, tqdm
import glob
import matplotlib as mpl
import seaborn as sns
import datetime
import time
import os,sys

import tensorflow as tf
from tensorflow import keras
from contextlib import redirect_stdout

from basal_melt_neural_networks.constants import *
import basal_melt_neural_networks.diagnostic_functions as diag
import basal_melt_neural_networks.data_formatting as dfmt
import basal_melt_neural_networks.postprocessing_functions as pp
from basal_melt_param.constants import *

DEFINE OPTIONS

In [None]:
mod_size =  'medium' #'mini', 'small', 'medium', 'large', 'extra_large'
TS_opt = 'extrap' # extrap, whole, thermocline
norm_method =  'std' # std, interquart, minmax
exp_name = 'allbutconstants'#'onlyTSdraftandslope' #'onlyTSdraftandslope' #'TSdraftbotandiceddandwcd' #'onlyTSisfdraft' #'TSdraftbotandiceddandwcdreldGL' #TSdraftslopereldGL

In [None]:
tblock_dim = range(1970, 1970 + 40)
#isf_dim = [10,11,12,13,18,22,23,24,25,30,31,33,38,39,40,42,43,44,45,47,48,51,52,53,54,55,58,61,65,66,69,70,71,73,75]
#isf_dim = [10,11,12,13,18,22,23,25,30,31,33,39,40,42,43,44,45,47,51,55,58,61,65,66,69,70,71,73,75] # for bi646
isf_dim = [10,11,12,13,18,22,23,25,30,31,33,39,40,42,43,44,45,47,51,55,58,61,65,66,69,70,71,73,75]

In [None]:
nemo_run = 'bi646' #bf663

READ IN DATA

In [None]:
inputpath_data_nn = '/bettik/burgardc/DATA/NN_PARAM/interim/INPUT_DATA/'

In [None]:
if TS_opt == 'extrap':
    inputpath_CVinput = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS_CV/'
    inputpath_csv = inputpath_data_nn+'SMITH_'+nemo_run+'_EXTRAPDRAFT_CHUNKS/'
elif TS_opt == 'whole':
    inputpath_CVinput = inputpath_data_nn+'WHOLE_PROF_CHUNKS_CV/'
    path_orig_data = inputpath_data_nn+'WHOLE_PROF_CHUNKS/'
elif TS_opt == 'thermocline':
    inputpath_CVinput = inputpath_data_nn+'THERMOCLINE_CHUNKS_CV/'

APPLY MODEL

In [None]:
#input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat','isfdraft_conc','theta_in','salinity_in','u_tide']
input_vars = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat','isfdraft_conc','theta_in','salinity_in','u_tide']

In [None]:
def read_input_evalmetrics_NN_yy(nemo_run, tt, isf_list):
    inputpath_boxes = '/bettik/burgardc/DATA/NN_PARAM/interim/BOXES/SMITH_'+nemo_run+'/'
    inputpath_data='/bettik/burgardc/DATA/NN_PARAM/interim/SMITH_'+nemo_run+'/'
    inputpath_mask='/bettik/burgardc/DATA/NN_PARAM/interim/ANTARCTICA_IS_MASKS/SMITH_'+nemo_run+'/'
    
    file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_yy'+str(tt - 1970).zfill(2)+'.nc').drop('time')
    file_isf = file_isf_orig.sel(Nisf=isf_list)
    
    map_lim = [-3000000,3000000]
    file_mask_orig = xr.open_dataset(inputpath_data+'other_mask_vars_Ant_stereo.nc')
    file_mask_orig_cut = dfmt.cut_domain_stereo(file_mask_orig, map_lim, map_lim)
    file_mask_orig_cut = file_mask_orig_cut.assign_coords({'time': range(1970, 1970 + len(file_mask_orig_cut.time))}).sel(time=tt).drop('time')

    file_other = xr.open_dataset(inputpath_data+'corrected_draft_bathy_isf.nc')#, chunks={'x': chunk_size, 'y': chunk_size})
    file_other_cut = dfmt.cut_domain_stereo(file_other, map_lim, map_lim)
    file_other_cut = file_other_cut.assign_coords({'time': range(1970, 1970 + len(file_other_cut.time))}).sel(time=tt).drop('time')

    file_conc = xr.open_dataset(inputpath_data+'isfdraft_conc_Ant_stereo.nc')
    file_conc_cut = dfmt.cut_domain_stereo(file_conc, map_lim, map_lim)
    file_conc_cut = file_conc_cut.assign_coords({'time': range(1970, 1970 + len(file_conc_cut.time))}).sel(time=tt).drop('time')
    
    ice_draft_pos = file_other_cut['corrected_isfdraft']
    ice_draft_neg = -ice_draft_pos
    
    box_charac_2D = xr.open_dataset(inputpath_boxes + 'nemo_5km_boxes_2D_oneFRIS_yy'+str(tt - 1970).zfill(2)+'.nc')
    box_charac_1D = xr.open_dataset(inputpath_boxes + 'nemo_5km_boxes_1D_oneFRIS_yy'+str(tt - 1970).zfill(2)+'.nc')
    
    isf_stack_mask = dfmt.create_stacked_mask(file_isf['ISF_mask'], file_isf.Nisf, ['y','x'], 'mask_coord')

    file_isf_conc = file_conc_cut['isfdraft_conc']

    xx = file_isf.x
    yy = file_isf.y
    dx = (xx[2] - xx[1]).values
    dy = (yy[2] - yy[1]).values
    grid_cell_area = abs(dx*dy)  
    grid_cell_area_weighted = file_isf_conc * grid_cell_area
    
    geometry_info_2D = xr.merge([ice_draft_pos.rename('ice_draft_pos'),
                            grid_cell_area_weighted.rename('grid_cell_area_weighted'),
                            file_isf_conc])
    
    return file_isf, geometry_info_2D, box_charac_2D, box_charac_1D, isf_stack_mask

In [None]:
input_vars

In [None]:
### use any model from CV over time
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/SMITH_'+nemo_run+'/'
path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/CV_TBLOCK/'

for tt in tqdm(tblock_dim):
    isf_out = 0
    tblock_out = 5
    
    file_isf, geometry_info_2D, box_charac_2D, box_charac_1D, isf_stack_mask = read_input_evalmetrics_NN_yy(nemo_run, tt, isf_dim)

    res_2D_all = None
    for kisf in isf_dim: 

        norm_metrics_file = xr.open_dataset(inputpath_CVinput + 'metrics_norm_CV_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'.nc')
        #norm_metrics_file_addvar1 = xr.open_dataset(inputpath_CVinput + 'metrics_norm_addvar1_CV_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'.nc')
        #norm_metrics_file_addvar1 = norm_metrics_file_addvar1.drop('salinity_in')
        #norm_metrics_file = xr.merge([norm_metrics_file_orig,norm_metrics_file_addvar1])
        norm_metrics = norm_metrics_file.sel(norm_method=norm_method).drop('norm_method').to_dataframe()

        df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_isf'+str(kisf).zfill(3)+'_'+str(tt).zfill(2)+'_'+nemo_run+'.csv',index_col=[0,1])
        #df_nrun_addvar1 = pd.read_csv(path_orig_data + 'dataframe_addvar1_isf'+str(kisf).zfill(3)+'_'+str(tblock_out).zfill(3)+'.csv',index_col=[0,1,2])
        #df_nrun_addvar1 = df_nrun_addvar1.drop(['salinity_in'], axis=1)
        #df_nrun = pd.concat([df_nrun_orig,df_nrun_addvar1],join = 'outer', axis = 1)

        model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')

        res_2D = pp.apply_NN_results_2D_1isf_1tblock(file_isf, norm_metrics, df_nrun, model, input_vars)

        if res_2D_all is None:
            res_2D_all = res_2D
        else:
            res_2D_all = res_2D_all.combine_first(res_2D)
    
    res_2D_all.to_netcdf(outputpath_melt_nn + 'evalmetrics_2D_'+mod_size+'_'+exp_name+'_'+TS_opt+'_norm'+norm_method+'_'+str(tt)+'_'+nemo_run+'.nc')
#    if (tblock_out == max(tblock_dim)):
#        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
#        res_all_CV = xr.concat(res_all_list, dim='time')
#        res_all_CV.to_netcdf(outputpath_melt_nn + 'evalmetrics_2D_CV_'+mod_size+'_'+exp_name+'_'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')

In [None]:
#### CV over time
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/SMITH_'+nemo_run+'/'
path_model = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/experiments/CV_TBLOCK/'

res_1D_allyy_list = []
for tt in tqdm(tblock_dim):
    isf_out = 0
    tblock_out = 5
    
    file_isf, geometry_info_2D, box_charac_2D, box_charac_1D, isf_stack_mask = read_input_evalmetrics_NN_yy(nemo_run, tt, isf_dim)
    
    res_1D_list = []
    for kisf in isf_dim:  
    #for kisf in [66]:  


        norm_metrics_file = xr.open_dataset(inputpath_CVinput + 'metrics_norm_CV_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'.nc')
        #norm_metrics_file_addvar1 = xr.open_dataset(inputpath_CVinput + 'metrics_norm_addvar1_CV_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'.nc')
        #norm_metrics_file_addvar1 = norm_metrics_file_addvar1.drop('salinity_in')
        #norm_metrics_file = xr.merge([norm_metrics_file_orig,norm_metrics_file_addvar1])
        norm_metrics = norm_metrics_file.sel(norm_method=norm_method).drop('norm_method').to_dataframe()

        df_nrun = pd.read_csv(inputpath_csv + 'dataframe_input_isf'+str(kisf).zfill(3)+'_'+str(tt)+'_'+nemo_run+'.csv',index_col=[0,1])
        #df_nrun_addvar1 = pd.read_csv(path_orig_data + 'dataframe_addvar1_isf'+str(kisf).zfill(3)+'_'+str(tblock_out).zfill(3)+'.csv',index_col=[0,1,2])
        #df_nrun_addvar1 = df_nrun_addvar1.drop(['salinity_in'], axis=1)
        #df_nrun = pd.concat([df_nrun_orig,df_nrun_addvar1],join = 'outer', axis = 1)

        model = keras.models.load_model(path_model + 'model_nn_'+mod_size+'_'+exp_name+'_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'_TS'+TS_opt+'_norm'+norm_method+'.h5')

        res_1D = pp.evalmetrics_1D_NN(kisf, norm_metrics, df_nrun, model, file_isf, geometry_info_2D, box_charac_2D, box_charac_1D, isf_stack_mask, input_vars)    
        res_1D_list.append(res_1D)

    res_1D_all = xr.concat(res_1D_list, dim='Nisf')
    res_1D_allyy_list.append(res_1D_all)

res_1D_allyy = xr.concat(res_1D_allyy_list, dim='time')
res_1D_allyy.to_netcdf(outputpath_melt_nn + 'evalmetrics_1D_'+mod_size+'_'+exp_name+'_'+TS_opt+'_norm'+norm_method+'_allyy_'+nemo_run+'.nc')
  


In [None]:
res_1D_allyy.sel(Nisf=66)

In [None]:
res_1D_allyy['predicted_melt'].sel(Nisf=66, metrics='Gt').plot(label='predicted')
res_1D_allyy['reference_melt'].sel(Nisf=66, metrics='Gt').plot()
plt.legend()

In [None]:
np.sqrt((((res_1D_allyy['predicted_melt'] - res_1D_allyy['reference_melt']).sel(metrics='Gt'))**2).mean())

In [None]:
y_all_isf['predicted_melt'].mean('time').plot()

In [None]:
res_2D_all = xr.open_dataset(outputpath_melt_nn + 'evalmetrics_2D_'+mod_size+'_'+exp_name+'_'+TS_opt+'_norm'+norm_method+'_'+str(tt)+'_'+nemo_run+'.nc')

In [None]:
(res_2D_all['predicted_melt'] - res_2D_all['reference_melt']).plot(vmax=10)

In [None]:
geometry_info_2D['ice_draft_pos'].where(file_isf['ISF_mask'] == 66, drop=True).plot()

In [None]:
res_2D_all['reference_melt'].where(file_isf['ISF_mask'] == 66, drop=True).plot(vmax=1)