In [1]:
# imports
from __future__ import annotations

import datetime
import pathlib as pl
import random
import shutil
import string
import time
from math import pi
from typing import Any

import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from PIL import Image, ImageOps
from rl.actor_critic import GAE, Actor, ActorCritic, Critic
from rl.network import GINPolicyNetwork, GINValueNetwork
from rl.utils import Sampler, to_graph
from torch.optim.adam import Adam
from torch.utils.tensorboard.writer import SummaryWriter

from mcerl.env import Env
from mcerl.utils import (
    delta_time_reward_standardize,
    exploration_reward_rescale,
    reward_sum,
)

%matplotlib inline

In [None]:
# load map
map_path = "map/0.png"
img = Image.open(map_path)
img = ImageOps.grayscale(img)
img = img.resize((200, 200))
grid_map = np.array(img)
grid_map[grid_map < 100] = 0
grid_map[grid_map >= 100] = 255
plt.imshow(grid_map, cmap="gray", vmin=0, vmax=255)

In [None]:
# define parameters
#########################
# log参数
#########################

# workspace
workspace_dir = pl.Path.cwd()
# output
output_dir = pl.Path(workspace_dir) / "output"
current_date = datetime.datetime.now().strftime("%Y_%m_%d_%H-%M-%S")
random_string = "".join(random.choices(string.ascii_letters + string.digits, k=6))
session_name = f"{current_date}_{random_string}"
# experiment
experiment_dir = output_dir / session_name
log_dir = experiment_dir / "log"
model_dir = experiment_dir / "model"
output_images_dir = experiment_dir / "images"
tensorboard_dir = experiment_dir / "tensorboard"

experiment_dir.mkdir(parents=True, exist_ok=True)
log_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
output_images_dir.mkdir(parents=True, exist_ok=True)
tensorboard_dir.mkdir(parents=True, exist_ok=True)


#########################
# 环境参数
#########################


exclude_parameters = list(locals().keys())

# log level
log_level = "warning"

# 几个agent
num_agents = 3

# agent的初始位置
agent_poses = None

# agent的初始方向
num_rays = 32

# 一个env最多迭代多少步
max_steps = 10000

# 一个agent最多迭代多少步
max_steps_per_agent = 100

# 传感器的范围
ray_range = 30

# 速度(pixel per step)
velocity = 1

# 最小的frontier pixel数
min_frontier_size = 8

# 最大的frontier pixel数
max_frontier_size = 30

# 探索的阈值
exploration_threshold = 0.95

# 地图的高度和宽度
map_height, map_width = grid_map.shape

# 一个frontier最多可以获得多少信息增益
max_exploration_gain = ray_range**2 * pi / 2.0


#########################
# PPO参数
#########################

# gae权重
lmbda = 0.95

# discount factor
gamma = 0.99

# clip范围
clip_coefficient = 0.2

# 最大gradient范数
max_grad_norm = 0.5

# ESS
entropy_weight = 0.001

# policy loss的权重
policy_loss_weight = 1.0

# value loss的权重
value_loss_weight = 0.5

#########################
# 训练参数
#########################

# 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 总共训练的次数
n_iters = 1000

# 每次训练的frame数
n_frames_per_iter = 10000

# 每次训练的并行环境数(num worker)
n_parallel_envs = 15

# 每次训练的epoch数, 即数据要被训练多少次
n_epochs_per_iter = 5

# 每次epoch的mini_batch大小
mini_batch_size = n_frames_per_iter // 10

# 每个agent的最大步数
max_steps_per_agent = 50

# 每次训练所用到的总的环境数量
n_envs_per_iter = round(n_frames_per_iter / num_agents / max_steps_per_agent)

# 每个epoch的frame数
n_frames_per_epoch = n_frames_per_iter * n_epochs_per_iter

# 每个epoch的mini_batch数
n_minibatches_per_epoch = n_frames_per_epoch // mini_batch_size


parameters = {
    k: v
    for k, v in locals().items()
    if k not in [*exclude_parameters, "exclude_parameters"]
}
parameters

