# Train NN to generate total latent heat given only forcing data

In [None]:
%pylab inline
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
import pandas as pd
from glob import glob
import xarray as xr
import seaborn as sns
from IPython.display import SVG
from tqdm.keras import TqdmCallback
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow import keras
from tensorflow.keras.utils import plot_model, model_to_dot
from tensorflow.keras import layers
import dask.dataframe as dd
from tensorflow.keras.callbacks import Callback, EarlyStopping

from sklearn.preprocessing import FunctionTransformer
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.pipeline import Pipeline

sns.set_context('talk')
mpl.style.use('seaborn-bright')
mpl.rcParams['figure.figsize'] = (18, 12)
def cube(x):
    return np.power(x, 3)

#strategy = tf.distribute.MirroredStrategy()
#print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
dtype='float32'
K.set_floatx(dtype)

In [None]:
sites = os.listdir('../sites')
bad_sites = []
sim_sites = [s for s in sites if s not in bad_sites]

seed = 50334
np.random.seed(seed)
np.random.shuffle(sim_sites)
len(sim_sites)

In [None]:
site_dict = {s: xr.open_dataset(f'../sites/{s}/forcings/{s}.nc').isel(hru=0, drop=True).load() 
             for s in sim_sites}

In [None]:
site_attr = {s: xr.open_dataset(f'../sites/{s}/params/local_attributes.nc').isel(hru=0, drop=True).load() 
             for s in sim_sites}

In [None]:
site_forc = {s: xr.open_dataset(f'../sites/{s}/output/template_output_{s}_timestep.nc').isel(hru=0, drop=True).load() 
             for s in sim_sites}

In [None]:
site_parm = {s: xr.open_dataset(f'../sites/{s}/params/parameter_trial.nc').isel(hru=0, drop=True).load() 
             for s in sim_sites}

In [None]:
nfold = 4
kfold_test_sites = np.array(sim_sites).reshape(nfold, -1)

kfold_train_sites = np.vstack([
    list(set(sim_sites) - set(test_sites)) for test_sites in kfold_test_sites
])

In [None]:
def etl_single_site(in_ds, attr_ds, parm_ds, use_mask=True):
   
    #---------------------------------------------------------------------------
    # Forcings
    #---------------------------------------------------------------------------
    airtemp   = (((in_ds['airtemp'].values / 27.315) - 10) / 2) + 0.5
    spechum   = (in_ds['spechum'].values * 50)  
    swradatm  = np.cbrt((in_ds['SWRadAtm'].values / 1000) )
    lwradatm  = in_ds['LWRadAtm'].values / (2 * 273.16)
    pptrate   = 10 * np.cbrt(in_ds['pptrate'].values)
    airpres   = (10 - (in_ds['airpres'].values / 10132.5)) / 2
    windspd   = in_ds['windspd'].values / 10
    mask      = in_ds['gap_filled'].values
    
    #---------------------------------------------------------------------------
    # Parameters
    #---------------------------------------------------------------------------
    soiltype = attr_ds['soilTypeIndex'].values[()] * np.ones_like(mask) / 12
    vegtype = attr_ds['vegTypeIndex'].values[()] * np.ones_like(mask) / 12

    canheight = parm_ds['heightCanopyTop'].values[()] * np.ones_like(mask) / 30
    vcmax = parm_ds['vcmax_Kn'].values[()] * np.ones_like(mask)
    canopyWettingFactor = parm_ds['canopyWettingFactor'].values[()] * np.ones_like(mask)
    thetasat = parm_ds['theta_sat'].values[()] * np.ones_like(mask)
    thetares = parm_ds['theta_res'].values[()] * np.ones_like(mask)
    laiscale = parm_ds['laiScaleParam'].values[()] * np.ones_like(mask) / 3
    rootdepth = parm_ds['rootingDepth'].values[()] * np.ones_like(mask) / 5
    
    #---------------------------------------------------------------------------
    # Arrange inputs and outputs
    #---------------------------------------------------------------------------    
    train_input = np.vstack([airtemp, spechum, swradatm, 
                             lwradatm, pptrate, airpres, 
                             windspd, vegtype, soiltype,
                             canheight, vcmax, canopyWettingFactor, 
                             thetasat, thetares, laiscale, rootdepth
                            ]).T 
    ebc       = -(in_ds['Qle_cor'].values + in_ds['Qh_cor'].values)
    train_output = np.vstack([in_ds['Qle_cor'].values / 500,
                              in_ds['Qh_cor'].values / 500,
                              ebc / 500]).T
    
    if use_mask:
        train_input = train_input[mask == 0]
        train_output = train_output[mask == 0]    
    return train_input.astype(np.float32), train_output.astype(np.float32)

