From 4f68aa23bf07dce99e26ba37017e62ba468e4a0d Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Wed, 25 Mar 2026 16:44:26 -0700 Subject: [PATCH 1/6] Adding robocasa server and client for batch inference --- src/opentau/scripts/robocasa/__init__.py | 13 + src/opentau/scripts/robocasa/client.py | 741 ++++++++++++++++++++++ src/opentau/scripts/robocasa/server.py | 764 +++++++++++++++++++++++ 3 files changed, 1518 insertions(+) create mode 100644 src/opentau/scripts/robocasa/__init__.py create mode 100644 src/opentau/scripts/robocasa/client.py create mode 100644 src/opentau/scripts/robocasa/server.py diff --git a/src/opentau/scripts/robocasa/__init__.py b/src/opentau/scripts/robocasa/__init__.py new file mode 100644 index 00000000..787f750f --- /dev/null +++ b/src/opentau/scripts/robocasa/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/opentau/scripts/robocasa/client.py b/src/opentau/scripts/robocasa/client.py new file mode 100644 index 00000000..fb92bc84 --- /dev/null +++ b/src/opentau/scripts/robocasa/client.py @@ -0,0 +1,741 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Threaded batched remote policy client for RoboCasa. + +Runs **n_parallel** environment threads. Each thread pulls rollouts from a shared queue +until **num_rollouts** episodes are finished. The **main** asyncio loop receives +observations from active workers, batches them into one WebSocket message per timestep, +and routes returned actions back to the corresponding threads. + +Batch protocol (MessagePack over WebSocket, binary frames) matches ``client.py`` / +``robocasa.scripts.server``: + + Client -> server: { + "batch": true, + "items": [ + { "images": { camera_name: bytes (JPEG), ... }, "state": list[float], "prompt": str }, + ... + ], + } + + Server -> client: list[list[float]] # one flat action per item, same order as ``items`` + +The number of ``items`` (and thus the batch size) is **only** the count of workers +still stepping this timestep. As workers finish their rollout queue and exit, batch +size shrinks from at most ``num_parallel`` down to 1 for the final active worker(s). +The policy server must return exactly ``len(items)`` actions, not a fixed width of +``num_parallel``. + +Rollout records and ``rollouts.json`` match ``client.py`` (``env_name``, ``seed``, +``length``, ``success`` per rollout; summary includes ``num_rollouts``, +``num_parallel_envs``, ``output_directory``). + +Requires ``websockets``, ``msgpack``, ``opencv-python`` (``cv2``). The WebSocket client +sets ``ping_timeout=None`` so MuJoCo stepping and JPEG encoding do not trip keepalive. +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import queue +import threading +import warnings +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Union + +import cv2 +import imageio +import msgpack +import numpy as np +import websockets + +import robocasa # noqa: F401 +from robocasa.utils.env_utils import convert_action_pi05, create_env + +# Same three cameras as ``create_env`` defaults / PandaOmron gym wrapper. +DEFAULT_CAMERA_NAMES: tuple[str, ...] = ( + "robot0_eye_in_hand", + "robot0_agentview_left", + "robot0_agentview_right", +) +# Resolution aligned with ``robocasa.wrappers.gym_wrapper.PandaOmronKeyConverter``. +DEFAULT_CAMERA_WIDTH = 256 +DEFAULT_CAMERA_HEIGHT = 256 + +# Flat action layout expected by env (PandaOmron); see also env_utils.convert_action. +ACTION_ORDER = ( + "end_effector_position", # 3 + "end_effector_rotation", # 3 + "gripper_close", # 1 + "base_motion", # 4 + "control_mode", # 1 +) # total 12 + +# Proprio keys aligned with PandaOmronKeyConverter / typical RoboCasa datasets. +DEFAULT_PROPRIO_KEYS = ( + "robot0_base_pos", + "robot0_base_quat", + "robot0_base_to_eef_pos", + "robot0_base_to_eef_quat", + "robot0_gripper_qpos", +) + + +def get_task_prompt(env) -> str: + """ + Natural-language instruction for the current episode (RoboCasa ``get_ep_meta()['lang']``). + """ + meta = env.get_ep_meta() + if not meta: + return "" + lang = meta.get("lang", "") + if lang is None: + return "" + if isinstance(lang, (list, tuple)): + return " ".join(str(x) for x in lang) + return str(lang) + + +def build_proprio_vector(obs: dict, keys: tuple[str, ...] = DEFAULT_PROPRIO_KEYS) -> np.ndarray: + """Concatenate low-dimensional robot state for policy input.""" + parts = [] + for k in keys: + if k not in obs: + raise KeyError( + f"Observation missing key {k!r}. Available keys (sample): " + f"{[x for x in obs if not x.endswith('_image')][:20]}..." + ) + parts.append(np.asarray(obs[k], dtype=np.float64).ravel()) + return np.concatenate(parts, axis=0) + + +def flip_image_obs(obs: dict, camera_names: tuple[str, ...]) -> dict: + """Flip images vertically since MuJoCo renders upside down.""" + for name in camera_names: + key = f"{name}_image" + if key in obs: + # Copy to ensure the array is contiguous for cv2/imageio + obs[key] = obs[key][::-1].copy() + return obs + + +def encode_camera_rgb_to_jpeg( + obs: dict, + camera_name: str, + jpeg_quality: int = 80, +) -> bytes: + """Encode one camera's RGB observation as JPEG bytes (OpenCV uses BGR).""" + key = f"{camera_name}_image" + if key not in obs: + raise KeyError(f"Missing {key!r}. Ensure create_env includes camera {camera_name!r}.") + rgb = obs[key] + if rgb.ndim != 3 or rgb.shape[-1] != 3: + raise ValueError(f"Expected HxWx3 RGB image for {key}, got shape {rgb.shape}") + bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + ok, buf = cv2.imencode(".jpg", bgr, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality]) + if not ok: + raise RuntimeError("cv2.imencode failed") + return buf.tobytes() + + +def encode_all_cameras_jpeg( + obs: dict, + camera_names: tuple[str, ...], + jpeg_quality: int = 80, +) -> dict: + """JPEG-encode every listed camera; keys match ``camera_names``.""" + return {name: encode_camera_rgb_to_jpeg(obs, name, jpeg_quality=jpeg_quality) for name in camera_names} + + +@dataclass +class ObsMsg: + """Worker needs a policy action for this observation (packed client payload).""" + + payload: dict[str, Any] + + +@dataclass +class DoneMsg: + """Episode finished; no server call for this message.""" + + rollout_idx: int + length: int + success: bool + + +@dataclass +class ExitMsg: + """Worker thread has no more rollouts and is exiting.""" + + +WorkerToMain = Union[ObsMsg, DoneMsg, ExitMsg] + +_SERVER_TRUNCATED_ACTION_BATCH_WARNED = False + + +def _normalize_batched_actions_response( + actions_batch: Any, + num_expected: int, +) -> list[Any]: + """Normalize the batched policy response to one action list per observation. + + Args: + actions_batch: Raw decoded batch from the server (typically a ``list``). + num_expected: Number of observations in this batch (length of ``items``). + + Returns: + A list of length ``num_expected``, one action (sequence of floats) per row. + + Raises: + ValueError: If the response is not a list, length cannot be reconciled with + ``num_expected``, or a partial batch cannot be interpreted. + + Note: + If ``num_expected == 1``, some servers return a single flat ``list[float]`` + instead of ``[list[float]]``; that case is wrapped. If the server returns + more rows than ``num_expected``, excess rows are dropped (with a one-time + warning). + """ + global _SERVER_TRUNCATED_ACTION_BATCH_WARNED + if not isinstance(actions_batch, list): + raise ValueError(f"Batched server response must be a list, got {type(actions_batch).__name__}") + if len(actions_batch) == num_expected: + return actions_batch + if len(actions_batch) > num_expected: + if not _SERVER_TRUNCATED_ACTION_BATCH_WARNED: + warnings.warn( + f"Policy server returned {len(actions_batch)} actions for a batch of " + f"{num_expected}; using the first {num_expected}. Prefer fixing the server " + f"to return exactly len(items) actions.", + UserWarning, + stacklevel=2, + ) + _SERVER_TRUNCATED_ACTION_BATCH_WARNED = True + return actions_batch[:num_expected] + # Single-env batch: server may send one flat list[float] instead of [list[float]]. + if num_expected == 1 and len(actions_batch) > 0: + first = actions_batch[0] + if isinstance(first, (int, float, np.floating, np.integer)): + return [actions_batch] + raise ValueError( + f"Batched actions length {len(actions_batch)} != batch size {num_expected} " + f"(partial batches must still return one action list per observation)" + ) + + +def _worker_loop( + *, + rollout_queue: queue.Queue[int | None], + to_main: queue.Queue[WorkerToMain], + from_main: queue.Queue[np.ndarray], + env_name: str, + split, + start_seed: int, + main_dir: str, + jpeg_quality: int, + max_episode_steps: int | None, + render: bool, + action_dim_holder: list[int | None], + action_dim_lock: threading.Lock, +) -> None: + """Run one worker thread: consume rollout indices and step the env until done. + + Pulls rollout IDs from ``rollout_queue``, builds observations (JPEG + state + + prompt), sends ``ObsMsg`` to the coordinator, blocks on ``from_main`` for the + action, steps the environment, and sends ``DoneMsg`` when the episode ends. + Puts ``ExitMsg`` when the queue is empty and the thread exits. + + Args: + rollout_queue: Queue of rollout indices; empty queue means this worker exits. + to_main: Queue to the asyncio coordinator (``ObsMsg``, ``DoneMsg``, ``ExitMsg``). + from_main: Queue from coordinator delivering one flat action vector per step. + env_name: Registered RoboCasa environment name. + split: Dataset split passed to ``create_env``. + start_seed: Base seed; rollout ``i`` uses ``start_seed + i``. + main_dir: Root directory for ``rollouts.json`` and per-rollout video folders. + jpeg_quality: JPEG quality for encoded camera frames (when not rendering). + max_episode_steps: Optional step cap per episode (in addition to success). + render: If True, onscreen render and no video files; else offscreen + videos. + action_dim_holder: Single-element list shared across workers for ``env.action_dim``. + action_dim_lock: Lock protecting ``action_dim_holder`` initialization. + """ + while True: + try: + # get the next rollout index from the queue and its protected by a lock + rollout_idx = rollout_queue.get_nowait() + except queue.Empty: + to_main.put(ExitMsg()) + return + + seed = start_seed + rollout_idx + if not render: + # create the video subdirectory for the rollout + sub = os.path.join(main_dir, f"rollout_{rollout_idx:04d}_seed_{seed}") + os.makedirs(sub, exist_ok=True) + video_writers: dict[str, Any] | None = {} + for cam in DEFAULT_CAMERA_NAMES: + path = os.path.join(sub, f"{cam}.mp4") + video_writers[cam] = imageio.get_writer(path, fps=20) + else: + video_writers = None + + # create the environment + env = create_env( + env_name, + split=split, + seed=seed, + render_onscreen=render, + camera_names=list(DEFAULT_CAMERA_NAMES), + camera_widths=DEFAULT_CAMERA_WIDTH, + camera_heights=DEFAULT_CAMERA_HEIGHT, + has_offscreen_renderer=not render, + use_camera_obs=not render, + ) + try: + # reset the environment and get the initial observation and action dimension + obs = env.reset() + # flip the image observations as mujoco returns flipped images + obs = flip_image_obs(obs, DEFAULT_CAMERA_NAMES) + # get the action dimension and store it in the action dimension holder + with action_dim_lock: + if action_dim_holder[0] is None: + ad = env.action_dim + if ad is None: + raise RuntimeError("env.action_dim is None after reset()") + action_dim_holder[0] = ad + step_count = 0 + + while True: + if render: + images: dict[str, Any] = {} + else: + # encode the image observations as JPEG + images = encode_all_cameras_jpeg(obs, DEFAULT_CAMERA_NAMES, jpeg_quality=jpeg_quality) + # write the image observations to the video writers + if video_writers is not None: + for cam in DEFAULT_CAMERA_NAMES: + cam_key = f"{cam}_image" + if cam_key in obs: + video_writers[cam].append_data(obs[cam_key]) + + # build the state vector in desried order and get the task prompt + state = build_proprio_vector(obs).tolist() + prompt = get_task_prompt(env) + payload_obs = {"images": images, "state": state, "prompt": prompt} + + # send the payload to the main thread + to_main.put(ObsMsg(payload=payload_obs)) + # get the action from the main thread + action = from_main.get() + # convert the action to a numpy array and convert the action to the desired range + action = np.asarray(action, dtype=np.float64).ravel() + # build action vector in desired order + action = convert_action_pi05(action) + + # check if the action dimension is correct + ad = action_dim_holder[0] + assert ad is not None + if action.shape[0] != ad: + raise ValueError(f"Policy returned action dim {action.shape[0]}, expected {ad}") + + # step the environment and get the new observation + obs, _r, _d, _i = env.step(action) + # flip the image observations as mujoco returns flipped images + obs = flip_image_obs(obs, DEFAULT_CAMERA_NAMES) + step_count += 1 + + # check if the episode is over + episode_over = bool(env._check_success()) or ( + max_episode_steps is not None and step_count >= max_episode_steps + ) + if episode_over: + success = bool(env._check_success()) + to_main.put( + DoneMsg( + rollout_idx=rollout_idx, + length=step_count, + success=success, + ) + ) + break + finally: + if video_writers is not None: + for w in video_writers.values(): + w.close() + env.close() + + +async def _run_coordinator( + *, + ws_uri: str, + n_workers: int, + to_mains: list[queue.Queue[WorkerToMain]], + from_mains: list[queue.Queue[np.ndarray]], + results_by_rollout: dict[int, tuple[int, bool]], + results_lock: threading.Lock, +) -> None: + """Batch observations from all active workers each timestep and drive the WebSocket. + + For each timestep, concurrently reads from each active worker until each has + produced one ``ObsMsg`` (``DoneMsg`` is consumed and recorded without blocking + others) or ``ExitMsg``. Builds one MessagePack batch ``{batch: true, items: ...}``, + sends it to the policy server, and distributes returned actions back to workers + via ``from_mains``. This ordering avoids deadlock when one worker is between + episodes while others already have the next observation. + + Args: + ws_uri: WebSocket URI (e.g. ``ws://host:port``). + n_workers: Number of parallel worker threads. + to_mains: Per-worker queues from workers to this coordinator. + from_mains: Per-worker queues from coordinator to workers (actions). + results_by_rollout: Mutable map ``rollout_idx -> (length, success)`` for ``DoneMsg``. + results_lock: Lock protecting ``results_by_rollout``. + """ + loop = asyncio.get_event_loop() + + def _get(q: queue.Queue[WorkerToMain]) -> WorkerToMain: + """Block until ``q`` delivers the next worker-to-main message.""" + + return q.get() + + async def _drain_to_obs_or_exit(wid: int) -> tuple[int, ObsMsg | None, bool]: + """Drain a worker queue until the next ``ObsMsg`` or thread exit. + + Skips ``DoneMsg`` (records results) until an observation or ``ExitMsg``. + + Args: + wid: Worker index (0 .. ``n_workers`` - 1). + + Returns: + Tuple of ``(worker_id, observation_message_or_none, is_exit)``. If + ``is_exit`` is True, the worker has finished; ``observation_message`` + is None. Otherwise ``observation_message`` is the ``ObsMsg`` to batch. + """ + while True: + msg = await loop.run_in_executor(None, _get, to_mains[wid]) + if isinstance(msg, ExitMsg): + return (wid, None, True) + if isinstance(msg, DoneMsg): + with results_lock: + results_by_rollout[msg.rollout_idx] = (msg.length, msg.success) + continue + if isinstance(msg, ObsMsg): + return (wid, msg, False) + raise TypeError(f"Unexpected message: {type(msg)}") + + async with websockets.connect( + ws_uri, + max_size=None, + ping_timeout=None, + ) as websocket: + active: set[int] = set(range(n_workers)) + + while active: + wids = sorted(active) + # gather the observations from the active workers + gathered = await asyncio.gather(*[_drain_to_obs_or_exit(wid) for wid in wids]) + + for wid, _obs, is_exit in gathered: + if is_exit: + # remove the finished worker from the active set + active.discard(wid) + + batch_pairs: list[tuple[int, ObsMsg]] = [ + (wid, om) for (wid, om, ex) in gathered if not ex and om is not None + ] + batch_pairs.sort(key=lambda x: x[0]) + + if not batch_pairs: + if not active: + break + raise RuntimeError("internal error: no observations to batch but workers are still active") + + batch_items = [om.payload for _wid, om in batch_pairs] + batch_workers = [wid for wid, _om in batch_pairs] + batch_size = len(batch_items) + # batch_size is often < n_workers as workers finish rollouts and exit. + + batch_payload = {"batch": True, "items": batch_items} + await websocket.send(msgpack.packb(batch_payload, use_bin_type=True)) + raw = await websocket.recv() + actions_batch = msgpack.unpackb(raw, raw=False) + actions_batch = _normalize_batched_actions_response(actions_batch, batch_size) + + for wid, act in zip(batch_workers, actions_batch, strict=False): + a = np.asarray(act, dtype=np.float64).ravel() + from_mains[wid].put(a) + + +async def run_policy_loop_threaded( + *, + ws_uri: str, + env_name: str, + split, + start_seed: int, + num_rollouts: int, + num_parallel: int, + output_dir: str | None, + jpeg_quality: int, + max_episode_steps: int | None, + render: bool = False, +) -> None: + """Run threaded RoboCasa rollouts against a batched policy WebSocket server. + + Spawns up to ``min(num_parallel, num_rollouts)`` worker threads, coordinates + batched policy calls in ``_run_coordinator``, then writes ``rollouts.json`` + under the output directory. + + Args: + ws_uri: WebSocket URI of the policy server. + env_name: RoboCasa task name for ``create_env``. + split: Dataset split for ``create_env``. + start_seed: Seed for rollout index 0; rollout ``i`` uses ``start_seed + i``. + num_rollouts: Total number of episodes to run. + num_parallel: Maximum parallel env threads (capped by ``num_rollouts``). + output_dir: Output root; default is ``{env_name}_async_{timestamp}``. + jpeg_quality: JPEG quality for camera encodes when not rendering. + max_episode_steps: Max steps per episode (in addition to env termination). + render: If True, onscreen rendering and no saved videos. + + Raises: + ValueError: If ``num_rollouts`` or ``num_parallel`` is invalid. + RuntimeError: If a worker thread does not exit or a rollout result is missing. + """ + if num_rollouts < 1: + raise ValueError("num_rollouts must be >= 1") + if num_parallel < 1: + raise ValueError("num_parallel must be >= 1") + + # number of threads to be created + n_workers = min(num_parallel, num_rollouts) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + main_dir = output_dir or f"{env_name}_async_{timestamp}" + os.makedirs(main_dir, exist_ok=True) + + print( + f"Output directory: {main_dir!r} — {num_rollouts} rollout(s), " + f"{n_workers} parallel worker thread(s), seeds {start_seed}..{start_seed + num_rollouts - 1}" + ) + + # queue to store the rollout indices + rollout_queue: queue.Queue[int | None] = queue.Queue() + for i in range(num_rollouts): + rollout_queue.put(i) + + # queues to send messages from the coordinator to the workers and from the workers to the coordinator + to_mains: list[queue.Queue[WorkerToMain]] = [queue.Queue() for _ in range(n_workers)] + from_mains: list[queue.Queue[np.ndarray]] = [queue.Queue() for _ in range(n_workers)] + + # dictionary to store the results by rollout index + results_by_rollout: dict[int, tuple[int, bool]] = {} + # lock to synchronize access to the results dictionary + results_lock = threading.Lock() + # holder for the action dimension + action_dim_holder: list[int | None] = [None] + # lock to synchronize access to the action dimension + action_dim_lock = threading.Lock() + # list to store the threads + threads: list[threading.Thread] = [] + for wid in range(n_workers): + t = threading.Thread( + target=_worker_loop, + kwargs={ + "rollout_queue": rollout_queue, + "to_main": to_mains[wid], + "from_main": from_mains[wid], + "env_name": env_name, + "split": split, + "start_seed": start_seed, + "main_dir": main_dir, + "jpeg_quality": jpeg_quality, + "max_episode_steps": max_episode_steps, + "render": render, + "action_dim_holder": action_dim_holder, + "action_dim_lock": action_dim_lock, + }, + name=f"robocasa-env-{wid}", + daemon=True, + ) + threads.append(t) + t.start() + + await _run_coordinator( + ws_uri=ws_uri, + n_workers=n_workers, + to_mains=to_mains, + from_mains=from_mains, + results_by_rollout=results_by_rollout, + results_lock=results_lock, + ) + + for t in threads: + t.join(timeout=600.0) + if t.is_alive(): + raise RuntimeError(f"Worker thread {t.name!r} did not exit in time") + + ad = action_dim_holder[0] + if ad is not None: + print( + f"RoboCasa env={env_name!r} split={split!r} action_dim={ad} " + f"cameras={list(DEFAULT_CAMERA_NAMES)} " + f"({DEFAULT_CAMERA_WIDTH}x{DEFAULT_CAMERA_HEIGHT})" + ) + + rollout_records: list[dict[str, Any]] = [] + for ridx in range(num_rollouts): + if ridx not in results_by_rollout: + raise RuntimeError(f"Missing result for rollout index {ridx}") + length, success = results_by_rollout[ridx] + seed = start_seed + ridx + rollout_records.append( + { + "env_name": env_name, + "seed": seed, + "length": length, + "success": success, + } + ) + print(f"Rollout {ridx + 1}/{num_rollouts} seed={seed} length={length} success={success}") + + summary_path = os.path.join(main_dir, "rollouts.json") + summary = { + "env_name": env_name, + "start_seed": start_seed, + "num_rollouts": num_rollouts, + "num_parallel_envs": n_workers, + "output_directory": os.path.abspath(main_dir), + "rollouts": rollout_records, + } + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2) + print(f"Wrote {summary_path!r}") + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + """Parse command-line arguments for the threaded async RoboCasa client. + + Args: + argv: Argument list; defaults to ``sys.argv`` when ``None``. + + Returns: + Parsed namespace with ``env_name``, ``host``, ``port``, rollout options, etc. + """ + p = argparse.ArgumentParser(description=__doc__) + p.add_argument( + "env_name", + metavar="ENV_NAME", + help="RoboCasa kitchen task (registered class name), same as client.py", + ) + p.add_argument( + "--host", + default=os.environ.get("ROBOCASA_POLICY_HOST", "localhost"), + help=( + "Policy server hostname or IP (default: localhost). " + "Use a real host — not the literal word HOST from examples." + ), + ) + p.add_argument( + "--port", + type=int, + default=int(os.environ.get("ROBOCASA_POLICY_PORT", "8765")), + help="Policy server port (or set ROBOCASA_POLICY_PORT)", + ) + p.add_argument( + "--split", + default="all", + choices=[None, "all", "pretrain", "target"], + help="Dataset split passed to create_env (default: all)", + ) + p.add_argument( + "--seed", + type=int, + default=0, + help="Seed for rollout index 0; rollout i uses seed + i", + ) + p.add_argument( + "--num-rollouts", + type=int, + default=1, + help="Total number of episodes (rollouts) to run", + ) + p.add_argument( + "--num-parallel", + type=int, + default=1, + help="Number of parallel environment threads (capped at num-rollouts)", + ) + p.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory for rollouts.json and per-rollout video subfolders", + ) + p.add_argument("--jpeg-quality", type=int, default=80, help="JPEG quality 0-100") + p.add_argument( + "--max-episode-steps", + type=int, + default=1500, + help="Cap steps per episode (in addition to env success)", + ) + p.add_argument("--render", action="store_true", help="Render onscreen (no videos)") + return p.parse_args(argv) + + +def main(argv=None) -> None: + """CLI entrypoint: parse args and run ``run_policy_loop_threaded``. + + Args: + argv: Optional argument list; forwarded to ``parse_args``. + + Raises: + SystemExit: On invalid ``--num-rollouts``, ``--num-parallel``, or placeholder + ``--host`` value. + """ + args = parse_args(argv) + if args.num_rollouts < 1: + raise SystemExit("error: --num-rollouts must be >= 1") + if args.num_parallel < 1: + raise SystemExit("error: --num-parallel must be >= 1") + host = args.host.strip() + if host.lower() == "host": + raise SystemExit( + "error: --host must be a real hostname or IP (e.g. localhost or 127.0.0.1), " + "not the placeholder HOST." + ) + uri = f"ws://{host}:{args.port}" + asyncio.run( + run_policy_loop_threaded( + ws_uri=uri, + env_name=args.env_name, + split=args.split, + start_seed=args.seed, + num_rollouts=args.num_rollouts, + num_parallel=args.num_parallel, + output_dir=args.output_dir, + jpeg_quality=args.jpeg_quality, + max_episode_steps=args.max_episode_steps, + render=args.render, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/src/opentau/scripts/robocasa/server.py b/src/opentau/scripts/robocasa/server.py new file mode 100644 index 00000000..39aefbb3 --- /dev/null +++ b/src/opentau/scripts/robocasa/server.py @@ -0,0 +1,764 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Policy WebSocket server for RoboCasa ``client.py`` / ``client_async.py`` with OpenTau loading. + +Implements the same wire protocol as ``robocasa_server``, plus **batched** messages for +``robocasa.scripts.client_async``: + + Single (``client.py``): + + Client -> server (MessagePack): ``{ "images": {...}, "state": [...], "prompt": "..." }`` + Server -> client: ``list[float]`` + + Batched (``client_async.py``): + + Client -> server: + ``{ "batch": true, "items": [ { "images": {...}, "state": [...], "prompt": "..." }, ... ] }`` + Server -> client: + ``[ list[float], ... ]`` # one action per item, same order and length as ``items`` + +**OpenTau mode** (default): loads the policy from ``policy.pretrained_path`` in the config. +For single requests, each step calls ``policy.select_action`` (internal action queue). +For **batched** requests, observations are stacked and ``select_action`` runs **once** on the +full batch (same as vector-env rollouts). + +**Stub mode** (``--robocasa_use_stub=true``): small random actions. + +Run:: + + python -m opentau.scripts.robocasa_server_async \\ + --config_path /path/to/train_config.json \\ + --robocasa_action_dim 16 --robocasa_port 8765 + +Dependencies: ``websockets``, ``msgpack``, ``opencv-python`` (optional but recommended for JPEG decode). +""" + +import argparse +import asyncio +import logging +import sys +from dataclasses import asdict +from pprint import pformat +from typing import Any, Callable, Dict, List, Optional, Tuple + +import msgpack +import numpy as np +import torch +from PIL import Image + +try: + import cv2 +except ImportError: + cv2 = None # type: ignore + +import websockets + +from opentau.configs import parser +from opentau.configs.train import TrainPipelineConfig +from opentau.policies.factory import get_policy_class +from opentau.utils.random_utils import set_seed +from opentau.utils.utils import attempt_torch_compile, auto_torch_device, init_logging + +logger = logging.getLogger(__name__) + +# WebSocket server options (defaults; override with --robocasa_* before ``TrainPipelineConfig`` parse). +ROBOCASA_HOST: str = "0.0.0.0" # nosec B104 — default listen; use ``--robocasa_host`` to restrict +ROBOCASA_PORT: int = 8765 +ROBOCASA_ACTION_DIM: int = 16 +ROBOCASA_TORCH_COMPILE: bool = True +ROBOCASA_USE_STUB: bool = False + + +def _parse_robocasa_cli() -> None: + """Parse ``--robocasa_*`` flags into module globals and strip them from ``sys.argv``. + + Must run before ``robocasa_async_main`` so ``argparse`` in the OpenTau config + parser does not see RoboCasa-specific flags. + + Side effects: + Updates ``ROBOCASA_HOST``, ``ROBOCASA_PORT``, ``ROBOCASA_ACTION_DIM``, + ``ROBOCASA_TORCH_COMPILE``, ``ROBOCASA_USE_STUB``, and replaces ``sys.argv`` + with a copy containing only non-RoboCasa arguments. + """ + global ROBOCASA_HOST, ROBOCASA_PORT, ROBOCASA_ACTION_DIM, ROBOCASA_TORCH_COMPILE, ROBOCASA_USE_STUB + + def _bool_arg(value: str) -> bool: + """Parse a string as a boolean (true/1/yes/y, case-insensitive).""" + + return value.lower() in ("true", "1", "yes", "y") + + p = argparse.ArgumentParser(add_help=False) + p.add_argument("--robocasa_host", type=str, default=None) + p.add_argument("--robocasa_port", type=int, default=None) + p.add_argument("--robocasa_action_dim", type=int, default=None) + p.add_argument("--robocasa_torch_compile", type=str, default=None) + p.add_argument("--robocasa_use_stub", type=str, default=None) + args, rest = p.parse_known_args(sys.argv[1:]) + if args.robocasa_host is not None: + ROBOCASA_HOST = args.robocasa_host + if args.robocasa_port is not None: + ROBOCASA_PORT = args.robocasa_port + if args.robocasa_action_dim is not None: + ROBOCASA_ACTION_DIM = args.robocasa_action_dim + if args.robocasa_torch_compile is not None: + ROBOCASA_TORCH_COMPILE = _bool_arg(args.robocasa_torch_compile) + if args.robocasa_use_stub is not None: + ROBOCASA_USE_STUB = _bool_arg(args.robocasa_use_stub) + sys.argv = [sys.argv[0]] + rest + + +# Camera keys must match ``client.DEFAULT_CAMERA_NAMES``; order maps to camera0, camera1, ... +ROBOCASA_CAMERA_ORDER = ( + "robot0_eye_in_hand", + "robot0_agentview_left", + "robot0_agentview_right", +) + + +def jpeg_bytes_to_rgb(jpeg_bytes: bytes) -> np.ndarray: + """Decode a JPEG bytestring to an RGB image array. + + Args: + jpeg_bytes: Raw JPEG file bytes. + + Returns: + ``uint8`` array of shape ``(H, W, 3)`` in RGB order. + + Raises: + RuntimeError: If OpenCV (``cv2``) is not installed. + ValueError: If ``cv2.imdecode`` fails (invalid JPEG). + """ + if cv2 is None: + raise RuntimeError("opencv-python (cv2) is required to decode JPEG images on the server.") + arr = np.frombuffer(jpeg_bytes, dtype=np.uint8) + bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if bgr is None: + raise ValueError("cv2.imdecode failed (invalid JPEG?)") + return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + + +def decode_all_images(images: Dict[str, bytes]) -> Dict[str, np.ndarray]: + """Decode all camera JPEG blobs for ``ROBOCASA_CAMERA_ORDER``. + + Args: + images: Map from camera name to JPEG bytes (must include every key in + ``ROBOCASA_CAMERA_ORDER``). + + Returns: + Map from camera name to ``uint8`` RGB arrays ``(H, W, 3)``. + + Raises: + KeyError: If a required camera key is missing. + """ + out: Dict[str, np.ndarray] = {} + for k in ROBOCASA_CAMERA_ORDER: + if k not in images: + raise KeyError(f"Missing image key {k!r}; got {list(images.keys())}") + out[k] = jpeg_bytes_to_rgb(images[k]) + return out + + +def unpack_payload_dict(data: dict) -> Tuple[Dict[str, bytes], np.ndarray, str]: + """Parse one policy request body (single message or one batch item). + + Args: + data: Decoded MessagePack dict with ``images``, ``state``, and optional + ``prompt``. + + Returns: + Tuple of ``(images_dict, state_vector, prompt_string)``. Image values are + normalized to ``bytes``. + + Raises: + ValueError: If ``images`` is not a dict or ``state`` is not a list. + """ + images = data.get("images") + state = data.get("state") + prompt = data.get("prompt", "") + if not isinstance(images, dict): + raise ValueError("Expected 'images' dict") + if not isinstance(state, list): + raise ValueError("Expected 'state' list") + images = {str(k): (v if isinstance(v, (bytes, bytearray)) else bytes(v)) for k, v in images.items()} + state_vec = np.asarray(state, dtype=np.float64) + prompt_str = str(prompt) if prompt is not None else "" + return images, state_vec, prompt_str + + +def unpack_request(message: bytes) -> Tuple[Dict[str, bytes], np.ndarray, str]: + """Decode a single (non-batched) MessagePack WebSocket frame into observation fields. + + Args: + message: Raw binary MessagePack payload. + + Returns: + Same as ``unpack_payload_dict``. + + Raises: + ValueError: If the top-level value is not a dict. + """ + data = msgpack.unpackb(message, raw=False) + if not isinstance(data, dict): + raise ValueError("Expected dict payload") + return unpack_payload_dict(data) + + +PolicyFn = Callable[ + [Dict[str, np.ndarray], np.ndarray, str, int, np.random.Generator], + np.ndarray, +] + + +def default_policy( + images_rgb: Dict[str, np.ndarray], + state: np.ndarray, + prompt: str, + action_dim: int, + rng: np.random.Generator, +) -> np.ndarray: + """Stub policy: uniform random actions in ``[-0.05, 0.05]``. + + Args: + images_rgb: Per-camera ``uint8`` RGB arrays (unused in stub). + state: Proprioceptive state vector (unused in stub). + prompt: Task text (unused in stub). + action_dim: Flat action size. + rng: NumPy random generator. + + Returns: + ``float64`` array of shape ``(action_dim,)``. + """ + del images_rgb, prompt # unused in stub + _ = state # available for your model + return rng.uniform(-0.05, 0.05, size=(action_dim,)).astype(np.float64) + + +def policy_forward( + images_rgb: Dict[str, np.ndarray], + state: np.ndarray, + prompt: str, + action_dim: int, + rng: np.random.Generator, +) -> np.ndarray: + """Default policy entrypoint; replace implementation while keeping the signature. + + Currently delegates to ``default_policy``. + + Args: + images_rgb: Per-camera ``uint8`` RGB arrays. + state: Proprioceptive state vector. + prompt: Task text. + action_dim: Flat action size. + rng: NumPy random generator. + + Returns: + Flat action vector of shape ``(action_dim,)``. + """ + return default_policy(images_rgb, state, prompt, action_dim, rng) + + +def pack_action(action: np.ndarray) -> bytes: + """Serialize a single flat action as MessagePack bytes for the WebSocket reply. + + Args: + action: 1D action vector (any shape that ravel-s to the policy dimension). + + Returns: + MessagePack-encoded bytes (list of floats). + """ + a = np.asarray(action, dtype=np.float64).ravel() + return msgpack.packb(a.tolist(), use_bin_type=True) + + +def pack_actions_batch(actions: List[np.ndarray]) -> bytes: + """Serialize a list of flat actions for a batched WebSocket reply. + + Args: + actions: One numpy action per batch row, same order as request ``items``. + + Returns: + MessagePack-encoded ``list[list[float]]`` bytes. + """ + return msgpack.packb( + [np.asarray(a, dtype=np.float64).ravel().tolist() for a in actions], + use_bin_type=True, + ) + + +def _numpy_rgb_to_camera_tensor( + rgb_uint8: np.ndarray, + resolution: tuple[int, int], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """Resize RGB ``uint8`` image to policy resolution and produce a CHW batch slice. + + Args: + rgb_uint8: Image array ``(H, W, 3)`` RGB. + resolution: Target ``(height, width)`` as in config (H, W). + device: Torch device for the output tensor. + dtype: Floating dtype for normalized pixels. + + Returns: + Tensor of shape ``(1, 3, H, W)`` with values in ``[0, 1]``. + """ + pil = Image.fromarray(rgb_uint8) + pil = pil.resize((resolution[1], resolution[0]), Image.Resampling.BILINEAR) + arr = np.asarray(pil, dtype=np.float32) / 255.0 + t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) + return t.to(device=device, dtype=dtype) + + +def build_opentau_batch( + cfg: TrainPipelineConfig, + images_rgb: Dict[str, np.ndarray], + state_vec: np.ndarray, + prompt: str, + device: torch.device, + dtype: torch.dtype, +) -> dict[str, torch.Tensor]: + """Build a single-row OpenTau policy batch from one RoboCasa observation. + + Args: + cfg: Training pipeline config (camera count, resolution, state dim, etc.). + images_rgb: Decoded RGB arrays keyed by ``ROBOCASA_CAMERA_ORDER`` names. + state_vec: Proprio vector; padded or truncated to ``cfg.max_state_dim``. + prompt: Task string. + device: Torch device. + dtype: Floating dtype for tensors. + + Returns: + Dict of tensors including ``camera*``, ``state``, ``prompt``, ``img_is_pad``. + """ + num_cams = cfg.num_cams + resolution = cfg.resolution + batch: dict[str, torch.Tensor] = {} + img_is_pad: list[bool] = [] + + for cam_idx in range(num_cams): + if cam_idx < len(ROBOCASA_CAMERA_ORDER): + key = ROBOCASA_CAMERA_ORDER[cam_idx] + rgb = images_rgb[key] + batch[f"camera{cam_idx}"] = _numpy_rgb_to_camera_tensor(rgb, resolution, device, dtype) + img_is_pad.append(False) + else: + batch[f"camera{cam_idx}"] = torch.zeros((1, 3, *resolution), dtype=dtype, device=device) + img_is_pad.append(True) + + state_list = state_vec.astype(np.float64).ravel().tolist() + if len(state_list) < cfg.max_state_dim: + state_list.extend([0.0] * (cfg.max_state_dim - len(state_list))) + state_list = state_list[: cfg.max_state_dim] + batch["state"] = torch.tensor([state_list], dtype=dtype, device=device) + raw_prompt = prompt.strip() if prompt else "" + batch["prompt"] = [str(raw_prompt) or ""] + batch["img_is_pad"] = torch.tensor([img_is_pad], dtype=torch.bool, device=device) + return batch + + +def build_opentau_batch_multi( + cfg: TrainPipelineConfig, + items: List[Tuple[Dict[str, np.ndarray], np.ndarray, str]], + device: torch.device, + dtype: torch.dtype, +) -> dict[str, torch.Tensor]: + """Stack multiple decoded observations into one OpenTau batch of batch size ``B``. + + Args: + cfg: Training pipeline config. + items: List of ``(images_rgb, state_vec, prompt)`` per environment row. + device: Torch device. + dtype: Floating dtype for tensors. + + Returns: + Batched tensor dict for ``select_action`` / ``sample_actions``. + + Raises: + ValueError: If ``items`` is empty. + """ + b = len(items) + if b == 0: + raise ValueError("empty batch") + num_cams = cfg.num_cams + resolution = cfg.resolution + batch: dict[str, torch.Tensor] = {} + img_is_pad_rows: list[list[bool]] = [] + + for cam_idx in range(num_cams): + if cam_idx < len(ROBOCASA_CAMERA_ORDER): + key = ROBOCASA_CAMERA_ORDER[cam_idx] + cam_tensors: list[torch.Tensor] = [] + for images_rgb, _state_vec, _prompt in items: + rgb = images_rgb[key] + t = _numpy_rgb_to_camera_tensor(rgb, resolution, device, dtype) + cam_tensors.append(t.squeeze(0)) + batch[f"camera{cam_idx}"] = torch.stack(cam_tensors, dim=0) + img_is_pad_rows.append([False] * b) + else: + batch[f"camera{cam_idx}"] = torch.zeros((b, 3, *resolution), dtype=dtype, device=device) + img_is_pad_rows.append([True] * b) + + img_is_pad_arr = np.array(img_is_pad_rows, dtype=bool).T + batch["img_is_pad"] = torch.tensor(img_is_pad_arr, dtype=torch.bool, device=device) + + state_rows: list[list[float]] = [] + prompts: list[str] = [] + for _images_rgb, state_vec, prompt in items: + state_list = state_vec.astype(np.float64).ravel().tolist() + if len(state_list) < cfg.max_state_dim: + state_list.extend([0.0] * (cfg.max_state_dim - len(state_list))) + state_list = state_list[: cfg.max_state_dim] + state_rows.append(state_list) + raw_prompt = prompt.strip() if prompt else "" + prompts.append(str(raw_prompt) or "") + + batch["state"] = torch.tensor(state_rows, dtype=dtype, device=device) + batch["prompt"] = prompts + return batch + + +class OpenTauRoboCasaPolicy: + """Loads an OpenTau policy from ``TrainPipelineConfig`` and runs inference.""" + + def __init__( + self, + cfg: TrainPipelineConfig, + *, + compile_model: bool = True, + seed: int | None = None, + ) -> None: + """Load the policy from ``cfg.policy.pretrained_path`` and warm up inference. + + Args: + cfg: Full training pipeline config (policy type, resolution, state dim, etc.). + compile_model: If True, apply ``torch.compile`` to ``sample_actions`` when + supported. + seed: Optional RNG seed applied before construction via ``set_seed``. + + Side effects: + Loads weights, moves the model to ``auto_torch_device()`` in bfloat16, + runs two dummy ``sample_actions`` calls for warmup, then ``reset()`` again. + """ + self.cfg = cfg + self.device = auto_torch_device() + self.dtype = torch.bfloat16 + if seed is not None: + set_seed(seed) + + logger.info("Loading OpenTau policy type=%s from %s", cfg.policy.type, cfg.policy.pretrained_path) + policy_class = get_policy_class(cfg.policy.type) + self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=cfg.policy) + self.policy.to(device=self.device, dtype=self.dtype) + self.policy.eval() + + if compile_model: + self.policy.model.sample_actions = attempt_torch_compile( + self.policy.model.sample_actions, device_hint=self.device + ) + + self.policy.reset() + + dummy = build_opentau_batch( + cfg, + { + ROBOCASA_CAMERA_ORDER[0]: np.zeros((*cfg.resolution, 3), dtype=np.uint8), + ROBOCASA_CAMERA_ORDER[1]: np.zeros((*cfg.resolution, 3), dtype=np.uint8), + ROBOCASA_CAMERA_ORDER[2]: np.zeros((*cfg.resolution, 3), dtype=np.uint8), + }, + np.zeros(cfg.max_state_dim, dtype=np.float64), + "warmup", + self.device, + self.dtype, + ) + with torch.inference_mode(): + _ = self.policy.sample_actions(dummy) + _ = self.policy.sample_actions(dummy) + self.policy.reset() + logger.info("OpenTau policy ready on %s", self.device) + + def infer( + self, + images_rgb: Dict[str, np.ndarray], + state_vec: np.ndarray, + prompt: str, + action_dim: int, + ) -> np.ndarray: + """Run a single-environment forward pass and return a fixed-length action. + + Args: + images_rgb: Decoded RGB images per camera. + state_vec: Proprio state; padded to ``cfg.max_state_dim`` in the batch. + prompt: Task string. + action_dim: Desired output length (RoboCasa / CLI); policy output is + truncated or zero-padded to this size. + + Returns: + ``float64`` vector of shape ``(action_dim,)``. + """ + batch = build_opentau_batch(self.cfg, images_rgb, state_vec, prompt, self.device, self.dtype) + with torch.inference_mode(): + act = self.policy.select_action(batch) + step0 = act.squeeze(0).to("cpu", torch.float32).numpy().ravel() + policy_adim = step0.shape[0] + out = np.zeros(action_dim, dtype=np.float64) + n = min(action_dim, policy_adim) + out[:n] = step0[:n].astype(np.float64) + return out + + def infer_batch( + self, + decoded_items: List[Tuple[Dict[str, np.ndarray], np.ndarray, str]], + action_dim: int, + ) -> List[np.ndarray]: + """Run one batched ``select_action`` for multiple observations. + + Args: + decoded_items: One triple ``(images_rgb, state_vec, prompt)`` per batch row. + action_dim: Target flat size per row (truncate or pad each output). + + Returns: + List of ``action_dim``-length ``float64`` arrays, same order as ``decoded_items``. + + Raises: + ValueError: If the policy output batch size is smaller than the number of + input rows. + """ + batch = build_opentau_batch_multi(self.cfg, decoded_items, self.device, self.dtype) + b = len(decoded_items) + with torch.inference_mode(): + act = self.policy.select_action(batch) + act_np = act.to("cpu", torch.float32).numpy() + if act_np.ndim == 1: + act_np = act_np.reshape(1, -1) + # Some policies pad to a fixed max batch size; only return rows for this request. + if act_np.shape[0] > b: + act_np = act_np[:b] + elif act_np.shape[0] < b: + raise ValueError(f"Policy returned batch dim {act_np.shape[0]} < input batch {b}") + outs: List[np.ndarray] = [] + for row in range(act_np.shape[0]): + step0 = act_np[row].ravel() + policy_adim = step0.shape[0] + out = np.zeros(action_dim, dtype=np.float64) + n = min(action_dim, policy_adim) + out[:n] = step0[:n].astype(np.float64) + outs.append(out) + return outs + + +def make_handler( + action_dim: int, + policy: PolicyFn, + rng: np.random.Generator, + opentau_runner: Optional[OpenTauRoboCasaPolicy] = None, +): + """Build the asyncio WebSocket handler for single and batched policy requests. + + Args: + action_dim: Expected flat action size for validation and zero-fill on errors. + policy: Callable used when ``opentau_runner`` is None (stub or custom). + rng: NumPy generator passed to ``policy`` for stochastic stubs. + opentau_runner: If set, batch paths use ``OpenTauRoboCasaPolicy.infer_batch`` + for efficiency; otherwise batch is looped with ``policy``. + + Returns: + An async function suitable for ``websockets.serve`` that reads MessagePack + frames and sends MessagePack-encoded actions. + """ + + async def _handler(websocket: Any): + """Handle one WebSocket connection: MessagePack in, MessagePack actions out.""" + + async for message in websocket: + try: + data = msgpack.unpackb(message, raw=False) + if not isinstance(data, dict): + raise ValueError("Expected dict payload") + + if data.get("batch") is True: + items = data.get("items") + if not isinstance(items, list): + raise ValueError("Batch request requires 'items' list") + decoded: List[Tuple[Dict[str, np.ndarray], np.ndarray, str]] = [] + for item in items: + if not isinstance(item, dict): + raise ValueError("Each batch item must be a dict") + images_jpeg, state, prompt = unpack_payload_dict(item) + images_rgb = decode_all_images(images_jpeg) + decoded.append((images_rgb, state, prompt)) + + if opentau_runner is not None: + actions_out = opentau_runner.infer_batch(decoded, action_dim) + else: + actions_out = [] + for images_rgb, state, prompt in decoded: + action = policy(images_rgb, state, prompt, action_dim, rng) + if action.shape[0] != action_dim: + raise ValueError( + f"Policy returned shape {action.shape}, expected ({action_dim},)" + ) + actions_out.append(action) + + await websocket.send(pack_actions_batch(actions_out)) + else: + images_jpeg, state, prompt = unpack_payload_dict(data) + images_rgb = decode_all_images(images_jpeg) + action = policy(images_rgb, state, prompt, action_dim, rng) + if action.shape[0] != action_dim: + raise ValueError(f"Policy returned shape {action.shape}, expected ({action_dim},)") + await websocket.send(pack_action(action)) + except Exception as e: + logger.exception("Policy step failed: %s", e) + try: + data = msgpack.unpackb(message, raw=False) + except Exception: + data = None + if isinstance(data, dict) and data.get("batch") is True: + items = data.get("items") if isinstance(data.get("items"), list) else [] + n = len(items) + zeros = [np.zeros(action_dim, dtype=np.float64).tolist() for _ in range(n)] + await websocket.send(msgpack.packb(zeros, use_bin_type=True)) + else: + await websocket.send(pack_action(np.zeros(action_dim, dtype=np.float64))) + + return _handler + + +def make_opentau_handler(runner: OpenTauRoboCasaPolicy) -> PolicyFn: + """Adapt ``OpenTauRoboCasaPolicy`` to the generic ``PolicyFn`` signature. + + The returned callable ignores ``rng`` (deterministic inference). + + Args: + runner: Loaded policy wrapper. + + Returns: + A ``PolicyFn`` that forwards to ``OpenTauRoboCasaPolicy.infer``. + """ + + def _policy( + images_rgb: Dict[str, np.ndarray], + state: np.ndarray, + prompt: str, + adim: int, + rng: np.random.Generator, + ) -> np.ndarray: + """Single-env policy shim; ``rng`` is unused.""" + + del rng + return runner.infer(images_rgb, state, prompt, adim) + + return _policy + + +async def run_server( + host: str, + port: int, + action_dim: int, + policy: Optional[PolicyFn] = None, + seed: int = 0, + opentau_runner: Optional[OpenTauRoboCasaPolicy] = None, +) -> None: + """Start the WebSocket server and block until the process is interrupted. + + Args: + host: Bind address (e.g. ``0.0.0.0``). + port: TCP port. + action_dim: Flat action dimension for validation and error fallbacks. + policy: Policy callable for non-OpenTau or stub mode; defaults to + ``policy_forward``. + seed: Seed for the numpy RNG used by stub / default policy. + opentau_runner: When non-None, batched requests use ``infer_batch`` on this + object; single requests still go through ``policy`` (typically from + ``make_opentau_handler``). + + Note: + Uses ``ping_timeout=None`` and ``max_size=None`` for large payloads and + long inference times. Runs until cancelled (infinite ``asyncio.Future``). + """ + rng = np.random.default_rng(seed) + pol: PolicyFn = policy if policy is not None else policy_forward + + async with websockets.serve( + make_handler(action_dim, pol, rng, opentau_runner=opentau_runner), + host, + port, + max_size=None, + ping_timeout=None, + ): + print( + f"RoboCasa policy server (async/batch) listening on ws://{host}:{port} " + f"(action_dim={action_dim}). Waiting for client…" + ) + await asyncio.Future() # run forever + + +@parser.wrap() +def robocasa_async_main(cfg: TrainPipelineConfig) -> None: + """CLI entry: parse config, optionally load OpenTau policy, and run ``run_server``. + + Honors module globals set by ``_parse_robocasa_cli`` (host, port, action + dimension, stub vs OpenTau, torch compile). When ``ROBOCASA_USE_STUB`` is True, + uses ``policy_forward``; otherwise builds ``OpenTauRoboCasaPolicy`` and + a handler from ``make_opentau_handler``. + + Args: + cfg: Parsed ``TrainPipelineConfig`` from OpenTau's argparse (includes + ``policy``, ``seed``, etc.). + """ + logging.basicConfig(level=logging.INFO) + logging.info( + "%s\nRoboCasa globals: host=%s port=%s action_dim=%s torch_compile=%s use_stub=%s", + pformat(asdict(cfg)), + ROBOCASA_HOST, + ROBOCASA_PORT, + ROBOCASA_ACTION_DIM, + ROBOCASA_TORCH_COMPILE, + ROBOCASA_USE_STUB, + ) + + if cfg.seed is not None: + set_seed(cfg.seed) + + seed = int(cfg.seed) if cfg.seed is not None else 0 + + policy_fn: PolicyFn + runner: Optional[OpenTauRoboCasaPolicy] = None + if ROBOCASA_USE_STUB: + policy_fn = policy_forward + else: + runner = OpenTauRoboCasaPolicy( + cfg, + compile_model=ROBOCASA_TORCH_COMPILE, + seed=cfg.seed, + ) + policy_fn = make_opentau_handler(runner) + + asyncio.run( + run_server( + host=ROBOCASA_HOST, + port=ROBOCASA_PORT, + action_dim=ROBOCASA_ACTION_DIM, + policy=policy_fn, + seed=seed, + opentau_runner=runner, + ) + ) + + +if __name__ == "__main__": + _parse_robocasa_cli() + init_logging() + robocasa_async_main() From 54de0a7510c309598683bc306f758c8c2e20ad86 Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Wed, 25 Mar 2026 20:59:59 -0700 Subject: [PATCH 2/6] cursor review --- src/opentau/scripts/robocasa/server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/opentau/scripts/robocasa/server.py b/src/opentau/scripts/robocasa/server.py index 39aefbb3..830be9fa 100644 --- a/src/opentau/scripts/robocasa/server.py +++ b/src/opentau/scripts/robocasa/server.py @@ -470,6 +470,7 @@ def __init__( self.policy.reset() + # batch for warmup inference dummy = build_opentau_batch( cfg, { @@ -483,6 +484,7 @@ def __init__( self.dtype, ) with torch.inference_mode(): + # two warmup calls are needed right after compiling _ = self.policy.sample_actions(dummy) _ = self.policy.sample_actions(dummy) self.policy.reset() @@ -737,8 +739,10 @@ def robocasa_async_main(cfg: TrainPipelineConfig) -> None: policy_fn: PolicyFn runner: Optional[OpenTauRoboCasaPolicy] = None if ROBOCASA_USE_STUB: + # policy to output random actions policy_fn = policy_forward else: + # initialize runner with loads model from config runner = OpenTauRoboCasaPolicy( cfg, compile_model=ROBOCASA_TORCH_COMPILE, From cd71771b387c2efb929fa4aef116483e8256a8de Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Thu, 26 Mar 2026 10:35:48 -0700 Subject: [PATCH 3/6] Add Documentation and trained checkpoints to readme --- README.md | 8 ++ docs/source/tutorials.rst | 1 + docs/source/tutorials/robocasa.rst | 177 +++++++++++++++++++++++++++++ 3 files changed, 186 insertions(+) create mode 100644 docs/source/tutorials/robocasa.rst diff --git a/README.md b/README.md index b5c12370..fd651944 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,10 @@ We provide fully functioning $\pi_{0.5}$ checkpoints trained with high success r | Model Checkpoint | Description | Success Rate (%) | |-------------------------------|---------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------| +| [TensorAuto/Robocasa_navigatekitchen][12] | A $\pi_{0.5}$ model checkpoint trained on Navigate to Kitchen objects task on Robocasa. | 97% | +| [TensorAuto/Robocasa_Closeupdown][11] | A $\pi_{0.5}$ model checkpoint trained on Close Oven, Close Toaster and Close Dishwasher on Robocasa. | Close Oven : 90%
Close Toaster : 70%
Close Dishwasher : 90% | +| [TensorAuto/TensorAuto/robocasa_Closesideways][10]| A $\pi_{0.5}$ model checkpoint trained on Close Microwave, Close Cabinet and Close Fridge on Robocasa. | Close Microwave : 97%
Close Cabinet : 65%
Close Fridge : 80% | +| [TensorAuto/pi05_libero_continuous_state][9] | A $\pi_{0.5}$ model checkpoint trained on Libero dataset with continuous actions. | 92% | | [TensorAuto/moka_pot_libero_sft][6]
[TensorAuto/moka_pot_RECAP_R0][7]
[TensorAuto/moka_pot_RECAP_R1][8] | A $\pi_{0}$ RECAP model checkpoint trained on moka pot task on libero. | 83%
89%
90% | | [TensorAuto/tPi0.5-libero][2] | A $\pi_{0.5}$ model checkpoint trained on the LIBERO dataset with discrete actions and knowledge insulation. | 98.4% (10)
97.6% (Goal)
100% (Object)
98% (Spatial) | | [TensorAuto/pi05_base][5] | A $\pi_{0.5}$ model checkpoint converted from the official openpi checkpoint, with language embeddings added. | N/A | @@ -81,3 +85,7 @@ This project builds on the $\pi$ series of [papers][3] and many other open-sourc [6]: https://huggingface.co/TensorAuto/moka_pot_libero_sft [7]: https://huggingface.co/TensorAuto/moka_pot_RECAP_R0 [8]: https://huggingface.co/TensorAuto/moka_pot_RECAP_R1 +[9]: https://huggingface.co/TensorAuto/pi05_libero_continuous_state +[10]: https://huggingface.co/TensorAuto/robocasa_Closesideways +[11]: https://huggingface.co/TensorAuto/Robocasa_Closeupdown +[12]: https://huggingface.co/TensorAuto/Robocasa_navigatekitchen diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst index 68b596d9..6543151d 100644 --- a/docs/source/tutorials.rst +++ b/docs/source/tutorials.rst @@ -16,3 +16,4 @@ This section provides step-by-step guides for common tasks in OpenTau, including RL tutorials/human_demo tutorials/ros_conversion + tutorials/robocasa diff --git a/docs/source/tutorials/robocasa.rst b/docs/source/tutorials/robocasa.rst new file mode 100644 index 00000000..226e84f5 --- /dev/null +++ b/docs/source/tutorials/robocasa.rst @@ -0,0 +1,177 @@ +.. _robocasa: + +RoboCasa setup and rollout client +================================= + +This page explains how to set up **RoboCasa** (kitchen simulation) alongside **OpenTau**, run the **policy WebSocket server** that serves an OpenTau checkpoint, and run the **batched rollout client** that steps parallel MuJoCo environments and queries the policy in batches. + +.. note:: + Complete the base :doc:`/installation` steps first. RoboCasa itself is installed **outside** the OpenTau package; OpenTau provides the client and server glue only. + +Overview +-------- + +The workflow is split across machines or terminals: + +1. **Simulation + client** — RoboCasa environments, JPEG encoding, and episode logging run where you have the ``robocasa`` Python package and assets (often the same machine as the server for local testing). +2. **Policy server** — Loads ``policy.pretrained_path`` from an OpenTau training config and answers WebSocket requests with MessagePack-encoded actions. + +OpenTau ships: + +* ``opentau.scripts.robocasa.server`` — async WebSocket server (single-observation and **batched** requests). +* ``opentau.scripts.robocasa.client`` — threaded client that batches observations from multiple env workers per timestep. + +Dependencies used by these modules (``websockets``, ``msgpack``) are declared in OpenTau’s ``pyproject.toml``. The server also needs **OpenCV** for JPEG decode on the policy side (``opencv-python`` or ``opencv-python-headless`` is already a core OpenTau dependency). + + +Prerequisites +------------- + +**Hardware and OS** + +* Linux with an NVIDIA GPU is recommended for both RoboCasa (MuJoCo) and OpenTau inference. +* Follow GPU guidance in :doc:`/installation`. + +**Python** + +* OpenTau currently targets **Python 3.10** (see ``requires-python`` in the repo root ``pyproject.toml``). Use the same interpreter for OpenTau and for the environment where you install RoboCasa, or ensure compatibility between the two stacks. + +**RoboCasa simulation** + +RoboCasa is not installed by ``pip install opentau``. Install the simulator and assets from the **upstream project**: + +* `RoboCasa installation `_ + +Typical steps include installing ``robosuite`` (often from source), then ``robocasa`` in editable mode, then running asset download scripts (kitchen assets can be large). Always refer to the official docs for the version you use. + +**OpenTau** + +Install OpenTau from source or PyPI as in :doc:`/installation`. Ensure your environment has the packages required by the scripts above (sync with ``uv sync`` or ``pip install -e .`` from the repo). + + +Policy server (OpenTau) +----------------------- + +The server listens on a WebSocket port and speaks MessagePack. It accepts either a **single** observation dict or a **batch** payload ``{ "batch": true, "items": [ ... ] }`` for parallel clients. + +**Entry point** + +.. code-block:: bash + + python -m opentau.scripts.robocasa.server \ + --config_path /path/to/train_config.json + +**RoboCasa-specific flags** (must appear **before** normal OpenTau config flags; they are parsed first and stripped from ``sys.argv``): + +.. list-table:: + :header-rows: 1 + :widths: 28 72 + + * - Flag + - Meaning + * - ``--robocasa_host`` + - Bind address (default ``0.0.0.0``). Use ``127.0.0.1`` to listen only locally. + * - ``--robocasa_port`` + - TCP port (default ``8765``). + * - ``--robocasa_action_dim`` + - Flat action size passed to the policy and validation (default ``16``; align with your RoboCasa / training setup). + * - ``--robocasa_torch_compile`` + - ``true`` / ``false`` — whether to compile ``sample_actions`` when supported (default ``true``). + * - ``--robocasa_use_stub`` + - ``true`` to use a small random policy instead of loading ``policy.pretrained_path`` (useful for wiring tests without weights). + +**Example** with explicit host and port: + +.. code-block:: bash + + python -m opentau.scripts.robocasa.server \ + --robocasa_host 0.0.0.0 \ + --robocasa_port 8765 \ + --robocasa_action_dim 16 \ + --config_path /path/to/train_config.json + +The training config must define ``policy.pretrained_path`` and compatible policy settings unless you use ``--robocasa_use_stub=true``. + + +Rollout client (RoboCasa + OpenTau) +----------------------------------- + +Copy the code from `opentau.scripts.robocasa.client` to `robocasa.scripts.client` and modify the code to fit the needs. Run the following command in robocasa environment: + +Add this function in `robocasa.utils.env_utils`: + +.. code-block:: python + + def convert_action_pi05(action): + """ + Converts input action (np.array) to format expected by gym env (dict) + """ + action = action.copy() + output_action = { + "action.end_effector_position": action[5:8], + "action.end_effector_rotation": action[8:11], + "action.gripper_close": action[11:12], + "action.base_motion": action[0:4], + "action.control_mode": action[4:5], + } + return np.concatenate([v for k,v in output_action.items()], axis=-1) + +Run the client **after** the server is listening. It registers a RoboCasa task name, spawns one thread per parallel worker (up to ``--num-parallel``), batches observations for each timestep, and writes ``rollouts.json`` plus optional per-camera videos. + +**Entry point** + +.. code-block:: bash + + python -m opentau.scripts.robocasa.client ENV_NAME \ + --host localhost \ + --port 8765 + +Replace ``ENV_NAME`` with a registered RoboCasa kitchen task class name (same as other RoboCasa tooling). + +**Useful options** + +.. list-table:: + :header-rows: 1 + :widths: 28 72 + + * - Option + - Meaning + * - ``--num-rollouts`` + - Total episodes (default ``1``). + * - ``--num-parallel`` + - Parallel env threads (capped by ``--num-rollouts``); batch size per step is at most this value. + * - ``--seed`` + - Base seed; rollout ``i`` uses ``seed + i``. + * - ``--split`` + - Dataset split for ``create_env`` (``all``, ``pretrain``, or ``target``). + * - ``--output-dir`` + - Root for ``rollouts.json`` and ``rollout_*_seed_*`` video folders (default: auto-generated under cwd). + * - ``--max-episode-steps`` + - Step cap per episode (default ``1500``). + * - ``--render`` + - On-screen rendering; disables saved videos. + +**Environment variables** + +* ``ROBOCASA_POLICY_HOST`` — default for ``--host`` (default ``localhost``). +* ``ROBOCASA_POLICY_PORT`` — default for ``--port`` (default ``8765``). + + +Protocol and outputs (short) +------------------------------ + +* **Transport:** WebSocket binary frames, MessagePack payloads. +* **Client → server (batched):** ``{ "batch": true, "items": [ { "images": { camera_name: jpeg_bytes, ... }, "state": [...], "prompt": "..." }, ... ] }``. +* **Server → client:** A list of flat action lists, one per item, same order as ``items``. +* **Client output:** A directory containing ``rollouts.json`` (summary and per-rollout ``seed``, ``length``, ``success``) and, when not using ``--render``, MP4 files per camera under ``rollout_*`` subfolders. + +For full behavioral details (variable batch size as workers finish, JPEG quality, ``ping_timeout``), see the module docstrings in ``src/opentau/scripts/robocasa/client.py`` and ``src/opentau/scripts/robocasa/server.py``. + + +Troubleshooting +--------------- + +* **Import errors for ``robocasa``** — Install and register RoboCasa per upstream docs; the client imports ``robocasa`` and ``robocasa.utils.env_utils``. +* **Server fails on JPEG decode** — Install OpenCV for Python on the server host (``cv2``); without it, JPEG decoding raises at runtime. +* **Port already in use** — Change ``--robocasa_port`` / ``--port`` or stop the conflicting process. +* **Action dimension mismatches** — Align ``--robocasa_action_dim`` with the policy and environment (e.g. PandaOmron / ``convert_action_pi05`` expectations in the client). From 92d72979ff02a7aa928939f45e4b1013cf3425e3 Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Thu, 26 Mar 2026 12:00:19 -0700 Subject: [PATCH 4/6] Updating server and clinet to use sample actions instead to select actions to avoid policy.reset for each new env --- src/opentau/scripts/robocasa/client.py | 278 ++++------------ src/opentau/scripts/robocasa/server.py | 431 ++++--------------------- 2 files changed, 120 insertions(+), 589 deletions(-) diff --git a/src/opentau/scripts/robocasa/client.py b/src/opentau/scripts/robocasa/client.py index fb92bc84..d8366545 100644 --- a/src/opentau/scripts/robocasa/client.py +++ b/src/opentau/scripts/robocasa/client.py @@ -17,8 +17,8 @@ Runs **n_parallel** environment threads. Each thread pulls rollouts from a shared queue until **num_rollouts** episodes are finished. The **main** asyncio loop receives -observations from active workers, batches them into one WebSocket message per timestep, -and routes returned actions back to the corresponding threads. +observations from active workers, batches them into one WebSocket message, and routes +returned action chunks back to the corresponding threads. Batch protocol (MessagePack over WebSocket, binary frames) matches ``client.py`` / ``robocasa.scripts.server``: @@ -31,10 +31,10 @@ ], } - Server -> client: list[list[float]] # one flat action per item, same order as ``items`` + Server -> client: list[list[list[float]]] # one action chunk per item, same order as ``items`` The number of ``items`` (and thus the batch size) is **only** the count of workers -still stepping this timestep. As workers finish their rollout queue and exit, batch +that need a new chunk right now. As workers finish their rollout queue and exit, batch size shrinks from at most ``num_parallel`` down to 1 for the final active worker(s). The policy server must return exactly ``len(items)`` actions, not a fixed width of ``num_parallel``. @@ -60,108 +60,22 @@ from datetime import datetime from typing import Any, Union -import cv2 import imageio import msgpack import numpy as np import websockets import robocasa # noqa: F401 -from robocasa.utils.env_utils import convert_action_pi05, create_env - -# Same three cameras as ``create_env`` defaults / PandaOmron gym wrapper. -DEFAULT_CAMERA_NAMES: tuple[str, ...] = ( - "robot0_eye_in_hand", - "robot0_agentview_left", - "robot0_agentview_right", -) -# Resolution aligned with ``robocasa.wrappers.gym_wrapper.PandaOmronKeyConverter``. -DEFAULT_CAMERA_WIDTH = 256 -DEFAULT_CAMERA_HEIGHT = 256 - -# Flat action layout expected by env (PandaOmron); see also env_utils.convert_action. -ACTION_ORDER = ( - "end_effector_position", # 3 - "end_effector_rotation", # 3 - "gripper_close", # 1 - "base_motion", # 4 - "control_mode", # 1 -) # total 12 - -# Proprio keys aligned with PandaOmronKeyConverter / typical RoboCasa datasets. -DEFAULT_PROPRIO_KEYS = ( - "robot0_base_pos", - "robot0_base_quat", - "robot0_base_to_eef_pos", - "robot0_base_to_eef_quat", - "robot0_gripper_qpos", +from robocasa.scripts.client import ( + DEFAULT_CAMERA_HEIGHT, + DEFAULT_CAMERA_NAMES, + DEFAULT_CAMERA_WIDTH, + build_proprio_vector, + encode_all_cameras_jpeg, + flip_image_obs, + get_task_prompt, ) - - -def get_task_prompt(env) -> str: - """ - Natural-language instruction for the current episode (RoboCasa ``get_ep_meta()['lang']``). - """ - meta = env.get_ep_meta() - if not meta: - return "" - lang = meta.get("lang", "") - if lang is None: - return "" - if isinstance(lang, (list, tuple)): - return " ".join(str(x) for x in lang) - return str(lang) - - -def build_proprio_vector(obs: dict, keys: tuple[str, ...] = DEFAULT_PROPRIO_KEYS) -> np.ndarray: - """Concatenate low-dimensional robot state for policy input.""" - parts = [] - for k in keys: - if k not in obs: - raise KeyError( - f"Observation missing key {k!r}. Available keys (sample): " - f"{[x for x in obs if not x.endswith('_image')][:20]}..." - ) - parts.append(np.asarray(obs[k], dtype=np.float64).ravel()) - return np.concatenate(parts, axis=0) - - -def flip_image_obs(obs: dict, camera_names: tuple[str, ...]) -> dict: - """Flip images vertically since MuJoCo renders upside down.""" - for name in camera_names: - key = f"{name}_image" - if key in obs: - # Copy to ensure the array is contiguous for cv2/imageio - obs[key] = obs[key][::-1].copy() - return obs - - -def encode_camera_rgb_to_jpeg( - obs: dict, - camera_name: str, - jpeg_quality: int = 80, -) -> bytes: - """Encode one camera's RGB observation as JPEG bytes (OpenCV uses BGR).""" - key = f"{camera_name}_image" - if key not in obs: - raise KeyError(f"Missing {key!r}. Ensure create_env includes camera {camera_name!r}.") - rgb = obs[key] - if rgb.ndim != 3 or rgb.shape[-1] != 3: - raise ValueError(f"Expected HxWx3 RGB image for {key}, got shape {rgb.shape}") - bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) - ok, buf = cv2.imencode(".jpg", bgr, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality]) - if not ok: - raise RuntimeError("cv2.imencode failed") - return buf.tobytes() - - -def encode_all_cameras_jpeg( - obs: dict, - camera_names: tuple[str, ...], - jpeg_quality: int = 80, -) -> dict: - """JPEG-encode every listed camera; keys match ``camera_names``.""" - return {name: encode_camera_rgb_to_jpeg(obs, name, jpeg_quality=jpeg_quality) for name in camera_names} +from robocasa.utils.env_utils import convert_action_pi05, create_env @dataclass @@ -194,30 +108,25 @@ def _normalize_batched_actions_response( actions_batch: Any, num_expected: int, ) -> list[Any]: - """Normalize the batched policy response to one action list per observation. - - Args: - actions_batch: Raw decoded batch from the server (typically a ``list``). - num_expected: Number of observations in this batch (length of ``items``). - - Returns: - A list of length ``num_expected``, one action (sequence of floats) per row. + """ + Ensure ``actions_batch`` is a list of length ``num_expected``, one action chunk per batch row. - Raises: - ValueError: If the response is not a list, length cannot be reconciled with - ``num_expected``, or a partial batch cannot be interpreted. + When ``num_expected == 1``, some servers may return one chunk directly + (``list[list[float]]``) or one flat action (``list[float]``) instead of + ``[chunk]``; wrap those cases. - Note: - If ``num_expected == 1``, some servers return a single flat ``list[float]`` - instead of ``[list[float]]``; that case is wrapped. If the server returns - more rows than ``num_expected``, excess rows are dropped (with a one-time - warning). + When the server returns *more* rows than ``num_expected`` (e.g. fixed max batch + width while the client sends a partial batch), the excess rows are dropped. """ global _SERVER_TRUNCATED_ACTION_BATCH_WARNED if not isinstance(actions_batch, list): raise ValueError(f"Batched server response must be a list, got {type(actions_batch).__name__}") if len(actions_batch) == num_expected: return actions_batch + if num_expected == 1 and len(actions_batch) > 0: + first = actions_batch[0] + if isinstance(first, (int, float, np.floating, np.integer, list, tuple, np.ndarray)): + return [actions_batch] if len(actions_batch) > num_expected: if not _SERVER_TRUNCATED_ACTION_BATCH_WARNED: warnings.warn( @@ -229,22 +138,27 @@ def _normalize_batched_actions_response( ) _SERVER_TRUNCATED_ACTION_BATCH_WARNED = True return actions_batch[:num_expected] - # Single-env batch: server may send one flat list[float] instead of [list[float]]. - if num_expected == 1 and len(actions_batch) > 0: - first = actions_batch[0] - if isinstance(first, (int, float, np.floating, np.integer)): - return [actions_batch] raise ValueError( f"Batched actions length {len(actions_batch)} != batch size {num_expected} " f"(partial batches must still return one action list per observation)" ) +def _normalize_action_chunk_for_worker(raw_action_chunk: Any) -> list[np.ndarray]: + """Convert one server row into a list of flat action vectors.""" + arr = np.asarray(raw_action_chunk, dtype=np.float64) + if arr.ndim == 1: + return [arr.ravel()] + if arr.ndim != 2: + raise ValueError(f"Expected action chunk rank 1 or 2, got shape {arr.shape}") + return [arr[i].ravel() for i in range(arr.shape[0])] + + def _worker_loop( *, rollout_queue: queue.Queue[int | None], to_main: queue.Queue[WorkerToMain], - from_main: queue.Queue[np.ndarray], + from_main: queue.Queue[Any], env_name: str, split, start_seed: int, @@ -255,27 +169,7 @@ def _worker_loop( action_dim_holder: list[int | None], action_dim_lock: threading.Lock, ) -> None: - """Run one worker thread: consume rollout indices and step the env until done. - - Pulls rollout IDs from ``rollout_queue``, builds observations (JPEG + state + - prompt), sends ``ObsMsg`` to the coordinator, blocks on ``from_main`` for the - action, steps the environment, and sends ``DoneMsg`` when the episode ends. - Puts ``ExitMsg`` when the queue is empty and the thread exits. - - Args: - rollout_queue: Queue of rollout indices; empty queue means this worker exits. - to_main: Queue to the asyncio coordinator (``ObsMsg``, ``DoneMsg``, ``ExitMsg``). - from_main: Queue from coordinator delivering one flat action vector per step. - env_name: Registered RoboCasa environment name. - split: Dataset split passed to ``create_env``. - start_seed: Base seed; rollout ``i`` uses ``start_seed + i``. - main_dir: Root directory for ``rollouts.json`` and per-rollout video folders. - jpeg_quality: JPEG quality for encoded camera frames (when not rendering). - max_episode_steps: Optional step cap per episode (in addition to success). - render: If True, onscreen render and no video files; else offscreen + videos. - action_dim_holder: Single-element list shared across workers for ``env.action_dim``. - action_dim_lock: Lock protecting ``action_dim_holder`` initialization. - """ + """One thread: sequential rollouts from ``rollout_queue`` until empty.""" while True: try: # get the next rollout index from the queue and its protected by a lock @@ -321,6 +215,7 @@ def _worker_loop( raise RuntimeError("env.action_dim is None after reset()") action_dim_holder[0] = ad step_count = 0 + pending_actions: list[np.ndarray] = [] while True: if render: @@ -340,12 +235,17 @@ def _worker_loop( prompt = get_task_prompt(env) payload_obs = {"images": images, "state": state, "prompt": prompt} - # send the payload to the main thread - to_main.put(ObsMsg(payload=payload_obs)) - # get the action from the main thread - action = from_main.get() - # convert the action to a numpy array and convert the action to the desired range - action = np.asarray(action, dtype=np.float64).ravel() + # request a new policy chunk only when local chunk is exhausted + if len(pending_actions) == 0: + # send the payload to the main thread + to_main.put(ObsMsg(payload=payload_obs)) + raw_action_chunk = from_main.get() + pending_actions = _normalize_action_chunk_for_worker(raw_action_chunk) + if len(pending_actions) == 0: + raise ValueError("Server returned an empty action chunk") + + # take one action from the local chunk + action = pending_actions.pop(0) # build action vector in desired order action = convert_action_pi05(action) @@ -387,47 +287,23 @@ async def _run_coordinator( ws_uri: str, n_workers: int, to_mains: list[queue.Queue[WorkerToMain]], - from_mains: list[queue.Queue[np.ndarray]], + from_mains: list[queue.Queue[Any]], results_by_rollout: dict[int, tuple[int, bool]], results_lock: threading.Lock, ) -> None: - """Batch observations from all active workers each timestep and drive the WebSocket. - - For each timestep, concurrently reads from each active worker until each has - produced one ``ObsMsg`` (``DoneMsg`` is consumed and recorded without blocking - others) or ``ExitMsg``. Builds one MessagePack batch ``{batch: true, items: ...}``, - sends it to the policy server, and distributes returned actions back to workers - via ``from_mains``. This ordering avoids deadlock when one worker is between - episodes while others already have the next observation. - - Args: - ws_uri: WebSocket URI (e.g. ``ws://host:port``). - n_workers: Number of parallel worker threads. - to_mains: Per-worker queues from workers to this coordinator. - from_mains: Per-worker queues from coordinator to workers (actions). - results_by_rollout: Mutable map ``rollout_idx -> (length, success)`` for ``DoneMsg``. - results_lock: Lock protecting ``results_by_rollout``. + """ + For each timestep, read from all active workers in parallel until each has produced + one ``ObsMsg`` (skipping ``DoneMsg``) or ``ExitMsg``. This avoids deadlock when one + worker finishes an episode and is slow to start the next rollout while others already + have the next observation ready. """ loop = asyncio.get_event_loop() def _get(q: queue.Queue[WorkerToMain]) -> WorkerToMain: - """Block until ``q`` delivers the next worker-to-main message.""" - return q.get() async def _drain_to_obs_or_exit(wid: int) -> tuple[int, ObsMsg | None, bool]: - """Drain a worker queue until the next ``ObsMsg`` or thread exit. - - Skips ``DoneMsg`` (records results) until an observation or ``ExitMsg``. - - Args: - wid: Worker index (0 .. ``n_workers`` - 1). - - Returns: - Tuple of ``(worker_id, observation_message_or_none, is_exit)``. If - ``is_exit`` is True, the worker has finished; ``observation_message`` - is None. Otherwise ``observation_message`` is the ``ObsMsg`` to batch. - """ + """Returns (worker_id, ObsMsg or None if exiting, is_exit).""" while True: msg = await loop.run_in_executor(None, _get, to_mains[wid]) if isinstance(msg, ExitMsg): @@ -479,8 +355,7 @@ async def _drain_to_obs_or_exit(wid: int) -> tuple[int, ObsMsg | None, bool]: actions_batch = _normalize_batched_actions_response(actions_batch, batch_size) for wid, act in zip(batch_workers, actions_batch, strict=False): - a = np.asarray(act, dtype=np.float64).ravel() - from_mains[wid].put(a) + from_mains[wid].put(act) async def run_policy_loop_threaded( @@ -496,28 +371,6 @@ async def run_policy_loop_threaded( max_episode_steps: int | None, render: bool = False, ) -> None: - """Run threaded RoboCasa rollouts against a batched policy WebSocket server. - - Spawns up to ``min(num_parallel, num_rollouts)`` worker threads, coordinates - batched policy calls in ``_run_coordinator``, then writes ``rollouts.json`` - under the output directory. - - Args: - ws_uri: WebSocket URI of the policy server. - env_name: RoboCasa task name for ``create_env``. - split: Dataset split for ``create_env``. - start_seed: Seed for rollout index 0; rollout ``i`` uses ``start_seed + i``. - num_rollouts: Total number of episodes to run. - num_parallel: Maximum parallel env threads (capped by ``num_rollouts``). - output_dir: Output root; default is ``{env_name}_async_{timestamp}``. - jpeg_quality: JPEG quality for camera encodes when not rendering. - max_episode_steps: Max steps per episode (in addition to env termination). - render: If True, onscreen rendering and no saved videos. - - Raises: - ValueError: If ``num_rollouts`` or ``num_parallel`` is invalid. - RuntimeError: If a worker thread does not exit or a rollout result is missing. - """ if num_rollouts < 1: raise ValueError("num_rollouts must be >= 1") if num_parallel < 1: @@ -542,7 +395,7 @@ async def run_policy_loop_threaded( # queues to send messages from the coordinator to the workers and from the workers to the coordinator to_mains: list[queue.Queue[WorkerToMain]] = [queue.Queue() for _ in range(n_workers)] - from_mains: list[queue.Queue[np.ndarray]] = [queue.Queue() for _ in range(n_workers)] + from_mains: list[queue.Queue[Any]] = [queue.Queue() for _ in range(n_workers)] # dictionary to store the results by rollout index results_by_rollout: dict[int, tuple[int, bool]] = {} @@ -630,14 +483,6 @@ async def run_policy_loop_threaded( def parse_args(argv: list[str] | None = None) -> argparse.Namespace: - """Parse command-line arguments for the threaded async RoboCasa client. - - Args: - argv: Argument list; defaults to ``sys.argv`` when ``None``. - - Returns: - Parsed namespace with ``env_name``, ``host``, ``port``, rollout options, etc. - """ p = argparse.ArgumentParser(description=__doc__) p.add_argument( "env_name", @@ -700,15 +545,6 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: def main(argv=None) -> None: - """CLI entrypoint: parse args and run ``run_policy_loop_threaded``. - - Args: - argv: Optional argument list; forwarded to ``parse_args``. - - Raises: - SystemExit: On invalid ``--num-rollouts``, ``--num-parallel``, or placeholder - ``--host`` value. - """ args = parse_args(argv) if args.num_rollouts < 1: raise SystemExit("error: --num-rollouts must be >= 1") diff --git a/src/opentau/scripts/robocasa/server.py b/src/opentau/scripts/robocasa/server.py index 830be9fa..d7460e12 100644 --- a/src/opentau/scripts/robocasa/server.py +++ b/src/opentau/scripts/robocasa/server.py @@ -17,24 +17,22 @@ Implements the same wire protocol as ``robocasa_server``, plus **batched** messages for ``robocasa.scripts.client_async``: - Single (``client.py``): + Single request: Client -> server (MessagePack): ``{ "images": {...}, "state": [...], "prompt": "..." }`` - Server -> client: ``list[float]`` + Server -> client: ``[[float, ...], ...]`` # shape ``(T, action_dim)`` Batched (``client_async.py``): Client -> server: ``{ "batch": true, "items": [ { "images": {...}, "state": [...], "prompt": "..." }, ... ] }`` Server -> client: - ``[ list[float], ... ]`` # one action per item, same order and length as ``items`` + ``[ [[float, ...], ...], ... ]`` # per item: full action chunk (``T`` steps × ``action_dim``) **OpenTau mode** (default): loads the policy from ``policy.pretrained_path`` in the config. -For single requests, each step calls ``policy.select_action`` (internal action queue). -For **batched** requests, observations are stacked and ``select_action`` runs **once** on the -full batch (same as vector-env rollouts). - -**Stub mode** (``--robocasa_use_stub=true``): small random actions. +Each request runs ``policy.sample_actions`` (no internal queue on the server). The reply is the +full predicted chunk per environment: shape ``(n_action_steps, action_dim)`` (trimmed/padded to +``--robocasa_action_dim``). **Batched** requests stack observations and call ``sample_actions`` once. Run:: @@ -51,7 +49,7 @@ import sys from dataclasses import asdict from pprint import pformat -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple import msgpack import numpy as np @@ -78,25 +76,16 @@ ROBOCASA_PORT: int = 8765 ROBOCASA_ACTION_DIM: int = 16 ROBOCASA_TORCH_COMPILE: bool = True -ROBOCASA_USE_STUB: bool = False - -def _parse_robocasa_cli() -> None: - """Parse ``--robocasa_*`` flags into module globals and strip them from ``sys.argv``. +# Fallback zero chunk length when sending an error response (client should discard). +_ERROR_RESPONSE_CHUNK_STEPS: int = 8 - Must run before ``robocasa_async_main`` so ``argparse`` in the OpenTau config - parser does not see RoboCasa-specific flags. - Side effects: - Updates ``ROBOCASA_HOST``, ``ROBOCASA_PORT``, ``ROBOCASA_ACTION_DIM``, - ``ROBOCASA_TORCH_COMPILE``, ``ROBOCASA_USE_STUB``, and replaces ``sys.argv`` - with a copy containing only non-RoboCasa arguments. - """ - global ROBOCASA_HOST, ROBOCASA_PORT, ROBOCASA_ACTION_DIM, ROBOCASA_TORCH_COMPILE, ROBOCASA_USE_STUB +def _parse_robocasa_cli() -> None: + """Read ``--robocasa_*`` flags into module globals and remove them from ``sys.argv``.""" + global ROBOCASA_HOST, ROBOCASA_PORT, ROBOCASA_ACTION_DIM, ROBOCASA_TORCH_COMPILE def _bool_arg(value: str) -> bool: - """Parse a string as a boolean (true/1/yes/y, case-insensitive).""" - return value.lower() in ("true", "1", "yes", "y") p = argparse.ArgumentParser(add_help=False) @@ -104,7 +93,6 @@ def _bool_arg(value: str) -> bool: p.add_argument("--robocasa_port", type=int, default=None) p.add_argument("--robocasa_action_dim", type=int, default=None) p.add_argument("--robocasa_torch_compile", type=str, default=None) - p.add_argument("--robocasa_use_stub", type=str, default=None) args, rest = p.parse_known_args(sys.argv[1:]) if args.robocasa_host is not None: ROBOCASA_HOST = args.robocasa_host @@ -114,8 +102,6 @@ def _bool_arg(value: str) -> bool: ROBOCASA_ACTION_DIM = args.robocasa_action_dim if args.robocasa_torch_compile is not None: ROBOCASA_TORCH_COMPILE = _bool_arg(args.robocasa_torch_compile) - if args.robocasa_use_stub is not None: - ROBOCASA_USE_STUB = _bool_arg(args.robocasa_use_stub) sys.argv = [sys.argv[0]] + rest @@ -128,18 +114,7 @@ def _bool_arg(value: str) -> bool: def jpeg_bytes_to_rgb(jpeg_bytes: bytes) -> np.ndarray: - """Decode a JPEG bytestring to an RGB image array. - - Args: - jpeg_bytes: Raw JPEG file bytes. - - Returns: - ``uint8`` array of shape ``(H, W, 3)`` in RGB order. - - Raises: - RuntimeError: If OpenCV (``cv2``) is not installed. - ValueError: If ``cv2.imdecode`` fails (invalid JPEG). - """ + """Decode JPEG bytes to HxWx3 uint8 RGB.""" if cv2 is None: raise RuntimeError("opencv-python (cv2) is required to decode JPEG images on the server.") arr = np.frombuffer(jpeg_bytes, dtype=np.uint8) @@ -150,18 +125,7 @@ def jpeg_bytes_to_rgb(jpeg_bytes: bytes) -> np.ndarray: def decode_all_images(images: Dict[str, bytes]) -> Dict[str, np.ndarray]: - """Decode all camera JPEG blobs for ``ROBOCASA_CAMERA_ORDER``. - - Args: - images: Map from camera name to JPEG bytes (must include every key in - ``ROBOCASA_CAMERA_ORDER``). - - Returns: - Map from camera name to ``uint8`` RGB arrays ``(H, W, 3)``. - - Raises: - KeyError: If a required camera key is missing. - """ + """Decode all camera JPEGs to RGB numpy arrays (uint8, H, W, 3).""" out: Dict[str, np.ndarray] = {} for k in ROBOCASA_CAMERA_ORDER: if k not in images: @@ -171,19 +135,7 @@ def decode_all_images(images: Dict[str, bytes]) -> Dict[str, np.ndarray]: def unpack_payload_dict(data: dict) -> Tuple[Dict[str, bytes], np.ndarray, str]: - """Parse one policy request body (single message or one batch item). - - Args: - data: Decoded MessagePack dict with ``images``, ``state``, and optional - ``prompt``. - - Returns: - Tuple of ``(images_dict, state_vector, prompt_string)``. Image values are - normalized to ``bytes``. - - Raises: - ValueError: If ``images`` is not a dict or ``state`` is not a list. - """ + """Parse one policy request body (single message or one element of a batch).""" images = data.get("images") state = data.get("state") prompt = data.get("prompt", "") @@ -197,102 +149,18 @@ def unpack_payload_dict(data: dict) -> Tuple[Dict[str, bytes], np.ndarray, str]: return images, state_vec, prompt_str -def unpack_request(message: bytes) -> Tuple[Dict[str, bytes], np.ndarray, str]: - """Decode a single (non-batched) MessagePack WebSocket frame into observation fields. - - Args: - message: Raw binary MessagePack payload. - - Returns: - Same as ``unpack_payload_dict``. - - Raises: - ValueError: If the top-level value is not a dict. - """ - data = msgpack.unpackb(message, raw=False) - if not isinstance(data, dict): - raise ValueError("Expected dict payload") - return unpack_payload_dict(data) - - -PolicyFn = Callable[ - [Dict[str, np.ndarray], np.ndarray, str, int, np.random.Generator], - np.ndarray, -] - - -def default_policy( - images_rgb: Dict[str, np.ndarray], - state: np.ndarray, - prompt: str, - action_dim: int, - rng: np.random.Generator, -) -> np.ndarray: - """Stub policy: uniform random actions in ``[-0.05, 0.05]``. - - Args: - images_rgb: Per-camera ``uint8`` RGB arrays (unused in stub). - state: Proprioceptive state vector (unused in stub). - prompt: Task text (unused in stub). - action_dim: Flat action size. - rng: NumPy random generator. - - Returns: - ``float64`` array of shape ``(action_dim,)``. - """ - del images_rgb, prompt # unused in stub - _ = state # available for your model - return rng.uniform(-0.05, 0.05, size=(action_dim,)).astype(np.float64) - - -def policy_forward( - images_rgb: Dict[str, np.ndarray], - state: np.ndarray, - prompt: str, - action_dim: int, - rng: np.random.Generator, -) -> np.ndarray: - """Default policy entrypoint; replace implementation while keeping the signature. - - Currently delegates to ``default_policy``. - - Args: - images_rgb: Per-camera ``uint8`` RGB arrays. - state: Proprioceptive state vector. - prompt: Task text. - action_dim: Flat action size. - rng: NumPy random generator. - - Returns: - Flat action vector of shape ``(action_dim,)``. - """ - return default_policy(images_rgb, state, prompt, action_dim, rng) - - -def pack_action(action: np.ndarray) -> bytes: - """Serialize a single flat action as MessagePack bytes for the WebSocket reply. - - Args: - action: 1D action vector (any shape that ravel-s to the policy dimension). - - Returns: - MessagePack-encoded bytes (list of floats). - """ - a = np.asarray(action, dtype=np.float64).ravel() +def pack_action(action_chunk: np.ndarray) -> bytes: + """MessagePack-encode one env's action chunk ``(T, action_dim)`` as nested lists.""" + a = np.asarray(action_chunk, dtype=np.float64) + if a.ndim != 2: + raise ValueError(f"Expected action chunk of shape (T, action_dim), got shape {a.shape}") return msgpack.packb(a.tolist(), use_bin_type=True) -def pack_actions_batch(actions: List[np.ndarray]) -> bytes: - """Serialize a list of flat actions for a batched WebSocket reply. - - Args: - actions: One numpy action per batch row, same order as request ``items``. - - Returns: - MessagePack-encoded ``list[list[float]]`` bytes. - """ +def pack_actions_batch(chunks: List[np.ndarray]) -> bytes: + """Encode one ``(T, action_dim)`` chunk per batch row (same order as request ``items``).""" return msgpack.packb( - [np.asarray(a, dtype=np.float64).ravel().tolist() for a in actions], + [np.asarray(c, dtype=np.float64).tolist() for c in chunks], use_bin_type=True, ) @@ -303,17 +171,7 @@ def _numpy_rgb_to_camera_tensor( device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: - """Resize RGB ``uint8`` image to policy resolution and produce a CHW batch slice. - - Args: - rgb_uint8: Image array ``(H, W, 3)`` RGB. - resolution: Target ``(height, width)`` as in config (H, W). - device: Torch device for the output tensor. - dtype: Floating dtype for normalized pixels. - - Returns: - Tensor of shape ``(1, 3, H, W)`` with values in ``[0, 1]``. - """ + """RGB uint8 (H,W,3) -> (1,3,H,W) float [0,1] on device.""" pil = Image.fromarray(rgb_uint8) pil = pil.resize((resolution[1], resolution[0]), Image.Resampling.BILINEAR) arr = np.asarray(pil, dtype=np.float32) / 255.0 @@ -329,19 +187,7 @@ def build_opentau_batch( device: torch.device, dtype: torch.dtype, ) -> dict[str, torch.Tensor]: - """Build a single-row OpenTau policy batch from one RoboCasa observation. - - Args: - cfg: Training pipeline config (camera count, resolution, state dim, etc.). - images_rgb: Decoded RGB arrays keyed by ``ROBOCASA_CAMERA_ORDER`` names. - state_vec: Proprio vector; padded or truncated to ``cfg.max_state_dim``. - prompt: Task string. - device: Torch device. - dtype: Floating dtype for tensors. - - Returns: - Dict of tensors including ``camera*``, ``state``, ``prompt``, ``img_is_pad``. - """ + """Map RoboCasa observation dict to OpenTau policy batch (batch size 1).""" num_cams = cfg.num_cams resolution = cfg.resolution batch: dict[str, torch.Tensor] = {} @@ -374,20 +220,7 @@ def build_opentau_batch_multi( device: torch.device, dtype: torch.dtype, ) -> dict[str, torch.Tensor]: - """Stack multiple decoded observations into one OpenTau batch of batch size ``B``. - - Args: - cfg: Training pipeline config. - items: List of ``(images_rgb, state_vec, prompt)`` per environment row. - device: Torch device. - dtype: Floating dtype for tensors. - - Returns: - Batched tensor dict for ``select_action`` / ``sample_actions``. - - Raises: - ValueError: If ``items`` is empty. - """ + """Stack multiple RoboCasa observations into one OpenTau batch (batch size B).""" b = len(items) if b == 0: raise ValueError("empty batch") @@ -439,18 +272,6 @@ def __init__( compile_model: bool = True, seed: int | None = None, ) -> None: - """Load the policy from ``cfg.policy.pretrained_path`` and warm up inference. - - Args: - cfg: Full training pipeline config (policy type, resolution, state dim, etc.). - compile_model: If True, apply ``torch.compile`` to ``sample_actions`` when - supported. - seed: Optional RNG seed applied before construction via ``set_seed``. - - Side effects: - Loads weights, moves the model to ``auto_torch_device()`` in bfloat16, - runs two dummy ``sample_actions`` calls for warmup, then ``reset()`` again. - """ self.cfg = cfg self.device = auto_torch_device() self.dtype = torch.bfloat16 @@ -470,7 +291,6 @@ def __init__( self.policy.reset() - # batch for warmup inference dummy = build_opentau_batch( cfg, { @@ -484,7 +304,6 @@ def __init__( self.dtype, ) with torch.inference_mode(): - # two warmup calls are needed right after compiling _ = self.policy.sample_actions(dummy) _ = self.policy.sample_actions(dummy) self.policy.reset() @@ -497,26 +316,18 @@ def infer( prompt: str, action_dim: int, ) -> np.ndarray: - """Run a single-environment forward pass and return a fixed-length action. - - Args: - images_rgb: Decoded RGB images per camera. - state_vec: Proprio state; padded to ``cfg.max_state_dim`` in the batch. - prompt: Task string. - action_dim: Desired output length (RoboCasa / CLI); policy output is - truncated or zero-padded to this size. - - Returns: - ``float64`` vector of shape ``(action_dim,)``. - """ + """Return full action chunk ``(T, action_dim)`` from ``sample_actions`` (trim/pad last dim).""" batch = build_opentau_batch(self.cfg, images_rgb, state_vec, prompt, self.device, self.dtype) with torch.inference_mode(): - act = self.policy.select_action(batch) - step0 = act.squeeze(0).to("cpu", torch.float32).numpy().ravel() - policy_adim = step0.shape[0] - out = np.zeros(action_dim, dtype=np.float64) + act = self.policy.sample_actions(batch) + # (1, T, policy_dim) + act_np = act.squeeze(0).to("cpu", torch.float32).numpy() + if act_np.ndim != 2: + raise ValueError(f"Expected policy output (T, D), got shape {act_np.shape}") + t_steps, policy_adim = act_np.shape + out = np.zeros((t_steps, action_dim), dtype=np.float64) n = min(action_dim, policy_adim) - out[:n] = step0[:n].astype(np.float64) + out[:, :n] = act_np[:, :n].astype(np.float64) return out def infer_batch( @@ -524,65 +335,32 @@ def infer_batch( decoded_items: List[Tuple[Dict[str, np.ndarray], np.ndarray, str]], action_dim: int, ) -> List[np.ndarray]: - """Run one batched ``select_action`` for multiple observations. - - Args: - decoded_items: One triple ``(images_rgb, state_vec, prompt)`` per batch row. - action_dim: Target flat size per row (truncate or pad each output). - - Returns: - List of ``action_dim``-length ``float64`` arrays, same order as ``decoded_items``. - - Raises: - ValueError: If the policy output batch size is smaller than the number of - input rows. - """ + """One ``sample_actions`` on a stacked batch; one ``(T, action_dim)`` chunk per env.""" batch = build_opentau_batch_multi(self.cfg, decoded_items, self.device, self.dtype) b = len(decoded_items) with torch.inference_mode(): - act = self.policy.select_action(batch) + act = self.policy.sample_actions(batch) act_np = act.to("cpu", torch.float32).numpy() - if act_np.ndim == 1: - act_np = act_np.reshape(1, -1) - # Some policies pad to a fixed max batch size; only return rows for this request. + if act_np.ndim == 2: + act_np = act_np.reshape(1, *act_np.shape) + if act_np.ndim != 3: + raise ValueError(f"Expected policy output (B, T, D), got shape {act_np.shape}") if act_np.shape[0] > b: act_np = act_np[:b] elif act_np.shape[0] < b: raise ValueError(f"Policy returned batch dim {act_np.shape[0]} < input batch {b}") + _, t_steps, policy_adim = act_np.shape outs: List[np.ndarray] = [] - for row in range(act_np.shape[0]): - step0 = act_np[row].ravel() - policy_adim = step0.shape[0] - out = np.zeros(action_dim, dtype=np.float64) + for row in range(b): + out = np.zeros((t_steps, action_dim), dtype=np.float64) n = min(action_dim, policy_adim) - out[:n] = step0[:n].astype(np.float64) + out[:, :n] = act_np[row, :, :n].astype(np.float64) outs.append(out) return outs -def make_handler( - action_dim: int, - policy: PolicyFn, - rng: np.random.Generator, - opentau_runner: Optional[OpenTauRoboCasaPolicy] = None, -): - """Build the asyncio WebSocket handler for single and batched policy requests. - - Args: - action_dim: Expected flat action size for validation and zero-fill on errors. - policy: Callable used when ``opentau_runner`` is None (stub or custom). - rng: NumPy generator passed to ``policy`` for stochastic stubs. - opentau_runner: If set, batch paths use ``OpenTauRoboCasaPolicy.infer_batch`` - for efficiency; otherwise batch is looped with ``policy``. - - Returns: - An async function suitable for ``websockets.serve`` that reads MessagePack - frames and sends MessagePack-encoded actions. - """ - +def make_handler(action_dim: int, runner: OpenTauRoboCasaPolicy): async def _handler(websocket: Any): - """Handle one WebSocket connection: MessagePack in, MessagePack actions out.""" - async for message in websocket: try: data = msgpack.unpackb(message, raw=False) @@ -601,25 +379,12 @@ async def _handler(websocket: Any): images_rgb = decode_all_images(images_jpeg) decoded.append((images_rgb, state, prompt)) - if opentau_runner is not None: - actions_out = opentau_runner.infer_batch(decoded, action_dim) - else: - actions_out = [] - for images_rgb, state, prompt in decoded: - action = policy(images_rgb, state, prompt, action_dim, rng) - if action.shape[0] != action_dim: - raise ValueError( - f"Policy returned shape {action.shape}, expected ({action_dim},)" - ) - actions_out.append(action) - + actions_out = runner.infer_batch(decoded, action_dim) await websocket.send(pack_actions_batch(actions_out)) else: images_jpeg, state, prompt = unpack_payload_dict(data) images_rgb = decode_all_images(images_jpeg) - action = policy(images_rgb, state, prompt, action_dim, rng) - if action.shape[0] != action_dim: - raise ValueError(f"Policy returned shape {action.shape}, expected ({action_dim},)") + action = runner.infer(images_rgb, state, prompt, action_dim) await websocket.send(pack_action(action)) except Exception as e: logger.exception("Policy step failed: %s", e) @@ -630,71 +395,24 @@ async def _handler(websocket: Any): if isinstance(data, dict) and data.get("batch") is True: items = data.get("items") if isinstance(data.get("items"), list) else [] n = len(items) - zeros = [np.zeros(action_dim, dtype=np.float64).tolist() for _ in range(n)] - await websocket.send(msgpack.packb(zeros, use_bin_type=True)) + zero_chunk = [[0.0] * action_dim for _ in range(_ERROR_RESPONSE_CHUNK_STEPS)] + await websocket.send(msgpack.packb([zero_chunk for _ in range(n)], use_bin_type=True)) else: - await websocket.send(pack_action(np.zeros(action_dim, dtype=np.float64))) + await websocket.send( + pack_action(np.zeros((_ERROR_RESPONSE_CHUNK_STEPS, action_dim), dtype=np.float64)) + ) return _handler -def make_opentau_handler(runner: OpenTauRoboCasaPolicy) -> PolicyFn: - """Adapt ``OpenTauRoboCasaPolicy`` to the generic ``PolicyFn`` signature. - - The returned callable ignores ``rng`` (deterministic inference). - - Args: - runner: Loaded policy wrapper. - - Returns: - A ``PolicyFn`` that forwards to ``OpenTauRoboCasaPolicy.infer``. - """ - - def _policy( - images_rgb: Dict[str, np.ndarray], - state: np.ndarray, - prompt: str, - adim: int, - rng: np.random.Generator, - ) -> np.ndarray: - """Single-env policy shim; ``rng`` is unused.""" - - del rng - return runner.infer(images_rgb, state, prompt, adim) - - return _policy - - async def run_server( host: str, port: int, action_dim: int, - policy: Optional[PolicyFn] = None, - seed: int = 0, - opentau_runner: Optional[OpenTauRoboCasaPolicy] = None, + runner: OpenTauRoboCasaPolicy, ) -> None: - """Start the WebSocket server and block until the process is interrupted. - - Args: - host: Bind address (e.g. ``0.0.0.0``). - port: TCP port. - action_dim: Flat action dimension for validation and error fallbacks. - policy: Policy callable for non-OpenTau or stub mode; defaults to - ``policy_forward``. - seed: Seed for the numpy RNG used by stub / default policy. - opentau_runner: When non-None, batched requests use ``infer_batch`` on this - object; single requests still go through ``policy`` (typically from - ``make_opentau_handler``). - - Note: - Uses ``ping_timeout=None`` and ``max_size=None`` for large payloads and - long inference times. Runs until cancelled (infinite ``asyncio.Future``). - """ - rng = np.random.default_rng(seed) - pol: PolicyFn = policy if policy is not None else policy_forward - async with websockets.serve( - make_handler(action_dim, pol, rng, opentau_runner=opentau_runner), + make_handler(action_dim, runner), host, port, max_size=None, @@ -709,55 +427,32 @@ async def run_server( @parser.wrap() def robocasa_async_main(cfg: TrainPipelineConfig) -> None: - """CLI entry: parse config, optionally load OpenTau policy, and run ``run_server``. - - Honors module globals set by ``_parse_robocasa_cli`` (host, port, action - dimension, stub vs OpenTau, torch compile). When ``ROBOCASA_USE_STUB`` is True, - uses ``policy_forward``; otherwise builds ``OpenTauRoboCasaPolicy`` and - a handler from ``make_opentau_handler``. - - Args: - cfg: Parsed ``TrainPipelineConfig`` from OpenTau's argparse (includes - ``policy``, ``seed``, etc.). - """ + """Start the RoboCasa WebSocket policy server (single + batched) using OpenTau config parsing.""" logging.basicConfig(level=logging.INFO) logging.info( - "%s\nRoboCasa globals: host=%s port=%s action_dim=%s torch_compile=%s use_stub=%s", + "%s\nRoboCasa globals: host=%s port=%s action_dim=%s torch_compile=%s", pformat(asdict(cfg)), ROBOCASA_HOST, ROBOCASA_PORT, ROBOCASA_ACTION_DIM, ROBOCASA_TORCH_COMPILE, - ROBOCASA_USE_STUB, ) if cfg.seed is not None: set_seed(cfg.seed) - seed = int(cfg.seed) if cfg.seed is not None else 0 - - policy_fn: PolicyFn - runner: Optional[OpenTauRoboCasaPolicy] = None - if ROBOCASA_USE_STUB: - # policy to output random actions - policy_fn = policy_forward - else: - # initialize runner with loads model from config - runner = OpenTauRoboCasaPolicy( - cfg, - compile_model=ROBOCASA_TORCH_COMPILE, - seed=cfg.seed, - ) - policy_fn = make_opentau_handler(runner) + runner = OpenTauRoboCasaPolicy( + cfg, + compile_model=ROBOCASA_TORCH_COMPILE, + seed=cfg.seed, + ) asyncio.run( run_server( host=ROBOCASA_HOST, port=ROBOCASA_PORT, action_dim=ROBOCASA_ACTION_DIM, - policy=policy_fn, - seed=seed, - opentau_runner=runner, + runner=runner, ) ) From 045c4bf0035dedbcb6a45053553767f0767811db Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Thu, 26 Mar 2026 13:29:18 -0700 Subject: [PATCH 5/6] Removing client script and adding gist link to it --- docs/source/tutorials/robocasa.rst | 145 +++---- src/opentau/scripts/robocasa/client.py | 577 ------------------------- 2 files changed, 65 insertions(+), 657 deletions(-) delete mode 100644 src/opentau/scripts/robocasa/client.py diff --git a/docs/source/tutorials/robocasa.rst b/docs/source/tutorials/robocasa.rst index 226e84f5..3d87311d 100644 --- a/docs/source/tutorials/robocasa.rst +++ b/docs/source/tutorials/robocasa.rst @@ -1,27 +1,34 @@ .. _robocasa: +.. _robocasa_client_gist: https://gist.github.com/akshay18iitg/4d299c135c2d384ceb9a283b745baa01 + RoboCasa setup and rollout client ================================= -This page explains how to set up **RoboCasa** (kitchen simulation) alongside **OpenTau**, run the **policy WebSocket server** that serves an OpenTau checkpoint, and run the **batched rollout client** that steps parallel MuJoCo environments and queries the policy in batches. +This page explains how to set up **RoboCasa** (kitchen simulation) alongside **OpenTau**, run the **policy WebSocket server** that serves an OpenTau checkpoint, and run the **rollout client** against that server. + +The rollout client code **is not shipped in the OpenTau repository**. Use the reference implementation in `robocasa_client_gist`_ (RoboCasa policy client: ``client`` and ``client_async``). .. note:: - Complete the base :doc:`/installation` steps first. RoboCasa itself is installed **outside** the OpenTau package; OpenTau provides the client and server glue only. + Complete the base :doc:`/installation` steps first. RoboCasa itself is installed **outside** the OpenTau package. OpenTau provides the **policy server**; you run the **client** inside your RoboCasa install (files from the gist, or equivalent). Overview -------- -The workflow is split across machines or terminals: +The workflow is usually split across machines or terminals: + +1. **OpenTau host** — runs the WebSocket policy server, loads ``policy.pretrained_path`` from a training config, and returns **action chunks** via MessagePack. +2. **RoboCasa host** — runs the kitchen sim, JPEG-encodes cameras, and talks to the server. Parallel rollouts use a threaded **async** client that **batches** observations for workers that need a new chunk. -1. **Simulation + client** — RoboCasa environments, JPEG encoding, and episode logging run where you have the ``robocasa`` Python package and assets (often the same machine as the server for local testing). -2. **Policy server** — Loads ``policy.pretrained_path`` from an OpenTau training config and answers WebSocket requests with MessagePack-encoded actions. +**In this repo** -OpenTau ships: +* ``opentau.scripts.robocasa.server`` — WebSocket server (single-observation or batched requests; replies are **action chunks** per request row). -* ``opentau.scripts.robocasa.server`` — async WebSocket server (single-observation and **batched** requests). -* ``opentau.scripts.robocasa.client`` — threaded client that batches observations from multiple env workers per timestep. +**Outside this repo** -Dependencies used by these modules (``websockets``, ``msgpack``) are declared in OpenTau’s ``pyproject.toml``. The server also needs **OpenCV** for JPEG decode on the policy side (``opencv-python`` or ``opencv-python-headless`` is already a core OpenTau dependency). +* ``robocasa.scripts.client`` / ``robocasa.scripts.client_async`` — reference rollout scripts from `robocasa_client_gist`_ (place them under your ``robocasa`` package tree or run them as you prefer). + +Server dependencies (``websockets``, ``msgpack``) are in OpenTau’s ``pyproject.toml``. The server needs **OpenCV** (``cv2``) to decode JPEG camera inputs. Prerequisites @@ -34,25 +41,37 @@ Prerequisites **Python** -* OpenTau currently targets **Python 3.10** (see ``requires-python`` in the repo root ``pyproject.toml``). Use the same interpreter for OpenTau and for the environment where you install RoboCasa, or ensure compatibility between the two stacks. +* OpenTau targets **Python 3.10** (see ``requires-python`` in the repo root ``pyproject.toml``). Match or reconcile Python versions with your RoboCasa environment. **RoboCasa simulation** -RoboCasa is not installed by ``pip install opentau``. Install the simulator and assets from the **upstream project**: +RoboCasa is not fully installed by ``pip install opentau``. Install the simulator and assets from upstream: * `RoboCasa installation `_ -Typical steps include installing ``robosuite`` (often from source), then ``robocasa`` in editable mode, then running asset download scripts (kitchen assets can be large). Always refer to the official docs for the version you use. - **OpenTau** -Install OpenTau from source or PyPI as in :doc:`/installation`. Ensure your environment has the packages required by the scripts above (sync with ``uv sync`` or ``pip install -e .`` from the repo). +Install OpenTau as in :doc:`/installation` (e.g. ``uv sync`` or ``pip install -e .``). Policy server (OpenTau) ----------------------- -The server listens on a WebSocket port and speaks MessagePack. It accepts either a **single** observation dict or a **batch** payload ``{ "batch": true, "items": [ ... ] }`` for parallel clients. +The server listens on WebSocket and uses **MessagePack** for request and response bodies. + +**Inference** + +* Each successful call uses ``policy.sample_actions`` (not ``select_action``): the model predicts a **temporal chunk** of actions. The last dimension is trimmed or zero-padded to ``--robocasa_action_dim``. + +**Requests** + +* **Single observation:** top-level dict with ``images`` (JPEG bytes per camera name), ``state`` (list of floats), ``prompt`` (string). +* **Batch:** ``{ "batch": true, "items": [ { ... same fields ... }, ... ] }``. + +**Responses** + +* **Single:** one chunk as nested lists: ``[[float, ...], ...]`` — shape ``(T, action_dim)`` with ``T`` equal to the policy’s predicted horizon (e.g. ``n_action_steps``). +* **Batch:** ``[ chunk_0, chunk_1, ... ]`` — one chunk per ``items`` row, same order. **Entry point** @@ -74,13 +93,11 @@ The server listens on a WebSocket port and speaks MessagePack. It accepts either * - ``--robocasa_port`` - TCP port (default ``8765``). * - ``--robocasa_action_dim`` - - Flat action size passed to the policy and validation (default ``16``; align with your RoboCasa / training setup). + - Flat action width for reply padding/trimming (default ``16``; align with RoboCasa env and training). * - ``--robocasa_torch_compile`` - ``true`` / ``false`` — whether to compile ``sample_actions`` when supported (default ``true``). - * - ``--robocasa_use_stub`` - - ``true`` to use a small random policy instead of loading ``policy.pretrained_path`` (useful for wiring tests without weights). -**Example** with explicit host and port: +**Example** .. code-block:: bash @@ -90,88 +107,56 @@ The server listens on a WebSocket port and speaks MessagePack. It accepts either --robocasa_action_dim 16 \ --config_path /path/to/train_config.json -The training config must define ``policy.pretrained_path`` and compatible policy settings unless you use ``--robocasa_use_stub=true``. - +The training config must define ``policy.pretrained_path`` and settings compatible with your checkpoint. -Rollout client (RoboCasa + OpenTau) ------------------------------------ -Copy the code from `opentau.scripts.robocasa.client` to `robocasa.scripts.client` and modify the code to fit the needs. Run the following command in robocasa environment: +Rollout client (RoboCasa environment) +------------------------------------- -Add this function in `robocasa.utils.env_utils`: +Get the client sources from `robocasa_client_gist`_. -.. code-block:: python +Typical layout after copying into a RoboCasa checkout: - def convert_action_pi05(action): - """ - Converts input action (np.array) to format expected by gym env (dict) - """ - action = action.copy() - output_action = { - "action.end_effector_position": action[5:8], - "action.end_effector_rotation": action[8:11], - "action.gripper_close": action[11:12], - "action.base_motion": action[0:4], - "action.control_mode": action[4:5], - } - return np.concatenate([v for k,v in output_action.items()], axis=-1) +* ``robocasa/scripts/client.py`` — single-env style client (if provided in the gist). +* ``robocasa/scripts/client_async.py`` — threaded client that **batches** observations for workers that need a **new action chunk**, sends one WebSocket message per batch, receives one chunk per batch row, then **steps the simulator for every action in each chunk** before querying the server again. -Run the client **after** the server is listening. It registers a RoboCasa task name, spawns one thread per parallel worker (up to ``--num-parallel``), batches observations for each timestep, and writes ``rollouts.json`` plus optional per-camera videos. +If your PandaOmron-style env expects actions in a particular layout, the gist may include a ``convert_action_pi05`` helper (or equivalent); wire it to match ``create_env`` / your task. -**Entry point** +**Example (async / batched client)** .. code-block:: bash - python -m opentau.scripts.robocasa.client ENV_NAME \ + python -m robocasa.scripts.client_async ENV_NAME \ --host localhost \ --port 8765 -Replace ``ENV_NAME`` with a registered RoboCasa kitchen task class name (same as other RoboCasa tooling). +Replace ``ENV_NAME`` with a registered RoboCasa kitchen task. Common options (see the gist for the exact CLI): -**Useful options** +* ``--num-rollouts`` — total episodes. +* ``--num-parallel`` — parallel env threads (batch size is at most the count of workers requesting a chunk at once). +* ``--seed``, ``--split``, ``--output-dir``, ``--max-episode-steps``, ``--render``, ``--jpeg-quality``. -.. list-table:: - :header-rows: 1 - :widths: 28 72 +**Environment variables** (if supported by the gist client) - * - Option - - Meaning - * - ``--num-rollouts`` - - Total episodes (default ``1``). - * - ``--num-parallel`` - - Parallel env threads (capped by ``--num-rollouts``); batch size per step is at most this value. - * - ``--seed`` - - Base seed; rollout ``i`` uses ``seed + i``. - * - ``--split`` - - Dataset split for ``create_env`` (``all``, ``pretrain``, or ``target``). - * - ``--output-dir`` - - Root for ``rollouts.json`` and ``rollout_*_seed_*`` video folders (default: auto-generated under cwd). - * - ``--max-episode-steps`` - - Step cap per episode (default ``1500``). - * - ``--render`` - - On-screen rendering; disables saved videos. - -**Environment variables** - -* ``ROBOCASA_POLICY_HOST`` — default for ``--host`` (default ``localhost``). -* ``ROBOCASA_POLICY_PORT`` — default for ``--port`` (default ``8765``). - - -Protocol and outputs (short) +* ``ROBOCASA_POLICY_HOST`` — default host. +* ``ROBOCASA_POLICY_PORT`` — default port. + + +Protocol and outputs (summary) ------------------------------ -* **Transport:** WebSocket binary frames, MessagePack payloads. -* **Client → server (batched):** ``{ "batch": true, "items": [ { "images": { camera_name: jpeg_bytes, ... }, "state": [...], "prompt": "..." }, ... ] }``. -* **Server → client:** A list of flat action lists, one per item, same order as ``items``. -* **Client output:** A directory containing ``rollouts.json`` (summary and per-rollout ``seed``, ``length``, ``success``) and, when not using ``--render``, MP4 files per camera under ``rollout_*`` subfolders. +* **Transport:** WebSocket binary frames, MessagePack. +* **Client → server (batch):** ``{ "batch": true, "items": [ { "images": {...}, "state": [...], "prompt": "..." }, ... ] }``. +* **Server → client (batch):** list of action chunks; each chunk is ``(T, action_dim)`` as nested lists. +* **Rollout output:** directory with ``rollouts.json`` and, when not rendering on screen, per-rollout MP4s per camera (behavior as implemented in the gist). -For full behavioral details (variable batch size as workers finish, JPEG quality, ``ping_timeout``), see the module docstrings in ``src/opentau/scripts/robocasa/client.py`` and ``src/opentau/scripts/robocasa/server.py``. +For server implementation details, see ``src/opentau/scripts/robocasa/server.py``. For client behavior and options, see `robocasa_client_gist`_. Troubleshooting --------------- -* **Import errors for ``robocasa``** — Install and register RoboCasa per upstream docs; the client imports ``robocasa`` and ``robocasa.utils.env_utils``. -* **Server fails on JPEG decode** — Install OpenCV for Python on the server host (``cv2``); without it, JPEG decoding raises at runtime. -* **Port already in use** — Change ``--robocasa_port`` / ``--port`` or stop the conflicting process. -* **Action dimension mismatches** — Align ``--robocasa_action_dim`` with the policy and environment (e.g. PandaOmron / ``convert_action_pi05`` expectations in the client). +* **Import errors for ``robocasa``** — Install RoboCasa per upstream docs; run the client from that environment. +* **Server JPEG decode errors** — Install OpenCV for Python on the server (``cv2``). +* **Port in use** — Change ``--robocasa_port`` / client ``--port``. +* **Action shape / chunk mismatch** — Align ``--robocasa_action_dim`` with training and env; ensure the client consumes **chunks** (multiple steps per server reply) if you use chunking inference. diff --git a/src/opentau/scripts/robocasa/client.py b/src/opentau/scripts/robocasa/client.py deleted file mode 100644 index d8366545..00000000 --- a/src/opentau/scripts/robocasa/client.py +++ /dev/null @@ -1,577 +0,0 @@ -# Copyright 2026 Tensor Auto Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Threaded batched remote policy client for RoboCasa. - -Runs **n_parallel** environment threads. Each thread pulls rollouts from a shared queue -until **num_rollouts** episodes are finished. The **main** asyncio loop receives -observations from active workers, batches them into one WebSocket message, and routes -returned action chunks back to the corresponding threads. - -Batch protocol (MessagePack over WebSocket, binary frames) matches ``client.py`` / -``robocasa.scripts.server``: - - Client -> server: { - "batch": true, - "items": [ - { "images": { camera_name: bytes (JPEG), ... }, "state": list[float], "prompt": str }, - ... - ], - } - - Server -> client: list[list[list[float]]] # one action chunk per item, same order as ``items`` - -The number of ``items`` (and thus the batch size) is **only** the count of workers -that need a new chunk right now. As workers finish their rollout queue and exit, batch -size shrinks from at most ``num_parallel`` down to 1 for the final active worker(s). -The policy server must return exactly ``len(items)`` actions, not a fixed width of -``num_parallel``. - -Rollout records and ``rollouts.json`` match ``client.py`` (``env_name``, ``seed``, -``length``, ``success`` per rollout; summary includes ``num_rollouts``, -``num_parallel_envs``, ``output_directory``). - -Requires ``websockets``, ``msgpack``, ``opencv-python`` (``cv2``). The WebSocket client -sets ``ping_timeout=None`` so MuJoCo stepping and JPEG encoding do not trip keepalive. -""" - -from __future__ import annotations - -import argparse -import asyncio -import json -import os -import queue -import threading -import warnings -from dataclasses import dataclass -from datetime import datetime -from typing import Any, Union - -import imageio -import msgpack -import numpy as np -import websockets - -import robocasa # noqa: F401 -from robocasa.scripts.client import ( - DEFAULT_CAMERA_HEIGHT, - DEFAULT_CAMERA_NAMES, - DEFAULT_CAMERA_WIDTH, - build_proprio_vector, - encode_all_cameras_jpeg, - flip_image_obs, - get_task_prompt, -) -from robocasa.utils.env_utils import convert_action_pi05, create_env - - -@dataclass -class ObsMsg: - """Worker needs a policy action for this observation (packed client payload).""" - - payload: dict[str, Any] - - -@dataclass -class DoneMsg: - """Episode finished; no server call for this message.""" - - rollout_idx: int - length: int - success: bool - - -@dataclass -class ExitMsg: - """Worker thread has no more rollouts and is exiting.""" - - -WorkerToMain = Union[ObsMsg, DoneMsg, ExitMsg] - -_SERVER_TRUNCATED_ACTION_BATCH_WARNED = False - - -def _normalize_batched_actions_response( - actions_batch: Any, - num_expected: int, -) -> list[Any]: - """ - Ensure ``actions_batch`` is a list of length ``num_expected``, one action chunk per batch row. - - When ``num_expected == 1``, some servers may return one chunk directly - (``list[list[float]]``) or one flat action (``list[float]``) instead of - ``[chunk]``; wrap those cases. - - When the server returns *more* rows than ``num_expected`` (e.g. fixed max batch - width while the client sends a partial batch), the excess rows are dropped. - """ - global _SERVER_TRUNCATED_ACTION_BATCH_WARNED - if not isinstance(actions_batch, list): - raise ValueError(f"Batched server response must be a list, got {type(actions_batch).__name__}") - if len(actions_batch) == num_expected: - return actions_batch - if num_expected == 1 and len(actions_batch) > 0: - first = actions_batch[0] - if isinstance(first, (int, float, np.floating, np.integer, list, tuple, np.ndarray)): - return [actions_batch] - if len(actions_batch) > num_expected: - if not _SERVER_TRUNCATED_ACTION_BATCH_WARNED: - warnings.warn( - f"Policy server returned {len(actions_batch)} actions for a batch of " - f"{num_expected}; using the first {num_expected}. Prefer fixing the server " - f"to return exactly len(items) actions.", - UserWarning, - stacklevel=2, - ) - _SERVER_TRUNCATED_ACTION_BATCH_WARNED = True - return actions_batch[:num_expected] - raise ValueError( - f"Batched actions length {len(actions_batch)} != batch size {num_expected} " - f"(partial batches must still return one action list per observation)" - ) - - -def _normalize_action_chunk_for_worker(raw_action_chunk: Any) -> list[np.ndarray]: - """Convert one server row into a list of flat action vectors.""" - arr = np.asarray(raw_action_chunk, dtype=np.float64) - if arr.ndim == 1: - return [arr.ravel()] - if arr.ndim != 2: - raise ValueError(f"Expected action chunk rank 1 or 2, got shape {arr.shape}") - return [arr[i].ravel() for i in range(arr.shape[0])] - - -def _worker_loop( - *, - rollout_queue: queue.Queue[int | None], - to_main: queue.Queue[WorkerToMain], - from_main: queue.Queue[Any], - env_name: str, - split, - start_seed: int, - main_dir: str, - jpeg_quality: int, - max_episode_steps: int | None, - render: bool, - action_dim_holder: list[int | None], - action_dim_lock: threading.Lock, -) -> None: - """One thread: sequential rollouts from ``rollout_queue`` until empty.""" - while True: - try: - # get the next rollout index from the queue and its protected by a lock - rollout_idx = rollout_queue.get_nowait() - except queue.Empty: - to_main.put(ExitMsg()) - return - - seed = start_seed + rollout_idx - if not render: - # create the video subdirectory for the rollout - sub = os.path.join(main_dir, f"rollout_{rollout_idx:04d}_seed_{seed}") - os.makedirs(sub, exist_ok=True) - video_writers: dict[str, Any] | None = {} - for cam in DEFAULT_CAMERA_NAMES: - path = os.path.join(sub, f"{cam}.mp4") - video_writers[cam] = imageio.get_writer(path, fps=20) - else: - video_writers = None - - # create the environment - env = create_env( - env_name, - split=split, - seed=seed, - render_onscreen=render, - camera_names=list(DEFAULT_CAMERA_NAMES), - camera_widths=DEFAULT_CAMERA_WIDTH, - camera_heights=DEFAULT_CAMERA_HEIGHT, - has_offscreen_renderer=not render, - use_camera_obs=not render, - ) - try: - # reset the environment and get the initial observation and action dimension - obs = env.reset() - # flip the image observations as mujoco returns flipped images - obs = flip_image_obs(obs, DEFAULT_CAMERA_NAMES) - # get the action dimension and store it in the action dimension holder - with action_dim_lock: - if action_dim_holder[0] is None: - ad = env.action_dim - if ad is None: - raise RuntimeError("env.action_dim is None after reset()") - action_dim_holder[0] = ad - step_count = 0 - pending_actions: list[np.ndarray] = [] - - while True: - if render: - images: dict[str, Any] = {} - else: - # encode the image observations as JPEG - images = encode_all_cameras_jpeg(obs, DEFAULT_CAMERA_NAMES, jpeg_quality=jpeg_quality) - # write the image observations to the video writers - if video_writers is not None: - for cam in DEFAULT_CAMERA_NAMES: - cam_key = f"{cam}_image" - if cam_key in obs: - video_writers[cam].append_data(obs[cam_key]) - - # build the state vector in desried order and get the task prompt - state = build_proprio_vector(obs).tolist() - prompt = get_task_prompt(env) - payload_obs = {"images": images, "state": state, "prompt": prompt} - - # request a new policy chunk only when local chunk is exhausted - if len(pending_actions) == 0: - # send the payload to the main thread - to_main.put(ObsMsg(payload=payload_obs)) - raw_action_chunk = from_main.get() - pending_actions = _normalize_action_chunk_for_worker(raw_action_chunk) - if len(pending_actions) == 0: - raise ValueError("Server returned an empty action chunk") - - # take one action from the local chunk - action = pending_actions.pop(0) - # build action vector in desired order - action = convert_action_pi05(action) - - # check if the action dimension is correct - ad = action_dim_holder[0] - assert ad is not None - if action.shape[0] != ad: - raise ValueError(f"Policy returned action dim {action.shape[0]}, expected {ad}") - - # step the environment and get the new observation - obs, _r, _d, _i = env.step(action) - # flip the image observations as mujoco returns flipped images - obs = flip_image_obs(obs, DEFAULT_CAMERA_NAMES) - step_count += 1 - - # check if the episode is over - episode_over = bool(env._check_success()) or ( - max_episode_steps is not None and step_count >= max_episode_steps - ) - if episode_over: - success = bool(env._check_success()) - to_main.put( - DoneMsg( - rollout_idx=rollout_idx, - length=step_count, - success=success, - ) - ) - break - finally: - if video_writers is not None: - for w in video_writers.values(): - w.close() - env.close() - - -async def _run_coordinator( - *, - ws_uri: str, - n_workers: int, - to_mains: list[queue.Queue[WorkerToMain]], - from_mains: list[queue.Queue[Any]], - results_by_rollout: dict[int, tuple[int, bool]], - results_lock: threading.Lock, -) -> None: - """ - For each timestep, read from all active workers in parallel until each has produced - one ``ObsMsg`` (skipping ``DoneMsg``) or ``ExitMsg``. This avoids deadlock when one - worker finishes an episode and is slow to start the next rollout while others already - have the next observation ready. - """ - loop = asyncio.get_event_loop() - - def _get(q: queue.Queue[WorkerToMain]) -> WorkerToMain: - return q.get() - - async def _drain_to_obs_or_exit(wid: int) -> tuple[int, ObsMsg | None, bool]: - """Returns (worker_id, ObsMsg or None if exiting, is_exit).""" - while True: - msg = await loop.run_in_executor(None, _get, to_mains[wid]) - if isinstance(msg, ExitMsg): - return (wid, None, True) - if isinstance(msg, DoneMsg): - with results_lock: - results_by_rollout[msg.rollout_idx] = (msg.length, msg.success) - continue - if isinstance(msg, ObsMsg): - return (wid, msg, False) - raise TypeError(f"Unexpected message: {type(msg)}") - - async with websockets.connect( - ws_uri, - max_size=None, - ping_timeout=None, - ) as websocket: - active: set[int] = set(range(n_workers)) - - while active: - wids = sorted(active) - # gather the observations from the active workers - gathered = await asyncio.gather(*[_drain_to_obs_or_exit(wid) for wid in wids]) - - for wid, _obs, is_exit in gathered: - if is_exit: - # remove the finished worker from the active set - active.discard(wid) - - batch_pairs: list[tuple[int, ObsMsg]] = [ - (wid, om) for (wid, om, ex) in gathered if not ex and om is not None - ] - batch_pairs.sort(key=lambda x: x[0]) - - if not batch_pairs: - if not active: - break - raise RuntimeError("internal error: no observations to batch but workers are still active") - - batch_items = [om.payload for _wid, om in batch_pairs] - batch_workers = [wid for wid, _om in batch_pairs] - batch_size = len(batch_items) - # batch_size is often < n_workers as workers finish rollouts and exit. - - batch_payload = {"batch": True, "items": batch_items} - await websocket.send(msgpack.packb(batch_payload, use_bin_type=True)) - raw = await websocket.recv() - actions_batch = msgpack.unpackb(raw, raw=False) - actions_batch = _normalize_batched_actions_response(actions_batch, batch_size) - - for wid, act in zip(batch_workers, actions_batch, strict=False): - from_mains[wid].put(act) - - -async def run_policy_loop_threaded( - *, - ws_uri: str, - env_name: str, - split, - start_seed: int, - num_rollouts: int, - num_parallel: int, - output_dir: str | None, - jpeg_quality: int, - max_episode_steps: int | None, - render: bool = False, -) -> None: - if num_rollouts < 1: - raise ValueError("num_rollouts must be >= 1") - if num_parallel < 1: - raise ValueError("num_parallel must be >= 1") - - # number of threads to be created - n_workers = min(num_parallel, num_rollouts) - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - main_dir = output_dir or f"{env_name}_async_{timestamp}" - os.makedirs(main_dir, exist_ok=True) - - print( - f"Output directory: {main_dir!r} — {num_rollouts} rollout(s), " - f"{n_workers} parallel worker thread(s), seeds {start_seed}..{start_seed + num_rollouts - 1}" - ) - - # queue to store the rollout indices - rollout_queue: queue.Queue[int | None] = queue.Queue() - for i in range(num_rollouts): - rollout_queue.put(i) - - # queues to send messages from the coordinator to the workers and from the workers to the coordinator - to_mains: list[queue.Queue[WorkerToMain]] = [queue.Queue() for _ in range(n_workers)] - from_mains: list[queue.Queue[Any]] = [queue.Queue() for _ in range(n_workers)] - - # dictionary to store the results by rollout index - results_by_rollout: dict[int, tuple[int, bool]] = {} - # lock to synchronize access to the results dictionary - results_lock = threading.Lock() - # holder for the action dimension - action_dim_holder: list[int | None] = [None] - # lock to synchronize access to the action dimension - action_dim_lock = threading.Lock() - # list to store the threads - threads: list[threading.Thread] = [] - for wid in range(n_workers): - t = threading.Thread( - target=_worker_loop, - kwargs={ - "rollout_queue": rollout_queue, - "to_main": to_mains[wid], - "from_main": from_mains[wid], - "env_name": env_name, - "split": split, - "start_seed": start_seed, - "main_dir": main_dir, - "jpeg_quality": jpeg_quality, - "max_episode_steps": max_episode_steps, - "render": render, - "action_dim_holder": action_dim_holder, - "action_dim_lock": action_dim_lock, - }, - name=f"robocasa-env-{wid}", - daemon=True, - ) - threads.append(t) - t.start() - - await _run_coordinator( - ws_uri=ws_uri, - n_workers=n_workers, - to_mains=to_mains, - from_mains=from_mains, - results_by_rollout=results_by_rollout, - results_lock=results_lock, - ) - - for t in threads: - t.join(timeout=600.0) - if t.is_alive(): - raise RuntimeError(f"Worker thread {t.name!r} did not exit in time") - - ad = action_dim_holder[0] - if ad is not None: - print( - f"RoboCasa env={env_name!r} split={split!r} action_dim={ad} " - f"cameras={list(DEFAULT_CAMERA_NAMES)} " - f"({DEFAULT_CAMERA_WIDTH}x{DEFAULT_CAMERA_HEIGHT})" - ) - - rollout_records: list[dict[str, Any]] = [] - for ridx in range(num_rollouts): - if ridx not in results_by_rollout: - raise RuntimeError(f"Missing result for rollout index {ridx}") - length, success = results_by_rollout[ridx] - seed = start_seed + ridx - rollout_records.append( - { - "env_name": env_name, - "seed": seed, - "length": length, - "success": success, - } - ) - print(f"Rollout {ridx + 1}/{num_rollouts} seed={seed} length={length} success={success}") - - summary_path = os.path.join(main_dir, "rollouts.json") - summary = { - "env_name": env_name, - "start_seed": start_seed, - "num_rollouts": num_rollouts, - "num_parallel_envs": n_workers, - "output_directory": os.path.abspath(main_dir), - "rollouts": rollout_records, - } - with open(summary_path, "w", encoding="utf-8") as f: - json.dump(summary, f, indent=2) - print(f"Wrote {summary_path!r}") - - -def parse_args(argv: list[str] | None = None) -> argparse.Namespace: - p = argparse.ArgumentParser(description=__doc__) - p.add_argument( - "env_name", - metavar="ENV_NAME", - help="RoboCasa kitchen task (registered class name), same as client.py", - ) - p.add_argument( - "--host", - default=os.environ.get("ROBOCASA_POLICY_HOST", "localhost"), - help=( - "Policy server hostname or IP (default: localhost). " - "Use a real host — not the literal word HOST from examples." - ), - ) - p.add_argument( - "--port", - type=int, - default=int(os.environ.get("ROBOCASA_POLICY_PORT", "8765")), - help="Policy server port (or set ROBOCASA_POLICY_PORT)", - ) - p.add_argument( - "--split", - default="all", - choices=[None, "all", "pretrain", "target"], - help="Dataset split passed to create_env (default: all)", - ) - p.add_argument( - "--seed", - type=int, - default=0, - help="Seed for rollout index 0; rollout i uses seed + i", - ) - p.add_argument( - "--num-rollouts", - type=int, - default=1, - help="Total number of episodes (rollouts) to run", - ) - p.add_argument( - "--num-parallel", - type=int, - default=1, - help="Number of parallel environment threads (capped at num-rollouts)", - ) - p.add_argument( - "--output-dir", - type=str, - default=None, - help="Directory for rollouts.json and per-rollout video subfolders", - ) - p.add_argument("--jpeg-quality", type=int, default=80, help="JPEG quality 0-100") - p.add_argument( - "--max-episode-steps", - type=int, - default=1500, - help="Cap steps per episode (in addition to env success)", - ) - p.add_argument("--render", action="store_true", help="Render onscreen (no videos)") - return p.parse_args(argv) - - -def main(argv=None) -> None: - args = parse_args(argv) - if args.num_rollouts < 1: - raise SystemExit("error: --num-rollouts must be >= 1") - if args.num_parallel < 1: - raise SystemExit("error: --num-parallel must be >= 1") - host = args.host.strip() - if host.lower() == "host": - raise SystemExit( - "error: --host must be a real hostname or IP (e.g. localhost or 127.0.0.1), " - "not the placeholder HOST." - ) - uri = f"ws://{host}:{args.port}" - asyncio.run( - run_policy_loop_threaded( - ws_uri=uri, - env_name=args.env_name, - split=args.split, - start_seed=args.seed, - num_rollouts=args.num_rollouts, - num_parallel=args.num_parallel, - output_dir=args.output_dir, - jpeg_quality=args.jpeg_quality, - max_episode_steps=args.max_episode_steps, - render=args.render, - ) - ) - - -if __name__ == "__main__": - main() From 31c6e48dfbf5a22052bc1743b69b1b33765c44cd Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Thu, 26 Mar 2026 13:36:37 -0700 Subject: [PATCH 6/6] Updating Readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index fd651944..80b3872f 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ We provide fully functioning $\pi_{0.5}$ checkpoints trained with high success r | [TensorAuto/Robocasa_navigatekitchen][12] | A $\pi_{0.5}$ model checkpoint trained on Navigate to Kitchen objects task on Robocasa. | 97% | | [TensorAuto/Robocasa_Closeupdown][11] | A $\pi_{0.5}$ model checkpoint trained on Close Oven, Close Toaster and Close Dishwasher on Robocasa. | Close Oven : 90%
Close Toaster : 70%
Close Dishwasher : 90% | | [TensorAuto/TensorAuto/robocasa_Closesideways][10]| A $\pi_{0.5}$ model checkpoint trained on Close Microwave, Close Cabinet and Close Fridge on Robocasa. | Close Microwave : 97%
Close Cabinet : 65%
Close Fridge : 80% | -| [TensorAuto/pi05_libero_continuous_state][9] | A $\pi_{0.5}$ model checkpoint trained on Libero dataset with continuous actions. | 92% | +| [TensorAuto/pi05_libero_continuous_state][9] | A $\pi_{0.5}$ model checkpoint trained on Libero dataset with continuous states (projecting raw proprioceptive states to models latent dimension). | 92% | | [TensorAuto/moka_pot_libero_sft][6]
[TensorAuto/moka_pot_RECAP_R0][7]
[TensorAuto/moka_pot_RECAP_R1][8] | A $\pi_{0}$ RECAP model checkpoint trained on moka pot task on libero. | 83%
89%
90% | | [TensorAuto/tPi0.5-libero][2] | A $\pi_{0.5}$ model checkpoint trained on the LIBERO dataset with discrete actions and knowledge insulation. | 98.4% (10)
97.6% (Goal)
100% (Object)
98% (Spatial) | | [TensorAuto/pi05_base][5] | A $\pi_{0.5}$ model checkpoint converted from the official openpi checkpoint, with language embeddings added. | N/A |