<a href="https://colab.research.google.com/github/anobody777/RL/blob/main/%E9%AC%BC%E6%8A%93%E4%BA%BA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import time
import math

# --- 1. 環境設定 ---
GHOST_SPEED = 2.0  # 鬼的速度 (單位/秒)
HUMAN_SPEED = 1.0  # 人的速度 (單位/秒)
RECALL_TIME = 2.0  # 鬼 "回城" 所需的時間 (秒)

MAP_MIN = 0.0
MAP_MAX = 10.0
DT = 0.01          # 模擬的時間精度 (每 0.01 秒更新一次)
CAPTURE_DISTANCE = 0.1 # 判定抓到的距離

# --- 2. 模擬：獨立 RL (Greedy) 策略 ---
# 假設：所有鬼都只會「貪婪地」追逐（即，都向左追）
def simulate_greedy_rl():
    print("="*30)
    print("開始模擬: 獨立 RL (Greedy) 策略")
    print("="*30)

    # 初始位置
    b_pos = 4.0
    a1_pos = 9.0
    a2_pos = 8.0

    # 速度 (人向左逃跑，鬼都向左追)
    b_vel = -HUMAN_SPEED
    a1_vel = -GHOST_SPEED
    a2_vel = -GHOST_SPEED

    t = 0.0

    while t < 10.0: # 最多模擬 10 秒
        t += DT

        # 更新位置
        b_pos += b_vel * DT
        a1_pos += a1_vel * DT
        a2_pos += a2_vel * DT

        # 檢查地圖邊界
        b_pos = max(MAP_MIN, min(MAP_MAX, b_pos))
        a1_pos = max(MAP_MIN, min(MAP_MAX, a1_pos))
        a2_pos = max(MAP_MIN, min(MAP_MAX, a2_pos))

        # 每 0.5 秒印出一次狀態
        if t % 0.5 < DT:
            print(f"t = {t:.2f}s | 人: {b_pos:.2f}, 鬼1: {a1_pos:.2f}, 鬼2: {a2_pos:.2f}")

        # 檢查是否抓到
        if abs(b_pos - a1_pos) < CAPTURE_DISTANCE or abs(b_pos - a2_pos) < CAPTURE_DISTANCE:
            print("-" * 30)
            print(f"!!! 抓到! (RL) !!!")
            print(f"最終時間: t = {t:.2f} 秒")
            print(f"最終位置: 人: {b_pos:.2f}, 鬼1: {a1_pos:.2f}, 鬼2: {a2_pos:.2f}")
            print("="*30 + "\n")
            return t

    print("--- RL 策略在 10 秒內未能抓到 ---")
    return t

# --- 3. 模擬：合作 MARL (QMIX) 策略 ---
# 假設：a1 (遠的鬼) 犧牲自己 "回城"，a2 (近的鬼) 負責追趕
def simulate_cooperative_marl():
    print("="*30)
    print("開始模擬: 合作 MARL (QMIX) 策略")
    print("="*30)

    # 初始位置
    b_pos = 4.0
    a1_pos = 9.0
    a2_pos = 8.0

    # 速度 (人向左逃跑，鬼2向左追)
    b_vel = -HUMAN_SPEED
    a2_vel = -GHOST_SPEED

    # 鬼1 的狀態
    a1_vel = 0.0 # 鬼1 正在 "回城"，不能移動
    a1_recall_timer = RECALL_TIME

    t = 0.0

    while t < 10.0:
        t += DT

        # --- 更新 鬼1 (a1) 的狀態 ---
        if a1_recall_timer > 0:
            a1_recall_timer -= DT
            if a1_recall_timer <= 0:
                print(f"t = {t:.2f}s | 鬼1 (a1) 完成回城，出現在 x=0")
                a1_pos = 0.0          # 傳送到 x=0
                a1_vel = GHOST_SPEED  # 開始向「右」追 (形成包圍)
        else:
            # 鬼1 正常移動
            a1_pos += a1_vel * DT

        # --- 更新 鬼2 (a2) 和 人 (b) 的狀態 ---
        b_pos += b_vel * DT
        a2_pos += a2_vel * DT

        # 檢查地圖邊界
        b_pos = max(MAP_MIN, min(MAP_MAX, b_pos))
        a1_pos = max(MAP_MIN, min(MAP_MAX, a1_pos))
        a2_pos = max(MAP_MIN, min(MAP_MAX, a2_pos))

        # 每 0.5 秒印出一次狀態
        if t % 0.5 < DT:
            print(f"t = {t:.2f}s | 人: {b_pos:.2f}, 鬼1: {a1_pos:.2f}, 鬼2: {a2_pos:.2f} (鬼1回城剩餘: {a1_recall_timer:.2f}s)")

        # 檢查是否抓到
        if abs(b_pos - a1_pos) < CAPTURE_DISTANCE or abs(b_pos - a2_pos) < CAPTURE_DISTANCE:
            print("-" * 30)
            print(f"!!! 抓到! (MARL) !!!")
            print(f"最終時間: t = {t:.2f} 秒")
            print(f"最終位置: 人: {b_pos:.2f}, 鬼1: {a1_pos:.2f}, 鬼2: {a2_pos:.2f}")
            print("="*30 + "\n")
            return t

    print("--- MARL 策略在 10 秒內未能抓到 ---")
    return t

