## Step1 初始化 Jupyter 环境 & 导入包

In [None]:
# 用于在 Jupyter 中强制刷新参数
%reset -f

# 导入相关的包
import os
import sys
import copy
import random
from collections import deque
from pathlib import Path

import torch
from torchvision.transforms import v2
import ale_py
import pygame
import imageio
import gymnasium as gym
import numpy as np

from tqdm.notebook import tqdm
from torchvision.transforms import v2
from loguru import logger

## Step2 设置相关参数

In [None]:
# 相关功能
is_training = 1                     # 是否进行训练
is_evaluate = 1                     # 是否进行评估, 此时会渲染游戏画面
need_record = 1                     # 是否开启录像, 前提是 is_evaluate=1 才有效, 不会渲染游戏画面

# 日志等级
log_level = "INFO"
logger.remove()
logger.add(sys.stderr, level=log_level)

# 环境信息
env_id = "ALE/Galaxian-v5"          # 游戏环境名
env_high = 210                      # 游戏画面高度
env_weight = 160                    # 游戏画面宽度
max_steps = 10000                   # 每个回合的最大步数
render_mode = "rgb_array"           # 渲染模式，可选 "human"、"rgb_array" 等

# DQN 算法参数
gamma = 0.99                        # 折扣因子
max_epsilon = 0.9                   # epsilon 的最大值
min_epsilon = 0.05                  # epsilon 的最小值
decay_rate = 0.001                  # epsilon 的衰减率
net_target_sync_freq = 10           # 目标网络的同步频率
replay_buffer_size = 100000         # 经验回放缓冲区的大小
frame_stack = 2                     # 帧堆叠的数量
dqn_type = "Double DQN"             # DQN 类型, 可选的值为 "DQN"、"Double DQN"

# 训练参数
num_train_episodes = 3000           # 训练的总回合数
num_train_steps = 100               # 每个回合的训练步数
lr = 1e-4                           # 学习率
batch_size = 64                     # 批量大小

# 评估参数
num_eval_episodes = 10              # 评估的回合数
reward_threshold = 1000             # 评估奖励阈值, 如果高于阈值时, 日志等级为 Success, 否则为 Warning

# 保存策略
save_dir = "./Gym_ALE_Galaxian"                      # 数据保存的目录
save_freq = 100                                      # 模型保存的频率
max_checkpoints = 5                                  # 最大保存的模型数量
checkpoint_perfix = "DQN_Gym_ALE_Galaxian_"          # 模型保存的前缀
evaluate_record_perfix = "DQN_Gym_ALE_Galaxian_"     # 评估记录保存的前缀
evaluate_record_fps = 30                             # 评估记录保存的帧率
evaluate_record_quality = 10                         # 评估记录保存的质量, 值为 0 ~ 10

# 其余参数初始化
device = "cuda" if torch.cuda.is_available() else "cpu"
gym.register_envs(ale_py)                            # Arcade Learning Environment(ALE) 环境需要提前注册

## Step3 预处理函数 & 工具

In [None]:
def get_max_checkpoint_id(save_dir=save_dir, checkpoint_perfix=checkpoint_perfix):
    """
    获取最新的模型路径, 并返回 "模型路径" 和 checkpoint 对应的 id
    """
    # 如果指定目录不存在, 则直接创建该目录
    if not Path(save_dir).exists():
        Path(save_dir).mkdir(parents=True)
        logger.debug("The specified directory does not exist, will create this folder")
        return None
    
    # 获取所有的模型文件
    checkpoints = []
    current_path = Path(save_dir)
    for entry in current_path.iterdir():
        if entry.is_file() and entry.suffix == ".pth":
            id = entry.name.split(checkpoint_perfix)[-1].split(".")[0]
            checkpoints.append(id)
    
    # 寻找最大的 checkpoint id
    if checkpoints.__len__() == 0:
        logger.info("Not found any checkpoint files, will random initialization of network parameters")
        return None
    else:
        max_checkpoint_id = max(checkpoints)
        max_checkpoint_path = os.path.abspath(f"{save_dir}/{checkpoint_perfix}{max_checkpoint_id}.pth")
        logger.info(f"Found max checkpoints, max_checkpoint_id is {max_checkpoint_id}")
        return {"max_checkpoint_path" : max_checkpoint_path, "max_checkpoint_id" : max_checkpoint_id}

