## Step1 初始化 Jupyter 环境 & 导入包

In [None]:
# 用于在 Jupyter 中强制刷新参数
%reset -f

# 导入相关的包
import os
import sys
from collections import deque
from pathlib import Path

import torch
import ale_py
import pygame
import imageio
import gymnasium as gym
import numpy as np

from tqdm.notebook import tqdm
from torch.distributions import Categorical
from torchvision.transforms import v2
from loguru import logger

## Step2 设置相关参数

In [None]:
# 相关功能
is_training = 1                     # 是否进行训练
is_evaluate = 1                     # 是否进行评估, 此时会渲染游戏画面
need_record = 1                     # 是否开启录像, 前提是 is_evaluate=1 才有效, 不会渲染游戏画面

# 日志等级
log_level = "INFO"
logger.remove()
logger.add(sys.stderr, level=log_level)

# 环境信息
env_id = "ALE/Galaxian-v5"          # 游戏环境名
env_height = 210                    # 游戏画面高度
env_width = 160                     # 游戏画面宽度
max_steps = 10000                   # 每个回合的最大步数
render_mode = "rgb_array"           # 渲染模式，可选 "human"、"rgb_array" 等

# A2C 算法参数
gamma = 0.99                        # 折扣因子
memory_buffer_size = 10000          # 记忆缓存区大小
frame_stack = 2                     # 帧堆叠的数量

# 训练参数
num_train_episodes = 10000          # 训练的总回合数
lr = 1e-5                           # 学习率
max_same_action = 16                # 最大连续相同动作次数，防止模型陷入局部最优解
timestep_reward = 300               # 如果每隔指定的时间步, 并且生命值不减少的话, 则给予奖励

# 评估参数
num_eval_episodes = 10              # 评估的回合数
reward_threshold = 2000             # 评估奖励阈值, 如果高于阈值时, 日志等级为 Success, 否则为 Warning
eval_sample_action = True           # 评估时的动作是否基于概率来采样, True 则基于概率来选取动作, False 则直接选取最大概率


# 保存策略
save_dir = "./Gym_ALE_Galaxian_A2C"                     # 数据保存的目录
save_freq = 100                                         # 模型保存的频率
max_checkpoints = 5                                     # 最大保存的模型数量
checkpoint_perfix_A = "CheckPoint_Gym_ALE_Galaxian_A_"  # 模型保存的前缀 Actor
checkpoint_perfix_C = "CheckPoint_Gym_ALE_Galaxian_C_"  # 模型保存的前缀 Critic
evaluate_record_perfix = "Video_Gym_ALE_Galaxian_"      # 评估记录保存的前缀
evaluate_record_fps = 30                                # 评估记录保存的帧率
evaluate_record_quality = 10                            # 评估记录保存的质量, 值为 0 ~ 10

# 其余参数初始化
device = "cuda" if torch.cuda.is_available() else "cpu"
gym.register_envs(ale_py)                               # Arcade Learning Environment(ALE) 环境需要提前注册

## Step3 预处理函数 & 工具

In [None]:
def get_max_checkpoint_id(checkpoint_perfix, save_dir=save_dir):
    """
    获取最新的模型路径, 并返回 "模型路径" 和 checkpoint 对应的 id
    """
    # 如果指定目录不存在, 则直接创建该目录
    if not Path(save_dir).exists():
        Path(save_dir).mkdir(parents=True)
        logger.debug("The specified directory does not exist, will create this folder")
        return None
    
    # 获取所有的模型文件
    checkpoints = []
    current_path = Path(save_dir)
    for entry in current_path.iterdir():
        if entry.is_file() and entry.suffix == ".pth" and entry.name.startswith(checkpoint_perfix):
            id = entry.name.split(checkpoint_perfix)[-1].split(".")[0]
            checkpoints.append(id)
    
    # 寻找最大的 checkpoint id
    if checkpoints.__len__() == 0:
        logger.info(f"Not found any {checkpoint_perfix} files, will random initialization of network parameters")
        return None
    else:
        max_checkpoint_id = max(checkpoints)
        max_checkpoint_path = os.path.abspath(f"{save_dir}/{checkpoint_perfix}{max_checkpoint_id}.pth")
        logger.info(f"Found max checkpoints, max_checkpoint_id is {max_checkpoint_id}")
        return {"max_checkpoint_path" : max_checkpoint_path, "max_checkpoint_id" : max_checkpoint_id}