# --- 4. 執行模擬 ---
if __name__ == "__main__":
    rl_time = simulate_greedy_rl()
    marl_time = simulate_cooperative_marl()

    print(f"模擬總結：")
    print(f"  獨立 RL (Greedy) 耗時: {rl_time:.2f} 秒")
    print(f"  合作 MARL (QMIX) 耗時: {marl_time:.2f} 秒")
    print(f"  合作策略節省了 {rl_time - marl_time:.2f} 秒")

開始模擬: 獨立 RL (Greedy) 策略
t = 0.50s | 人: 3.50, 鬼1: 8.00, 鬼2: 7.00
t = 1.00s | 人: 3.00, 鬼1: 7.00, 鬼2: 6.00
t = 1.50s | 人: 2.50, 鬼1: 6.00, 鬼2: 5.00
t = 2.00s | 人: 2.00, 鬼1: 5.00, 鬼2: 4.00
t = 2.51s | 人: 1.49, 鬼1: 3.98, 鬼2: 2.98
t = 3.01s | 人: 0.99, 鬼1: 2.98, 鬼2: 1.98
t = 3.51s | 人: 0.49, 鬼1: 1.98, 鬼2: 0.98
------------------------------
!!! 抓到! (RL) !!!
最終時間: t = 3.91 秒
最終位置: 人: 0.09, 鬼1: 1.18, 鬼2: 0.18

開始模擬: 合作 MARL (QMIX) 策略
t = 0.50s | 人: 3.50, 鬼1: 9.00, 鬼2: 7.00 (鬼1回城剩餘: 1.50s)
t = 1.00s | 人: 3.00, 鬼1: 9.00, 鬼2: 6.00 (鬼1回城剩餘: 1.00s)
t = 1.50s | 人: 2.50, 鬼1: 9.00, 鬼2: 5.00 (鬼1回城剩餘: 0.50s)
t = 2.00s | 鬼1 (a1) 完成回城，出現在 x=0
t = 2.00s | 人: 2.00, 鬼1: 0.00, 鬼2: 4.00 (鬼1回城剩餘: -0.00s)
t = 2.51s | 人: 1.49, 鬼1: 1.02, 鬼2: 2.98 (鬼1回城剩餘: -0.00s)
------------------------------
!!! 抓到! (MARL) !!!
最終時間: t = 2.64 秒
最終位置: 人: 1.36, 鬼1: 1.28, 鬼2: 2.72

模擬總結：
  獨立 RL (Greedy) 耗時: 3.91 秒
  合作 MARL (QMIX) 耗時: 2.64 秒
  合作策略節省了 1.27 秒


In [2]:
# --- 0. 安裝依賴 ---
# (我們需要 pettingzoo 來建立環境基礎)
!pip install pettingzoo pygame pyvirtualdisplay moviepy

# --- 1. 導入 (Imports) ---
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque, namedtuple
import time
import math
import pyvirtualdisplay
import gym
import imageio
from pettingzoo.utils.env import ParallelEnv
from gym.spaces import Discrete, Box
from IPython.display import Image, display
import os # 為了 Drive 儲存/加載

# --- 2. 啟動虛擬螢幕 ---
try:
    _display.stop()
except:
    pass
_display = pyvirtualdisplay.Display(visible=0, size=(500, 150))
_display.start()
print("Virtual display started.")

