In [None]:
# imports
from __future__ import annotations

from typing import Any

import numpy as np
from matplotlib import pyplot as plt
from PIL import Image, ImageOps
from rl.actor_critic import Actor
from rl.network import GINNetwork
from rl.utils import to_graph

from mcerl.env import Env
from mcerl.utils import (
    delta_time_reward_standardize,
    exploration_reward_rescale,
    multi_threaded_rollout,
    reward_sum,
    single_env_rollout,  # noqa: F401
)

%env SPDLOG_LEVEL=warning
%matplotlib inline

In [None]:
# load map
img = Image.open("0.png")
img = ImageOps.grayscale(img)
img = img.resize((300, 200))
grid_map = np.array(img)
grid_map[grid_map < 100] = 0
grid_map[grid_map >= 100] = 255
plt.imshow(grid_map, cmap="gray", vmin=0, vmax=255)

In [None]:
# define parameters
from math import pi

num_agents = 3
agent_poses = None
num_rays = 32
max_steps = 100000
max_steps_per_agent = 100
ray_range = 30
velocity = 1
min_frontier_size = 8
max_frontier_size = 30
exploration_threshold = 0.98
map_height, map_width = grid_map.shape
max_step_exploration_reward = ray_range * 2 * 1.41
max_exploration_gain = ray_range**2 * pi / 2.0
max_step_exploration_reward, max_exploration_gain

In [4]:
# create environment
env = Env(
    num_agents=num_agents,
    max_steps=max_steps,
    max_steps_per_agent=max_steps_per_agent,
    velocity=velocity,
    sensor_range=ray_range,
    num_rays=num_rays,
    min_frontier_pixel=min_frontier_size,
    max_frontier_pixel=max_frontier_size,
    exploration_threshold=exploration_threshold,
)
# Example usage
# num_threads = 15
# epochs = 10
# rollouts = multi_threaded_rollout(
#     env=lambda: Env(
#         num_agents=num_agents,
#         max_steps=max_steps,
#         max_steps_per_agent=max_steps_per_agent,
#         velocity=velocity,
#         sensor_range=ray_range,
#         num_rays=num_rays,
#         min_frontier_pixel=min_frontier_size,
#         max_frontier_pixel=max_frontier_size,
#     ),
#     grid_map=grid_map,
#     agent_poses=agent_poses,
#     policy=policy,
#     num_threads=num_threads,
#     epochs=epochs,
# )
# trajectories = []
# frame_data = env.reset(grid_map, agent_poses, return_maps=True)
# trajectories.append(frame_data)
# while True:
#     agent_id = frame_data["info"]["agent_id"]
#     frame_data["action_agent_id"] = agent_id
#     frame_data = random_policy(frame_data)
#     frame_data = env.step(frame_data, return_maps=True)
#     trajectories.append(frame_data)
#     if env.done() is True:
#         break
# rollouts = split_trajectories(trajectories)
# rollouts = [pad_trajectory(rollout) for rollout in rollouts]
# rollouts = [refine_trajectory(rollout) for rollout in rollouts]

In [None]:
# define policy


from rl.actor_critic import ActorCritic, Critic


def forward_preprocess(frame_data: dict[str, Any]) -> dict[str, Any]:
    """
    normalize position, exploration gain,etc.
    """
    # normalize frontier position to [0,1] and exploration gain to [0,1]
    width = float(map_width)
    height = float(map_height)
    frame_data["observation"]["frontier_points"] = [
        (
            float(x) / width,
            float(y) / height,
            float(gain) / max_exploration_gain,
        )
        for x, y, gain in frame_data["observation"]["frontier_points"]
    ]
    # normalize position to [0,1]
    frame_data["observation"]["pos"] = [
        (float(x) / width, float(y) / height)
        for x, y in frame_data["observation"]["pos"]
    ]
    frame_data["observation"]["target_pos"] = [
        (float(x) / width, float(y) / height)
        for x, y in frame_data["observation"]["target_pos"]
    ]
    # build graph
    frame_data = to_graph(frame_data)
    return frame_data


policy_network = GINNetwork(dim_h=32)
value_network = GINNetwork(dim_h=32)

actor = Actor(policy_network=policy_network)
critic = Critic(value_network=value_network)
wrapped_actor_critic = ActorCritic(
    actor=actor, critic=critic, forward_preprocess=forward_preprocess
)

In [6]:
# get data
# single-threaded rollout

single_rollouts = single_env_rollout(env, grid_map, policy=wrapped_actor_critic)
rollouts=single_rollouts

# multi-threaded rollout
# num_threads = 20
# epochs = 100
# rollouts = multi_threaded_rollout(
#     env=lambda: Env(
#         num_agents=num_agents,
#         max_steps=max_steps,
#         max_steps_per_agent=max_steps_per_agent,
#         velocity=velocity,
#         sensor_range=ray_range,
#         num_rays=num_rays,
#         min_frontier_pixel=min_frontier_size,
#         max_frontier_pixel=max_frontier_size,
#     ),
#     grid_map=grid_map,
#     agent_poses=agent_poses,
#     policy=actor,
#     num_threads=num_threads,
#     epochs=epochs,
#     return_maps=False,
# )

In [None]:
# draw raw rewards for debugging

print("exploration reward:")
for rollout in rollouts:
    plt.plot(
        [frame_data["next"]["reward"]["exploration_reward"] for frame_data in rollout]
    )
plt.show()
print("time_step_reward :")
for rollout in rollouts:
    plt.plot(
        [frame_data["next"]["reward"]["time_step_reward"] for frame_data in rollout]
    )
