In [None]:
"""
Created on Fri Oct 21 16:58 2022

Look if with training I can match exactly latlon

Author: Clara Burgard
"""

In [None]:
import numpy as np
import xarray as xr
import pandas as pd
from tqdm.notebook import trange, tqdm
import glob
import datetime
import time
import sys

import tensorflow as tf
from tensorflow import keras
import basal_melt_neural_networks.model_functions as modf
import basal_melt_neural_networks.prep_input_data as indat

In [None]:
######### READ IN OPTIONS

mod_size = 'medium' #'mini', 'small', 'medium', 'large', 'extra_large'
tblock_out = 1
isf_out = 0
TS_opt = 'extrap' # extrap, whole, thermocline
norm_method = 'std' # std, interquart, minmax

In [None]:
######### READ IN DATA

inputpath_data = '/bettik/burgardc/DATA/NN_PARAM/interim/INPUT_DATA/'
outputpath_nn_models = '/bettik/burgardc/DATA/NN_PARAM/interim/NN_MODELS/'
outputpath_doc = '/bettik/burgardc/SCRIPTS/basal_melt_neural_networks/custom_doc/'

tblock_dim = range(1,14)
isf_dim = [10,11,12,13,18,22,23,24,25,30,31,33,38,39,40,42,43,44,45,47,48,51,52,53,54,55,58,61,65,66,69,70,71,73,75]

if (tblock_out > 0) and (isf_out == 0):
    path_model = outputpath_nn_models+'CV_TBLOCK/'
    
elif (isf_out > 0) and (tblock_out == 0):
    path_model = outputpath_nn_models+'CV_ISF/'
    
else:
    print("I do not know what to do with both tblock and isf left out! ")

#new_path_doc = outputpath_doc+timetag+'/'
#if not os.path.isdir(new_path_doc):
#    print("I did not find this folder ("+timetag+") in doc folder! :( ")

inputpath_CVinput = inputpath_data+'EXTRAPOLATED_ISFDRAFT_CHUNKS_CV/'
    
input_data_train_norm = xr.open_dataset(inputpath_CVinput + 'train_data_CV_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'.nc')
input_data_val_norm = xr.open_dataset(inputpath_CVinput + 'val_data_CV_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'.nc') 
latlon_train_norm = xr.open_dataset(inputpath_CVinput + 'trainlatlon_data_CV_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'.nc')
latlon_val_norm = xr.open_dataset(inputpath_CVinput + 'vallatlon_data_CV_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'.nc')

In [None]:
## prepare input and target

x_train_norm = input_data_train_norm.drop_vars(['melt_m_ice_per_y','theta_in','salinity_in']).sel(norm_method=norm_method).to_array().load()
y_train_norm = latlon_train_norm.drop_vars(['salinity_in']).sel(norm_method=norm_method).to_array().load()

x_val_norm = input_data_val_norm.drop_vars(['melt_m_ice_per_y','theta_in','salinity_in']).sel(norm_method=norm_method).to_array().load()
y_val_norm = latlon_val_norm.drop_vars(['salinity_in']).sel(norm_method=norm_method).to_array().load()

In [None]:
######### TRAIN THE MODEL

input_size = x_train_norm.values.shape[0]
activ_fct = 'relu' #LeakyReLU
epoch_nb = 100
batch_siz = 512

model = modf.get_model(mod_size, input_size, activ_fct, 2)


reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5,
                              patience=3, min_lr=0.0000001, min_delta=0.0005) #, min_delta=0.1
            
early_stop = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    #min_delta=0.000001,
    patience=10,
    verbose=0,
    mode="auto",
    baseline=None,
    restore_best_weights=True,
)

time_start = time.time()
time_start0 = datetime.datetime.now()
print(time_start0)

history = model.fit(x_train_norm.T.values,
                    y_train_norm.T.values,
                    epochs          = epoch_nb,
                    batch_size      = batch_siz,
                    validation_data = (x_val_norm.T.values, y_val_norm.T.values),
                   callbacks=[reduce_lr, early_stop])
time_end = time.time()
timelength = time_end - time_start

time_end0 = datetime.datetime.now()
print(time_end0)

In [None]:
model.save(path_model + 'model_medium_latlon.h5')


In [None]:
y_out_norm = model.predict(x_val_norm.T.values)
y_out_norm_xr = xr.DataArray(data=y_out_norm.squeeze()).rename({'dim_0': 'index'})
y_out_norm_xr = y_out_norm_xr.assign_coords({'index': x_val_norm.index})

norm_metrics_file = xr.open_dataset(inputpath_CVinput + 'metricslatlon_norm_CV_noisf'+str(isf_out).zfill(3)+'_notblock'+str(tblock_out).zfill(3)+'.nc')
norm_metrics = norm_metrics_file.sel(norm_method=norm_method).drop('norm_method').to_dataframe()

# denormalise the output
y_out = pp.denormalise_vars(y_out_norm_xr, 
                         norm_metrics['latitude','longitude'].loc['mean_vars'],
                         norm_metrics['latitude','longitude'].loc['range_vars'])

In [None]:
y_out_pd_s = pd.Series(y_out.values,index=df_nrun.index,name='predicted_melt') 
y_target_pd_s = pd.Series(df_nrun['melt_m_ice_per_y'].values,index=df_nrun.index,name='reference_melt') 

# put some order in the file
y_out_xr = y_out_pd_s.to_xarray()
y_target_xr = y_target_pd_s.to_xarray()
y_to_compare = xr.merge([y_out_xr.T, y_target_xr.T]).sortby('y')