In [None]:
# Load library imports
import os
import sys
import torch
import joblib
import random
import logging
import datetime
import numpy as np
import pandas as pd
from tqdm import tqdm
from permetrics.regression import RegressionMetric
from sklearn.metrics import mean_absolute_error, mean_squared_error

# Load project Imports
from src.utils.config_loader import load_project_config, deep_format, expanduser_tree
from src.model.model_building import build_data_loader, instantiate_model_and_associated
from src.utils.config_loader import load_project_config

In [None]:
# Set up logger config
logging.basicConfig(
    level=logging.INFO,
   format='%(levelname)s - %(message)s',
#    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

# Set up logger for file and load config file for paths and params
logger = logging.getLogger(__name__)
config = load_project_config(config_path="config/project_config.yaml")
notebook = True

# Set up root directory paths in config
raw_data_root = config["global"]["paths"]["raw_data_root"]
results_root = config["global"]["paths"]["results_root"]

# Reformat config roots
config = deep_format(
    config,
    raw_data_root=raw_data_root,
    results_root=results_root
)
config = expanduser_tree(config)

In [None]:
# Set up seeding to define global states
random_seed = config["global"]["pipeline_settings"]["random_seed"]
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define notebook demo catchment
catchments_to_process = config["global"]["pipeline_settings"]["catchments_to_process"]
catchment = catchments_to_process[0]
run_defra_API_calls = config["global"]["pipeline_settings"]["run_defra_api"]

logger.info(f"Show Notebook Outputs: {notebook}")
logger.info(f"Notebook Demo Catchment: {catchment.capitalize()}")

Find Station / Validation Groups

# TESTING #

In [None]:
# # all_timesteps_list = torch.load(config[catchment]["paths"]["pyg_object_path"])
# all_timesteps_list = torch.load("data/03_graph/eden/PyG/all_timesteps_list_great_musgrave_20250818_190654.pt")
# all_timesteps_list

In [None]:
# # --- 7a. Build Data Loaders by Timestep ---

# full_dataset_loader = build_data_loader(
#     all_timesteps_list=all_timesteps_list,
#     batch_size = config["global"]["model"]["data_loader_batch_size"],
#     shuffle = config["global"]["model"]["data_loader_shuffle"],
#     catchment=catchment
# )

# logger.info(f"Pipeline Step 'Create PyG DataLoaders' complete for {catchment} catchment.")

In [None]:
# # --- 7b. Define Graph Neural Network Architecture ---

# model, device, optimizer, criterion = instantiate_model_and_associated(
#     all_timesteps_list=all_timesteps_list,
#     config=config,
#     catchment=catchment
# )

# logger.info(f"Pipeline Step 'Instantiate GAT-LSTM Model' complete for {catchment} catchment.")

In [None]:
import torch
import joblib
import logging
import numpy as np
from tqdm import tqdm
from sklearn.metrics import mean_absolute_error, mean_squared_error

# Testing and Plotting Process: 
#    - Transfer .pt model across from NCC using rsync
#    - Transfer timestep PyG object from hard drive and update filepath in section above
#    - Update path below
#    - Update test_station below
#    - Update scalers below
#    - Update test and val station lists in config

all_timesteps_list = torch.load("data/03_graph/eden/PyG/all_timesteps_list_bgs_ev2_20250818_212241.pt")
path = "data/04_model/eden/model/pt_model/model_20250819-010222_GATTrue_LSTMTrue_GATH12_GATD0-4_GATHC64_GATOC64_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-001_SM0-1_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD0-0001_GCMN1-0.pt"
scalers_dir = "data/03_graph/eden/scalers/bgs_ev2_20250818_212241/"
iteration = 6 # 1: Lags (52, 59); 2: (51, 58); 3: (53, 60)
test_station = "bgs_ev2"

# --- 7a. Build Data Loaders by Timestep ---

full_dataset_loader = build_data_loader(
    all_timesteps_list=all_timesteps_list,
    batch_size = config["global"]["model"]["data_loader_batch_size"],
    shuffle = config["global"]["model"]["data_loader_shuffle"],
    catchment=catchment
)

logger.info(f"Pipeline Step 'Create PyG DataLoaders' complete for {catchment} catchment.")

# --- 7b. Define Graph Neural Network Architecture ---

model, device, optimizer, criterion = instantiate_model_and_associated(
    all_timesteps_list=all_timesteps_list,
    config=config,
    catchment=catchment
)

logger.info(f"Pipeline Step 'Instantiate GAT-LSTM Model' complete for {catchment} catchment.")

mean_gwl_map = {
    "ainstable": 84.6333698214874,
    "baronwood": 85.8373720963633,
    "bgs_ev2": 87.2166125260539,
    "castle_carrock": 133.19521880854,
    "cliburn_town_bridge_2": 110.805906037388,
    "coupland": 135.670365012452,
    "croglin": 167.758299820582,
    "east_brownrigg": 106.74319765862,
    "great_musgrave": 152.209015790055,
    "hilton": 214.739017912584,
    "longtown": 18.1315500711501,
    "renwick": 177.683627274689,
    "scaleby": 41.1093269995661,
    "skirwith": 130.796279748829
}

mean_gwl = mean_gwl_map[test_station]

best_model = model  # Assume model object already defined and moved to correct device
best_model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
best_model.eval()
logger.info(f"Loaded best model from {path}")

# Load target scaler
target_scaler_path = os.path.join(scalers_dir, "target_scaler.pkl")
target_scaler = joblib.load(target_scaler_path)
logger.info(f"Loaded target scaler from: {target_scaler_path}")

# Load gwl scaler
gwl_scaler_path = os.path.join(scalers_dir, "gwl_scaler.pkl")
gwl_scaler = joblib.load(gwl_scaler_path)
logger.info(f"Loaded gwl scaler from: {gwl_scaler_path}")

target_scale = float(target_scaler.scale_[0])  # [0] as only gwl_value processed in this scaler
target_mean = float(target_scaler.mean_[0])

# Only ever updating first, don't need others
lag1_scale = float(gwl_scaler.scale_[0])  # [0] as gwl_lag1 was processed first in this scaler
lag1_mean = float(gwl_scaler.mean_[0])

# Initialise global LSTM state
if best_model.run_LSTM:
    lstm_state_store = {
        'h': torch.zeros(best_model.num_layers_lstm, best_model.num_nodes, best_model.hidden_channels_lstm).to(device),
        'c': torch.zeros(best_model.num_layers_lstm, best_model.num_nodes, best_model.hidden_channels_lstm).to(device)
    }
else:
    lstm_state_store = None

# Prepare lists for evaluation
test_predictions_unscaled = []
test_actuals_unscaled = []
fusion_alphas = [] 

logger.info(f"--- Starting Model Evaluation on Test Set ---\n")
test_loop = tqdm(all_timesteps_list, desc="Evaluating on Test Set", leave=False)

# Run brief all_timesteps_list assertions to ensure no critical errors
assert len(all_timesteps_list) > 0, "Empty timesteps list."
first = all_timesteps_list[0]
# assert first.x.size(1) == 70, "Feature dimension different than expected — check column order."

# Get gwl lag, dip and mask cols for autoregression, init and metric calcs
if iteration == 1:
    lag_slice = slice(52, 59) # gwl_lag1 to gwl_lag7 in x
elif iteration == 2:
    lag_slice = slice(51, 58) # gwl_lag1 to gwl_lag7 in x
elif iteration == 3:
    lag_slice = slice(53, 60) # gwl_lag1 to gwl_lag7 in x
elif iteration == 4:
    lag_slice = slice(54, 61) # gwl_lag1 to gwl_lag7 in x
elif iteration == 5:
    lag_slice = slice(51, 52) # gwl_lag1 to gwl_lag7 in x
elif iteration == 6:
    lag_slice = slice(56, 60) # gwl_lag1 to gwl_lag7 in x
else:
    print("CHECK ITERATION AND SLICES")
    
dip_col = 60
mask_col = 68
drift_lim = 365

# Get burn in period (ensureing int dtype and appropriate length)
burn_in = int(config["global"]["pipeline_settings"]["burn_in"])
burn_in = max(burn_in, 7)  # ensure it's at least 7 days

# set initialised once at the first timestep
initialised = False
k = lag_slice.stop - lag_slice.start 

# Start testing loop
with torch.no_grad():
    
    for i, data in enumerate(test_loop):
        data = data.to(device)
        test_mask = data.test_mask
        known_data_mask = (data.train_mask | data.val_mask | data.test_mask)
        
        # Skip timesteps with no nodes with known ground truth
        if known_data_mask.sum() == 0:
            continue
        
        # Confirm assertation that there is exactly 1 test node
        n_test = int(data.test_mask.sum().item())
        assert n_test == 1, f"Expected exactly 1 test node, got {n_test}."

        # at start use stations gwl_dip (already standardised) to warm up test station lags
        if not initialised:
            lag_init_std = float(data.x[test_mask, dip_col].mean().item())  # mean used defensively, should only be 1
            # lag_state_std = np.full(7, lag_init_std, dtype=np.float32)  # already in same std space (using gwl scaler)     
            lag_state_std = np.full(k, lag_init_std, dtype=np.float32)  
            initialised = True

        # write current lag bugger into x for the test node (overwriting real values that would cause leakage)
        x = data.x.clone()
        x[test_mask, lag_slice] = torch.from_numpy(lag_state_std).to(x.dtype).to(x.device)

        # Model forward pass on full node set
        predictions_all, (h_new, c_new), returned_node_ids = best_model(
            x=x,  # data.x updated with warmed up / autoregressive lags
            edge_index=data.edge_index,
            edge_attr=data.edge_attr,
            current_timestep_node_ids=data.node_id,
            lstm_state_store=lstm_state_store
        )
        
        # Update LSTM memory for current nodes
        if best_model.run_LSTM:
            lstm_state_store['h'][:, returned_node_ids, :] = h_new.detach()
            lstm_state_store['c'][:, returned_node_ids, :] = c_new.detach()

        # Filter predictions/targets to test nodes
        preds_std = predictions_all[test_mask]
        targets_std = data.y[test_mask]

        # Inverse transform to original scale
        preds_np = preds_std.cpu().numpy()
        targets_np = targets_std.cpu().numpy()

        preds_unscaled = target_scaler.inverse_transform(preds_np)
        targets_unscaled = target_scaler.inverse_transform(targets_np)

        test_predictions_unscaled.extend(preds_unscaled.flatten())
        test_actuals_unscaled.extend(targets_unscaled.flatten())
        
        # Capture residual contribution relative to baseline (for interpretability)
        if best_model.run_GAT and best_model.run_LSTM:
            dbg = getattr(best_model, "last_debug", None)
            if dbg is not None:
                residual = dbg.get("residual", None)
                baseline = dbg.get("baseline", None)
                if isinstance(residual, torch.Tensor) and isinstance(baseline, torch.Tensor):
                    res_abs = torch.abs(residual[test_mask]).sum().item()
                    base_abs = torch.abs(baseline[test_mask]).sum().item()
                    if base_abs > 0:
                        fusion_alphas.append(res_abs / base_abs)  # store rel contribution ratio
                        
        # if i < 5:  # Show first few predictions
        #     print("Sample predictions (m AOD):", preds_unscaled[:5].flatten())
        #     print("Sample actuals     (m AOD):", targets_unscaled[:5].flatten())

        if burn_in <= i < burn_in + 5:  # Show first few predictions
            print("Sample predictions (m AOD):", preds_unscaled[:5].flatten())
            print("Sample actuals     (m AOD):", targets_unscaled[:5].flatten())
        
        # Update lag buffer with new predictions made this timestep
        y_std = float(preds_std.view(-1)[0].cpu().item())
        
        # Convert target-std -> raw -> lag-std (as target and gwl have different scalers)
        y_raw = (y_std * target_scale) + target_mean  # back to mAOD
        y_lag_std = (y_raw - lag1_mean) / lag1_scale  # rescaled to lag scaler
        
        # Note: using np.roll, from docs: "elements that roll beyond the last position are
        # re-introduced at the first" (so roll, reintroduce, then overwrite with latest)
        lag_state_std = np.roll(lag_state_std, 1)  # roll by 1
        lag_state_std[0] = float(y_lag_std)  # defensively ensure dtype

In [None]:
# def compute_drift_from_integer_hit(preds, gwl_dip, direction='down'):
#     """
#     Find the earliest index i (0-based) where int(preds[i]) == int(gwl_dip).
#     If not found, fall back to the index with minimum |pred - gwl_dip|.
#     """
#     preds = np.asarray(preds, dtype=float)
#     target_int = int(gwl_dip)

#     # indices where integer part matches
#     match_idx = np.nonzero(preds.astype(int) == target_int)[0]

#     if match_idx.size == 0:
#         # fallback: closest point
#         return int(np.argmin(np.abs(preds - gwl_dip)))

#     if direction == 'down':
#         for i in match_idx:
#             if i == 0 or int(preds[i-1]) != target_int:
#                 return int(i)
#         # if none satisfied the guard, just return the first match
#         return int(match_idx[0])
#     else:
#         return int(match_idx[0])

# TODO: THIS CORRECTION IS CURRENTLY LEAKAGE - adjust to use inputs
def _best_shift(y, yhat, max_lag=30):
    # returns lag* (positive = prediction lags observation)
    lags = range(-max_lag, max_lag+1)
    corr = [np.corrcoef(y[max(0,l):len(yhat)+min(0,l)],
                        yhat[max(0,-l):len(y)-max(0,l)])[0,1] for l in lags]
    return lags[int(np.nanargmax(corr))], np.nanmax(corr)


In [None]:
def align_by_lag(y, yhat, lag):
    """
    Align y (actuals) and yhat (predictions) given an integer lag.
    lag > 0 : predictions lag observations by lag days (shift preds left)
    lag < 0 : predictions lead observations by |lag| (shift preds right)
    """
    y = np.asarray(y, dtype=float)
    yh = np.asarray(yhat, dtype=float)
    
    if lag > 0:
        # drop the first 'lag' from y, drop the last 'lag' from yh
        return y[lag:], yh[:-lag]
    elif lag < 0:
        L = -lag
        # drop the last 'L' from y, drop the first 'L' from yh
        return y[:-L], yh[L:]
    else:
        return y, yh

# Clip burn-in (before applying)
preds_full = np.asarray(test_predictions_unscaled, dtype=float)
acts_full = np.asarray(test_actuals_unscaled, dtype=float)

if len(preds_full) <= burn_in:
    raise ValueError(f"Series too short ({len(preds_full)}) for burn-in {burn_in}.")

burn_in = 730
preds_full = preds_full[burn_in:]
acts_full = acts_full[burn_in:]

# Determine lag (capped to ±max_lag in best_shift)
drift, r_star = _best_shift(acts_full, preds_full, max_lag=300)
logger.info(f"Best shift = {drift} days; r = {r_star:.3f}")

# Align safely
acts_aln, preds_aln = align_by_lag(acts_full, preds_full, drift)

if len(acts_aln) == 0 or len(preds_aln) == 0:
    raise ValueError(f"Empty arrays after alignment: len(acts)={len(acts_aln)}, len(preds)={len(preds_aln)}")

test_actuals_unscaled = acts_aln
test_predictions_unscaled = preds_aln

In [None]:
# --- CLip out burn in period before runnning metrics and plotting ---

# test_predictions_unscaled_original = test_predictions_unscaled.copy()
# test_actuals_unscaled_original = test_actuals_unscaled.copy()

# preds = np.asarray(test_predictions_unscaled, dtype=float)
# acts = np.asarray(test_actuals_unscaled, dtype=float)

# drift, r_star = best_shift(test_actuals_unscaled, test_predictions_unscaled, max_lag=30)
# print(f"Best shift = {drift} days; r = {r_star:.3f}")

# func_return = compute_drift_from_integer_hit(preds, acts[0], direction='down')
# print(f"Drift from integar hit: {func_return}")

# if drift < burn_in:
#     # Drop burn in period and save as canonical arrays
#     test_predictions_unscaled = preds[burn_in:]
#     test_actuals_unscaled = acts[burn_in-drift:-drift]
# else:
#     test_predictions_unscaled = preds[burn_in:]
#     test_actuals_unscaled = acts[burn_in:]
#     test_predictions_unscaled = preds[drift:]
#     test_actuals_unscaled = acts[:-drift]

# --- Final model prediction evaluation ---

if len(test_actuals_unscaled) > 0:
    loss_type = config[catchment]["training"]["loss"]

    if loss_type == "MAE":
        final_test_metric = mean_absolute_error(test_actuals_unscaled, test_predictions_unscaled)
        logger.info(f"--- Final Test Set MAE (m AOD): {final_test_metric:.4f} ---\n")

    elif loss_type == "MSE":
        final_test_metric = mean_squared_error(test_actuals_unscaled, test_predictions_unscaled)
        logger.info(f"--- Final Test Set MSE (m AOD²): {final_test_metric:.4f} ---\n")

    else:
        logger.warning(f"Unrecognized loss type '{loss_type}' in config — skipping final metric calculation.\n")
else:
    logger.warning("No test data found — check 'data.test_mask'.\n")

logger.info("--- Model Evaluation on Test Set Complete ---\n")

# Calculate and display the global average residual contribution
if fusion_alphas:
    avg_rel_contrib = np.mean(fusion_alphas) * 100
    logger.info("--- Residual Contribution (on test node) ---")
    logger.info(f"Average GAT Residual Contribution: {avg_rel_contrib:.2f}%")
    logger.info(f"Average LSTM Contribution: {100 - avg_rel_contrib:.2f}%")
    logger.info("-------------------------------------------\n")
    
# TODO: TEST METRICS MUST NOT INCLUDE MASKED VALUES IN THE TEST SET (e.g. BGS_EV2)

In [None]:
mean_error = np.mean(np.array(test_predictions_unscaled) - np.array(test_actuals_unscaled))
diff = np.mean(test_predictions_unscaled) - mean_gwl
logger.info(f"Final Test Set Mean Error (Bias): {mean_error:.4f} m AOD [{diff}]")

In [None]:
# Load the target scaler
from joblib import load

# # Load target scaler (y, 'gwl_value') in
# scalers_dir = config[catchment]["paths"]["scalers_dir"]
# target_scaler_path = os.path.join(scalers_dir, "target_scaler.pkl")
# try:
#     target_scaler = load(target_scaler_path)
#     logger.info(f"Successfully loaded target scaler from: {target_scaler_path}")
# except Exception as e:
#     logger.warning(f"No target scaler found or error loading it: {e}")
#     target_scaler = None

# Convert both to np array
test_predictions_np = np.array(test_predictions_unscaled).reshape(-1, 1)
test_actuals_np = np.array(test_actuals_unscaled).reshape(-1, 1)

# Confirm range (sanity checker)
logger.info(f"Sample prediction range: {test_predictions_np.min():.2f} to {test_predictions_np.max():.2f}")
logger.info(f"Sample actual range:     {test_actuals_np.min():.2f} to {test_actuals_np.max():.2f}")

# Assign reshaped vals to final arrs for plotting and metrics
test_predictions_final = test_predictions_np
test_actuals_final = test_actuals_np

In [None]:
from sklearn.metrics import mean_absolute_error
from permetrics.regression import RegressionMetric
diff = np.mean(test_predictions_final) - mean_gwl  # predictions - ["gwl_mean"]

final_test_mae = mean_absolute_error(test_actuals_final, test_predictions_final)
adjusted_mae = mean_absolute_error(test_actuals_final, test_predictions_final-diff)
unit_label = "mAOD" if target_scaler else "standard units"
evaluator = RegressionMetric(test_actuals_final, test_predictions_final)  # Before offset correction
adj_evaluator = RegressionMetric(test_actuals_final, test_predictions_final-diff)  # After offset correction

# Mean Absolute Error (MAE)
logger.info(f"Final Test Set MAE: {final_test_mae:.4f} {unit_label} [Target 0.2 to 0.5 mAOD]")
logger.info(f"Final Adjusted MAE: {adjusted_mae:.4f} {unit_label} [Target 0.2 to 0.5 mAOD]\n")

# Root Mean Square Error (RMSE)
logger.info(f"Baseline RMSE: {evaluator.root_mean_squared_error():.4f} {unit_label} [Target 0.25 to 0.6 mAOD]")
logger.info(f"Adjusted RMSE: {adj_evaluator.root_mean_squared_error():.4f} {unit_label} [Target 0.25 to 0.6 mAOD]\n")

# Coefficient of Determination (R^2)
logger.info(f"Baseline R^2: {evaluator.coefficient_of_determination():.4f} [Target 0.80 or higher]")
logger.info(f"Adjusted R^2: {adj_evaluator.coefficient_of_determination():.4f} [Target 0.80 or higher]\n")

# Nash-Sutcliffe Efficiency (NSE)
logger.info(f"Baseline NSE: {evaluator.nash_sutcliffe_efficiency():.4f} [Target 0.75 or higher]")
logger.info(f"Adjusted NSE: {adj_evaluator.nash_sutcliffe_efficiency():.4f} [Target 0.75 or higher]\n")

# Kling Gupta Efficiency (KGE)
logger.info(f"Baseline KGE: {evaluator.kling_gupta_efficiency():.4f} [Target 0.75 or higher]")
logger.info(f"Adjusted KGE: {adj_evaluator.kling_gupta_efficiency():.4f} [Target 0.75 or higher]\n")

In [None]:
import numpy as np

def calculate_kge_components(actuals, predictions):
    """Get KGE Component parts as individual values."""
    r = np.corrcoef(actuals.flatten(), predictions.flatten())[0, 1]
    beta = np.mean(predictions.flatten()) / np.mean(actuals.flatten())
    gamma = np.std(predictions.flatten()) / np.std(actuals.flatten())
    return r, beta, gamma

# Using precalc'd adjusted predictions
adjusted_predictions = test_predictions_final - diff
r_actual, beta_actual, gamma_actual = calculate_kge_components(test_actuals_final, test_predictions_final)
r_adjusted, beta_adjusted, gamma_adjusted = calculate_kge_components(test_actuals_final, adjusted_predictions)

# Log component results (baseline)
print(f"Baseline KGE Components:")
print(f"  Correlation (r): {r_actual:.4f};")
print(f"  Bias (beta): {beta_actual:.4f};")
print(f"  Variability (gamma): {gamma_actual:.4f}\n")

# Log component results (adjusted)
print(f"Adjusted KGE Components:")
print(f"  Correlation (r): {r_adjusted:.4f};")
print(f"  Bias (beta): {beta_adjusted:.4f};")
print(f"  Variability (gamma): {gamma_adjusted:.4f}\n")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

diff = np.mean(test_predictions_final) - mean_gwl  # predictions - ["gwl_mean"]

# for x axis as date
start_date = config["global"]["data_ingestion"]["model_start_date"]
test_start_date = pd.to_datetime(start_date) + pd.Timedelta(days=burn_in)  # don't plot burn in period
date_range = pd.date_range(start=test_start_date, periods=len(test_actuals_unscaled))

plt.figure(figsize=(15, 5))

plt.plot(date_range, test_actuals_final, label='Actual GWL Value', color='blue', alpha=0.7, linewidth=1)
plt.plot(date_range, test_predictions_final-diff, label='Predicted GWL Value', color='red', alpha=0.7, linewidth=1)  # , linestyle='--'

# Format the x-axis to show years
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
plt.gca().xaxis.set_major_locator(mdates.YearLocator())

plt.title(f'Actual vs. Predicted Groundwater Levels at {test_station}')
plt.xlabel('Date')
plt.ylabel('Groundwater Level (m AOD)' if target_scaler else 'Standardised GWL')
plt.legend()
plt.grid(True)
plt.tight_layout()

# Give slightly more room above and below than automatic
all_vals = np.concatenate([test_actuals_final, test_predictions_final-diff])
y_min = all_vals.min()
y_max = all_vals.max()
y_range = y_max - y_min
plt.ylim(y_min - y_range/3, y_max + y_range/3)

# Save plot
base_name = os.path.basename(path)
filename_no_ext, extension = os.path.splitext(base_name)
save_path = "results/trained_models/eden/" + f"{test_station}_{iteration}_" + filename_no_ext
# save_path = os.path.join("results/trained_models/eden/2_", filename_no_ext)
plt.savefig(save_path, dpi=300)

plt.show()

In [None]:
# --- Save final results as csv for future ref ---

logger.info(f"Converting {test_station} results to dataframe for reference...\n")

results_df = pd.DataFrame({
    "ground_truth_values": test_actuals_final.flatten(),
    "baseline_predictions": test_predictions_final.flatten(),
    "drift_adjusted_predictions": (test_predictions_final - diff).flatten()
})

# Ensure csv dir exists
csv_dir = os.path.join(config[catchment]["paths"]["model_dir"], "test_results")
os.makedirs(csv_dir, exist_ok=True)
csv_path = os.path.join(csv_dir, f"{test_station}_{iteration}.csv")
results_df.to_csv(csv_path)

logger.info(f"{test_station} results saved to: {csv_path}\n")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Set a style for plots (optional)
sns.set_style("whitegrid")

# Create a figure and a set of subplots
plt.figure(figsize=(15, 5))

# Plot the predicted values
plt.plot(test_predictions_unscaled - diff, label='Predicted GWL Value', color='red', alpha=0.7)  # , linestyle='--'

# Add titles and labels
plt.title('Predicted Groundwater Levels on Test Set')
plt.xlabel('Data Point Index (Sequential Timesteps/Stations)')
plt.ylabel('Groundwater Level (m AOD)')
plt.legend()
plt.grid(True)
plt.tight_layout() 

# Give slightly more room above and below than automatic
all_vals = np.concatenate([test_actuals_final, test_predictions_final - diff])
y_min = all_vals.min()
y_max = all_vals.max()
y_range = y_max - y_min
plt.ylim(y_min - y_range/3, y_max + y_range/3)

plt.show()

logger.info("Pipeline step 'Generate plot of predicted values' complete.")

In [None]:
# Weekly (7-day) mean plot from daily arrays 

pred_daily = np.asarray(test_predictions_final).reshape(-1)
act_daily = np.asarray(test_actuals_final).reshape(-1)

# Calc means
k = 7
n_full = (len(act_daily) // k) * k
if n_full < len(act_daily):
    logger.info(f"Trimming {len(act_daily) - n_full} trailing day(s).")
act_weekly = act_daily[:n_full].reshape(-1, k).mean(axis=1)
pred_weekly = pred_daily[:n_full].reshape(-1, k).mean(axis=1)

week0 = pd.to_datetime(config["global"]["data_ingestion"]["model_start_date"]) + pd.Timedelta(days=burn_in)
week_dates = pd.date_range(start=week0, periods=len(act_weekly), freq="7D")

plt.figure(figsize=(15, 7))
plt.plot(act_weekly,  label='Actual GWL', alpha=0.8)
plt.plot(pred_weekly - diff, label='Predicted GWL', alpha=0.8)  # linestyle='--',
# plt.plot(pred_weekly + 0.3, label='Predicted GWL', alpha=0.8)  # linestyle='--',
plt.title('Actual vs. Predicted Groundwater Levels (Weekly Means)')
plt.xlabel('Week Index')
plt.ylabel('Groundwater Level (m AOD)' if target_scaler else 'Standardised GWL')
plt.legend()
plt.grid(True)
plt.tight_layout()

# Set plot limits
all_vals = np.concatenate([act_weekly, pred_weekly - diff])
y_min = all_vals.min()
y_max = all_vals.max()
y_range = y_max - y_min
plt.ylim(y_min - y_range/2, y_max + y_range/2)

# Save next to the daily figure with _weekly
# weekly_save_path = save_path + "_weekly"
# plt.savefig(weekly_save_path, dpi=300)

plt.show()


In [None]:
# import numpy as np, pandas as pd
# from sklearn.metrics import mean_absolute_error, mean_squared_error

# # Reload if needed
# scaler_path = "data/03_graph/eden/scalers/croglin_20250818_180550/target_scaler.pkl"
# target_scaler = joblib.load(scaler_path)
# logger.info(f"Loaded target scaler from: {scaler_path}")
# scale = torch.tensor(target_scaler.scale_, device=device)
# mean = torch.tensor(target_scaler.mean_, device=device)


# def inv_std(arr_std, scaler):
#     arr_std = np.asarray(arr_std).reshape(-1,1)
#     if scaler is None:
#         return arr_std.ravel()
#     return scaler.inverse_transform(arr_std).ravel()

# def nse(y, yhat):
#     y = np.asarray(y); yhat = np.asarray(yhat)
#     den = np.sum((y - y.mean())**2)
#     return 1.0 - (np.sum((y - yhat)**2) / den if den > 0 else np.inf)

# def kge(y, yhat):
#     y = np.asarray(y); yhat = np.asarray(yhat)
#     r = np.corrcoef(y, yhat)[0,1] if y.size > 1 else np.nan
#     beta = yhat.mean()/y.mean() if y.mean()!=0 else np.nan
#     gamma = (yhat.std(ddof=1)/y.std(ddof=1)) if y.std(ddof=1)>0 else np.nan
#     return 1 - np.sqrt((r-1)**2 + (beta-1)**2 + (gamma-1)**2), r, beta, gamma

# # ---------- Build climatology from TRAINING ONLY ----------
# from collections import defaultdict
# clim_sum = defaultdict(lambda: np.zeros(366, dtype=float))
# clim_cnt = defaultdict(lambda: np.zeros(366, dtype=int))

# for data in all_timesteps_list:
#     train_mask = data.train_mask
#     if train_mask.sum() == 0:
#         continue
#     doy = int(pd.to_datetime(str(data.timestep)).dayofyear) - 1
#     y_std = data.y[train_mask].detach().cpu().numpy().ravel()
#     y_raw = inv_std(y_std, target_scaler)
#     nids  = data.node_id[train_mask].detach().cpu().numpy().ravel().astype(int)
#     for nid, val in zip(nids, y_raw):
#         clim_sum[nid][doy] += val
#         clim_cnt[nid][doy] += 1

# climatology = {}
# node_annual = {}
# for nid in clim_sum.keys():
#     s, c = clim_sum[nid], clim_cnt[nid]
#     with np.errstate(divide='ignore', invalid='ignore'):
#         m = np.divide(s, np.where(c==0, 1, c))
#     annual = s.sum() / max(c.sum(), 1) if c.sum()>0 else 0.0
#     m[c==0] = annual
#     climatology[nid] = m            # per-node DoY mean (366)
#     node_annual[nid] = annual       # per-node annual mean

# # Global fallbacks (across all training observations)
# all_train_vals = []
# for nid in clim_sum.keys():
#     s, c = clim_sum[nid], clim_cnt[nid]
#     if c.sum() > 0:
#         # expand s/c into list of observed values is overkill; use annual means as proxy
#         all_train_vals.append(node_annual[nid])
# global_mean = float(np.mean(all_train_vals)) if len(all_train_vals) else 0.0
# global_doy = np.zeros(366, dtype=float)
# # simple global DoY baseline = global mean (you could compute true global DoY means if needed)
# global_doy[:] = global_mean

# # ---------- Evaluate on TEST timesteps ----------
# y_true, y_persist, y_clim = [], [], []
# prev_obs = {}  # nid -> last RAW y

# for data in all_timesteps_list:
#     doy = int(pd.to_datetime(str(data.timestep)).dayofyear) - 1

#     # 1) make predictions FIRST using previous observations
#     test_mask = data.test_mask
#     if test_mask.any():
#         y_std = data.y[test_mask].detach().cpu().numpy().ravel()
#         y_raw = inv_std(y_std, target_scaler)
#         nids  = data.node_id[test_mask].detach().cpu().numpy().ravel().astype(int)

#         for nid, yt in zip(nids, y_raw):
#             # persistence (t-1), fallback to node climatology, then global day-of-year, then global mean
#             yp = prev_obs.get(nid, None)
#             if yp is None:
#                 yp = climatology.get(nid, global_doy)[doy] if nid in climatology else global_doy[doy]
#             yc = climatology.get(nid, global_doy)[doy] if nid in climatology else global_doy[doy]

#             y_true.append(yt)
#             y_persist.append(yp)
#             y_clim.append(yc)

#     # 2) AFTER predicting, update prev_obs with *current* observed values
#     known_mask = (data.train_mask | data.val_mask | data.test_mask)
#     if known_mask.any():
#         ys_std = data.y[known_mask].detach().cpu().numpy().ravel()
#         ys_raw = inv_std(ys_std, target_scaler)
#         nids   = data.node_id[known_mask].detach().cpu().numpy().ravel().astype(int)
#         for nid, val in zip(nids, ys_raw):
#             prev_obs[nid] = val

# # ---------- Metrics ----------
# y_true = np.asarray(y_true); y_persist = np.asarray(y_persist); y_clim = np.asarray(y_clim)

# # ensure no NaNs remain
# for name, arr in [("y_true", y_true), ("y_persist", y_persist), ("y_clim", y_clim)]:
#     if np.isnan(arr).any():
#         n = int(np.isnan(arr).sum())
#         raise ValueError(f"{name} contains {n} NaNs after baseline construction.")

# def report(name, yhat):
#     mae  = mean_absolute_error(y_true, yhat)
#     rmse = np.sqrt(mean_squared_error(y_true, yhat))
#     _nse = nse(y_true, yhat)
#     _kge, r, beta, gamma = kge(y_true, yhat)
#     print(f"{name}: MAE={mae:.3f} m, RMSE={rmse:.3f} m, NSE={_nse:.3f}, KGE={_kge:.3f} (r={r:.3f}, β={beta:.3f}, γ={gamma:.3f})")

# if y_true.size == 0:
#     print("No test data found — check test masks.")
# else:
#     report("Persistence", y_persist)
#     report("Seasonal climatology (DoY mean)", y_clim)


In [None]:
# import matplotlib.pyplot as plt
# import numpy as np
# import pandas as pd

# # Assume you already have:
# # y_true = np.array([...])   # test targets (unscaled, m AOD)
# # y_pred = np.array([...])   # model predictions (unscaled, m AOD)

# # --- 1. Hydrograph with persistence baseline ---
# y_persist = np.roll(y_true, 1)   # simple lag-1 baseline
# y_persist[0] = y_true[0]         # first value no lag

# plt.figure(figsize=(12,4))
# plt.plot(y_true, label="Observed", lw=1)
# plt.plot(test_predictions_final, label="Model", alpha=0.7, lw=1)
# plt.plot(y_persist, label="Persistence", alpha=0.7, lw=1)
# plt.legend(); plt.title("Hydrograph Comparison")
# plt.show()

# # --- 2. Autocorrelation (ACF) of observed series ---
# from statsmodels.graphics.tsaplots import plot_acf
# plot_acf(y_true, lags=50)
# plt.title("Autocorrelation of Observed GWL")
# plt.show()

# # --- 3. Lag plot: y(t) vs y(t-1) ---
# plt.figure(figsize=(4,4))
# plt.scatter(y_true[:-1], y_true[1:], alpha=0.5)
# plt.xlabel("y(t-1)")
# plt.ylabel("y(t)")
# plt.title("Lag-1 Plot (Observed)")
# plt.show()


Plot at a weekly resolution (using daily predictions)

In [None]:
import pandas as pd
from ray.tune import ExperimentAnalysis
import os
import logging
import sys
import glob
import numpy as np

logging.basicConfig(
    level=logging.INFO,
    format='%(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

ray_tune_dir = "data/04_model/eden/model/ray_tune_gwl"

if not os.path.isdir(ray_tune_dir):
    logger.error(f"Error: The specified directory does not exist or is not a directory: {ray_tune_dir}")
else:
    try:
        abs_dir = os.path.abspath(ray_tune_dir)
        uri = f"file://{abs_dir}"

        analysis = ExperimentAnalysis(uri)
        df = analysis.dataframe()

        # Coerce numeric metrics (in case anything logged as strings)
        for col in ['val_loss', 'train_loss', 'epoch']:
            if col in df.columns:
                df[col] = pd.to_numeric(df[col], errors='coerce')

        # Keep only rows with a valid trial_id
        if 'trial_id' not in df.columns:
            raise KeyError("Expected 'trial_id' column not found in Ray Tune results DataFrame.")
        df = df[df['trial_id'].notna()].copy()

        # filter out errored trials (errored due to timeouts)
        if 'error' in df.columns:
            df = df[df['error'].isna()].copy()

        logger.info(f"Total rows (iterations) loaded: {len(df)}")
        n_trials = df['trial_id'].nunique()
        logger.info(f"Trials represented: {n_trials}")

        # Compute per-trial aggregates
        config_cols = [c for c in df.columns if c.startswith('config/')]
        keep_cols = ['trial_id', 'logdir'] + config_cols
        meta_first = (df[keep_cols]
                      .sort_values(['trial_id']) 
                      .groupby('trial_id', as_index=False)
                      .first())
        
        # Trials that never logged val_loss/train_los (timed out) become NaN -> drop for ranking
        agg = (df.groupby('trial_id')
                 .agg(min_val_loss=('val_loss', 'min'),
                      mean_val_loss=('val_loss', 'mean'),
                      min_train_loss=('train_loss', 'min'),
                      mean_train_loss=('train_loss', 'mean'),
                      last_epoch=('epoch', 'max'),
                      last_iter=('training_iteration', 'max'))
                 .reset_index())

        # Merge configs back in
        per_trial = meta_first.merge(agg, on='trial_id', how='left')

        # Keep only trials with at least some val_loss signal
        ranked = per_trial[per_trial['min_val_loss'].notna()].copy()
        if ranked.empty:
            raise RuntimeError("No trials have non-NaN 'val_loss'. Cannot rank.")

        ranked = ranked.sort_values(['min_val_loss', 'mean_val_loss'], ascending=[True, True]).reset_index(drop=True)

        best = ranked.iloc[0]
        trial_id = best['trial_id']

        # Pack nice dicts for logging
        best_config = {k: best[k] for k in config_cols if k in ranked.columns}
        best_metrics = {k: best[k] for k in ['min_val_loss', 'mean_val_loss', 'min_train_loss', 'mean_train_loss',
                                             'last_epoch', 'last_iter'] if k in ranked.columns}

        logger.info("\n--- Best Trial (by min_val_loss, then mean_val_loss) ---")
        logger.info(f"Trial ID: {trial_id}")
        logger.info(f"Best metrics: {best_metrics}")
        logger.info(f"Best hyperparameters: {best_config}")

        # Locate the corresponding PT model for this trial_id
        pt_root = os.path.abspath("data/04_model/eden/model/pt_model")
        pattern = os.path.join(pt_root, f"trial_{trial_id}", "pt_model", "*.pt")
        pt_candidates = sorted(glob.glob(pattern))
        if pt_candidates:
            best_pt_model_path = pt_candidates[-1]  # choose last
            logger.info(f"Best trial PT model: {best_pt_model_path}")
        else:
            logger.warning(f"No .pt model files found for trial_id={trial_id} with pattern: {pattern}")

        # Show a compact table of the top 10 trials
        display_cols = (['trial_id'] + config_cols +
                        ['min_val_loss', 'mean_val_loss', 'min_train_loss', 'mean_train_loss', 'last_epoch'])
        display_cols = [c for c in display_cols if c in ranked.columns]
        top10 = ranked[display_cols].head(10)
        pd.set_option('display.max_columns', None)
        pd.set_option('display.width', 160)
        print("\nTop 10 trials by validation loss:")
        print(top10.to_string(index=False))

    except Exception as e:
        logger.error(f"An error occurred during analysis: {e}")
