#### Physics GNN + RL System
- Refining the architecture
- Simple Physics element
- CAUTION: Still in dev - some code block may not make too much sense

In [1]:
import os
import torch
import random
import numpy as np
import pandas as pd
import torch.nn as nn
import networkx as nx
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.data import Batch
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import from_networkx
from torch.optim.lr_scheduler import ReduceLROnPlateau

os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # Verbose Cuda Error stuff
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')

Device: cuda


In [None]:
# NOTE Visuals...NOT IMPORTANT
# import matplotlib.pyplot as plt
# # Draw edges with thickness based on flow capacity
# DG = nx.DiGraph()
# # Add edges with attributes
# for index, row in df_edges.iterrows():
#     DG.add_edge(row['source_node'], row['target_node'], flow_capacity=row['flow_capacity'])

# largest_component = max(nx.connected_components(G), key=len)
# H = G.subgraph(largest_component).copy()

# pos = nx.kamada_kawai_layout(H)

# # Edge widths based on the flow capacity attribute
# edge_widths = [H[u][v]['flow_capacity']*1.5 for u, v in H.edges()]

# plt.figure(figsize=(12, 12))
# nx.draw_networkx_nodes(H, pos, node_size=700, node_color='skyblue', alpha=0.6)
# edges = nx.draw_networkx_edges(H, pos, edge_color='blue', width=edge_widths, alpha=0.7, arrows=True)
# nx.draw_networkx_labels(H, pos, font_size=14, font_color='darkblue')

# plt.title('Synthetic Network with Flow Capacities (Thickness)')
# plt.axis('off')  # Turn off the axis
# plt.show()


##### Simple GNN Architecture
$$
h_v^{(k+1)} = \sigma \left( \sum_{u \in \mathcal{N}(v)} \frac{1}{\sqrt{d_u d_v}} W^{(k)} h_u^{(k)} \right)
$$

where $ h_v^{(k)} $ represents the node features at layer $ k $, and $ W^{(k)} $ is the weight matrix.

In [10]:
# NOTE GNN Simple Architecture
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, out_channels)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv3(x, edge_index)
        return x

##### Reinforcement Learning (Policy Network)
The policy network in the RL model maps the state to an action:

$$
\text{Action} = \text{PolicyNetwork}(State)
$$

