In [None]:
from __future__ import annotations

import numpy as np

%env SPDLOG_LEVEL=trace
import mcerl
from mcerl.env import Env
from mcerl.utils import (
    multi_threaded_rollout,
    pad_trajectory,
    refine_trajectory,
    split_trajectories,
    stack_trajectory,
)


In [None]:
grid_map = np.ndarray(shape=(100, 100), dtype=np.uint8)
grid_map.fill(255)
grid_map[20:70, 20:30] = 0
grid_map[30:40, 10:50] = 0
test_grid_map = mcerl.GridMap(grid_map)
num_agents = 4
agent_poses = [(10, 10), (60, 60), (70, 70), (80, 80)]
num_rays = 32
max_steps = 1000
max_steps_per_agent = 40
ray_range = 20
velocity = 1
min_frontier_size = 5
max_frontier_size = 20
env = Env(
    num_agents=num_agents,
    max_steps=max_steps,
    max_steps_per_agent=max_steps_per_agent,
    velocity=velocity,
    sensor_range=ray_range,
    num_rays=num_rays,
    min_frontier_pixel=min_frontier_size,
    max_frontier_pixel=max_frontier_size,
)

In [None]:
def policy(observation):
    action_space = len(observation["frontier_points"])
    if action_space > 0:
        rng = np.random.default_rng()
        return rng.integers(action_space).item() # type: ignore  # noqa: PGH003
    return 0

In [None]:
# Example usage
num_threads = 15
epochs = 10
rollouts = multi_threaded_rollout(
    env=lambda: Env(
        num_agents=num_agents,
        max_steps=max_steps,
        max_steps_per_agent=max_steps_per_agent,
        velocity=velocity,
        sensor_range=ray_range,
        num_rays=num_rays,
        min_frontier_pixel=min_frontier_size,
        max_frontier_pixel=max_frontier_size,
    ),
    grid_map=grid_map,
    agent_poses=agent_poses,
    policy=policy,
    num_threads=num_threads,
    epochs=epochs,
)

In [None]:
trajectories = []
frame_data = env.reset(grid_map, agent_poses)
maps = []
trajectories.append(frame_data)
while True:
    agent_id = frame_data["info"]["agent_id"]
    action_index = policy(frame_data["observation"])
    frame_data["action"] = action_index
    frame_data = env.step(agent_id, action_index)
    trajectories.append(frame_data)
    if env.done() is True:
        break
rollouts = split_trajectories(trajectories)
rollouts = [pad_trajectory(rollout) for rollout in rollouts]
rollouts = [refine_trajectory(rollout) for rollout in rollouts]
stacked_rollouts = [stack_trajectory(rollout) for rollout in rollouts]