In [4]:
# transform function
# env_transform
def env_transform(frame_data: dict[str, Any]) -> dict[str, Any]:
    """
    normalize position, exploration gain,etc.
    """
    # normalize frontier position to [0,1] and exploration gain to [0,1]
    width = float(map_width)
    height = float(map_height)
    frame_data["observation"]["frontier_points"] = [
        (
            float(x) / width,
            float(y) / height,
            float(gain) / max_exploration_gain,
            float(distance) / (width + height),
        )
        for x, y, gain, distance in frame_data["observation"]["frontier_points"]
    ]
    # normalize position to [0,1]
    frame_data["observation"]["pos"] = [
        (float(x) / width, float(y) / height)
        for x, y in frame_data["observation"]["pos"]
    ]
    frame_data["observation"]["target_pos"] = [
        (float(x) / width, float(y) / height)
        for x, y in frame_data["observation"]["target_pos"]
    ]
    # build graph
    frame_data = to_graph(frame_data)
    return frame_data  # noqa: RET504 explicit return 
# policy transform
def device_cast(
    frame_data: dict[str, Any], device: torch.device | None = None
) -> dict[str, Any]:
    """
    cast data to device
    """
    device = torch.device("cuda") if device is None else device
    frame_data["observation"]["graph"] = frame_data["observation"]["graph"].to(device)

    return frame_data

In [5]:
# define policy

policy_network = GINPolicyNetwork(dim_feature=6, dim_h=32)
value_network = GINValueNetwork(dim_feature=6, dim_h=32)
actor = Actor(policy_network=policy_network)
critic = Critic(value_network=value_network)
wrapped_actor_critic = ActorCritic(
    actor=actor, critic=critic, forward_preprocess=device_cast
)
wrapped_actor_critic = wrapped_actor_critic.to(device)
value_estimator = GAE(gamma=gamma, lmbda=lmbda)
optimizer = Adam(wrapped_actor_critic.parameters())

In [6]:
# wrapped_actor_critic.load_state_dict(torch.load("ppo_parallel_100.pt"))

In [7]:
# parameters for env rollout


rollout_parameters = {
    "grid_map": grid_map,
    "policy": wrapped_actor_critic,
    "env_transform": env_transform,
    "agent_poses": agent_poses,
    "num_agents": num_agents,
    "max_steps": max_steps,
    "max_steps_per_agent": max_steps_per_agent,
    "velocity": velocity,
    "sensor_range": ray_range,
    "num_rays": num_rays,
    "min_frontier_pixel": min_frontier_size,
    "max_frontier_pixel": max_frontier_size,
    "exploration_threshold": exploration_threshold,
    "return_maps": True,
    "requires_grad": False,
}

In [8]:
# single rollout
env = Env(
    log_level,
    (log_dir / "output.log").as_posix(),
)
with torch.no_grad():
    single_rollouts = env.rollout(**rollout_parameters)
rollouts = single_rollouts

In [9]:
# parallel rollout
import concurrent.futures


def rollout_env(env, params):
    """在给定的环境中执行rollout。"""
    return env.rollout(**params)


def parallel_rollout(envs, num_parallel_envs, rollout_params):
    """
    在多个仿真环境中并行执行rollout。

    :param envs: List[object] - 仿真环境列表，每个环境对象应有一个rollout方法。
    :param num_parallel_envs: int - 最大并行环境数量。
    :param rollout_params: dict - 传递给每个环境rollout方法的参数。
    :return: List - 每个环境的rollout结果列表。
    """
    results = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=num_parallel_envs
    ) as executor:
        # 提交所有环境的rollout任务
        futures = {
            executor.submit(rollout_env, env, rollout_params): env for env in envs
        }

        # 收集所有任务的结果
        for future in concurrent.futures.as_completed(futures):
            try:
                result = future.result()
                results.extend(result)
            except Exception as e:
                print(f"环境 {futures[future]} 的rollout执行时发生错误: {e}")

    return results


envs = [Env(log_level, log_path=(log_dir / "output.log").as_posix()) for i in range(20)]
with torch.no_grad():
    rollouts = parallel_rollout(envs, 10, rollout_parameters)