Optimization Methods
- Switching Optimizer
- Hierarchical RL
- Aggergator Method (Could Try but from theory calculation it's not good)

Adding Physics
- Conservation of Mass
- Darcy-Weisbach (Flow Rate) | https://en.wikipedia.org/wiki/Darcy%E2%80%93Weisbach_equation
- Junctions and Flow Splitting
- Pressure and Velocity

Exploration Strat
- Entropy Regularization
- Epsilon-Greedy Policies

Rewards Shape
- Overflow Penalty: Penalize overflows to prioritize their prevention
- Stability Reward: Encourage maintaing water levels close to an ideal value
- Energy Efficiency: Adding Penalties for excessive valve adjustments

In [36]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return torch.sigmoid(self.fc3(x)).squeeze(-1)  # Constrain actions between 0 and 1

# Local control: Adjust valves (non-junction nodes)
def local_control(node_representations, policy_network, junction_nodes, last_adjustments, current_time, adjustment_interval=15):
    non_junction_mask = torch.ones(node_representations.shape[0], dtype=torch.bool, device=node_representations.device)
    non_junction_mask[junction_nodes] = False
    non_junction_representations = node_representations[non_junction_mask]
    
    adjustable = torch.tensor(
        [(current_time - last_adjustments[node]) >= adjustment_interval for node in range(num_nodes)],
        dtype=torch.bool,
        device=node_representations.device
    )
    adjustable = adjustable[non_junction_mask]
    
    actions = torch.zeros(non_junction_representations.shape[0], device=node_representations.device)
    if adjustable.sum() > 0:
        adjustable_reps = non_junction_representations[adjustable]
        adjustable_actions = policy_network(adjustable_reps)
        actions[adjustable] = adjustable_actions
    
    # Create full action vector
    full_actions = torch.zeros(node_representations.shape[0], device=node_representations.device)
    full_actions[non_junction_mask] = actions
    return full_actions  # Shape [num_nodes]

# Global control: Adjust all valves based on global state
def global_control(node_representations, policy_network):
    graph_representation = torch.mean(node_representations, dim=0)
    global_action = policy_network(graph_representation.unsqueeze(0))
    actions = global_action.expand(node_representations.shape[0], -1)
    return actions

# Hybrid control: Switch between local and global control based on water levels
def hybrid_control(node_representations, policy_network, global_threshold, current_water_levels, junction_nodes, last_adjustments, current_time, adjustment_interval=15):
    max_water_level = torch.max(current_water_levels)
    if max_water_level > global_threshold:
        print("Switching to Global Control!")
        actions = global_control(node_representations, policy_network)
    else:
        print("Using Local Control.")
        actions = local_control(node_representations, policy_network, junction_nodes, last_adjustments, current_time, adjustment_interval)
    return actions

# Darcy-Weisbach flow calculation
def darcy_weisbach_flow(src_level, tgt_level, distance, friction_factor):
    pressure_diff = src_level - tgt_level
    if pressure_diff <= 0:
        return torch.tensor(0.0, device=distance.device)
    flow_rate = pressure_diff / (friction_factor * distance)
    return flow_rate

# Update state based on actions and compute rewards
def environment_step(state, action, edge_index, flow_capacities, junction_nodes, splitting_ratios, friction_factors, last_adjustments, current_time, adjustment_interval=15):
    water_levels = state[:, 0].clone()
    inflow_rate = state[:, 1].clone()
    outflow_rate = state[:, 2].clone()
    valve_position = action
    inflow_rate_new = torch.zeros_like(inflow_rate)

    for edge_idx, (src, tgt) in enumerate(edge_index.t()):
        src, tgt = int(src.item()), int(tgt.item())
        if src in junction_nodes or tgt in junction_nodes:
            junction = src if src in junction_nodes else tgt
            if junction in splitting_ratios:
                for downstream_node, ratio in splitting_ratios[junction].items():
                    split_flow = inflow_rate[src] * ratio
                    head_loss = friction_factors[edge_idx] * split_flow
                    flow = torch.clamp(split_flow - head_loss, min=0.0) # Non Negative!!
                    assert flow.dim() == 0, f"Flow tensor at edge {edge_idx} is not scalar: {flow.shape}"
                    inflow_rate_new[downstream_node] += flow
            else:
                print(f"Junction node {junction} has no splitting ratios assigned.")
        else:
            flow_rate = darcy_weisbach_flow(water_levels[src], water_levels[tgt], distances_tensor[edge_idx], friction_factors[edge_idx])
            flow_rate = flow_rate * valve_position[src]
            flow_rate = flow_rate.squeeze()
            assert flow_rate.dim() == 0, f"flow_rate tensor at edge {edge_idx} is not scalar: {flow_rate.shape}"
            inflow_rate_new[tgt] += flow_rate
    assert inflow_rate_new.dim() == 1, f"inflow_rate_new has incorrect shape: {inflow_rate_new.shape}"

    new_water_level = torch.clamp(water_levels + inflow_rate_new - outflow_rate, 0, None)
    new_inflow_rate = torch.clamp(inflow_rate_new, 0, None)
    new_outflow_rate = outflow_rate
    new_valve_position = valve_position.clone()

    # Update last valve adjustment times (assuming 15 mins)
    adjusted_nodes = (action != state[:, 3]).nonzero(as_tuple=True)[0]
    for node in adjusted_nodes:
        last_adjustments[node] = current_time

    new_state = torch.stack((new_water_level, new_inflow_rate, new_outflow_rate, new_valve_position), dim=1)
    reward = overflow_penalty(new_state) + stability_reward(new_state)
    return new_state, reward

# Reward functions
def overflow_penalty(state):
    penalty = torch.clamp(state[:, 0] - 1.5, min=0) # Exceeding 1.5 units
    return -torch.sum(penalty) * 10  # Higher penalty for overflow

def stability_reward(state):
    ideal_level = 1.0 # Reward for maintaining water levels = 1.0
    return -torch.mean((state[:, 0] - ideal_level) ** 2)  # Minimize deviation from the ideal level

#### Next Steps (Might be bottlenecked by computational resources)
- MARL
- Advanced Reward System
- Testing different learning strats (PPO vs Q-Learning vs A2C)

#### Custom Synthetic Data
- <b>water_level</b>: The current water level at each node at each timestep. It is updated based on inflows (both natural and from neighboring nodes) and outflows (controlled by the valve).
- <b>inflow_rate</b>: The natural inflow rate into the node, modulated by a seasonal/cyclic pattern.
- <b>outflow_rate</b>: The outflow rate is a function of the valve position and the node's current water level.
- <b>valve_position</b>: The valve position for each node, controlling how much water flows out. It can be controlled by your policy network.
- Cross Section Area
- Gradient
- Junctions - Multi Inflow / One out flow
- Catachment Runoff

In [42]:
num_nodes = 50
num_edges = 60  # Average 2.4 edges per node
num_junctions = 10
time_steps = 150
seasonal_cycle_length = 1440  # One day in minutes (assuming 1 min)

np.random.seed(42)
random.seed(42)
torch.manual_seed(42)

def generate_directed_graph(n_nodes, n_edges, seed=None):
    while True:
        G = nx.gnm_random_graph(n_nodes, n_edges, seed=seed, directed=True)
        if nx.is_weakly_connected(G):
            return G
G = generate_directed_graph(num_nodes, num_edges, seed=42)
print(f"Generated directed graph with {num_nodes} nodes and {G.number_of_edges()} edges.")

# Edges: Flow capacities, distances, and friction factors
flow_capacities = np.random.uniform(low=1.0, high=3.0, size=(G.number_of_edges(),))
distances = np.random.uniform(low=0.5, high=5.0, size=(G.number_of_edges(),))
friction_factors = np.random.uniform(low=0.01, high=0.1, size=(G.number_of_edges(),))
for idx, (u, v) in enumerate(G.edges()):
    G[u][v]['flow_capacity'] = flow_capacities[idx]
    G[u][v]['distance'] = distances[idx]
    G[u][v]['friction_factor'] = friction_factors[idx]
print("Assigned edge attributes to all edges.")

possible_junction_nodes = [node for node in G.nodes() if G.out_degree(node) > 0]
if len(possible_junction_nodes) < num_junctions:
    raise ValueError(f"Not enough nodes with outgoing edges to select {num_junctions} junction nodes.")
junction_nodes = random.sample(possible_junction_nodes, num_junctions)
print(f"Selected junction nodes: {junction_nodes}")

# Define splitting ratios for junction nodes
splitting_ratios = {}
head_loss_factor = 0.1  # Energy loss when splitting

for node in junction_nodes:
    # Get outgoing edges for the junction node
    downstream_nodes = list(G.successors(node))
    num_downstream = len(downstream_nodes)
    if num_downstream > 0:
        raw_ratios = np.random.uniform(0.1, 1.0, size=num_downstream)
        normalized_ratios = raw_ratios / np.sum(raw_ratios)
        splitting_ratios[node] = dict(zip(downstream_nodes, normalized_ratios))
        print(f"Splitting ratios for node {node}: {splitting_ratios[node]}")
    else:
        print(f"Junction node {node} has no downstream nodes assigned.")
        # Optionally, assign default splitting ratios or handle accordingly

# Simulating rainfall with a sinusoidal pattern
time = np.arange(time_steps)
seasonal_pattern = np.sin(2 * np.pi * time / seasonal_cycle_length) * 0.2 + 0.3

# Initialize water levels, inflow/outflow rates, and valve positions
initial_water_levels = np.random.uniform(low=0.5, high=2.0, size=num_nodes)
water_levels = np.zeros((time_steps, num_nodes))
inflow_rates = np.zeros((time_steps, num_nodes))
outflow_rates = np.zeros((time_steps, num_nodes))
valve_positions = np.zeros((time_steps, num_nodes))

# Set initial conditions
water_levels[0, :] = initial_water_levels
valve_positions[0, :] = np.random.uniform(low=0.0, high=1.0, size=num_nodes)
print(f"Initial water levels: {initial_water_levels}")
print(f"Initial valve positions: {valve_positions[0]}")

# Physics parameters
friction_coefficients = 10 * np.ones(G.number_of_edges())  # Adjust as per system
hydrostatic_heads = distances.copy()  # Pressure of flow

# Convert NetworkX directed graph to PyTorch Geometric Data object
water_data = from_networkx(G)
print("Converted directed NetworkX graph to PyTorch Geometric Data object.")

# Assign edge attributes
flow_capacities_tensor = torch.tensor(flow_capacities, dtype=torch.float).to(device)
distances_tensor = torch.tensor(distances, dtype=torch.float).to(device)
friction_factors_tensor = torch.tensor(friction_factors, dtype=torch.float).to(device)

water_data.flow_capacity = flow_capacities_tensor
water_data.distance = distances_tensor
water_data.friction_factor = friction_factors_tensor

print("Assigned edge attributes to water_data.")

# Verify that edge attributes are assigned
assert hasattr(water_data, 'flow_capacity'), "flow_capacity not found in water_data."
assert hasattr(water_data, 'distance'), "distance not found in water_data."
assert hasattr(water_data, 'friction_factor'), "friction_factor not found in water_data."
print("All required edge attributes are present in water_data.")

# Verify that the number of edge attributes matches edge_index
num_edges_data = water_data.edge_index.shape[1]
print(f"Number of edges in water_data.edge_index: {num_edges_data}")
print(f"Length of flow_capacities_tensor: {flow_capacities_tensor.shape[0]}")
print(f"Length of distances_tensor: {distances_tensor.shape[0]}")
print(f"Length of friction_factors_tensor: {friction_factors_tensor.shape[0]}")

assert num_edges_data == flow_capacities_tensor.shape[0], "Mismatch between edge_index and flow_capacities_tensor"
assert num_edges_data == distances_tensor.shape[0], "Mismatch between edge_index and distances_tensor"
assert num_edges_data == friction_factors_tensor.shape[0], "Mismatch between edge_index and friction_factors_tensor"
print("All edge attribute tensors match the number of edges in edge_index.")

# Move edge_index to device
edge_index = water_data.edge_index.to(device)

# Split data into training, validation, and test sets
train_steps = int(0.7 * time_steps)
val_steps = int(0.15 * time_steps)
test_steps = time_steps - train_steps - val_steps

# Prepare node features as tensors
gnn_in_chan = 4  # water_level, inflow_rate, outflow_rate, valve_position
node_features = np.stack([
    water_levels,
    inflow_rates,
    outflow_rates,
    valve_positions
], axis=2)  # Shape: [time_steps, num_nodes, 4]

node_features_tensor = torch.tensor(node_features, dtype=torch.float).to(device)

# Split node features
train_node_features = node_features_tensor[:train_steps, :, :]
val_node_features = node_features_tensor[train_steps:train_steps+val_steps, :, :]
test_node_features = node_features_tensor[train_steps+val_steps:, :, :]

# Create PyTorch Geometric Data objects for each time step
train_data = [Data(x=train_node_features[t], edge_index=edge_index) for t in range(train_steps)]
val_data = [Data(x=val_node_features[t], edge_index=edge_index) for t in range(val_steps)]
test_data = [Data(x=test_node_features[t], edge_index=edge_index) for t in range(test_steps)]
print("Prepared training, validation, and test datasets.")

In [5]:
# # NOTE Simple but effective way to add X features (Could improve in later phases)
# gnn_in_chan = 4  # aka. Num Input Feat (e.g, water_level, inflow_rate, outflow_rate, valve_position)
# node_features = np.zeros((time_steps, num_nodes, gnn_in_chan))
# for t in range(time_steps):
#     for n in range(num_nodes):
#         node_features[t, n, 0] = df_nodes.loc[(df_nodes['time_step'] == t) & (df_nodes['node'] == n), 'water_level'].values[0]
#         node_features[t, n, 1] = df_nodes.loc[(df_nodes['time_step'] == t) & (df_nodes['node'] == n), 'inflow_rate'].values[0]
#         node_features[t, n, 2] = df_nodes.loc[(df_nodes['time_step'] == t) & (df_nodes['node'] == n), 'outflow_rate'].values[0]
#         node_features[t, n, 3] = df_nodes.loc[(df_nodes['time_step'] == t) & (df_nodes['node'] == n), 'valve_position'].values[0]
# node_features_tensor = torch.tensor(node_features, dtype=torch.float)

In [37]:
gnn_in_chan = 4  # water_level, inflow_rate, outflow_rate, valve_position
node_features = np.stack([
    water_levels,
    inflow_rates,
    outflow_rates,
    valve_positions
], axis=2)  # Shape: [time_steps, num_nodes, 4]

node_features_tensor = torch.tensor(node_features, dtype=torch.float).to(device)

In [6]:
# NOTE Will need LSTM for multi-step to deal with sequencial data: node_features_tensor.to(device)
# single_immediate_features = node_features[0] # NOTE for immediate condition (Snapshot)
# single_immediate_features_tensor = torch.tensor(single_immediate_features, dtype=torch.float).to(device)
# print(f"Node Features: {node_features.shape}")
# print(f"Single Node Features: {single_immediate_features.shape}")

# water_data = from_networkx(G)
# water_data.x = single_immediate_features_tensor
# water_data.edge_index = water_data.edge_index.to(device)

In [41]:
# Extract edge attributes from NetworkX graph
flow_capacities = [G[u][v]['flow_capacity'] for u, v in G.edges()]
distances = [G[u][v]['distance'] for u, v in G.edges()]
friction_factors = [G[u][v]['friction_factor'] for u, v in G.edges()]

# Convert lists to PyTorch tensors and move to the appropriate device
flow_capacities_tensor = torch.tensor(flow_capacities, dtype=torch.float).to(device)
distances_tensor = torch.tensor(distances, dtype=torch.float).to(device)
friction_factors_tensor = torch.tensor(friction_factors, dtype=torch.float).to(device)

# Assign edge attributes to the Data object
water_data.flow_capacity = flow_capacities_tensor
water_data.distance = distances_tensor
water_data.friction_factor = friction_factors_tensor

# Verify that edge attributes are assigned
assert hasattr(water_data, 'flow_capacity'), "flow_capacity not found in water_data."
assert hasattr(water_data, 'distance'), "distance not found in water_data."
assert hasattr(water_data, 'friction_factor'), "friction_factor not found in water_data."

print("Edge attributes successfully assigned to water_data.")

edge_index = water_data.edge_index.to(device)
# After assigning edge attributes and before the training loop
print(f"Number of edges in edge_index: {edge_index.shape[1]}")
print(f"Length of flow_capacities_tensor: {flow_capacities_tensor.shape[0]}")
print(f"Length of distances_tensor: {distances_tensor.shape[0]}")
print(f"Length of friction_factors_tensor: {friction_factors_tensor.shape[0]}")

# Assert that all edge attribute tensors have the same length as edge_index
assert edge_index.shape[1] == flow_capacities_tensor.shape[0], "Mismatch between edge_index and flow_capacities_tensor"
assert edge_index.shape[1] == distances_tensor.shape[0], "Mismatch between edge_index and distances_tensor"
assert edge_index.shape[1] == friction_factors_tensor.shape[0], "Mismatch between edge_index and friction_factors_tensor"
print("All edge attribute tensors match the number of edges in edge_index.")


Edge attributes successfully assigned to water_data.
Number of edges in edge_index: 120
Length of flow_capacities_tensor: 60
Length of distances_tensor: 60
Length of friction_factors_tensor: 60


AssertionError: Mismatch between edge_index and flow_capacities_tensor

In [39]:
# Split data into training, validation, and test sets
train_steps = int(0.7 * time_steps)
val_steps = int(0.15 * time_steps)
test_steps = time_steps - train_steps - val_steps

train_node_features = node_features_tensor[:train_steps, :, :]
val_node_features = node_features_tensor[train_steps:train_steps+val_steps, :, :]
test_node_features = node_features_tensor[train_steps+val_steps:, :, :]

train_data = [Data(x=train_node_features[t], edge_index=edge_index) for t in range(train_steps)]
val_data = [Data(x=val_node_features[t], edge_index=edge_index) for t in range(val_steps)]
test_data = [Data(x=test_node_features[t], edge_index=edge_index) for t in range(test_steps)]

# Initialize friction factors (if not already done)
actual_num_edges = edge_index.shape[1]
friction_factors = np.random.uniform(low=0.01, high=0.1, size=(actual_num_edges,))
friction_factors_tensor = torch.tensor(friction_factors, dtype=torch.float).to(device)
assert friction_factors_tensor.shape[0] == edge_index.shape[1], "Mismatch in number of edges and friction factors"

In [40]:
# NOTE New Training Method
gnn_out_chann = gnn_in_chan # Num of features in = out
rl_out_channels = 1 # Controlling valve positions (one output per node)
hidden_channels = 128
gnn = GCN(in_channels=gnn_in_chan, hidden_channels=hidden_channels, out_channels=gnn_out_chann).to(device)
policy_network = PolicyNetwork(input_dim=gnn_out_chann, output_dim=rl_out_channels).to(device)
optimizer = torch.optim.Adam(list(gnn.parameters()) + list(policy_network.parameters()), lr=1e-3)
# schedule = ReduceLROnPlateau

num_episodes = 1_000
best_val_loss = float("inf")
train_losses = []
val_losses = []
es_threshold = 15
early_stoppping = 0

torch.autograd.set_detect_anomaly(True) # Inplace Error??
# Training Loop
for episode in range(num_episodes):
    gnn.train()
    policy_network.train()
    total_train_loss = 0

    # Initialize episode-specific state
    current_water_levels = initial_water_levels.copy()
    current_inflow_rates = np.zeros(num_nodes)
    current_outflow_rates = np.zeros(num_nodes)
    current_valve_positions = valve_positions[0].copy()

    # Reset last valve adjustment tracker
    last_valve_adjustment = np.full(num_nodes, -np.inf, dtype=np.float32)

    # Iterate through time steps within the episode
    for t in range(time_steps):
        # Prepare current state tensor
        state = {
            'water_level': current_water_levels,
            'inflow_rate': current_inflow_rates,
            'outflow_rate': current_outflow_rates,
            'valve_position': current_valve_positions,
        }
        state_tensor = torch.tensor([
            state['water_level'],
            state['inflow_rate'],
            state['outflow_rate'],
            state['valve_position']
        ], dtype=torch.float).transpose(0, 1).to(device)  # Shape: [num_nodes, 4]

        data = Data(x=state_tensor, edge_index=edge_index)
        data = data.to(device)

        # GNN forward pass
        node_representations = gnn(data)

        # Determine global threshold (95th percentile of current water levels)
        train_current_water_levels = torch.tensor(current_water_levels, dtype=torch.float).to(device)
        train_global_threshold = torch.quantile(train_current_water_levels, 0.95)

        # Hybrid control: decide between local and global control
        actions = hybrid_control(
            node_representations, 
            policy_network, 
            train_global_threshold, 
            train_current_water_levels, 
            torch.tensor(junction_nodes, device=device), 
            torch.tensor(last_valve_adjustment, device=device), 
            t, 
            adjustment_interval=15
        )

        # Environment step: simulate water dynamics and calculate reward
        new_state, reward = environment_step(
            state_tensor, 
            actions, 
            edge_index,
            flow_capacities_tensor, 
            junction_nodes, 
            splitting_ratios,
            friction_factors_tensor,
            last_valve_adjustment,
            t,
            adjustment_interval=15
        )

        # Update current state
        current_water_levels = new_state[:, 0].detach().cpu().numpy()
        current_inflow_rates = new_state[:, 1].detach().cpu().numpy()
        current_outflow_rates = new_state[:, 2].detach().cpu().numpy()
        current_valve_positions = new_state[:, 3].detach().cpu().numpy()

        # Compute loss (maximize reward -> minimize negative reward)
        train_loss = -reward.mean()
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        total_train_loss += train_loss.item()

    # Average training loss for the episode
    average_train_loss = total_train_loss / time_steps
    train_losses.append(average_train_loss)

    # Validation Phase
    gnn.eval()
    policy_network.eval()
    total_val_loss = 0
    with torch.no_grad():
        for t, batch in enumerate(val_data):
            batch = batch.to(device)
            current_time = train_steps + t

            # Extract current water levels for threshold
            val_current_water_levels = batch.x[:, 0]
            val_global_threshold = torch.quantile(val_current_water_levels, 0.95)

            # GNN forward pass
            val_node_representations = gnn(batch)

            # Hybrid control
            val_actions = hybrid_control(
                val_node_representations, 
                policy_network, 
                val_global_threshold, 
                val_current_water_levels, 
                torch.tensor(junction_nodes, device=device), 
                torch.tensor(last_valve_adjustment, device=device), 
                current_time, 
                adjustment_interval=15
            )

            # Environment step
            val_new_state, val_reward = environment_step(
                batch.x, 
                val_actions, 
                edge_index, 
                flow_capacities_tensor, 
                junction_nodes, 
                splitting_ratios, 
                friction_factors_tensor,
                last_valve_adjustment,
                current_time,
                adjustment_interval=15
            )

            # Compute validation loss
            val_loss = -val_reward.mean()
            total_val_loss += val_loss.item()
    average_val_loss = total_val_loss / len(val_data)
    val_losses.append(average_val_loss)

    # Early Stopping and Model Saving
    if average_val_loss < best_val_loss:
        best_val_loss = average_val_loss
        early_stopping_counter = 0  # Reset counter
        torch.save(gnn.state_dict(), 'best_gnn_model.pth')
        torch.save(policy_network.state_dict(), 'best_policy_network.pth')
        print(f"Episode {episode}: New best validation loss {best_val_loss:.4f}. Models saved.")
    else:
        early_stopping_counter += 1
        print(f"Episode {episode}: Validation loss did not improve. Counter: {early_stopping_counter}/{es_threshold}")
        if early_stopping_counter >= es_threshold:
            print(f"Early Stopping triggered at Episode {episode}")
            break

    print(f"Episode {episode}: Train Loss = {average_train_loss:.4f}, Val Loss = {average_val_loss:.4f}")

# Load the best models after training
gnn.load_state_dict(torch.load('best_gnn_model.pth'))
policy_network.load_state_dict(torch.load('best_policy_network.pth'))
print("Loaded the best-performing models.")


Switching to Global Control!


IndexError: index 60 is out of bounds for dimension 0 with size 60

#### GAT
In the GAT layer, the node representations are updated using attention mechanisms:

$$
h_v' = \sum_{u \in \mathcal{N}(v)} \alpha_{vu} W h_u
$$

where $ \alpha_{vu} $ is the attention coefficient between node $ v $ and node $ u $.