# --- 3. 【全新】1D 鬼抓人環境 (GhostHumanEnv) ---
# (這模擬了您描述的 1D 遊戲)
class GhostHumanEnv(ParallelEnv):
    metadata = {'render.modes': ['human', 'rgb_array'], 'name': 'ghost_human_v1'}

    def __init__(self, num_ghosts=2):
        self.MAP_MIN, self.MAP_MAX = 0.0, 10.0
        self.GHOST_SPEED = 2.0  # 單位/秒
        self.HUMAN_SPEED = 1.0  # 單位/秒
        self.RECALL_TIME = 2.0    # 2 秒
        self.TIME_STEP = 0.5      # 模擬的每一步是 0.5 秒
        self.RECALL_STEPS = int(self.RECALL_TIME / self.TIME_STEP) # = 4 步
        self.CAPTURE_DISTANCE = 0.25 # 判定抓到的距離
        self.MAX_STEPS = 50       # 每回合最多 50 步 (25 秒)

        self.num_ghosts = num_ghosts
        self.ghost_agent_ids = [f"ghost_{i}" for i in range(num_ghosts)]
        self.human_agent_id = "human_0"
        self.possible_agents = self.ghost_agent_ids + [self.human_agent_id]

        # 狀態空間 (觀測) for Ghost: [human_pos, my_pos, other_ghost_pos, my_recall_timer]
        self.observation_space_ghost = Box(low=-1.0, high=self.MAP_MAX, shape=(4,))
        # 狀態空間 for Human: [my_pos, ghost_0_pos, ghost_1_pos]
        self.observation_space_human = Box(low=-1.0, high=self.MAP_MAX, shape=(3,))

        # 動作空間 for Ghost: 0=向左, 1=向右, 2=回城, 3=不動
        self.action_space_ghost = Discrete(4)
        # 動作空間 for Human: 0=向左, 1=向右 (註：在此環境中由 AI 固定控制)
        self.action_space_human = Discrete(2)

        # PettingZoo API
        self.observation_spaces = {name: self.observation_space_ghost for name in self.ghost_agent_ids}
        self.observation_spaces[self.human_agent_id] = self.observation_space_human
        self.action_spaces = {name: self.action_space_ghost for name in self.ghost_agent_ids}
        self.action_spaces[self.human_agent_id] = self.action_space_human

        self.screen = None

    def reset(self, seed=None, hard_code_scenario=False):
        # hard_code_scenario 用於測試您指定的初始狀態
        if hard_code_scenario:
            self.human_pos = 4.0
            self.ghost_pos = [9.0, 8.0]
        else: # 隨機初始位置 (用於訓練)
            self.human_pos = random.uniform(3.0, 7.0)
            self.ghost_pos = [random.uniform(7.0, 10.0) for _ in range(self.num_ghosts)]

        self.ghost_recall_timers = [0] * self.num_ghosts # 0 = 不在回城
        self.steps = 0

        # PettingZoo API: agents 列表必須是當前活躍的
        self.agents = self.possible_agents[:] # 複製列表

        return self._get_observations()

    def _get_observations(self):
        obs = {}
        for i in range(self.num_ghosts):
            other_ghost_pos = self.ghost_pos[1-i] # (假設只有 2 個鬼)
            obs[f"ghost_{i}"] = np.array([self.human_pos, self.ghost_pos[i], other_ghost_pos, self.ghost_recall_timers[i]])
        obs["human_0"] = np.array([self.human_pos, self.ghost_pos[0], self.ghost_pos[1]])
        return obs

    def step(self, actions):
        if not self.agents: return {}, {}, {}, {}, {} # 如果回合已結束

        rewards = {name: 0.0 for name in self.agents}
        terminations = {name: False for name in self.agents}

        # --- 1. 人類 AI (固定策略：往左逃跑) ---
        if self.human_pos > self.MAP_MIN:
            self.human_pos += (-self.HUMAN_SPEED * self.TIME_STEP)
        self.human_pos = max(self.MAP_MIN, self.human_pos)

        # --- 2. 鬼的行動 (掠食者) ---
        for i in range(self.num_ghosts):
            ghost_name = f"ghost_{i}"
            action = actions.get(ghost_name, 3) # 默認為 3 (不動)

            # 檢查是否正在回城
            if self.ghost_recall_timers[i] > 0:
                self.ghost_recall_timers[i] -= 1
                if self.ghost_recall_timers[i] == 0:
                    self.ghost_pos[i] = self.MAP_MIN # 完成回城，瞬移到 0

            # 如果不在回城，則執行動作
            elif action == 0: # 向左
                self.ghost_pos[i] -= self.GHOST_SPEED * self.TIME_STEP
            elif action == 1: # 向右
                self.ghost_pos[i] += self.GHOST_SPEED * self.TIME_STEP
            elif action == 2: # 啟動回城
                self.ghost_recall_timers[i] = self.RECALL_STEPS
            # action == 3 (不動)

            self.ghost_pos[i] = max(self.MAP_MIN, min(self.MAP_MAX, self.ghost_pos[i]))

        # --- 3. 計算獎勵和終止 ---
        self.steps += 1

        captured = False
        for i in range(self.num_ghosts):
            if abs(self.ghost_pos[i] - self.human_pos) < self.CAPTURE_DISTANCE:
                captured = True; break

        # 為鬼 (ghost) 和人 (human) 計算不同的獎勵
        team_reward_ghost = 0.0  # 這是給 QMIX 和 IQL 的獎勵
        human_reward = 0.0       # 這是給人類的獎勵 (雖然我們沒訓練它)

        if captured:
            team_reward_ghost = 100.0  # 鬼抓到人，獲得巨大獎勵
            human_reward = -100.0
            terminations = {name: True for name in self.agents}
            self.agents = []
        elif self.steps >= self.MAX_STEPS:
            team_reward_ghost = -100.0 # 鬼超時，獲得巨大懲罰
            human_reward = 100.0  # 人類成功逃脫
            terminations = {name: True for name in self.agents}
            self.agents = []
        else:
            # 獎勵塑形：負的距離總和 (鼓勵靠近)
            team_reward_ghost = -sum(abs(self.ghost_pos[i] - self.human_pos) for i in range(self.num_ghosts))
            human_reward = -team_reward_ghost # 鼓勵遠離

        # 分配獎勵
        for name in self.agents:
            if "ghost" in name:
                rewards[name] = team_reward_ghost
            else:
                rewards[name] = human_reward

        return self._get_observations(), rewards, terminations, {name: False for name in self.agents}, {}

    def render(self, mode='rgb_array'):
        import pygame
        if self.screen is None: pygame.init(); self.screen = pygame.display.set_mode((self.screen_width, self.screen_height))
        self.screen.fill((255, 255, 255))
        map_y = 75
        map_start_x = int(self.MAP_MIN / self.MAP_MAX * self.screen_width * 0.9 + 0.05 * self.screen_width)
        map_end_x = int(self.MAP_MAX / self.MAP_MAX * self.screen_width * 0.9 + 0.05 * self.screen_width)
        pygame.draw.line(self.screen, (0, 0, 0), (map_start_x, map_y), (map_end_x, map_y), 2)
        def get_x_pos(pos): return int(pos / self.MAP_MAX * (map_end_x - map_start_x) + map_start_x)
        pygame.draw.circle(self.screen, (0, 255, 0), (get_x_pos(self.human_pos), map_y), 10)
        for i in range(self.num_ghosts):
            color = (255, 0, 0) if self.ghost_recall_timers[i] == 0 else (255, 180, 180)
            pygame.draw.circle(self.screen, color, (get_x_pos(self.ghost_pos[i]), map_y), 12)
        if mode == 'rgb_array': return np.transpose(np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2))
        else: pygame.display.flip()

    def close(self):
        if self.screen is not None:
            import pygame; pygame.display.quit(); pygame.quit(); self.screen = None

