# Environment

In [1]:
import matplotlib
matplotlib.use('TkAgg')

In [2]:
import torch
import copy
import numpy as np
import numpy.random as npr
import random
import matplotlib.pyplot as plt 
import time


In [3]:
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

# tool functions
def output_agents_info(states):
    states_cp = copy.deepcopy(states)
    for agent_idx, st in states_cp.items():
        print(f"Agent {agent_idx}: {(st[0], st[1])}, {st[4]}")

# 全局变量 用于收集每 episode 的统计数据
episode_losses = []
episode_collisions = []
episode_epsilons = []

def drawing_plots(episodes, losses, collisions, epsilons):
    """
    绘制训练统计图表：每集平均 loss、每集碰撞数、epsilon 衰减曲线。
    """
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))

    # Loss 折线
    ax1.plot(episodes, losses, linewidth=1)
    ax1.set_title("Per-Episode Average Loss")
    ax1.set_xlabel("Episode")
    ax1.set_ylabel("Loss")

    # Collisions 折线
    ax2.plot(episodes, collisions, linewidth=1)
    ax2.set_title("Per-Episode Collisions")
    ax2.set_xlabel("Episode")
    ax2.set_ylabel("Collisions")

    # Epsilon 衰减折线
    ax3.plot(episodes, epsilons, linewidth=1)
    ax3.set_title("Epsilon Decay")
    ax3.set_xlabel("Episode")
    ax3.set_ylabel("Epsilon")

    plt.tight_layout()
    plt.show()