In [None]:
def del_old_checkpoint(save_dir=save_dir, checkpoint_perfix=checkpoint_perfix, max_checkpoints=max_checkpoints):
    """
    删除旧的模型文件, 只保留最新的 max_checkpoints 个模型文件
    """
    if Path(save_dir).exists():
        checkpoints = []
        for entry in Path(save_dir).iterdir():
            if entry.is_file() and entry.suffix == ".pth":
                id = int(entry.name.split(checkpoint_perfix)[-1].split(".")[0])
                checkpoints.append(id)
    
    if checkpoints.__len__() > max_checkpoints:
        min_checkpoint_id = min(checkpoints)
        min_checkpoint_path = os.path.abspath(f"{save_dir}/{checkpoint_perfix}{min_checkpoint_id}.pth")
        os.remove(min_checkpoint_path)
        logger.warning(f"Delete old checkpoint file {min_checkpoint_path}")

## Step4 定义智能体

In [None]:
class RLAgent:
    """
    智能体类, 封装了智能体所需要的各种方法
    """
    def __init__(self, action_size):
        # Global Args
        self.max_checkpoint = get_max_checkpoint_id()
        self.current_epsilon = 1.0
        self.replay_buffer = deque(maxlen=replay_buffer_size)

        # Init DQN Network
        self.network = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=frame_stack, out_channels=frame_stack * 2, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(in_channels=frame_stack * 2, out_channels=frame_stack * 4, kernel_size=3, stride=2, padding=1),
            torch.nn.LeakyReLU(),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(out_features=1024),
            torch.nn.LeakyReLU(),
            torch.nn.LazyLinear(out_features=1024),
            torch.nn.LeakyReLU(),
            torch.nn.LazyLinear(out_features=action_size),
        )
        if self.max_checkpoint is not None:
            self.network.load_state_dict(torch.load(self.max_checkpoint["max_checkpoint_path"]))
        
        # Init DQN Target Network
        self.target_network = copy.deepcopy(self.network)

        # Move to designated device
        self.network.to(device)
        self.target_network.to(device)

        # Transfoms
        self.transform = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Grayscale(1),
        ])

        # Loss function and optimizer
        self.loss_fn = torch.nn.MSELoss()
        self.optimizer = torch.optim.AdamW(self.network.parameters(), lr=lr)
    
    def sync_target_network(self):
        """
        同步目标网络
        """
        self.target_network.load_state_dict(self.network.state_dict())
        logger.info("The target network has been synchronized")
    
    def select_action(self, state):
        """
        选择动作
        """
        state = state.to(device)
        action = self.network(state).argmax(dim=1).item()
        return action
    
    def processing_states(self, frame_buffer):
        """
        对输入的 frame_buffer 进行预处理, 并返回模型可以处理的 Tensor 对象
        """
        # 将形状处理为 [batch_size, color_channel * stack_size, height, width]
        states = torch.stack(tuple(self.transform(frame_buffer)), dim=0)
        states = states.reshape(1, frame_stack, env_high, env_weight)
        logger.debug(f"Processing states shape: {states.shape}")
        return states
    
    def save_model(self, episodes):
        """
        保存模型到指定路径, 并根据实际情况删除老的模型
        """
        # 没有任何已存在的模型文件, 即首次启动训练
        if self.max_checkpoint is None:
            max_checkpoint_path = os.path.abspath(f"{save_dir}/{checkpoint_perfix}{episodes}.pth")
        # 已存在模型文件的情况
        else:
            max_checkpoint_path = os.path.abspath(f"{save_dir}/{checkpoint_perfix}{episodes + int(self.max_checkpoint["max_checkpoint_id"])}.pth")

        torch.save(self.network.state_dict(), max_checkpoint_path)
        logger.info(f"Model saved to {max_checkpoint_path}")
        # 删掉老模型
        del_old_checkpoint()

## Step5 调整环境