# --- 4. 超參數 (Hyperparameters) ---
BUFFER_SIZE, BATCH_SIZE, GAMMA, LR, TAU = int(1e5), 64, 0.95, 1e-3, 0.005
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print(f"Using device: {DEVICE}")
NUM_EPISODES, MAX_T = 1000, 50
EPS_START, EPS_END, EPS_DECAY = 1.0, 0.05, 0.995
MODEL_CHECKPOINT = 100
MIXER_EMBED_DIM = 32
VISUALIZATION_SEED = 5
LOAD_EPISODE = 0
DRIVE_PATH = '/content/gdrive/MyDrive/MARL_GhostGame_Training/' # (可以設為您的 Drive 路徑)


# --- 5. 演算法組件 (IQL 和 QMIX) ---
Experience = namedtuple("Experience", field_names=["states", "actions", "reward", "next_states", "done", "global_state", "next_global_state"])
class SharedReplayBuffer:
    def __init__(self, buffer_size, batch_size): self.memory, self.batch_size = deque(maxlen=buffer_size), batch_size
    def push(self, states, actions, reward, next_states, done, global_state, next_global_state): self.memory.append(Experience(states, actions, reward, next_states, done, global_state, next_global_state))
    def sample(self, num_agents):
        experiences = random.sample(self.memory, k=self.batch_size)
        states = [torch.from_numpy(np.vstack([e.states[i] for e in experiences])).float().to(DEVICE) for i in range(num_agents)]
        actions = [torch.from_numpy(np.vstack([e.actions[i] for e in experiences])).long().to(DEVICE) for i in range(num_agents)]
        next_states = [torch.from_numpy(np.vstack([e.next_states[i] for e in experiences])).float().to(DEVICE) for i in range(num_agents)]
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences])).float().to(DEVICE)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences]).astype(np.uint8)).float().to(DEVICE)
        global_states = torch.from_numpy(np.vstack([e.global_state for e in experiences])).float().to(DEVICE)
        next_global_states = torch.from_numpy(np.vstack([e.next_global_state for e in experiences])).float().to(DEVICE)
        return (states, actions, rewards, next_states, dones, global_states, next_global_states)
    def __len__(self): return len(self.memory)

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, seed):
        super(QNetwork, self).__init__(); self.seed = torch.manual_seed(seed)
        self.fc1, self.fc2, self.fc3 = nn.Linear(state_dim, 32), nn.Linear(32, 32), nn.Linear(32, action_dim)
    def forward(self, state): return self.fc3(F.relu(self.fc2(F.relu(self.fc1(state)))))

