Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions tests/configs/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ class TestPretrainedConfigCodec:
Without these handlers, ``PI07HighLevelPlannerConfig._save_pretrained``
raises ``Exception("No parser for object ...")`` when draccus tries to
serialise the nested ``vlm_config: Gemma3WithExpertConfig`` field β€”
which is the bug surfaced when bootstrapping the
``TensorAuto/pi07-{high,low}-untrained`` checkpoints.
which is the bug originally surfaced when bootstrapping a Ο€0.7
high-level / low-level planner checkpoint from public Gemma 3 weights.
"""

def test_encode_dispatches_via_to_dict_for_subclass(self):
Expand Down Expand Up @@ -242,9 +242,8 @@ def test_pi07_high_level_config_save_load_round_trip(self, tmp_path: Path):
``_save_pretrained`` -> ``PreTrainedConfig.from_pretrained``, with the
``vlm_config`` subtree preserved.

This is the actual failure mode that broke the
``TensorAuto/pi07-high-untrained`` build before the codec was
registered.
This is the actual failure mode that broke a Ο€0.7 high-level
planner checkpoint bootstrap before the codec was registered.
"""
from opentau.configs.policies import PreTrainedConfig
from opentau.policies.pi07.high_level_planner.configuration_pi07_high_level import (
Expand Down
174 changes: 0 additions & 174 deletions tests/policies/test_pi06.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,177 +901,3 @@ def test_pi06_loc_tokens_extend_vocab_and_resize_embeddings(lerobot_dataset_meta
# process don't OOM on a single-GPU dev box.
del policy
torch.cuda.empty_cache()


# Published `TensorAuto/pi06-untrained` checkpoint β€” load + weight equivalence.
#
# These tests verify the published init checkpoint (issue #239): that it loads
# cleanly via `PI06Policy.from_pretrained("TensorAuto/pi06-untrained")` with no
# missing/unexpected keys, and that the VLM / SigLIP submodules are bit-for-bit
# identical to `google/gemma-3-4b-pt`. Heavy: each fixture loads ~10 GB on CPU,
# so they're gated behind `gpu` + `slow` to stay out of CPU CI. They do NOT
# actually require CUDA β€” comparison is on CPU in bf16 β€” but the markers double
# as a "needs heavy infra + network" gate matching the surrounding pi06 tests.


@pytest.fixture(scope="module")
def pi06_untrained_policy():
"""Load `TensorAuto/pi06-untrained` once on CPU; reused across the module."""
from opentau.policies.pi06.modeling_pi06 import PI06Policy

return PI06Policy.from_pretrained("TensorAuto/pi06-untrained")


_SIGLIP_POS_EMBED_KEY = "model.vision_tower.vision_model.embeddings.position_embedding.weight"


@pytest.fixture(scope="module")
def gemma3_4b_pt_aligned_state_dict(pi06_untrained_policy):
"""`google/gemma-3-4b-pt`'s state_dict, aligned to the policy:

- `ensure_loc_tokens` extends the embedding/LM head with the same 1024
`<locNNNN>` rows the policy carries (deterministic via fixed-seed RNG fork).
- The SigLIP `position_embedding` is bilinearly resampled from the published
4096 patches (896Γ—896) down to the policy's 1024 patches (448Γ—448) β€” the
same operation the build script applies before saving the checkpoint.

Returns a `dict[str, Tensor]` rather than the model itself because the
resampled position embedding has a different shape than the original
`nn.Embedding` parameter β€” so we can't mutate the model in place."""
from transformers import (
AutoTokenizer,
Gemma3ForConditionalGeneration, # noqa: I100
)

from opentau.datasets.grounding.tokenizer_utils import ensure_loc_tokens
from opentau.utils.vision_utils import bilinear_resample_pos_embed

model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-pt", torch_dtype=torch.bfloat16)
tok = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt")
ensure_loc_tokens(tok, model=model)

state = model.state_dict()
pol_n_patches = pi06_untrained_policy.model.gemma3_with_expert.gemma3.state_dict()[
_SIGLIP_POS_EMBED_KEY
].shape[0]
state[_SIGLIP_POS_EMBED_KEY] = bilinear_resample_pos_embed(
state[_SIGLIP_POS_EMBED_KEY], target_num_patches=pol_n_patches
)
del model
return state


def _diff_state_dicts_against_reference(
policy_state: dict[str, torch.Tensor],
reference_state: dict[str, torch.Tensor],
*,
name_filter,
) -> tuple[int, list[str]]:
"""Compare every entry of `reference_state` whose name passes `name_filter`
against the matching entry in `policy_state`. Returns `(checked, mismatches)`.

`torch.equal` is used (not `allclose`) because the policy weights came from
the same source bf16 tensors with no precision-changing transform β€” they
must be byte-identical, and any drift is a real bug worth surfacing.
"""
mismatches: list[str] = []
checked = 0
for name, ref in reference_state.items():
if not name_filter(name):
continue
pol = policy_state.get(name)
if pol is None:
mismatches.append(f"{name}: missing in policy state_dict")
continue
if pol.shape != ref.shape:
mismatches.append(f"{name}: shape mismatch {tuple(pol.shape)} vs {tuple(ref.shape)}")
continue
if not torch.equal(pol, ref):
max_abs = (pol.float() - ref.float()).abs().max().item()
mismatches.append(f"{name}: tensor not equal (max abs diff {max_abs:g})")
continue
checked += 1
return checked, mismatches


@pytest.mark.gpu
@pytest.mark.slow
def test_pi06_untrained_loads_with_no_missing_or_unexpected_keys(pi06_untrained_policy):
"""Every key in the loaded policy's `state_dict()` is either present in the
published `model.safetensors`, or shares storage with a key that is. The
second case covers Gemma 3's tied embed_tokens/lm_head, which
`save_model_as_safetensor` de-duplicates on save (only one name on disk;
both are populated at load time via the model's `_tie_weights()` hook).
No safetensors key may be unexpected.

Catches the regression we actually care about: a save flow that silently
drops a non-tied parameter (e.g. an action-expert layer) would surface
here because that parameter has its own data_ptr and so won't be in any
tied group."""
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

ckpt_path = hf_hub_download("TensorAuto/pi06-untrained", "model.safetensors")
saved_keys = set(load_file(ckpt_path).keys())
loaded_state = pi06_untrained_policy.state_dict()
expected_keys = set(loaded_state.keys())

# Build groups of state_dict keys that share storage (tied weights).
ptr_to_names: dict[int, list[str]] = {}
for name, tensor in loaded_state.items():
ptr_to_names.setdefault(tensor.data_ptr(), []).append(name)
tied_groups = [names for names in ptr_to_names.values() if len(names) > 1]

# Each tied group must have exactly one representative on disk.
tied_group_violations = []
allowed_missing: set[str] = set()
for group in tied_groups:
on_disk = [n for n in group if n in saved_keys]
if len(on_disk) != 1:
tied_group_violations.append(
f"tied group {sorted(group)}: expected exactly 1 on disk, got {len(on_disk)}"
)
for n in group:
if n not in saved_keys:
allowed_missing.add(n)

missing = expected_keys - saved_keys - allowed_missing
unexpected = saved_keys - expected_keys
assert not tied_group_violations, "\n".join(tied_group_violations)
assert not missing and not unexpected, (
f"missing in safetensors ({len(missing)}): {sorted(missing)[:10]}\n"
f"unexpected in safetensors ({len(unexpected)}): {sorted(unexpected)[:10]}"
)


@pytest.mark.skip(reason="Requires too much memory, does not fit on RTX 3090 24GB")
@pytest.mark.gpu
@pytest.mark.slow
def test_pi06_untrained_vlm_matches_gemma3_4b_pt(pi06_untrained_policy, gemma3_4b_pt_aligned_state_dict):
"""Gemma 3 text tower + multimodal projector inside the published checkpoint
are byte-identical to `google/gemma-3-4b-pt` (vision tower checked separately
in `test_pi06_untrained_siglip_matches_gemma3_4b_pt`)."""
pol_state = pi06_untrained_policy.model.gemma3_with_expert.gemma3.state_dict()

checked, mismatches = _diff_state_dicts_against_reference(
pol_state, gemma3_4b_pt_aligned_state_dict, name_filter=lambda name: "vision_tower" not in name
)
assert checked > 0, "Sanity: should have compared at least one VLM (non-vision) param"
assert not mismatches, "VLM (Gemma 3 text + projector) mismatches:\n" + "\n".join(mismatches[:20])


@pytest.mark.gpu
@pytest.mark.slow
def test_pi06_untrained_siglip_matches_gemma3_4b_pt(pi06_untrained_policy, gemma3_4b_pt_aligned_state_dict):
"""SigLIP vision tower inside the published checkpoint is byte-identical to
the vision tower bundled in `google/gemma-3-4b-pt`, modulo a deterministic
bilinear resample of `position_embedding` from 4096 (896Γ—896 published) to
1024 (448Γ—448 Ο€0.6) patches β€” the build script applies the same resample,
so byte-equality holds."""
pol_state = pi06_untrained_policy.model.gemma3_with_expert.gemma3.state_dict()

checked, mismatches = _diff_state_dicts_against_reference(
pol_state, gemma3_4b_pt_aligned_state_dict, name_filter=lambda name: "vision_tower" in name
)
assert checked > 0, "Sanity: should have compared at least one SigLIP param"
assert not mismatches, "SigLIP vision tower mismatches:\n" + "\n".join(mismatches[:20])
Loading
Loading