In [4]:
# Define the grid world environment
class GridWorldEnvironment:
    def __init__(self, size=5, agents_num=4):
        self.size = size
        self.agents_num = agents_num
        self.agents_positions = {}  # agent position
        self.agents_reached_A = {}  # if agents get item
        self.A_position = None
        self.B_position = (size - 1, size - 1)  # fixed location of B
        self.directions = ["up", "down", "left", "right"]
        self.total_collisions = 0
        self.total_steps = 0
        self.agents_idx = list(range(agents_num))
        # 存储 agent 上一步的动作
        self.last_action = {i: None for i in range(agents_num)}
        self._reset()

    def _reset(self, A_position=None):
        """
        Reset the environment to its initial state.
        """
        if A_position is not None:
            # when testing, set A position
            self.A_position = A_position
        else:
            # initialize A position
            self.A_position = (
                npr.randint(0, self.size - 1),
                npr.randint(0, self.size - 1),
            )
            # ensure A and B are not in the same position
            while self.A_position == self.B_position:
                self.A_position = (
                    npr.randint(0, self.size - 1),
                    npr.randint(0, self.size - 1),
                )

        # initialize agents' positions and reached_A status
        self.agents_positions = {}
        self.agents_reached_A = {}
        for idx in self.agents_idx:
            if npr.rand() > 0.5:
                self.agents_positions[idx] = self.A_position
                self.agents_reached_A[idx] = True
            else:
                self.agents_positions[idx] = self.B_position
                self.agents_reached_A[idx] = False

        self.total_collisions = 0
        self.total_steps = 0

    def _get_destination(self, agent_idx):
        """
        Get the destination position(A or B)
        """
        return "B" if self.agents_reached_A[agent_idx] else "A"

    def _find_nearby_collision_agents(self, agent_id):
        """
        Find nearby agents that might collide.
        """
        y, x = self.agents_positions[agent_id]
        destination_cur = self._get_destination(agent_id)
        nearby_agents = [
            (-1, -1),
            (-1, 0),
            (-1, 1),
            (0, -1),
            (0, 1),
            (1, -1),
            (1, 0),
            (1, 1),
        ]
        collision_status = []
        for dy, dx in nearby_agents:
            new_y, new_x = y + dy, x + dx
            # Check if new position is valid
            if 0 <= new_y < self.size and 0 <= new_x < self.size:
                has_agent = 0
                for other_agent_id in self.agents_idx:
                    if (
                        other_agent_id != agent_id
                        and self.agents_positions[other_agent_id] == (new_y, new_x)
                        and self._get_destination(other_agent_id) != destination_cur
                    ):  # agents are going to the same destination would cause collision
                        has_agent = 1
                collision_status.append(has_agent)
            else:
                collision_status.append(0)
        return collision_status

    def get_state(self, agent_idx):
        """
        Get the state of the environment for a specific agent.
        """
        position = self.agents_positions[agent_idx]
        reached_A = self.agents_reached_A[agent_idx]
        collision_agents = self._find_nearby_collision_agents(agent_idx)

        return np.array(
            [
                *position,  # (x, y)
                *self.A_position,  # (A_x, A_y)
                reached_A,
                *collision_agents,
            ]
        )

    def _check_done(self, agent_idx, test_flag=False):
        """
        Check if the agent has reached its destination.
        """
        if test_flag:
            print(
                f"Agent {agent_idx} | Position: {self.agents_positions[agent_idx]} | Reached A: {self.agents_reached_A[agent_idx]}"
            )
        if (
            self.agents_positions[agent_idx] == self.B_position
            and self.agents_reached_A[agent_idx]
        ):  # already at B and has item
            self.agents_reached_A[agent_idx] = False  # reset
            return True

    def take_action(self, action_dict, test_flag=False):
        """
        Take an action in the environment and return the next state, reward and collosions.
        """
        # print(f"Action dict: {action_dict}")
        planned_actions = {}  # {action_idx: action}
        # wall_collisions = []  # number of hitting wall
        if test_flag:
            print(f"    Next Action dict: ")
            for agent_idx, action in action_dict.items():
                print(f"    Agent {agent_idx}: {self.directions[action]}")

        for idx, a in action_dict.items():
            self.last_action[idx] = a  # 更新last_action

        for agent_idx, action in action_dict.items():
            y, x = self.agents_positions[agent_idx]
            if self.directions[action] == "up":
                new_y, new_x = y - 1, x
            elif self.directions[action] == "down":
                new_y, new_x = y + 1, x
            elif self.directions[action] == "left":
                new_y, new_x = y, x - 1
            elif self.directions[action] == "right":
                new_y, new_x = y, x + 1

            # check valid
            if 0 <= new_y < self.size and 0 <= new_x < self.size:
                planned_actions[agent_idx] = (new_y, new_x)  # move
                # wall_collisions.append(False)
            else:
                planned_actions[agent_idx] = (y, x)  # not move
                # wall_collisions.append(True)

        # check collision
        next_positions = copy.deepcopy(self.agents_positions)
        collisions = 0  # number of head-on collisions
        positions_agents_dict = {}  # agents in cells {(x, y): [agent_idx]}
        for idx in sorted(self.agents_idx):
            next_positions[idx] = planned_actions[idx]

        for agent_idx, pos in next_positions.items():
            if pos not in positions_agents_dict:
                positions_agents_dict[pos] = []
            positions_agents_dict[pos].append(agent_idx)

        agents_collisions = set()  # store agents that have collisions
        for pos, agents_cur in positions_agents_dict.items():
            if pos == self.A_position or pos == self.B_position:
                continue  # ignore A or B
            if len(agents_cur) > 1:
                dirs = [self._get_destination(a) for a in agents_cur]
                if "B" in dirs and "A" in dirs:  # head-on collision in same cell
                    collisions += 1
                    agents_collisions.update(agents_cur)
        # print(
        #     f"=====\nCurrent positions: {self.agents_positions}\nPlanned positions: {planned_actions}\nAgents with collisions: {agents_collisions}, Collisions: {collisions}\n====="
        # )

        # calculate rewards
        rewards = {}
        for agent_idx in self.agents_idx:
            reward = 0
            next_location = next_positions[agent_idx]
            # 碰撞惩罚，和其他reward互斥
            if agent_idx in agents_collisions:
                reward += -20
            else:
                # 正常的 pick-up / drop-off 奖励
                if not self.agents_reached_A[agent_idx]:
                    if next_location == self.A_position:
                        reward += 5
                    else:
                        reward += -0.1
                elif (
                    self.agents_reached_A[agent_idx]
                    and next_location == self.B_position
                ):
                    reward += 10
                else:
                    reward += -0.1

            rewards[agent_idx] = reward  # store reward

        # update agents' positions
        self.agents_positions = next_positions

        # accumulate total collisions and steps
        self.total_collisions += collisions
        self.total_steps += self.agents_num

        # Update item-carrying status after moving - 这里不更新 在check_done里更新
        for agent_idx in self.agents_idx:
            if (
                self.agents_reached_A[agent_idx]
                and self.agents_positions[agent_idx] == self.B_position
            ):
                pass
                # self.agents_reached_A[agent_idx] = False  # delivered item at B
            elif (not self.agents_reached_A[agent_idx]) and self.agents_positions[
                agent_idx
            ] == self.A_position:
                self.agents_reached_A[agent_idx] = True  # picked up item at A

        # format next state
        next_states = {}
        for agent_idx in self.agents_idx:
            next_states[agent_idx] = self.get_state(agent_idx)

        # LOG
        # print(f"Agent {agent_idx} | Position: {next_location} | Reward: {reward}")

        return next_states, rewards, collisions

    def get_valid_actions(self, agent_idx):
        """
        Get the valid actions(do not hit the wall) for a specific agent.
        """
        y, x = self.agents_positions[agent_idx]
        valid = []
        for a, d in enumerate(self.directions):
            ny, nx = {
                "up": (y - 1, x),
                "down": (y + 1, x),
                "left": (y, x - 1),
                "right": (y, x + 1),
            }[d]
            if 0 <= ny < self.size and 0 <= nx < self.size:
                valid.append(a)

        # 禁止与 last_action 动作相反，避免走重复的路径
        oppsite_direction = {0: 1, 1: 0, 2: 3, 3: 2}
        last = self.last_action[agent_idx]
        if last is not None and oppsite_direction[last] in valid:
            valid.remove(oppsite_direction[last])
        return valid
    
    # def get_legal_actions_mask(self, agent_idx):
    #     """
    #      返回一个 shape=(4,) 的 bool 数组，指示这个 agent 哪些动作
    #     （0=up,1=down,2=left,3=right）会在下一步造成 head-on 碰撞，
    #      应当被屏蔽（mask=False），其余 mask=True。
    #     """
    #     # 当前 agent 的位置 & 目标方向
    #     y, x = self.agents_positions[agent_idx]
    #     dest_i = self._get_destination(agent_idx)

    #     mask = np.ones(len(self.directions), dtype=bool)
    #     for a, dir_str in enumerate(self.directions):
    #         # 预测自己执行 a 之后的位置
    #         if dir_str == "up":
    #             ny, nx = y - 1, x
    #         elif dir_str == "down":
    #             ny, nx = y + 1, x
    #         elif dir_str == "left":
    #             ny, nx = y, x - 1
    #         else:  # "right"
    #             ny, nx = y, x + 1

    #         # 越界的动作也当作非法
    #         if not (0 <= ny < self.size and 0 <= nx < self.size):
    #             mask[a] = False
    #             continue

    #         # 如果目标格上有一个逆向 agent，就视为 head-on 碰撞
    #         for other in self.agents_idx:
    #             if other == agent_idx:
    #                 continue
    #             oy, ox = self.agents_positions[other]
    #             dest_o = self._get_destination(other)
    #             # 对向：一个要去 A，一个要去 B
    #             if (ny, nx) == (oy, ox) and dest_i != dest_o:
    #                 mask[a] = False
    #                 break

    #     return mask