In [7]:
with torch.inference_mode():
    single_rollouts = env.rollout(**parameters)
rollouts = single_rollouts

In [12]:
# for logging
timestamp_str = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
writer = SummaryWriter(f"runs/ppo{timestamp_str}")
writer.add_text("parameters", str(parameters))

In [None]:
# train loop
import tqdm
from torch_geometric.loader import DataLoader

for _n_iter in tqdm.tqdm(range(n_iters)):
    # collect data
    iter_start_time = time.time()
    data_collection_start_time = time.time()
    with torch.no_grad():
        ##################
        # BEGIN DEBUGGING
        ##################
        # rollouts = single_env_rollout(env, grid_map, policy=wrapped_actor_critic,env_transform=env_transform)
        ##################
        # END DEBUGGING
        ##################
        env_rollout_start_time = time.time()
        rollouts = multi_threaded_rollout(
            env=lambda: Env(
                num_agents=num_agents,
                max_steps=max_steps,
                max_steps_per_agent=max_steps_per_agent,
                velocity=velocity,
                sensor_range=ray_range,
                num_rays=num_rays,
                min_frontier_pixel=min_frontier_size,
                max_frontier_pixel=max_frontier_size,
            ),
            grid_map=grid_map,
            agent_poses=agent_poses,
            policy=wrapped_actor_critic,
            num_threads=n_parallel_envs,
            epochs=n_envs_per_iter,
            return_maps=False,
            env_transform=env_transform,
        )
        env_rollout_end_time = time.time()
        # print(f"env rollout time: {env_rollout_end_time - env_rollout_start_time}")

        data_post_process_start_time = time.time()
        # normalize rewards
        rollouts = [
            exploration_reward_rescale(rollout, max_value=max_exploration_gain)
            for rollout in rollouts
        ]
        rollouts = [delta_time_reward_standardize(rollout) for rollout in rollouts]
        rollouts = [reward_sum(rollout, gamma=gamma) for rollout in rollouts]
        data_post_process_end_time = time.time()
        # print(
        # f"data post process time: {data_post_process_end_time - data_post_process_start_time}"
        # )

        gae_start_time = time.time()
        # compute GAE
        rollouts = [value_estimator(rollout) for rollout in rollouts]
        flattened_rollouts = [
            frame_data for rollout in rollouts for frame_data in rollout
        ]
        gae_end_time = time.time()
        # print(f"GAE time: {gae_end_time - gae_start_time}")
        # add to sampler

        # frame_indices = None
        graphs = []
        rewards = []
        values = []
        advantages = []
        log_probs = []
        returns = []
        total_times = []

        for frame_data in flattened_rollouts:
            graphs.append(frame_data["observation"]["graph"])
            rewards.append(frame_data["next"]["reward"]["total_reward"])
            values.append(frame_data["value"])
            advantages.append(frame_data["advantage"])
            log_probs.append(frame_data["log_prob"])
            returns.append(frame_data["return"])
            total_times.append(frame_data["info"]["total_time_step"])
            # if frame_indices is None:
            #     frame_indices = torch.zeros(graph.num_graphs)
            # else:
            #     frame_indices = torch.cat(
            #         [
            #             frame_indices,
            #             torch.ones(graph.num_graphs) * (frame_indices.max() + 1),
            #         ]
            #     )
        #     for i in range(graph.num_graphs):
        #         graphs.append(graph[i])

        # graphs = DataLoader(
        #     graphs,
        #     batch_size=len(graphs),
        # )
        # graphs = next(iter(graphs))
        # graphs = graphs.to(device)
        rewards = torch.tensor(rewards).to(device)
        values = torch.tensor(values).to(device)
        advantages = torch.tensor(advantages).to(device)
        log_probs = torch.tensor(log_probs).to(device)
        returns = torch.tensor(returns).to(device)
        total_times = torch.tensor(total_times).to(device)

        sampler = Sampler(
            batch_size=mini_batch_size,
            length=len(flattened_rollouts),
            graphs=graphs,
            rewards=rewards,
            values=values,
            advantages=advantages,
            log_probs=log_probs,
            returns=returns,
            total_times=total_times,
        )
    data_collection_end_time = time.time()
    # print(
    # f"data collection time: {data_collection_end_time - data_collection_start_time}"
    # )
    ##################
    # BEGIN DEBUGGING
    ##################
    # sampler = Sampler(
    #     rollout=[frame_data for rollout in rollouts for frame_data in rollout],
    #     batch_size=10,
    # )
    ##################
    # END DEBUGGING
    ##################
    # epoch loop
    clip_fraction = []
    ##################
    # BEGIN DEBUGGING
    ##################
    # for _n_epoch in range(1):
    ##################
    # END DEBUGGING
    ##################
    training_start_time = time.time()
    for _n_epoch in range(n_epochs_per_iter):
        epoch_start_time = time.time()
        ##################
        # BEGIN DEBUGGING
        ##################
        # for _n_mini_batch in range(1):
        ##################
        # END DEBUGGING
        ##################
        for _n_mini_batch in range(n_frames_per_iter // mini_batch_size):
            mini_batch_start_time = time.time()

            sample_start_time = time.time()
            # sample data
            mini_batch_data = sampler.random_sample()
            graphs = mini_batch_data["graphs"].to(device)
            rewards = mini_batch_data["rewards"]
            values = mini_batch_data["values"].to(device).flatten()
            advantages = mini_batch_data["advantages"].to(device).flatten()
            prev_log_prob = mini_batch_data["log_probs"].to(device).flatten()
            returns = mini_batch_data["returns"].to(device).flatten()
            total_times = mini_batch_data["total_times"]
            frame_indices = mini_batch_data["frame_indices"].to(device)
            sample_end_time = time.time()
            # print(f"sample time: {sample_end_time - sample_start_time}")

            training_data_prepare_start_time = time.time()
            # get minibatch data and transform to tensor for training

            forward_start_time = time.time()
            new_action, new_log_probs, new_values, entropy = (
                wrapped_actor_critic.forward_parallel(graphs, frame_indices)
            )
            new_log_probs = new_log_probs.to(device).flatten()
            new_values = new_values.to(device).flatten()
            forward_end_time = time.time()
            # print(f"forward time: {forward_end_time - forward_start_time}")

            training_data_prepare_end_time = time.time()
            # print(
            # f"training data prepare time: {training_data_prepare_end_time - training_data_prepare_start_time}"
            # )
            loss_compute_start_time = time.time()
            # compute loss
            log_ratio = new_log_probs - prev_log_prob
            ratio = log_ratio.exp()

            with torch.no_grad():
                old_approx_kl = (-log_ratio).mean()
                approx_kl = ((ratio - 1) - log_ratio).mean()
                clip_fraction += [
                    ((ratio - 1.0).abs() > clip_coefficient).float().mean().item()
                ]

            pg_loss1 = -advantages * ratio
            pg_loss2 = -advantages * torch.clamp(
                ratio, 1 - clip_coefficient, 1 + clip_coefficient
            )

            policy_loss = torch.max(pg_loss1, pg_loss2).mean()
            value_loss = 0.5 * ((new_values - returns) ** 2).mean()
            ess_loss = entropy.mean()
            loss = (
                policy_loss_weight * policy_loss
                + value_loss_weight * value_loss
                - entropy_weight * ess_loss
            )
            loss_compute_end_time = time.time()
            # print(
            # f"loss compute time: {loss_compute_end_time - loss_compute_start_time}"
            # )

            optimizer_start_time = time.time()
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                wrapped_actor_critic.parameters(), max_grad_norm
            )
            optimizer.step()
            optimizer_end_time = time.time()
            # print(f"optimizer time: {optimizer_end_time - optimizer_start_time}")
            mini_batch_end_time = time.time()
            # print(f"mini batch time: {mini_batch_end_time - mini_batch_start_time}")
        epoch_end_time = time.time()
        # print(f"epoch time: {epoch_end_time - epoch_start_time}")
    training_end_time = time.time()
    # print(f"training time: {training_end_time - training_start_time}")

    average_epsiode_time = torch.mean(total_times.to(torch.float)).item()
    episode_reward_mean = torch.mean(returns).item()

    global_step = _n_iter
    if global_step % 10 == 0 and global_step > 0:
        torch.save(wrapped_actor_critic.state_dict(), f"ppo_parallel_{global_step}.pt")
    writer.add_scalar("losses/value_loss", value_loss.item(), global_step)
    writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
    writer.add_scalar("losses/ESS", ess_loss.item(), global_step)
    writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
    writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
    writer.add_scalar("losses/clip_fraction", np.mean(clip_fraction), global_step)
    writer.add_scalar("info/episode_reward_mean", episode_reward_mean, global_step)
    writer.add_scalar("info/average_epsiode_time", average_epsiode_time, global_step)
    iter_end_time = time.time()
    # print(f"iter time: {iter_end_time - iter_start_time}")