In [None]:
all_train_input = []
all_valid_input = []
all_train_output = []
all_valid_output = []

for i in range(kfold_train_sites.shape[0]):
    print(i)
    # -----------------------------------------------
    # load in data, transform, and split for training
    # -----------------------------------------------
    train_set = kfold_train_sites[i]
    train_data = [etl_single_site(site_dict[s], site_attr[s], site_parm[s]) for s in train_set]

    train_input = np.vstack([td[0] for td in train_data])
    train_output = np.vstack([td[1] for td in train_data])
    
    index_shuffle = np.arange(train_output.shape[0])
    np.random.shuffle(index_shuffle)
    
    train_input = train_input[index_shuffle, :]
    train_output = train_output[index_shuffle, :]
    
    validation_frac = 0.2
    validation_start_idx = int(train_output.shape[0] * (1-validation_frac))

    train_input, valid_input = train_input[0:validation_start_idx, :], train_input[validation_start_idx:, :]
    train_output, valid_output = train_output[0:validation_start_idx], train_output[validation_start_idx:]
    
    assert np.isnan(train_input).sum() + np.isnan(train_output).sum() == 0
    all_train_input.append(train_input)
    all_valid_input.append(valid_input)
    all_train_output.append(train_output)
    all_valid_output.append(valid_output)
    

In [None]:
class LRHistory(keras.callbacks.Callback):
    """Simple callback for recording the learning rate curve"""
    def on_train_begin(self, logs={}):
        self.lr = []

    def on_epoch_end(self, batch, logs={}):
        self.lr.append(model.optimizer._decayed_lr(np.float32).numpy())          
        
def mse_eb(y_true, y_pred):
    # Normal MSE loss
    mse = K.mean(K.square(y_true[:, 0:2]-y_pred[:, 0:2]), axis=-1)
    # Loss that penalizes differences between sum(predictions) and sum(true) (energy balance constraint)
    sum_constraint = K.mean(K.square(K.sum(y_pred[:, 0:2], axis=-1) + y_true[:, 2] )) / 10
    return mse + sum_constraint

In [None]:
all_hist = []

for i in range(kfold_train_sites.shape[0]):
    # -----------------------------------------------
    # load in data, transform, and split for training
    # -----------------------------------------------
    train_input = all_train_input[i]
    train_output = all_train_output[i]
    valid_input = all_valid_input[i]
    valid_output = all_valid_output[i]
    
    # -----------------------------------------------
    # Define model hyperparameters
    # -----------------------------------------------

    
    loss = mse_eb
    activation = 'tanh'    
    width = 48
    dropout_rate = 0.1
    epochs = 200
    batch_size = 48 * 7
    learning_rate = 1.25e-2
    decay_rate = learning_rate / (epochs * epochs)
    optimizer = keras.optimizers.SGD(momentum=0.8, learning_rate=learning_rate, decay=decay_rate)

    # -----------------------------------------------
    # Define model structure
    # -----------------------------------------------
    model = keras.Sequential([
            layers.Dense(width, activation=activation, input_shape=train_input[0].shape),
            layers.Dropout(dropout_rate),
            layers.Dense(width, activation=activation),
            layers.Dense(width, activation=activation),
            layers.Dense(width, activation=activation),
            layers.Dense(width, activation=activation),
            layers.Dropout(dropout_rate),
            layers.Dense(width, activation=activation),
            layers.Dense(3, activation='linear')
        ])     
    model.compile(loss=loss, optimizer=optimizer)

    # -----------------------------------------------
    # Train the model
    # -----------------------------------------------
    history = model.fit(
        train_input, train_output,
        validation_data=(valid_input, valid_output),
        batch_size=batch_size, epochs=epochs, shuffle=True, verbose=0, 
        callbacks=[TqdmCallback(verbose=1), EarlyStopping(monitor='val_loss', patience=5), LRHistory()])
    all_hist.append(history)
    
    # -----------------------------------------------
    # Save the model and history
    # -----------------------------------------------
    # save the history     
    hist_df = pd.DataFrame(history.history) 
    hist_csv_file = f'../models/history_fluxes_qle_{i}.csv'
    with open(hist_csv_file, mode='w') as f:
        hist_df.to_csv(f)
    
    # save the model    
    model.save(f'../models/train_fluxes_qle_set_{i}.h5')
    from KerasWeightsProcessing.convert_weights import txt_to_h5, h5_to_txt
    h5_to_txt(
        weights_file_name=f'../models/train_fluxes_qle_set_{i}.h5', 
        output_file_name=f'../models/train_fluxes_qle_set_{i}.txt'
    )