In [None]:
# 定制环境
class AleCustomEnv(gym.Wrapper):
    """
    ALE 定制环境, 继承自 gym.Wrapper 类
    """
    def __init__(self, env):
        super().__init__(env)
    
    def step(self, action):
        # 调用原始环境的 step 方法
        # 如果想要调整 env 的行为或奖励机制, 可以在这里进行调整
        observation, reward, terminated, truncated, info = self.env.step(action)

        return observation, reward, terminated, truncated, info


## Step6 训练智能体

In [None]:
if is_training:
    # 训练用的主环境
    env = gym.make(env_id, render_mode=render_mode)
    env = AleCustomEnv(env)

    # 实例化智能体 (动作空间必须是离散的)
    if isinstance(env.action_space, gym.spaces.Discrete):
        action_size = env.action_space.n
        Agent = RLAgent(action_size=action_size)
    else:
        logger.error("Action space is not Discrete!")
        raise ValueError("Action space is not Discrete!")

    # 循环每个回合
    for episode in tqdm(range(num_train_episodes)):
        # 初始化环境
        state, info = env.reset()
        steps = 0
        total_reward = 0
        frame_buffer = deque(maxlen=frame_stack)
        current_action = None

        # 初始化帧缓冲区
        for _ in range(frame_stack): 
            frame_buffer.append(state)
        
        # 计算当前的 epsilon 值
        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * episode)
        logger.info(f"Episode {episode + 1}, Current Epsilon: {epsilon:.4f}")

        # 采样阶段: 将回合的每一步都保存到 "经验回放缓冲区" 中
        for step in range(max_steps):
            # 处理当前状态
            current_states = Agent.processing_states(frame_buffer)
            # 选择动作
            if step % frame_stack == 0:
                # 生成随机数
                random_num = random.uniform(0, 1)
                if random_num > epsilon:
                    action = Agent.select_action(current_states)
                    current_action = action
                    logger.debug(f"Selected action: {action}")
                else:
                    action = env.action_space.sample()
                    current_action = action
                    logger.debug(f"Random action: {action}")
            
            # 执行动作
            observation, reward, terminated, truncated, info = env.step(current_action)
            total_reward += reward
            logger.debug(f"Step {step + 1} | Reward: {reward} | Total Reward: {total_reward} | Terminated: {terminated} | Truncated: {truncated} | Info: {info}")

            # 更新帧缓冲区
            frame_buffer.append(observation)

            # 保存到 "经验回放缓冲区"
            Agent.replay_buffer.append({"St": current_states, "At": current_action, "Rt": reward, "St+1": Agent.processing_states(frame_buffer), "Done": terminated or truncated})

            # 判断是否结束该回合
            if terminated or truncated:
                if total_reward > reward_threshold:
                    logger.success(f"Step {step + 1} | Total Reward: {total_reward}")
                else:
                    logger.warning(f"Step {step + 1} | Total Reward: {total_reward}")
                total_reward = 0
                break
        
        # 训练阶段
        for step in range(num_train_steps):
            loss_avg = 0
            loss_num = 0
            # 加载数据
            if len(Agent.replay_buffer) > batch_size:
                mini_batch = random.sample(Agent.replay_buffer, batch_size)
            else:
                mini_batch = Agent.replay_buffer
            # size: [bs, frame_stack, height, width]
            batch_state = torch.cat([data["St"] for data in mini_batch], dim=0).to(device)
            # size: [bs, 1] -> 目的是为了与 gather 操作相兼容
            batch_action = torch.tensor([data["At"] for data in mini_batch]).unsqueeze(1).to(device)
            # size: [bs]
            batch_reward = torch.tensor([data["Rt"] for data in mini_batch]).to(device)
            # size: [bs, frame_stack, height, width]
            batch_next_s = torch.cat([data["St+1"] for data in mini_batch], dim=0).to(device)
            # size: [bs]
            batch_done = torch.tensor([data["Done"] for data in mini_batch]).float().to(device)

            if dqn_type == "DQN":
                # 计算当前 action 的 Q 值
                # gather 操作要求 batch_action 具有形状 [bs, 1], 因此需要 unsqueeze(1) 添加额外的维度, 然后用 squeeze(1) 去掉多余的维度
                q_values = Agent.network(batch_state).gather(1, batch_action).squeeze(1)
                # 计算 target 的值, 即利用 target network 来估计下一个状态的最大 Q 值
                with torch.no_grad():
                    next_q_values = Agent.target_network(batch_next_s).max(1)[0]
                    # 如果被终结, 则奖励只有 batch_reward
                    target_q_values = batch_reward + gamma * next_q_values * (1 - batch_done)
            elif dqn_type == "Double DQN":
                # Double DQN 算法, 利用 Agent.network 来选择下一个状态的 action, 利用 Agent.target_network 来估计下一个状态的 Q 值
                # 即将动作选择和 Q 值估计分离, 用来减轻 Q 估值过高的问题
                q_values = Agent.network(batch_state).gather(1, batch_action).squeeze(1)
                next_actions = Agent.network(batch_next_s).argmax(1)
                next_q_values = Agent.target_network(batch_next_s).gather(1, next_actions.unsqueeze(1)).squeeze(1)
                target_q_values = batch_reward + gamma * next_q_values * (1 - batch_done)
            
            loss = Agent.loss_fn(q_values, target_q_values)
            loss_value = loss.item()
            loss_avg += loss_value
            loss_num += 1
            loss.backward()
            Agent.optimizer.step()
            Agent.optimizer.zero_grad()

        logger.info(f"Reploy_Buffer size: {len(Agent.replay_buffer)}; Loss Avg: {loss_avg/loss_num:.4f}")

        # 同步网络
        if (episode + 1) % net_target_sync_freq == 0 and episode != 0:
            Agent.sync_target_network()
        
        # 保存模型
        if (episode + 1) % save_freq == 0 and episode != 0:
            episodes = episode + 1
            Agent.save_model(episodes)       