# Agent

In [5]:
# deep q-learning agent
class Agent:
    def __init__(
        self,
        statespace_size,
        action_size,
        gamma=0.99,
        epsilon=1.0,
        epsilon_decay=0.9999,
        min_epsilon=0.1,
        batch_size=256,
        replay_buffer_size=50000,  # 5w
        lr=1e-3,  # 1e-3 → 5e-4 → 1e-4 → 5e-5
        copy_frequency=50,
    ):
        self.statespace_size = statespace_size
        self.action_size = action_size
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.min_epsilon = min_epsilon
        self.batch_size = batch_size
        self.replay_buffer_size = replay_buffer_size
        self.lr = lr
        self.copy_frequency = copy_frequency

        self.steps = 0  # count agent's steps
        self.replay_buffer = []  # memory

        # initialize the DQN
        self.model, self.model2, self.optimizer, self.loss_fn = self.prepare_torch()

        # set the device
        self.model.to(device)
        self.model2.to(device)

    def prepare_torch(self):
        l1, l2, l3, l4 = self.statespace_size, 128, 128, self.action_size
        model = torch.nn.Sequential(
            torch.nn.Linear(l1, l2),
            torch.nn.ReLU(),
            torch.nn.Linear(l2, l3),
            torch.nn.ReLU(),
            torch.nn.Linear(l3, l4),
        )
        model2 = copy.deepcopy(model)
        model2.load_state_dict(model.state_dict())
        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
        return model, model2, optimizer, loss_fn

    def update_target(self):
        self.model2.load_state_dict(self.model.state_dict())

    def get_qvals(self, state):
        state_tensor = torch.from_numpy(state).float().to(device)
        qvals_torch = self.model(state_tensor)
        qvals = qvals_torch.detach().numpy()
        return qvals

    def get_maxQ(self, s):
        # return torch.max(self.model2(torch.from_numpy(s).float())).detach().numpy()
        s_t = torch.from_numpy(s).float().to(device)
        return torch.max(self.model2(s_t)).detach().cpu().numpy()

    def get_action(self, state):
        if npr.uniform() < self.epsilon:
            action = npr.choice(self.action_size)
        else:
            qvals = self.get_qvals(state)
            action = np.argmax(qvals)
        return action

    # def get_action(self, state, legal_actions_mask):
    #     if np.random.rand() < self.epsilon:
    #         # only pick actions without collision
    #         return np.random.choice(np.where(legal_actions_mask)[0])
    #     q = self.get_qvals(state)
    #     q[~legal_actions_mask] = -np.inf
    #     return np.argmax(q)

    def get_greedy_action(self, state, valid_actions):
        qvals = self.get_qvals(state)
        masked = np.full_like(qvals, -np.inf)
        masked[valid_actions] = qvals[valid_actions]
        return int(np.argmax(masked))

    def store_transition(self, state, action, reward, next_state):
        """
        Store the transition in the replay buffer.
        """
        if len(self.replay_buffer) >= self.replay_buffer_size:
            # random remove sample
            remove_idx = npr.randint(0, len(self.replay_buffer))
            self.replay_buffer.pop(remove_idx)
        self.replay_buffer.append((state, action, reward, next_state))

    def train(self):
        """
        Train the agent using the replay buffer.
        """
        if len(self.replay_buffer) < self.batch_size:
            return  # samples not enough

        # sample a batch from the replay buffer
        minibatch = random.sample(
            self.replay_buffer,
            self.batch_size,
        )
        states, actions, rewards, next_states = zip(*minibatch)

        # TD targets
        targets = []
        for i in range(len(minibatch)):
            next_maxQ = self.get_maxQ(next_states[i])
            action_target = rewards[i] + self.gamma * next_maxQ
            targets.append(action_target)

        # train the model
        loss = self.train_one_step(states, actions, targets, self.gamma)

        # update network periodically
        self.steps += 1
        if self.steps % self.copy_frequency == 0:
            self.update_target()

        return loss

    def train_one_step(self, states, actions, targets, gamma):
        state1_batch = torch.tensor(np.array(states), dtype=torch.float32)
        action_batch = torch.tensor(np.array(actions), dtype=torch.float32)
        Q1 = self.model(state1_batch)
        X = Q1.gather(dim=1, index=action_batch.long().unsqueeze(dim=1)).squeeze()
        Y = torch.tensor(np.array(targets), dtype=torch.float32)
        loss = self.loss_fn(X, Y)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    # decay epsilon
    def decay_epsilon(self):
        self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)