torch.save(wrapped_actor_critic.state_dict(), f"ppo_parallel_{global_step}.pt")

In [None]:
# draw raw rewards for debugging

#print("exploration reward:")
for rollout in rollouts:
    plt.plot(
        [frame_data["next"]["reward"]["exploration_reward"] for frame_data in rollout]
    )
plt.show()
#print("time_step_reward :")
for rollout in rollouts:
    plt.plot(
        [frame_data["next"]["reward"]["time_step_reward"] for frame_data in rollout]
    )
plt.show()

In [20]:
lmbda = 0.95
gamma = 0.99
# normalize rewards
rollouts = [
    exploration_reward_rescale(rollout, max_value=max_exploration_gain)
    for rollout in rollouts
]
rollouts = [delta_time_reward_standardize(rollout) for rollout in rollouts]
rollouts = [reward_sum(rollout,gamma=gamma) for rollout in rollouts]

In [None]:
# compute GAE
from rl.actor_critic import GAE

gae = GAE(gamma=gamma, lmbda=lmbda)
rollouts = [gae(rollout) for rollout in rollouts]

In [None]:
# draw normalized rewards for debugging
# total_reward
for rollout in rollouts:
    plt.plot([frame_data["next"]["reward"]["total_reward"] for frame_data in rollout])
