In [14]:
import numpy as np
import gymnasium as gym
import random
from gymnasium import spaces
from collections import deque

In [15]:
def bfs_reachable(grid, start, targets):
    """
    Check if all target cells are reachable from start on grid (0=free, 1=obstacle).
    """
    H, W = grid.shape
    visited = np.zeros_like(grid, dtype=bool)
    queue = deque([start])
    visited[start] = True
    reached = set()
    while queue:
        i, j = queue.popleft()
        if (i, j) in targets:
            reached.add((i, j))
            if reached == set(targets):
                return True
        for di, dj in ((1,0),(-1,0),(0,1),(0,-1)):
            ni, nj = i+di, j+dj
            if 0 <= ni < H and 0 <= nj < W and not visited[ni, nj] and grid[ni, nj] == 0:
                visited[ni, nj] = True
                queue.append((ni, nj))
    return False


In [16]:
class CoverageEnv(gym.Env):
    metadata = {"render.modes": ["human"]}

    def __init__(self, curriculum_max=3, max_steps=200, seed=None):
        super().__init__()
        self.H, self.W = 8, 8
        self.curriculum_level = 0
        self.curriculum_max = curriculum_max
        self.max_steps = max_steps

        # seeding for reproducibility
        self.seed(seed)

        # Action: down, up, right, left
        self.action_space = spaces.Discrete(4)
        # Observation: flattened grid + agent pos + coverage mask
        obs_size = self.H * self.W + 2 + self.H * self.W
        self.observation_space = spaces.Box(0.0, 1.0, shape=(obs_size,), dtype=np.float32)

        # Define a fixed shape library
        self.shape_library = [
            np.array([[1]]),                # single cell
            np.ones((1,3), dtype=int),     # horizontal bar
            np.ones((2,2), dtype=int),     # 2x2 block
            np.array([[1,1,1],             # U-shape
                      [1,0,1],
                      [1,1,1]]),
            np.array([[1,1,0],             # L-shape
                      [1,0,0]])
        ]

    def seed(self, seed=None):
        """
        Seed the environment's RNGs for reproducible layouts.
        """
        np.random.seed(seed)
        random.seed(seed)
        return [seed]

    def reset(self, *, seed=None, options=None):
        """
        Reset the environment; returns obs, info
        """
        if seed is not None:
            self.seed(seed)

        self.curriculum_level = min(self.curriculum_max, self.curriculum_level + 1)
        while True:
            grid = np.zeros((self.H, self.W), dtype=int)
            allowed = self.shape_library[: self.curriculum_level + 1]
            num_shapes = np.random.randint(1, self.curriculum_level * 2 + 1)
            placed = np.zeros_like(grid)
            for _ in range(num_shapes):
                shape = random.choice(allowed)
                sh, sw = shape.shape
                i = np.random.randint(0, self.H - sh + 1)
                j = np.random.randint(0, self.W - sw + 1)
                if not np.any(placed[i:i+sh, j:j+sw] & shape):
                    placed[i:i+sh, j:j+sw] |= shape
            grid = placed

            ti = np.random.randint(0, self.H - 3 + 1)
            tj = np.random.randint(0, self.W - 3 + 1)
            full_block = [(ti+di, tj+dj) for di in range(3) for dj in range(3)]
            targets = [(i,j) for (i,j) in full_block if grid[i,j] == 0]
            if not targets:
                continue

            free_cells = list(zip(*np.where(grid == 0)))
            start = random.choice(free_cells)
            if bfs_reachable(grid, start, targets):
                break

        self.grid = grid
        self.targets = set(targets)
        self.agent_pos = start
        self.visited = set()
        self.steps = 0

        return self._get_obs(), {}

    def _get_obs(self):
        flat_grid = self.grid.flatten()
        ai, aj = self.agent_pos
        agent_vec = np.array([ai/(self.H-1), aj/(self.W-1)], dtype=np.float32)
        cover_mask = np.zeros_like(flat_grid)
        for (i,j) in self.targets:
            cover_mask[i*self.W + j] = 1
        return np.concatenate([flat_grid, agent_vec, cover_mask])

    def step(self, action):
        # Define movement vectors
        # 0 = down, 1 = up, 2 = right, 3 = left
        moves = {0: (1, 0), 1: (-1, 0), 2: (0, 1), 3: (0, -1)}
        i, j = self.agent_pos
        di, dj = moves[action]
        ni, nj = i + di, j + dj

        # Default baseline reward
        reward = 0

        # Check validity and apply penalties
        if not (0 <= ni < self.H and 0 <= nj < self.W and self.grid[ni, nj] == 0):
            # Invalid action: stay in place
            self.agent_pos = (i, j) 
        else:
            # Valid move: update position
            self.agent_pos = (ni, nj)

        # Check if on a target
        if self.agent_pos not in self.visited:
            if self.agent_pos in self.targets:
                reward = 2.0   # new target
            self.visited.add(self.agent_pos)

        # Step count
        self.steps += 1

        # Terminal bonus
        terminated = (self.visited == self.targets)
        truncated = (self.steps >= self.max_steps)
        if terminated:
            reward += 30.0

        return self._get_obs(), reward, terminated, truncated, {}

    def render(self, mode="human"):
        disp = np.full((self.H, self.W), '.', dtype=str)
        for (i,j) in self.targets:
            disp[i,j] = 'T'
        for (i,j) in zip(*np.where(self.grid == 1)):
            disp[i,j] = '#'
        ai, aj = self.agent_pos
        disp[ai, aj] = 'A'
        print("\n".join("".join(row) for row in disp))

    def close(self):
        pass