# Training

In [6]:
def train_agents(
    agent, env, max_steps=1500000, max_collisions=4000, max_walltime=600, verbose=True
):
    print("开始训练，配置参数：")
    print(f"最大步数: {max_steps}")
    print(f"最大碰撞数: {max_collisions}")
    print(f"最大训练时间: {max_walltime}秒")
    """Training each agent in the environment."""
    # start time
    start_time = time.time()

    # global variables
    total_collisions = 0
    total_steps = 0
    episode = 0

    while total_collisions <= max_collisions and total_steps <= max_steps:
        if time.time() - start_time > max_walltime:
            print("===== Time limit exceeded. =====")
            break

        # *** variables for statistics  ***
        collisions_before = total_collisions  # store the collisions of the last episode
        loss_in_episode = []  # store the loss of the current episode
        # *** variables for statistics  ***

        # initialize the environment
        env._reset()
        # print(
        #     f"===== 主训练循环开始 episode: {episode}, A 点位置: {env.A_position} ====="
        # )
        # initial states of four agents
        states = {agent_idx: env.get_state(agent_idx) for agent_idx in env.agents_idx}

        # episode finish flag
        done = False
        max_steps_episode = 250
        episode_steps = 0
        loss = None
        while not done:
            # print(f"===== 主训练内循环开始 =====")
            # print(f"当前Agents位置\n")
            # output_agents_info(states)
            if episode_steps >= max_steps_episode:
                print("===== 触发了max_steps_episode退出 =====")
                done = True
                break
            actions_dict = {}
            for agent_idx in sorted(env.agents_idx):  # central clock - fix order
                action = agent.get_action(states[agent_idx])
                actions_dict[agent_idx] = action
                if len(agent.replay_buffer) >= agent.batch_size:
                    agent.decay_epsilon()

            # take action in the environment
            next_states, rewards, collisions = env.take_action(actions_dict)
            # print(f"[Step {total_steps + len(env.agents_idx)}] rewards: {rewards}")
            # print(f"下一Agents位置\n")
            # output_agents_info(next_states)

            # store transition in replay buffer
            for agent_idx in sorted(env.agents_idx):
                state = states[agent_idx]
                action = actions_dict[agent_idx]
                reward = rewards[agent_idx]
                next_state = next_states[agent_idx]
                agent.store_transition(state, action, reward, next_state)

            # train the agent
            if len(agent.replay_buffer) >= agent.batch_size:
                loss = agent.train()
                loss_in_episode.append(loss)

            # update the total collisions and steps
            total_collisions += collisions
            total_steps += len(env.agents_idx)

            # LOG
            if verbose and total_steps % 5000 == 0:
                elapsed = time.time() - start_time
                print(
                    f"Steps: {total_steps}/{max_steps}, "
                    f"Collisions: {total_collisions}/{max_collisions}, "
                    f"Epsilon: {agent.epsilon:.3f}, "
                    f"Time Elapsed: {elapsed:.1f}s, "
                    f"Loss: {loss}"
                )

            # check if any agent has done the task
            # for i in sorted(env.agents_idx):
            #     if env._check_done(i):
            #         # print("===== 触发了check_done退出 =====")
            #         done = True
            #         break

            # check if the training should stop
            if (
                total_steps >= max_steps
                or total_collisions >= max_collisions
                or time.time() - start_time > max_walltime
            ):
                print("===== 触发了步数、碰撞、超时退出 =====")
                done = True
                break

            # update the states
            states = next_states
            episode_steps += 1

        # agent.decay_epsilon()
        episode += 1

        # record episode statistics data
        avg_loss = float(np.mean(loss_in_episode)) if loss_in_episode else 0.0
        episode_losses.append(avg_loss)
        episode_collisions.append(total_collisions - collisions_before)
        episode_epsilons.append(agent.epsilon)

        # log
        # if verbose and episode % 10 == 0:
        #     elapsed = time.time() - start_time
        #     print(
        #         f"Episode: {episode}, "
        #         f"TotalSteps: {total_steps}, "
        #         f"TotalCollisions: {total_collisions}, "
        #         f"Epsilon: {agent.epsilon:.3f}, "
        #         f"Elapsed: {elapsed:.1f}s"
        #     )

    print("Training completed.")
    print(f"Total steps: {total_steps}")
    print(f"Total collisions: {total_collisions}")
    print(f"Final epsilon: {agent.epsilon:.3f}")

    # return the training results statistics to plot
    return {
        "episodes": list(range(1, episode + 1)),
        "losses": episode_losses,
        "collisions": episode_collisions,
        "epsilons": episode_epsilons,
    }

