In [None]:
"""
Created on Fri Apr 15 11:16 2022

Convert "raw output" from the model to melt Gt per y to compute the RMSE ultimately

Author: @claraburgard

"""

In [None]:
import numpy as np
import xarray as xr
from tqdm.notebook import trange, tqdm
import glob
import matplotlib as mpl
import seaborn as sns
import datetime
import time
import os,sys

import tensorflow as tf
from tensorflow import keras
from contextlib import redirect_stdout

from basal_melt_neural_networks.constants import *
import basal_melt_neural_networks.diagnostic_functions as diag
import basal_melt_neural_networks.data_formatting as dfmt
from basal_melt_param.constants import *

READ IN DATA

In [None]:
nemo_run0 = 'OPM021'
if nemo_run0 in ['OPM031-1','OPM031-2']:
    nemo_run = 'OPM031'
else:
    nemo_run = nemo_run0

In [None]:
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/'
inputpath_boxes = '/bettik/burgardc/DATA/BASAL_MELT_PARAM/interim/BOXES/nemo_5km_'+nemo_run+'/'
inputpath_mask = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_'+nemo_run+'/'
inputpath_data='/bettik/burgardc/DATA/BASAL_MELT_PARAM/interim/NEMO_eORCA025.L121_'+nemo_run+'_ANT_STEREO/'
outputpath_melt = '/bettik/burgardc/DATA/BASAL_MELT_PARAM/processed/MELT_RATE/nemo_5km_'+nemo_run+'/'


In [None]:
file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_new.nc')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
file_isf = file_isf_nonnan.sel(Nisf=large_isf)

In [None]:
# make the domain a little smaller to make the computation even more efficient - file isf has already been made smaller at its creation
map_lim = [-3000000,3000000]
file_other = xr.open_dataset(inputpath_data+'corrected_draft_bathy_isf.nc')#, chunks={'x': chunk_size, 'y': chunk_size})
file_other_cut = dfmt.cut_domain_stereo(file_other, map_lim, map_lim)
file_conc = xr.open_dataset(inputpath_data+'isfdraft_conc_Ant_stereo.nc')
file_conc_cut = dfmt.cut_domain_stereo(file_conc, map_lim, map_lim)

In [None]:
ice_draft_pos = file_other_cut['corrected_isfdraft']
ice_draft_neg = -ice_draft_pos

In [None]:
box_charac_2D = xr.open_dataset(inputpath_boxes + 'nemo_5km_boxes_2D.nc')
box_charac_1D = xr.open_dataset(inputpath_boxes + 'nemo_5km_boxes_1D.nc')

PREPARE GEOMETRICAL INFO

In [None]:
isf_stack_mask = dfmt.create_stacked_mask(file_isf['ISF_mask'], file_isf.Nisf, ['y','x'], 'mask_coord')

In [None]:
verbose=True
nisf_list = file_isf.Nisf

file_isf_conc = file_conc_cut['isfdraft_conc']

xx = file_isf.x
yy = file_isf.y
dx = (xx[2] - xx[1]).values
dy = (yy[2] - yy[1]).values
grid_cell_area = abs(dx*dy)  
grid_cell_area_weighted = file_isf_conc * grid_cell_area

In [None]:
geometry_info_2D = xr.merge([ice_draft_pos.rename('ice_draft_pos'),
                            grid_cell_area_weighted.rename('grid_cell_area_weighted'),
                            file_isf_conc])

In [None]:
var_list = ['T_002',
       'T_003', 'T_004', 'T_005', 'T_006', 'T_007', 'T_008', 'T_009',
       'T_010', 'T_011', 'T_012', 'T_013', 'T_014', 'T_015', 'T_016',
       'T_017', 'T_018', 'T_019', 'T_020', 'T_021', 'T_022', 'T_023',
       'T_024', 'T_025', 'T_026', 'T_027', 'T_028', 'T_029', 'T_030',
       'T_031', 'T_032', 'T_033', 'T_034', 'T_035', 'T_036', 'T_037',
       'T_038', 'T_039', 'T_040', 'T_041', 'T_042', 'T_043', 'T_044',
       'T_045', 'T_046', 'T_047', 'T_048', 'T_049', 'T_050', 'T_051',
       'T_052', 'T_053', 'T_054', 'T_055', 'T_056', 'T_057', 'T_058',
       'T_059', 'T_060', 'T_061', 'T_062', 'T_063', 'T_064', 'T_065',
       'T_066', 'T_067', 'T_068', 'S_001', 'S_002', 'S_003', 'S_004',
       'S_005', 'S_006', 'S_007', 'S_008', 'S_009']