class QMixer(nn.Module):
    def __init__(self, num_agents, global_state_dim, embed_dim):
        super(QMixer, self).__init__(); self.num_agents, self.global_state_dim, self.embed_dim = num_agents, global_state_dim, embed_dim
        self.hyper_w1, self.hyper_b1 = nn.Linear(global_state_dim, embed_dim * num_agents), nn.Linear(global_state_dim, embed_dim)
        self.hyper_w2 = nn.Linear(global_state_dim, embed_dim)
        self.hyper_b2 = nn.Sequential(nn.Linear(global_state_dim, embed_dim), nn.ReLU(), nn.Linear(embed_dim, 1))
    def forward(self, agent_q_values, global_state):
        w1, b1 = torch.abs(self.hyper_w1(global_state)).view(-1, self.num_agents, self.embed_dim), self.hyper_b1(global_state).view(-1, 1, self.embed_dim)
        w2, b2 = torch.abs(self.hyper_w2(global_state)).view(-1, self.embed_dim, 1), self.hyper_b2(global_state).view(-1, 1, 1)
        agent_q_values = agent_q_values.view(-1, 1, self.num_agents)
        q_total = torch.bmm(F.elu(torch.bmm(agent_q_values, w1) + b1), w2) + b2
        return q_total.view(-1, 1)

# IQL Manager (對照組)
class IQLManager:
    def __init__(self, state_dims, action_dims, agent_ids, seed):
        self.agent_ids, self.num_agents = agent_ids, len(agent_ids)
        self.action_dims_dict = action_dims
        self.agents = {}
        for i, agent_id in enumerate(self.agent_ids):
            self.agents[agent_id] = DQNAgent(state_dims[agent_id], action_dims[agent_id], seed + i)
    def select_actions(self, obs, epsilon=0.):
        actions = {}
        for id, agent in self.agents.items():
            actions[id] = agent.select_action(obs[id], epsilon)
        return actions
    def push_and_learn(self, s, a, r, ns, d):
        for id, agent in self.agents.items():
            # IQL 學習：只使用「個人」獎勵
            individual_reward = r[id] # (環境已經計算好 team_reward_ghost)
            agent.push_and_learn(s[id], a[id], individual_reward, ns[id], d[id])
    def get_manager(self): return self
    def save_model(self, drive_path, episode_num): # IQL 的儲存
        model_dir = os.path.join(drive_path, f'IQL_episode_{episode_num}'); os.makedirs(model_dir, exist_ok=True)
        for id, agent in self.agents.items():
            torch.save(agent.qnetwork_local.state_dict(), os.path.join(model_dir, f'qnet_{id}.pth'))
        print(f"\n--- IQL Model saved to {model_dir} ---")
    def load_model(self, drive_path, episode_num): # IQL 的加載
        model_dir = os.path.join(drive_path, f'IQL_episode_{episode_num}')
        if not os.path.exists(model_dir): return False
        print(f"--- Loading IQL model from {model_dir} ---")
        for id, agent in self.agents.items():
            agent.qnetwork_local.load_state_dict(torch.load(os.path.join(model_dir, f'qnet_{id}.pth'), map_location=DEVICE))
            agent.soft_update_target_networks(1.0)
        return True