# Test

In [7]:
def test_agents(agent, env, max_steps=25, step_verbose=True):
    """
    Test the trained agent in the environment.
    """
    # initialize the parmeters
    agent.epsilon = 0  # no exploration
    A_positions_num = 24  # number of A positions

    def test_24_scenarios():
        """测试A的24种情况 B点固定在右下角 轮流初始化某个agent(agent_B)在B 其他三个在A 观察agent_B是否可以成功送达"""
        # for each scrneario, all of 4 agents could be delivered successfully, then count as 1
        success_times = 0
        # sum up steps for every successful delivery of all test scenarios
        success_steps_used = 0
        # sum up collisions for every successful delivery of all test scenarios
        total_collisions = 0
        # 固定B点在右下角，A点位置有24个可能
        all_positions = [(i, j) for i in range(env.size) for j in range(env.size)]
        A_positions = [pos for pos in all_positions if pos != env.B_position]
        for i, A_pos in enumerate(A_positions):
            # 强制设定 A 点位置
            env._reset(A_pos)
            # define a dic store agents who has delivery successfully and its steps
            hero_agents = {}  # {agent_idx: steps}
            collisions_scenerio = 0  # 记录当前场景发生的碰撞数

            print("==============================================================")
            print(
                f"===== 24 Scenarios 测试循环 Scenario: {i + 1}, A 点位置: {env.A_position} ====="
            )
            print("==============================================================")

            # 当前场景 - 设置每个agent初始化在B点去执行任务
            for agent_idx_B in env.agents_idx:
                # 初始设定所有 agent 都在A点
                for idx in env.agents_idx:
                    env.agents_positions[idx] = env.A_position
                    env.agents_reached_A[idx] = True
                # 设置当前 agent 在B点
                env.agents_positions[agent_idx_B] = env.B_position
                env.agents_reached_A[agent_idx_B] = False

                # initial states for four agents
                states = {
                    agent_idx: env.get_state(agent_idx) for agent_idx in env.agents_idx
                }

                # print(f"Agent {agent_idx_B} 开始送货")
                for step in range(max_steps):
                    # LOG
                    # print(f"\n[TEST] Step {step+1}/{max_steps} Q-values:")
                    # for agent_idx in sorted(env.agents_idx):
                    #     st = states[agent_idx]
                    #     qvals = agent.get_qvals(st)
                    #     print(f"  Agent {agent_idx} state={st} → qvals={qvals}")

                    actions_dict = {}  # 获取动作
                    for agent_idx in sorted(
                        env.agents_idx
                    ):  # central clock - fix order
                        valid = env.get_valid_actions(agent_idx)
                        actions_dict[agent_idx] = agent.get_greedy_action(
                            states[agent_idx], valid
                        )

                    # take action in the environment
                    next_states, _, collisions = env.take_action(actions_dict)

                    # print(f"[TEST] Step {step+1}/{max_steps}")
                    # print(f"  当前Agents位置")
                    # output_agents_info(states)
                    # print(f"  下一Agents位置")
                    # output_agents_info(next_states)

                    # 发生碰撞，当前agent此场景失败
                    if collisions > 0:
                        collisions_scenerio += collisions
                        # print(
                        #     f"===== 发生碰撞，此场景 Agent {agent_idx_B} 送货失败 ====="
                        # )
                        break

                    if env._check_done(agent_idx_B):
                        # 送达成功，当前agent此场景成功，退出测试下一个agent
                        if str(agent_idx_B) not in hero_agents:
                            # 记录成功送达的agent和步数
                            hero_agents[agent_idx_B] = step + 1
                            print(
                                f"上帝保佑! 当前场景 Agent {agent_idx_B} 送货成功 {len(hero_agents)}"
                            )
                            break
                        else:
                            print(f"重大错误! 当前场景重复测试了 Agent {agent_idx_B}")

                    # 更新states
                    states = copy.deepcopy(next_states)

            print(
                f"[Results] Scenario: {i + 1}, A 点位置: {env.A_position}, 成功次数: {len(hero_agents)}, 碰撞数: {collisions_scenerio}, 步数: {sum(hero_agents.values())}"
            )

            # 累加当前场景成功送达的步数
            success_steps_used += sum(hero_agents.values())
            # 累加当前场景碰撞数
            total_collisions += collisions_scenerio

            # 判断当前场景是否4个agents都完成送达，是的话算一次成功，退出当前循环，进入下一场景
            if len(hero_agents) == len(env.agents_idx):
                success_times += 1
                print(f"上帝保佑! Scenario: {i + 1} 4个Agents 全都送货成功")

        return success_times, total_collisions, success_steps_used

    # success_times, total_collisions, success_steps_used = test_24_scenarios()

    # 测试100次取各个指标的平均值
    TEST_TIMES = 100
    success_times_aggr, total_collisions_aggr, success_steps_used_aggr = 0, 0, 0
    for i in range(TEST_TIMES):
        success_times, total_collisions, success_steps_used = test_24_scenarios()
        print(
            "success_times, total_collisions, success_steps_used:",
            success_times,
            total_collisions,
            success_steps_used,
        )
        success_times_aggr += success_times
        total_collisions_aggr += total_collisions
        success_steps_used_aggr += success_steps_used
        print(
            f"第 {i+1} 次测试成功率: {(success_times / A_positions_num) *100:.2f}%, 碰撞数: {total_collisions}, 步数: {success_steps_used}"
        )
    print(
        "-----> ",
        success_times_aggr,
        total_collisions_aggr,
        success_steps_used_aggr,
        " <-----",
    )
    # test indicators
    total_possible_deliveries = (
        len(env.agents_idx) * A_positions_num * TEST_TIMES
    )  # 所有可能场景的送达数
    success_rate = (success_times_aggr / (A_positions_num * TEST_TIMES)) * 100
    avg_steps = (
        success_steps_used_aggr / (success_times_aggr * len(env.agents_idx))
        if success_times_aggr > 0
        else 0
    )
    total_collisions = total_collisions_aggr / TEST_TIMES
    collisions_rate = (total_collisions / (len(env.agents_idx) * A_positions_num)) * 100

    # test summary
    print("\n===== Test Summary =====")
    print(f"     平均成功率: {success_rate:.2f}%")
    print(f"送达成功平均步数: {avg_steps:.2f}")
    print(f"     平均总碰撞: {total_collisions}")

    return {
        "success_rate": success_rate,
        "avg_steps": avg_steps,
        "total_collisions": total_collisions,
        "collisions_rate": collisions_rate,
    }

