In [None]:
# import torch

# # Define test path
# model_path = 'data/04_model/eden/model/best_model.pt' # Adjust this path as needed

# # Load the state dictionary
# try:
#     state_dict = torch.load(model_path)
#     print("Model state dictionary loaded successfully!")
# except Exception as e:
#     print(f"Error loading model: {e}")
#     exit() # Exit if loading fails

# # --- Inspect the contents ---

# # Print all keys (layer names) in the state dictionary
# print("\nKeys in the state dictionary:")
# for key in state_dict.keys():
#     print(key)

# # Inspect the shape and device of a few params
# print("\nExample parameters from the state dictionary (first few keys):\n")Ÿ
# for i, (key, value) in enumerate(state_dict.items()):
#     if i >= 5: # Limit to first 5
#         break
#     print(f"  Key: {key}")
#     print(f"  Shape: {value.shape}")
#     print(f"  Device: {value.device}")
#     print(f"      Value (first 5 elements): {value.flatten()[:5].tolist()}\n")

# if 'gat_layers.0.lin_src.weight' in state_dict:
#     print(f"\nShape of first GAT layer weights: {state_dict['gat_layers.0.lin_src.weight'].shape}")

In [None]:
# Load library imports
import os
import sys
import torch
import random
import logging
import datetime
import numpy as np
import pandas as pd
import geopandas as gpd
from collections import Counter

# Load project Imports
from src.utils.config_loader import load_project_config
from src.model.model_building import build_data_loader, instantiate_model_and_associated
from src.utils.config_loader import load_project_config
from src.graph_building.graph_construction import build_mesh, \
    define_catchment_polygon, define_graph_adjacency
from src.preprocessing.data_partitioning import define_station_id_splits, \
    load_graph_tensors, build_pyg_object
from src.preprocessing.model_feature_engineering import preprocess_gwl_features, \
    preprocess_shared_features

