In [None]:
from __future__ import annotations

import numpy as np

%env SPDLOG_LEVEL=trace
from matplotlib import pyplot as plt
from PIL import Image, ImageOps

import mcerl
from mcerl.env import Env
from mcerl.utils import (
    pad_trajectory,
    refine_trajectory,
    split_trajectories,
    stack_trajectory,
)


In [None]:
img = Image.open("0.png")
img = ImageOps.grayscale(img)
img = img.resize((300, 200))
grid_map = np.array(img)
plt.imshow(grid_map, cmap="gray",vmin=0, vmax=255)

In [None]:
num_agents = 3
agent_poses = None
num_rays = 32
max_steps = 100000
max_steps_per_agent = 100
ray_range = 30
velocity = 1
min_frontier_size = 8
max_frontier_size = 30
exploration_threshold = 0.95
env = Env(
    num_agents=num_agents,
    max_steps=max_steps,
    max_steps_per_agent=max_steps_per_agent,
    velocity=velocity,
    sensor_range=ray_range,
    num_rays=num_rays,
    min_frontier_pixel=min_frontier_size,
    max_frontier_pixel=max_frontier_size,
    exploration_threshold=exploration_threshold,
)

In [None]:
(
    env._env.test_xy_cv_mat(mcerl.GridMap(grid_map), (0, 0)),
    env._env.test_xy_cv_mat(mcerl.GridMap(grid_map), (1, 0)),
    env._env.test_xy_cv_mat(mcerl.GridMap(grid_map), (0, 1)),
)

In [None]:
(
    env._env.test_xy_coord(mcerl.GridMap(grid_map), (0, 0)),
    env._env.test_xy_coord(mcerl.GridMap(grid_map), (1, 0)),
    env._env.test_xy_coord(mcerl.GridMap(grid_map), (0, 1)),
)

In [None]:
def policy(observation):
    action_space = len(observation["frontier_points"])
    if action_space > 0:
        rng = np.random.default_rng()
        return rng.integers(action_space).item() # type: ignore  # noqa: PGH003
    return 0

In [None]:
# Example usage
# num_threads = 15
# epochs = 10
# rollouts = multi_threaded_rollout(
#     env=lambda: Env(
#         num_agents=num_agents,
#         max_steps=max_steps,
#         max_steps_per_agent=max_steps_per_agent,
#         velocity=velocity,
#         sensor_range=ray_range,
#         num_rays=num_rays,
#         min_frontier_pixel=min_frontier_size,
#         max_frontier_pixel=max_frontier_size,
#     ),
#     grid_map=grid_map,
#     agent_poses=agent_poses,
#     policy=policy,
#     num_threads=num_threads,
#     epochs=epochs,
# )

In [None]:
%matplotlib inline

In [None]:
trajectories = []
frame_data = env.reset(grid_map, agent_poses)
maps = []
maps_1 = []
maps_2 = []
maps_3 = []
trajectories.append(frame_data)
while True:
    agent_id = frame_data["info"]["agent_id"]
    action_index = policy(frame_data["observation"])
    frame_data["action"] = action_index
    frame_data = env.step(agent_id, action_index)
    maps.append(env.global_map())
    maps_1.append(env.agent_map(0))
    maps_2.append(env.agent_map(1))
    maps_3.append(env.agent_map(2))
    trajectories.append(frame_data)
    if env.done() is True:
        break
rollouts = split_trajectories(trajectories)
rollouts = [pad_trajectory(rollout) for rollout in rollouts]
rollouts = [refine_trajectory(rollout) for rollout in rollouts]
stacked_rollouts = [stack_trajectory(rollout) for rollout in rollouts]

In [None]:
imgs = [Image.fromarray(img) for img in maps]
imgs[0].save("array0.gif", save_all=True, append_images=imgs[1:], duration=50, loop=0)
imgs = [Image.fromarray(img) for img in maps_1]
imgs[0].save("array1.gif", save_all=True, append_images=imgs[1:], duration=50, loop=0)
imgs = [Image.fromarray(img) for img in maps_2]
imgs[0].save("array2.gif", save_all=True, append_images=imgs[1:], duration=50, loop=0)
imgs = [Image.fromarray(img) for img in maps_3]
imgs[0].save("array3.gif", save_all=True, append_images=imgs[1:], duration=50, loop=0)