In [None]:
%pip -q install gymnasium pettingzoo supersuit stable-baselines3

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/852.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━[0m [32m399.4/852.5 kB[0m [31m12.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m849.9/852.5 kB[0m [31m17.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m852.5/852.5 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.0/188.0 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m563.6/563.6 kB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[?25h

## 環境本体

1. シミュレーション環境
   - PettingZoo の `ParallelEnv` として実装された協調型マルチエージェント環境（全車が同一方策で制御、報酬は共通のチーム報酬）。
   - 本線は右向きに進む3車線（lane 0,1,2）で、右側（画面の下側）からの合流車線（lane 3）がある。
   - シミュレーション全体は 1000m（合流前 400m、合流区間 200m、合流後 400m）。
     - `merge_start = 400m`、`merge_end = 600m`、`goal_x = 1000m`
   - 車線幅は `lane_width = 4.0m`。
   - 速度は 80〜150km/h。毎ステップ最大 ±10km/h で変更可能（`dv_kmh = 10`、`vmax_kmh = 150`、`vmin_kmh = 80`、クリップ下限 `vclip_min_kmh = 0`）。
   - 車の数は最大で N（`num_agents`）として可変（デフォルト 12）。すべての車が協調学習・コントロール対象。
   - シミュレーションの 1 ステップは 1 秒換算（`dt = 1.0`）。エピソード最大 500 ステップ（`episode_horizon = 500`）。
   - スポーンは `x=0` で行い、同じレーンからの連続スポーンを避けるためレーン単位のクールダウンあり（最短 5 step、`spawn_lane_cooldown_steps = 5`）。
     - スポーン時速度は 80〜150km/h の一様乱数。
     - 近すぎるスポーンは禁止（`spawn_min_dist = 20m`）。
     - 毎 step で複数回試行（`spawn_attempts_per_step = 8`）。
   - レーン変更には時間がかかり、その間は行動を変更できない（`lane_change_steps = 2`。変更中は lane 行動は keep 固定）。
   - レーン移動の制約：
     - 本線（0,1,2）→ 合流車線（3）へは移動できない。
     - 合流車線（3）→ 本線は lane2 のみに可能で、`x ∈ [merge_start, merge_end]` の合流区間内のみ許可。
   - 衝突判定あり（簡易）：同一レーンで車間が `collision_dist` 未満（設定が無い場合 2.0m）だと衝突扱い（`vehicle_length` がある場合はそれも考慮）。
   - 行き止まり（dead-end）判定あり：lane3 の車が `x >= merge_end` のまま lane3 にいると dead-end。
     - ただし「lane3→lane2 へ合流中に `merge_end` を超えたら即 lane2 に確定」する救済処理がある。
   - ゴール（`x >= goal_x`）や衝突・dead-end はイベントとして記録され、その車は待機プールへ送られて再スポーンする（環境全体の即終了ではない）。
   - 報酬は車毎ではなく共通に一つ持つ協調学習（`common_reward` を全エージェントに同じ値で返す）。

2. 観測空間（17次元）
   - 観測できる距離の最大値は 200m（`max_obs_dist = 200`）。
   - 観測できる行き止まりまでの距離は 400m（`max_deadend_obs = 400`、lane3のみ）。
   - 合流開始地点の 200m 手前（`x >= 200m`）から隣の車線の状態を観測できる（`adj_obs_unlock_x = 200`）。
     - lane2 と lane3 は、それ以前は左右車線の観測を 0 にする。
   - 距離特徴は「最大観測可能距離 - 実際の距離」とする（近いほど大きい、遠い/不在は 0）。
   - 観測ベクトル（active のとき）は以下：
     - `is_active`
     - `has_left_lane`, `has_right_lane`
     - `deadend_dist_feature`（lane3のみ：`400 - dist_to_end`）
     - `self_speed_kmh`
     - 同一車線：前方車の速度・距離、後方車の速度・距離
     - 左隣車線：前方車の速度・距離、後方車の速度・距離
     - 右隣車線：前方車の速度・距離、後方車の速度・距離
   - inactive のときは `is_active=0`、他は 0 のダミー観測。

3. 行動空間
   - `MultiDiscrete([3,3])`
     - 速度変化：-10 / 0 / +10 km/h（1秒あたり最大 10km/h 変更）
     - レーン方向：左 / キープ / 右（ただしレーン変更中はキープ固定）

4. 報酬（協調・共通のチーム報酬）
   - Goal したら +10 点（`reward_goal=10`）× goal 台数。
   - 衝突 or dead-end は -100 点（`reward_crash=-100`）×（衝突/行き止まり台数）。
   - 前の車との実際の距離 < 20m のとき（同一レーン最近接前方車）：
     - `-40 + 実際の距離 * 2`（各車で加算し合計がチーム報酬へ）。
   - 加速/減速ペナルティ：
     - `- |速度変化(km/h)| * 0.1`（全車合計、`accel_penalty_scale=0.1`）。
   - 合流以外のレーンチェンジ：
     - `-0.1` ×（その step で開始した非合流レーン変更台数、`lane_change_penalty=-0.1`）。

**今後改良した方が良さそうな点**
- 車の密度をエピソード毎に可変にした方が良い。
- 行動空間は現在加減速が3種類しかないため増やした方が現実的。
- レーン3の終点までに近づいたときに段階的ペナルティを導入。
- 学習に 500万ステップくらいかかる（形成報酬/カリキュラム等の改善余地）。
- 観測空間を正規化した方が良い（速度/距離スケールを揃える）。
- 報酬の一部は該当する車に分配した方が良いかもしれない（credit assignment 改善）。
- 動画をもっときれいにする。


In [None]:
%%writefile coop_merge_env.py
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Optional, Tuple, List

import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils import seeding

from pettingzoo.utils import ParallelEnv


# -----------------------------
# Config
# -----------------------------
@dataclass
class CoopMergeConfig:
    # geometry (global x along main road)
    pre_merge: float = 400.0
    merge_length: float = 200.0
    post_merge: float = 400.0

    # derived
    @property
    def goal_x(self) -> float:
        return self.pre_merge + self.merge_length + self.post_merge  # 1000

    @property
    def merge_start(self) -> float:
        return self.pre_merge  # 400

    @property
    def merge_end(self) -> float:
        return self.pre_merge + self.merge_length  # 600

    # sim
    episode_horizon: int = 500
    dt: float = 1.0  # 1 step = 1 sec

    # lanes
    lane_width: float = 4.0
    # main lanes indices: 0,1,2; ramp lane index: 3
    main_lane_count: int = 3

    # speeds
    vmin_kmh: float = 80.0
    vmax_kmh: float = 150.0
    dv_kmh: float = 10.0  # per step max change
    # clip min speed (allow 0)
    vclip_min_kmh: float = 0.0

    # perception
    max_obs_dist: float = 200.0
    max_deadend_obs: float = 400.0
    adj_obs_unlock_x: float = 200.0  # "merge_start - 200"

    # spawn / pool
    spawn_min_dist: float = 20.0
    spawn_lane_cooldown_steps: int = 5
    spawn_attempts_per_step: int = 8  # try multiple random lanes each step

    # lane change dynamics
    lane_change_steps: int = 2  # takes time; during this, lane direction is locked

    # collisions (simple geometric)
    collision_radius: float = 2.5  # meters (rough)

    # reward
    reward_goal: float = 10.0
    reward_crash: float = -100.0
    # close-front penalty: if front_dist < 20m -> -40 + dist*2
    close_dist_threshold: float = 20.0
    close_penalty_base: float = -40.0
    close_penalty_slope: float = 2.0
    accel_penalty_scale: float = 0.1
    lane_change_penalty: float = -0.1  # except merging action (ramp->lane2)


# -----------------------------
# Helpers
# -----------------------------
def kmh_to_mps(v_kmh: float) -> float:
    return float(v_kmh) / 3.6

def mps_to_kmh(v_mps: float) -> float:
    return float(v_mps) * 3.6


class CoopMergeEnv(ParallelEnv):
    """
    Cooperative multi-agent merge environment (fixed N agents, waiting pool/inactive supported).
    - All agents exist always (fixed N); inactive agents get dummy obs and their actions are forced NO-OP.
    - Common reward: returned identically for all agents.
    - Action: MultiDiscrete([3, 3]) = (speed_delta, lane_dir)
        speed_delta: 0->-10km/h, 1->0, 2->+10km/h
        lane_dir:    0->left,    1->keep, 2->right
    """

    metadata = {
        "name": "coop_merge_env",
        "is_parallelizable": True,
        "render_modes": [None],   # ←追加（空でもいいが、Noneを入れておくと安全）
    }

    def __init__(self, num_agents: int = 12, config: Optional[CoopMergeConfig] = None,
                 seed: Optional[int] = None, render_mode: Optional[str] = None):
        super().__init__()
        self.cfg = config or CoopMergeConfig()
        self._num_agents = int(num_agents)

        self.render_mode = render_mode   # ←これが必要（SuperSuitが参照）

        self.possible_agents = [f"agent_{i}" for i in range(self._num_agents)]
        self.agents = self.possible_agents[:]  # fixed list
        self.agent_spawn_cd = np.zeros((self._num_agents,), dtype=np.int64)

        # obs dim:
        # [is_active,
        #  has_left_lane, has_right_lane, deadend_dist_feat, self_speed_kmh,
        #  (front/back same lane: speed, dist) *2,
        #  (front/back left lane: speed, dist) *2,
        #  (front/back right lane: speed, dist) *2]
        # => 1 + 2 + 1 + 1 + (2*2)*3 = 1+2+1+1+12 = 17
        self.obs_dim = 17

        self._obs_space = spaces.Box(low=-1e9, high=1e9, shape=(self.obs_dim,), dtype=np.float32)
        # 2D discrete: MultiDiscrete( [speed_action(3), lane_action(3)] )
        self._act_space = spaces.MultiDiscrete(np.array([3, 3], dtype=np.int64))

        self.np_random, _ = seeding.np_random(seed)

        # state arrays
        self.t = 0
        self.active = np.zeros((self._num_agents,), dtype=bool)
        self.x = np.zeros((self._num_agents,), dtype=np.float64)
        self.y = np.zeros((self._num_agents,), dtype=np.float64)
        self.lane = np.zeros((self._num_agents,), dtype=np.int64)  # 0,1,2,3
        self.v_mps = np.zeros((self._num_agents,), dtype=np.float64)

        # lane-change (smooth y interpolation)
        self.lc_rem = np.zeros((self._num_agents,), dtype=np.int64)
        self.lc_tot = np.zeros((self._num_agents,), dtype=np.int64)
        self.lc_start_y = np.zeros((self._num_agents,), dtype=np.float64)
        self.lc_end_y = np.zeros((self._num_agents,), dtype=np.float64)
        self.lc_target_lane = np.zeros((self._num_agents,), dtype=np.int64)

        # per-lane spawn cooldown (0..3)
        self.spawn_cd = np.zeros((4,), dtype=np.int64)

        # cache crash / goal events
        self._crashed_any = False

    # -------------------------
    # PettingZoo API
    # -------------------------

    def observation_space(self, agent: str):
        return self._obs_space

    def action_space(self, agent: str):
        return self._act_space

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        if seed is not None:
            self.np_random, _ = seeding.np_random(seed)

        self.t = 0
        self._crashed_any = False
        self.agent_spawn_cd[:] = 0

        self.active[:] = False
        self.x[:] = 0.0
        self.y[:] = 0.0
        self.lane[:] = 0
        self.v_mps[:] = 0.0

        self.lc_rem[:] = 0
        self.lc_tot[:] = 0
        self.lc_start_y[:] = 0.0
        self.lc_end_y[:] = 0.0
        self.lc_target_lane[:] = 0

        self.spawn_cd[:] = 0

        # 初期は全員待機プールに入れておき、stepごとに安全にスポーン
        # ただし "最初からある程度走っている車" が欲しければ、ここで数台だけ強制スポーンしてOK
        self._spawn_from_pool(max_new=min(self._num_agents, 4))

        obs = {a: self._obs(i) for i, a in enumerate(self.possible_agents)}
        infos = {a: {} for a in self.possible_agents}
        return obs, infos

    def step(self, actions: Dict[str, np.ndarray]):
        common_reward = 0.0

        # 1) lane cooldown 更新
        self.spawn_cd = np.maximum(self.spawn_cd - 1, 0)
        self.agent_spawn_cd = np.maximum(self.agent_spawn_cd - 1, 0)


        # 2) 行動適用（inactive は無視）
        dv_kmh = self.cfg.dv_kmh
        applied_abs_dv_kmh = np.zeros((self._num_agents,), dtype=np.float64)
        lane_change_started = np.zeros((self._num_agents,), dtype=bool)
        lane_change_is_merge = np.zeros((self._num_agents,), dtype=bool)

        for i, agent in enumerate(self.possible_agents):
            if not self.active[i]:
                continue

            act = actions.get(agent, np.array([1, 1], dtype=np.int64))
            act = np.asarray(act)

            if self.lc_rem[i] > 0:
                lane_act = 1
            else:
                lane_act = int(act[1])
            speed_act = int(act[0])

            dv = {-1: -dv_kmh, 0: 0.0, 1: dv_kmh}[speed_act - 1]
            applied_abs_dv_kmh[i] = abs(dv)

            v_kmh = mps_to_kmh(self.v_mps[i]) + dv
            v_kmh = float(np.clip(v_kmh, self.cfg.vclip_min_kmh, self.cfg.vmax_kmh))
            self.v_mps[i] = kmh_to_mps(v_kmh)

            if self.lc_rem[i] == 0:
                dir_ = {0: -1, 1: 0, 2: +1}[lane_act]
                if dir_ != 0:
                    ok, new_lane, is_merge = self._can_start_lane_change(i, dir_)
                    if ok:
                        lane_change_started[i] = True
                        lane_change_is_merge[i] = is_merge
                        self._start_lane_change(i, new_lane)

        # 3) 物理更新
        for i in range(self._num_agents):
            if not self.active[i]:
                continue

            # move
            self.x[i] += self.v_mps[i] * self.cfg.dt

            # lane change progress (smooth y)
            if self.lc_rem[i] > 0:
                self.lc_rem[i] -= 1
                alpha = 1.0 - (self.lc_rem[i] / max(1, self.lc_tot[i]))
                self.y[i] = (1.0 - alpha) * self.lc_start_y[i] + alpha * self.lc_end_y[i]

                # --- ★重要：合流（ramp 3 -> lane2）中に merge_end を超えたら即 lane2 に確定 ---
                if int(self.lane[i]) == 3 and int(self.lc_target_lane[i]) == 2 and float(self.x[i]) >= self.cfg.merge_end:
                    self.lane[i] = 2
                    self.lc_rem[i] = 0
                    self.lc_tot[i] = 0
                    self.y[i] = self._lane_center_y(2, float(self.x[i]))
                    self.lc_start_y[i] = self.y[i]
                    self.lc_end_y[i] = self.y[i]
                    self.lc_target_lane[i] = 2

                elif self.lc_rem[i] == 0:
                    # finalize lane (通常の確定)
                    self.lane[i] = int(self.lc_target_lane[i])
                    self.y[i] = self._lane_center_y(int(self.lane[i]), float(self.x[i]))

            else:
                self.y[i] = self._lane_center_y(int(self.lane[i]), float(self.x[i]))


        # 4) dead-end mask
        deadend_mask = np.zeros((self._num_agents,), dtype=bool)
        for i in range(self._num_agents):
            if not self.active[i]:
                continue
            if int(self.lane[i]) == 3 and float(self.x[i]) >= self.cfg.merge_end:
                deadend_mask[i] = True

        # 5) collision mask（簡易版）
        collision_mask = np.zeros((self._num_agents,), dtype=bool)
        collision_dist = float(getattr(self.cfg, "collision_dist", 2.0))
        vehicle_length = getattr(self.cfg, "vehicle_length", None)
        if vehicle_length is not None:
            collision_dist = max(collision_dist, 0.5 * float(vehicle_length))

        if np.any(self.active):
            for L in np.unique(self.lane[self.active].astype(int)):
                idx = np.where(self.active & (self.lane.astype(int) == int(L)))[0]
                if idx.size <= 1:
                    continue
                order = idx[np.argsort(self.x[idx])]
                gaps = np.diff(self.x[order])
                hit = np.where(gaps < collision_dist)[0]
                if hit.size > 0:
                    collision_mask[order[hit]] = True
                    collision_mask[order[hit + 1]] = True

        crash_mask = deadend_mask | collision_mask
        n_crash = int(np.sum(crash_mask))
        if n_crash > 0:
            common_reward += float(self.cfg.reward_crash) * float(n_crash)

        # (b) goal は「マークだけ」して、ここでは deactivate しない
        goal_mask = np.zeros((self._num_agents,), dtype=bool)
        goal_hits = 0
        for i in range(self._num_agents):
            if not self.active[i]:
                continue
            if float(self.x[i]) >= self.cfg.goal_x:
                goal_mask[i] = True
                goal_hits += 1

        if goal_hits > 0:
            common_reward += self.cfg.reward_goal * float(goal_hits)

        # (c) close-front penalty（必要なら crash/goal も含めて計算でOK）
        common_reward += self._close_front_penalty_sum()

        # (d) accel penalty
        common_reward += -self.cfg.accel_penalty_scale * float(np.sum(applied_abs_dv_kmh[self.active]))

        # (e) lane change penalty except merging
        if np.any(lane_change_started):
            non_merge = lane_change_started & (~lane_change_is_merge)
            common_reward += self.cfg.lane_change_penalty * float(np.sum(non_merge))

               # --- 7) プール送り：マスクを見て「必ず」実行 ---
        active_pre = self.active.copy()  # infos用（pool送り前）
        # ADD: event位置を保持（pool送り/respawn後に座標が変わるので必須）
        x_pre = self.x.copy()
        y_pre = self.y.copy()
        lane_pre = self.lane.copy()

        to_pool = crash_mask | goal_mask

        if np.any(to_pool):
            for i in np.where(to_pool)[0]:
                reason = "crash" if crash_mask[i] else "goal"
                self._deactivate_to_pool(int(i), reason=reason)  # ★ reason を渡す

        # --- 8) spawn ---
        self._spawn_from_pool(max_new=self._num_agents)

        # --- 9) time ---
        self.t += 1
        time_up = (self.t >= int(self.cfg.episode_horizon))

        obs = {a: self._obs(i) for i, a in enumerate(self.possible_agents)}
        rewards = {a: float(common_reward) for a in self.possible_agents}
        terminations = {a: False for a in self.possible_agents}
        truncations = {a: bool(time_up) for a in self.possible_agents}

        # ★ infos にイベントフラグ + 位置情報を出す（互換: keys追加のみ）
        # ★追加：非merge lane change の数（このstepで開始したもの）
        non_merge_lc = lane_change_started & (~lane_change_is_merge)
        non_merge_lc_count = int(np.sum(non_merge_lc))

        # ★追加：merge lane change の数（このstepで開始したもの）
        merge_lc_count = int(np.sum(lane_change_started & lane_change_is_merge))

        # ★追加：加減速（|dv|>0）が入った回数
        accel_count = int(np.sum((applied_abs_dv_kmh[self.active] > 0)))

        # ★追加：このstepでのイベント総数（team）
        goal_count      = int(goal_hits)
        crash_count     = int(n_crash)
        deadend_count   = int(np.sum(deadend_mask))
        collision_count = int(np.sum(collision_mask))

        infos = {}
        for i, a in enumerate(self.possible_agents):
            infos[a] = {
                # 既存
                "active_pre": bool(active_pre[i]),
                "event_goal": bool(goal_mask[i]),
                "event_crash": bool(crash_mask[i]),
                "event_deadend": bool(deadend_mask[i]),
                "event_collision": bool(collision_mask[i]),

                # ★追加（per-agent でも team でも使える）
                "x": float(self.x[i]),
                "y": float(self.y[i]),
                "lane": int(self.lane[i]),
                "event_x": float(x_pre[i]),
                "event_y": float(y_pre[i]),
                "event_lane": int(lane_pre[i]),

                # ★追加（team集計：どのagentでも同じ値を持つ）
                "team_goal_count": goal_count,
                "team_crash_count": crash_count,
                "team_deadend_count": deadend_count,
                "team_collision_count": collision_count,
                "team_nonmerge_lc_count": non_merge_lc_count,
                "team_merge_lc_count": merge_lc_count,
                "team_accel_count": accel_count,
            }
        return obs, rewards, terminations, truncations, infos




    # -------------------------
    # Core mechanics
    # -------------------------
    def _lane_center_y(self, lane: int, x: float) -> float:
        # main lanes: y = 0,4,8 (right-hand traffic doesn't matter in top-down)
        if lane in (0, 1, 2):
            return float(lane) * self.cfg.lane_width

        # ramp lane (3): approach then merge-lane
        # - x < merge_start: y decreases linearly (farther down)
        # - merge_start..merge_end: y=12 (one lane below lane2 at y=8)
        # Note: this gives a "ramp approach" while keeping merge-lane length=200m.
        y_merge = 3.0 * self.cfg.lane_width  # 12
        y_far = 4.0 * self.cfg.lane_width    # 16
        if x <= self.cfg.merge_start:
            # linear from (0,y_far) -> (merge_start, y_merge)
            if self.cfg.merge_start <= 1e-9:
                return y_merge
            a = float(np.clip(x / self.cfg.merge_start, 0.0, 1.0))
            return (1.0 - a) * y_far + a * y_merge
        else:
            return y_merge

    def _can_start_lane_change(self, i: int, dir_: int) -> Tuple[bool, int, bool]:
        cur = int(self.lane[i])
        new_lane = cur + int(dir_)

        # bounds (0..3)
        if new_lane < 0 or new_lane > 3:
            return False, cur, False

        # disallow main -> ramp
        if cur in (0, 1, 2) and new_lane == 3:
            return False, cur, False

        # ramp -> main only allowed into lane2 and only in merge zone
        if cur == 3:
            # must go left into lane2
            if new_lane != 2:
                return False, cur, False
            x = float(self.x[i])
            if x < self.cfg.merge_start or x > self.cfg.merge_end:
                return False, cur, False
            return True, new_lane, True  # merging

        # main lanes normal (0<->1<->2)
        return True, new_lane, False

    def _start_lane_change(self, i: int, new_lane: int) -> None:
        self.lc_tot[i] = int(self.cfg.lane_change_steps)
        self.lc_rem[i] = int(self.cfg.lane_change_steps)
        self.lc_target_lane[i] = int(new_lane)
        self.lc_start_y[i] = float(self.y[i])
        self.lc_end_y[i] = self._lane_center_y(int(new_lane), float(self.x[i]))

    def _check_any_collision(self) -> bool:
        # O(N^2) - OK for moderate N
        idxs = np.where(self.active)[0]
        if len(idxs) <= 1:
            return False
        r = float(self.cfg.collision_radius)
        rr = r * r
        for a in range(len(idxs)):
            i = int(idxs[a])
            xi, yi = float(self.x[i]), float(self.y[i])
            for b in range(a + 1, len(idxs)):
                j = int(idxs[b])
                dx = float(self.x[j]) - xi
                dy = float(self.y[j]) - yi
                if (dx * dx + dy * dy) <= rr:
                    return True
        return False

    def _close_front_penalty_sum(self) -> float:
        # For each active vehicle: find nearest front vehicle in same (integer) lane.
        # If dist < 20 => -40 + dist*2
        penalty = 0.0
        for i in range(self._num_agents):
            if not self.active[i]:
                continue
            li = int(self.lane[i])
            xi = float(self.x[i])
            # find front in same lane among active
            best = None
            for j in range(self._num_agents):
                if i == j or (not self.active[j]):
                    continue
                if int(self.lane[j]) != li:
                    continue
                dx = float(self.x[j]) - xi
                if dx <= 0:
                    continue
                if best is None or dx < best:
                    best = dx
            if best is not None and best < self.cfg.close_dist_threshold:
                penalty += self.cfg.close_penalty_base + self.cfg.close_penalty_slope * float(best)
        return float(penalty)

    def _deactivate_to_pool(self, i: int, reason: str = ""):
        self.active[i] = False

        # block respawn for at least 1 step (prevents lane "teleport" in same step)
        self.agent_spawn_cd[i] = 1

        # hard reset kinematics
        self.x[i] = 0.0
        self.v_mps[i] = 0.0

        # hard reset lane / lateral position
        self.lane[i] = 0
        self.y[i] = self._lane_center_y(0, 0.0)

        # hard reset lane-change state
        self.lc_rem[i] = 0
        self.lc_tot[i] = 0
        self.lc_start_y[i] = self.y[i]
        self.lc_end_y[i] = self.y[i]
        self.lc_target_lane[i] = 0

        if not hasattr(self, "_last_deactivate_reason"):
            self._last_deactivate_reason = [""] * self._num_agents
        self._last_deactivate_reason[i] = str(reason)




    def _is_spawn_position_free(self, x: float, y: float) -> bool:
        # check min distance to all active vehicles
        md = float(self.cfg.spawn_min_dist)
        md2 = md * md
        for j in range(self._num_agents):
            if not self.active[j]:
                continue
            dx = float(self.x[j]) - x
            dy = float(self.y[j]) - y
            if (dx * dx + dy * dy) < md2:
                return False
        return True

    def _spawn_one(self, agent_index: int, lane: int) -> bool:
        # spawn at x=0 fixed
        if self.spawn_cd[lane] > 0:
            return False

        x0 = 0.0
        y0 = self._lane_center_y(lane, x0)
        if not self._is_spawn_position_free(x0, y0):
            return False

        v0_kmh = float(self.np_random.uniform(self.cfg.vmin_kmh, self.cfg.vmax_kmh))
        self.active[agent_index] = True
        self.x[agent_index] = x0
        self.y[agent_index] = y0
        self.lane[agent_index] = int(lane)
        self.v_mps[agent_index] = kmh_to_mps(v0_kmh)

        self.lc_rem[agent_index] = 0
        self.lc_tot[agent_index] = 0

        # cooldown for that lane
        self.spawn_cd[lane] = int(self.cfg.spawn_lane_cooldown_steps)
        return True

    def _spawn_from_pool(self, max_new: int) -> int:
        if max_new <= 0:
            return 0

        # only spawn agents that are inactive AND not in cooldown
        inactive = np.where((~self.active) & (self.agent_spawn_cd == 0))[0]
        if len(inactive) == 0:
            return 0

        spawned = 0
        # multiple attempts per step to avoid "stuck"
        attempts = int(self.cfg.spawn_attempts_per_step)
        # candidate lanes: 0..3
        lanes = [0, 1, 2, 3]

        # shuffle inactive order
        self.np_random.shuffle(inactive)

        for idx in inactive:
            if spawned >= max_new:
                break
            ok = False
            for _ in range(attempts):
                lane = int(self.np_random.choice(lanes))
                if self._spawn_one(int(idx), lane):
                    ok = True
                    spawned += 1
                    break
            if not ok:
                # couldn't spawn this agent now
                continue
        return spawned


    # -------------------------
    # Observation
    # -------------------------
    def _adj_obs_allowed(self, i: int) -> bool:
        # restriction applies to ramp lane (3) and rightmost main lane (2)
        li = int(self.lane[i])
        if li in (2, 3):
            return float(self.x[i]) >= self.cfg.adj_obs_unlock_x
        return True

    def _has_left_lane(self, li: int) -> bool:
        return li > 0 and li != 3  # ramp has left (lane2) but treat separately

    def _has_right_lane(self, li: int, x: float) -> bool:
        # main lanes: right exists only for lane2 IF ramp exists at this x (< merge_end)
        if li == 2:
            return x < self.cfg.merge_end
        # ramp has no right
        return False

    def _deadend_dist_feature(self, i: int) -> float:
        # feature = max_deadend_obs - actual_dist (clipped)
        if int(self.lane[i]) == 3:
            # ramp ends at merge_end
            dist = max(0.0, self.cfg.merge_end - float(self.x[i]))
            dist = float(np.clip(dist, 0.0, self.cfg.max_deadend_obs))
            return float(self.cfg.max_deadend_obs - dist)
        # main lanes: no deadend within 400 => treat dist=max -> feature 0
        return 0.0

    def _nearest_front_back_in_lane(self, i: int, lane_id: int) -> Tuple[float, float, float, float]:
        """
        return (front_speed_kmh, front_dist_feat, back_speed_kmh, back_dist_feat)
        where dist_feat = max_obs_dist - actual_dist (clip)
        if none: speed=0, dist_feat=0
        """
        if not self.active[i]:
            return 0.0, 0.0, 0.0, 0.0

        xi = float(self.x[i])
        best_front = None
        best_back = None
        front_j = -1
        back_j = -1

        for j in range(self._num_agents):
            if j == i or (not self.active[j]):
                continue
            if int(self.lane[j]) != int(lane_id):
                continue
            dx = float(self.x[j]) - xi
            if dx > 0:
                if (best_front is None) or (dx < best_front):
                    best_front = dx
                    front_j = j
            elif dx < 0:
                dd = -dx
                if (best_back is None) or (dd < best_back):
                    best_back = dd
                    back_j = j

        def dist_feat(d: Optional[float]) -> float:
            if d is None:
                return 0.0
            d2 = float(np.clip(d, 0.0, self.cfg.max_obs_dist))
            return float(self.cfg.max_obs_dist - d2)

        if best_front is None:
            fs, fd = 0.0, 0.0
        else:
            fs = float(mps_to_kmh(self.v_mps[front_j]))
            fd = dist_feat(best_front)

        if best_back is None:
            bs, bd = 0.0, 0.0
        else:
            bs = float(mps_to_kmh(self.v_mps[back_j]))
            bd = dist_feat(best_back)

        return fs, fd, bs, bd

    def _obs(self, i: int) -> np.ndarray:
        # inactive: dummy obs
        if not self.active[i]:
            # is_active=0 and the rest zeros
            o = np.zeros((self.obs_dim,), dtype=np.float32)
            o[0] = 0.0
            return o

        li = int(self.lane[i])
        xi = float(self.x[i])
        vi_kmh = float(mps_to_kmh(self.v_mps[i]))

        adj_ok = self._adj_obs_allowed(i)

        # lane existence flags (also gated if adj not allowed for lane2/3)
        if not adj_ok and li in (2, 3):
            has_left = 0.0
            has_right = 0.0
        else:
            if li == 3:
                has_left = 1.0  # ramp has left (lane2)
                has_right = 0.0
            else:
                has_left = 1.0 if (li > 0) else 0.0
                has_right = 1.0 if self._has_right_lane(li, xi) else 0.0

        deadend_feat = self._deadend_dist_feature(i)

        # same lane neighbors always observable (within 200m cap)
        fs, fd, bs, bd = self._nearest_front_back_in_lane(i, li)

        # adjacent lanes (gated)
        if not adj_ok and li in (2, 3):
            lfs = lfd = lbs = lbd = 0.0
            rfs = rfd = rbs = rbd = 0.0
        else:
            # left lane id
            if li == 3:
                left_id = 2
            else:
                left_id = li - 1
            if left_id < 0 or (li == 0):
                lfs = lfd = lbs = lbd = 0.0
            else:
                lfs, lfd, lbs, lbd = self._nearest_front_back_in_lane(i, left_id)

            # right lane id
            # only lane2 has ramp on right; lane0/1 right is main, ramp none
            if li == 2 and (xi < self.cfg.merge_end):
                right_id = 3
                rfs, rfd, rbs, rbd = self._nearest_front_back_in_lane(i, right_id)
            else:
                # main lanes: right is li+1
                if li in (0, 1):
                    right_id = li + 1
                    rfs, rfd, rbs, rbd = self._nearest_front_back_in_lane(i, right_id)
                else:
                    rfs = rfd = rbs = rbd = 0.0

        o = np.array(
            [
                1.0,                    # is_active
                has_left, has_right,    # lane existence
                deadend_feat,           # dead-end distance feature
                vi_kmh,                 # self speed km/h

                fs, fd, bs, bd,         # same lane front/back
                lfs, lfd, lbs, lbd,     # left lane front/back
                rfs, rfd, rbs, rbd,     # right lane front/back
            ],
            dtype=np.float32
        )
        assert o.shape == (self.obs_dim,)
        return o


Writing coop_merge_env.py


In [None]:
# モジュールを更新した場合の再読み込み
# import coop_merge_env
# import importlib

# importlib.reload(coop_merge_env)

# # 以降は「reload後のモジュール」から取り直すのが安全
# from coop_merge_env import CoopMergeEnv, CoopMergeConfig


In [None]:
# デバッグ用

import numpy as np

def run_goal_crash_sanity_infos(env, steps=700, print_every=50, seed=0):
    obs, infos = env.reset(seed=seed)

    total_goal = 0
    total_deadend = 0
    total_collision = 0
    total_crash = 0

    for t in range(steps):
        actions = {a: env.action_space(a).sample() for a in env.possible_agents}

        obs, rew, term, trunc, infos = env.step(actions)

        # infosからイベント集計（そのstepで起きたイベント）
        goal_hits = sum(int(infos[a].get("event_goal", False)) for a in env.possible_agents)
        dead_hits = sum(int(infos[a].get("event_deadend", False)) for a in env.possible_agents)
        col_hits  = sum(int(infos[a].get("event_collision", False)) for a in env.possible_agents)
        crash_hits= sum(int(infos[a].get("event_crash", False)) for a in env.possible_agents)

        total_goal += goal_hits
        total_deadend += dead_hits
        total_collision += col_hits
        total_crash += crash_hits

        if (t + 1) % print_every == 0:
            x_max = float(np.max(env.x[env.active])) if np.any(env.active) else 0.0
            print(f"[t={t+1:4d}] x_max={x_max:8.3f} active_sum={int(np.sum(env.active))} "
                  f"goal+{goal_hits} deadend+{dead_hits} coll+{col_hits} crash+{crash_hits} trunc={any(trunc.values())}")

        if any(trunc.values()):
            print(f"[OK] time_up trunc at t={t+1}")
            break

    print("=== totals ===")
    print("goal:", total_goal, "deadend:", total_deadend, "collision:", total_collision, "crash:", total_crash)

# 実行
import sys
sys.path.append("..")
from coop_merge_env import CoopMergeEnv, CoopMergeConfig


env = CoopMergeEnv(num_agents=12, config=CoopMergeConfig(), seed=0)
run_goal_crash_sanity_infos(env, steps=700, print_every=50, seed=0)


[t=  50] x_max= 926.247 active_sum=12 goal+0 deadend+0 coll+0 crash+0 trunc=False
[t= 100] x_max= 999.106 active_sum=11 goal+1 deadend+0 coll+0 crash+0 trunc=False
[t= 150] x_max= 956.614 active_sum=12 goal+0 deadend+0 coll+0 crash+0 trunc=False
[t= 200] x_max= 989.269 active_sum=12 goal+0 deadend+0 coll+0 crash+0 trunc=False
[t= 250] x_max= 871.370 active_sum=12 goal+0 deadend+0 coll+0 crash+0 trunc=False
[t= 300] x_max= 931.232 active_sum=11 goal+0 deadend+0 coll+0 crash+0 trunc=False
[t= 350] x_max= 938.889 active_sum=12 goal+0 deadend+0 coll+0 crash+0 trunc=False
[t= 400] x_max= 911.793 active_sum=12 goal+0 deadend+0 coll+0 crash+0 trunc=False
[t= 450] x_max= 828.383 active_sum=12 goal+0 deadend+0 coll+0 crash+0 trunc=False
[t= 500] x_max= 902.992 active_sum=11 goal+1 deadend+0 coll+0 crash+0 trunc=True
[OK] time_up trunc at t=500
=== totals ===
goal: 116 deadend: 3 collision: 53 crash: 56


## 実行

In [None]:
import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env import VecMonitor

base = CoopMergeEnv(num_agents=12, config=CoopMergeConfig(), seed=0)

# PettingZoo ParallelEnv -> (gym-like) vector env
venv = ss.pettingzoo_env_to_vec_env_v1(base)

# ★ここが必須：SB3 がそのまま食える VecEnv にする（num_vec_envs=1 でもやる）
venv = ss.concat_vec_envs_v1(
    venv,
    num_vec_envs=1,          # まずは 1 でOK。並列したければ増やす
    num_cpus=0,              # Colabなら 0 でOK（multiprocessing無し）
    base_class="stable_baselines3",
)
venv = VecMonitor(venv)

model = PPO("MlpPolicy", venv, verbose=1)


# --- checkpoint: 途中保存（任意だが推奨） ---
ckpt_cb = CheckpointCallback(
    save_freq=500_000,                 # 50万stepごとに保存
    save_path="./checkpoints_coopmerge",
    name_prefix="ppo_coopmerge_finetune"
)

# 学習
model.learn(
    total_timesteps=100_000,           #★実際には 5 000 000 くらい必要
    reset_num_timesteps=False,         # 既存の続きとして学習曲線を繋ぐ
    callback=ckpt_cb,
    progress_bar=True
)

# --- 保存 ---
save_path = "./ppo_trained" + ".zip"
model.save(save_path)
print("saved:", save_path)

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


Using cpu device


  return datetime.utcnow().replace(tzinfo=utc)


Output()

----------------------------------
| rollout/           |           |
|    ep_len_mean     | 500       |
|    ep_rew_mean     | -1.36e+04 |
| time/              |           |
|    fps             | 3882      |
|    iterations      | 1         |
|    time_elapsed    | 6         |
|    total_timesteps | 24576     |
----------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 500         |
|    ep_rew_mean          | -1.38e+04   |
| time/                   |             |
|    fps                  | 1301        |
|    iterations           | 2           |
|    time_elapsed         | 37          |
|    total_timesteps      | 49152       |
| train/                  |             |
|    approx_kl            | 0.006056854 |
|    clip_fraction        | 0.0363      |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.19       |
|    explained_variance   | -0.000229   |
|    learning_rate  

saved: ./ppo_trained.zip


In [None]:
import numpy as np

# -----------------------
# VecEnv (num_envs=1) : 学習時と同じ作り方
# -----------------------
base = CoopMergeEnv(num_agents=12, config=CoopMergeConfig(), seed=0)

# PettingZoo ParallelEnv -> (gym-like) vector env
venv = ss.pettingzoo_env_to_vec_env_v1(base)

venv = ss.concat_vec_envs_v1(
    venv,
    num_vec_envs=1,          # まずは 1 でOK。並列したければ増やす
    num_cpus=0,              # Colabなら 0 でOK（multiprocessing無し）
    base_class="stable_baselines3",
)
venv = VecMonitor(venv)

# -----------------------
# Load -> (optional tweak) -> Learn more -> Save
# -----------------------
model = PPO.load(save_path, env=venv, device="cpu")

# --- 「小さめのランダム性」：探索（entropy）を下げる ---
# PPO の exploration は主に entropy 正則化係数 ent_coef で調整します。
# 例: 既存が ent_coef=0.01 くらいなら 0.001 に下げる。
model.ent_coef = 0.001

# 学習率も落とす
model.learning_rate = 1e-4

# --- checkpoint: 途中保存（任意だが推奨） ---
ckpt_cb = CheckpointCallback(
    save_freq=500_000,                 # 50万stepごとに保存
    save_path="./checkpoints_coopmerge",
    name_prefix="ppo_coopmerge_finetune"
)

# --- 追加学習 ---
model.learn(
    total_timesteps= 100_000,           # 適当に変える
    reset_num_timesteps=False,          # 既存の続きとして学習曲線を繋ぐ
    callback=ckpt_cb,
    progress_bar=True
)

# --- 保存 ---
save_path2 = "./ppo_trained2" + ".zip"
model.save(save_path2)
print("saved:", save_path2)

Output()



----------------------------------
| rollout/           |           |
|    ep_len_mean     | 500       |
|    ep_rew_mean     | -1.16e+04 |
| time/              |           |
|    fps             | 3560      |
|    iterations      | 1         |
|    time_elapsed    | 6         |
|    total_timesteps | 147456    |
----------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 500         |
|    ep_rew_mean          | -1.18e+04   |
| time/                   |             |
|    fps                  | 1377        |
|    iterations           | 2           |
|    time_elapsed         | 35          |
|    total_timesteps      | 172032      |
| train/                  |             |
|    approx_kl            | 0.007044844 |
|    clip_fraction        | 0.0586      |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.14       |
|    explained_variance   | 0           |
|    learning_rate  

saved: ./ppo_trained2.zip


評価テスト

In [None]:
import numpy as np
import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecMonitor
from coop_merge_env import CoopMergeEnv, CoopMergeConfig

NUM_ENVS   = 12   # ← concat の複製数（= インスタンス数）
NUM_AGENTS = 12   # ← 1インスタンス内のエージェント数

def make_vec_env(num_envs=NUM_ENVS, seed=0):
    base = CoopMergeEnv(num_agents=NUM_AGENTS, config=CoopMergeConfig(), seed=seed)
    venv = ss.pettingzoo_env_to_vec_env_v1(base)
    venv = ss.concat_vec_envs_v1(
        venv, num_vec_envs=num_envs, num_cpus=0, base_class="stable_baselines3"
    )
    venv = VecMonitor(venv)
    return venv

def step_vecenv_compat(venv, action):
    out = venv.step(action)
    if len(out) == 4:
        obs, rew, done, infos = out
        return obs, rew, done, infos
    if len(out) == 5:
        obs, rew, term, trunc, infos = out
        done = np.logical_or(term, trunc)
        return obs, rew, done, infos
    raise RuntimeError("unexpected step format")

def eval_team_over_seeds(model_path, seeds, max_steps=500, device="cpu", check_dup=True):
    all_team_returns = []

    for s in seeds:
        venv = make_vec_env(NUM_ENVS, seed=int(s))
        assert venv.num_envs == NUM_ENVS * NUM_AGENTS, venv.num_envs  # 144 のはず

        model = PPO.load(model_path, env=venv, device=device)

        obs = venv.reset()
        ep_team = np.zeros(NUM_ENVS, dtype=np.float64)  # ★ 12本（インスタンス数）だけ積算

        for _ in range(max_steps):
            act, _ = model.predict(obs, deterministic=True)
            obs, rew, done, infos = step_vecenv_compat(venv, act)

            rew = np.asarray(rew, dtype=np.float64).reshape(NUM_ENVS, NUM_AGENTS)

            if check_dup:
                # インスタンス内で12エージェント分の報酬が同一か確認
                max_err = np.max(np.abs(rew - rew[:, [0]]))
                if max_err > 1e-9:
                    print(f"[seed {s}] WARNING: reward not duplicated within instance. max_err={max_err:e}")

            ep_team += rew[:, 0]  # ★ 代表1本だけ足す（=チーム報酬）

            # done も同様にインスタンス代表で判定（どれも同じなら [:,0] でOK）
            done2 = np.asarray(done).reshape(NUM_ENVS, NUM_AGENTS)
            if np.all(done2[:, 0]):
                break

        all_team_returns.extend(ep_team.tolist())
        venv.close()

    return np.array(all_team_returns, dtype=np.float64)

# 実行
# 保存済みモデルを選択
# model_path = "./ppo_coopmerge_finetuned_more.zip" # pretrained model
model_path = save_path2

seeds = list(range(1000, 1020))

rets = eval_team_over_seeds(model_path, seeds)

print("team eval over seeds:")
print("  n      =", len(rets))  # ★ 240 のはず
print("  mean   =", float(rets.mean()))
print("  std    =", float(rets.std()))
print("  min/max=", float(rets.min()), float(rets.max()))


team eval over seeds:
  n      = 240
  mean   = -6400.197013524547
  std    = 752.0056025197899
  min/max= -7474.500192955136 -4852.023220293224


### 結果の描画

In [None]:
# --- 依存（Colab等） ---
!pip -q install "imageio[ffmpeg]" >/dev/null

import numpy as np
import imageio.v2 as imageio
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from stable_baselines3 import PPO
from coop_merge_env import CoopMergeEnv, CoopMergeConfig

base = CoopMergeEnv(num_agents=12, config=CoopMergeConfig(), seed=0)
venv = ss.pettingzoo_env_to_vec_env_v1(base)
venv = ss.concat_vec_envs_v1(
    venv,
    num_vec_envs=1,
    num_cpus=0,
    base_class="stable_baselines3",
)
venv = VecMonitor(venv)

# 保存済みモデルを選択
# model_path = "./ppo_coopmerge_finetuned_more.zip" # pretrained model
model_path = save_path2
model = PPO.load(model_path, env=venv, device="cpu")


# -----------------------------
# 1) camera bounds（あなたの安定版に合わせた簡易版）
# -----------------------------
def compute_fixed_camera_bounds(cfg: CoopMergeConfig, mode="merge", pad=25.0, tail_x=150.0):
    if mode == "full":
        xmin = -pad
        xmax = cfg.goal_x + pad
    else:
        xmin = -pad
        xmax = cfg.merge_end + tail_x + pad

    ymin = -pad
    ymax = 4.0 * cfg.lane_width + pad
    return (xmin, xmax), (ymin, ymax)

# -----------------------------
# 2) ParallelEnv を直接回す（vec化しない）
# -----------------------------
cfg = CoopMergeConfig()
env = CoopMergeEnv(num_agents=12, config=cfg, seed=0)

AGENTS = env.possible_agents  # 固定順序


def obs_dict_to_batch(obs_dict):
    return np.stack([np.asarray(obs_dict[a], dtype=np.float32) for a in AGENTS], axis=0)

def act_batch_to_dict(act_batch):
    act_batch = np.asarray(act_batch)
    return {a: act_batch[i].astype(np.int64) for i, a in enumerate(AGENTS)}


# -----------------------------
# 3) warning 判定（描画と同一ロジックを共有）
# -----------------------------
def front_gap_same_lane(env, i: int):
    if not env.active[i]:
        return None
    li = int(env.lane[i])
    xi = float(env.x[i])
    best = None
    for j in range(env._num_agents):
        if j == i or (not env.active[j]):
            continue
        if int(env.lane[j]) != li:
            continue
        dx = float(env.x[j]) - xi
        if dx > 0 and (best is None or dx < best):
            best = dx
    return best

def is_warn_vehicle(env, i: int, close_warn_m: float, deadend_warn_m: float) -> bool:
    if not env.active[i]:
        return False

    close_warn_m = float(close_warn_m)
    deadend_warn_m = float(deadend_warn_m)

    lane = int(env.lane[i])
    x = float(env.x[i])

    too_close = False
    g = front_gap_same_lane(env, i)
    if g is not None and g < close_warn_m:
        too_close = True

    near_deadend = False
    if lane == 3:
        dist_to_end = float(env.cfg.merge_end) - x
        if 0.0 <= dist_to_end < deadend_warn_m:
            near_deadend = True

    return (too_close or near_deadend)

def count_warnings(env, close_warn_m, deadend_warn_m):
    warn_now = 0
    for i in range(env._num_agents):
        if is_warn_vehicle(env, i, close_warn_m, deadend_warn_m):
            warn_now += 1
    return warn_now


# -----------------------------
# 4) 固定カメラ描画（Matplotlib -> RGB）
#   - warning: orange（車そのもの）
#   - collision flash: 「地点」で1秒 red（flash_spotsで制御）
# -----------------------------
def render_fixed_topdown(
    env, xlim, ylim=None, width=960, height=540, dpi=100,
    y_scale=6.0, pad_y=1.0,
    road_lw=10,
    ramp_lw=12,
    car_len_scale=3.2,
    car_width_scale=1.0,
    veh_edge_lw=2.5,
    close_warn_m=None,
    deadend_warn_m=40.0,
    flash_spots=None,          # ★変更：スロットではなく「地点」を赤表示
    flash_color="red",
    warn_color="orange"
):
    cfg = env.cfg

    if close_warn_m is None:
        close_warn_m = float(getattr(cfg, "close_dist_threshold", 20.0))
    close_warn_m = float(close_warn_m)
    deadend_warn_m = float(deadend_warn_m)

    if flash_spots is None:
        flash_spots = []

    if ylim is None:
        y_min = 0.0 - pad_y
        y_max = 4.0 * cfg.lane_width + pad_y
        ylim = (y_min, y_max)

    def Y(y):
        return float(y) * float(y_scale)

    fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.set_xlim(*xlim)

    # 上下逆（y反転）
    ax.set_ylim(Y(ylim[1]), Y(ylim[0]))

    ax.set_aspect("auto")
    ax.axis("off")

    # -----------------------------
    # roads (奥)
    # -----------------------------
    xs_full = np.linspace(float(xlim[0]), float(xlim[1]), 250)

    for lane in (0, 1, 2):
        y = lane * cfg.lane_width
        ax.plot(
            xs_full,
            np.full_like(xs_full, Y(y)),
            linewidth=road_lw,
            solid_capstyle="round",
            zorder=1
        )

    # ramp lane(3) は merge_end で打ち切り
    ramp_x1 = min(float(xlim[1]), float(cfg.merge_end))
    ramp_x0 = float(xlim[0])
    xs_ramp = np.linspace(ramp_x0, ramp_x1, 250)
    ys_ramp = np.array([env._lane_center_y(3, float(x)) for x in xs_ramp], dtype=float)
    ax.plot(
        xs_ramp,
        np.array([Y(y) for y in ys_ramp], dtype=float),
        linewidth=ramp_lw,
        solid_capstyle="round",
        zorder=1
    )

    # 区間線（破線）(道路より手前、車より奥)
    ax.axvline(cfg.merge_start, linewidth=1, linestyle=(0, (4, 6)), color="k", alpha=0.6, zorder=2)
    ax.axvline(cfg.merge_end,   linewidth=1, linestyle=(0, (4, 6)), color="k", alpha=0.6, zorder=2)

    # -----------------------------
    # vehicle geometry
    # -----------------------------
    veh_h = float(cfg.lane_width) * float(car_width_scale)   # 幅
    veh_w = float(cfg.lane_width) * float(car_len_scale)     # 長さ
    veh_draw_h = veh_h * float(y_scale)

    # -----------------------------
    # collision flash spots (地点を赤) : 車より「少し奥」に描く
    # -----------------------------
    for s in flash_spots:
        x = float(s["x"])
        y = float(s["y"])
        rect = plt.Rectangle(
            (x - veh_w / 2, Y(y) - veh_draw_h / 2),
            veh_w, veh_draw_h,
            fill=True,
            facecolor=flash_color,
            alpha=0.95,
            linewidth=float(veh_edge_lw) + 1.5,
            edgecolor="k",
            zorder=9,  # 車(zorder=10)より奥
        )
        ax.add_patch(rect)

    # -----------------------------
    # vehicles (最前面)
    # -----------------------------
    for i in range(env._num_agents):
        if not env.active[i]:
            continue

        x = float(env.x[i])
        y = float(env.y[i])
        lane = int(env.lane[i])

        # 枠線
        edge_lw = float(veh_edge_lw) + (1.0 if (lane == 3 or env.lc_rem[i] > 0) else 0.0)

        # warn 判定（orange）
        warn = is_warn_vehicle(env, i, close_warn_m, deadend_warn_m)

        if warn:
            fc, a = warn_color, 0.95
        else:
            fc, a = "white", 0.85

        rect = plt.Rectangle(
            (x - veh_w / 2, Y(y) - veh_draw_h / 2),
            veh_w, veh_draw_h,
            fill=True,
            facecolor=fc,
            alpha=a,
            linewidth=edge_lw,
            edgecolor="k",
            zorder=10,
        )
        ax.add_patch(rect)

    fig.canvas.draw()
    rgba = np.asarray(fig.canvas.buffer_rgba(), dtype=np.uint8)
    rgb = rgba[..., :3].copy()
    plt.close(fig)
    return rgb


# -----------------------------
# 5) ロールアウトして動画化 + 集計を出力
# -----------------------------
def run_and_record(
    env,
    model,
    out_mp4="demo_coopmerge_fixed.mp4",
    fps=15,
    seconds=12,
    deterministic=True,
    seed=0,
    camera_mode="merge",
    pad=25.0,
    tail_x=180.0,
    width=960,
    height=540,
    # warn/flash params
    close_warn_m=None,
    deadend_warn_m=40.0,
    flash_seconds=1.0,
):
    xlim, ylim = compute_fixed_camera_bounds(env.cfg, mode=camera_mode, pad=pad, tail_x=tail_x)
    print("fixed camera:", "xlim=", xlim, "ylim=", ylim)

    if close_warn_m is None:
        close_warn_m = float(getattr(env.cfg, "close_dist_threshold", 20.0))
    close_warn_m = float(close_warn_m)
    deadend_warn_m = float(deadend_warn_m)

    obs_dict, _info = env.reset(seed=seed)
    frames = []
    n_frames = int(seconds * fps)

    flash_frames = int(round(float(flash_seconds) * fps))
    flash_spots = []  # [{"x":..., "y":..., "ttl":...}]

    # --- counters（出力用）---
    C = dict(
        goal=0, crash=0, collision=0, deadend=0,
        warn_total=0, warn_max=0,
        warn_enter=0,

        # reward
        reward_team_sum=0.0,
        reward_team_min=+1e18,
        reward_team_max=-1e18,
        reward_step_count=0,

        # debug: 衝突地点infoが無い場合のカウント
        collision_missing_xy=0,
    )

    prev_warn = np.zeros(env._num_agents, dtype=bool)

    for t in range(n_frames):
        # --- policy ---
        obs_batch = obs_dict_to_batch(obs_dict)
        act_batch, _ = model.predict(obs_batch, deterministic=deterministic)
        actions = act_batch_to_dict(act_batch)

        # --- step ---
        obs_dict, rew_dict, term_dict, trunc_dict, info_dict = env.step(actions)

        # --- reward（チーム報酬）---
        # 環境は common_reward を全agentに複製して返す設計なので、
        # 代表として先頭agentの値を「チーム報酬」として集計（12倍になるsumはNG）
        team_rew = float(rew_dict[AGENTS[0]])
        C["reward_team_sum"] += team_rew
        C["reward_team_min"] = min(C["reward_team_min"], team_rew)
        C["reward_team_max"] = max(C["reward_team_max"], team_rew)
        C["reward_step_count"] += 1

        # --- infos からイベント集計（active_pre のみ数える）---
        goals = crashes = colls = deads = 0
        for i, a in enumerate(AGENTS):
            inf = info_dict.get(a, {})
            if not inf.get("active_pre", True):
                continue
            goals   += int(inf.get("event_goal", False))
            crashes += int(inf.get("event_crash", False))
            colls   += int(inf.get("event_collision", False))
            deads   += int(inf.get("event_deadend", False))

        C["goal"]      += goals
        C["crash"]     += crashes
        C["collision"] += colls
        C["deadend"]   += deads

        # --- collision を「地点」で 1秒フラッシュ ---
        # 必要: step() info に event_x/event_y を追加（互換性は壊れない）
        for i, a in enumerate(AGENTS):
            inf = info_dict.get(a, {})
            if inf.get("event_collision", False) and inf.get("active_pre", False):
                if ("event_x" in inf) and ("event_y" in inf):
                    flash_spots.append({
                        "x": float(inf["event_x"]),
                        "y": float(inf["event_y"]),
                        "ttl": flash_frames
                    })
                else:
                    # event位置が無い場合は「赤が動く」原因になるので、地点フラッシュは出さない
                    C["collision_missing_xy"] += 1

        # --- warn 台数 + warn_enter ---
        warn_flags = np.zeros(env._num_agents, dtype=bool)
        for i in range(env._num_agents):
            warn_flags[i] = is_warn_vehicle(env, i, close_warn_m, deadend_warn_m)

        warn_now = int(np.sum(warn_flags))
        C["warn_total"] += warn_now
        C["warn_max"] = max(C["warn_max"], warn_now)

        enter = np.logical_and(warn_flags, ~prev_warn)
        C["warn_enter"] += int(np.sum(enter))
        prev_warn = warn_flags

        # --- TTL 減衰（地点）---
        for s in flash_spots:
            s["ttl"] -= 1
        flash_spots = [s for s in flash_spots if s["ttl"] > 0]

        # --- render ---
        frame = render_fixed_topdown(
            env, xlim=xlim, ylim=None,
            width=width, height=height,
            y_scale=4.0, pad_y=10.0,
            road_lw=12, ramp_lw=14,
            car_len_scale=5.0, car_width_scale=0.3,
            veh_edge_lw=3.0,
            close_warn_m=close_warn_m,
            deadend_warn_m=deadend_warn_m,
            flash_spots=flash_spots,
            warn_color="orange",
            flash_color="red",
        )
        frames.append(frame)

        # --- time_up でリセット ---
        if any(trunc_dict.values()):
            obs_dict, _info = env.reset(seed=seed)
            prev_warn[:] = False
            flash_spots = []

    # --- write ---
    out_mp4 = str(out_mp4)
    try:
        imageio.mimwrite(out_mp4, frames, fps=fps, codec="libx264", pixelformat="yuv420p", macro_block_size=1)
    except Exception as e:
        print("WARN: libx264 failed:", repr(e))
        out2 = out_mp4.replace(".mp4", "_fallback.mp4")
        imageio.mimwrite(out2, frames, fps=fps, codec="mpeg4", macro_block_size=1)
        out_mp4 = out2

    # --- summary（動画には入れない）---
    print("=== summary (counts in output window) ===")
    print(f"goal      : {C['goal']}")
    print(f"crash     : {C['crash']}")
    print(f"collision : {C['collision']}")
    print(f"deadend   : {C['deadend']}")
    print(f"warn_total (延べwarn台数/フレーム): {C['warn_total']}")
    print(f"warn_max/frame            : {C['warn_max']}")
    print(f"warn_enter (warn突入回数) : {C['warn_enter']}")

    # reward
    mean_step = (C["reward_team_sum"] / C["reward_step_count"]) if C["reward_step_count"] > 0 else 0.0
    print("--- reward (team, per-step) ---")
    print(f"team_return(sum over steps): {C['reward_team_sum']:.6f}")
    print(f"team_reward_mean_per_step  : {mean_step:.6f}")
    print(f"team_reward_min/max_step   : {C['reward_team_min']:.6f} / {C['reward_team_max']:.6f}")

    # collision XY availability
    if C["collision"] > 0 and C["collision_missing_xy"] > 0:
        print(f"[WARN] collision occurred but event_x/event_y missing in info: {C['collision_missing_xy']} times")
        print("       -> step() の info に event_x/event_y を追加してください（観測/報酬/遷移は不変で互換性OK）")

    return out_mp4


mp4_path = run_and_record(
    env, model,
    out_mp4="demo_coopmerge_fixed.mp4",
    fps=15,
    seconds=12,
    deterministic=True,
    seed=0,
    camera_mode="merge",
    pad=25.0,
    tail_x=180.0,
    width=960,
    height=540,
    close_warn_m=None,     # Noneなら cfg.close_dist_threshold を使う
    deadend_warn_m=40.0,
    flash_seconds=1.0,
)

print("saved:", mp4_path)

from IPython.display import Video, display
display(Video(mp4_path, embed=True))


fixed camera: xlim= (-25.0, 805.0) ylim= (-25.0, 41.0)
=== summary (counts in output window) ===
goal      : 49
crash     : 6
collision : 6
deadend   : 0
warn_total (延べwarn台数/フレーム): 69
warn_max/frame            : 2
warn_enter (warn突入回数) : 43
--- reward (team, per-step) ---
team_return(sum over steps): -2572.957810
team_reward_mean_per_step  : -14.294210
team_reward_min/max_step   : -248.616894 / 15.900000
saved: demo_coopmerge_fixed.mp4