In [None]:
# Set up logger config
logging.basicConfig(
    level=logging.INFO,
   format='%(levelname)s - %(message)s',
#    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

# Set up logger for file and load config file for paths and params
logger = logging.getLogger(__name__)
config = load_project_config(config_path="config/project_config.yaml")
notebook = True

# Set up seeding to define global states
random_seed = config["global"]["pipeline_settings"]["random_seed"]
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define notebook demo catchment
catchments_to_process = config["global"]["pipeline_settings"]["catchments_to_process"]
catchment = catchments_to_process[0]
run_defra_API_calls = config["global"]["pipeline_settings"]["run_defra_api"]

logger.info(f"Show Notebook Outputs: {notebook}")
logger.info(f"Notebook Demo Catchment: {catchment.capitalize()}")

In [None]:
# Select Catchment area from country wide gdf
define_catchment_polygon(
    england_catchment_gdf_path=config[catchment]['paths']['gis_catchment_boundary'],
    target_mncat=config[catchment]['target_mncat'],
    catchment=catchment,
    polygon_output_path=config[catchment]['paths']['gis_catchment_dir']
)

# Build catchment mesh
mesh_nodes_table, mesh_nodes_gdf, mesh_cells_gdf_polygons, catchment_polygon = build_mesh(
    shape_filepath=config[catchment]['paths']['gis_catchment_dir'],
    output_path=config[catchment]['paths']['mesh_nodes_output'],
    catchment=catchment,
    grid_resolution=config[catchment]['preprocessing']['graph_construction']['grid_resolution']
)

logger.info(f"Pipeline step 'Build Mesh' complete for {catchment} catchment.")

In [None]:
directional_edge_path = config[catchment]["paths"]["direction_edge_weights_path"]
directional_edge_weights = pd.read_csv(directional_edge_path)

# Create specific node_id column to merge
directional_edge_weights["node_id"] = range(0, len(directional_edge_weights))
directional_edge_weights

In [None]:
# Load in directional edge weights and mean elevation (not req. in main pipeline)
directional_edge_path=config[catchment]["paths"]["direction_edge_weights_path"]
directional_edge_weights = pd.read_csv(directional_edge_path)

edge_attr_tensor, edge_index_tensor = define_graph_adjacency(
    directional_edge_weights=directional_edge_weights,
    elevation_geojson_path=config[catchment]['paths']['elevation_geojson_path'],
    graph_output_dir=config[catchment]["paths"]["graph_data_output_dir"],
    mesh_cells_gdf_polygons=mesh_cells_gdf_polygons,
    epsilon_path=config["global"]["graph"]["epsilon"],
    catchment=catchment
)

logger.info(f"Pipeline step 'Define Graph Adjacency' complete for {catchment} catchment.\n")

In [None]:
# # Load tensors from file if needed
# edge_index_tensor, edge_attr_tensor = load_graph_tensors(
#     graph_output_dir=config[catchment]["paths"]["graph_data_output_dir"],
#     catchment=catchment
# )

# Load main_df_full from file if needed
load_path = config[catchment]["paths"]["final_df_path"] + 'final_df.csv'
main_df_full = pd.read_csv(load_path)
main_df_full

In [None]:
# --- 6a. Define Spatial Split for Observed Stations ---

train_station_ids, val_station_ids, test_station_ids = define_station_id_splits(
    main_df_full=main_df_full,
    catchment=catchment,
    test_station_shortlist=config[catchment]["model"]["data_partioning"]["test_station_shortlist"],
    val_station_shortlist=config[catchment]["model"]["data_partioning"]["val_station_shortlist"],
    random_seed=config["global"]["pipeline_settings"]["random_seed"],
    perc_train=config[catchment]["model"]["data_partioning"]["percentage_train"],
    perc_val=config[catchment]["model"]["data_partioning"]["percentage_val"],
    perc_test=config[catchment]["model"]["data_partioning"]["percentage_test"]
)

logger.info(f"Pipeline Step 'define station splits' complete for {catchment} catchment.")

In [None]:
# --- 6b. Preprocess (Standardise, one hot encode, round to 4dp) all shared features (not GWL) ---

processed_df, shared_scaler, shared_encoder, gwl_feats = preprocess_shared_features(
    main_df_full=main_df_full,
    catchment=catchment,
    random_seed=config["global"]["pipeline_settings"]["random_seed"],
    violin_plt_path=config[catchment]["visualisations"]["violin_plt_path"],
    scaler_dir = config[catchment]["paths"]["scalers_dir"]
)

logger.info(f"Pipeline Step 'Preprocess Final Shared Features' complete for {catchment} catchment.")

In [None]:
# --- 6c. Preprocess all GWL features using training data only ---

processed_df, gwl_scaler, gwl_encoder = preprocess_gwl_features(
    processed_df=processed_df,
    catchment=catchment,
    train_station_ids=train_station_ids,
    val_station_ids=val_station_ids,
    test_station_ids=test_station_ids,
    sentinel_value = config["global"]["graph"]["sentinel_value"],
    scaler_dir = config[catchment]["paths"]["scalers_dir"]
)

logger.info(f"Pipeline Step 'Preprocess Final GWL Features' complete for {catchment} catchment.")

In [None]:
all_timesteps_list = torch.load(config[catchment]["paths"]["pyg_object_path"])
all_timesteps_list

In [None]:
# --- 7a. Build Data Loaders by Timestep ---

full_dataset_loader = build_data_loader(
    all_timesteps_list=all_timesteps_list,
    batch_size = config["global"]["model"]["data_loader_batch_size"],
    shuffle = config["global"]["model"]["data_loader_shuffle"],
    catchment=catchment
)

logger.info(f"Pipeline Step 'Create PyG DataLoaders' complete for {catchment} catchment.")

In [None]:
# --- 7b. Define Graph Neural Network Architecture ---

model, device, optimizer, criterion = instantiate_model_and_associated(
    all_timesteps_list=all_timesteps_list,
    config=config,
    catchment=catchment
)

logger.info(f"Pipeline Step 'Instantiate GAT-LSTM Model' complete for {catchment} catchment.")

In [None]:
import torch
import joblib
import logging
import numpy as np
from tqdm import tqdm
from sklearn.metrics import mean_absolute_error, mean_squared_error

path = "data/04_model/eden/model/pt_model/model_20250727-225058_GATTrue_LSTMFalse_GATH12_GATD0-4_GATHC64_GATOC64_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-0005_SM0-0_E200_ESP25_LRSF0-5_LRSP8_MINLR1e-06_LD0-0001_GCMN1-0.pt"
iteration = 5

best_model = model  # Assume model object already defined and moved to correct device
best_model.load_state_dict(torch.load(path))
best_model.eval()
logger.info(f"Loaded best model from {path}")

# Load target scaler
scaler_path = "data/03_graph/eden/scalers/target_scaler.pkl"
target_scaler = joblib.load(scaler_path)
logger.info(f"Loaded target scaler from: {scaler_path}")

scale = torch.tensor(target_scaler.scale_, device=device)
mean = torch.tensor(target_scaler.mean_, device=device)

# Initialize global LSTM state
if best_model.run_LSTM:
    lstm_state_store = {
        'h': torch.zeros(best_model.num_layers_lstm, best_model.num_nodes, best_model.hidden_channels_lstm).to(device),
        'c': torch.zeros(best_model.num_layers_lstm, best_model.num_nodes, best_model.hidden_channels_lstm).to(device)
    }
else:
    lstm_state_store = None

# Prepare for evaluation
test_predictions_unscaled = []
test_actuals_unscaled = []

logger.info("--- Starting Model Evaluation on Test Set ---")
test_loop = tqdm(all_timesteps_list, desc="Evaluating on Test Set", leave=False)

with torch.no_grad():
    for i, data in enumerate(test_loop):
        data = data.to(device)
        test_mask = data.test_mask

        if test_mask.sum() == 0:
            continue

        # Model forward pass on full node set
        predictions_all, (h_new, c_new), returned_node_ids = best_model(
            x=data.x,
            edge_index=data.edge_index,
            edge_attr=data.edge_attr,
            current_timestep_node_ids=data.node_id,
            lstm_state_store=lstm_state_store
        )

        # Update LSTM memory for current nodes
        if best_model.run_LSTM:
            lstm_state_store['h'][:, returned_node_ids, :] = h_new.detach()
            lstm_state_store['c'][:, returned_node_ids, :] = c_new.detach()

        # Filter predictions/targets to test nodes
        preds_std = predictions_all[test_mask]
        targets_std = data.y[test_mask]

        # Inverse transform to original scale
        preds_np = preds_std.cpu().numpy()
        targets_np = targets_std.cpu().numpy()

        preds_unscaled = target_scaler.inverse_transform(preds_np)
        targets_unscaled = target_scaler.inverse_transform(targets_np)

        test_predictions_unscaled.extend(preds_unscaled.flatten())
        test_actuals_unscaled.extend(targets_unscaled.flatten())

        if i < 5:  # Show first few predictions
            print("Sample predictions (m AOD):", preds_unscaled[:5].flatten())
            print("Sample actuals     (m AOD):", targets_unscaled[:5].flatten())

# Final evaluation
if len(test_actuals_unscaled) > 0:
    loss_type = config[catchment]["training"]["loss"]

    if loss_type == "MAE":
        final_test_metric = mean_absolute_error(test_actuals_unscaled, test_predictions_unscaled)
        logger.info(f"\n--- Final Test Set MAE (m AOD): {final_test_metric:.4f} ---\n")

    elif loss_type == "MSE":
        final_test_metric = mean_squared_error(test_actuals_unscaled, test_predictions_unscaled)
        logger.info(f"\n--- Final Test Set MSE (m AOD²): {final_test_metric:.4f} ---\n")

    else:
        logger.warning(f"Unrecognized loss type '{loss_type}' in config — skipping final metric calculation.")
else:
    logger.warning("No test data found — check 'data.test_mask'.")

logger.info("--- Model Evaluation on Test Set Complete ---")

In [None]:
# Load the target scaler
from joblib import load

# Load target scaler (y, 'gwl_value') in
scalers_dir = config[catchment]["paths"]["scalers_dir"]
target_scaler_path = os.path.join(scalers_dir, "target_scaler.pkl")
try:
    target_scaler = load(target_scaler_path)
    logger.info(f"Successfully loaded target scaler from: {target_scaler_path}")
except Exception as e:
    logger.warning(f"No target scaler found or error loading it: {e}")
    target_scaler = None

# Convert both to np array
test_predictions_np = np.array(test_predictions_unscaled).reshape(-1, 1)
test_actuals_np = np.array(test_actuals_unscaled).reshape(-1, 1)

# Confirm range for sanity check
logger.info(f"Sample prediction range: {test_predictions_np.min():.2f} to {test_predictions_np.max():.2f}")
logger.info(f"Sample actual range:     {test_actuals_np.min():.2f} to {test_actuals_np.max():.2f}")

# Assign reshaped values to final arrays for plotting/metrics
test_predictions_final = test_predictions_np
test_actuals_final = test_actuals_np

In [None]:
from sklearn.metrics import mean_absolute_error

final_test_mae = mean_absolute_error(test_actuals_final, test_predictions_final)
unit_label = "mAOD" if target_scaler else "standard units"
logger.info(f"Final Test Set MAE: {final_test_mae:.4f} {unit_label}")


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 7))