In [None]:
def del_old_checkpoint(checkpoint_perfix, save_dir=save_dir, max_checkpoints=max_checkpoints):
    """
    删除旧的模型文件, 只保留最新的 max_checkpoints 个模型文件
    """
    if Path(save_dir).exists():
        checkpoints = []
        for entry in Path(save_dir).iterdir():
            if entry.is_file() and entry.suffix == ".pth" and entry.name.startswith(checkpoint_perfix):
                id = int(entry.name.split(checkpoint_perfix)[-1].split(".")[0])
                checkpoints.append(id)
    
    if checkpoints.__len__() > max_checkpoints:
        min_checkpoint_id = min(checkpoints)
        min_checkpoint_path = os.path.abspath(f"{save_dir}/{checkpoint_perfix}{min_checkpoint_id}.pth")
        os.remove(min_checkpoint_path)
        logger.warning(f"Delete old checkpoint file {min_checkpoint_path}")

## Step4 定义智能体

In [None]:
class RLAgent:
    """
    智能体类, 封装了智能体所需要的各种方法
    """
    def __init__(self, action_size):
        # Global Args
        self.max_checkpoint_a = get_max_checkpoint_id(checkpoint_perfix=checkpoint_perfix_A)
        self.max_checkpoint_c = get_max_checkpoint_id(checkpoint_perfix=checkpoint_perfix_C)
        self.memory_buffer = deque(maxlen=memory_buffer_size)

        # Init Actor Network
        self.actor_network = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=frame_stack, out_channels=frame_stack * 2, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(in_channels=frame_stack * 2, out_channels=frame_stack * 4, kernel_size=3, stride=2, padding=1),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(out_features=1024),
            torch.nn.LeakyReLU(),
            torch.nn.LazyLinear(out_features=128),
            torch.nn.LeakyReLU(),
            torch.nn.LazyLinear(out_features=action_size),
        )
        if self.max_checkpoint_a is not None:
            self.actor_network.load_state_dict(torch.load(self.max_checkpoint_a["max_checkpoint_path"]))
        
        # Init Critic Network
        self.critic_network = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=frame_stack, out_channels=frame_stack * 2, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(in_channels=frame_stack * 2, out_channels=frame_stack * 4, kernel_size=3, stride=2, padding=1),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(out_features=1024),
            torch.nn.LeakyReLU(),
            torch.nn.LazyLinear(out_features=128),
            torch.nn.LeakyReLU(),
            torch.nn.LazyLinear(out_features=1),
        )
        if self.max_checkpoint_c is not None:
            self.critic_network.load_state_dict(torch.load(self.max_checkpoint_c["max_checkpoint_path"]))


        # Move to designated device
        self.actor_network.to(device)
        self.critic_network.to(device)

        # Transfoms
        self.transform = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Grayscale(1),
        ])

        # optimizer
        self.a_optimizer = torch.optim.AdamW(self.actor_network.parameters(), lr=lr)
        self.c_optimizer = torch.optim.AdamW(self.critic_network.parameters(), lr=lr)

    def processing_states(self, frame_buffer):
        """
        对输入的 frame_buffer 进行预处理, 并返回模型可以处理的 Tensor 对象
        """
        # 将形状处理为 [batch_size=1, color_channel * stack_size, height, width]
        states = torch.stack(tuple(self.transform(frame_buffer)), dim=0)
        states = states.reshape(1, frame_stack, env_height, env_width)
        logger.debug(f"Processing states shape: {states.shape}")
        return states
    
    def select_action(self, state, sample=True):
        """
        选择动作, 某些算法需要对模型的输出进行采样, 因此可以将 sample 设置为 True
        :param state:  神经网络可以接收的输入形状: [batch_size, color_channel * stack_size, height, width]
        :param sample: 动作是否是采样, 如果不是则直接选择概率最高的动作
        """
        state = state.to(device)
        if sample:
            # https://pytorch.ac.cn/docs/stable/distributions.html#categorical
            # 采样 & 动作的对数概率最好采用这种方法, 可以避免梯度消失的问题
            logits = self.actor_network(state)
            action_dist = Categorical(logits=logits)
            action = action_dist.sample()
            log_prob = action_dist.log_prob(action)
            return {"action": action, "log_prob": log_prob}
        else:
            action_logits = self.actor_network(state)
            action = action_logits.argmax(dim=1).item()
            return {"action": action}

    def update(self):
        """
        更新 A2C 算法的网络参数
        """
        num_mems = len(self.memory_buffer)
        logger.debug(f"memory buffer size: {num_mems}")

        # 提取对应的数据
        # 注意, 这里要提前处理形状, 防止在计算时广播导致形状不对
        state = torch.cat([data["St"] for data in self.memory_buffer], dim=0).to(device)
        action = torch.tensor([data["At"] for data in self.memory_buffer]).to(device)
        reward = torch.tensor([data["Rt"] for data in self.memory_buffer]).unsqueeze(1).to(device)
        next_state = torch.cat([data["St+1"] for data in self.memory_buffer], dim=0).to(device)
        done = torch.tensor([data["Done"] for data in self.memory_buffer]).float().unsqueeze(1).to(device)
        logger.debug(f"state shape: {state.shape}, action shape: {action.shape}, reward shape: {reward.shape}, next_state shape: {next_state.shape}, done shape: {done.shape}")

        # Critic 网络: TD目标 & TD误差
        td_tgt = reward + gamma * self.critic_network(next_state) * (1 - done)
        td_err = td_tgt - self.critic_network(state)

        # Actor 网络：动作的概率分布
        # 这里不使用 torch.log(self.actor_network(state).gather(1, action)) 来获取概率的对数
        # 而是使用 Categorical 分布来获取 log_probs
        logits = self.actor_network(state)
        action_dist = Categorical(logits=logits)
        log_probs = action_dist.log_prob(action)

        # Actor 和 Critic 的损失函数
        actor_loss = torch.mean(-log_probs * td_err.detach())
        critic_loss = torch.mean(torch.nn.functional.mse_loss(self.critic_network(state), td_tgt.detach()))
        logger.info(f"actor_loss: {actor_loss:4f}, critic_loss: {critic_loss:4f}")

        self.a_optimizer.zero_grad()
        self.c_optimizer.zero_grad()
        actor_loss.backward()
        critic_loss.backward()
        self.a_optimizer.step()
        self.c_optimizer.step()

        # 清空经验池中的数据
        self.memory_buffer.clear()
    
    def save_model(self, episodes):
        """
        保存模型到指定路径, 并根据实际情况删除老的模型
        """
        # 没有任何已存在的模型文件, 即首次启动训练
        if self.max_checkpoint_a is None:
            max_checkpoint_path_a = os.path.abspath(f"{save_dir}/{checkpoint_perfix_A}{episodes}.pth")
        # 已存在模型文件的情况
        else:
            max_checkpoint_path_a = os.path.abspath(f"{save_dir}/{checkpoint_perfix_A}{episodes + int(self.max_checkpoint_a["max_checkpoint_id"])}.pth")

        # 没有任何已存在的模型文件, 即首次启动训练
        if self.max_checkpoint_c is None:
            max_checkpoint_path_c = os.path.abspath(f"{save_dir}/{checkpoint_perfix_C}{episodes}.pth")
        # 已存在模型文件的情况
        else:
            max_checkpoint_path_c = os.path.abspath(f"{save_dir}/{checkpoint_perfix_C}{episodes + int(self.max_checkpoint_c["max_checkpoint_id"])}.pth")

        # 保存模型参数
        torch.save(self.actor_network.state_dict(), max_checkpoint_path_a)
        torch.save(self.critic_network.state_dict(), max_checkpoint_path_c)
        logger.info(f"Actor Model saved to {max_checkpoint_path_a}")
        logger.info(f"Critic Model saved to {max_checkpoint_path_c}")

        # 删掉老模型
        del_old_checkpoint(checkpoint_perfix=checkpoint_perfix_A)
        del_old_checkpoint(checkpoint_perfix=checkpoint_perfix_C)