mean_val = np.mean(
    [
        frame_data["next"]["reward"]["total_reward"]
        for frame_data in rollout
        for rollout in rollouts
    ]
)
plt.title(f"total reward:{mean_val}")
plt.axhline(mean_val, color="r", linestyle="--")
plt.show()


# exploration_reward
for rollout in rollouts:
    plt.plot(
        [frame_data["next"]["reward"]["exploration_reward"] for frame_data in rollout]
    )
mean_val = np.mean(
    [
        frame_data["next"]["reward"]["exploration_reward"]
        for frame_data in rollout
        for rollout in rollouts
    ]
)
plt.title(f"exploration reward:{mean_val}")
plt.axhline(mean_val, color="r", linestyle="--")
plt.show()

# time_step_reward
for rollout in rollouts:
    plt.plot(
        [frame_data["next"]["reward"]["time_step_reward"] for frame_data in rollout]
    )
mean_val = np.mean(
    [
        frame_data["next"]["reward"]["time_step_reward"]
        for frame_data in rollout
        for rollout in rollouts
    ]
)
plt.title(f"time step reward:{mean_val}")
plt.axhline(mean_val, color="r", linestyle="--")
plt.show()

# reward_to_go
for rollout in rollouts:
    plt.plot([frame_data["reward_to_go"] for frame_data in rollout])
mean_val = np.mean(
    [frame_data["reward_to_go"] for frame_data in rollout for rollout in rollouts]
)
plt.title(f"reward to go:{mean_val}")
plt.axhline(mean_val, color="r", linestyle="--")
plt.show()


# reward_mean
reward_mean = [
    np.mean([frame_data["reward_to_go"] for frame_data in rollout])
    for rollout in rollouts
]
plt.plot(reward_mean)
plt.title(f"reward mean:{np.mean(reward_mean)}")
plt.axhline(y=np.mean(reward_mean), color="r", linestyle="--")
plt.show()