# DQNAgent (IQL 的組件)
class DQNAgent:
    def __init__(self, state_dim, action_dim, seed):
        self.action_dim = action_dim
        self.qnetwork_local = QNetwork(state_dim, action_dim, seed).to(DEVICE)
        self.qnetwork_target = QNetwork(state_dim, action_dim, seed).to(DEVICE)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
        self.memory = SharedReplayBuffer(BUFFER_SIZE, BATCH_SIZE)
        self.qnetwork_target.load_state_dict(self.qnetwork_local.state_dict())
    def select_action(self, state, epsilon=0.):
        state = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
        self.qnetwork_local.eval();
        with torch.no_grad(): action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()
        if random.random() > epsilon: return np.argmax(action_values.cpu().data.numpy())
        else: return random.choice(np.arange(self.action_dim))
    def push_and_learn(self, s, a, r, ns, d):
        self.memory.push((s,), (a,), r, (ns,), d, (s,), (ns,))
        if len(self.memory) > BATCH_SIZE: self.learn(self.memory.sample(1), GAMMA)
    def learn(self, experiences, gamma):
        states, actions, rewards, next_states, dones, _, _ = experiences
        s, a, ns = states[0], actions[0], next_states[0]
        Q_targets_next = self.qnetwork_target(ns).detach().max(1)[0].unsqueeze(1)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        Q_expected = self.qnetwork_local(s).gather(1, a)
        loss = F.mse_loss(Q_expected, Q_targets)
        self.optimizer.zero_grad(); loss.backward(); self.optimizer.step()
        self.soft_update_target_networks(TAU)
    def soft_update_target_networks(self, tau):
        for target_param, local_param in zip(self.qnetwork_target.parameters(), self.qnetwork_local.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)

# QMIX Manager (實驗組)
class QMIXManager:
    def __init__(self, state_dims, action_dims, global_state_dim, agent_ids, seed):
        self.agent_ids, self.num_agents, self.action_dims_dict, self.global_state_dim = agent_ids, len(agent_ids), action_dims, global_state_dim
        self.qnetworks_local = [QNetwork(state_dims[id], action_dims[id], seed + i).to(DEVICE) for i, id in enumerate(agent_ids)]
        self.qnetworks_target = [QNetwork(state_dims[id], action_dims[id], seed + i).to(DEVICE) for i, id in enumerate(agent_ids)]
        self.mixer_local = QMixer(self.num_agents, global_state_dim, MIXER_EMBED_DIM).to(DEVICE)
        self.mixer_target = QMixer(self.num_agents, global_state_dim, MIXER_EMBED_DIM).to(DEVICE)
        self.optimizer = optim.Adam([p for net in self.qnetworks_local for p in net.parameters()] + list(self.mixer_local.parameters()), lr=LR)
        self.memory = SharedReplayBuffer(BUFFER_SIZE, BATCH_SIZE)
        for local, target in zip(self.qnetworks_local, self.qnetworks_target): target.load_state_dict(local.state_dict())
        self.mixer_target.load_state_dict(self.mixer_local.state_dict())
    def get_manager(self): return self
    def select_actions(self, obs, epsilon=0.):
        actions = {}
        for i, id in enumerate(self.agent_ids):
            state = torch.from_numpy(obs[id]).float().unsqueeze(0).to(DEVICE); self.qnetworks_local[i].eval()
            with torch.no_grad(): actions[id] = np.argmax(self.qnetworks_local[i](state).cpu().data.numpy())
            self.qnetworks_local[i].train()
            if random.random() <= epsilon: actions[id] = random.choice(np.arange(self.action_dims_dict[id]))
        return actions
    def push_and_learn(self, s, a, r, ns, d):
        s_tuple = tuple(s[id] for id in self.agent_ids); a_tuple = tuple(a[id] for id in self.agent_ids)
        ns_tuple = tuple(ns[id] for id in self.agent_ids)
        team_reward = r[self.agent_ids[0]] # 所有鬼的獎勵都相同
        team_done = any(d[id] for id in self.agent_ids)
        # global_state: [a1_pos, a2_pos, b_pos, a1_timer, a2_timer]
        global_state = np.array([s["ghost_0"][1], s["ghost_1"][1], s["human_0"][0], s["ghost_0"][3], s["ghost_1"][3]])
        next_global_state = np.array([ns["ghost_0"][1], ns["ghost_1"][1], ns["human_0"][0], ns["ghost_0"][3], ns["ghost_1"][3]])
        self.memory.push(s_tuple, a_tuple, team_reward, ns_tuple, team_done, global_state, next_global_state)
        if len(self.memory) > BATCH_SIZE: self.learn(self.memory.sample(self.num_agents), GAMMA)
    def learn(self, experiences, gamma):
        states, actions, rewards, next_states, dones, global_states, next_global_states = experiences
        Q_targets_next_total = self.mixer_target(torch.cat([self.qnetworks_target[i](next_states[i]).detach().max(1)[0].unsqueeze(1) for i in range(self.num_agents)], dim=1), next_global_states)
        Q_targets_total = rewards + (gamma * Q_targets_next_total * (1 - dones))
        Q_expected_total = self.mixer_local(torch.cat([self.qnetworks_local[i](states[i]).gather(1, actions[i]) for i in range(self.num_agents)], dim=1), global_states)
        loss = F.mse_loss(Q_expected_total, Q_targets_total)
        self.optimizer.zero_grad(); loss.backward(); self.optimizer.step()
        self.soft_update_target_networks(TAU)
    def soft_update_target_networks(self, tau):
        for local, target in zip(self.qnetworks_local, self.qnetworks_target):
            for target_param, local_param in zip(target.parameters(), local.parameters()): target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
        for target_param, local_param in zip(self.mixer_target.parameters(), self.mixer_local.parameters()): target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
    def save_model(self, drive_path, episode_num): # QMIX 的儲存
        model_dir = os.path.join(drive_path, f'QMIX_episode_{episode_num}'); os.makedirs(model_dir, exist_ok=True)
        for i, net in enumerate(self.qnetworks_local):
            torch.save(net.state_dict(), os.path.join(model_dir, f'qnet_agent_{i}.pth'))
        torch.save(self.mixer_local.state_dict(), os.path.join(model_dir, 'mixer.pth')); print(f"\n--- QMIX Model saved to {model_dir} ---")
    def load_model(self, drive_path, episode_num): # QMIX 的加載
        model_dir = os.path.join(drive_path, f'QMIX_episode_{episode_num}')
        if not os.path.exists(model_dir): return False
        print(f"--- Loading QMIX model from {model_dir} ---")
        for i, net in enumerate(self.qnetworks_local):
            net.load_state_dict(torch.load(os.path.join(model_dir, f'qnet_agent_{i}.pth'), map_location=DEVICE))
        self.mixer_local.load_state_dict(torch.load(os.path.join(model_dir, 'mixer.pth'), map_location=DEVICE))
        self.soft_update_target_networks(1.0); return True


