In [None]:
"""
Created on Fri Sep 09 14:24 2022

Convert "raw output" from the model to melt Gt per y to compute the RMSE ultimately using cross-validation results

Author: @claraburgard

"""

In [None]:
import numpy as np
import xarray as xr
from tqdm.notebook import trange, tqdm
import glob
import matplotlib as mpl
import seaborn as sns
import datetime
import time
import os,sys

import tensorflow as tf
from tensorflow import keras
from contextlib import redirect_stdout

from basal_melt_neural_networks.constants import *
import basal_melt_neural_networks.diagnostic_functions as diag
import basal_melt_neural_networks.data_formatting as dfmt
import basal_melt_neural_networks.postprocessing_functions as pp
from basal_melt_param.constants import *

DEFINE OPTIONS

In [None]:
mod_size =  'medium' #'mini', 'small', 'medium', 'large', 'extra_large'
TS_opt = 'extrap' # extrap, whole, thermocline
norm_method =  'std' # std, interquart, minmax

In [None]:
tblock_dim = range(1,14)
isf_dim = [10,11,12,13,18,22,23,24,25,30,31,33,38,39,40,42,43,44,45,47,48,51,52,53,54,55,58,61,65,66,69,70,71,73,75]

READ IN DATA

In [None]:
inputpath_data_nn = '/bettik/burgardc/DATA/NN_PARAM/interim/INPUT_DATA/'


In [None]:
if TS_opt == 'extrap':
    inputpath_CVinput = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS_CV/'
    path_orig_data = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS/'
elif TS_opt == 'whole':
    inputpath_CVinput = inputpath_data_nn+'WHOLE_PROF_CHUNKS_CV/'
    path_orig_data = inputpath_data_nn+'WHOLE_PROF_CHUNKS/'
elif TS_opt == 'thermocline':
    inputpath_CVinput = inputpath_data_nn+'THERMOCLINE_CHUNKS_CV/'

APPLY MODEL

In [None]:
tblock_dim = range(1,14)
isf_dim = [10,11,12,13,18,22,23,24,25,30,31,33,38,39,40,42,43,44,45,47,48,51,52,53,54,55,58,61,65,66,69,70,71,73,75]

#### CV over time
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/CV_TBLOCKS/'

res_all_list = []
nemo_run_old = 'OPM006'
for tblock_out in tqdm(tblock_dim):

    isf_out = 0

    nemo_run = pp.identify_nemo_run_from_tblock(tblock_out)

    if (nemo_run_old != nemo_run):
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
        res_all_CV = xr.concat(res_all_list, dim='time')
        res_all_CV.to_netcdf(outputpath_melt_nn + 'evalmetrics_1D_CV_'+mod_size+'testwoconstants_TS'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')
        res_all_list = []
    
    nemo_run_old = nemo_run
        
    res_all = pp.compute_crossval_metric_1D_for_1CV(tblock_out,isf_out,tblock_dim,isf_dim,inputpath_CVinput,path_orig_data,norm_method,TS_opt,mod_size,drop_vars=['melt_m_ice_per_y','isf_area','entry_depth_max'],verbose=False)
    res_all_list.append(res_all)
    
    if (tblock_out == max(tblock_dim)):
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
        res_all_CV = xr.concat(res_all_list, dim='time')
        res_all_CV.to_netcdf(outputpath_melt_nn + 'evalmetrics_1D_CV_'+mod_size+'testwoconstants_TS'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')


In [None]:
#### CV over shelves
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/CV_ISF/'

t_list = []
res_all_list = []
nemo_run_old = 'OPM006'

for tt in tblock_dim:
    nemo_run = pp.identify_nemo_run_from_tblock(tt)
    print(nemo_run)
    
    if (nemo_run_old != nemo_run):
        print(t_list)
        
        for isf_out in tqdm(isf_dim):
            
            tblock_out = 0
        
            res_all = pp.compute_crossval_metric_1D_for_1CV(tblock_out,isf_out,t_list,isf_dim,inputpath_CVinput,path_orig_data,norm_method,TS_opt,mod_size,drop_vars=['melt_m_ice_per_y','isf_area','entry_depth_max'],verbose=False)
            res_all_list.append(res_all)
            
        res_all_CV = xr.concat(res_all_list, dim='Nisf')
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
        res_all_CV.to_netcdf(outputpath_melt_nn + 'evalmetrics_1D_CV_'+mod_size+'testwoconstants_TS'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')
        res_all_list = []
        
        t_list = []
        
    t_list.append(tt)
    nemo_run_old = nemo_run
    
    if (tt == max(tblock_dim)):
        print(t_list)
        
        for isf_out in tqdm(isf_dim):
            
            tblock_out = 0
        
            res_all = pp.compute_crossval_metric_1D_for_1CV(tblock_out,isf_out,t_list,isf_dim,inputpath_CVinput,path_orig_data,norm_method,TS_opt,mod_size,drop_vars=['melt_m_ice_per_y','isf_area','entry_depth_max'],verbose=False)
            res_all_list.append(res_all)
            
        res_all_CV = xr.concat(res_all_list, dim='Nisf')
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
        res_all_CV.to_netcdf(outputpath_melt_nn + 'evalmetrics_1D_CV_'+mod_size+'testwoconstants_TS'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')

            

In [None]:
res_all_CV = xr.concat(res_all_list, dim='Nisf')

In [None]:
xx = range(100,140)
plt.scatter(res_all['reference_melt'].sel(metrics='Gt',Nisf=66),res_all['predicted_melt'].sel(metrics='Gt',Nisf=66))
plt.plot(xx,xx,'k-')