In [1]:
import numpy as np
import gymnasium as gym

from coverage_env import CoverageEnv

In [2]:
gym.register(
    id="Coverage-v0",
    entry_point="coverage_env:CoverageEnv",
    max_episode_steps=200,
)


In [5]:
import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class SmallCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=64):
        super().__init__(observation_space, features_dim)
        n_channels, H, W = observation_space.shape

        # A tiny two‐layer CNN
        self.cnn = nn.Sequential(
            nn.Conv2d(n_channels, 16, kernel_size=3, stride=1),  # -> (16, H-2, W-2)
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1),          # -> (32, H-4, W-4)
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with th.no_grad():
            sample = th.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(sample).shape[1]

        # Final linear layer to get exactly features_dim outputs
        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim),
            nn.ReLU(),
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))


class MediumCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=64):
        super().__init__(observation_space, features_dim)
        n_channels, H, W = observation_space.shape

        self.cnn = nn.Sequential(
            nn.Conv2d(n_channels, 32, kernel_size=3, stride=1, padding=1),  # (32, H, W)
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),          # (64, H, W)
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Flatten(),
        )

        # Compute flattened feature size
        with th.no_grad():
            sample = th.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(sample).shape[1]

        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim),
            nn.ReLU()
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))



In [7]:
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy

# instantiate a single env (you can wrap VecEnv for parallelism later)
env = CoverageEnv(seed=42, curriculum=0)

# create the DQN model
model = DQN(
    policy="CnnPolicy",
    env=env,
    learning_starts=1000,
    buffer_size=500_000,
    learning_rate=1e-3,
    batch_size=32,
    gamma=0.99,
    exploration_fraction=0.3,
    exploration_final_eps=0.02,
    policy_kwargs=dict(
        features_extractor_class=MediumCNN,
        features_extractor_kwargs=dict(features_dim=64),
        normalize_images=False,    # since you already output float32 [0,1]
    ),
    verbose=1,
)

# train for 50k timesteps
model.learn(total_timesteps=200_000)

# save it
model.save("dqn_coverage")

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 182      |
|    ep_rew_mean      | -148     |
|    exploration_rate | 0.988    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 4655     |
|    time_elapsed     | 0        |
|    total_timesteps  | 726      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 175      |
|    ep_rew_mean      | -136     |
|    exploration_rate | 0.977    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 1117     |
|    time_elapsed     | 1        |
|    total_timesteps  | 1399     |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.0532   |
|    n_updates        | 99       |
------------------------------

KeyboardInterrupt: 

In [8]:
# load (if needed)
# model = DQN.load("dqn_coverage", env=env)

mean_reward, std_reward = evaluate_policy(
    model, 
    env, 
    n_eval_episodes=20, 
    deterministic=True,
)
print(f"Mean reward: {mean_reward:.2f} ± {std_reward:.2f}")




Mean reward: -182.00 ± 4.65


In [26]:
obs, _ = env.reset()

for i in range(env.max_steps):
    # model.predict returns e.g. array([2], dtype=int64)
    action_arr, _ = model.predict(obs, deterministic=True)
    action = int(action_arr)       # unwrap to Python int
    print("step:", i, "action:", action)

    obs, reward, terminated, truncated, info = env.step(action)
    env.render()
    print("\n")

    if terminated or truncated:
        break


step: 0 action: 0
........
........
A...TTT.
....#TT.
....TTT.
........
........
........


step: 1 action: 2
........
........
.A..TTT.
....#TT.
....TTT.
........
........
........


step: 2 action: 2
........
........
..A.TTT.
....#TT.
....TTT.
........
........
........


step: 3 action: 2
........
........
...ATTT.
....#TT.
....TTT.
........
........
........


step: 4 action: 2
........
........
....ATT.
....#TT.
....TTT.
........
........
........


step: 5 action: 2
........
........
....TAT.
....#TT.
....TTT.
........
........
........