## Step5 调整环境

In [None]:
# 定制环境
class AleCustomEnv(gym.Wrapper):
    """
    ALE 定制环境, 继承自 gym.Wrapper 类
    """
    def __init__(self, env):
        super().__init__(env)
        self.current_lives = 4          # 剩余生命值
        self.live_time = 0              # 生存时间, 超过一定时间会给予奖励
        self.previous_action = None     # 上一次执行的动作
        self.same_action_count = 0      # 重复动作的次数

    def reset(self):
        """
        重置环境, 这里定制了一些需要重置的参数
        """
        # 重置观察结果
        observation = self.env.reset()

        # 重置相关计数
        self.current_lives = 4
        self.live_time = 0
        self.previous_action = None
        self.same_action_count = 0
        self.same_action_display = False

        return observation
    
    def step(self, action):
        """
        执行动作, 并调整了env 的行为或奖励机制
        """
        # 调用原始环境的 step 方法
        observation, reward, terminated, truncated, info = self.env.step(action)

        # 如果生命值发生变化, 则给予惩罚
        if info['lives'] != self.current_lives:
            self.current_lives = info['lives']
            self.live_time = 0
            reward = -100
            logger.debug(f"lives -1, current lives: {self.current_lives}")
        
        # 鼓励存活时间
        elif reward == 0:
            self.live_time += 1
            if self.live_time == timestep_reward:
                self.live_time = 0
                reward = 100
                logger.debug(f"live_time +100")


        # 如果重复次数过多, 则给予惩罚
        if self.previous_action == action:
            self.same_action_count += 1
            if self.same_action_count >= max_same_action:
                reward = -25
                if self.same_action_display is False:
                    self.same_action_display = True
                    logger.error(f"same action too many times, same_action_count = {self.same_action_count}")
        else:
            same_action = self.same_action_count
            self.same_action_count = 0
            self.previous_action = action
            if self.same_action_display is True:
                self.same_action_display = False
                logger.error(f"same action it's over, total {same_action}")

        # 返回最终结果: observation, reward, terminated, truncated, info
        return observation, reward, terminated, truncated, info


