In [3]:
# collector.py — intersection-v1 dataset (grayscale), multi-agent with one EGO
# -----------------------------------------------------------------------------
# Paper config:
#   Observation: Grayscale (128, 64), stack_size: 4
#   Action type: DiscreteMetaAction (LANE_LEFT, IDLE, LANE_RIGHT, FASTER, SLOWER)
#   Duration: 30s, Sim/Policy freq: 15 Hz, Vehicles: 5, Spawn prob: 0.2
#
# What this script does:
#   • Builds intersection env (v1/v0) with those settings (multi-agent capable)
#   • Ensures there are N controlled agents; agent[0] is the EGO “ambulance”
#   • Other cars “open the road” (yield/right-bias/slow near ambulance)
#   • Steps the env with balanced labels across 5 actions
#   • Saves the **last grayscale frame** in the stack as a PNG (128×64)
#   • If obs grayscale is blank, falls back to RGB render → grayscale
#   • Writes a single JSONL with rows: {image, action_id, instruction, meta…}
#
# Jupyter-friendly: no argparse. Just run the last cell that calls collect_intersection().
# -----------------------------------------------------------------------------

import os, json, random, math
from pathlib import Path
from collections import Counter, deque
from typing import Dict, Any, List

import numpy as np
from PIL import Image
from tqdm.auto import tqdm

import gymnasium as gym
import highway_env  # registers env ids

import os, pygame
print("SDL_VIDEODRIVER (env):", os.environ.get("SDL_VIDEODRIVER"))
pygame.init()
print("pygame video driver:", pygame.display.get_driver())


# -------------------- knobs you can tweak --------------------
OUTDIR                 = "/home/chettra/ITC/mvs-manus/mvs_test_4/data"  # root folder
FRAMES_TOTAL           = 1000                           # <- as requested
N_CONTROLLED           = 3                              # agent[0] is EGO
SAVE_SIZE              = (128, 64)                      # paper image size (W,H)
STACK_SIZE             = 4
SEED                   = 42

COURTESY_ON            = True   # NPCs yield to ambulance
COURTESY_RADIUS        = 60.0
COURTESY_SLOW_FACTOR   = 0.6
COURTESY_HEADWAY_MULT  = 1.5
RIGHT_LANE_BIAS        = True

# Action labels (paper order; robust mapping created automatically)
ACTION_IDS = ["SLOWER", "IDLE", "FASTER", "LANE_LEFT", "LANE_RIGHT"]

PROMPTS = {
    "SLOWER":     "Reduce speed—collision risk ahead.",
    "IDLE":       "Keep the current lane and speed.",
    "FASTER":     "Accelerate; a safe gap is ahead.",
    "LANE_LEFT":  "Change to the left lane.",
    "LANE_RIGHT": "Change to the right lane.",
}
PARAPHRASES = {
    "SLOWER": [
        "Reduce speed—caution ahead.", "Back off the throttle; traffic up ahead.",
        "Slow down to stay safe.", "Drop speed; the gap is tight.", "Decelerate—possible hazard ahead."
    ],
    "IDLE": [
        "Hold speed and lane.", "Stay steady in this lane.", "Maintain pace; no lane change.",
        "Remain in lane with current speed.", "Continue unchanged."
    ],
    "FASTER": [
        "Accelerate—path looks clear.", "Increase speed; safe gap ahead.",
        "Pick up pace—no blockers.", "Go quicker; open stretch ahead.", "Speed up to target pace."
    ],
    "LANE_LEFT": [
        "Merge left safely.", "Move to the left lane.", "Shift to left lane to pass."
    ],
    "LANE_RIGHT": [
        "Merge right safely.", "Move to the right lane.", "Shift to right lane to yield."
    ],
}

# -------------------- small utils --------------------
def set_all_seeds(s: int = 42):
    random.seed(s); np.random.seed(s)
    try:
        import torch
        torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    except Exception:
        pass

def ensure_dir(p: Path): p.mkdir(parents=True, exist_ok=True)

def to_gray_uint8(x: np.ndarray) -> Image.Image:
    """x can be [0,1] or [0,255], shape (H, W) -> PIL 'L'"""
    arr = np.asarray(x)
    if arr.max() <= 1.0: arr = arr * 255.0
    arr = np.clip(arr, 0, 255).astype(np.uint8)
    return Image.fromarray(arr, mode="L")

