In [None]:
"""
Created on Fri Sep 09 14:24 2022

Convert "raw output" from the model to melt Gt per y to compute the RMSE ultimately using cross-validation results
for different input variables

Author: @claraburgard

"""

In [None]:
import numpy as np
import xarray as xr
from tqdm.notebook import trange, tqdm
import glob
import matplotlib as mpl
import seaborn as sns
import datetime
import time
import os,sys

import tensorflow as tf
from tensorflow import keras
from contextlib import redirect_stdout

from basal_melt_neural_networks.constants import *
import basal_melt_neural_networks.diagnostic_functions as diag
import basal_melt_neural_networks.data_formatting as dfmt
import basal_melt_neural_networks.postprocessing_functions as pp
from basal_melt_param.constants import *

import basal_melt_neural_networks.model_functions as modf

DEFINE OPTIONS

In [None]:
mod_size =  'small' #'mini', 'small', 'medium', 'large', 'extra_large'
TS_opt = 'extrap' # extrap, whole, thermocline
norm_method =  'std' # std, interquart, minmax
exp_name = 'newbasic2' #'onlyTSdraftandslopeandconc' #'onlyTSdraftandslope' #'TSdraftbotandiceddandwcd' #'onlyTSisfdraft' #'TSdraftbotandiceddandwcdreldGL' #'TSTfdGLdIFwcd' #TSdraftslopereldGL

In [None]:
tblock_dim = range(1,14)
isf_dim = [10,11,12,13,18,22,23,24,25,30,31,33,38,39,40,42,43,44,45,47,48,51,52,53,54,55,58,61,65,66,69,70,71,73,75]

READ IN DATA

In [None]:
inputpath_data_nn = '/bettik/burgardc/DATA/NN_PARAM/interim/INPUT_DATA/'


In [None]:
if TS_opt == 'extrap':
    inputpath_CVinput = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS_CV/'
    path_orig_data = inputpath_data_nn+'EXTRAPOLATED_ISFDRAFT_CHUNKS/'
elif TS_opt == 'whole':
    inputpath_CVinput = inputpath_data_nn+'WHOLE_PROF_CHUNKS_CV/'
    path_orig_data = inputpath_data_nn+'WHOLE_PROF_CHUNKS/'
elif TS_opt == 'thermocline':
    inputpath_CVinput = inputpath_data_nn+'THERMOCLINE_CHUNKS_CV/'

In [None]:
isf_out = 66
tblock = 1
df_nrun = pd.read_csv(path_orig_data + 'dataframe_input_isf'+str(isf_out).zfill(3)+'_'+str(tblock).zfill(3)+'.csv',index_col=[0,1,2])

In [None]:
df_nrun

APPLY MODEL

In [None]:
tblock = 1
isf_out = 66
df_nrun_orig = pd.read_csv(path_orig_data + 'dataframe_input_isf'+str(isf_out).zfill(3)+'_'+str(tblock).zfill(3)+'.csv',index_col=[0,1,2])
df_nrun_addvar1 = pd.read_csv(path_orig_data + 'dataframe_addvar1_isf'+str(isf_out).zfill(3)+'_'+str(tblock).zfill(3)+'.csv',index_col=[0,1,2])