## Step6 训练智能体

In [None]:
if is_training:
    # 训练用的主环境
    env = gym.make(env_id, render_mode=render_mode)
    env = AleCustomEnv(env)

    # 实例化智能体 (动作空间必须是离散的)
    if isinstance(env.action_space, gym.spaces.Discrete):
        action_size = env.action_space.n
        Agent = RLAgent(action_size=action_size)
    else:
        logger.error("Action space is not Discrete!")
        raise ValueError("Action space is not Discrete!")

    # 循环每个回合
    for episode in tqdm(range(num_train_episodes)):
        # 初始化环境
        state, info = env.reset()
        steps = 0
        total_reward = 0
        frame_buffer = deque(maxlen=frame_stack)
        current_action = None
        current_states = None
        next_states = None
        
        # 初始化帧缓冲区
        for _ in range(frame_stack): 
            frame_buffer.append(state)

        # 回合中的每一步
        for step in range(max_steps):
            # 处理当前状态
            current_states = Agent.processing_states(frame_buffer)

            # 选择动作 & 对数概率
            if step % frame_stack == 0:
                output = Agent.select_action(current_states)
                action, log_prob = output['action'].item(), output['log_prob']
                current_action = action
                logger.debug(f"Selected action: {action}")
            # 执行当前动作: current_action & 更新帧缓冲区
            observation, reward, terminated, truncated, info = env.step(current_action)
            total_reward += reward
            frame_buffer.append(observation)
            next_states = Agent.processing_states(frame_buffer)
            logger.debug(f"Step {step + 1} | Reward: {reward} | Total Reward: {total_reward} | Terminated: {terminated} | Truncated: {truncated} | Info: {info}")

            # 保存到记忆区: 如果该帧是决策帧, 则新建记忆区记录
            if step % frame_stack == 0:
                Agent.memory_buffer.append({"St": current_states, "At": current_action, "Rt": reward, "St+1": next_states, "Done": terminated})
            # 如果该帧不是决策帧, 则调整 & 完善记忆区记录
            else:
                # 奖励叠加
                Agent.memory_buffer[-1]["Rt"] += reward
                # 将 St+1 替换为最新的状态
                Agent.memory_buffer[-1]["St+1"] = next_states

            # 判断是否结束该回合
            if terminated or truncated:
                if total_reward >= reward_threshold:
                    logger.success(f"Episode finish, total step {step + 1} | Total Reward: {total_reward}")
                else:
                    logger.warning(f"Episode finish, total step {step + 1} | Total Reward: {total_reward}")
                total_reward = 0
                Agent.memory_buffer[-1]["Done"] = terminated
                break
        
        # 更新模型
        Agent.update()
        
        # 保存模型
        if (episode + 1) % save_freq == 0 and episode != 0:
            episodes = episode + 1
            Agent.save_model(episodes)       