## Test process

In [8]:
# setup the environment
test_env = GridWorldEnvironment()
test_state = test_env.get_state(0)
statespace_size = test_state.shape[0]
action_size = len(test_env.directions)

# Test
print("===== TEST =====")
env = GridWorldEnvironment()
agent = Agent(
    statespace_size,
    action_size,
)
stats = train_agents(
    agent,
    env,
    verbose=True,
)

metrics = test_agents(agent, env, max_steps=25, step_verbose=False)

# print("\n===== Final Test Metrics =====")
# print(f"Success rate: {metrics['success_rate']:.2f}%")
# print(f"Average steps: {metrics['avg_steps']:.2f}")
# print(f"Total collisions: {metrics['total_collisions']}")

# plot the training statistics
drawing_plots(
    stats["episodes"], stats["losses"], stats["collisions"], stats["epsilons"]
)

===== TEST =====
开始训练，配置参数：
最大步数: 1500000
最大碰撞数: 4000
最大训练时间: 600秒
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
Steps: 5000/1500000, Collisions: 97/4000, Epsilon: 0.622, Time Elapsed: 15.6s, Loss: 45.96787643432617
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
Steps: 10000/1500000, Collisions: 128/4000, Epsilon: 0.377, Time Elapsed: 31.8s, Loss: 18.840438842773438
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
Steps: 15000/1500000, Collisions: 139/4000, Epsilon: 0.229, Time Elapsed: 48.1s, Loss: 10.798181533813477
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_episode退出 =====
===== 触发了max_steps_epis