可视化

In [None]:
import os
import warnings
import gymnasium as gym
import torch
import pprint
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import SubprocVectorEnv
from tianshou.data import Collector, VectorReplayBuffer, Batch
from tianshou.policy import PPOPolicy
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils.net.common import Net, ActorCritic
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.utils import TensorboardLogger
from IPython.display import display, clear_output
import ipywidgets as widgets
import time
from PIL import Image, ImageDraw, ImageFont

warnings.filterwarnings("ignore", category=DeprecationWarning)


def make_env():
    return gym.make("ALE/MsPacman-ram-v5")


DEVICE = "cuda:0"

env = make_env()
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
train_envs = SubprocVectorEnv([make_env for _ in range(10)])
# test_envs = SubprocVectorEnv([make_env for _ in range(10)])

net = Net(state_shape, hidden_sizes=[1024] * 3, device=DEVICE)
actor = Actor(net, action_shape, device=DEVICE).to(DEVICE)
critic = Critic(net, device=DEVICE).to(DEVICE)

actor_critic = ActorCritic(actor, critic).to(DEVICE)
optim = torch.optim.Adam(actor_critic.parameters(), lr=3e-4)
dist_fn = torch.distributions.Categorical
policy = PPOPolicy(
    actor=actor,
    critic=critic,
    optim=optim,
    dist_fn=dist_fn,
    discount_factor=0.99,
    max_grad_norm=0.5,
    eps_clip=0.2,
    vf_coef=0.5,
    ent_coef=0.01,
    reward_normalization=True,
    action_space=train_envs.action_space[0],
    action_scaling=False,
    deterministic_eval=True,
).to(DEVICE)


ckpt = "/root/gym/rl_compexp/save/MsPacMan-PPO1024"
save_path = "/root/gym/rl_compexp/save/MsPacMan-PPO1024"
policy.load_state_dict(torch.load(os.path.join(save_path, "ppo_best.pth")))
policy.eval()
obs_mapping = {
    "enemy_sue_x": 6,
    "enemy_inky_x": 7,
    "enemy_pinky_x": 8,
    "enemy_blinky_x": 9,
    "enemy_sue_y": 12,
    "enemy_inky_y": 13,
    "enemy_pinky_y": 14,
    "enemy_blinky_y": 15,
    "player_x": 10,
    "player_y": 16,
    "fruit_x": 11,
    "fruit_y": 17,
    "ghosts_count": 19,
    "player_direction": 56,
    "dots_eaten_count": 119,
    "num_lives": 123,
    "power_mode": 0x4E,
}


def calculate_score_from_ram(byte_120, byte_121):
    """
    Calculate the score from RAM values using BCD encoding.

    Parameters:
    - byte_120 (int): The value at RAM address 120.
    - byte_121 (int): The value at RAM address 121.

    Returns:
    - int: The final score as displayed on the screen.
    """
    # Decode BCD for byte_120
    n2 = (byte_120 >> 4) & 0x0F  # High nibble
    n1 = byte_120 & 0x0F  # Low nibble

    # Decode BCD for byte_121
    n4 = (byte_121 >> 4) & 0x0F  # High nibble (十位)
    n3 = byte_121 & 0x0F  # Low nibble (个位)

    # Calculate the final score
    score = n4 * 1000 + n3 * 100 + n2 * 10 + n1

    return score


# 监视函数的改进版
def watch_agent_play():
    env = make_env()
    obs, info = env.reset()
    done = False
    step = 0
    start_time = time.time()

    output_area = widgets.Output()
    display(output_area)

    while not done:
        step += 1
        action = policy.forward(
            Batch(obs=obs.reshape(1, -1), info=info), deterministic=True
        ).act[0]
        obs, reward, done, truncated, info = env.step(action)
        rgb_array = env.ale.getScreenRGB()

        game_img = Image.fromarray(rgb_array)
        game_width, game_height = game_img.size
        img_width = game_width * 2
        img_height = max(game_height, 600)
        img = Image.new("RGB", (img_width, img_height), color="black")

        img.paste(game_img, (0, 0))

        draw = ImageDraw.Draw(img)
        font = ImageFont.load_default()

        elapsed_time = time.time() - start_time
        fps = step / elapsed_time

        draw.text(
            (game_width + 10, 10),
            f"Step: {step} | FPS: {fps:.2f}",
            font=font,
            fill=(255, 255, 255),
        )

        obs_info = {key: obs[value] for key, value in obs_mapping.items()}
        obs_text = "Obs values and their meanings:\n"
        obs_text += "\n".join([f"{key}: {value}" for key, value in obs_info.items()])
        obs_text += "\n" + f"{calculate_score_from_ram(obs[120], obs[121])}"

        y_offset = 3
        for line in obs_text.split("\n"):
            draw.text(
                (game_width + 10, y_offset), line, font=font, fill=(255, 255, 255)
            )
            y_offset += 20

        with output_area:
            clear_output(wait=True)
            display(img)
    # def watch_agent_play():
    #     env = make_env()
    #     obs, info = env.reset()
    #     done = False
    #     step = 0
    #     start_time = time.time()

    #     # 创建一个输出区域
    #     output_area = widgets.Output()
    #     display(output_area)

    #     while not done:
    #         step += 1
    #         action = policy.forward(
    #             Batch(obs=obs.reshape(1, -1), info=info), deterministic=True
    #         ).act[0]
    #         obs, reward, done, truncated, info = env.step(action)
    #         rgb_array = env.ale.getScreenRGB()

    #         # 创建一个新的大图像，左侧放游戏画面，右侧放文字信息
    #         game_img = Image.fromarray(rgb_array)
    #         game_width, game_height = game_img.size
    #         img_width = game_width * 2
    #         img_height = max(game_height, 600)  # 确保有足够的高度显示所有信息
    #         img = Image.new("RGB", (img_width, img_height), color="black")

    #         # 将游戏画面粘贴到左侧
    #         img.paste(game_img, (0, 0))

    #         # 准备在右侧绘制文字
    #         draw = ImageDraw.Draw(img)
    #         font = ImageFont.load_default()

    #         elapsed_time = time.time() - start_time
    #         fps = step / elapsed_time

    #         # 添加步数和FPS信息
    #         draw.text(
    #             (game_width + 10, 10),
    #             f"Step: {step} | FPS: {fps:.2f}",
    #             font=font,
    #             fill=(255, 255, 255),
    #         )

    #         # 添加obs信息
    #         obs_info = {key: obs[value] for key, value in obs_mapping.items()}
    #         obs_text = "Obs values and their meanings:\n"
    #         obs_text += "\n".join([f"{key}: {value}" for key, value in obs_info.items()])
    #         obs_text += "\n" + f"{calculate_score_from_ram(obs[120], obs[121])}"

    #         y_offset = 3
    #         for line in obs_text.split("\n"):
    #             draw.text(
    #                 (game_width + 10, y_offset), line, font=font, fill=(255, 255, 255)
    #             )
    #             y_offset += 20

    #         # 显示合并后的图像
    #         with output_area:
    #             clear_output(wait=True)
    #             display(img)

    #         # time.sleep(0.1)  # 添加短暂的延迟以便更好地观察动画效果

    env.close()


watch_agent_play()

Output()