In [None]:
# Load library imports
import os
import sys
import torch
import joblib
import random
import logging
import datetime
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from permetrics.regression import RegressionMetric
from sklearn.metrics import mean_absolute_error, mean_squared_error

# Load project Imports
from src.utils.config_loader import load_project_config, deep_format, expanduser_tree
from src.model.model_building import build_data_loader, instantiate_model_and_associated
from src.utils.config_loader import load_project_config

In [None]:
# Set up logger config
logging.basicConfig(
    level=logging.INFO,
   format='%(levelname)s - %(message)s',
#    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

# Set up logger for file and load config file for paths and params
logger = logging.getLogger(__name__)
config = load_project_config(config_path="config/project_config.yaml")
notebook = True

# Set up root directory paths in config
raw_data_root = config["global"]["paths"]["raw_data_root"]
results_root = config["global"]["paths"]["results_root"]

# Reformat config roots
config = deep_format(
    config,
    raw_data_root=raw_data_root,
    results_root=results_root
)
config = expanduser_tree(config)

In [None]:
# Set up seeding to define global states
random_seed = config["global"]["pipeline_settings"]["random_seed"]
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define notebook demo catchment
catchments_to_process = config["global"]["pipeline_settings"]["catchments_to_process"]
catchment = catchments_to_process[0]
run_defra_API_calls = config["global"]["pipeline_settings"]["run_defra_api"]

logger.info(f"Show Notebook Outputs: {notebook}")
logger.info(f"Notebook Demo Catchment: {catchment.capitalize()}")

# TESTING #

In [None]:
# --- DEFINE STATION RUNS ---

iteration = ""  # Mark if runs are an ablation (leave as "" if not)
start_slice = 0  # Define start of testing range (e.g. BGS_EV2)
end_slice = 0  # Define end of testing range

# Mark any to not run
exclude = []

station_model_map = {
    "ainstable_20250901_133721": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-184132_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        },
    "baronwood_20250901_115622": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-122212_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        },
    "bgs_ev2_20250901_120534": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-122918_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 730, "end": 0
        },
    "castle_carrock_20250901_121422": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-123011_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        },
    "cliburn_town_20250901_122225": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-123952_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        },
    "coupland_20250901_123028": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-134427_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        },
    "croglin_20250827_150320": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250828-100815_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-00091_WD2e-05_SM0-13_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 1095
        },
    "east_brownrigg_20250901_123714": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-143709_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        },
    "great_musgrave_20250901_124500": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-154212_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        },
    "hilton_20250901_125340": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-155047_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        },
    "longtown_20250901_130138": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-160358_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        },
    "renwick_20250901_131054": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-165608_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        },
    "scaleby_20250901_131941": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-170916_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        },
    "skirwith_20250901_132823": {
        "model_path": "data/04_model/eden/model/pt_model/model_20250901-182247_GATTrue_LSTMTrue_GATH8_GATD0-22_GATHC48_GATOC48_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0002_SM0-15_E250_ESP35_LRSF0-5_LRSP8_MINLR1e-06_LD1e-05_GCMN1-0.pt",
        "start": 0, "end": 0
        }
}

In [None]:
"""
CHANGE NOTHING PAST HERE FOR EACH STATION RUN!
"""

# Get testing count
num_stations = len(station_model_map)