In [None]:
var_list = ['dGL', 'dIF', 'corrected_isfdraft', 'bathy_metry', 'slope_bed_lon',
       'slope_bed_lat', 'slope_ice_lon', 'slope_ice_lat', 'isf_area',
       'entry_depth_max', 'isfdraft_conc', 'u_tide', 'T_001', 'T_002',
       'T_003', 'T_004', 'T_005', 'T_006', 'T_007', 'T_008', 'T_009',
       'T_010', 'T_011', 'T_012', 'T_013', 'T_014', 'T_015', 'T_016',
       'T_017', 'T_018', 'T_019', 'T_020', 'T_021', 'T_022', 'T_023',
       'T_024', 'T_025', 'T_026', 'T_027', 'T_028', 'T_029', 'T_030',
       'T_031', 'T_032', 'T_033', 'T_034', 'T_035', 'T_036', 'T_037',
       'T_038', 'T_039', 'T_040', 'T_041', 'T_042', 'T_043', 'T_044',
       'T_045', 'T_046', 'T_047', 'T_048', 'T_049', 'T_050', 'T_051',
       'T_052', 'T_053', 'T_054', 'T_055', 'T_056', 'T_057', 'T_058',
       'T_059', 'T_060', 'T_061', 'T_062', 'T_063', 'T_064', 'T_065',
       'T_066', 'T_067', 'T_068', 'S_001', 'S_002', 'S_003', 'S_004',
       'S_005', 'S_006', 'S_007', 'S_008', 'S_009', 'S_010', 'S_011',
       'S_012', 'S_013', 'S_014', 'S_015', 'S_016', 'S_017', 'S_018',
       'S_019', 'S_020', 'S_021', 'S_022', 'S_023', 'S_024', 'S_025',
       'S_026', 'S_027', 'S_028', 'S_029', 'S_030', 'S_031', 'S_032',
       'S_033', 'S_034', 'S_035', 'S_036', 'S_037', 'S_038', 'S_039',
       'S_040', 'S_041', 'S_042', 'S_043', 'S_044', 'S_045', 'S_046',
       'S_047', 'S_048', 'S_049', 'S_050', 'S_051', 'S_052', 'S_053',
       'S_054', 'S_055', 'S_056', 'S_057', 'S_058', 'S_059', 'S_060',
       'S_061', 'S_062', 'S_063', 'S_064', 'S_065', 'S_066', 'S_067',
       'S_068', 'water_column']

In [None]:
#timetag_list = ['20220427-1051']
#timetag_list = ['20220427-1059','20220427-1021','20220427-1042','20220427-1051']
#timetag_list = ['20220427-1058','20220427-0957','20220427-1002']