step: 6 action: 2
........
........
....TTA.
....#TT.
....TTT.
........
........
........


step: 7 action: 3
........
........
....TAT.
....#TT.
....TTT.
........
........
........


step: 8 action: 0
........
........
....TTT.
....#AT.
....TTT.
........
........
........


step: 9 action: 0
........
........
....TTT.
....#TT.
....TAT.
........
........
........


step: 10 action: 3
........
........
....TTT.
....#TT.
....ATT.
........
.......

# **PPO**

In [10]:
class SimplerCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=64):
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]

        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        with th.no_grad():
            sample = th.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(sample).shape[1]

        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim),
            nn.ReLU(),
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))


In [11]:
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy

# instantiate your env as before
env = CoverageEnv(seed=42)

# create the PPO model
model = PPO(
    policy="CnnPolicy",
    env=env,
    learning_rate=1e-4,         # often lower LR for PPO
    n_steps=256,               # timesteps per rollout
    batch_size=32,              # minibatch size
    n_epochs=10,                # how many times to reuse each rollout
    gamma=0.99,
    ent_coef=0.01,  # or try 0.005 if too random
    policy_kwargs=dict(
        features_extractor_class=SimplerCNN,
        features_extractor_kwargs=dict(features_dim=64),
        normalize_images=False, # your env already outputs float32 [0,1]
    ),
    verbose=1,
    device="cuda",              # send to GPU if available
)

# train for 500k timesteps
model.learn(total_timesteps=1_500_000)

# save it
model.save("ppo_coverage")


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 95       |
|    ep_rew_mean     | -41      |
| time/              |          |
|    fps             | 631      |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 256      |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 165          |
|    ep_rew_mean          | -137         |
| time/                   |              |
|    fps                  | 337          |
|    iterations           | 2            |
|    time_elapsed         | 1            |
|    total_timesteps      | 512          |
| train/                  |              |
|    approx_kl            | 0.0042865165 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    e

KeyboardInterrupt: 

In [23]:
mean_reward, std_reward = evaluate_policy(
    model, 
    env, 
    n_eval_episodes=20, 
    deterministic=True,
)
print(f"Mean reward: {mean_reward:.2f} ± {std_reward:.2f}")


Mean reward: -194.00 ± 0.00


In [17]:
obs, _ = env.reset(seed=42)

for i in range(env.max_steps):
    # model.predict returns e.g. array([2], dtype=int64)
    action_arr, _ = model.predict(obs, deterministic=True)
    action = int(action_arr)       # unwrap to Python int
    print("step:", i, "action:", action)

    obs, reward, terminated, truncated, info = env.step(action)
    env.render()
    print("\n")

    if terminated or truncated:
        break


step: 0 action: 0
........
........
..TTT.A#
..TTT...
..TT###.
....##..
.###....
........


step: 1 action: 0
........
........
..TTT..#
..TTT.A.
..TT###.
....##..
.###....
........


step: 2 action: 0
........
........
..TTT..#
..TTT.A.
..TT###.
....##..
.###....
........


step: 3 action: 0
........
........
..TTT..#
..TTT.A.
..TT###.
....##..
.###....
........


step: 4 action: 0
........
........
..TTT..#
..TTT.A.
..TT###.
....##..
.###....
........


step: 5 action: 0
........
........
..TTT..#
..TTT.A.
..TT###.
....##..
.###....
........


step: 6 action: 0
........
........
..TTT..#
..TTT.A.
..TT###.
....##..
.###....
........


step: 7 action: 0
........
........
..TTT..#
..TTT.A.
..TT###.
....##..
.###....
........


step: 8 action: 0
........
........
..TTT..#
..TTT.A.
..TT###.
....##..
.###....
........


step: 9 action: 0
........
........
..TTT..#
..TTT.A.
..TT###.
....##..
.###....
........


step: 10 action: 0
........
........
..TTT..#
..TTT.A.
..TT###.
....##..
.###...