def last_gray_from_obs(obs: Any) -> np.ndarray:
    """
    Works with obs in shapes:
      • (S, H, W) or (H, W, S)  (paper grayscale stack)
      • list/tuple for multi-agent — we take ego [0]
    Returns (H, W) float in [0,1].
    """
    if isinstance(obs, (list, tuple)):
        obs = obs[0]  # ego
    a = np.asarray(obs)
    if a.ndim == 3 and a.shape[0] == STACK_SIZE:
        fr = a[-1]
    elif a.ndim == 3 and a.shape[2] == STACK_SIZE:
        fr = a[..., -1]
    elif a.ndim == 2:
        fr = a
    else:
        raise ValueError(f"Unexpected obs shape: {a.shape}")
    return fr.astype(np.float32) / (255.0 if fr.max() > 1.5 else 1.0)

# -------------------- env factory (version-robust) --------------------
def grayscale_obs_config():
    # Some highway-env versions require 'weights'; newer ones accept/ignore it.
    return {
        "type": "GrayscaleObservation",
        "observation_shape": (64, 128),  # H, W  ← swap
        "stack_size": STACK_SIZE,
        "weights": [0.2989, 0.5870, 0.1140],
    }

# --- fixes for blank grayscale frames ---

def rgb_to_gray(img_rgb: np.ndarray) -> np.ndarray:
    """RGB (H,W,3) in [0..255] or [0..1] -> grayscale (H,W) in [0..1]."""
    arr = img_rgb.astype(np.float32)
    if arr.max() > 1.5:
        arr = arr / 255.0
    r, g, b = arr[..., 0], arr[..., 1], arr[..., 2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
    return np.clip(gray, 0.0, 1.0)

def safe_render(env):
    """Always produce a fresh RGB frame from the viewer."""
    frame = env.render()
    if frame is None:
        frame = env.render()
    return frame

def gray_from_obs_or_render(env, obs) -> np.ndarray:
    """
    Prefer the grayscale observation (last frame of the stack).
    If it looks blank, fall back to converting the RGB render to grayscale.
    Returns (H, W) in [0,1].
    """
    try:
        g = last_gray_from_obs(obs)
        if float(np.nanmax(g)) > 0.02:  # not blank
            return g
    except Exception:
        pass
    rgb = env.render()
    if rgb is None:
        raise RuntimeError("render() returned None; cannot build grayscale frame.")
    return rgb_to_gray(rgb)

def make_intersection_env(seed=SEED, n_controlled=N_CONTROLLED) -> gym.Env:
    cfg = dict(
        offscreen_rendering=False,
        # these affect render() only (not observation); useful for fallback:
        screen_width=700, screen_height=256,
        centering_position=[0.5, 0.5],
        scaling=3.5,
        show_trajectories=False,
        render_agent=True,

        action=dict(type="DiscreteMetaAction"),
        observation=grayscale_obs_config(),   # (128,64), stack=4 — paper
        duration=30,
        simulation_frequency=15,
        policy_frequency=15,
        vehicles_count=5,
        spawn_probability=0.2,
        controlled_vehicles=int(n_controlled),
    )
    env, last_err = None, None
    for env_id in ("intersection-multi-agent-v1", "intersection-multi-agent-v0"):
        for with_render_kw in (True, False):
            try:
                env = gym.make(env_id, render_mode=("rgb_array" if with_render_kw else None), config=cfg)
                env.reset(seed=seed)
                return env
            except Exception as e:
                last_err = e
                env = None
    raise RuntimeError(f"Could not create intersection env. Last error: {last_err!r}")

# -------------------- action mapping (robust) --------------------
def detect_action_mapping(env: gym.Env) -> Dict[str, int]:
    """
    Returns a mapping like:
      {'LANE_LEFT':0, 'IDLE':1, 'LANE_RIGHT':2, 'FASTER':3, 'SLOWER':4}
    Works even if the env exposes actions as '0','1','2','3','4'.
    """
    try:
        aty = env.unwrapped.action_type
        if hasattr(aty, "agents_action_types") and aty.agents_action_types:
            actions = list(aty.agents_action_types[0].actions)
        else:
            actions = list(aty.actions)
        names = []
        for a in actions:
            if isinstance(a, str):
                names.append(a)
            else:
                nm = getattr(a, "name", None)
                names.append(nm if isinstance(nm, str) else str(a))
    except Exception:
        names = ["LANE_LEFT", "IDLE", "LANE_RIGHT", "FASTER", "SLOWER"]

    # If we got ['0','1','2','3','4'], map by standard order:
    if all(isinstance(n, str) and n.isdigit() for n in names) and len(names) == 5:
        std = ["LANE_LEFT", "IDLE", "LANE_RIGHT", "FASTER", "SLOWER"]
        mapping = {std[i]: i for i in range(5)}
    else:
        mapping = {n: i for i, n in enumerate(names)}

    # Ensure all five labels exist (fallback to standard order where missing)
    std = {"LANE_LEFT": 0, "IDLE": 1, "LANE_RIGHT": 2, "FASTER": 3, "SLOWER": 4}
    for k, v in std.items():
        mapping.setdefault(k, v)

    print("[env] Detected action mapping:", mapping)
    return mapping

# -------------------- courtesy (NPCs open road) --------------------
def apply_courtesy(env: gym.Env,
                   radius=COURTESY_RADIUS,
                   slow_factor=COURTESY_SLOW_FACTOR,
                   headway_mult=COURTESY_HEADWAY_MULT,
                   right_bias=RIGHT_LANE_BIAS):
    if not COURTESY_ON: return
    try:
        road = env.unwrapped.road
        cvs  = getattr(env.unwrapped, "controlled_vehicles", []) or []
        ego  = cvs[0] if cvs else env.unwrapped.vehicle
    except Exception:
        return
    if road is None or ego is None: return
    ego_xy = np.asarray(getattr(ego, "position", (0.0, 0.0)), dtype=float)

    for v in list(road.vehicles):
        if v is ego: continue
        pos = getattr(v, "position", None)
        if pos is None: continue
        d = float(np.linalg.norm(np.asarray(pos, dtype=float) - ego_xy))
        if d > radius: continue

        # Slow down / increase headway
        if hasattr(v, "target_speed"):
            try: v.target_speed = max(0.0, float(v.target_speed) * slow_factor)
            except Exception: pass
        elif hasattr(v, "speed"):
            try: v.speed = float(v.speed) * slow_factor
            except Exception: pass
        for attr in ("T", "desired_headway", "desired_time_headway", "desired_gap", "min_gap", "s0"):
            if hasattr(v, attr):
                try: setattr(v, attr, float(getattr(v, attr)) * headway_mult)
                except Exception: pass

        # Bias to right-most lane (opens middle/left for ambulance)
        if right_bias:
            try:
                li = getattr(v, "lane_index", None)
                if isinstance(li, (tuple, list)) and len(li) >= 3:
                    cur = int(li[2]); tgt = max(0, cur - 1)
                    if tgt < cur:
                        if hasattr(v, "go_to_lane"): v.go_to_lane(tgt)
                        elif hasattr(v, "target_lane_index"): v.target_lane_index = (li[0], li[1], tgt)
            except Exception:
                pass

# -------------------- simple side-agent policy --------------------
def side_agent_label(t: int) -> str:
    # Mostly IDLE; occasionally nudge lanes to create variety
    if t % 25 == 0: return random.choice(["IDLE", "LANE_LEFT", "LANE_RIGHT"])
    return "IDLE"

# -------------------- main collector --------------------
def collect_intersection(
    outdir: str = OUTDIR,
    frames_total: int = FRAMES_TOTAL,
    n_controlled: int = N_CONTROLLED,
    seed: int = SEED,
):
    set_all_seeds(seed)
    out = Path(outdir); img_dir = out / "images"
    ensure_dir(img_dir)

    env = make_intersection_env(seed=seed, n_controlled=n_controlled)
    name2idx = detect_action_mapping(env)

    label_counter = Counter()
    schedule = deque()  # (action_id, instruction)

    # build a balanced schedule across 5 actions, round-robin with paraphrases
    per = math.ceil(frames_total / len(ACTION_IDS))
    for a in ACTION_IDS:
        phrases = [PROMPTS[a]] + PARAPHRASES.get(a, [])
        for i in range(per):
            schedule.append((a, phrases[i % len(phrases)]))
    while len(schedule) > frames_total: schedule.pop()

    rows = []
    frame_id = 0
    ep = 0
    pbar = tqdm(total=frames_total, desc="Collecting (intersection-v1)")

    while schedule:
        obs, info = env.reset(seed=seed + ep)
        print("reset: obs type/shape =", type(obs), (np.asarray(obs[0]).shape if isinstance(obs,(list,tuple)) else np.asarray(obs).shape))
        done = trunc = False
        t = 0

        while not (done or trunc) and schedule:
            apply_courtesy(env)

            # --- choose ego action (balanced) ---
            action_id, instruction = schedule[0]
            ego_idx = int(name2idx.get(action_id, name2idx.get("IDLE", 1)))
            label_counter[action_id] += 1

            # --- build multi-agent action vector (ego + side agents) ---
            acts = [ego_idx]
            for k in range(1, n_controlled):
                nm = side_agent_label(t)
                acts.append(int(name2idx.get(nm, name2idx.get("IDLE", 1))))

            # step (tuple for multi-agent; fallback to int for single-agent builds)
            try:
                obs, reward, done, trunc, info = env.step(tuple(acts))
                
            except Exception:
                obs, reward, done, trunc, info = env.step(acts[0])

            # --- diagnostics (do this e.g. when t == 0 or frame_id < 3) ---
            if t == 0:  # or use: if frame_id < 3:
                # show obs shape
                obs_shape = (np.asarray(obs[0]).shape if isinstance(obs, (list, tuple))
                            else np.asarray(obs).shape)
                print("obs type/shape:", type(obs), obs_shape)

                # try grayscale-from-obs
                try:
                    g = last_gray_from_obs(obs)
                    print("obs gray max:", float(np.nanmax(g)))
                except Exception as e:
                    print("last_gray_from_obs failed:", e)

                # try render (can be None/black in headless)
                try:
                    r = env.render()
                    if r is None:
                        print("render() returned None")
                    else:
                        r_arr = np.asarray(r)
                        print("render shape:", r_arr.shape, "render max:", float(np.nanmax(r_arr)))
                except Exception as e:
                    print("render() failed:", e)

            # --- SAVE FULL RENDER FRAME INSTEAD OF OBS CROP ---
            rgb = env.render()  # get full simulator screen
            if rgb is None:
                rgb = env.render()  # safety retry

            # convert to grayscale or keep color
            pil = Image.fromarray(rgb)  # full view, same as human render
            pil = pil.resize(SAVE_SIZE, Image.BILINEAR)
            # pil = Image.fromarray(rgb).convert("L")  # uncomment if you want grayscale

            fname = f"f_{frame_id:06d}.png"
            pil.save(img_dir / fname)
            # -------------------------------------------------

            # write a row (ego label); include a bit of env state
            try:
                ego_speed = float(getattr(env.unwrapped.controlled_vehicles[0], "speed", 0.0))
            except Exception:
                ego_speed = float(info.get("speed", 0.0))

            rows.append({
                "index": frame_id,
                "image": f"images/{fname}",
                "action_id": action_id,
                "instruction": instruction,
                "episode": ep,
                "t": t,
                "ego_speed": ego_speed,
                "n_controlled": n_controlled,
                "courtesy": bool(COURTESY_ON),
            })

            frame_id += 1
            t += 1
            schedule.popleft()
            pbar.update(1)
            if frame_id >= frames_total: break

        ep += 1
        if frame_id >= frames_total: break

    pbar.close()
    env.close()

    # JSONL (no splits; you can split later)
    ensure_dir(Path(outdir))
    with open(Path(outdir) / "frames.jsonl", "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    print(f"✅ Done. Saved {frame_id} frames to {outdir}")
    print("Label counts:", dict(label_counter))





SDL_VIDEODRIVER (env): None
pygame video driver: x11


In [5]:
# ---- run in Jupyter: just execute this cell ----
collect_intersection(
    outdir=OUTDIR,
    frames_total=FRAMES_TOTAL,
    n_controlled=N_CONTROLLED,
    seed=SEED,
)

[env] Detected action mapping: {'LANE_LEFT': 0, 'IDLE': 1, 'LANE_RIGHT': 2, 'FASTER': 3, 'SLOWER': 4}


Collecting (episode-style):  78%|███████▊  | 777/1000 [00:13<00:03, 65.68it/s]

AttributeError: 'NoneType' object has no attribute 'get_image'

Collecting (episode-style):  78%|███████▊  | 780/1000 [00:31<00:03, 65.68it/s]

In [4]:
# --- Override collector with episode-style loop from clip-rl-2 (simplified, no Ollama) ---
# This cell replaces the previous step-loop collector with an episode-based collector
# that uses the PROMPTS / PARAPHRASES defined earlier in the notebook for text labels.

def collect_intersection_episode(
    outdir: str = OUTDIR,
    frames_total: int = None,            # backward-compatible arg used in original run cell
    n_episodes: int = 50,
    frames_per_episode: int = None,
    n_controlled: int = N_CONTROLLED,
    seed: int = SEED,
):
    """Episode-style data collector inspired by clip-rl-2 patterns.
    - Runs multiple episodes; in each episode, selects actions per step from the
      balanced schedule derived from the notebook PROMPTS/PARAPHRASES.
    - No Ollama or external LLM used; labels come from PROMPTS/PARAPHRASES.
    - Saves PNG renders and a single JSONL dataset at outdir/frames.jsonl

    Backwards compatibility:
      If `frames_total` is provided (old API), compute `frames_per_episode = ceil(frames_total / n_episodes)`.
    """
    set_all_seeds(seed)
    out = Path(outdir); img_dir = out / "images"
    ensure_dir(img_dir)

    # Backwards compatibility handling
    if frames_total is not None and frames_per_episode is None:
        frames_per_episode = math.ceil(frames_total / max(1, n_episodes))
    if frames_per_episode is None:
        frames_per_episode = 20

    env = make_intersection_env(seed=seed, n_controlled=n_controlled)
    name2idx = detect_action_mapping(env)

    # Build a per-episode balanced schedule (reused across episodes with shuffling)
    base_schedule = []
    per = math.ceil((n_episodes * frames_per_episode) / len(ACTION_IDS))
    for a in ACTION_IDS:
        phrases = [PROMPTS[a]] + PARAPHRASES.get(a, [])
        for i in range(per):
            base_schedule.append((a, phrases[i % len(phrases)]))

    # Shuffle for variety but keep deterministic via seed
    random.Random(seed).shuffle(base_schedule)

    rows = []
    frame_id = 0
    pbar = tqdm(total=n_episodes * frames_per_episode, desc="Collecting (episode-style)")

    schedule_iter = iter(base_schedule)

    for ep in range(n_episodes):
        obs, info = env.reset(seed=seed + ep)
        done = trunc = False
        t = 0

        while t < frames_per_episode:
            apply_courtesy(env)

            # Get next label from schedule (wrap if needed)
            try:
                action_id, instruction = next(schedule_iter)
            except StopIteration:
                # Recreate iterator if we exhausted base_schedule
                schedule_iter = iter(base_schedule)
                action_id, instruction = next(schedule_iter)

            ego_idx = int(name2idx.get(action_id, name2idx.get("IDLE", 1)))

            # Build multi-agent actions
            acts = [ego_idx]
            for k in range(1, n_controlled):
                nm = side_agent_label(t)
                acts.append(int(name2idx.get(nm, name2idx.get("IDLE", 1))))

            # Step environment
            try:
                obs, reward, done, trunc, info = env.step(tuple(acts))
            except Exception:
                obs, reward, done, trunc, info = env.step(acts[0])

            # Render full RGB frame
            rgb = env.render()
            if rgb is None:
                rgb = env.render()
            if rgb is None:
                # fallback to grayscale obs -> to RGB-like for saving
                try:
                    g = gray_from_obs_or_render(env, obs)
                    # convert to uint8
                    arr = (g * 255.0).astype(np.uint8)
                    pil = Image.fromarray(arr).convert("RGB")
                except Exception:
                    # create a blank image
                    pil = Image.new("RGB", SAVE_SIZE, color=(0, 0, 0))
            else:
                pil = Image.fromarray(rgb).resize(SAVE_SIZE, Image.BILINEAR)

            fname = f"f_{frame_id:06d}.png"
            pil.save(img_dir / fname)

            # gather ego speed if possible
            try:
                ego_speed = float(getattr(env.unwrapped.controlled_vehicles[0], "speed", 0.0))
            except Exception:
                ego_speed = float(info.get("speed", 0.0))

            rows.append({
                "index": frame_id,
                "image": f"images/{fname}",
                "action_id": action_id,
                "instruction": instruction,
                "episode": ep,
                "t": t,
                "ego_speed": ego_speed,
                "n_controlled": n_controlled,
                "courtesy": bool(COURTESY_ON),
            })

            frame_id += 1
            t += 1
            pbar.update(1)

        if frame_id >= n_episodes * frames_per_episode:
            break

    pbar.close()
    env.close()

    # Write JSONL
    ensure_dir(Path(outdir))
    with open(Path(outdir) / "frames.jsonl", "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    print(f"✅ Done. Saved {frame_id} frames to {outdir}")
    return rows


# Replace the notebook's collect_intersection reference with the episode variant
collect_intersection = collect_intersection_episode

# Note: to run, execute the existing run cell which calls collect_intersection()
# Example (in notebook):
# collect_intersection(outdir=OUTDIR, frames_total=FRAMES_TOTAL, n_controlled=N_CONTROLLED, seed=SEED)