for i, station in enumerate(station_model_map.keys()):
    test_number = i
    station_append = station
    
    station_input = station_model_map[station]
    path = station_input["model_path"]
    
    # Skip if no path yet
    if not path:
        logger.info(f"({test_number} / {num_stations}) - {station} has no associated model path, skipping...\n")
        continue

    if station in exclude:
        logger.info(f"({test_number} / {num_stations}) - {station} marked for exclusion, skipping...\n")
        continue
    
    # Define necessary dirs and load model
    test_station = station_append[:-16]
    all_timesteps_list = torch.load(f"data/03_graph/eden/PyG/all_timesteps_list_{station_append}.pt")
    scalers_dir = f"data/03_graph/eden/scalers/{station_append}/"

    # Initialise (or reset in loop) dict to store metrics
    metrics = {}

    # --- 7a. Build Data Loaders by Timestep ---

    full_dataset_loader = build_data_loader(
        all_timesteps_list=all_timesteps_list,
        batch_size = config["global"]["model"]["data_loader_batch_size"],
        shuffle = config["global"]["model"]["data_loader_shuffle"],
        catchment=catchment
    )

    logger.info(f"Pipeline Step 'Create PyG DataLoaders' complete for {catchment} catchment.")

    # --- 7b. Define Graph Neural Network Architecture ---

    model, device, optimizer, criterion = instantiate_model_and_associated(
        all_timesteps_list=all_timesteps_list,
        config=config,
        catchment=catchment
    )

    logger.info(f"Pipeline Step 'Instantiate GAT-LSTM Model' complete for {catchment} catchment.")

    mean_gwl_map = {
        "ainstable": 84.6333698214874,
        "baronwood": 85.8373720963633,
        "bgs_ev2": 87.2166125260539,
        "castle_carrock": 133.19521880854,
        "cliburn_town": 110.805906037388,
        "coupland": 135.670365012452,
        "croglin": 167.758299820582,
        "east_brownrigg": 106.74319765862,
        "great_musgrave": 152.209015790055,
        "hilton": 214.739017912584,
        "longtown": 18.1315500711501,
        "renwick": 177.683627274689,
        "scaleby": 41.1093269995661,
        "skirwith": 130.796279748829
    }

    mean_gwl = mean_gwl_map[test_station]

    best_model = model  # Assume model object already defined and moved to correct device
    best_model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
    best_model.eval()
    logger.info(f"Loaded best model from {path}")

    # Load target scaler
    target_scaler_path = os.path.join(scalers_dir, "target_scaler.pkl")
    target_scaler = joblib.load(target_scaler_path)
    logger.info(f"Loaded target scaler from: {target_scaler_path}")

    # Load gwl scaler
    gwl_scaler_path = os.path.join(scalers_dir, "gwl_scaler.pkl")
    gwl_scaler = joblib.load(gwl_scaler_path)
    logger.info(f"Loaded gwl scaler from: {gwl_scaler_path}\n")

    target_scale = float(target_scaler.scale_[0])  # [0] as only gwl_value processed in this scaler
    target_mean = float(target_scaler.mean_[0])

    # Initialise global LSTM state
    if best_model.run_LSTM:
        lstm_state_store = {
            'h': torch.zeros(best_model.num_layers_lstm, best_model.num_nodes, best_model.hidden_channels_lstm).to(device),
            'c': torch.zeros(best_model.num_layers_lstm, best_model.num_nodes, best_model.hidden_channels_lstm).to(device)
        }
    else:
        lstm_state_store = None

    # Prepare lists for evaluation
    test_predictions_unscaled = []
    test_actuals_unscaled = []
    fusion_alphas = [] 
    
    # Collect per-timestep interpretability signals for this station
    alpha_series = []
    gamma_series = []
    beta_series = []

    logger.info(f"--- Starting Model Evaluation on Test Set ---\n")
    test_loop = tqdm(all_timesteps_list, desc="Evaluating on Test Set", leave=False)

    logger.info(f"\n\nSTATION TEST: {station_append}")
    logger.info(f"    Running Station {test_number} / {num_stations}..\n\n")

    # Run brief all_timesteps_list assertions to ensure no critical errors
    assert len(all_timesteps_list) > 0, "Empty timesteps list."
    first = all_timesteps_list[0]
    # assert first.x.size(1) == 70, "Feature dimension different than expected — check column order."

    dip_col = 53
    mask_col = 61
    drift_lim = 365

    # Get burn in period (ensureing int dtype and appropriate length)
    burn_in = int(config["global"]["pipeline_settings"]["burn_in"])
    burn_in = max(burn_in, 7)  # ensure it's at least 7 days

    # Start testing loop
    with torch.no_grad():
        
        for i, data in enumerate(test_loop):
            data = data.to(device)
            test_mask = data.test_mask
            known_data_mask = (data.train_mask | data.val_mask | data.test_mask)
            
            # Skip timesteps with no nodes with known ground truth
            if known_data_mask.sum() == 0:
                continue
            
            # Confirm assertation that there is exactly 1 test node
            n_test = int(data.test_mask.sum().item())
            assert n_test == 1, f"Expected exactly 1 test node, got {n_test}."

            # Model forward pass on full node set
            predictions_all, (h_new, c_new), returned_node_ids = best_model(
                x=data.x,  # data.x updated with warmed up / autoregressive lags
                edge_index=data.edge_index,
                edge_attr=data.edge_attr,
                current_timestep_node_ids=data.node_id,
                lstm_state_store=lstm_state_store
            )
            
            # Update LSTM memory for current nodes
            if best_model.run_LSTM:
                lstm_state_store['h'][:, returned_node_ids, :] = h_new.detach()
                lstm_state_store['c'][:, returned_node_ids, :] = c_new.detach()

            # Filter predictions/targets to test nodes
            preds_std = predictions_all[test_mask]
            targets_std = data.y[test_mask]

            # Inverse transform to original scale
            preds_np = preds_std.cpu().numpy()
            targets_np = targets_std.cpu().numpy()

            preds_unscaled = target_scaler.inverse_transform(preds_np)
            targets_unscaled = target_scaler.inverse_transform(targets_np)

            test_predictions_unscaled.extend(preds_unscaled.flatten())
            test_actuals_unscaled.extend(targets_unscaled.flatten())
            
            # Capture residual contribution relative to baseline (for interpretability)
            if best_model.run_GAT and best_model.run_LSTM:
                dbg = getattr(best_model, "last_debug", None)
                if dbg is not None:
                    residual = dbg.get("residual", None)
                    baseline = dbg.get("baseline", None)
                    if isinstance(residual, torch.Tensor) and isinstance(baseline, torch.Tensor):
                        res_abs = torch.abs(residual[test_mask]).sum().item()
                        base_abs = torch.abs(baseline[test_mask]).sum().item()
                        if base_abs > 0:
                            fusion_alphas.append(res_abs / base_abs)  # store rel contribution ratio
                            
            # capture alpha / gamma / beta for the single TEST node at this step
            dbg = getattr(best_model, "last_debug", None)
            if dbg is not None:
                # alpha: (N,1) -> scalar for the test node
                if dbg.get("alpha", None) is not None:
                    a = dbg["alpha"][test_mask].squeeze().item()
                    alpha_series.append(float(a))

                # gamma/beta: (N, d_h) -> reduce over hidden dim (mean) for the test node
                if dbg.get("gamma", None) is not None:
                    g = dbg["gamma"][test_mask].mean(dim=1).item()
                    gamma_series.append(float(g))
                if dbg.get("beta", None) is not None:
                    b = dbg["beta"][test_mask].mean(dim=1).item()
                    beta_series.append(float(b))

            if burn_in <= i < burn_in + 5:  # Show first few predictions
                print("Sample predictions (m AOD):", preds_unscaled[:5].flatten())
                print("Sample actuals     (m AOD):", targets_unscaled[:5].flatten())
                
    # --- DEFINED HELPER FUNCS FOR DIAGNOSTICS ---
    
    def _best_shift(y, yhat, max_lag=30):
        # returns lag* (positive = prediction lags observation)
        lags = range(-max_lag, max_lag+1)
        corr = [np.corrcoef(y[max(0,l):len(yhat)+min(0,l)],
                            yhat[max(0,-l):len(y)-max(0,l)])[0,1] for l in lags]
        return lags[int(np.nanargmax(corr))], np.nanmax(corr)

    def align_by_lag(y, yhat, lag):
        """
        Align y (actuals) and yhat (predictions) given an integer lag.
        lag > 0 : predictions lag observations by lag days (shift preds left)
        lag < 0 : predictions lead observations by |lag| (shift preds right)
        """
        y = np.asarray(y, dtype=float)
        yh = np.asarray(yhat, dtype=float)
        
        if lag > 0:
            # drop the first 'lag' from y, drop the last 'lag' from yh
            return y[lag:], yh[:-lag]
        elif lag < 0:
            L = -lag
            # drop the last 'L' from y, drop the first 'L' from yh
            return y[:-L], yh[L:]
        else:
            return y, yh
            
    # --- CLIP AND APPLY DRIFTS ---

    # Clip burn-in (before applying)
    preds_full = np.asarray(test_predictions_unscaled, dtype=float)
    acts_full = np.asarray(test_actuals_unscaled, dtype=float)

    if len(preds_full) <= burn_in:
        raise ValueError(f"Series too short ({len(preds_full)}) for burn-in {burn_in}.")

    burn_in = 0  # 730
    preds_full = preds_full[burn_in:]
    acts_full = acts_full[burn_in:]

    # Determine lag (capped to ±max_lag in best_shift)
    drift, r_star = _best_shift(acts_full, preds_full, max_lag=300)
    logger.info(f"Best shift = {drift} days; r = {r_star:.3f}")

    # Align safely
    acts_aln, preds_aln = align_by_lag(acts_full, preds_full, drift)

    if len(acts_aln) == 0 or len(preds_aln) == 0:
        raise ValueError(f"Empty arrays after alignment: len(acts)={len(acts_aln)}, len(preds)={len(preds_aln)}")

    test_actuals_unscaled = acts_aln
    test_predictions_unscaled = preds_aln

    # Add metrics to dict
    metrics['best_shift'] = drift
    metrics['burn_in'] = burn_in
    
    # --- GET GAT AND LSTM CONTRIBUTIONS ---
    
    # --- Clipping if needed (e.g. masked missingness) ---
    
    start_slice = station_input["start"] - burn_in
    end_slice = station_input["end"]

    if start_slice != 0 and end_slice != 0:
        logger.info(f"Trimming {start_slice} timesteps from start and {end_slice} "
                    f"timesteps from end (Current length: {len(test_predictions_unscaled)})")
        test_actuals_unscaled = test_actuals_unscaled[start_slice:-end_slice]
        test_predictions_unscaled = test_predictions_unscaled[start_slice:-end_slice]
        logger.info(f"Trimmed length: {len(test_predictions_unscaled)}\n")
    elif start_slice != 0:
        logger.info(f"Trimming {start_slice} timesteps from start (Current length: {len(test_predictions_unscaled)})")
        test_actuals_unscaled = test_actuals_unscaled[start_slice:]
        test_predictions_unscaled = test_predictions_unscaled[start_slice:]
        logger.info(f"Trimmed length: {len(test_predictions_unscaled)}\n")
    elif end_slice != 0:
        logger.info(f"Trimming {end_slice} timesteps from end (Current length: {len(test_predictions_unscaled)})")
        test_actuals_unscaled = test_actuals_unscaled[:-end_slice]
        test_predictions_unscaled = test_predictions_unscaled[:-end_slice]
        logger.info(f"Trimmed length: {len(test_predictions_unscaled)}\n")

    # --- Final model prediction evaluation ---

    if len(test_actuals_unscaled) > 0:
        loss_type = config[catchment]["training"]["loss"]

        if loss_type == "MAE":
            final_test_metric = mean_absolute_error(test_actuals_unscaled, test_predictions_unscaled)
            logger.info(f"--- Final Test Set MAE (m AOD): {final_test_metric:.4f} ---\n")

        elif loss_type == "MSE":
            final_test_metric = mean_squared_error(test_actuals_unscaled, test_predictions_unscaled)
            logger.info(f"--- Final Test Set MSE (m AOD²): {final_test_metric:.4f} ---\n")

        else:
            logger.warning(f"Unrecognized loss type '{loss_type}' in config — skipping final metric calculation.\n")
    else:
        logger.warning("No test data found — check 'data.test_mask'.\n")

    logger.info("--- Model Evaluation on Test Set Complete ---\n")

    # Calculate and display the global average residual contribution
    if fusion_alphas:
        avg_rel_contrib = np.mean(fusion_alphas) * 100
        logger.info("--- Residual Contribution (on test node) ---")
        logger.info(f"Average GAT Residual Contribution: {avg_rel_contrib:.2f}%")
        logger.info(f"Average LSTM Contribution: {100 - avg_rel_contrib:.2f}%")
        logger.info("-------------------------------------------\n")
            
    # --- GET SAMPLE AND PREDICTION RANGES ---
            
        # Add metrics to dict
        metrics['GAT_contribution'] = avg_rel_contrib
        metrics['LSTM_contribution'] = avg_rel_contrib
    
    mean_error = np.mean(np.array(test_predictions_unscaled) - np.array(test_actuals_unscaled))
    diff = np.mean(test_predictions_unscaled) - mean_gwl
    logger.info(f"Final Test Set Mean Error (Bias): {mean_error:.4f} m AOD [{diff}]\n")

    # Add metrics to dict
    metrics['mae_unadjusted'] = mean_error
    
    # Convert both to np array
    test_predictions_np = np.array(test_predictions_unscaled).reshape(-1, 1)
    test_actuals_np = np.array(test_actuals_unscaled).reshape(-1, 1)

    # Confirm range (sanity checker)
    logger.info(f"Sample prediction range: {test_predictions_np.min():.2f} to {test_predictions_np.max():.2f}")
    logger.info(f"Sample actual range:     {test_actuals_np.min():.2f} to {test_actuals_np.max():.2f}\n")

    # Assign reshaped vals to final arrs for plotting and metrics
    test_predictions_final = test_predictions_np
    test_actuals_final = test_actuals_np
    
    # --- GET MAIN METRICS ---
    
    # diff = np.mean(test_predictions_final) - mean_gwl  # predictions - ["gwl_mean"]

    final_test_mae = mean_absolute_error(test_actuals_final, test_predictions_final)
    adjusted_mae = mean_absolute_error(test_actuals_final, test_predictions_final-diff)
    unit_label = "mAOD" if target_scaler else "standard units"
    evaluator = RegressionMetric(test_actuals_final, test_predictions_final)  # Before offset correction
    adj_evaluator = RegressionMetric(test_actuals_final, test_predictions_final-diff)  # After offset correction

    # Mean Absolute Error (MAE)
    logger.info(f"Final Test Set MAE: {final_test_mae:.4f} {unit_label} [Target 0.2 to 0.5 mAOD]")
    logger.info(f"Final Adjusted MAE: {adjusted_mae:.4f} {unit_label} [Target 0.2 to 0.5 mAOD]\n")

    # Root Mean Square Error (RMSE)
    baseline_rmse = evaluator.root_mean_squared_error()
    adjusted_rmse = adj_evaluator.root_mean_squared_error()
    logger.info(f"Baseline RMSE: {baseline_rmse:.4f} {unit_label} [Target 0.25 to 0.6 mAOD]")
    logger.info(f"Adjusted RMSE: {adjusted_rmse:.4f} {unit_label} [Target 0.25 to 0.6 mAOD]\n")

    # Coefficient of Determination (R^2)
    logger.info(f"Baseline R^2: {evaluator.coefficient_of_determination():.4f} [Target 0.80 or higher]")
    logger.info(f"Adjusted R^2: {adj_evaluator.coefficient_of_determination():.4f} [Target 0.80 or higher]\n")

    # Nash-Sutcliffe Efficiency (NSE)
    baseline_nse = evaluator.nash_sutcliffe_efficiency()
    adjusted_nse = adj_evaluator.nash_sutcliffe_efficiency()
    logger.info(f"Baseline NSE: {baseline_nse:.4f} [Target 0.75 or higher]")
    logger.info(f"Adjusted NSE: {adjusted_nse:.4f} [Target 0.75 or higher]\n")

    # Kling Gupta Efficiency (KGE)
    baseline_kge = evaluator.kling_gupta_efficiency()
    adjusted_kge = adj_evaluator.kling_gupta_efficiency()
    logger.info(f"Baseline KGE: {baseline_kge:.4f} [Target 0.75 or higher]")
    logger.info(f"Adjusted KGE: {adjusted_kge:.4f} [Target 0.75 or higher]\n")
    
    # Add metrics to dict
    metrics['final_metrics_baseline_mae'] = final_test_mae
    metrics['final_metrics_baseline_rmse'] = baseline_rmse
    metrics['final_metrics_baseline_nse'] = baseline_nse
    metrics['final_metrics_baseline_kge'] = baseline_kge

    metrics['final_metrics_adjusted_mae'] = adjusted_mae
    metrics['final_metrics_adjusted_rmse'] = adjusted_rmse
    metrics['final_metrics_adjusted_nse'] = adjusted_nse
    metrics['final_metrics_adjusted_kge'] = adjusted_kge
    
    # --- GET COMPONENTS ---
    
    def calculate_kge_components(actuals, predictions):
        """Get KGE Component parts as individual values."""
        r = np.corrcoef(actuals.flatten(), predictions.flatten())[0, 1]
        beta = np.mean(predictions.flatten()) / np.mean(actuals.flatten())
        gamma = np.std(predictions.flatten()) / np.std(actuals.flatten())
        return r, beta, gamma

    # Using precalc'd adjusted predictions
    adjusted_predictions = test_predictions_final - diff
    r_actual, beta_actual, gamma_actual = calculate_kge_components(test_actuals_final, test_predictions_final)
    r_adjusted, beta_adjusted, gamma_adjusted = calculate_kge_components(test_actuals_final, adjusted_predictions)

    # Log component results (baseline)
    print(f"Baseline KGE Components:")
    print(f"  Correlation (r): {r_actual:.4f};")
    print(f"  Bias (beta): {beta_actual:.4f};")
    print(f"  Variability (gamma): {gamma_actual:.4f}\n")

    # Log component results (adjusted)
    print(f"Adjusted KGE Components:")
    print(f"  Correlation (r): {r_adjusted:.4f};")
    print(f"  Bias (beta): {beta_adjusted:.4f};")
    print(f"  Variability (gamma): {gamma_adjusted:.4f}\n")

    # Add metrics to dict
    metrics['kge_components_baseline_corr'] = r_actual
    metrics['kge_components_baseline_bias'] = beta_actual
    metrics['kge_components_baseline_var'] = gamma_actual
    metrics['kge_components_adjusted_corr'] = r_adjusted
    metrics['kge_components_adjusted_bias'] = beta_adjusted
    metrics['kge_components_adjusted_var'] = gamma_adjusted
    
    # Summarise alpha/gamma/beta over the test period for this station
    def _median_iqr(arr):
        arr = np.asarray(arr, float)
        if arr.size == 0:
            return np.nan, np.nan, np.nan, np.nan
        q25, med, q75 = np.percentile(arr, [25, 50, 75])
        return float(med), float(q75 - q25), float(arr.min()), float(arr.max())

    alpha_med, alpha_iqr, alpha_min, alpha_max = _median_iqr(alpha_series)

    # Keep γ and β unchanged (no min/max since they’re unbounded and noisier)
    def _median_iqr_only(arr):
        arr = np.asarray(arr, float)
        if arr.size == 0:
            return np.nan, np.nan
        q25, med, q75 = np.percentile(arr, [25, 50, 75])
        return float(med), float(q75 - q25)

    gamma_med, gamma_iqr = _median_iqr_only(gamma_series)
    beta_med,  beta_iqr  = _median_iqr_only(beta_series)

    logger.info(f"Fusion α — median: {alpha_med:.3f}, IQR: {alpha_iqr:.3f}, "
                f"min: {alpha_min:.3f}, max: {alpha_max:.3f}")
    logger.info(f"FiLM γ (scale) — median: {gamma_med:.3f}, IQR: {gamma_iqr:.3f}")
    logger.info(f"FiLM β (shift) — median: {beta_med:.3f}, IQR: {beta_iqr:.3f}\n")

    metrics['alpha_median'] = alpha_med
    metrics['alpha_IQR'] = alpha_iqr
    metrics['alpha_min'] = alpha_min
    metrics['alpha_max'] = alpha_max
    metrics['gamma_median'] = gamma_med
    metrics['gamma_IQR'] = gamma_iqr
    metrics['beta_median'] = beta_med
    metrics['beta_IQR'] = beta_iqr
        
    # --- RUNNING DAILY RES PLOTS ---

    diff = np.mean(test_predictions_final) - mean_gwl  # predictions - ["gwl_mean"]

    # for x axis as date
    start_date = config["global"]["data_ingestion"]["model_start_date"]
    test_start_date = pd.to_datetime(start_date) + pd.Timedelta(days=burn_in)  # don't plot burn in period
    date_range = pd.date_range(start=test_start_date, periods=len(test_actuals_unscaled))

    plt.figure(figsize=(15, 5))

    plt.plot(date_range, test_actuals_final, label='Actual GWL Value', color='blue', alpha=0.7, linewidth=1)
    plt.plot(date_range, test_predictions_final-diff, label='Predicted GWL Value', color='red', alpha=0.7, linewidth=1)  # , linestyle='--'

    # Format the x-axis to show years
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    plt.gca().xaxis.set_major_locator(mdates.YearLocator())

    plt.title(f'Actual vs. Predicted Groundwater Levels at {test_station}')
    plt.xlabel('Date')
    plt.ylabel('Groundwater Level (m AOD)' if target_scaler else 'Standardised GWL')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    # Give slightly more room above and below than automatic
    all_vals = np.concatenate([test_actuals_final, test_predictions_final-diff])
    y_min = all_vals.min()
    y_max = all_vals.max()
    y_range = y_max - y_min
    plt.ylim(y_min - y_range/3, y_max + y_range/3)

    # Save plot
    base_name = os.path.basename(path)
    filename_no_ext, extension = os.path.splitext(base_name)
    save_path = "results/trained_models/eden/FINAL_daily/" + f"{test_station}_{iteration}_" + filename_no_ext
    plt.savefig(save_path, dpi=300)

    plt.show()
    
    # --- SAVE PREDICTION RESULTS TO CSV ---
    
    logger.info(f"Converting {test_station} results to dataframe for reference...\n")

    results_df = pd.DataFrame({
        "ground_truth_values": test_actuals_final.flatten(),
        "baseline_predictions": test_predictions_final.flatten(),
        "drift_adjusted_predictions": (test_predictions_final - diff).flatten()
    })

    # Ensure csv dir exists
    csv_dir = os.path.join(config[catchment]["paths"]["model_dir"], "test_results")
    os.makedirs(csv_dir, exist_ok=True)
    csv_path = os.path.join(csv_dir, f"{test_station}_{iteration}.csv")
    results_df.to_csv(csv_path)

    logger.info(f"{test_station} results saved to: {csv_path}\n")
    
    # --- SAVE METRICS DICT TO CSV ---
    
    # (Save raw alpha/gamma/beta time series)
    diag_dir = os.path.join("data/04_model/eden/metrics", "diagnostics")
    os.makedirs(diag_dir, exist_ok=True)
    pd.DataFrame({
        "alpha": alpha_series,
        "gamma_mean_over_dh": gamma_series,
        "beta_mean_over_dh":  beta_series
    }).to_csv(os.path.join(diag_dir, f"{station_append}_alphagammabeta.csv"), index=False)
    
    # Main Save
    output_dir = "data/04_model/eden/metrics/"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f"{station_append}_metrics.csv")

    if iteration != "":
        output_file = os.path.join(output_dir, f"{station_append}_metrics_{iteration}.csv")

    # Convert dictionary to pd df
    metrics_df = pd.DataFrame([metrics])

    # If file doesn't exist then write the header
    if os.path.exists(output_file):
        metrics_df.to_csv(output_file, mode='a', header=False, index=False)
    else:
        metrics_df.to_csv(output_file, index=False)

    logger.info(f"Metrics saved to {output_file}\n")
    
    # --- RUNNING WEEKLY RES PLOTS ---

    pred_daily = np.asarray(test_predictions_final).reshape(-1)
    act_daily = np.asarray(test_actuals_final).reshape(-1)

    # Calc means
    k = 7
    n_full = (len(act_daily) // k) * k
    if n_full < len(act_daily):
        logger.info(f"Trimming {len(act_daily) - n_full} trailing day(s).")
    act_weekly = act_daily[:n_full].reshape(-1, k).mean(axis=1)
    pred_weekly = pred_daily[:n_full].reshape(-1, k).mean(axis=1)

    week0 = pd.to_datetime(config["global"]["data_ingestion"]["model_start_date"]) + pd.Timedelta(days=burn_in)
    week_dates = pd.date_range(start=week0, periods=len(act_weekly), freq="7D")

    plt.figure(figsize=(15, 7))
    plt.plot(act_weekly,  label='Actual GWL', alpha=0.8)
    plt.plot(pred_weekly - diff, label='Predicted GWL', alpha=0.8)  # linestyle='--',
    # plt.plot(pred_weekly + 0.3, label='Predicted GWL', alpha=0.8)  # linestyle='--',
    plt.title('Actual vs. Predicted Groundwater Levels (Weekly Means)')
    plt.xlabel('Week Index')
    plt.ylabel('Groundwater Level (m AOD)' if target_scaler else 'Standardised GWL')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    # Set plot limits
    all_vals = np.concatenate([act_weekly, pred_weekly - diff])
    y_min = all_vals.min()
    y_max = all_vals.max()
    y_range = y_max - y_min
    plt.ylim(y_min - y_range/2, y_max + y_range/2)
    
    weekly_base_name = os.path.basename(path)
    weekly_filename_no_ext, extension = os.path.splitext(base_name)
    weekly_save_path = "results/trained_models/eden/FINAL_weekly/" + f"{test_station}_{iteration}_weekly_" + filename_no_ext
    plt.savefig(weekly_save_path, dpi=300)
    
    # --- VARIOUS RESOLUTIONS ---
    
    def _block_mean(arr: np.ndarray, k: int) -> np.ndarray:
        """Non-overlapping block means of length k, trimming any tail."""
        n = (len(arr) // k) * k
        return arr[:n].reshape(-1, k).mean(axis=1)

    def _metrics_row(y: np.ndarray, yhat: np.ndarray, mean_gwl_val: float):
        """Return a dict with the same metric fields as the daily CSV."""
        y = np.asarray(y, float).reshape(-1, 1)
        yhat = np.asarray(yhat, float).reshape(-1, 1)

        # baseline metrics
        mae = mean_absolute_error(y, yhat)
        ev = RegressionMetric(y, yhat)
        rmse = ev.root_mean_squared_error()
        nse = ev.nash_sutcliffe_efficiency()
        kge = ev.kling_gupta_efficiency()

        # simple offset correction to mean GWL (consistent with your daily logic)
        diff_local = float(yhat.mean() - mean_gwl_val)
        yhat_adj = (yhat - diff_local)
        adj_mae = mean_absolute_error(y, yhat_adj)
        ev_adj = RegressionMetric(y, yhat_adj)
        adj_rmse = ev_adj.root_mean_squared_error()
        adj_nse = ev_adj.nash_sutcliffe_efficiency()
        adj_kge = ev_adj.kling_gupta_efficiency()

        return {
            "best_shift": metrics.get("best_shift", np.nan),
            "burn_in": metrics.get("burn_in", np.nan),
            "GAT_contribution": metrics.get("GAT_contribution", np.nan),
            "LSTM_contribution": metrics.get("LSTM_contribution", np.nan),
            "mae_unadjusted": float((yhat - y).mean()),  # keeps your existing field semantics
            "final_metrics_baseline_mae": float(mae),
            "final_metrics_baseline_rmse": float(rmse),
            "final_metrics_baseline_nse": float(nse),
            "final_metrics_baseline_kge": float(kge),
            "final_metrics_adjusted_mae": float(adj_mae),
            "final_metrics_adjusted_rmse": float(adj_rmse),
            "final_metrics_adjusted_nse": float(adj_nse),
            "final_metrics_adjusted_kge": float(adj_kge),
        }

    # --- Build daily series you already computed ---
    pred_daily = np.asarray(test_predictions_final, dtype=float).reshape(-1)
    act_daily  = np.asarray(test_actuals_final, dtype=float).reshape(-1)

    # --- Weekly (7-day non-overlapping means) using arrays you already have for plotting ---
    k7 = 7
    n7 = (len(act_daily) // k7) * k7
    if n7 >= k7:
        act_7 = act_daily[:n7].reshape(-1, k7).mean(axis=1)
        pred_7 = pred_daily[:n7].reshape(-1, k7).mean(axis=1)

        weekly_row = _metrics_row(act_7, pred_7, mean_gwl)
        weekly_file = output_file.replace("_metrics", "_metrics_7day")
        pd.DataFrame([weekly_row]).to_csv(
            weekly_file,
            mode="a" if os.path.exists(weekly_file) else "w",
            header=not os.path.exists(weekly_file),
            index=False,
        )

    # --- 30-day non-overlapping means ---
    k30 = 30
    act_30 = _block_mean(act_daily, k30)
    pred_30 = _block_mean(pred_daily, k30)
    if len(act_30) >= 1:
        m30_row = _metrics_row(act_30, pred_30, mean_gwl)
        m30_file = output_file.replace("_metrics", "_metrics_30day")
        pd.DataFrame([m30_row]).to_csv(
            m30_file,
            mode="a" if os.path.exists(m30_file) else "w",
            header=not os.path.exists(m30_file),
            index=False,
        )
    
    logger.info(f"\n\nSTATION TEST: {station_append} COMPLETE. MOVING TO NEXT...\n\n")
    logger.info(f"Station {test_number} / {num_stations} complete.\n\n")
    
logger.info(f"\n\nAll station tests complete.")

RAY TUNE RESULTS RUN

In [None]:
# import pandas as pd
# from ray.tune import ExperimentAnalysis
# import os
# import logging
# import sys
# import glob
# import numpy as np

# logging.basicConfig(
#     level=logging.INFO,
#     format='%(levelname)s - %(message)s',
#     handlers=[logging.StreamHandler(sys.stdout)]
# )
# logger = logging.getLogger(__name__)

# ray_tune_dir = "data/04_model/eden/model/ray_tune_gwl"

# if not os.path.isdir(ray_tune_dir):
#     logger.error(f"Error: The specified directory does not exist or is not a directory: {ray_tune_dir}")
# else:
#     try:
#         abs_dir = os.path.abspath(ray_tune_dir)
#         uri = f"file://{abs_dir}"

#         analysis = ExperimentAnalysis(uri)
#         df = analysis.dataframe()

#         # Coerce numeric metrics (in case anything logged as strings)
#         for col in ['val_loss', 'train_loss', 'epoch']:
#             if col in df.columns:
#                 df[col] = pd.to_numeric(df[col], errors='coerce')

#         # Keep only rows with a valid trial_id
#         if 'trial_id' not in df.columns:
#             raise KeyError("Expected 'trial_id' column not found in Ray Tune results DataFrame.")
#         df = df[df['trial_id'].notna()].copy()

#         # filter out errored trials (errored due to timeouts)
#         if 'error' in df.columns:
#             df = df[df['error'].isna()].copy()

#         logger.info(f"Total rows (iterations) loaded: {len(df)}")
#         n_trials = df['trial_id'].nunique()
#         logger.info(f"Trials represented: {n_trials}")

#         # Compute per-trial aggregates
#         config_cols = [c for c in df.columns if c.startswith('config/')]
#         keep_cols = ['trial_id', 'logdir'] + config_cols
#         meta_first = (df[keep_cols]
#                       .sort_values(['trial_id']) 
#                       .groupby('trial_id', as_index=False)
#                       .first())
        
#         # Trials that never logged val_loss/train_los (timed out) become NaN -> drop for ranking
#         agg = (df.groupby('trial_id')
#                  .agg(min_val_loss=('val_loss', 'min'),
#                       mean_val_loss=('val_loss', 'mean'),
#                       min_train_loss=('train_loss', 'min'),
#                       mean_train_loss=('train_loss', 'mean'),
#                       last_epoch=('epoch', 'max'),
#                       last_iter=('training_iteration', 'max'))
#                  .reset_index())

#         # Merge configs back in
#         per_trial = meta_first.merge(agg, on='trial_id', how='left')

#         # Keep only trials with at least some val_loss signal
#         ranked = per_trial[per_trial['min_val_loss'].notna()].copy()
#         if ranked.empty:
#             raise RuntimeError("No trials have non-NaN 'val_loss'. Cannot rank.")

#         ranked = ranked.sort_values(['min_val_loss', 'mean_val_loss'], ascending=[True, True]).reset_index(drop=True)

#         best = ranked.iloc[0]
#         trial_id = best['trial_id']

#         # Pack nice dicts for logging
#         best_config = {k: best[k] for k in config_cols if k in ranked.columns}
#         best_metrics = {k: best[k] for k in ['min_val_loss', 'mean_val_loss', 'min_train_loss', 'mean_train_loss',
#                                              'last_epoch', 'last_iter'] if k in ranked.columns}

#         logger.info("\n--- Best Trial (by min_val_loss, then mean_val_loss) ---")
#         logger.info(f"Trial ID: {trial_id}")
#         logger.info(f"Best metrics: {best_metrics}")
#         logger.info(f"Best hyperparameters: {best_config}")

#         # Locate the corresponding PT model for this trial_id
#         pt_root = os.path.abspath("data/04_model/eden/model/pt_model")
#         pattern = os.path.join(pt_root, f"trial_{trial_id}", "pt_model", "*.pt")
#         pt_candidates = sorted(glob.glob(pattern))
#         if pt_candidates:
#             best_pt_model_path = pt_candidates[-1]  # choose last
#             logger.info(f"Best trial PT model: {best_pt_model_path}")
#         else:
#             logger.warning(f"No .pt model files found for trial_id={trial_id} with pattern: {pattern}")

#         # Show a compact table of the top 10 trials
#         display_cols = (['trial_id'] + config_cols +
#                         ['min_val_loss', 'mean_val_loss', 'min_train_loss', 'mean_train_loss', 'last_epoch'])
#         display_cols = [c for c in display_cols if c in ranked.columns]
#         top10 = ranked[display_cols].head(10)
#         pd.set_option('display.max_columns', None)
#         pd.set_option('display.width', 160)
#         print("\nTop 10 trials by validation loss:")
#         print(top10.to_string(index=False))

#     except Exception as e:
#         logger.error(f"An error occurred during analysis: {e}")