plt.show()

In [48]:
# normalize rewards
rollouts = [
    exploration_reward_rescale(rollout, max_value=max_exploration_gain)
    for rollout in rollouts
]
rollouts = [delta_time_reward_standardize(rollout) for rollout in rollouts]
rollouts = [reward_sum(rollout) for rollout in rollouts]

In [None]:
# draw normalized rewards for debugging
# total_reward
for rollout in rollouts:
    plt.plot([frame_data["next"]["reward"]["total_reward"] for frame_data in rollout])
mean_val = np.mean(
    [
        frame_data["next"]["reward"]["total_reward"]
        for frame_data in rollout
        for rollout in rollouts
    ]
)
plt.title(f"total reward:{mean_val}")
plt.axhline(mean_val, color="r", linestyle="--")
plt.show()


# exploration_reward
for rollout in rollouts:
    plt.plot(
        [frame_data["next"]["reward"]["exploration_reward"] for frame_data in rollout]
    )
mean_val = np.mean(
    [
        frame_data["next"]["reward"]["exploration_reward"]
        for frame_data in rollout
        for rollout in rollouts
    ]
)
plt.title(f"exploration reward:{mean_val}")
plt.axhline(mean_val, color="r", linestyle="--")
plt.show()

# time_step_reward
for rollout in rollouts:
    plt.plot(
        [frame_data["next"]["reward"]["time_step_reward"] for frame_data in rollout]
    )
mean_val = np.mean(
    [
        frame_data["next"]["reward"]["time_step_reward"]
        for frame_data in rollout
        for rollout in rollouts
    ]
)
plt.title(f"time step reward:{mean_val}")
plt.axhline(mean_val, color="r", linestyle="--")
plt.show()

# reward_to_go
for rollout in rollouts:
    plt.plot([frame_data["reward_to_go"] for frame_data in rollout])
mean_val = np.mean(
    [frame_data["reward_to_go"] for frame_data in rollout for rollout in rollouts]
)
plt.title(f"reward to go:{mean_val}")
plt.axhline(mean_val, color="r", linestyle="--")
plt.show()


# reward_mean
reward_mean = [
    np.mean([frame_data["reward_to_go"] for frame_data in rollout])
    for rollout in rollouts
]
plt.plot(reward_mean)
plt.title(f"reward mean:{np.mean(reward_mean)}")
plt.axhline(y=np.mean(reward_mean), color="r", linestyle="--")
plt.show()

In [10]:
# store maps to gif
# rollouts=single_rollouts
# for rollout in rollouts:
#     agent_id = rollout[0]["info"]["agent_id"]
#     imgs = [
#         Image.fromarray(frame_data["observation"]["agent_map"])
#         for frame_data in rollout
#     ]
#     imgs[0].save(
#         f"agent_{agent_id}.gif",
#         save_all=True,
#         append_images=imgs[1:],
#         duration=100,
#         loop=0,
#     )
# flattened_rollouts = []
# for rollout in rollouts:
#     flattened_rollouts.extend(rollout)
# sorted_rollouts = sorted(flattened_rollouts, key=lambda x: x["info"]["step_cnt"])
# imgs = [
#     Image.fromarray(frame_data["observation"]["global_map"])
#     for frame_data in sorted_rollouts
# ]
# imgs[0].save(
#     "global_map.gif",
#     save_all=True,
#     append_images=imgs[1:],
#     duration=100,
#     loop=0,
# )

In [None]:
def check_dict_struct(d: dict[str, Any], prefix: str = "") -> list:
    flattened_keys = []
    for k, v in d.items():
        if isinstance(v, dict):
            flattened_keys.extend(check_dict_struct(v, prefix + k + "."))
        else:
            flattened_keys.append(prefix + k)
    return flattened_keys

{len(check_dict_struct(frame_data)) for frame_data in rollout for rollout in rollouts}

In [7]:
import tensordict

rollout = rollouts[0]


In [56]:
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

for i in range(len(rollout)):
    graph = rollout[i]["observation"]["graph"]
    rollout[i]["observation"].update(
        {"x": graph.x, "edge_index": graph.edge_index, "batch_index": graph.batch}
    )
    rollout[i]["next"]["observation"].update(
        {"x": graph.x, "edge_index": graph.edge_index, "batch_index": graph.batch}
    )
tds = [tensordict.TensorDict.from_dict(frame_data) for frame_data in rollout]
tds = tensordict.LazyStackedTensorDict.maybe_dense_stack(tds)
tds["next", "reward", "total_reward"] = torch.zeros_like(
    tds["next", "reward", "total_reward"]
)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([35, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([35]), device=cpu, dtype=torch.bool, is_shared=False),
        info: TensorDict(
            fields={
                agent_exploration_rate: Tensor(shape=torch.Size([35]), device=cpu, dtype=torch.float32, is_shared=False),
                agent_explored_pixels: Tensor(shape=torch.Size([35]), device=cpu, dtype=torch.int64, is_shared=False),
                agent_id: Tensor(shape=torch.Size([35]), device=cpu, dtype=torch.int64, is_shared=False),
                agent_step_cnt: Tensor(shape=torch.Size([35]), device=cpu, dtype=torch.int64, is_shared=False),
                delta_time: Tensor(shape=torch.Size([35]), device=cpu, dtype=torch.int64, is_shared=False),
                global_exploration_rate: Tensor(shape=torch.Size([35]), device=cpu, dtype=torch.float32, is_shared=False),
                step_cnt: Tensor(shape=torch.S