for shuff_var in var_list:
    print(shuff_var)

    for timetag in ['20220427-1051']:

        print(timetag)

        new_path_output = outputpath_melt_nn+timetag+'/'
        nn_output_m_ice_per_y = xr.open_dataset(new_path_output+'NN_melt_predicted_reference_m_ice_per_yr_'+nemo_run0+'_shuffled'+shuff_var+'.nc')


        tuning_mode = False
        nisf_list = file_isf.Nisf

        if verbose:
            time_start = time.time()
            print('WELCOME! AS YOU WISH, I WILL COMPUTE THE EVALUATION METRICS FOR '+str(len(nisf_list))+' ICE SHELVES')

        if verbose:
            list_loop = tqdm(nisf_list)
        else:
            list_loop = nisf_list

        if box_charac_2D and box_charac_1D:
            box_loc_config2 = box_charac_2D['box_location'].sel(box_nb_tot=box_charac_1D['nD_config'].sel(config=2))
            box1 = box_loc_config2.where(box_loc_config2==1).isel(Nisf=0).drop('Nisf')
        #elif not box_charac_2D:
        #    return print('You have not given me the 2D box characteristics! :( ')
        #elif not box_charac_1D:
        #    return print('You have not given me the 1D box characteristics! :( ')

        melt1D_Gt_per_yr_list = []
        if not tuning_mode:
            melt1D_myr_box1_list = []

        for kisf in list_loop:
            #print(kisf, n)
            geometry_isf_2D = dfmt.choose_isf(geometry_info_2D,isf_stack_mask, kisf)
            melt_rate_2D_isf_m_per_y = dfmt.choose_isf(nn_output_m_ice_per_y['predicted_melt'].reindex_like(file_isf['ISF_mask']),isf_stack_mask, kisf)

            melt_rate_1D_isf_Gt_per_y = (melt_rate_2D_isf_m_per_y * geometry_isf_2D['grid_cell_area_weighted']).sum(dim=['mask_coord']) * rho_i / 10**12
            melt1D_Gt_per_yr_list.append(melt_rate_1D_isf_Gt_per_y)

            if not tuning_mode:
                box_loc_config_stacked = dfmt.choose_isf(box1, isf_stack_mask, kisf)
                param_melt_2D_box1_isf = melt_rate_2D_isf_m_per_y.where(np.isfinite(box_loc_config_stacked))

                melt_rate_1D_isf_myr_box1_mean = dfmt.weighted_mean(param_melt_2D_box1_isf,['mask_coord'], geometry_isf_2D['isfdraft_conc'])     
                melt1D_myr_box1_list.append(melt_rate_1D_isf_myr_box1_mean)

        melt1D_Gt_per_yr = xr.concat(melt1D_Gt_per_yr_list, dim='Nisf')
        if not tuning_mode:
            melt1D_myr_box1 = xr.concat(melt1D_myr_box1_list, dim='Nisf')

        melt1D_Gt_per_yr_ds = melt1D_Gt_per_yr.to_dataset(name='melt_1D_Gt_per_y')
        if not tuning_mode:
            melt1D_myr_box1_ds = melt1D_myr_box1.to_dataset(name='melt_1D_mean_myr_box1')
            out_1D = xr.merge([melt1D_Gt_per_yr_ds, melt1D_myr_box1_ds])
        else:
            out_1D = melt1D_Gt_per_yr_ds

        if verbose:
            timelength = time.time() - time_start
            print("I AM DONE! IT TOOK: "+str(round(timelength,2))+" seconds.")

        out_1D.to_netcdf(new_path_output+'eval_metrics_'+nemo_run0+'_shuffled'+shuff_var+'.nc')

In [None]:
out_1D

In [None]:
(nn_output_m_ice_per_y['predicted_melt'] - nonshuffled['predicted_melt']).min()

In [None]:
nonshuffled = xr.open_dataset(new_path_output+'NN_melt_predicted_reference_m_ice_per_yr_'+nemo_run0+'.nc')


In [None]:
nonshuffled_melt = xr.open_dataset(new_path_output+'eval_metrics_'+nemo_run0+'.nc')

In [None]:
nonshuffled_melt

In [None]:
##############################

In [None]:
np.sqrt(((out_1D['melt_1D_Gt_per_y'] - ref_Gt).mean())**2)

In [None]:
out_1D['melt_1D_Gt_per_y']

In [None]:
ref_Gt.T

In [None]:
plt.scatter(ref_Gt.T, out_1D['melt_1D_Gt_per_y'],alpha=0.1)

In [None]:
kisf = 10
out_1D['melt_1D_Gt_per_y'].sel(Nisf=kisf).plot(label='NN')
ref_Gt.sel(Nisf=kisf).plot(label='ref')
plt.legend()

In [None]:
NEMO_melt_rates_1D = xr.open_dataset(outputpath_melt+'melt_rates_1D_NEMO.nc')
ref_Gt = NEMO_melt_rates_1D['melt_Gt_per_y_tot']
NEMO_box1_myr = xr.open_dataset(outputpath_melt+'melt_rates_box1_NEMO.nc')
ref_box1 = NEMO_box1_myr['mean_melt_box1_myr']


In [None]:
out_1D['melt_1D_Gt_per_y'].Nisf

In [None]:
(out_1D['melt_1D_Gt_per_y'] - ref_Gt).plot(vmax=40)

In [None]:
out_1D['melt_1D_Gt_per_y'].sel(Nisf=66).plot()

In [None]:
nn_output_m_ice_per_y['predicted_melt'].where(file_isf['ISF_mask']==10, drop=True).mean('time').plot()

In [None]:
nn_output_m_ice_per_y['reference_melt'].where(file_isf['ISF_mask']==10, drop=True).mean('time').plot()