# --- 7. 訓練主迴圈 ---
def train_marl(manager_class, load_episode=0):
    print(f"\n--- 開始訓練 {manager_class.__name__} ---")

    env = GhostHumanEnv(num_ghosts=2)
    predator_ids = [agent for agent in env.agents if 'ghost' in agent]
    state_dims = {id: env.observation_space(id).shape[0] for id in predator_ids}
    action_dims = {id: env.action_space(id).n for id in predator_ids}

    if manager_class == QMIXManager:
        global_state_dim = 5
        marl_manager = QMIXManager(state_dims, action_dims, global_state_dim, predator_ids, seed=42)
    else: # IQLManager
        marl_manager = IQLManager(state_dims, action_dims, predator_ids, seed=42)

    model_loaded = False
    if load_episode > 0:
        # (確保 Colab 儲存格 1 已經執行)
        try:
            model_loaded = marl_manager.load_model(DRIVE_PATH, load_episode)
        except Exception as e:
            print(f"Error loading model, Google Drive might not be connected: {e}")

    if model_loaded:
        print(f"--- Model loaded successfully! Resuming training. ---")
        epsilon = EPS_END # 假設加載的模型已訓練好
    else:
        print(f"--- No model loaded. Starting training from scratch. ---")
        epsilon = EPS_START

    print(f"Starting Training from Episode {load_episode + 1}...")
    start_time = time.time()
    scores_window = deque(maxlen=100)
    capture_window = deque(maxlen=100)

    for i_episode in range(load_episode + 1, NUM_EPISODES + 1):
        observations = env.reset()
        captured_this_episode = False

        for t in range(MAX_T):
            actions = marl_manager.select_actions(observations, epsilon)
            next_observations, rewards, terminations, truncations, dones = env.step(actions)

            team_done = any(terminations.values())
            if team_done and t < (MAX_T - 1): # 如果不是因為超時
                captured_this_episode = True

            marl_manager.push_and_learn(observations, actions, rewards, next_observations, terminations)
            observations = next_observations

            if team_done or all(truncations.values()):
                break

        capture_window.append(1 if captured_this_episode else 0)
        epsilon = max(EPS_END, EPS_DECAY * epsilon)
        # 只記錄鬼的（團隊）獎勵
        scores_window.append(rewards[predator_ids[0]] if predator_ids[0] in rewards else 0)

        if i_episode % MODEL_CHECKPOINT == 0:
            avg_score = np.mean(scores_window)
            capture_rate = np.mean(capture_window) * 100
            print(f'\rEpisode {i_episode}\tAvg Score: {avg_score:.2f}\tCapture Rate: {capture_rate:.1f}%\tEpsilon: {epsilon:.3f}')
            # 儲存檢查點 (確保 Colab 儲存格 1 已經執行)
            try:
                marl_manager.save_model(DRIVE_PATH, i_episode)
            except Exception as e:
                print(f"Error saving model, Google Drive might not be connected: {e}")

    print(f"\n{manager_class.__name__} 訓練完成 in {(time.time() - start_time) / 60:.2f} minutes.")
    env.close()
    return marl_manager.get_manager()

