refactor(policies): remove init_strategy#188
Conversation
The "expert_only_he_init" path loaded google/paligemma-3b-pt-224 inside the model constructor (random expert + pretrained backbone), but every shipped config that named it also set pretrained_path, which silently overrode init_strategy to "no_init" before construction. The path was dead config in practice. Users who want a pretrained backbone now go through pretrained_path exclusively. - Drop expert_only_he_init from PI0/PI05/PI05Mem/PI05ContinuousState configs and the corresponding _init_model branches. - Drop load_pretrained_paligemma from PaliGemmaWithExpertConfig/Model in pi0 and pi05; PaliGemma is always built from config now. - Make init_strategy default to None and resolve in __post_init__: pretrained_path set -> "no_init", else "full_he_init". Warn (not info) when a user-set non-no_init value is overridden. - Strip init_strategy from saved config JSON and from the wandb upload dict alongside the existing deprecated-latency stripping; it's a construction-time-only field with no meaning in a saved record. - Remove now-dead "init_strategy" entries from all 9 shipped configs. https://claude.ai/code/session_01KaAzgu1xy7io6Etf6MXYnJ
Continue the cleanup begun in the previous commit by removing the init_strategy field, the _init_model/_init_weights methods, and the post_init resolution logic across all four policy variants. Without the expert_only_he_init path, init_strategy was just a knob to choose between He-init everything (full_he_init) or skip init (no_init), and in every shipped config the post_init override was driving it to no_init anyway because pretrained_path was set. With from_pretrained loading state dict on top, custom He-init was either redundant (overwritten) or unused (no fresh-from-scratch configs ship). - Drop init_strategy field, docstring entries, and post_init resolution in PI0Config / PI05Config / PI05MemConfig / PI05ContinuousStateConfig. Drop now-unused logging and Literal imports where applicable. - Remove _init_model() call and _init_model / _init_weights method bodies from the four modeling_*.py constructors. - In policies.py: rename _TRANSIENT_POLICY_FIELDS to _REMOVED_POLICY_FIELDS and call strip_deprecated_fields_from_json on the config_file before draccus.parse in PreTrainedConfig and TrainPipelineConfig from_pretrained, so old saved configs that still contain "init_strategy" load without erroring on draccus's strict unknown-field check. Drop strip_transient_fields_in_place and revert TrainPipelineConfig.to_dict() to its pre-PR shape — the field no longer exists in the dataclass, so draccus.encode never emits it. https://claude.ai/code/session_01KaAzgu1xy7io6Etf6MXYnJ
Reconcile with #161 which deprecates the pi05_continuous_state policy and merges its functionality into PI05Config via state_type. - src/opentau/policies/pi05/configuration_pi05.py: kept main's state_type field, PI05ContinuousStateConfig deprecated wrapper, and warnings/Literal imports; dropped main's init_strategy validation block (the field is gone on this branch). Removed the now-unused logging import. - src/opentau/policies/pi05_continuous_state/configuration_pi05.py and modeling_pi05.py: accepted main's deletion (modify/delete conflict). - pi05/modeling_pi05.py and the rest auto-merged cleanly: this branch's removal of _init_model / _init_weights survived intact.
|
@claude fix the merge conflicts |
- addresses @shuheng-liu (fix merge conflicts): re-merge latest main into the branch and resolve two docstring conflicts. In pi05/paligemma_with_expert.py keep main's improved attention_implementation docstring (eager/sdpa/fa2) but drop the load_pretrained_paligemma line — that param was removed by this PR. In pi05_mem/configuration_pi05.py keep main's improved attention_implementation + freeze_vision_encoder docstrings but drop init_strategy from the args list — that field was removed by this PR. - addresses @shuheng-liu (fix merge conflicts, runtime fallout): main added tests/policies/test_pi05_mem_gpu.py which constructed PI05MemConfig with init_strategy="no_init"; that's now an unknown kwarg on a dataclass config and crashes at construction. Drop the kwarg and update the comment to reference pretrained_path=None as the equivalent skip-loads signal. tests: passed — pytest -m "not gpu" -n auto --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py tests/policies/ tests/configs/ tests/scripts/test_train.py (159 passed, 2 skipped) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
[claude-fix] Merged latest origin/main (
CPU tests: passed locally — Two follow-up items I deliberately did NOT touch — they may warrant separate PRs:
Happy to address either or both of those if you'd like — let me know. |
|
@claude fix the first one in pi06. Leave pi07_paligemma alone. Also make sure the 2 breaking CPU tests skipped by main is also skipped here. |
- addresses @shuheng-liu (pi06 init_strategy regression): drop init_strategy field, post_init resolution, _init_model/_init_weights, and the load_pretrained_gemma3 plumbing in pi06's Gemma3WithExpertConfig/Model. Also strip "init_strategy" from configs/examples/pi06_training_config.json so _REMOVED_POLICY_FIELDS no longer silently drops a user choice that used to be load-bearing. - addresses @shuheng-liu (CPU test skips): verified the merge-in cpu_test.yml matches main's skip list verbatim — same --ignore for test_planner.py, test_libero_utils.py, test_pi07_paligemma_low_level_planner.py, and the two --deselect entries for test_make_env_{async,sync}_vector_env. Re-ran tests locally with that exact skip list (vs. the previous run which only ignored the pi07 file) and confirmed the suite passes. tests: passed -- pytest -m "not gpu" -n auto --ignore=tests/planner/test_planner.py --ignore=tests/utils/test_libero_utils.py --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_async_vector_env --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_sync_vector_env tests/policies/ tests/configs/ tests/scripts/test_train.py (159 passed, 2 skipped) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
[claude-fix] Done — pushed in a595c8e. 1) pi06 — drop
Per your instruction, pi07_paligemma is left alone — its 2) CPU test skips. Confirmed Tests: passed locally — |
|
[claude-review] summary for commit c6befca Scope: removal of Findings:
Non-blocking notes:
|
|
@claude fix the suggestions and nits if they make sense |
- addresses claude bot (in-place HF cache mutation): replace `strip_deprecated_fields_from_json(Path(config_file))` at the load sites in `from_pretrained` (`policies.py`, `train.py`) with new `load_stripped_config_to_tempfile()` that strips in memory and writes to a `tempfile.mkstemp` file. The HF cache symlink → blob is no longer mutated; user-supplied paths are no longer reformatted as a side effect of load. The on-write strip in `_save_pretrained` is retained (caller owns the directory). - addresses claude bot (silent strip of removed fields): add `warn_removed_policy_fields()` sibling to `warn_deprecated_latency_fields()` and call it before the strip at both load sites, so users loading old configs containing `init_strategy=...` see a `DeprecationWarning` rather than having their choice silently dropped. Factored the `top-level + policy.<key>` lookup into `_find_present_keys()` to share between the two warn helpers. - addresses claude bot (stale docstring): drop the `init_strategy:` entry from `PI05Config`'s class docstring (`pi05/configuration_pi05.py:62-63`). - addresses claude bot (misleading print): change `init=no_init` → `pretrained_path=None` in the e2e debug print in `tests/policies/test_pi05_mem_gpu.py`. - addresses claude bot (no test for strip-on-load): add unit tests in `tests/configs/test_utils_policies.py` covering `_strip_keys` (top-level + nested-under-policy + no-op), `load_stripped_config_to_tempfile` (does not mutate source, does not follow symlinks → cache-safe), `strip_deprecated_fields_from_json` (still in-place for owned files), and `warn_removed_policy_fields` (warns when present, silent when clean). tests: passed — pytest -m "not gpu" -n auto --ignore=tests/planner/test_planner.py --ignore=tests/utils/test_libero_utils.py --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_async_vector_env --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_sync_vector_env tests/ → 720 passed, 16 skipped (3 unrelated test_hub.py errors require HUGGINGFACE_HUB_TOKEN — provided by CI, not the sandbox). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
[claude-fix] Done — pushed in 4505fac. Walked through the two suggestions and two nits from the review summary in comment 4348009825; replied inline on the two suggestions and addressed the two nits + the test gap here:
Tests: |
Merge origin/main into claude/sdpa-ckpt-pi0-pi06. - addresses @shuheng-liu (merge conflicts): merged main (PR #188 removed init_strategy and load_pretrained_paligemma / load_pretrained_gemma3). Three docstring conflicts resolved in src/opentau/policies/pi0/paligemma_with_expert.py, src/opentau/policies/pi06/configuration_pi06.py, and src/opentau/policies/pi06/gemma3_with_expert.py — kept the SDPA additions from this PR while dropping the docstring lines for fields that main removed (init_strategy, load_pretrained_paligemma, load_pretrained_gemma3). The PR's __init__ signatures were already free of those parameters. tests: passed — pytest tests/policies tests/configs -m "not gpu" --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py -n auto Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
What this does
Removes the
init_strategyconfig knob entirely fromPI0Config,PI05Config,PI05MemConfig, andPI05ContinuousStateConfig, along with the_init_model/_init_weightsmethods that consumed it.Why
The knob had three modes —
"no_init","full_he_init","expert_only_he_init"— but in practice none of them did meaningful work in any shipped config:pretrained_path. The post_init logic silently overrodeinit_strategyto"no_init"wheneverpretrained_pathwas set, so"expert_only_he_init"(which appeared in 5 configs including the LIBERO repro) was always dead config."expert_only_he_init"was the only mode that did something the rest of the system couldn't already do — loadgoogle/paligemma-3b-pt-224inside the model constructor for the random-expert + pretrained-backbone use case. We've decided to drop that use case; users who want a pretrained backbone go throughpretrained_pathexclusively. If someone wants random-expert + pretrained-backbone they can build that checkpoint and upload it themselves."no_init"(skip init) and"full_he_init"(He-init everything). Whenfrom_pretrainedis used, He-init is wasted work because the state dict overwrites it. When fresh-from-scratch is used, no shipped config exercises that path. So the knob is dead either way.Changes
init_strategyfield, docstring entry, and__post_init__resolution from all four policy configs. Drop now-unusedlogging/Literalimports._init_model()call and_init_model/_init_weightsmethod bodies from the fourmodeling_*.pyconstructors.load_pretrained_paligemmafromPaliGemmaWithExpertConfig/PaliGemmaWithExpertModelin pi0 and pi05; PaliGemma is always built from config.policies.py: rename_TRANSIENT_POLICY_FIELDSto_REMOVED_POLICY_FIELDSand callstrip_deprecated_fields_from_json(Path(config_file))beforedraccus.parsein bothPreTrainedConfig.from_pretrainedandTrainPipelineConfig.from_pretrained. This is needed because draccus errors on unknown fields, and old saved configs (e.g.TensorAuto/pi05_base/config.json) still contain"init_strategy".init_strategyentries from all 9 shipped JSON configs.Backward-compat note
Old saved configs (HF Hub or local) that still contain
"init_strategy"continue to load — the load-time strip drops the key in place before draccus sees it. The next save writes a clean config.How it was tested
pretrained_path=None,pretrained_path="ckpt", and confirmed noinit_strategyattribute exists post-construction."init_strategy": "expert_only_he_init"raisesDecodingErroron directdraccus.parse, but parses cleanly afterstrip_deprecated_fields_from_jsonruns first._strip_keys/strip_deprecated_fields_from_jsonexercised at top-level and nested-under-policy.json.loadcleanly.python -m py_compile).How to checkout & try? (for the reviewer)
Checklist
https://claude.ai/code/session_01KaAzgu1xy7io6Etf6MXYnJ