plt.plot(test_actuals_final, label='Actual GWL Value', color='blue', alpha=0.7)
plt.plot(test_predictions_final, label='Predicted GWL Value', color='red', alpha=0.7, linestyle='--')

plt.title('Actual vs. Predicted Groundwater Levels on Test Set')
plt.xlabel('Data Point Index (Sequential Timesteps/Stations)')
plt.ylabel('Groundwater Level (m AOD)' if target_scaler else 'Standardised GWL')
plt.legend()
plt.grid(True)
plt.tight_layout()

# Get model path base name and split extension of so save path matches model file
base_name = os.path.basename(path)
filename_no_ext, extension = os.path.splitext(base_name)
save_path = "results/trained_models/eden/" + f"{iteration}_" + filename_no_ext
# save_path = os.path.join("results/trained_models/eden/2_", filename_no_ext)
plt.savefig(save_path, dpi=300)

plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Set a style for plots (optional)
sns.set_style("whitegrid")

# Create a figure and a set of subplots
plt.figure(figsize=(15, 3))

# Plot the predicted values
plt.plot(test_predictions_unscaled, label='Predicted GWL Value', color='red', alpha=0.7, linestyle='--')

# Add titles and labels
plt.title('Predicted Groundwater Levels on Test Set')
plt.xlabel('Data Point Index (Sequential Timesteps/Stations)')
plt.ylabel('Groundwater Level (m AOD)')
plt.legend()
plt.grid(True)
plt.tight_layout() 

plt.show()

logger.info("Generated plot of predicted values.")