# --- 8. 視覺化測試 ---
def run_test(manager, gif_filename):
    print(f"\n[Test] 測試 {manager.__class__.__name__} (b=4, a1=9, a2=8)...")
    vis_env = GhostHumanEnv(num_ghosts=2)
    observations = vis_env.reset(hard_code_scenario=True)
    frames, captured = [], False
    predator_ids = [agent for agent in vis_env.agents if 'ghost' in agent]

    for t in range(MAX_T):
        frames.append(vis_env.render())
        actions = manager.select_actions(observations, epsilon=0.0)
        next_observations, rewards, terminations, truncations, infos = vis_env.step(actions)
        observations = next_observations

        if any(terminations.values()):
            if t < (MAX_T - 1):
                print(f"  > 成功! 在第 {t+1} 步 ({(t+1)*vis_env.TIME_STEP:.2f} 秒) 抓到!")
                captured = True
            else:
                print(f"  > 失敗. 在第 {t+1} 步 ({(t+1)*vis_env.TIME_STEP:.2f} 秒) 超時.")
            break

    vis_env.close()
    imageio.mimsave(gif_filename, frames, fps=10)
    duration_sec = len(frames) / 10.0
    print(f"  > 測試 GIF 已儲存: {gif_filename} ({duration_sec:.2f} 秒)")
    return captured, duration_sec

# --- 9. 主執行緒 ---
if __name__ == "__main__":

    # 決定是否加載
    # LOAD_EPISODE = 0 # (從頭訓練)
    LOAD_EPISODE = 0 # (假設您想從 1000 回合的檢查點加載)

    # 訓練 IQL
    iql_model = train_marl(IQLManager, LOAD_EPISODE)

    # 訓練 QMIX
    qmix_model = train_marl(QMIXManager, LOAD_EPISODE)

    print("\n" + "=" * 40)
    print(" 最終視覺化比較 ".center(40, "="))
    print("=" * 40)

    # 測試 IQL
    iql_result = run_test(iql_model, "ghost_game_IQL.gif")

    # 測試 QMIX
    qmix_result = run_test(qmix_model, "ghost_game_QMIX.gif")

    print("\n" + "=" * 40)
    print(" 最終結果 ".center(40, "="))
    print(f"  獨立 RL (IQL):   {'成功' if iql_result[0] else '失敗'} (耗時 {iql_result[1]:.2f} 秒)")
    print(f"  合作 MARL (QMIX): {'成功' if qmix_result[0] else '失敗'} (耗時 {qmix_result[1]:.2f} 秒)")
    print("=" * 40)

Collecting pettingzoo
  Downloading pettingzoo-1.25.0-py3-none-any.whl.metadata (8.9 kB)
Collecting pyvirtualdisplay
  Downloading PyVirtualDisplay-3.0-py3-none-any.whl.metadata (943 bytes)
Downloading pettingzoo-1.25.0-py3-none-any.whl (852 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m852.5/852.5 kB[0m [31m36.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading PyVirtualDisplay-3.0-py3-none-any.whl (15 kB)
Installing collected packages: pyvirtualdisplay, pettingzoo
Successfully installed pettingzoo-1.25.0 pyvirtualdisplay-3.0


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)


Virtual display started.
Using device: cpu

--- 開始訓練 IQLManager ---


AttributeError: 'GhostHumanEnv' object has no attribute 'agents'