In [1]:
import tensorflow as tf
import xarray as xr
import fnmatch
import pandas as pd
import numpy as np
import os
import keras
from surgeNN import io, preprocessing
from surgeNN.io import train_predict_output_to_ds, setup_output_dirs, add_loss_to_output
from surgeNN.denseLoss import get_denseloss_weights #if starting with a clean environment, first, in terminal, do->'mamba install kdepy'
from surgeNN.evaluation import add_error_metrics_to_prediction_ds
from surgeNN.models import build_LSTM_stacked,build_LSTM_stacked_multioutput_static
from tqdm import tqdm
import itertools
import random
from functools import reduce
import gcsfs
fs = gcsfs.GCSFileSystem() #list stores, stripp zarr from filename, load 

import gc #callback to clean up garbage after each epoch, not sure if strictly necessary (usage: callbacks = [GC_Callback()])
class GC_Callback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()

2025-07-10 08:24:21.742720: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-07-10 08:24:21.824237: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
#settings
tgs        = ['stavanger-svg-nor-nhs.csv','wick-wic-gbr-bodc.csv','esbjerg-esb-dnk-dmi.csv',
              'immingham-imm-gbr-bodc.csv','den_helder-denhdr-nld-rws.csv', 'fishguard-fis-gbr-bodc.csv',  
              'brest-822a-fra-uhslc.csv', 'vigo-vigo-esp-ieo.csv',  'alicante_i_outer_harbour-alio-esp-da_mm.csv'] #all tide gauges to process
tgs = ['den_helder-denhdr-nld-rws.csv','brest-822a-fra-uhslc.csv']
region_name = 'test'
model_architecture = 'lstm'
predictor_degrees = 5

#i/o
predictor_path  = 'gs://leap-persistent/timh37/era5_predictors/3hourly'
predictand_path = '/home/jovyan/test_surge_models/input/t_tide_3h_hourly_deseasoned_predictands'
splits = 'chronological' #'stratified'
output_dir = os.path.join('/home/jovyan/test_surge_models/results/nns_highresmip/gesla3',splits) #'/home/jovyan/test_surge_models/results/nns/'
store_model = 0#1 #whether to store the tensorflow models
temp_freq = 3 # [hours] temporal frequency to use

#training
predictor_vars = ['msl','u10','v10'] #variables to use
n_runs = 1 #how many hyperparameter combinations to run
n_iterations = 1 #how many iterations to run per hyperparameter combination
n_epochs = 1 #how many training epochs
patience = 10 #early stopping patience
loss_function = {'mse':'mse'} # default tensorflow loss function string or string of custom loss function of surgeNN.losses (e.g., 'gevl({gamma})')

#splitting & stratified sampling
split_fractions = [.6,.2,.2] #train, test, val
strat_metric = '99pct'
strat_start_month = 7
strat_seed = 0

#hyperparameters:

dl_alpha = np.array([0,3]).astype('int') #defined from command line
batch_size = np.array([128]).astype('int')
n_steps = np.array([9]).astype('int')
n_convlstm = np.array([1]).astype('int')
n_convlstm_units = np.array([32]).astype('int')
n_dense = np.array([2]).astype('int')
n_dense_units = np.array([32]).astype('int')
dropout = np.array([0.2])#np.array([0.1,0.2])
lrs = np.array([5e-5])#np.array([1e-5,5e-5,1e-4])
l1s = np.array([0.02])

hyperparam_options = [batch_size, n_steps, n_convlstm, n_convlstm_units,
                n_dense, n_dense_units, dropout, lrs, l1s, dl_alpha]

n_static=2 #test

In [4]:
setup_output_dirs(output_dir,store_model,model_architecture)

lf_name = list(loss_function.keys())[0]
lf = list(loss_function.values())[0]

try:
    lf = eval(lf)
except:
    pass

n_cells = int(predictor_degrees * (4/1))