## Step7 评估智能体

In [None]:
# 评估但不录制视频
if is_evaluate == 1 and need_record == 0:
    eval_env = gym.make(env_id, render_mode="human")
    eval_env = AleCustomEnv(eval_env)
# 评估且需要录制视频
elif is_evaluate == 1 and need_record == 1:
    eval_env = gym.make(env_id, render_mode="rgb_array")
    eval_env = AleCustomEnv(eval_env)

# 如果启用了评估
if is_evaluate == 1:
    # 初始化用于评估的参数
    frame_record = []
    max_reward = 0

    # 实例化用于评估的智能体
    Agent = RLAgent(action_size=eval_env.action_space.n)

    # 每个回合
    for episode in tqdm(range(num_eval_episodes)):
        # 初始化环境
        state, info = eval_env.reset()
        steps = 0
        total_reward = 0
        frame_buffer = deque(maxlen=frame_stack)
        current_action = None
        # 初始化帧缓冲区
        for _ in range(frame_stack): 
            frame_buffer.append(state)
            
        # 回合中的每一步
        for step in range(max_steps):
            # 处理当前状态
            current_states = Agent.processing_states(frame_buffer)
            # 选择动作
            if step % frame_stack == 0:
                output = Agent.select_action(current_states, sample=eval_sample_action)
                current_action = output["action"].item()
            # 执行该动作
            observation, reward, terminated, truncated, info = eval_env.step(current_action)
            total_reward += reward
            # 更新帧缓冲区
            frame_buffer.append(observation)
            # 如果需要记录视频, 则渲染画面 eval_env.render(), 然后将此画面添加到 frame_record 中
            if need_record:
                frame_record.append(eval_env.render())
            # 判断是否结束
            if terminated or truncated:
                # 如果需要记录视频, 则保留最好的记录
                if need_record and total_reward > max_reward:
                    np_frame_record = np.array(frame_record)
                    max_reward = total_reward
                    frame_record.clear()
                # 评估奖励
                if total_reward >= reward_threshold:
                    logger.success(f"Step {step + 1} | Total Reward: {total_reward}")
                else:
                    logger.warning(f"Step {step + 1} | Total Reward: {total_reward}")
                break

    # 记录评估结果(只记录最好的奖励轮次)
    if need_record:
        record_file = f"{os.path.abspath(os.path.join(save_dir, evaluate_record_perfix))}{int(max_reward)}.mp4"
        imageio.mimsave(record_file, np_frame_record, fps=evaluate_record_fps, quality=evaluate_record_quality)
        logger.info(f"The best evaluation record is: {record_file}")

    # 关闭环境
    eval_env.close()
    pygame.quit()