In [23]:
import os
import sys
import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Jupyter Magic to make plots show up inline
%matplotlib inline

import gymnasium as gym
import gym4real
from gymnasium import spaces
from gym4real.envs.wds.utils import parameter_generator
from stable_baselines3 import DQN, A2C
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
from sb3_contrib import ARS

import wntr
import wntr.sim

# Force the Python simulator to prevent M3 Kernel Crashes
print("ðŸ”§ Patching WNTR Simulator for M3 Compatibility...")
wntr.sim.EpanetSimulator = wntr.sim.WNTRSimulator

ðŸ”§ Patching WNTR Simulator for M3 Compatibility...


In [24]:
# --- THE "ARS COMPATIBILITY" WRAPPER ---
# ARS crashes on Discrete environments (like pumps). 
# This wrapper tricks ARS by pretending the actions are continuous (-1 to +1),
# then converts them to Discrete (0, 1, 2...) before the simulator sees them.
class DiscreteToBoxWrapper(gym.ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
        # 1. Save the original discrete space (e.g., 2 pumps = Discrete(4) or MultiDiscrete)
        self.original_action_space = env.action_space
        
        # 2. Fake a Continuous space (Box) for ARS
        # If discrete space size is N, we create N continuous inputs
        if isinstance(env.action_space, spaces.Discrete):
            n_actions = env.action_space.n
            self.action_space = spaces.Box(low=-1, high=1, shape=(n_actions,), dtype=np.float32)
        elif isinstance(env.action_space, spaces.MultiDiscrete):
            n_actions = sum(env.action_space.nvec)
            self.action_space = spaces.Box(low=-1, high=1, shape=(n_actions,), dtype=np.float32)
        else:
            # If already Box, do nothing
            pass

    def action(self, action):
        # ARS gives us floats like [0.8, -0.2, 0.9]
        # We convert this to the Discrete index with the highest value (Argmax)
        if isinstance(self.original_action_space, spaces.Discrete):
            return int(np.argmax(action))
        return action

In [25]:
config_path = os.path.join(os.getcwd(), "gym4real", "envs", "wds", "world_anytown.yaml")

DURATION_STEPS = 168 

base_params = parameter_generator(
    world_options=config_path, 
    hydraulic_step=3600,       # 1 hour steps
    duration=3600 * 24 * 7,    # 7 days in seconds
    seed=42)

def make_env(for_ars=False):
    params = copy.deepcopy(base_params)
    env = gym.make('gym4real/wds-v0', settings=params)

    # CRITICAL: If using ARS, wrap the env to fix the crash!
    if for_ars:
        env = DiscreteToBoxWrapper(env)
        
    return env

In [26]:
# CREATE TRAINING ENVIRONMENT
train_env = DummyVecEnv([lambda: make_env(for_ars=False)])
train_env = VecMonitor(train_env)

MODEL_NAME = "DQN_WDSEnv"

print("ðŸ§  Initializing DQN Agent...")
model = DQN(
    "MlpPolicy", 
    train_env,
    device='cpu',
    verbose=1)

print(f"Training {MODEL_NAME}...")
model.learn(total_timesteps=200000) # Increase this for better results
print("âœ… Training Complete.")

# Save the model
model.save(MODEL_NAME)

ðŸ§  Initializing DQN Agent...
Using cpu device
Training DQN_WDSEnv...
Resetting the environment...


  gym.logger.warn(
  gym.logger.warn(


Resetting the environment...
Resetting the environment...
Resetting the environment...
Resetting the environment...
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.18e+03 |
|    ep_rew_mean      | 961      |
|    exploration_rate | 0.776    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 31       |
|    time_elapsed     | 149      |
|    total_timesteps  | 4710     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0327   |
|    n_updates        | 1152     |
----------------------------------




Resetting the environment...
Resetting the environment...
Resetting the environment...
Resetting the environment...
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.21e+03 |
|    ep_rew_mean      | 970      |
|    exploration_rate | 0.54     |
| time/               |          |
|    episodes         | 8        |
|    fps              | 31       |
|    time_elapsed     | 305      |
|    total_timesteps  | 9693     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0306   |
|    n_updates        | 2398     |
----------------------------------
Resetting the environment...
Resetting the environment...
Resetting the environment...
Resetting the environment...
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.17e+03 |
|    ep_rew_mean      | 988      |
|    exploration_rate | 0.332    |
| time/               |          |
|    episodes         | 12       

In [27]:
# CREATE TRAINING ENVIRONMENT
train_env = DummyVecEnv([lambda: make_env(for_ars=False)])
train_env = VecMonitor(train_env)

MODEL_NAME = "A2C_WDSEnv"

print("ðŸ§  Initializing A2C Agent...")
model = A2C(
    "MlpPolicy", 
    train_env, 
    learning_rate=1e-4,
    device="cpu", 
    verbose=1)

print(f"Training {MODEL_NAME}...")
model.learn(total_timesteps=200000) # Increase this for better results
print("âœ… Training Complete.")

# Save the model
model.save(MODEL_NAME)

ðŸ§  Initializing A2C Agent...
Using cpu device
Training A2C_WDSEnv...
Resetting the environment...


  gym.logger.warn(
  gym.logger.warn(


-------------------------------------
| time/                 |           |
|    fps                | 33        |
|    iterations         | 100       |
|    time_elapsed       | 14        |
|    total_timesteps    | 500       |
| train/                |           |
|    entropy_loss       | -1.38     |
|    explained_variance | -0.000771 |
|    learning_rate      | 0.0001    |
|    n_updates          | 99        |
|    policy_loss        | 3.86      |
|    value_loss         | 9.53      |
-------------------------------------
-------------------------------------
| time/                 |           |
|    fps                | 32        |
|    iterations         | 200       |
|    time_elapsed       | 30        |
|    total_timesteps    | 1000      |
| train/                |           |
|    entropy_loss       | -1.38     |
|    explained_variance | -0.000291 |
|    learning_rate      | 0.0001    |
|    n_updates          | 199       |
|    policy_loss        | 3.84      |
|    value_l



-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 1.14e+03  |
|    ep_rew_mean        | 1.01e+03  |
| time/                 |           |
|    fps                | 32        |
|    iterations         | 1700      |
|    time_elapsed       | 264       |
|    total_timesteps    | 8500      |
| train/                |           |
|    entropy_loss       | -1.35     |
|    explained_variance | -1.72e-05 |
|    learning_rate      | 0.0001    |
|    n_updates          | 1699      |
|    policy_loss        | 2.12      |
|    value_loss         | 2.62      |
-------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 1.14e+03 |
|    ep_rew_mean        | 1.01e+03 |
| time/                 |          |
|    fps                | 32       |
|    iterations         | 1800     |
|    time_elapsed       | 279      |
|    total_timesteps    | 9000     |
| train/             



------------------------------------
| rollout/              |          |
|    ep_len_mean        | 1.17e+03 |
|    ep_rew_mean        | 1e+03    |
| time/                 |          |
|    fps                | 32       |
|    iterations         | 5700     |
|    time_elapsed       | 887      |
|    total_timesteps    | 28500    |
| train/                |          |
|    entropy_loss       | -1.35    |
|    explained_variance | 0        |
|    learning_rate      | 0.0001   |
|    n_updates          | 5699     |
|    policy_loss        | 0.706    |
|    value_loss         | 0.287    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 1.17e+03 |
|    ep_rew_mean        | 1e+03    |
| time/                 |          |
|    fps                | 32       |
|    iterations         | 5800     |
|    time_elapsed       | 903      |
|    total_timesteps    | 29000    |
| train/                |          |
|



------------------------------------
| rollout/              |          |
|    ep_len_mean        | 1.16e+03 |
|    ep_rew_mean        | 985      |
| time/                 |          |
|    fps                | 32       |
|    iterations         | 8500     |
|    time_elapsed       | 1321     |
|    total_timesteps    | 42500    |
| train/                |          |
|    entropy_loss       | -1.24    |
|    explained_variance | 0        |
|    learning_rate      | 0.0001   |
|    n_updates          | 8499     |
|    policy_loss        | 2.11     |
|    value_loss         | 2.14     |
------------------------------------
Resetting the environment...
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 1.16e+03 |
|    ep_rew_mean        | 983      |
| time/                 |          |
|    fps                | 32       |
|    iterations         | 8600     |
|    time_elapsed       | 1337     |
|    total_timesteps    | 43000    |
| train/ 



Resetting the environment...
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 1.06e+03 |
|    ep_rew_mean        | 1.01e+03 |
| time/                 |          |
|    fps                | 32       |
|    iterations         | 36800    |
|    time_elapsed       | 5710     |
|    total_timesteps    | 184000   |
| train/                |          |
|    entropy_loss       | -0.788   |
|    explained_variance | 1.79e-07 |
|    learning_rate      | 0.0001   |
|    n_updates          | 36799    |
|    policy_loss        | 0.00244  |
|    value_loss         | 0.000119 |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 1.06e+03 |
|    ep_rew_mean        | 1.01e+03 |
| time/                 |          |
|    fps                | 32       |
|    iterations         | 36900    |
|    time_elapsed       | 5725     |
|    total_timesteps    | 184500   |
| train/ 

In [30]:
# # CREATE TRAINING ENVIRONMENT
# train_env = DummyVecEnv([lambda: make_env(for_ars=True)])
# train_env = VecMonitor(train_env)

# MODEL_NAME = "ARS_WDSEnv"

# print("ðŸ§  Initializing ARS Agent...")
# model = ARS(
#     "LinearPolicy", 
#     train_env, 
#     device="cpu", 
#     verbose=1)

# print(f"Training {MODEL_NAME}...")
# model.learn(total_timesteps=100) # Increase this for better results
# print("âœ… Training Complete.")

# # Save the model
# model.save(MODEL_NAME)