In [None]:
from __future__ import annotations

import numpy as np

%env SPDLOG_LEVEL=trace


from matplotlib import pyplot as plt
from PIL import Image, ImageOps
from rl.utils import build_graphs

from mcerl.env import Env
from mcerl.utils import (
    delta_time_reward_standardize,
    exploration_reward_rescale,
    reward_sum,
    single_env_rollout,
)

%matplotlib inline

In [None]:
img = Image.open("0.png")
img = ImageOps.grayscale(img)
img = img.resize((300, 200))
grid_map = np.array(img)
plt.imshow(grid_map, cmap="gray",vmin=0, vmax=255)

In [3]:
num_agents = 3
agent_poses = None
num_rays = 32
max_steps = 100000
max_steps_per_agent = 100
ray_range = 30
velocity = 1
min_frontier_size = 8
max_frontier_size = 30
exploration_threshold = 0.98
env = Env(
    num_agents=num_agents,
    max_steps=max_steps,
    max_steps_per_agent=max_steps_per_agent,
    velocity=velocity,
    sensor_range=ray_range,
    num_rays=num_rays,
    min_frontier_pixel=min_frontier_size,
    max_frontier_pixel=max_frontier_size,
    exploration_threshold=exploration_threshold,
)
# Example usage
# num_threads = 15
# epochs = 10
# rollouts = multi_threaded_rollout(
#     env=lambda: Env(
#         num_agents=num_agents,
#         max_steps=max_steps,
#         max_steps_per_agent=max_steps_per_agent,
#         velocity=velocity,
#         sensor_range=ray_range,
#         num_rays=num_rays,
#         min_frontier_pixel=min_frontier_size,
#         max_frontier_pixel=max_frontier_size,
#     ),
#     grid_map=grid_map,
#     agent_poses=agent_poses,
#     policy=policy,
#     num_threads=num_threads,
#     epochs=epochs,
# )
# trajectories = []
# frame_data = env.reset(grid_map, agent_poses, return_maps=True)
# trajectories.append(frame_data)
# while True:
#     agent_id = frame_data["info"]["agent_id"]
#     frame_data["action_agent_id"] = agent_id
#     frame_data = random_policy(frame_data)
#     frame_data = env.step(frame_data, return_maps=True)
#     trajectories.append(frame_data)
#     if env.done() is True:
#         break
# rollouts = split_trajectories(trajectories)
# rollouts = [pad_trajectory(rollout) for rollout in rollouts]
# rollouts = [refine_trajectory(rollout) for rollout in rollouts]

In [None]:
from rl.actor_critic import GINPolicyNetwork
from torch.nn import functional as F

network = GINPolicyNetwork(dim_h=32)

pred = network(
    rollout[0]["observation"]["graph"].x,
    rollout[0]["observation"]["graph"].edge_index,
    rollout[0]["observation"]["graph"].batch,
)
probabilities = F.softmax(pred, dim=0)
action_index = torch.multinomial(probabilities.squeeze(), 1)
log_prob = torch.log(probabilities[action_index])

In [4]:
rollouts = single_env_rollout(env, grid_map,policy=network)

In [5]:
imgs = [
    Image.fromarray(frame_data["observation"]["global_map"])
    for frame_data in rollouts[0]
]
imgs[0].save("array0.gif", save_all=True, append_images=imgs[1:], duration=50, loop=0)

In [6]:
max_step_exploration_reward = ray_range * 2 * 1.41

In [7]:
rollouts = [build_graphs(rollout) for rollout in rollouts]
rollouts = [
    exploration_reward_rescale(rollout, max_value=max_step_exploration_reward)
    for rollout in rollouts
]
rollouts = [delta_time_reward_standardize(rollout) for rollout in rollouts]
rollouts = [reward_sum(rollout) for rollout in rollouts]
rollouts = [build_graphs(rollout) for rollout in rollouts]

In [None]:
rollout = rollouts[2]
[frame_data["reward_to_go"] for frame_data in rollout]