In [17]:
gym.register(
    id="Coverage-v0",
    entry_point="coverage_env:CoverageEnv",
    max_episode_steps=200,
)


  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


In [18]:
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy

# instantiate a single env (you can wrap VecEnv for parallelism later)
env = CoverageEnv(seed=42)

# create the DQN model
model = DQN(
    policy="MlpPolicy",   # a simple MLP
    env=env,
    learning_starts=1000,
    buffer_size=500_000,
    learning_rate=1e-3,
    batch_size=32,
    gamma=0.99,
    verbose=1,
)

# train for 50k timesteps
model.learn(total_timesteps=500_000)

# save it
model.save("dqn_coverage")


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 200      |
|    ep_rew_mean      | 13       |
|    exploration_rate | 0.985    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 29053    |
|    time_elapsed     | 0        |
|    total_timesteps  | 800      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 200      |
|    ep_rew_mean      | 11.8     |
|    exploration_rate | 0.97     |
| time/               |          |
|    episodes         | 8        |
|    fps              | 9550     |
|    time_elapsed     | 0        |
|    total_timesteps  | 1600     |
| train/              |          |
|    learning_rate    | 0.001    |
|    loss             | 0.185    |
|    n_updates        | 149      |
-------------------------------

In [19]:
# load (if needed)
# model = DQN.load("dqn_coverage", env=env)

mean_reward, std_reward = evaluate_policy(
    model, 
    env, 
    n_eval_episodes=20, 
    deterministic=True,
)
print(f"Mean reward: {mean_reward:.2f} ± {std_reward:.2f}")




Mean reward: 0.80 ± 1.83


In [20]:
obs, _ = env.reset(seed=42)

for i in range(env.max_steps):
    # model.predict returns e.g. array([2], dtype=int64)
    action_arr, _ = model.predict(obs, deterministic=True)
    action = int(action_arr)       # unwrap to Python int
    print("step:", i, "action:", action)

    obs, reward, terminated, truncated, info = env.step(action)
    env.render()
    print("\n")

    if terminated or truncated:
        break


step: 0 action: 0
........
........
..TTT.A#
..TTT...
..TT###.
....##..
.###....
........


step: 1 action: 0
........
........
..TTT..#
..TTT.A.
..TT###.
....##..
.###....
........


step: 2 action: 3
........
........
..TTT..#
..TTTA..
..TT###.
....##..
.###....
........


step: 3 action: 0
........
........
..TTT..#
..TTTA..
..TT###.
....##..
.###....
........


step: 4 action: 0
........
........
..TTT..#
..TTTA..
..TT###.
....##..
.###....
........


step: 5 action: 0
........
........
..TTT..#
..TTTA..
..TT###.
....##..
.###....
........


step: 6 action: 0
........
........
..TTT..#
..TTTA..
..TT###.
....##..
.###....
........


step: 7 action: 0
........
........
..TTT..#
..TTTA..
..TT###.
....##..
.###....
........


step: 8 action: 0
........
........
..TTT..#
..TTTA..
..TT###.
....##..
.###....
........


step: 9 action: 0
........
........
..TTT..#
..TTTA..
..TT###.
....##..
.###....
........


step: 10 action: 0
........
........
..TTT..#
..TTTA..
..TT###.
....##..
.###...