# Calculate the median NSE by ensemble of LSTMs

e.g., the model is trained for 10 times. 

(1) we calculate the median simulated discharge of 10 runs for each basin, 

(2) calculate NSE for each basin using the calculated median simulated discharge,

(3) calculate median NSE across all basins

In [1]:
# Import necessary packages
import pickle
import sys
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

sys.path.append("..")
from hy2dl.aux_functions.functions_evaluation import nse

In [19]:
# Read the results generated using Hy2DL
path_results_LSTM = ["/hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/28_days_seed_110/test_results_best_epoch/",
                     "/hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/28_days_seed_111/test_results_best_epoch/",
                     "/hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/28_days_seed_222/test_results_best_epoch/",
                     "/hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/28_days_seed_333/test_results_best_epoch/",
                     "/hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/28_days_seed_444/test_results_best_epoch/",
                     "/hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/28_days_seed_555/test_results_best_epoch/",
                     "/hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/28_days_seed_666/test_results_best_epoch/",
                     "/hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/28_days_seed_777/test_results_best_epoch/",
                     "/hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/28_days_seed_888/test_results_best_epoch/",
                     "/hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/28_days_seed_999/test_results_best_epoch/"
                     ]

test_result_save_path = "/hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/Ensemble/"
csv_name = "28_day.csv"

if not os.path.exists(test_result_save_path):
    os.makedirs(test_result_save_path)

    
# Read information produced by ensemble of LSTMs and store it in dictionary of dataframes
lstm_results = {}
for i, ensemble_member in enumerate(path_results_LSTM):
    with open(ensemble_member + "/test_results.pickle", "rb") as f:
        info_lstm = pickle.load(f)
    # Iterate over each basin
    for basin in info_lstm.keys():
        y_sim = info_lstm[basin]["y_sim"]
        if i == 0: # If this is the first ensemble member, initialize the DataFrame with y_obs and y_sim
            y_obs = info_lstm[basin]["y_obs"]
            lstm_results[basin] = pd.DataFrame(data={"y_obs": y_obs, f"y_sim_ens_{i+1}": y_sim}, index=y_obs.index)
        else: # For subsequent ensemble members, add y_sim as a new column
            lstm_results[basin][f"y_sim_ens_{i+1}"] = y_sim
            
# Calculate the median of the simulated values for each basin and add it as a new column
for basin in lstm_results.keys():
    # Select only the y_sim columns
    y_sim_columns = [col for col in lstm_results[basin].columns if col.startswith("y_sim_ens_")]
    lstm_results[basin]["y_sim"] = lstm_results[basin][y_sim_columns].median(axis=1)
    
# Calculate the median NSE across all basins
df_NSE_lstm_CAMELS_US_hourly = pd.DataFrame(data={"basin_id": list(lstm_results.keys()), 
                                                  "Median_NSE_by_ensemble_of_LSTMs": np.round(nse(df_results=lstm_results, average=False),3)}
                                                  ).set_index("basin_id")

# print(df_NSE_lstm_CAMELS_US_hourly)

In [20]:
# Save results by ensemble of LSTMs in a csv file

df_NSE_lstm_CAMELS_US_hourly.to_csv(os.path.join(test_result_save_path, csv_name))
print(f"Results by ensemble of LSTMs has been saved to {test_result_save_path}")

Results by ensemble of LSTMs has been saved to /hkfs/home/haicore/iwu/qa8171/Project/Hy2DL/results/US_exp2_correct/Ensemble/


In [21]:
median_NSE = df_NSE_lstm_CAMELS_US_hourly["Median_NSE_by_ensemble_of_LSTMs"].median()
print(f"Median NSE across all basins: {median_NSE:.3f}")

Median NSE across all basins: 0.752