In [None]:
norm_metrics_file_addvar1 = xr.open_dataset(inputpath_CVinput + 'metrics_norm_addvar1_CV_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'.nc')
norm_metrics_file_addvar1 = norm_metrics_file_addvar1.drop('salinity_in')

In [None]:
norm_metrics_file_addvar1

In [None]:
tblock_dim = range(1,14)
isf_dim = [10,11,12,13,18,22,23,24,25,30,31,33,38,39,40,42,43,44,45,47,48,51,52,53,54,55,58,61,65,66,69,70,71,73,75]
#drop_variables = ['rel_dGL','water_col_depth','theta_bot','salinity_bot','dGL', 'dIF', 'bathy_metry', 'slope_bed_lon',
#       'slope_bed_lat', 'isf_area',
#       'entry_depth_max', 'isfdraft_conc', 'u_tide',
#       'melt_m_ice_per_y','slope_ice_lon', 'slope_ice_lat']

if TS_opt == 'whole':
    
    data_val_orig_norm = xr.open_dataset(inputpath_CVinput + 'val_data_CV_norm'+norm_method+'_noisf000_notblock001.nc')
    
    T_list = []
    S_list = []
    for kk in data_val_orig_norm.keys():
        #print(kk)
        if kk[0:2] == 'T_':
            T_list.append(kk)
        elif kk[0:2] == 'S_':
            S_list.append(kk)
    
    var_list = ['corrected_isfdraft','slope_ice_lon','slope_ice_lat']
    var_list[-1:0] = T_list 
    var_list[-1:0] = S_list 
    input_list = var_list

elif TS_opt == 'extrap':
    input_list = ['dGL','dIF','corrected_isfdraft','bathy_metry','slope_bed_lon','slope_bed_lat','slope_ice_lon','slope_ice_lat','theta_in','salinity_in','T_mean', 'S_mean', 'T_std', 'S_std']

In [None]:
len(input_list)

In [None]:
res_list = []
for nrun in ['OPM006','OPM016','OPM018','OPM021']:
    res_CVtime = xr.open_dataset(outputpath_melt_nn + 'evalmetrics_1D_CV_'+mod_size+'_'+exp_name+'_'+TS_opt+'_norm'+norm_method+'_'+nrun+'.nc')
    res_list.append(res_CVtime)

In [None]:
res_all = xr.concat(res_list,dim='nrun')

In [None]:
res_diff = res_all['predicted_melt'].sel(metrics='Gt') - res_all['reference_melt'].sel(metrics='Gt')

In [None]:
res_all['reference_melt'].sel(metrics='Gt').sel(Nisf=11,nrun=0).plot()
res_all['predicted_melt'].sel(metrics='Gt').sel(Nisf=11,nrun=0).plot()

In [None]:
res_all['predicted_melt'].sel(metrics='Gt')

In [None]:
np.sqrt((res_diff**2).mean())

In [None]:
#### CV over shelves
outputpath_melt_nn = '/bettik/burgardc/DATA/NN_PARAM/processed/MELT_RATE/CV_ISF/'

t_list = []
res_all_list = []
nemo_run_old = 'OPM006'

#for tt in tblock_dim:
for tt in range(1,6):
    nemo_run = pp.identify_nemo_run_from_tblock(tt)
    print(nemo_run)
    
    if (nemo_run_old != nemo_run):
        print(t_list)
        
        for isf_out in tqdm(isf_dim):
            
            tblock_out = 0
        
            res_all = pp.compute_crossval_metric_1D_for_1CV(tblock_out,isf_out,t_list,isf_dim,inputpath_CVinput,path_orig_data,norm_method,TS_opt,mod_size+'_'+exp_name,'experiments/',input_vars=input_list,verbose=False)
            res_all_list.append(res_all)
            
        res_all_CV = xr.concat(res_all_list, dim='Nisf')
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
        res_all_CV.to_netcdf(outputpath_melt_nn + 'evalmetrics_1D_CV_'+mod_size+'_'+exp_name+'_'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')
        res_all_list = []
        
        t_list = []
        
    t_list.append(tt)
    nemo_run_old = nemo_run
    
    if (tt == max(tblock_dim)):
        print(t_list)
        
        for isf_out in tqdm(isf_dim):
            
            tblock_out = 0
        
            res_all = pp.compute_crossval_metric_1D_for_1CV(tblock_out,isf_out,t_list,isf_dim,inputpath_CVinput,path_orig_data,norm_method,TS_opt,mod_size+'_'+exp_name,'experiments/',input_vars=input_list,verbose=False)
            res_all_list.append(res_all)
            
        res_all_CV = xr.concat(res_all_list, dim='Nisf')
        print('I AM SAVING RESULTS FOR NEMO RUN '+nemo_run_old)
        res_all_CV.to_netcdf(outputpath_melt_nn + 'evalmetrics_1D_CV_'+mod_size+'_'+exp_name+'_'+TS_opt+'_norm'+norm_method+'_'+nemo_run_old+'.nc')

            

In [None]:
res_all_CV = xr.concat(res_all_list, dim='Nisf')

In [None]:
xx = range(100,140)
plt.scatter(res_all['reference_melt'].sel(metrics='Gt',Nisf=66),res_all['predicted_melt'].sel(metrics='Gt',Nisf=66))
plt.plot(xx,xx,'k-')