## Step7 评估智能体

In [None]:
# 评估但不录制视频
if is_evaluate == 1 and need_record == 0:
    eval_env = gym.make(env_id, render_mode="human")
    eval_env = AleCustomEnv(eval_env)
# 评估且需要录制视频
elif is_evaluate == 1 and need_record == 1:
    eval_env = gym.make(env_id, render_mode="rgb_array")
    eval_env = AleCustomEnv(eval_env)

# 初始化用于评估的参数
frame_record = []
max_reward = 0

# 实例化用于评估的智能体
Agent = RLAgent(action_size=eval_env.action_space.n)

# 每个回合
for episode in tqdm(range(num_eval_episodes)):
    # 初始化环境
    state, info = eval_env.reset()
    steps = 0
    total_reward = 0
    frame_buffer = deque(maxlen=frame_stack)
    current_action = None
    # 初始化帧缓冲区
    for _ in range(frame_stack): 
        frame_buffer.append(state)
        
    # 回合中的每一步
    for step in range(max_steps):
        # 处理当前状态
        current_states = Agent.processing_states(frame_buffer)
        # 选择动作
        if step % frame_stack == 0:
            action = Agent.select_action(current_states)
            current_action = action
        # 执行该动作
        observation, reward, terminated, truncated, info = eval_env.step(current_action)
        total_reward += reward
        # 更新帧缓冲区
        frame_buffer.append(observation)
        # 如果需要记录视频, 则缓存视频帧, 否则渲染画面
        if need_record:
            frame_record.append(observation)
        else:
            eval_env.render()
            
        # 判断是否结束
        if terminated or truncated:
            # 如果需要记录视频, 则保留最好的记录
            if need_record and total_reward > max_reward:
                np_frame_record = np.array(frame_record)
                max_reward = total_reward
                frame_record.clear()
            # 评估奖励
            if total_reward > reward_threshold:
                logger.success(f"Step {step + 1} | Total Reward: {total_reward}")
            else:
                logger.warning(f"Step {step + 1} | Total Reward: {total_reward}")
            break

# 记录评估结果(只记录最好的奖励轮次)
if need_record:
    record_file = f"{os.path.abspath(os.path.join(save_dir, evaluate_record_perfix))}{int(max_reward)}.mp4"
    imageio.mimsave(record_file, np_frame_record, fps=evaluate_record_fps, quality=evaluate_record_quality)
    logger.info(f"The best evaluation record is: {record_file}")

# 关闭环境
eval_env.close()
pygame.quit()
