From 579e2f4d8007bd2c610c6207ed5cf9de7314620c Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Sun, 3 May 2026 19:53:44 -0700 Subject: [PATCH 1/5] feat(grounding): add codec + ensure_loc_tokens for pi05/pi06 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Land the infrastructure for PaliGemma-style location-token grounding data without yet shipping a concrete grounding dataset. - src/opentau/datasets/grounding/loc_codec.py — pure functions to convert pixel coordinates to/from `` strings (xyxy/xywh boxes, points, tolerant inverses). y-then-x order, 1024-bin quantization, computed against original image dims. - src/opentau/datasets/grounding/tokenizer_utils.py — `ensure_loc_tokens` uses `AddedToken(special=True, normalized=False)` to promote the loc strings to single-token match mode. Idempotent: 0 new IDs on PaliGemma (the strings already live at IDs 256000..257023 but the bare HF tokenizer otherwise BPE-fragments them); 1024 new IDs on a fresh Gemma 3 tokenizer, with `model.resize_token_embeddings` updating the embedding table and tied LM head. - pi05 / pi06 wire `ensure_loc_tokens` into their `__init__`s. pi06 also passes the Gemma 3 model handle so the embedding/LM-head resize fires after the public `google/gemma-3-4b-pt` weights have loaded — the new rows are random-init. - Delete the broken `vqa/pixmo.py`. Its JSON-encoded points-as-ASCII responses fragmented through BPE; the replacement (a configurable grounding dataset) is tracked as a follow-up. - Tests: 10 codec round-trip / clamping / order tests; 7 PaliGemma tokenizer tests; 7 Gemma 3 tokenizer tests (including a fake-model resize check). All pass under `pytest -m "not gpu"`. --- src/opentau/datasets/__init__.py | 6 +- src/opentau/datasets/factory.py | 3 +- src/opentau/datasets/grounding/__init__.py | 35 ++++ src/opentau/datasets/grounding/loc_codec.py | 158 ++++++++++++++++ .../datasets/grounding/tokenizer_utils.py | 94 ++++++++++ .../datasets/standard_data_format_mapping.py | 7 - src/opentau/datasets/vqa/__init__.py | 5 +- src/opentau/datasets/vqa/pixmo.py | 177 ------------------ src/opentau/policies/pi05/modeling_pi05.py | 10 + src/opentau/policies/pi06/modeling_pi06.py | 15 ++ tests/datasets/test_loc_codec.py | 155 +++++++++++++++ tests/datasets/test_loc_tokens_gemma3.py | 137 ++++++++++++++ tests/datasets/test_loc_tokens_paligemma.py | 102 ++++++++++ 13 files changed, 711 insertions(+), 193 deletions(-) create mode 100644 src/opentau/datasets/grounding/__init__.py create mode 100644 src/opentau/datasets/grounding/loc_codec.py create mode 100644 src/opentau/datasets/grounding/tokenizer_utils.py delete mode 100644 src/opentau/datasets/vqa/pixmo.py create mode 100644 tests/datasets/test_loc_codec.py create mode 100644 tests/datasets/test_loc_tokens_gemma3.py create mode 100644 tests/datasets/test_loc_tokens_paligemma.py diff --git a/src/opentau/datasets/__init__.py b/src/opentau/datasets/__init__.py index 0947a75d..20a4bb64 100644 --- a/src/opentau/datasets/__init__.py +++ b/src/opentau/datasets/__init__.py @@ -22,7 +22,7 @@ - **Core Datasets**: LeRobotDataset for robot learning data with support for temporal alignment, multi-modal data, and version compatibility. - - **VQA Datasets**: Vision-language datasets (CLEVR, COCO-QA, PIXMO, VSR) + - **VQA Datasets**: Vision-language datasets (CLEVR, COCO-QA, VSR) for training visual understanding without robot actions. - **Dataset Mixtures**: WeightedDatasetMixture for combining multiple datasets with controlled sampling proportions. @@ -53,7 +53,7 @@ Main Modules: - **lerobot_dataset**: Core dataset implementation for robot learning data. - - **vqa**: Vision-language vqa datasets (CLEVR, COCO-QA, PIXMO, VSR). + - **vqa**: Vision-language vqa datasets (CLEVR, COCO-QA, VSR). - **dataset_mixture**: Weighted combination of multiple datasets. - **factory**: Factory functions for creating datasets from configurations. - **utils**: Utility functions for I/O, metadata management, and validation. @@ -80,5 +80,5 @@ >>> from opentau import available_vqa_datasets >>> print(list(available_vqa_datasets.keys())) - ['clevr', 'cocoqa', 'dummy', 'pixmo', 'vsr'] + ['clevr', 'cocoqa', 'dummy', 'vsr'] """ diff --git a/src/opentau/datasets/factory.py b/src/opentau/datasets/factory.py index 2626b702..2ea42c82 100644 --- a/src/opentau/datasets/factory.py +++ b/src/opentau/datasets/factory.py @@ -26,7 +26,7 @@ 1. LeRobot datasets: Standard robot learning datasets loaded from HuggingFace repositories with configurable delta timestamps for temporal alignment. 2. VQA datasets: Vision-language vqa datasets (CLEVR, COCO-QA, - PIXMO, VSR, etc.) for multimodal learning tasks. + VSR, etc.) for multimodal learning tasks. Key Features: - Delta timestamp resolution: Automatically configures temporal offsets @@ -71,7 +71,6 @@ import opentau.datasets.vqa.clevr # noqa: F401 import opentau.datasets.vqa.cocoqa # noqa: F401 import opentau.datasets.vqa.dummy # noqa: F401 -import opentau.datasets.vqa.pixmo # noqa: F401 import opentau.datasets.vqa.vsr # noqa: F401 from opentau import available_vqa_datasets from opentau.configs.default import DatasetConfig diff --git a/src/opentau/datasets/grounding/__init__.py b/src/opentau/datasets/grounding/__init__.py new file mode 100644 index 00000000..b367327c --- /dev/null +++ b/src/opentau/datasets/grounding/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for encoding spatial outputs as PaliGemma-style location tokens. + +PaliGemma reserves 1024 single-token IDs ``..`` that +quantize a coordinate axis into 1024 bins. Bounding boxes and points are +emitted as plain strings (` label`), +which the standard tokenizer turns into a single integer per ``. + +Two helpers live here: + +- ``loc_codec``: pure functions to convert pixel coordinates to/from + the loc-token string format. No torch dependency. +- ``tokenizer_utils.ensure_loc_tokens``: makes the loc strings available + on any HuggingFace tokenizer. A no-op for PaliGemma (already shipped). + For Gemma 3 (and any other base tokenizer) it appends them as special + tokens and, when given a model handle, resizes the embedding table to + match. + +Concrete grounding datasets (PixMo-points, RefCOCO, …) are NOT yet shipped +under this package — see the follow-up tracking the configurable response +formatter that will make them config-driven rather than one class per source. +""" diff --git a/src/opentau/datasets/grounding/loc_codec.py b/src/opentau/datasets/grounding/loc_codec.py new file mode 100644 index 00000000..b9b60db0 --- /dev/null +++ b/src/opentau/datasets/grounding/loc_codec.py @@ -0,0 +1,158 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Codec between pixel coordinates and PaliGemma `` strings. + +PaliGemma's 1024-bin grounding format quantizes a coordinate axis into +10 bits and emits each bin as a single `` token (zero-padded to +four digits — `` does not match the tokenizer). A bounding box is +encoded as four loc tokens in **(y_min, x_min, y_max, x_max)** order +(y-then-x; do not swap), then a space, then the label, and `; ` separates +multiple boxes: + + " dog ; cat" + +A point is two loc tokens in **(y, x)** order: + + " spout" + +The 1024 grid is **abstract** — it is not the input image resolution. +Coordinates are normalized using the original image dimensions, then +quantized as ``int(round(coord_norm * 1023))`` and clamped to ``[0, 1023]``. +Pass the original image's `(width, height)` from the dataset (e.g. +``Image.open(...).size``), NOT the post-resize tensor shape that the +policy actually consumes. + +TODO: eval-side decoding will use ``loc_tokens_to_xyxy`` / +``loc_tokens_to_points`` against decoded response strings to recover +bounding boxes for IoU/mAP. Tracked as a follow-up to the configurable +response-formatter work. +""" + +from __future__ import annotations + +import re + +NUM_BINS = 1024 +MAX_BIN = NUM_BINS - 1 + +_LOC_TOKEN_RE = re.compile(r"") + + +def _quantize(coord: float, extent: float) -> int: + """Map a single pixel coordinate to a `[0, 1023]` bin index. + + Args: + coord: Pixel coordinate (e.g. an x or y in original-image space). + extent: The image dimension along this axis (width for x, height for y). + + Returns: + An integer bin index in ``[0, 1023]``. + """ + if extent <= 0: + return 0 + bin_idx = int(round((coord / extent) * MAX_BIN)) + if bin_idx < 0: + return 0 + if bin_idx > MAX_BIN: + return MAX_BIN + return bin_idx + + +def _dequantize(bin_idx: int, extent: float) -> float: + """Inverse of `_quantize`: map a bin index back to a pixel coordinate.""" + return (bin_idx / MAX_BIN) * extent + + +def _loc(bin_idx: int) -> str: + return f"" + + +def xyxy_to_loc_tokens(box_xyxy: tuple[float, float, float, float], img_w: int, img_h: int) -> str: + """Encode an `(x_min, y_min, x_max, y_max)` box as four loc tokens. + + The output order is `` + (y-then-x), matching PaliGemma's convention. + + Args: + box_xyxy: ``(x_min, y_min, x_max, y_max)`` in pixel coordinates of the + **original** image. + img_w: Original image width in pixels. + img_h: Original image height in pixels. + + Returns: + A four-token string with no separators. + """ + x_min, y_min, x_max, y_max = box_xyxy + return ( + _loc(_quantize(y_min, img_h)) + + _loc(_quantize(x_min, img_w)) + + _loc(_quantize(y_max, img_h)) + + _loc(_quantize(x_max, img_w)) + ) + + +def xywh_to_loc_tokens(box_xywh: tuple[float, float, float, float], img_w: int, img_h: int) -> str: + """Same as `xyxy_to_loc_tokens` but accepts COCO-style ``(x, y, w, h)``.""" + x, y, w, h = box_xywh + return xyxy_to_loc_tokens((x, y, x + w, y + h), img_w, img_h) + + +def point_to_loc_tokens(x: float, y: float, img_w: int, img_h: int) -> str: + """Encode an `(x, y)` point as two loc tokens in y-then-x order.""" + return _loc(_quantize(y, img_h)) + _loc(_quantize(x, img_w)) + + +def loc_tokens_to_xyxy(s: str, img_w: int, img_h: int) -> list[tuple[float, float, float, float]]: + """Parse a string of loc tokens into `(x_min, y_min, x_max, y_max)` pixel boxes. + + Tolerant: any segment that does not contain a multiple of four loc tokens + is dropped silently. Garbage strings or partial decodes return ``[]``. + + Args: + s: A string that may contain `` tokens, e.g. a decoded + response. Non-loc text is ignored. + img_w: Original image width in pixels. + img_h: Original image height in pixels. + + Returns: + A list of `(x_min, y_min, x_max, y_max)` tuples in pixel coordinates. + """ + bins = [int(m) for m in _LOC_TOKEN_RE.findall(s)] + boxes: list[tuple[float, float, float, float]] = [] + for i in range(0, len(bins) - 3, 4): + y_min_b, x_min_b, y_max_b, x_max_b = bins[i : i + 4] + boxes.append( + ( + _dequantize(x_min_b, img_w), + _dequantize(y_min_b, img_h), + _dequantize(x_max_b, img_w), + _dequantize(y_max_b, img_h), + ) + ) + return boxes + + +def loc_tokens_to_points(s: str, img_w: int, img_h: int) -> list[tuple[float, float]]: + """Parse a string of loc tokens into `(x, y)` pixel points. + + Pairs of loc tokens are decoded as `(y, x)` per the PaliGemma convention + and returned as `(x, y)`. Lone trailing tokens are dropped. + """ + bins = [int(m) for m in _LOC_TOKEN_RE.findall(s)] + points: list[tuple[float, float]] = [] + for i in range(0, len(bins) - 1, 2): + y_b, x_b = bins[i : i + 2] + points.append((_dequantize(x_b, img_w), _dequantize(y_b, img_h))) + return points diff --git a/src/opentau/datasets/grounding/tokenizer_utils.py b/src/opentau/datasets/grounding/tokenizer_utils.py new file mode 100644 index 00000000..de5425f8 --- /dev/null +++ b/src/opentau/datasets/grounding/tokenizer_utils.py @@ -0,0 +1,94 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenizer-side support for PaliGemma ``..`` tokens. + +Two cases must be handled, both via the same call: + +1. **PaliGemma (`google/paligemma-3b-pt-224`).** The 1024 loc strings are + already in the SentencePiece vocab at IDs ``256000``..``257023``. They + are NOT, however, registered as **added tokens**, so the bare HF tokenizer + BPE-fragments any ``-shaped string into seven pieces + (``['<', 'loc', '0', '0', '0', '0', '>']``) instead of matching it as one + unit at ID 256000. Calling ``add_tokens`` with an ``AddedToken`` whose + string already exists in the vocab is the documented HF mechanism to + *promote* an existing entry to single-token-match status without + reassigning its ID. No new vocab slots are created and no embedding + resize is needed. +2. **Gemma 3 (`google/gemma-3-4b-pt`).** The strings are absent. The same + ``add_tokens`` call appends 1024 new IDs at the end of the vocab; the + model's embedding table and tied LM head must be resized to match. The + new rows are random-init — they learn from the grounding data on first + use. There is no PaliGemma loc-embedding transfer. + +The utility below covers both cases idempotently, so it can be wired into +every policy ``__init__`` defensively. +""" + +from __future__ import annotations + +import logging + +from transformers.tokenization_utils_base import AddedToken + +LOC_TOKENS: tuple[str, ...] = tuple(f"" for i in range(1024)) + +_logger = logging.getLogger(__name__) + + +def ensure_loc_tokens(tokenizer, model=None) -> int: + """Idempotently make ``..`` available as single tokens. + + Always promotes the 1024 loc strings to added/special tokens via + ``tokenizer.add_tokens``. For PaliGemma this is a no-op vocab-size-wise: + the strings already live at the reserved IDs ``256000``..``257023``, and + the call only flips them into single-token match mode. For Gemma 3 the + 1024 strings are appended as new IDs, and the model's embedding table + and tied LM head are resized via ``model.resize_token_embeddings`` when + a model handle is supplied. + + Safe to call multiple times — once the strings are registered as added + tokens, subsequent calls neither grow the vocab nor resize. + + Args: + tokenizer: A HuggingFace `PreTrainedTokenizer` / `PreTrainedTokenizerFast`. + model: Optional `PreTrainedModel` whose embeddings should be resized + when new IDs are assigned. Pass the top-level VLM (e.g. the + ``Gemma3ForConditionalGeneration`` / + ``PaliGemmaForConditionalGeneration`` instance) — HF's + ``resize_token_embeddings`` dispatches through + ``get_input_embeddings`` / ``set_input_embeddings`` to the + language model and updates the tied LM head as well. + + Returns: + The number of NEW IDs appended to the tokenizer vocab. Always 0 for + PaliGemma; 1024 on the first call against a fresh Gemma 3 tokenizer; + 0 for any subsequent call. + """ + initial_len = len(tokenizer) + added_tokens = [AddedToken(t, special=True, normalized=False) for t in LOC_TOKENS] + tokenizer.add_tokens(added_tokens, special_tokens=True) + n_new_ids = len(tokenizer) - initial_len + + if n_new_ids > 0 and model is not None: + model.resize_token_embeddings(len(tokenizer)) + + if n_new_ids > 0: + _logger.info( + "ensure_loc_tokens: appended %d token IDs (new vocab size %d); embeddings %sresized.", + n_new_ids, + len(tokenizer), + "" if model is not None else "NOT ", + ) + return n_new_ids diff --git a/src/opentau/datasets/standard_data_format_mapping.py b/src/opentau/datasets/standard_data_format_mapping.py index 49a8b263..5750ca28 100644 --- a/src/opentau/datasets/standard_data_format_mapping.py +++ b/src/opentau/datasets/standard_data_format_mapping.py @@ -188,13 +188,6 @@ "prompt": "task", "response": "response", }, - "pixmo": { - "camera0": "image", - "state": "state", - "actions": "actions", - "prompt": "prompt", - "response": "postfix", - }, "dummy": { "camera0": "image", "state": "state", diff --git a/src/opentau/datasets/vqa/__init__.py b/src/opentau/datasets/vqa/__init__.py index 1c24a7dd..17eb76b2 100644 --- a/src/opentau/datasets/vqa/__init__.py +++ b/src/opentau/datasets/vqa/__init__.py @@ -34,8 +34,6 @@ for visual question answering with synthetic scenes. - COCO-QA: Visual question answering dataset based on COCO images, filtered for spatial reasoning tasks. - - PIXMO: Pixel-level manipulation vqa dataset for object - localization and manipulation tasks. - VSR: Visual Spatial Reasoning dataset for true/false statement vqa about spatial relationships in images. - dummy: Synthetic test dataset with simple black, white, and gray @@ -51,7 +49,6 @@ clevr: CLEVR dataset implementation. cocoqa: COCO-QA dataset implementation. dummy: Dummy test dataset implementation. - pixmo: PIXMO dataset implementation. vsr: VSR dataset implementation. Example: @@ -63,5 +60,5 @@ Access available vqa datasets: >>> from opentau import available_vqa_datasets >>> print(list(available_vqa_datasets.keys())) - ['clevr', 'cocoqa', 'dummy', 'pixmo', 'vsr'] + ['clevr', 'cocoqa', 'dummy', 'vsr'] """ diff --git a/src/opentau/datasets/vqa/pixmo.py b/src/opentau/datasets/vqa/pixmo.py deleted file mode 100644 index 40dd5066..00000000 --- a/src/opentau/datasets/vqa/pixmo.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2026 Tensor Auto Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Datasets for Image-Text Point Set vqa tasks. - -This module provides the PIXMO (Pixel-level Manipulation) dataset implementation -for training vision-language models on part localization and object vqa tasks. -""" - -import json -import logging -import random -import warnings -from io import BytesIO - -import numpy as np -import requests -import torch -from datasets import load_dataset -from PIL import Image -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry - -from opentau import register_vqa_dataset -from opentau.configs.train import TrainPipelineConfig -from opentau.datasets.vqa.base import VQADataset - -# TODO: add a config to filter the warnings -logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) -warnings.filterwarnings( - "ignore", - message=r"Palette images with Transparency expressed in bytes should be converted to RGBA images", - category=UserWarning, - module=r"PIL\.Image", -) -warnings.filterwarnings( - "ignore", - message=r"image file could not be identified because AVIF support not installed", - category=UserWarning, - module=r"PIL\.Image", -) - -IMG_SIZE = 224 -POINT_GRID = 255 -MAX_RETRIES = 1 -HTTP_TIMEOUT = 1 -LOG_EVERY_N_BAD = 1000 - -_session = requests.Session() -_session.mount( - "https://", - HTTPAdapter( - max_retries=Retry( - total=MAX_RETRIES, - backoff_factor=0.5, - status_forcelist=[500, 502, 503, 504], - ) - ), -) - - -def _pil_from_url(url: str) -> Image.Image | None: - """Download, decode, and resize an image using its URL. Returns None in case of failure.""" - try: - r = _session.get(url, timeout=HTTP_TIMEOUT) - r.raise_for_status() - # TODO: Check against the hash in case the image somehow changed. - return Image.open(BytesIO(r.content)).convert("RGB") - except Exception: - return None - - -def _get_post_fix(label: str, points: list, orig_w: int, orig_h: int, max_points: int = 16) -> str: - """Map points from pixel space to grid space and return a JSON postfix string. - - Converts pixel coordinates to a 255x255 grid, deduplicates points, and - limits to max_points. Returns a JSON string with point coordinates and labels. - - Args: - label: Label for the points (e.g., object class name). - points: List of point dictionaries with 'x' and 'y' keys. - orig_w: Original image width. - orig_h: Original image height. - max_points: Maximum number of points to include. Defaults to 16. - - Returns: - JSON string containing point coordinates and labels. - """ - # use `dict` to deduplicate as `set` is not guaranteed to preserve order - deduplicated = { - (int(p["x"] * POINT_GRID / orig_w), int(p["y"] * POINT_GRID / orig_h)): None for p in points - } - if len(deduplicated) > max_points: - deduplicated = random.choices(list(deduplicated), k=max_points) - rows = [{"in_frame": True, "point": pair, "label": label} for pair in deduplicated] - return json.dumps(rows) - - -def _img_to_normalized_tensor(img: Image.Image) -> torch.Tensor: - """Convert a PIL Image to a normalized torch tensor. - - Resizes the image to IMG_SIZE and converts it from (H, W, C) to (C, H, W) - format, normalizing pixel values to [0, 1]. - - Args: - img: PIL Image to convert. - - Returns: - Normalized tensor of shape (C, IMG_SIZE, IMG_SIZE) with values in [0, 1]. - """ - img = img.resize((IMG_SIZE, IMG_SIZE), Image.BILINEAR) - # pytorch uses (C, H, W) while PIL uses (H, W, C) - return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0 - - -@register_vqa_dataset("pixmo") -class PixmoDataset(VQADataset): - r"""Dataset for the iterable PixMo dataset implementation, recommended to be used together with PrefetchWrapper""" - - def __init__(self, cfg: TrainPipelineConfig, consecutive_bad_tolerance=100): - # Self.ds is needed for metadata, which is computed in parent constructor - self.ds = load_dataset("allenai/pixmo-points", split="train") - super().__init__(cfg) - self.bad_ids = set() - self.consecutive_bad_tolerance = consecutive_bad_tolerance - - def __len__(self): - return len(self.ds) - - def _get_feature_mapping_key(self) -> str: - return "pixmo" - - def __getitem_helper__(self, item) -> dict: - """Get a PixMo dataset item. - - Downloads the image from URL and formats it for part localization tasks. - Retries with random indices if image download fails. - - Args: - item: Index of the item to retrieve. - - Returns: - Dictionary with image, task, postfix, task_type, and prompt. - - Raises: - RuntimeError: If too many consecutive items fail to load. - """ - for _ in range(self.consecutive_bad_tolerance): - if item in self.bad_ids: - item = np.random.randint(0, len(self.ds)) - continue - ex = self.ds[item] - img = _pil_from_url(ex["image_url"]) - if img is None: - self.bad_ids.add(item) - item = np.random.randint(0, len(self.ds)) - continue - - return { - "image": _img_to_normalized_tensor(img), - "task": ex["label"], - "postfix": _get_post_fix(ex["label"], ex["points"], *img.size), - "task_type": "part", - "prompt": f'{{"task": "part", "description": "Find {ex["label"]} in the image"}}', - } - - raise RuntimeError("Too many consecutive bad items. Please check dataset or increase the tolerance.") diff --git a/src/opentau/policies/pi05/modeling_pi05.py b/src/opentau/policies/pi05/modeling_pi05.py index 6c893eeb..41f83536 100644 --- a/src/opentau/policies/pi05/modeling_pi05.py +++ b/src/opentau/policies/pi05/modeling_pi05.py @@ -35,6 +35,7 @@ from opentau.configs.policies import PreTrainedConfig from opentau.configs.types import NormalizationMode +from opentau.datasets.grounding.tokenizer_utils import ensure_loc_tokens from opentau.policies.normalize import Normalize, Unnormalize from opentau.policies.pi05.configuration_pi05 import PI05Config from opentau.policies.pi05.paligemma_with_expert import ( @@ -265,6 +266,12 @@ def __init__( ) self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + # PaliGemma reserves .. at IDs 256000..257023, but + # the bare HF tokenizer does not register them as added/special + # tokens — a string "" otherwise BPE-fragments into seven + # pieces. This call promotes the reserved entries to single-token + # match mode (no new IDs, no embedding resize on PaliGemma). + ensure_loc_tokens(self.language_tokenizer) self.discrete_action_processor = AutoProcessor.from_pretrained( "physical-intelligence/fast", trust_remote_code=True @@ -944,6 +951,9 @@ def __init__(self, config: PI05Config, discrete_action_vocab_size: int | None = self.time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width) self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + # See PI05Policy.__init__ — promotes the reserved entries + # to single-token match mode on this tokenizer instance as well. + ensure_loc_tokens(self.language_tokenizer) def sample_noise(self, shape: tuple[int, ...], device: torch.device | str) -> Tensor: """Samples Gaussian noise. diff --git a/src/opentau/policies/pi06/modeling_pi06.py b/src/opentau/policies/pi06/modeling_pi06.py index 6855b86d..0b4939bb 100644 --- a/src/opentau/policies/pi06/modeling_pi06.py +++ b/src/opentau/policies/pi06/modeling_pi06.py @@ -42,6 +42,7 @@ from opentau.configs.policies import PreTrainedConfig from opentau.configs.types import NormalizationMode +from opentau.datasets.grounding.tokenizer_utils import ensure_loc_tokens from opentau.policies.normalize import Normalize, Unnormalize from opentau.policies.pi06.configuration_pi06 import PI06Config from opentau.policies.pi06.gemma3_with_expert import ( @@ -249,6 +250,9 @@ def __init__( # only if the Hub download fails at module import time in an offline CI, # so users still get a useful error rather than silent drift. self.language_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt") + # Tokenizer-side extension; the matching embedding/LM-head resize fires + # inside PI06FlowMatching.__init__, where the Gemma 3 model handle exists. + ensure_loc_tokens(self.language_tokenizer) self.discrete_action_processor = AutoProcessor.from_pretrained( "physical-intelligence/fast", trust_remote_code=True @@ -737,6 +741,17 @@ def __init__(self, config: PI06Config, discrete_action_vocab_size: int | None = self.time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width) self.language_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt") + # π0.6 uses Gemma 3, whose stock tokenizer does NOT carry the 1024 + # .. grounding tokens that PaliGemma reserves. We + # unconditionally extend the vocab here so any grounding/VQA training + # data containing loc tokens flows through the same response_ce_loss + # path as on PaliGemma backbones. The new embedding rows are random- + # init — they learn from grounding data on first use; there is NO + # PaliGemma loc-embedding transfer. The resize must happen after + # `Gemma3WithExpertModel(...)` has already loaded the public Gemma 3 + # weights (line above), so the original 256K rows survive and only + # the 1024 new rows are freshly initialized. + ensure_loc_tokens(self.language_tokenizer, model=self.gemma3_with_expert.gemma3) def sample_noise(self, shape: tuple[int, ...], device: torch.device | str) -> Tensor: """Standard Gaussian noise (float32).""" diff --git a/tests/datasets/test_loc_codec.py b/tests/datasets/test_loc_codec.py new file mode 100644 index 00000000..2ce147da --- /dev/null +++ b/tests/datasets/test_loc_codec.py @@ -0,0 +1,155 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the loc-token codec — pure functions, no network.""" + +from __future__ import annotations + +from opentau.datasets.grounding.loc_codec import ( + NUM_BINS, + loc_tokens_to_points, + loc_tokens_to_xyxy, + point_to_loc_tokens, + xywh_to_loc_tokens, + xyxy_to_loc_tokens, +) + + +# Slack equal to one bin step in pixel units, plus a hair for the round-trip +# `int(round(...))` quantization. (img_dim / 1023) * 1.0 is one quantum; +# we allow 0.55 of that on either side. +def _tol(extent: int) -> float: + return extent / (NUM_BINS - 1) * 0.55 + + +def test_xyxy_round_trip_integer_aligned() -> None: + img_w, img_h = 1024, 1024 # one pixel per bin: round-trip is exact. + box = (100.0, 200.0, 800.0, 900.0) + s = xyxy_to_loc_tokens(box, img_w, img_h) + assert s.count(" None: + img_w, img_h = 640, 480 + s_xywh = xywh_to_loc_tokens((50.0, 60.0, 100.0, 120.0), img_w, img_h) + s_xyxy = xyxy_to_loc_tokens((50.0, 60.0, 150.0, 180.0), img_w, img_h) + assert s_xywh == s_xyxy + + +def test_point_round_trip() -> None: + img_w, img_h = 1920, 1080 + s = point_to_loc_tokens(960.0, 540.0, img_w, img_h) + assert s.count(" None: + """Bin order matches PaliGemma's y-then-x convention. + + For an asymmetric image (width != height), swapping x and y in the input + must produce a different bin output — proving we are not silently treating + the two coordinates as interchangeable. + """ + img_w, img_h = 1000, 500 + a = point_to_loc_tokens(100.0, 200.0, img_w, img_h) + b = point_to_loc_tokens(200.0, 100.0, img_w, img_h) + assert a != b + + +def test_clamping_negative_and_overflow() -> None: + img_w, img_h = 100, 100 + s = xyxy_to_loc_tokens((-50.0, -10.0, 5_000.0, 999.0), img_w, img_h) + # Lowest bin is ; highest is . The clamped string + # should start with two tokens and end with two 's. + assert s.startswith("") + assert s.endswith("") + + +def test_multi_box_concat() -> None: + img_w, img_h = 1024, 1024 + box1 = xyxy_to_loc_tokens((10.0, 20.0, 30.0, 40.0), img_w, img_h) + box2 = xyxy_to_loc_tokens((100.0, 200.0, 300.0, 400.0), img_w, img_h) + response = f"{box1} dog ; {box2} cat" + + decoded = loc_tokens_to_xyxy(response, img_w, img_h) + assert len(decoded) == 2 + # First box + x_min, y_min, x_max, y_max = decoded[0] + assert abs(x_min - 10.0) <= _tol(img_w) + assert abs(y_min - 20.0) <= _tol(img_h) + # Second box + x_min, y_min, x_max, y_max = decoded[1] + assert abs(x_min - 100.0) <= _tol(img_w) + assert abs(y_min - 200.0) <= _tol(img_h) + + +def test_garbage_input_returns_empty() -> None: + assert loc_tokens_to_xyxy("garbage with no tokens", 100, 100) == [] + assert loc_tokens_to_points("garbage with no tokens", 100, 100) == [] + + +def test_partial_token_count_drops_orphan_pairs() -> None: + """A response containing 6 loc tokens decodes to 1 box (4 tokens consumed). + + The trailing pair is silently dropped — this matches the codec's tolerant + contract documented in `loc_tokens_to_xyxy`'s docstring. + """ + img_w, img_h = 1024, 1024 + # Six loc tokens — one full box plus two orphans. + response = "" + boxes = loc_tokens_to_xyxy(response, img_w, img_h) + assert len(boxes) == 1 + + +def test_codec_uses_original_image_dims_not_post_resize() -> None: + """Regression: bins must be computed against the original image dims. + + A common mistake is to pass the post-resize tensor shape (e.g. 224, 224) + that the policy actually consumes. The codec must use the original + `(img_w, img_h)` from the dataset so loc tokens carry the same spatial + meaning regardless of input pipeline resizing. + """ + # Same pixel coordinate, different "original" extents. + bin_at_1920 = xyxy_to_loc_tokens((960.0, 540.0, 960.0, 540.0), 1920, 1080) + bin_at_224 = xyxy_to_loc_tokens((960.0, 540.0, 960.0, 540.0), 224, 224) + # If the codec ignored extent, both would tokenize to the same string; + # they must not. + assert bin_at_1920 != bin_at_224 + + +def test_format_string_is_zero_padded_to_four_digits() -> None: + """`` does not match the PaliGemma tokenizer; only `` does.""" + img_w, img_h = 102_400, 102_400 # 100 px / bin -> bin 0 for coord 23 + s = xyxy_to_loc_tokens((23.0, 23.0, 23.0, 23.0), img_w, img_h) + # We don't care about the exact bin index — only that every emitted token + # is exactly 9 characters long: "". + for tok in s.replace(">", ">|").split("|"): + if not tok: + continue + assert tok.startswith("") + assert len(tok) == len("") diff --git a/tests/datasets/test_loc_tokens_gemma3.py b/tests/datasets/test_loc_tokens_gemma3.py new file mode 100644 index 00000000..3649bc7b --- /dev/null +++ b/tests/datasets/test_loc_tokens_gemma3.py @@ -0,0 +1,137 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for `` token extension on the Gemma 3 4B tokenizer. + +Gemma 3 (used by π0.6) does NOT ship ``..``. The +`ensure_loc_tokens` utility appends them as 1024 special tokens and, +when given a model handle, resizes the model's embedding/LM head to +match. + +Marked `slow` because the tokenizer is fetched from the HF Hub on first +run. The Gemma 3 *model* is NOT loaded — instead a fake mini-model is +used to verify that `resize_token_embeddings` is invoked when tokens are +added. +""" + +from __future__ import annotations + +import pytest +import torch +from torch import nn +from transformers import AutoTokenizer + +from opentau.datasets.grounding.loc_codec import xyxy_to_loc_tokens +from opentau.datasets.grounding.tokenizer_utils import LOC_TOKENS, ensure_loc_tokens + + +def _fresh_gemma3_tokenizer(): + """Always return a freshly loaded tokenizer so tests don't share state.""" + return AutoTokenizer.from_pretrained("google/gemma-3-4b-pt") + + +class _FakeResizableModel: + """Minimal stand-in for an HF model: tracks `resize_token_embeddings` calls. + + The real `Gemma3ForConditionalGeneration` is not loaded in CI because + its weights are several GB. We only need to confirm that the resize + call fires with the right vocab size. + """ + + def __init__(self, initial_vocab: int, hidden: int = 8) -> None: + self.embed = nn.Embedding(initial_vocab, hidden) + self.resize_calls: list[int] = [] + + def resize_token_embeddings(self, new_size: int) -> nn.Embedding: + self.resize_calls.append(new_size) + old = self.embed + self.embed = nn.Embedding(new_size, old.embedding_dim) + with torch.no_grad(): + n = min(old.num_embeddings, new_size) + self.embed.weight[:n] = old.weight[:n] + return self.embed + + +@pytest.mark.slow +def test_gemma3_lacks_loc_tokens_initially() -> None: + tok = _fresh_gemma3_tokenizer() + vocab = tok.get_vocab() + assert "" not in vocab + assert "" not in vocab + + +@pytest.mark.slow +def test_ensure_loc_tokens_adds_1024() -> None: + tok = _fresh_gemma3_tokenizer() + n_added = ensure_loc_tokens(tok) + assert n_added == 1024 + + +@pytest.mark.slow +def test_ensure_loc_tokens_is_idempotent() -> None: + tok = _fresh_gemma3_tokenizer() + assert ensure_loc_tokens(tok) == 1024 + # Second call must be a no-op now that the strings are in the vocab. + assert ensure_loc_tokens(tok) == 0 + + +@pytest.mark.slow +def test_loc_tokens_become_single_token_after_extension() -> None: + tok = _fresh_gemma3_tokenizer() + ensure_loc_tokens(tok) + for sample in ("", "", ""): + ids = tok.encode(sample, add_special_tokens=False) + assert len(ids) == 1, f"{sample} did not tokenize to a single id" + + +@pytest.mark.slow +def test_bbox_postfix_round_trips_after_extension() -> None: + tok = _fresh_gemma3_tokenizer() + ensure_loc_tokens(tok) + + img_w, img_h = 1024, 1024 + postfix = xyxy_to_loc_tokens((10.0, 20.0, 30.0, 40.0), img_w, img_h) + " dog" + decoded = tok.decode(tok.encode(postfix, add_special_tokens=False)) + + # Each loc string from the input must survive the round-trip. + for tok_str in LOC_TOKENS: + if tok_str in postfix: + assert tok_str in decoded + + +@pytest.mark.slow +def test_resize_fires_when_tokens_added() -> None: + """When tokens are added AND a model is provided, embeddings must resize.""" + tok = _fresh_gemma3_tokenizer() + initial_vocab = len(tok) + fake = _FakeResizableModel(initial_vocab=initial_vocab) + + n_added = ensure_loc_tokens(tok, model=fake) + + assert n_added == 1024 + assert fake.resize_calls == [initial_vocab + 1024] + assert fake.embed.num_embeddings == initial_vocab + 1024 + + +@pytest.mark.slow +def test_resize_does_not_fire_on_idempotent_call() -> None: + """Second call (tokens already present) must not call resize.""" + tok = _fresh_gemma3_tokenizer() + ensure_loc_tokens(tok) # first call adds tokens + + fake = _FakeResizableModel(initial_vocab=len(tok)) + n_added = ensure_loc_tokens(tok, model=fake) + + assert n_added == 0 + assert fake.resize_calls == [] diff --git a/tests/datasets/test_loc_tokens_paligemma.py b/tests/datasets/test_loc_tokens_paligemma.py new file mode 100644 index 00000000..0a991690 --- /dev/null +++ b/tests/datasets/test_loc_tokens_paligemma.py @@ -0,0 +1,102 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for `` token handling on the PaliGemma tokenizer. + +PaliGemma reserves the 1024 loc tokens at IDs 256000..257023 in its base +vocab, but the bare HF tokenizer **does not** register them as added/ +special tokens, so a string like ``""`` BPE-fragments into seven +pieces during ``encode``. Calling ``ensure_loc_tokens`` on the tokenizer +promotes the existing entries to single-token match mode without +reassigning their IDs (no vocab growth, no model-side resize). Every test +below exercises the post-promotion behavior — the public contract that +`PI05Policy.__init__` relies on. + +Marked `slow` because the tokenizer is fetched from the HF Hub on first +run. The model itself is NOT downloaded — only `AutoTokenizer` files. +""" + +from __future__ import annotations + +import pytest +from transformers import AutoTokenizer + +from opentau.datasets.grounding.loc_codec import xyxy_to_loc_tokens +from opentau.datasets.grounding.tokenizer_utils import LOC_TOKENS, ensure_loc_tokens + + +@pytest.fixture(scope="module") +def paligemma_tokenizer(): + """Module-scoped tokenizer with `` tokens already promoted.""" + tok = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + ensure_loc_tokens(tok) + return tok + + +@pytest.mark.slow +def test_ensure_loc_tokens_does_not_grow_vocab() -> None: + """No new IDs assigned: the strings already live at 256000..257023.""" + tok = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + assert ensure_loc_tokens(tok) == 0 + + +@pytest.mark.slow +def test_loc0000_is_single_token_after_promotion(paligemma_tokenizer) -> None: + ids = paligemma_tokenizer.encode("", add_special_tokens=False) + assert len(ids) == 1 + + +@pytest.mark.slow +def test_loc_token_ids_match_reserved_block(paligemma_tokenizer) -> None: + """Promotion preserves the reserved IDs at 256000..257023.""" + assert paligemma_tokenizer.get_vocab()[""] == 256000 + assert paligemma_tokenizer.get_vocab()[""] == 257023 + + +@pytest.mark.slow +def test_loc_token_ids_are_contiguous(paligemma_tokenizer) -> None: + """`` sits one ID after `` — sanity on the block.""" + id0 = paligemma_tokenizer.encode("", add_special_tokens=False)[0] + id1 = paligemma_tokenizer.encode("", add_special_tokens=False)[0] + assert id1 == id0 + 1 + + +@pytest.mark.slow +def test_bbox_postfix_round_trips(paligemma_tokenizer) -> None: + """A 4-loc bbox postfix encodes + decodes without losing any loc token.""" + img_w, img_h = 1024, 1024 + postfix = xyxy_to_loc_tokens((10.0, 20.0, 30.0, 40.0), img_w, img_h) + " dog" + + ids = paligemma_tokenizer.encode(postfix, add_special_tokens=False) + decoded = paligemma_tokenizer.decode(ids) + + for tok in LOC_TOKENS: + if tok in postfix: + assert tok in decoded + + +@pytest.mark.slow +def test_ensure_loc_tokens_is_idempotent(paligemma_tokenizer) -> None: + """Calling promotion a second time on an already-promoted tokenizer + does not grow the vocab or change behavior.""" + before_size = len(paligemma_tokenizer) + assert ensure_loc_tokens(paligemma_tokenizer) == 0 + assert len(paligemma_tokenizer) == before_size + + +@pytest.mark.slow +def test_all_1024_loc_tokens_are_present(paligemma_tokenizer) -> None: + vocab = paligemma_tokenizer.get_vocab() + for tok in LOC_TOKENS: + assert tok in vocab, f"{tok} missing from PaliGemma vocab" From 5fc77c41ba54e7120a7058017bc29ed26241f34f Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Sun, 3 May 2026 20:41:40 -0700 Subject: [PATCH 2/5] test(pi05, pi06): GPU regression for tokens in response MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Confirms the loc-token wiring added to PI05Policy / PI06Policy / their inner FlowMatching modules works end-to-end on GPU. - pi05: builds the policy, asserts both tokenizer instances encode `` and `` to a single ID each, then runs one forward pass with a four-loc bbox postfix and asserts MSE / CE are finite. - pi06: bare `google/gemma-3-4b-pt` tokenizer has no loc strings; after policy construction the inner tokenizer has grown by exactly 1024, loc tokens encode as single IDs, the Gemma 3 input-embedding row count matches the new tokenizer length, and the LM head output dim has been resized in lockstep. Closes with a forward pass on a loc-token-bearing response and asserts finite loss. Both tests are `@pytest.mark.gpu` + `@pytest.mark.slow` — they run on g6.12xlarge nightly and on the worktree's GPU box. --- tests/policies/test_pi05.py | 49 +++++++++++++++++++++++ tests/policies/test_pi06.py | 78 +++++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) diff --git a/tests/policies/test_pi05.py b/tests/policies/test_pi05.py index a0e9583c..86f7e93c 100644 --- a/tests/policies/test_pi05.py +++ b/tests/policies/test_pi05.py @@ -571,3 +571,52 @@ def capture_embed_suffix_select_action(*args, **kwargs): ) assert action.shape == (1, policy.config.max_action_dim) + + +# Loc-token regression — confirm the `` strings flow through pi05's +# tokenizer + embedding + response_ce_loss path without shape errors and +# produce a finite loss. Guarded by GPU because instantiating the full +# PaliGemma backbone is heavy. + + +@pytest.mark.gpu +@pytest.mark.slow +def test_pi05_loc_tokens_in_response_produce_finite_loss(pi05_training_config, lerobot_dataset_metadata): + """π0.5: a response containing `` should encode each loc string + as a single token (not BPE-fragment), get embedded by the existing + PaliGemma `embed_language_tokens`, and produce a finite `response_ce_loss` + on a one-batch forward pass. Regression for the `ensure_loc_tokens` + promotion wired in `PI05Policy.__init__` / `PI05FlowMatching.__init__`.""" + + config = pi05_training_config.policy + policy = PI05Policy(config, dataset_stats=lerobot_dataset_metadata.stats) + + # The promotion makes a single-token match on both tokenizer + # instances — the policy-level one and the inner FlowMatching one. + for tok in (policy.language_tokenizer, policy.model.language_tokenizer): + assert len(tok.encode("", add_special_tokens=False)) == 1 + assert len(tok.encode("", add_special_tokens=False)) == 1 + + batch_size = 1 + batch = { + "camera0": torch.randn(batch_size, 3, 224, 224), + "camera1": torch.randn(batch_size, 3, 224, 224), + "state": torch.randn(batch_size, config.max_state_dim), + "actions": torch.randn(batch_size, config.chunk_size, config.max_action_dim), + "prompt": ["detect red block"], + "response": [" red block"], + "img_is_pad": torch.zeros(batch_size, 2, dtype=torch.bool), + "action_is_pad": torch.zeros(batch_size, config.chunk_size, dtype=torch.bool), + } + + policy.to(dtype=torch.bfloat16, device="cuda") + batch_cuda = { + k: v.to("cuda", non_blocking=True, dtype=torch.bfloat16) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } + batch_cuda["action_is_pad"] = batch_cuda["action_is_pad"].to(dtype=torch.bool) + + loss = policy.forward(batch_cuda) + assert isinstance(loss, dict) + assert "MSE" in loss and "CE" in loss + assert all(v.isfinite() for v in loss.values()), f"Non-finite loss with loc tokens: {loss}" diff --git a/tests/policies/test_pi06.py b/tests/policies/test_pi06.py index 5bf32a57..68231c0c 100644 --- a/tests/policies/test_pi06.py +++ b/tests/policies/test_pi06.py @@ -531,3 +531,81 @@ def test_complete_pi06_pipeline_integration_smoke(lerobot_dataset_metadata): assert isinstance(loss, dict) assert "MSE" in loss and "CE" in loss assert all(v.isfinite() for v in loss.values()) + + +@pytest.mark.gpu +@pytest.mark.slow +def test_pi06_loc_tokens_extend_vocab_and_resize_embeddings(lerobot_dataset_metadata): + """π0.6: confirms the unconditional Gemma 3 vocab extension wired into + `PI06FlowMatching.__init__`. The bare `google/gemma-3-4b-pt` tokenizer + has no `` tokens; after policy construction: + + * The tokenizer vocab grows by exactly 1024. + * The Gemma 3 input embedding row count matches the new tokenizer length. + * The LM head output dim grows by 1024. + * `` encodes to a single token id. + * A forward pass with a loc-token-bearing response produces finite loss. + """ + from transformers import AutoTokenizer + + from opentau.policies.pi06.modeling_pi06 import PI06Policy + + config = PI06Config( + max_state_dim=32, + max_action_dim=32, + chunk_size=10, + n_action_steps=10, + discrete_action_max_length=32, + predict_response=True, + ) + + from opentau.configs.types import FeatureType + from opentau.datasets.utils import dataset_to_policy_features + + features = dataset_to_policy_features( + { + "state": {"shape": (32,), "dtype": "float32"}, + "actions": {"shape": (10, 32), "dtype": "float32"}, + "camera0": {"shape": (3, 448, 448), "dtype": "image"}, + "camera1": {"shape": (3, 448, 448), "dtype": "image"}, + } + ) + config.output_features = {k: ft for k, ft in features.items() if ft.type is FeatureType.ACTION} + config.input_features = {k: ft for k, ft in features.items() if k not in config.output_features} + + bare_tok_size = len(AutoTokenizer.from_pretrained("google/gemma-3-4b-pt")) + + policy = PI06Policy(config, dataset_stats=lerobot_dataset_metadata.stats) + + inner_tok = policy.model.language_tokenizer + assert len(inner_tok) == bare_tok_size + 1024, ( + f"Expected vocab size {bare_tok_size + 1024} after extension, got {len(inner_tok)}" + ) + assert len(inner_tok.encode("", add_special_tokens=False)) == 1 + assert len(inner_tok.encode("", add_special_tokens=False)) == 1 + + gemma3 = policy.model.gemma3_with_expert.gemma3 + assert gemma3.get_input_embeddings().num_embeddings == len(inner_tok) + # The LM head output dim should also have grown (tied weights with the + # input embedding in Gemma 3 by default). + output_emb = gemma3.get_output_embeddings() + if output_emb is not None: + # Linear layer: out_features matches the resized vocab. + assert output_emb.out_features == len(inner_tok) + + policy.to(dtype=torch.bfloat16, device="cuda") + batch = { + "camera0": torch.randn(1, 3, 448, 448, dtype=torch.bfloat16, device="cuda"), + "camera1": torch.randn(1, 3, 448, 448, dtype=torch.bfloat16, device="cuda"), + "state": torch.randn(1, 32, dtype=torch.bfloat16, device="cuda"), + "actions": torch.randn(1, 10, 32, dtype=torch.bfloat16, device="cuda"), + "prompt": ["detect red block"], + "response": [" red block"], + "img_is_pad": torch.zeros(1, 2, dtype=torch.bool, device="cuda"), + "action_is_pad": torch.zeros(1, 10, dtype=torch.bool, device="cuda"), + } + + loss = policy.forward(batch) + assert isinstance(loss, dict) + assert "MSE" in loss and "CE" in loss + assert all(v.isfinite() for v in loss.values()), f"Non-finite loss with loc tokens: {loss}" From e2f30fecfe0d8101580b60a155a96a3739359b5b Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Sun, 3 May 2026 20:46:21 -0700 Subject: [PATCH 3/5] test(pi06): override actions-stats shape in loc-tokens GPU test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The shared `lerobot_dataset_metadata` fixture carries actions stats sized (50, 32) for the default `chunk_size`. The new loc-tokens GPU regression runs at `chunk_size=10` to stay small, so the Normalize buffer is built from (50, 32) stats while the live actions tensor is (B, 10, 32) — and `(actions - min) / (max - min + EPS)` errors at dim=1. Inline the actions-stats override before calling `PI06Policy(config, ...)` so the buffer matches the test's `chunk_size`. Same pattern that the existing pi06 smoke test uses on `fix/pi06-paper-alignment`; keeping it inline here so this PR stays orthogonal to that one. --- tests/policies/test_pi06.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/policies/test_pi06.py b/tests/policies/test_pi06.py index 68231c0c..7001bfaa 100644 --- a/tests/policies/test_pi06.py +++ b/tests/policies/test_pi06.py @@ -575,7 +575,24 @@ def test_pi06_loc_tokens_extend_vocab_and_resize_embeddings(lerobot_dataset_meta bare_tok_size = len(AutoTokenizer.from_pretrained("google/gemma-3-4b-pt")) - policy = PI06Policy(config, dataset_stats=lerobot_dataset_metadata.stats) + # `lerobot_dataset_metadata` carries actions stats shaped (50, 32) — the + # default chunk_size. This test runs at chunk_size=10 to stay small, so + # override the action-stats arrays to (10, 32) before Normalize buffers + # are constructed. Otherwise `(actions - min) / (max - min + EPS)` errors + # with a shape mismatch (actions is (B, 10, 32); buffer would be (50, 32)). + import copy + + import numpy as np + + dataset_stats = copy.deepcopy(lerobot_dataset_metadata.stats) + for k in ("max", "mean", "min", "std"): + dataset_stats["actions"][k] = np.full( + (config.chunk_size, 32), + float(dataset_stats["actions"][k].flatten()[0]), + dtype=np.float32, + ) + + policy = PI06Policy(config, dataset_stats=dataset_stats) inner_tok = policy.model.language_tokenizer assert len(inner_tok) == bare_tok_size + 1024, ( From c68a488d5a3790302873400630452fcf66e1f819 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Sun, 3 May 2026 20:51:06 -0700 Subject: [PATCH 4/5] test(pi05, pi06): release GPU memory at end of loc-token tests Loading PaliGemma 3B (~6 GB) and Gemma 3 4B (~8 GB) onto a single 32 GB GPU and leaving them resident across tests OOMs the next allocation. Wrap the forward pass in try/finally and `del policy; empty_cache()` at the end so the loc-tokens regressions can run alongside the broader GPU suite on a single-card dev box. --- tests/policies/test_pi05.py | 14 ++++++++++---- tests/policies/test_pi06.py | 14 ++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/tests/policies/test_pi05.py b/tests/policies/test_pi05.py index 86f7e93c..fc6eb94c 100644 --- a/tests/policies/test_pi05.py +++ b/tests/policies/test_pi05.py @@ -616,7 +616,13 @@ def test_pi05_loc_tokens_in_response_produce_finite_loss(pi05_training_config, l } batch_cuda["action_is_pad"] = batch_cuda["action_is_pad"].to(dtype=torch.bool) - loss = policy.forward(batch_cuda) - assert isinstance(loss, dict) - assert "MSE" in loss and "CE" in loss - assert all(v.isfinite() for v in loss.values()), f"Non-finite loss with loc tokens: {loss}" + try: + loss = policy.forward(batch_cuda) + assert isinstance(loss, dict) + assert "MSE" in loss and "CE" in loss + assert all(v.isfinite() for v in loss.values()), f"Non-finite loss with loc tokens: {loss}" + finally: + # Free ~6 GB of PaliGemma weights so adjacent GPU tests in the same + # process don't OOM on a single-GPU dev box. + del policy + torch.cuda.empty_cache() diff --git a/tests/policies/test_pi06.py b/tests/policies/test_pi06.py index 7001bfaa..06dc46d0 100644 --- a/tests/policies/test_pi06.py +++ b/tests/policies/test_pi06.py @@ -622,7 +622,13 @@ def test_pi06_loc_tokens_extend_vocab_and_resize_embeddings(lerobot_dataset_meta "action_is_pad": torch.zeros(1, 10, dtype=torch.bool, device="cuda"), } - loss = policy.forward(batch) - assert isinstance(loss, dict) - assert "MSE" in loss and "CE" in loss - assert all(v.isfinite() for v in loss.values()), f"Non-finite loss with loc tokens: {loss}" + try: + loss = policy.forward(batch) + assert isinstance(loss, dict) + assert "MSE" in loss and "CE" in loss + assert all(v.isfinite() for v in loss.values()), f"Non-finite loss with loc tokens: {loss}" + finally: + # Free ~8 GB of Gemma 3 weights so adjacent GPU tests in the same + # process don't OOM on a single-GPU dev box. + del policy + torch.cuda.empty_cache() From f126271737419d64c89a6c28735b9e8dcdebc8ce Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Mon, 4 May 2026 04:49:28 +0000 Subject: [PATCH 5/5] [claude-fix] address review feedback on #237 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - addresses @claude (loc_codec parser): segment-aware loc_tokens_to_xyxy / loc_tokens_to_points — split on `;` so a malformed segment cannot misalign every subsequent box. Update test_partial_token_count_drops_orphan_pairs to the new contract and add regressions for malformed-segment isolation. - addresses @claude (RNG hazard in ensure_loc_tokens): wrap the resize_token_embeddings call in torch.random.fork_rng with a fixed internal seed so embedding init is reproducible and does not consume entropy from the caller's stream. Add Gemma 3 tests asserting RNG isolation and bit-identical new rows across outer seeds. - addresses @claude (duplicate Gemma 3 / PaliGemma tokenizer load): PI06Policy / PI05Policy now share a single tokenizer instance with their inner FlowMatching; ensure_loc_tokens runs once. Existing pi05/pi06 GPU tests assert the shared identity. tests: passed — pytest -m "not gpu" -n auto tests/datasets/ tests/policies/test_pi05.py tests/policies/test_pi06.py tests/configs/ (16 HF-gated tests skipped locally for lack of HF auth; CI runs them with auth) Co-Authored-By: Claude Opus 4.7 (1M context) --- src/opentau/datasets/grounding/loc_codec.py | 38 +++++++++---- .../datasets/grounding/tokenizer_utils.py | 29 +++++++++- src/opentau/policies/pi05/modeling_pi05.py | 40 ++++++++++---- src/opentau/policies/pi06/modeling_pi06.py | 43 ++++++++++----- tests/datasets/test_loc_codec.py | 53 ++++++++++++++++--- tests/datasets/test_loc_tokens_gemma3.py | 49 +++++++++++++++++ tests/policies/test_pi05.py | 13 +++-- tests/policies/test_pi06.py | 4 ++ 8 files changed, 221 insertions(+), 48 deletions(-) diff --git a/src/opentau/datasets/grounding/loc_codec.py b/src/opentau/datasets/grounding/loc_codec.py index b9b60db0..72b84d35 100644 --- a/src/opentau/datasets/grounding/loc_codec.py +++ b/src/opentau/datasets/grounding/loc_codec.py @@ -49,6 +49,11 @@ _LOC_TOKEN_RE = re.compile(r"") +# Segment separator the encoder emits between adjacent box/point entries +# (e.g. ``"<...> dog ; <...> cat"``). Decoders split on this so that a +# malformed segment cannot misalign every subsequent one. +SEGMENT_SEPARATOR = ";" + def _quantize(coord: float, extent: float) -> int: """Map a single pixel coordinate to a `[0, 1023]` bin index. @@ -117,22 +122,28 @@ def point_to_loc_tokens(x: float, y: float, img_w: int, img_h: int) -> str: def loc_tokens_to_xyxy(s: str, img_w: int, img_h: int) -> list[tuple[float, float, float, float]]: """Parse a string of loc tokens into `(x_min, y_min, x_max, y_max)` pixel boxes. - Tolerant: any segment that does not contain a multiple of four loc tokens - is dropped silently. Garbage strings or partial decodes return ``[]``. + Tolerant and segment-aware: the input is split on the encoder's segment + separator (``;``), and each segment must contribute exactly four loc + tokens to yield a box. A segment with any other count (0, 1, 2, 3, 5, + ...) is dropped silently — its tokens do NOT spill into the next + segment, so a single malformed box cannot misalign every subsequent one. + Garbage strings or partial decodes return ``[]``. Args: s: A string that may contain `` tokens, e.g. a decoded - response. Non-loc text is ignored. + response. Non-loc text within a segment is ignored. img_w: Original image width in pixels. img_h: Original image height in pixels. Returns: A list of `(x_min, y_min, x_max, y_max)` tuples in pixel coordinates. """ - bins = [int(m) for m in _LOC_TOKEN_RE.findall(s)] boxes: list[tuple[float, float, float, float]] = [] - for i in range(0, len(bins) - 3, 4): - y_min_b, x_min_b, y_max_b, x_max_b = bins[i : i + 4] + for segment in s.split(SEGMENT_SEPARATOR): + bins = [int(m) for m in _LOC_TOKEN_RE.findall(segment)] + if len(bins) != 4: + continue + y_min_b, x_min_b, y_max_b, x_max_b = bins boxes.append( ( _dequantize(x_min_b, img_w), @@ -147,12 +158,17 @@ def loc_tokens_to_xyxy(s: str, img_w: int, img_h: int) -> list[tuple[float, floa def loc_tokens_to_points(s: str, img_w: int, img_h: int) -> list[tuple[float, float]]: """Parse a string of loc tokens into `(x, y)` pixel points. - Pairs of loc tokens are decoded as `(y, x)` per the PaliGemma convention - and returned as `(x, y)`. Lone trailing tokens are dropped. + Tolerant and segment-aware in the same sense as `loc_tokens_to_xyxy`: + the input is split on ``;``, and each segment must contribute exactly + two loc tokens (in `(y, x)` order per the PaliGemma convention) to + yield a point. Segments with any other count are dropped — a malformed + segment cannot shift later ones. """ - bins = [int(m) for m in _LOC_TOKEN_RE.findall(s)] points: list[tuple[float, float]] = [] - for i in range(0, len(bins) - 1, 2): - y_b, x_b = bins[i : i + 2] + for segment in s.split(SEGMENT_SEPARATOR): + bins = [int(m) for m in _LOC_TOKEN_RE.findall(segment)] + if len(bins) != 2: + continue + y_b, x_b = bins points.append((_dequantize(x_b, img_w), _dequantize(y_b, img_h))) return points diff --git a/src/opentau/datasets/grounding/tokenizer_utils.py b/src/opentau/datasets/grounding/tokenizer_utils.py index de5425f8..f21d5fbc 100644 --- a/src/opentau/datasets/grounding/tokenizer_utils.py +++ b/src/opentau/datasets/grounding/tokenizer_utils.py @@ -40,10 +40,20 @@ import logging +import torch from transformers.tokenization_utils_base import AddedToken LOC_TOKENS: tuple[str, ...] = tuple(f"" for i in range(1024)) +# Fixed seed used to initialize the new `` embedding rows on +# Gemma 3. Hardcoded (not a tunable) so policy construction is bit-stable +# regardless of when in setup `ensure_loc_tokens` fires. CLAUDE.md hard +# rule #3 (deterministic seeded reruns of the training loop) depends on +# this — if the resize were to consume the active RNG, two seeded runs +# could diverge purely from where the helper is called, even though the +# loop seed is identical. +_LOC_EMBEDDING_INIT_SEED: int = 0x10CC0DE + _logger = logging.getLogger(__name__) @@ -58,6 +68,15 @@ def ensure_loc_tokens(tokenizer, model=None) -> int: and tied LM head are resized via ``model.resize_token_embeddings`` when a model handle is supplied. + The embedding resize is wrapped in a snapshot/restore of the global torch + RNG (CPU + all visible CUDA devices) and re-seeded with a fixed constant + inside that block. This guarantees the 1024 new rows are bit-identical + across runs regardless of when in construction the helper is called, and + leaves the outer RNG state untouched so downstream consumers (the loop + seed, dataset shuffler, dropout, etc.) are not perturbed. Without this, + construction-time embedding init would couple to the active RNG and + silently violate CLAUDE.md hard rule #3 (deterministic seeded reruns). + Safe to call multiple times — once the strings are registered as added tokens, subsequent calls neither grow the vocab nor resize. @@ -82,7 +101,15 @@ def ensure_loc_tokens(tokenizer, model=None) -> int: n_new_ids = len(tokenizer) - initial_len if n_new_ids > 0 and model is not None: - model.resize_token_embeddings(len(tokenizer)) + # Fork RNG so the resize's random init does not consume entropy from + # the caller's RNG stream and is reproducible across runs. We fork + # CPU + every visible CUDA device; the seed inside the fork is fixed. + cuda_devices = list(range(torch.cuda.device_count())) if torch.cuda.is_available() else [] + with torch.random.fork_rng(devices=cuda_devices, enabled=True): + torch.manual_seed(_LOC_EMBEDDING_INIT_SEED) + if cuda_devices: + torch.cuda.manual_seed_all(_LOC_EMBEDDING_INIT_SEED) + model.resize_token_embeddings(len(tokenizer)) if n_new_ids > 0: _logger.info( diff --git a/src/opentau/policies/pi05/modeling_pi05.py b/src/opentau/policies/pi05/modeling_pi05.py index 41f83536..e340f88a 100644 --- a/src/opentau/policies/pi05/modeling_pi05.py +++ b/src/opentau/policies/pi05/modeling_pi05.py @@ -265,20 +265,22 @@ def __init__( config.output_features, config.normalization_mapping, dataset_stats ) + # The same PaliGemma tokenizer instance is shared with the inner + # `PI05FlowMatching`. The single `ensure_loc_tokens` call inside the + # inner ctor promotes the reserved .. entries on + # both layers at once — no second load, no risk of revision drift. self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") - # PaliGemma reserves .. at IDs 256000..257023, but - # the bare HF tokenizer does not register them as added/special - # tokens — a string "" otherwise BPE-fragments into seven - # pieces. This call promotes the reserved entries to single-token - # match mode (no new IDs, no embedding resize on PaliGemma). - ensure_loc_tokens(self.language_tokenizer) self.discrete_action_processor = AutoProcessor.from_pretrained( "physical-intelligence/fast", trust_remote_code=True ) # Get vocab size from processor discrete_action_vocab_size = getattr(self.discrete_action_processor, "vocab_size", None) - self.model = PI05FlowMatching(config, discrete_action_vocab_size=discrete_action_vocab_size) + self.model = PI05FlowMatching( + config, + discrete_action_vocab_size=discrete_action_vocab_size, + language_tokenizer=self.language_tokenizer, + ) self.reset() @@ -919,12 +921,22 @@ class PI05FlowMatching(nn.Module): └──────────────────────────────────────────┘ """ - def __init__(self, config: PI05Config, discrete_action_vocab_size: int | None = None): + def __init__( + self, + config: PI05Config, + discrete_action_vocab_size: int | None = None, + language_tokenizer: AutoTokenizer | None = None, + ): """Initializes the PI05FlowMatching model. Args: config: Model configuration. discrete_action_vocab_size: Size of the discrete action vocabulary. + language_tokenizer: Optional pre-loaded PaliGemma tokenizer to share + with the enclosing `PI05Policy`. When ``None`` (e.g. unit tests + that construct the inner module directly) the tokenizer is + loaded here. Either way, the same instance is used by both + layers, and `ensure_loc_tokens` runs once. """ super().__init__() self.config = config @@ -950,9 +962,15 @@ def __init__(self, config: PI05Config, discrete_action_vocab_size: int | None = self.time_mlp_in = nn.Linear(self.config.proj_width, self.config.proj_width) self.time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width) - self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") - # See PI05Policy.__init__ — promotes the reserved entries - # to single-token match mode on this tokenizer instance as well. + if language_tokenizer is None: + language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + self.language_tokenizer = language_tokenizer + # PaliGemma reserves .. at IDs 256000..257023, but + # the bare HF tokenizer does not register them as added/special + # tokens — a string "" otherwise BPE-fragments into seven + # pieces. This call promotes the reserved entries to single-token + # match mode (no new IDs, no embedding resize on PaliGemma) and + # mutates the shared tokenizer instance for `PI05Policy` too. ensure_loc_tokens(self.language_tokenizer) def sample_noise(self, shape: tuple[int, ...], device: torch.device | str) -> Tensor: diff --git a/src/opentau/policies/pi06/modeling_pi06.py b/src/opentau/policies/pi06/modeling_pi06.py index 0b4939bb..ad75ceb4 100644 --- a/src/opentau/policies/pi06/modeling_pi06.py +++ b/src/opentau/policies/pi06/modeling_pi06.py @@ -246,19 +246,24 @@ def __init__( config.output_features, config.normalization_mapping, dataset_stats ) - # π0.6 uses Gemma 3's tokenizer. We fall back to the paligemma tokenizer - # only if the Hub download fails at module import time in an offline CI, - # so users still get a useful error rather than silent drift. + # π0.6 uses Gemma 3's tokenizer. The same instance is shared with the + # inner `PI06FlowMatching` so vocab extension happens exactly once and + # token IDs cannot drift between the two layers (e.g. if anyone + # introduces a non-deterministic adder, two independent loads at + # different revisions, or reorders the calls). The single + # `ensure_loc_tokens` call inside the inner ctor extends both this + # tokenizer and resizes the model embeddings together. self.language_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt") - # Tokenizer-side extension; the matching embedding/LM-head resize fires - # inside PI06FlowMatching.__init__, where the Gemma 3 model handle exists. - ensure_loc_tokens(self.language_tokenizer) self.discrete_action_processor = AutoProcessor.from_pretrained( "physical-intelligence/fast", trust_remote_code=True ) discrete_action_vocab_size = getattr(self.discrete_action_processor, "vocab_size", None) - self.model = PI06FlowMatching(config, discrete_action_vocab_size=discrete_action_vocab_size) + self.model = PI06FlowMatching( + config, + discrete_action_vocab_size=discrete_action_vocab_size, + language_tokenizer=self.language_tokenizer, + ) self.reset() @@ -714,12 +719,22 @@ class PI06FlowMatching(nn.Module): └──────────────────────────────────────────┘ """ - def __init__(self, config: PI06Config, discrete_action_vocab_size: int | None = None): + def __init__( + self, + config: PI06Config, + discrete_action_vocab_size: int | None = None, + language_tokenizer: AutoTokenizer | None = None, + ): """Initializes the PI06FlowMatching model. Args: config: `PI06Config` instance. discrete_action_vocab_size: FAST tokenizer vocabulary size. + language_tokenizer: Optional pre-loaded Gemma 3 tokenizer to share + with the enclosing `PI06Policy`. When ``None`` (e.g. unit tests + that construct the inner module directly) the tokenizer is + loaded here. Either way, the same instance is used by both + layers — there is no second copy to fall out of sync. """ super().__init__() self.config = config @@ -740,17 +755,19 @@ def __init__(self, config: PI06Config, discrete_action_vocab_size: int | None = self.time_mlp_in = nn.Linear(self.config.proj_width, self.config.proj_width) self.time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width) - self.language_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt") + if language_tokenizer is None: + language_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt") + self.language_tokenizer = language_tokenizer # π0.6 uses Gemma 3, whose stock tokenizer does NOT carry the 1024 # .. grounding tokens that PaliGemma reserves. We # unconditionally extend the vocab here so any grounding/VQA training # data containing loc tokens flows through the same response_ce_loss # path as on PaliGemma backbones. The new embedding rows are random- - # init — they learn from grounding data on first use; there is NO - # PaliGemma loc-embedding transfer. The resize must happen after + # init under a forked, fixed-seed RNG (see `ensure_loc_tokens`); there + # is NO PaliGemma loc-embedding transfer. The resize must happen after # `Gemma3WithExpertModel(...)` has already loaded the public Gemma 3 - # weights (line above), so the original 256K rows survive and only - # the 1024 new rows are freshly initialized. + # weights (above), so the original 256K rows survive and only the 1024 + # new rows are freshly initialized. ensure_loc_tokens(self.language_tokenizer, model=self.gemma3_with_expert.gemma3) def sample_noise(self, shape: tuple[int, ...], device: torch.device | str) -> Tensor: diff --git a/tests/datasets/test_loc_codec.py b/tests/datasets/test_loc_codec.py index 2ce147da..e25a8093 100644 --- a/tests/datasets/test_loc_codec.py +++ b/tests/datasets/test_loc_codec.py @@ -112,17 +112,56 @@ def test_garbage_input_returns_empty() -> None: assert loc_tokens_to_points("garbage with no tokens", 100, 100) == [] -def test_partial_token_count_drops_orphan_pairs() -> None: - """A response containing 6 loc tokens decodes to 1 box (4 tokens consumed). - - The trailing pair is silently dropped — this matches the codec's tolerant - contract documented in `loc_tokens_to_xyxy`'s docstring. +def test_segment_with_non_four_token_count_is_dropped() -> None: + """A single segment with anything other than 4 loc tokens yields 0 boxes. + + The decoder is segment-aware (splits on ``;``); a malformed segment is + dropped silently rather than spilling its tokens into a neighbour. Six + loc tokens in one segment is malformed and produces 0 boxes — not 1 + box plus two orphans, which would be the buggy behaviour of a global + pairing scheme. """ img_w, img_h = 1024, 1024 - # Six loc tokens — one full box plus two orphans. response = "" boxes = loc_tokens_to_xyxy(response, img_w, img_h) - assert len(boxes) == 1 + assert boxes == [] + + +def test_malformed_segment_does_not_misalign_following_segments() -> None: + """Regression: a 5-loc-token segment must NOT shift the boundary of the next one. + + If the parser collected loc tokens globally and grouped them in fours, + the second box would absorb the orphan from the first segment and end + up encoding completely wrong coordinates — silent box-to-label + misattribution at eval time. Segment-aware parsing drops the bad + segment and decodes the next one cleanly. + """ + img_w, img_h = 1024, 1024 + good_box = xyxy_to_loc_tokens((100.0, 200.0, 300.0, 400.0), img_w, img_h) + # 5 loc tokens in the first segment — malformed. + bad_segment = " dog" + response = f"{bad_segment} ; {good_box} cat" + + decoded = loc_tokens_to_xyxy(response, img_w, img_h) + assert len(decoded) == 1 + x_min, y_min, x_max, y_max = decoded[0] + assert abs(x_min - 100.0) <= _tol(img_w) + assert abs(y_min - 200.0) <= _tol(img_h) + assert abs(x_max - 300.0) <= _tol(img_w) + assert abs(y_max - 400.0) <= _tol(img_h) + + +def test_points_segment_aware() -> None: + """`loc_tokens_to_points` is segment-aware — a 3-token bad segment doesn't shift the next.""" + img_w, img_h = 1024, 1024 + good_point = point_to_loc_tokens(500.0, 500.0, img_w, img_h) + response = f" noise ; {good_point} target" + + decoded = loc_tokens_to_points(response, img_w, img_h) + assert len(decoded) == 1 + x, y = decoded[0] + assert abs(x - 500.0) <= _tol(img_w) + assert abs(y - 500.0) <= _tol(img_h) def test_codec_uses_original_image_dims_not_post_resize() -> None: diff --git a/tests/datasets/test_loc_tokens_gemma3.py b/tests/datasets/test_loc_tokens_gemma3.py index 3649bc7b..fb651020 100644 --- a/tests/datasets/test_loc_tokens_gemma3.py +++ b/tests/datasets/test_loc_tokens_gemma3.py @@ -135,3 +135,52 @@ def test_resize_does_not_fire_on_idempotent_call() -> None: assert n_added == 0 assert fake.resize_calls == [] + + +@pytest.mark.slow +def test_ensure_loc_tokens_does_not_perturb_caller_rng() -> None: + """The resize must not consume entropy from the caller's RNG stream. + + Two `torch.randn` draws bracketing `ensure_loc_tokens(..., model=...)` + must match what the same outer RNG produces without the helper running + in between. Otherwise the `model.resize_token_embeddings` random init + would couple construction order to RNG state and silently violate + CLAUDE.md hard rule #3 (deterministic seeded reruns). + """ + tok = _fresh_gemma3_tokenizer() + fake = _FakeResizableModel(initial_vocab=len(tok)) + + torch.manual_seed(123) + expected = torch.randn(8) + + torch.manual_seed(123) + ensure_loc_tokens(tok, model=fake) + actual = torch.randn(8) + + assert torch.equal(expected, actual), ( + "ensure_loc_tokens leaked the resize's RNG draws into the caller's stream" + ) + + +@pytest.mark.slow +def test_ensure_loc_tokens_resize_is_seed_independent() -> None: + """The resize seeds deterministically, so two calls under different + outer RNG states must produce bit-identical new embedding rows. + """ + initial_vocab = len(_fresh_gemma3_tokenizer()) + + tok_a = _fresh_gemma3_tokenizer() + fake_a = _FakeResizableModel(initial_vocab=initial_vocab, hidden=16) + torch.manual_seed(7) + ensure_loc_tokens(tok_a, model=fake_a) + + tok_b = _fresh_gemma3_tokenizer() + fake_b = _FakeResizableModel(initial_vocab=initial_vocab, hidden=16) + torch.manual_seed(99) + ensure_loc_tokens(tok_b, model=fake_b) + + new_rows_a = fake_a.embed.weight[initial_vocab:] + new_rows_b = fake_b.embed.weight[initial_vocab:] + assert torch.equal(new_rows_a, new_rows_b), ( + "new embedding rows differ between runs — RNG snapshot/restore failed" + ) diff --git a/tests/policies/test_pi05.py b/tests/policies/test_pi05.py index fc6eb94c..9e196952 100644 --- a/tests/policies/test_pi05.py +++ b/tests/policies/test_pi05.py @@ -591,11 +591,14 @@ def test_pi05_loc_tokens_in_response_produce_finite_loss(pi05_training_config, l config = pi05_training_config.policy policy = PI05Policy(config, dataset_stats=lerobot_dataset_metadata.stats) - # The promotion makes a single-token match on both tokenizer - # instances — the policy-level one and the inner FlowMatching one. - for tok in (policy.language_tokenizer, policy.model.language_tokenizer): - assert len(tok.encode("", add_special_tokens=False)) == 1 - assert len(tok.encode("", add_special_tokens=False)) == 1 + # PI05Policy and PI05FlowMatching share a single tokenizer instance so + # token IDs cannot drift between the two layers. + assert policy.language_tokenizer is policy.model.language_tokenizer + # The promotion makes a single-token match on the (shared) + # tokenizer. + tok = policy.language_tokenizer + assert len(tok.encode("", add_special_tokens=False)) == 1 + assert len(tok.encode("", add_special_tokens=False)) == 1 batch_size = 1 batch = { diff --git a/tests/policies/test_pi06.py b/tests/policies/test_pi06.py index 06dc46d0..884a14ef 100644 --- a/tests/policies/test_pi06.py +++ b/tests/policies/test_pi06.py @@ -594,6 +594,10 @@ def test_pi06_loc_tokens_extend_vocab_and_resize_embeddings(lerobot_dataset_meta policy = PI06Policy(config, dataset_stats=dataset_stats) + # PI06Policy and PI06FlowMatching share a single tokenizer instance so + # the outer encodes index the inner model's resized embedding correctly + # by construction (no risk of revision drift between two loads). + assert policy.language_tokenizer is policy.model.language_tokenizer inner_tok = policy.model.language_tokenizer assert len(inner_tok) == bare_tok_size + 1024, ( f"Expected vocab size {bare_tok_size + 1024} after extension, got {len(inner_tok)}"