In [None]:
# Load library imports
import os
import sys
import shap
import torch
import random
import logging
import datetime
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
import geopandas as gpd
from collections import Counter
import matplotlib.pyplot as plt
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

# Load project Imports
from src.utils.config_loader import load_project_config, deep_format, expanduser_tree
from src.model.GAT_LSTM_class import GAT_LSTM_Model

In [None]:
# Set up logger config
logging.basicConfig(
    level=logging.INFO,
   format='%(levelname)s - %(message)s',
#    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

# Set up logger for file and load config file for paths and params
logger = logging.getLogger(__name__)
config = load_project_config(config_path="config/project_config.yaml")
notebook = True

# Set up root directory paths in config
raw_data_root = config["global"]["paths"]["raw_data_root"]
results_root = config["global"]["paths"]["results_root"]
 
# Reformat config roots
config = deep_format(
    config,
    raw_data_root=raw_data_root,
    results_root=results_root
)
config = expanduser_tree(config)

In [None]:
# Set up seeding to define global states
random_seed = config["global"]["pipeline_settings"]["random_seed"]
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define notebook demo catchment
catchments_to_process = config["global"]["pipeline_settings"]["catchments_to_process"]
catchment = catchments_to_process[0]
run_defra_API_calls = config["global"]["pipeline_settings"]["run_defra_api"]

logger.info(f"Show Notebook Outputs: {notebook}")
logger.info(f"Notebook Demo Catchment: {catchment.capitalize()}")

In [None]:
# Load in timesteps PyG Object
all_timesteps_list = torch.load(config[catchment]["paths"]["pyg_object_path"])
all_timesteps_list

In [None]:
# --- Load GAT-LSTM Model and Data Loaders ---

# Define model checkpoint path
checkpoint_path = "data/04_model/eden/model/pt_model/model_20250812-092352_GATTrue_LSTMFalse_GATH12_GATD0-4_GATHC64_GATOC64_GATNL2_LSTHC32_LSTNL1_OUTD1_LR0-001_WD0-001_SM0-1_E200_ESP30_LRSF0-5_LRSP8_MINLR1e-06_LD0-0001_GCMN1-0.pt"

# Load model hyperparameters from config
model_config = config[catchment]["model"]["architecture"]
training_config = config[catchment]["training"]

temporal_features = model_config["temporal_features"]
temporal_features_dim = len(temporal_features)
in_channels = all_timesteps_list[0].x.shape[1]

# Instantiate the model class
model = GAT_LSTM_Model(
    in_channels=in_channels,
    temporal_features_dim=temporal_features_dim,
    static_features_dim=in_channels - temporal_features_dim,
    hidden_channels_gat=model_config["hidden_channels_gat"],
    out_channels_gat=model_config["out_channels_gat"],
    heads_gat=model_config["heads_gat"],
    dropout_gat=model_config["dropout_gat"],
    hidden_channels_lstm=model_config["hidden_channels_lstm"],
    num_layers_lstm=model_config["num_layers_lstm"],
    num_layers_gat=model_config["num_layers_gat"],
    num_nodes=len(all_timesteps_list[0].x),
    output_dim=model_config["output_dim"],
    run_GAT=model_config["run_GAT"],
    run_LSTM=model_config["run_LSTM"],
    random_seed=random_seed,
    catchment=catchment,
    run_node_conditioner=model_config["run_node_conditioner"],
    fusion_mode=model_config["fusion_mode"]
)

# Load trained model (checkpoint) state
model.load_state_dict(torch.load(checkpoint_path))
logger.info(f"Successfully loaded model from {checkpoint_path}")
model.eval()

# Run data loader
data_loader = DataLoader(all_timesteps_list, batch_size=1, shuffle=False)

In [None]:
# --- Define custom SHAP prediction wrapper ---

# Create PyG model from numpy array
def predict_func_shap(data_as_numpy):
    """
    This takes a numpy array, reconstructs the PyG object needed for the forward pass of the model and
    returns a np array again as needed for SHAP analysis.
    
    NOTE: It assumes edge_index and edge_attr are the same for all timesteps (currently current but
    will need adjusting in future if this changes).
    """
    # Initialise predictions list to build array
    predictions = []
    
    # Ensure the model is in evaluation mode
    model.eval()
    
    # If using LSTM state then store it (fill tnesor initially with zeros)
    if model.run_LSTM:
        lstm_state_store = {
            'h': torch.zeros(model.num_layers_lstm, model.num_nodes, model.hidden_channels_lstm),
            'c': torch.zeros(model.num_layers_lstm, model.num_nodes, model.hidden_channels_lstm)
        }
    
    # Ensure no gradient and iterate over each sample provided by SHAP
    with torch.no_grad():
        for sample in data_as_numpy:
            
            # Reconstruct the PyG Data object from the NumPy feature array
            data_obj = Data(x=torch.tensor(sample.reshape(model.num_nodes, -1), dtype=torch.float32),
                            edge_index=all_timesteps_list[0].edge_index,
                            edge_attr=all_timesteps_list[0].edge_attr,
                            node_id=all_timesteps_list[0].node_id)
            
            # Define vars required for fwd pass
            x_full = data_obj.x
            node_ids_in_current_timestep = data_obj.node_id
            
            # Run forward pass using reconstructed data object
            predictions_all_nodes, _, _ = (
                x_full, data_obj.edge_index, data_obj.edge_attr, node_ids_in_current_timestep, lstm_state_store
            )
            
            # Flatten the output (SHAP works with one pred per input row). This is an oversimplification but works for summary
            predictions.append(predictions_all_nodes.cpu().numpy().flatten())
    
    return np.array(predictions)

In [None]:
# --- Prepare data for SHAP Explainer ---

# KernelExplainer expects a 2D NumPy array (samples x features) - so flatten node features by timestep.

# Use small subset of timesteps for background data
background_data_list = [data.x.numpy().flatten() for data in all_timesteps_list[:10]]
background_data_array = np.vstack(background_data_list)

# Use small subset again for the samples to explain
samples_to_explain_list = [data.x.numpy().flatten() for data in all_timesteps_list[10:20]]
samples_to_explain_array = np.vstack(samples_to_explain_list)


# Set up the SHAP Explainer and get SHAP values using PyG reconstruction func (computationally intensive)
explainer = shap.KernelExplainer(predict_func_shap, background_data_array)
logger.info("Calculating SHAP values... This may take a while.")
shap_values = explainer.shap_values(samples_to_explain_array) # SHAP vals have same shape as flattened input feats

In [None]:
# --- Visualise the SHAP results ---

# Build flattened list of feature names for each node
feature_names = all_timesteps_list[0].features
node_ids = all_timesteps_list[0].node_id.tolist()

# Get feature names for summary
flat_feature_names = [f"Node {node_id}: {feat_name}" for node_id in node_ids for feat_name in feature_names]

# Generate summary plot o fhte impact of each flattened feature in the final model
logger.info("\nGenerating SHAP summary plot...")
shap.summary_plot(shap_values, samples_to_explain_array, feature_names=flat_feature_names)
plt.show()

In [None]:
# Visualise a single prediction (e.g. [0] for the first sample)
logger.info("\nGenerating SHAP force plot for a single instance...")
shap.force_plot(explainer.expected_value, shap_values[0], samples_to_explain_array[0], feature_names=flat_feature_names)