### (1) Load & prepare input data
predictors = []
for tg in tgs:
    predictor = io.Predictor(predictor_path)
    predictor.open_dataset(tg,predictor_vars,n_cells)
    predictors.append(predictor.data.swap_dims({'lat_around_tg':'latitude','lon_around_tg':'longitude'}).stack(c=('latitude','longitude')))

merged_predictors = xr.concat(predictors,dim='c').drop_duplicates(dim='c')
merged_predictors = merged_predictors.assign_coords({'c':np.arange(len(merged_predictors.c))})

predictors = io.Predictor(predictor_path)
predictors.data = merged_predictors
predictors.trim_years(1979,2017)
predictors.subtract_annual_means()
predictors.deseasonalize()

predictands = []
for t,tg in enumerate(tgs):
    
    predictand = io.Predictand(predictand_path)
    predictand.open_dataset(tg)
    predictand.trim_dates(predictors.data.time.isel(time=0).values,predictors.data.time.isel(time=-1).values)
    predictand.deseasonalize()
    predictand.resample_fillna(str(temp_freq)+'h')
    
    if t==0:
        predictand0 = predictand
        
    predictand.data = predictand.data.rename(columns={"surge":"surge_"+str(t), "lon": "lon_"+str(t),"lat":"lat_"+str(t)})
    predictands.append(predictand.data)

predictands_merged = reduce(lambda  left,right: pd.merge(left,right,on=['date'],how='inner'), predictands) #merge predictands for the sites

for k in np.arange(len(tgs)):
    predictands_merged.loc[predictands_merged.isnull().any(axis=1), 'surge_'+str(k)] = np.nan #set complete rows to np.nan where at least one tg (column) has no observations

predictand0.data['surge'] = predictands_merged['surge_0']

### (2) Configure sets of hyperparameters to run with
all_settings = list(itertools.product(*hyperparam_options))
n_settings = len(all_settings)

if n_runs<n_settings:
    selected_settings = random.sample(all_settings, n_runs)
else:
    selected_settings = all_settings

    
