In [None]:
# Load library imports
import os
import sys
import torch
import joblib
import random
import logging
import datetime
import subprocess
import numpy as np
import pandas as pd
import geopandas as gpd

# Load project Imports
from src.utils.config_loader import load_project_config, deep_format, expanduser_tree
from src.graph_building.graph_construction import build_mesh, \
    define_catchment_polygon, define_graph_adjacency
from src.preprocessing.data_partitioning import define_station_id_splits, \
    load_graph_tensors, build_pyg_object
from src.preprocessing.model_feature_engineering import preprocess_gwl_features, \
    preprocess_shared_features
from src.utils.run_manifest import save_run_manifest 

In [None]:
# Set up logger config
logging.basicConfig(
    level=logging.INFO,
   format='%(levelname)s - %(message)s',
#    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

# Set up logger for file and load config file for paths and params
logger = logging.getLogger(__name__)
config = load_project_config(config_path="config/project_config.yaml")
notebook = True

# Set up root directory paths in config
raw_data_root = config["global"]["paths"]["raw_data_root"]
results_root = config["global"]["paths"]["results_root"]

# Reformat config roots
config = deep_format(
    config,
    raw_data_root=raw_data_root,
    results_root=results_root
)
config = expanduser_tree(config)

In [None]:
# Set up seeding to define global states
random_seed = config["global"]["pipeline_settings"]["random_seed"]
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define notebook demo catchment
catchments_to_process = config["global"]["pipeline_settings"]["catchments_to_process"]
catchment = catchments_to_process[0]
run_defra_API_calls = config["global"]["pipeline_settings"]["run_defra_api"]

logger.info(f"Show Notebook Outputs: {notebook}")
logger.info(f"Notebook Demo Catchment: {catchment.capitalize()}")

In [None]:
# Select Catchment area from country wide gdf
define_catchment_polygon(
    england_catchment_gdf_path=config[catchment]['paths']['gis_catchment_boundary'],
    target_mncat=config[catchment]['target_mncat'],
    catchment=catchment,
    polygon_output_path=config[catchment]['paths']['gis_catchment_dir']
)

# Build catchment mesh
mesh_nodes_table, mesh_nodes_gdf, mesh_cells_gdf_polygons, catchment_polygon = build_mesh(
    shape_filepath=config[catchment]['paths']['gis_catchment_dir'],
    output_path=config[catchment]['paths']['mesh_nodes_output'],
    catchment=catchment,
    grid_resolution=config[catchment]['preprocessing']['graph_construction']['grid_resolution']
)

logger.info(f"Pipeline step 'Build Mesh' complete for {catchment} catchment.")

In [None]:
directional_edge_path = config[catchment]["paths"]["direction_edge_weights_path"]
directional_edge_weights = pd.read_csv(directional_edge_path)

# Create specific node_id column to merge
directional_edge_weights["node_id"] = range(0, len(directional_edge_weights))
directional_edge_weights

In [None]:
mesh_cells_gdf_polygons

### Split into Train/Val/Test Split ###

In [None]:
# # Load tensors from file if needed
# edge_index_tensor, edge_attr_tensor = load_graph_tensors(
#     graph_output_dir=config[catchment]["paths"]["graph_data_output_dir"],
#     catchment=catchment
# )

# Load main_df_full from file if needed
load_path = config[catchment]["paths"]["final_df_path"] + 'final_df.csv'
main_df_full = pd.read_csv(load_path)
main_df_full

In [None]:
# --- 6a. Define Spatial Split for Observed Stations ---

train_station_ids, val_station_ids, test_station_ids = define_station_id_splits(
    main_df_full=main_df_full,
    catchment=catchment,
    test_station_shortlist=config[catchment]["model"]["data_partioning"]["test_station_shortlist"],
    val_station_shortlist=config[catchment]["model"]["data_partioning"]["val_station_shortlist"],
    random_seed=config["global"]["pipeline_settings"]["random_seed"],
    output_dir=config[catchment]["paths"]["aux_dir"],
    perc_train=config[catchment]["model"]["data_partioning"]["percentage_train"],
    perc_val=config[catchment]["model"]["data_partioning"]["percentage_val"],
    perc_test=config[catchment]["model"]["data_partioning"]["percentage_test"]
)

logger.info(f"Pipeline Step 'define station splits' complete for {catchment} catchment.")

### BUILD EDGE WEIGHTS ###

In [None]:
# Load in directional edge weights and mean elevation (not req. in main pipeline)
directional_edge_path=config[catchment]["paths"]["direction_edge_weights_path"]
directional_edge_weights = pd.read_csv(directional_edge_path)

station_node_ids = np.array(sorted(
    set(train_station_ids) | set(val_station_ids) | set(test_station_ids)
), dtype=int)

edge_attr_tensor, edge_index_tensor = define_graph_adjacency(
    directional_edge_weights=directional_edge_weights,
    elevation_geojson_path=config[catchment]['paths']['elevation_geojson_path'],
    graph_output_dir=config[catchment]["paths"]["graph_data_output_dir"],
    mesh_cells_gdf_polygons=mesh_cells_gdf_polygons,
    epsilon_path=config["global"]["graph"]["epsilon"],
    station_node_ids=station_node_ids,
    station_radius_m=config["global"]["graph"]["graph_construction"]["station_radius_m"],
    catchment=catchment
)

logger.info(f"Pipeline step 'Define Graph Adjacency' complete for {catchment} catchment.\n")

In [None]:
edge_attr_tensor

In [None]:
edge_index_tensor

### Preprocess Shared Features Prior to Split

In [None]:
# --- 6b. Preprocess (Standardise, one hot encode, round to 4dp) all shared features (not GWL) ---

processed_df, shared_scaler, shared_encoder, gwl_feats = preprocess_shared_features(
    main_df_full=main_df_full,
    catchment=catchment,
    random_seed=config["global"]["pipeline_settings"]["random_seed"],
    violin_plt_path=config[catchment]["visualisations"]["violin_plt_path"],
    scaler_dir = config[catchment]["paths"]["scalers_dir"],
    aux_dir=config[catchment]["paths"]["aux_dir"]
)

logger.info(f"Pipeline Step 'Preprocess Final Shared Features' complete for {catchment} catchment.")

### Split processed_df into train/val/test subsets

In [None]:
# --- 6c. Preprocess all GWL features using training data only ---

processed_df, gwl_scaler, gwl_encoder = preprocess_gwl_features(
    processed_df=processed_df,
    catchment=catchment,
    train_station_ids=train_station_ids,
    val_station_ids=val_station_ids,
    test_station_ids=test_station_ids,
    sentinel_value=config["global"]["graph"]["sentinel_value"],
    scaler_dir=config[catchment]["paths"]["scalers_dir"],
    parquet_path=os.path.join(config[catchment]["paths"]["final_df_path"], 'processed_df.parquet')
)

logger.info(f"Pipeline Step 'Preprocess Final GWL Features' complete for {catchment} catchment.")

In [None]:
print("Columns in processed_df after preprocessing:", processed_df.columns)

In [None]:
# # Shorten df to test range to reduce computation requirements

# test_start_date_str = config["global"]["data_ingestion"]["test_start_date"]
# test_end_date_str = config["global"]["data_ingestion"]["test_end_date"]

# # Convert 'timestep' column to datetime objects
# processed_df['timestep'] = pd.to_datetime(processed_df['timestep'])

# processed_df_test = processed_df.loc[
#     (processed_df['timestep'] >= test_start_date_str) &
#     (processed_df['timestep'] <= test_end_date_str)
# ].copy()

# print(f"Original processed_df shape (full data): {processed_df.shape}")
# print(f"Test processed_df_test shape (sliced from {test_start_date_str} to {test_end_date_str}): {processed_df_test.shape}")
# processed_df_test = processed_df.drop(columns=['streamflow_total_m3', 'HOST_soil_class_freely_draining_soils', 'HOST_soil_class_high_runoff_(impermeable)', 
#                                                'HOST_soil_class_impeded_saturated_subsurface_flow', 'HOST_soil_class_peat_soils', 'aquifer_productivity_High',
#                                                'aquifer_productivity_Low', 'aquifer_productivity_Mixed', 'aquifer_productivity_Moderate',
#                                                'aquifer_productivity_nan']).copy()
# processed_df_test = processed_df.copy()

# TESTING WITH GWL LAGS (AR Inputs) TO UNDERSTAND PERFORMANCE
processed_df_test = processed_df.drop(columns=['gwl_lag1', 'gwl_lag2', 'gwl_lag3', 'gwl_lag4', 'gwl_lag5', 'gwl_lag6', 'gwl_lag7'])

In [None]:
# # Display full (test) processed df
processed_df_test

In [None]:
print(processed_df_test.columns[54:61])

In [None]:
print(processed_df_test.columns[62:63])

In [None]:
processed_df_test.columns

In [None]:
# processed_df_test.describe()

Assign validation stations using geographic proximity with buffering

In [None]:
# Split to final preprocessed static columns
column_indices = list(range(0, 9)) + list(range(25, 54)) + list(range(61, 63))
split_df = processed_df_test.iloc[:, column_indices]

# Aggregate to node_id
aggregated_df = split_df.groupby('node_id').first().reset_index()

# Get station data
station_nodes = [430, 902, 1254, 1326, 1335, 1420, 1556, 1648, 1772, 1858, 1983, 2388, 2487, 2594]
station_dfs = aggregated_df[aggregated_df['node_id'].isin(station_nodes)]
station_dfs = station_dfs.drop(columns='timestep')

# Load reference df
station_metadata = pd.read_csv("data/02_processed/eden/gwl_station_data/snapped_station_node_mapping.csv")

# Merge in required data
merged_station_df = station_dfs.merge(
    station_metadata[['node_id', 'station_name', 'easting', 'northing', 'geometry']],
    on='node_id',
    how='left'
)

# Clean station_name column and drop unneeded stations
merged_station_df['station_name'] = merged_station_df['station_name'].astype(str).str.lower().str.replace(" ", "_")
rows_to_drop = merged_station_df[merged_station_df['station_name'] == 'cliburn_town_bridge_1'].index
merged_station_df.drop(rows_to_drop, inplace=True)

merged_station_df

In [None]:
import numpy as np
import pandas as pd
from scipy.stats import spearmanr

DF = merged_station_df.copy() 
BUFFER_KM = 5.0  # km
USE_BUFFER = True
RIDGE = 1e-3

# Columns not used as environmental features
EXCLUDE = {'node_id', 'station_name', 'easting', 'northing', 'geometry'}

# Get pairwise geographic distances
def pairwise_geo_km(xy_m):
    d = xy_m[:, None, :] - xy_m[None, :, :]
    return np.hypot(d[...,0], d[...,1]) / 1000.0

# Calc per-fold whitening (Mahalanobis)
def whiten_per_fold(X, train_idx, ridge=1e-3):
    """
    Returns X_whitened for all rows, using mean/cov estimated on train_idx only.
    """
    mu = X[train_idx].mean(axis=0)
    Xc = X - mu
    Xt = X[train_idx] - mu
    # covariance on training stations
    C  = np.cov(Xt, rowvar=False)
    # ridge on diagonal
    C.flat[::C.shape[0]+1] += ridge
    # Σ^{-1/2}
    U,S,_ = np.linalg.svd(C, full_matrices=False)
    W = U @ np.diag(1.0/np.sqrt(S)) @ U.T
    return Xc @ W

# prepare merged_station_df matrices
ids = DF['node_id'].to_numpy()
XY  = DF[['easting','northing']].to_numpy(float)  # metres (BNG)
feat_cols = [c for c in DF.columns if c not in EXCLUDE]
X = DF[feat_cols].to_numpy(float)

D_geo = pairwise_geo_km(XY)

# proximity vs environmental similarity diagnostic
def proximity_env_report(X, D_geo, use_whitening=True, ridge=1e-3):
    n = X.shape[0]
    
    # global whitening for diagnostic only (not selection)
    if use_whitening:
        Xc = X - X.mean(0)
        C  = np.cov(Xc, rowvar=False)
        C.flat[::C.shape[0]+1] += ridge
        U,S,_ = np.linalg.svd(C, full_matrices=False)
        W = U @ np.diag(1.0/np.sqrt(S)) @ U.T
        Xw = Xc @ W
    else:
        Xw = X

    diff = Xw[:,None,:] - Xw[None,:,:]
    D_env = np.sqrt((diff**2).sum(-1))

    # upper triangle
    mask = np.triu(np.ones_like(D_env, dtype=bool), 1)
    rho, p_spear = spearmanr(D_env[mask], D_geo[mask])

    # quick Mantel permutation test
    def mantel(A, B, perms=20000, seed=42):
        rng = np.random.default_rng(seed)
        a = A[mask]; b = B[mask]
        obs = np.corrcoef(a, b)[0,1]
        cnt = 0
        for _ in range(perms):
            p = rng.permutation(n)
            bb = B[np.ix_(p, p)][mask]
            if np.corrcoef(a, bb)[0,1] >= obs:
                cnt += 1
        p_val = (cnt + 1) / (perms + 1)
        return float(obs), float(p_val)

    r_mantel, p_mantel = mantel(D_env, D_geo, perms=20000, seed=42)

    print(f"Spearman ρ(env, geo) = {rho:.3f}  (p={p_spear:.3f})")
    print(f"Mantel r = {r_mantel:.3f}  (p={p_mantel:.3f})")
    return D_env

# Run diagnostic report
D_env_diag = proximity_env_report(X, D_geo, use_whitening=True)

# Get assignments: two val stations by environmental similarity
def assign_two_validations_env_only(ids, X, XY, buffer_km=5.0, use_buffer=True, ridge=1e-3):
    """
    For each test station i, choose two validation stations j by:
      - computing Mahalanobis distances in static space with params
      - ranking by environmental dist only
      - discarding candidates within buffer dist of test station.
    """
    n = len(ids)
    D_geo = pairwise_geo_km(XY)
    rows = []

    for i in range(n):
        train_idx = np.arange(n) != i
        Xw = whiten_per_fold(X, train_idx, ridge=ridge)

        # environmental distances from test i to everyone
        d_env = np.linalg.norm(Xw - Xw[i], axis=1)
        d_env[i] = np.inf

        # candidate mask
        if use_buffer and buffer_km > 0:
            cand = (D_geo[i] >= buffer_km) & (np.arange(n) != i)
            # relax if too few
            if cand.sum() < 2:
                cand = (D_geo[i] >= max(3.0, 0.6*buffer_km)) & (np.arange(n) != i)
        else:
            cand = (np.arange(n) != i)

        idxs = np.where(cand)[0]
        order = idxs[np.argsort(d_env[idxs])]

        v1 = order[0]
        v2 = order[1] if len(order) > 1 else None

        rows.append({
            "test_node": ids[i],
            "val1_node": ids[v1],
            "val2_node": (ids[v2] if v2 is not None else np.nan),
            "env_d_val1": float(d_env[v1]),
            "env_d_val2": (float(d_env[v2]) if v2 is not None else np.nan),
            "geo_km_val1": float(D_geo[i, v1]),
            "geo_km_val2": (float(D_geo[i, v2]) if v2 is not None else np.nan)
        })
    return pd.DataFrame(rows)

assignments = assign_two_validations_env_only(
    ids=ids, X=X, XY=XY,
    buffer_km=BUFFER_KM, use_buffer=USE_BUFFER, ridge=RIDGE
)

assignments.sort_values("test_node").reset_index(drop=True)

In [None]:
from sklearn.manifold import MDS
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist

# 2-D MDS on the precomputed Mahalanobis (environmental) distances
mds_model = MDS(
    n_components=2,
    dissimilarity="precomputed",
    random_state=42,
    n_init=8,
    max_iter=3000,
    normalized_stress="auto",
)
mds = mds_model.fit_transform(D_env_diag)   # <— this defines `mds`
print("MDS stress:", mds_model.stress_)

In [None]:
import numpy as np
import pandas as pd
from sklearn.covariance import OAS, LedoitWolf
from sklearn.decomposition import PCA
from scipy.spatial.distance import cdist

# ---------- distances in environmental space ----------

def whiten_global(X):
    Xc = X - X.mean(0)
    C = LedoitWolf().fit(Xc).covariance_
    U,S,_ = np.linalg.svd(C, full_matrices=False)
    W = U @ np.diag(1/np.sqrt(S)) @ U.T
    return Xc @ W

def whiten_global_shrink(X, method="oas", center=True, eps=1e-12):
    """
    Global whitening so Euclidean ≡ Mahalanobis under a shrinkage covariance.
    Returns X_w (whitened), mean vector mu, and whitening matrix W (Σ^{-1/2}).
    """
    Xc = X - X.mean(0) if center else X.copy()
    if method == "oas":
        est = OAS().fit(Xc)
    elif method == "lw":
        est = LedoitWolf().fit(Xc)
    else:
        raise ValueError("method must be 'oas' or 'lw'")
    C = est.covariance_
    # symmetric eigendecomposition for Σ^{-1/2}
    eigvals, eigvecs = np.linalg.eigh(C)
    eigvals = np.clip(eigvals, eps, None)
    W = eigvecs @ np.diag(1.0 / np.sqrt(eigvals)) @ eigvecs.T
    Xw = Xc @ W
    return Xw, Xc.mean(0), W

def reduce_dimensionality(Xw, var_keep=0.95, min_dim=3, max_dim=12):
    """
    PCA on whitened features; keeps enough components to explain var_keep.
    """
    pca = PCA(n_components=min(max_dim, Xw.shape[1]), svd_solver="full")
    pca.fit(Xw)
    cum = np.cumsum(pca.explained_variance_ratio_)
    r = np.searchsorted(cum, var_keep) + 1
    r = max(min_dim, min(r, max_dim))
    Z = pca.transform(Xw)[:, :r]
    return Z, r, pca

# Build environmental distance matrix
# Xw, mu_w, W = whiten_global_shrink(X, method="oas")       # stable Σ^{-1/2}
Xw = whiten_global(X)
Z, r, pca = reduce_dimensionality(Xw, var_keep=0.95)      # contrast restored
D_env = cdist(Z, Z, metric="euclidean")                   # ≡ Mahalanobis on reduced space

# ---------- optional geographic matrix ----------

def pairwise_geo_km(xy_m):
    d = xy_m[:, None, :] - xy_m[None, :, :]
    return np.hypot(d[...,0], d[...,1]) / 1000.0

D_geo = pairwise_geo_km(XY)  # as before

# ---------- k-center with buffer and graceful relaxation ----------

def kcenter_greedy(D_env, k, D_geo=None, min_geo_km=0.0, relax_to=0.0):
    """
    Farthest-first traversal (k-center). Enforces a geo buffer to ALL chosen
    centers; if infeasible, relaxes the buffer multiplicatively to 'relax_to'.
    Returns: centers list, per-point coverage (dist to nearest center),
    R_max and R_95.
    """
    n = D_env.shape[0]
    centers = [int(np.argmax(D_env.mean(1)))]  # farthest from global mean

    while len(centers) < k:
        d_to_S = D_env[:, centers].min(1)
        ok = np.ones(n, dtype=bool)
        ok[centers] = False

        if D_geo is not None and min_geo_km > 0:
            buf = float(min_geo_km)
            while True:
                ok = np.ones(n, dtype=bool); ok[centers] = False
                for c in centers:
                    ok &= (D_geo[:, c] >= buf)
                if ok.any() or buf <= relax_to:
                    break
                buf = max(relax_to, 0.8 * buf)  # relax by 20%

        cand = np.argmax(np.where(ok, d_to_S, -np.inf))
        centers.append(int(cand))

    cover = D_env[:, centers].min(1)
    R_max = float(cover.max())
    R_95  = float(np.quantile(cover, 0.95))
    return centers, cover, R_max, R_95

# choose k and compute coverage
k = 3
centers, cover, R_max, R_95 = kcenter_greedy(D_env, k, D_geo=D_geo, min_geo_km=5.0, relax_to=2.0)
labels = np.argmin(D_env[:, centers], axis=1)  # assignments
medoid_idx = centers                            # prototypes are the medoids
print("prototypes:", centers, f" | R_max={R_max:.3f}, R_95={R_95:.3f}")

# ---------- utility to scan k for an elbow ----------
def scan_k(D_env, ks, D_geo=None, min_geo_km=5.0, relax_to=2.0):
    rows = []
    for kk in ks:
        C, cov, rmax, r95 = kcenter_greedy(D_env, kk, D_geo, min_geo_km, relax_to)
        rows.append((kk, rmax, r95))
    return pd.DataFrame(rows, columns=["k","R_max","R_95"])


In [None]:
from sklearn.manifold import MDS
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# same distance matrix for k 
D_for_plot = D_env   

mds_model = MDS(
    n_components=2, dissimilarity="precomputed",
    random_state=42, n_init=8, max_iter=3000, normalized_stress="auto"
)
mds = mds_model.fit_transform(D_for_plot)

min_to_center = D_for_plot[:, centers].min(1)

with mpl.rc_context({
    "figure.dpi": 120,
    "axes.spines.top": False, "axes.spines.right": False,
    "axes.grid": True, "grid.linestyle": ":", "grid.alpha": 0.6, "grid.linewidth": 0.6,
    "axes.titlesize": 12, "axes.labelsize": 11, "xtick.labelsize": 9, "ytick.labelsize": 9,
}):
    fig, ax = plt.subplots(figsize=(6.6, 4.6))

    # sequential colormap for distances
    vmax = np.percentile(min_to_center, 97.5)
    norm = mpl.colors.Normalize(vmin=0, vmax=vmax)
    sc = ax.scatter(
        mds[:, 0], mds[:, 1],
        c=min_to_center, norm=mpl.colors.Normalize(0, vmax), cmap='viridis',
        s=50, linewidths=0.4, edgecolors="white", alpha=0.95, zorder=2
    )

    # ring prototypes
    for c in centers:
        ax.scatter(
            mds[c, 0], mds[c, 1],
            s=220, facecolors="none", edgecolors="black",
            linewidths=0.8, linestyle="--", zorder=4
        )

    # labels 
    names = DF["station_name"].to_list()
    nodes = DF["node_id"].astype(int).to_list()
    for i, node in enumerate(nodes):
        ax.annotate(
            node, (mds[i, 0], mds[i, 1]), xytext=(6, 6),
            textcoords="offset points", fontsize=8, color="#222",
            bbox=dict(boxstyle="round,pad=0.15", fc="white", ec="none", alpha=0.7),
            zorder=5
        )

    cbar = fig.colorbar(sc, ax=ax, shrink=0.86, pad=0.05)
    cbar.set_label("Mahalanobis distance to nearest prototype")

    ax.set_title("Environmental space (MDS of Mahalanobis distances)")
    ax.set_xlabel("MDS-1"); ax.set_ylabel("MDS-2")
    ax.set_aspect("equal")
    fig.tight_layout()
    fig.savefig("results/figures/eden/other/mds_prototypes.png", bbox_inches="tight", dpi=300)


### Create Train/Val/Test PyG Objects for Model Input

In [None]:
# --- 6d. Create PyG data objects using partioned station IDs (from 6a) ---

# Run time approx. 13.5 mins to build 4018 timesteps of objects (0.201s per timestep)
gwl_ohe_cols = joblib.load(os.path.join(config[catchment]["paths"]["scalers_dir"], "gwl_ohe_cols.pkl"))
all_timesteps_list = build_pyg_object(
    processed_df=processed_df_test,
    sentinel_value=config["global"]["graph"]["sentinel_value"],
    train_station_ids=train_station_ids,
    val_station_ids=val_station_ids,
    test_station_ids=test_station_ids,
    gwl_feats=gwl_feats,
    gwl_ohe_cols=gwl_ohe_cols,
    edge_index_tensor=edge_index_tensor,
    edge_attr_tensor=edge_attr_tensor,
    scalers_dir=config[catchment]["paths"]["scalers_dir"],
    catchment=catchment
)

# Save all_timesteps_list to file
torch.save(all_timesteps_list, config[catchment]["paths"]["pyg_object_path"])
logger.info(f"Pipeline Step 'Build PyG Data Objects' complete for {catchment} catchment.")

In [None]:
# Get current git commit hash to help with reproducibility if run performance is lost
git_commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()

run_dir = os.path.join("runs", datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
save_run_manifest(
    run_dir=run_dir,
    config=config,
    git_commit=git_commit,
    all_timesteps_list=all_timesteps_list,
    temporal_features=config[catchment]["model"]["architecture"]["temporal_features"],
    scalers_dir=config[catchment]["paths"]["scalers_dir"],
    train_station_ids=train_station_ids,
    val_station_ids=val_station_ids,
    test_station_ids=test_station_ids,
    edge_index_path=os.path.join(config[catchment]["paths"]["graph_data_output_dir"], "edge_index_tensor.pt"),
    edge_attr_path=os.path.join(config[catchment]["paths"]["graph_data_output_dir"], "edge_attr_tensor.pt"),
    sentinel_value=config["global"]["graph"]["sentinel_value"],
    epsilon=config["global"]["graph"]["epsilon"]
)