# gae
gaes = [frame_data["advantage"].item() for frame_data in rollout for rollout in rollouts]
plt.plot(gaes)
plt.title(f"gae:{np.mean(gaes)}")
plt.axhline(y=np.mean(gaes), color="r", linestyle="--")
plt.show()

In [11]:
rollouts = single_env_rollout(env, grid_map, policy=wrapped_actor_critic,env_transform=env_transform)


In [None]:
# store maps to gif
for rollout in rollouts:
    agent_id = rollout[0]["info"]["agent_id"]
    imgs = [
        Image.fromarray(frame_data["observation"]["agent_map"])
        for frame_data in rollout
    ]
    imgs[0].save(
        output_images_dir / f"agent_{agent_id}_{map_path.split('/')[-1]}.gif",
        save_all=True,
        append_images=imgs[1:],
        duration=1000,
        loop=0,
    )
flattened_rollouts = []
for rollout in rollouts:
    flattened_rollouts.extend(rollout)
sorted_rollouts = sorted(flattened_rollouts, key=lambda x: x["info"]["step_cnt"])
imgs = [
    Image.fromarray(frame_data["observation"]["global_map"])
    for frame_data in sorted_rollouts
]
imgs[0].save(
    output_images_dir / f"global_map_{map_path.split('/')[-1]}.gif",
    save_all=True,
    append_images=imgs[1:],
    duration=1000,
    loop=0,
)
rollouts[2][-1]["info"]["total_time_step"]

In [None]:
np.sum(
    [
        np.count_nonzero(rollouts[0][-1]["observation"]["agent_map"] != 127),
        np.count_nonzero(rollouts[1][-1]["observation"]["agent_map"] != 127),
        np.count_nonzero(rollouts[2][-1]["observation"]["agent_map"] != 127),
    ]
) / np.count_nonzero(rollouts[0][-1]["observation"]["global_map"] != 127)

In [6]:
# force clean



shutil.rmtree(experiment_dir)

In [None]:
# for test
grid_map=np.ones_like(grid_map)*255
grid_map[50:75,:]=0
grid_map[:,50:75]=0
grid_map[125:150,:]=0
grid_map[:,125:150]=0
plt.imshow(grid_map, cmap="gray", vmin=0, vmax=255)

# tests
env=Env()



from mcerl import GridMap
test_env_grid_map=GridMap(grid_map)
map_to_update=grid_map.copy()
map_to_update.fill(127)
test_map_tp_update=GridMap(map_to_update)
basic_end_points, circle_end_points, circle_end_points_with_polygon, map_to_update,roi_map = (
    env._env.test_map_update(
        test_env_grid_map, test_map_tp_update, (100, 100), 30, 32, 5
    )
)
a=np.array(map_to_update)
b=np.array(roi_map)
basic_end_points = [(x / 10000, y / 10000) for x, y in basic_end_points]
plt.imshow(b, cmap="gray", vmin=0, vmax=255)
from matplotlib.patches import Rectangle, Circle
plt.imshow(grid_map, cmap="gray", vmin=0, vmax=255)
# plt.scatter(*zip(*circle_end_points), c="g", s=2)
# plt.scatter(*zip(*basic_end_points), c="b", s=2)
plt.scatter(*zip(*circle_end_points_with_polygon), c="r", s=2,marker='x')
rect=[(70, 70), (131, 131)]
rect=Rectangle((rect[0][0],rect[0][1]),rect[1][0]-rect[0][0],rect[1][1]-rect[0][1],linewidth=1,edgecolor='y',facecolor='none')
plt.gca().add_patch(rect)
bbx = cv2.boundingRect(np.array(circle_end_points_with_polygon))
np.array(circle_end_points_with_polygon) - bbx[0:2]
# poly=cv2.fillPoly(b, [np.array(circle_end_points_with_polygon) - bbx[0:2]], 255)
# plt.imshow(poly, cmap="gray", vmin=0, vmax=255)