### (3) Execute training & evaluation (n_iterations * n_runs times):
for it in np.arange(n_iterations): #for each iteration
    tg_datasets = [] #list to store output

    for i,these_settings in enumerate(selected_settings): #for each set of hyperparameters

        this_batch_size,this_n_steps,this_n_convlstm,this_n_convlstm_units,this_n_dense,this_n_dense_units,this_dropout,this_lr,this_l2,this_dl_alpha = these_settings #pick hyperparameters for this run
       
        #generate train, validation and test splits (fow now chronological, to-do: think about stratification for multiple sites
        #predictand0 = predictands_merged[['surge_0','date']].rename(columns={"surge_0":"surge"}) #quickfix to make this work with current preprocessing scripts
        
        model_input = preprocessing.Input(predictors,predictand0) #use predictand for tg0 to do the splitting --> not ideal, other tgs may have a different distribution
        model_input.stack_predictor_coords()
        
        if splits=='stratified':
            model_input.split_stratified(split_fractions,this_n_steps,strat_start_month,strat_seed,strat_metric)
        elif splits=='chronological':
            model_input.split_chronological(split_fractions,this_n_steps)
        
        #index all predictands using the splitting indices
        [y_train_,y_val_,y_test_] = [predictands_merged.iloc[k].filter(like='surge').values for k in [model_input.idx_train,model_input.idx_val,model_input.idx_test]] #set first few timesteps to np.nan? to-do?
        y_train_[:,0] = model_input.y_train
        y_val_[:,0] = model_input.y_val
        y_test_[:,0] = model_input.y_test
        
        y_train_[np.any(np.isnan(y_train_),axis=1),:] = np.nan
        y_val_[np.any(np.isnan(y_val_),axis=1),:] = np.nan
        y_test_[np.any(np.isnan(y_test_),axis=1),:] = np.nan
        
        model_input.y_train = y_train_
        model_input.y_val = y_val_
        model_input.y_test = y_test_
        
        y_train_mean,y_train_sd = model_input.standardize()
        model_input.compute_denseloss_weights(this_dl_alpha)
        
        x_train,y_train,w_train = model_input.get_windowed_filtered_np_input('train',this_n_steps) #generate input for neural network model
        x_val,y_val,w_val       = model_input.get_windowed_filtered_np_input('val',this_n_steps)
        x_test,y_test,w_test    = model_input.get_windowed_filtered_np_input('test',this_n_steps)
                
     
        o_train,o_val,o_test = [y_train_sd * k + y_train_mean for k in [y_train,y_val,y_test]] #back-transform observations
        
        model = build_LSTM_stacked(this_n_convlstm, this_n_dense,(np.ones(this_n_convlstm)*this_n_convlstm_units).astype(int), 
                                  (np.ones(this_n_dense)*this_n_dense_units).astype(int),(this_n_steps,len(predictors.data.c) * len(predictor_vars)),
                                                   len(tgs),'lstm0',this_dropout, this_lr, lf,l2=this_l2)
       
        train_history = model.fit(x=x_train,y=list(np.transpose(y_train)),epochs=n_epochs,batch_size=this_batch_size,
                                  sample_weight=list(np.transpose(w_train)),validation_data=(x_val,list(np.transpose(y_val)),list(np.transpose(w_val))),
                                      callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience,restore_best_weights=True),GC_Callback()],verbose=2) #with numpy arrays input
        
        #make predictions & back-transform
        yhat_train = np.hstack(model.predict(x_train,verbose=0))*y_train_sd + y_train_mean
        yhat_val = np.hstack(model.predict(x_val,verbose=0))*y_train_sd + y_train_mean
        yhat_test = np.hstack(model.predict(x_test,verbose=0))*y_train_sd + y_train_mean
        
       
        ''' #to be further developed
        elif model_architecture == 'lstm_static':
            
            train_conditions = [np.transpose(np.tile(np.array([predictands[k].iloc[0,2],predictands[k].iloc[0,3]])[::,np.newaxis],y_train.shape[0])) for k in np.arange(len(tgs))]
            val_conditions = [np.transpose(np.tile(np.array([predictands[k].iloc[0,2],predictands[k].iloc[0,3]])[::,np.newaxis],y_val.shape[0])) for k in np.arange(len(tgs))]
            test_conditions = [np.transpose(np.tile(np.array([predictands[k].iloc[0,2],predictands[k].iloc[0,3]])[::,np.newaxis],y_test.shape[0])) for k in np.arange(len(tgs))]
            
            model = build_LSTM_stacked_multioutput_static(this_n_convlstm, this_n_dense,(np.ones(this_n_convlstm)*this_n_convlstm_units).astype(int), 
                                  (np.ones(this_n_dense)*this_n_dense_units).astype(int),(this_n_steps,np.prod([predictors.dims[k] for k in predictors.dims if k!='time']) * len(predictor_vars)),
                                                   len(tgs),train_conditions[0].shape[-1],'lstm0',this_dropout, this_lr, lf,l2=this_l2)

            train_history = model.fit(x=[x_train]+train_conditions,y=list(np.transpose(y_train)),epochs=n_epochs,batch_size=this_batch_size,sample_weight=list(np.transpose(w_train)),validation_data=([x_val]+val_conditions,list(np.transpose(y_val)),list(np.transpose(w_val))),
                                          callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience,restore_best_weights=True),GC_Callback()],verbose=2) #with numpy arrays input
            
            yhat_train = np.hstack(model.predict([x_train]+train_conditions,verbose=0))*y_train_sd + y_train_mean
            yhat_val = np.hstack(model.predict([x_val]+val_conditions,verbose=0))*y_train_sd + y_train_mean
            yhat_test = np.hstack(model.predict([x_test]+test_conditions,verbose=0))*y_train_sd + y_train_mean
        
        #test multi-output model
        '''
        

        
        #store results into xr dataset for current settings and iteration
        ds_train = train_predict_output_to_ds(o_train,yhat_train,model_input.t_train,these_settings,tgs,model_architecture,lf_name)
        ds_val = train_predict_output_to_ds(o_val,yhat_val,model_input.t_val,these_settings,tgs,model_architecture,lf_name)
        ds_test = train_predict_output_to_ds(o_test,yhat_test,model_input.t_test,these_settings,tgs,model_architecture,lf_name)
        
        ds_i = xr.concat((ds_train,ds_val,ds_test),dim='split',coords='different') #concatenate results for each split
        ds_i = ds_i.assign_coords(split = ['train','val','test'])

        ds_i = add_loss_to_output(ds_i,train_history,n_epochs)
        tg_datasets.append(ds_i) #append output of current iteration to list of all outputs
        
        if store_model:
            my_path = os.path.join(output_dir,'keras_models',model_architecture)
            my_fn = model_architecture+'_'+str(temp_freq)+'h_'+region_name+'_'+lf_name+'_hp1_i'+str(i)+'_it'

            model.save(os.path.join(my_path,
             my_fn+str(len(fnmatch.filter(os.listdir(my_path),my_fn+'*')))+'.keras'))

        del model, train_history, ds_i #, x_train, x_val, x_test
        tf.keras.backend.clear_session()
        gc.collect()
        
    #concatenate across runs & compute statistics
    out_ds = xr.concat(tg_datasets,dim='i',coords='different')
    out_ds = add_error_metrics_to_prediction_ds(out_ds,[.95,.98,.99,.995],3) #optional third argument 'max_numT_between_isolated_extremes' to exclude extremes isolated by more than n timesteps from another extreme from evaluation (to avoid including extremes mainly due to semi-diurnal tides, see manuscript for more explanation)

    out_ds = out_ds.assign_coords(lon = ('tg',np.array(np.array(predictands_merged.filter(like='lon').iloc[0]))))
    out_ds = out_ds.assign_coords(lat = ('tg',np.array(np.array(predictands_merged.filter(like='lat').iloc[0]))))

    if len(n_steps) == 1: #if n_steps is constant across i, obs doesn't need to have i as a dimension. Saves storage.
        out_ds['o'] = out_ds['o'].isel(i=0,drop=True)

    out_ds.attrs['temp_freq'] = temp_freq
    out_ds.attrs['n_cells'] = n_cells
    out_ds.attrs['n_epochs'] = n_epochs
    out_ds.attrs['patience'] = patience
    out_ds.attrs['loss_function'] = lf_name
    out_ds.attrs['split_fractions'] = split_fractions
    out_ds.attrs['stratification'] = strat_metric+'_'+str(strat_start_month)+'_'+str(strat_seed)

    my_path = os.path.join(output_dir,'performance',model_architecture)
    my_fn = model_architecture+'_'+str(temp_freq)+'h_'+region_name+'_'+lf_name+'_hp1_ndeg'+str(predictor_degrees)+'_it'
    #my_fn = model_architecture+'_'+str(temp_freq)+'h_'+tg.replace('.csv','')+'_'+lf_name+'_hp1_ndeg'+str(predictor_degrees)+'_it'
    
    #out_ds.to_netcdf(os.path.join(my_path,my_fn+str(len(fnmatch.filter(os.listdir(my_path),my_fn+'*')))+'.nc'),mode='w')
    

  data.coords.update(results)


Epoch 1/2
526/526 - 8s - loss: 4.6793 - dense_6_loss: 2.2865 - dense_7_loss: 2.1680 - val_loss: 1.8614 - val_dense_6_loss: 1.0736 - val_dense_7_loss: 0.5186 - 8s/epoch - 15ms/step
Epoch 2/2
526/526 - 6s - loss: 2.9247 - dense_6_loss: 1.3792 - dense_7_loss: 1.2451 - val_loss: 1.4564 - val_dense_6_loss: 0.7685 - val_dense_7_loss: 0.4148 - 6s/epoch - 12ms/step
