Lift discrete fixed states as partition dimensions#326
Conversation
`_extract_period_data` previously computed `subject_id = jnp.arange(len(result.in_regime))` per period. That's fine today (one monolithic dispatch), but it prevents per-partition simulation dispatch: multiple groups would emit colliding local `subject_id`s. Add a `subject_ids` field to `PeriodRegimeSimulationData`, compute `jnp.arange(n_subjects)` once at `simulate()` entry, and thread it through `_simulate_regime_in_period`. `_extract_period_data` now reads from `result.subject_ids` verbatim. Pure refactor — behaviour and test output unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Introduce `src/lcm/regime_building/partitions.py` with four helpers (`iterate_partition_points`, `inject_partition_scalars`, `stack_partition_V_arrays`, `group_subjects_by_partition`) that handle the empty-grid case transparently: single iteration with no scalar injection, no axis stacking, no subject grouping. Add `InternalRegime.partitions` (defaults to empty) and a derived `Model._partition_grid` that is the validated union across regimes. Wrap `Model.solve` in a partition loop; `Model.simulate` stays single- call for now (will be rewritten in the next commit when the grid actually becomes non-empty). Zero behavioural change — every internal regime still has empty partitions. All 843 tests pass unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`state_transitions[name] = None` on a `DiscreteGrid` state has always been the moral "this variable never changes" declaration; until now the implementation auto-generated an identity transition that was vmap'd along with every other state — doubling memory per fixed dim with no mathematical benefit. Because the Bellman backward induction never couples values across such a dimension, each fixed state partitions the state space into independent sub-problems; solve and simulate can compile one reduced sub-model and run it once per point in the product of all partition grids. `process_regimes` now detects model-level partition states (discrete `None` transition present in *every* regime that declares the state, with consistent categories) via `detect_model_partitions` and strips them from each regime's state-action space via `lift_partitions_from_regime` before any downstream machinery runs. The partition scalars are injected into `internal_params[regime]` at solve / simulate time via `inject_partition_scalars`, filtered per regime so regimes that do not reference a partition never receive the scalar as an unexpected kwarg. `Model.solve` iterates the partition product, compiles once (shape invariant across points), stacks the sub-V-arrays along trailing partition axes. `Model.simulate` groups subjects by their partition-value tuple from `initial_conditions`, slices V-arrays per group, simulates each group separately, and concatenates `PeriodRegimeSimulationData` with globally-correct `subject_ids`. Continuous fixed states are NOT lifted — partitioning a continuous value is meaningless; they fall through to an identity transition (same as today). `_IdentityTransition` still exists; removal is a follow-up. All 843 tests pass unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Thirteen tests in `tests/regime_building/test_partitions.py` cover: - `detect_model_partitions` qualification rules (discrete + uniform `None` transition across all regimes; non-None transitions, continuous grids, and mismatched categories are rejected). - `lift_partitions_from_regime` removes partition states. - `InternalRegime.partitions` exposes the lifted grid. - `Model._partition_grid` aggregates across regimes. - `iterate_partition_points` enumerates the full product. - `Model.solve` appends partition axes at the tail of each V-array and produces values monotone in the partition scalar. - `Model.simulate` routes subjects to the correct sub-solution, keeps partition values constant per subject across periods, preserves global subject_ids after per-partition dispatch, and is invariant to the input ordering of subjects. Surface partition names in `SimulationResult` metadata: treat them as "state" columns in `regime_to_states` and include their categories in `regime_discrete_categories` so `to_dataframe` emits them and maps integer codes to labels correctly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
3626a8f to
54a7ca2
Compare
A state that is `None`-transitioned in every non-terminal regime but appears in a terminal regime (which validation requires to have empty `state_transitions`) is still a fixed state. Without this, models like Mahler-Yum — whose `discount_type` / `health_type` / `education` / `productivity` appear on both alive and dead regimes — were silently held back from partitioning because the terminal regime looked like it had a non-None transition. Target-only states (declared only in a terminal regime and populated by per-target transitions at the boundary) are excluded from partitioning: they have no value at initial_conditions. Update `test_dag_output_feeds_default_h_monotone_in_discount_factor` to only check periods before `FINAL_AGE_ALIVE`: at the last working-life age the deterministic transition goes to dead (V=0), so V = U is pref_type-independent regardless of the discount factor. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
54a7ca2 to
85e4d06
Compare
Code reviewUser asked for ALL issues, not just those ≥80. Scoring is from a follow-up Haiku pass against the rubric in the code-review skill (0 = false positive, 100 = certain & frequent). Confidence 80+
Confidence 75 (highly likely real)
Confidence 50 (real but edge-case or stylistic)
Confidence 25 (plausible, unverified)
Confidence 0 (false positive / pre-existing, documented for completeness)
🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
Issues fixed: - Kw-only separators on helper functions (`detect_model_partitions`, `iterate_partition_points`, `_build_partition_grid`) per the project's helper-function convention. - Derive a distinct seed per partition group in `Model.simulate`. Previously every sub-call hit `jax.random.key(seed)` with the same integer, producing identical stochastic draws across groups. - Rename outer `partition_names` in `process_regimes` to `regime_partition_names` so the per-regime loop no longer shadows the model-level frozenset. - Drop the single-result fast-path in `_merge_sub_simulation_results` (shape mismatch when `partition_grid` had one point) and add an assertion that state/action keys agree across sub-result slices. - Fix test `_utility` docstring to imperative mood. - Document the V-array axis-order invariant on `slice_V_at_partition_point`. - Clarify module docstring in `partitions.py` to list both the model-building helpers and the solve/simulate-runtime helpers. - Document the new global-subject_id semantics on `_extract_period_data`. No behavioural change for models without stochastic transitions or with empty partition grids. 856 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The Mahler-Yum GPU regression failed with `InvalidFunctionArgumentsError: missing required arguments: education, productivity` because `_check_regime_feasibility` drew its state set only from `variable_info` — partition states were lifted out of there by the partition refactor. The user's `initial_conditions` carries the partition values just like any other state, so merging the partition names alongside the non-partition state names fixes the path. Also drop the `dead_states[pref_type]` workaround in `tests/solution/test_custom_aggregator.py`: with the partition lift, pref_type no longer flows through dead's state-action space, so dead no longer needs to declare it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Include partition states in `_validate_discrete_state_values` so a user supplying an invalid partition code (e.g. 99 for a 3-category grid) fails validation instead of silently dispatching into a wrong sub-solution. - Include partitions in `_raise_feasibility_type_error`'s `discrete_names` set so the dtype-hint correctly identifies float-typed partition codes. - Replace inline `variable_info | partitions` walk in `_check_regime_feasibility` with a call to the existing `_get_regime_state_names` helper. - Strengthen `test_simulate_feasibility_validation_sees_partition_states` with a DataFrame-level assertion (each subject keeps their pref_type across periods, all three types appear). - Add `test_invalid_partition_code_raises` guarding the new partition-code validation path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The Phase-3 merge of main (post-#317 squash) into the 318→326 stack brought in a duplicate call site that uses the pre-refactor signature (`regime=` kwarg). The 318+ refactor moved the call inside the non-terminal branch and dropped `regime` from the API; the duplicate was appended after `max_Q_over_a` without a conflict marker. Delete it here so 326 tests pass. Same fix will cascade down-stack. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Mahler-Yum benchmark on #326 was 42s (11x main) with 1.99m compile — because `Model.solve` called `solve()` (hence `_compile_all_functions`) inside the partition loop. For a 32-point product that is ~32 full AOT compile passes. Split `solve()` into `compile_solve` (AOT + V template) and `run_compiled_solve` (backward induction). `Model.solve` now calls `compile_solve` once outside the loop, then `run_compiled_solve` per point. Partition scalars flow in via `internal_params` at dispatch time; JAX reuses the cached kernel across points. Lowering needs partition scalars in the pytree, so we inject scalars from the first partition point before compiling. Every other point has identical shape/dtype, so the same compiled kernel serves all of them. Regression test: `_compile_all_functions` called exactly once regardless of partition cardinality. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Closing to reopen as a clean single-commit stack; replaced by the new PR in the same slot. No content change. |
|
Reopened as #329 (clean single-commit stack). |
The fixture-based test broke once partition lifting landed on #326: subjects grouped by partition-value tuple now draw per-group key streams, so per-subject mortality draws no longer align with the pre-partition fixture recorded in dff30e8. The failure was masked on 32-bit because the test was explicitly skipped there (XLA fusion variance); on 64-bit it surfaced as a ~20 % mismatch on the regime column even at N=4. Fix: - Compare per-period averages with meaningful but loose tolerance (atol=0.05, rtol=0.005) so real model regressions (utility sign flip, wrong Bellman accumulation, mis-indexed transition — all O(10 %)+) still trip the test, while stochastic survival flips at the mortality boundary do not. - Bump n_subjects from 4 to 128 so fraction-alive noise per period stays below 1/128 ≈ 0.008, well inside the tolerance. - Generate the fixture from pytest itself with `LCM_UPDATE_FIXTURES=1`, so fixture generation and the compare use the exact same process and import order. Regeneration via a standalone generator script drifts `value` by ~3 % between processes. - Pin PYTHONHASHSEED=0 on the pixi tests/tests-32bit/tests-with-cov tasks. Without the pin, Python hash-based ordering propagates into the simulation random draws and makes per-process output nondeterministic (~70 % productivity-shock drift between runs). With the pin, f64 is byte-reproducible and f32 drifts only by ~1e-4. - Re-enable the test at f32 — the skipif-not-X64_ENABLED decorator is gone now that tolerance and hash pinning together cover both precisions. - Drop the Mahler-Yum generator from generate_benchmark_data.py; its docstring now points callers at the pytest-driven regen path. The small benchmark models (precautionary_savings, mortality) remain reproducible at per-subject granularity via the script.
Summary
state_transitions[name] = Noneon aDiscreteGridstate has always meant "this value never changes." Today pylcm implements that by vmap'ing an auto-generated identity transition alongside every other state, paying memory linearly in the fixed state's cardinality. Because the Bellman backward induction never couples values across such a dimension, each fixed state partitions the state space into independent sub-problems: pylcm can compile one reduced sub-model and run it once per point in the product of all partition grids.This PR adds that machinery end-to-end. User-facing API is unchanged —
Nonetransitions still declare a fixed state; internally pylcm lifts them out ofvariable_info/ state-action space / Q_and_F, iterates the partition product inModel.solve/Model.simulate, and restores partition axes at the API boundary.What changed
src/lcm/regime_building/partitions.py(new). Shared helpers:detect_model_partitions,lift_partitions_from_regime,iterate_partition_points,inject_partition_scalars,stack_partition_V_arrays,group_subjects_by_partition,slice_V_at_partition_point,slice_initial_conditions. All handle the empty-grid case transparently (one iteration, no stacking, no grouping).src/lcm/regime_building/processing.py.process_regimesdetects model-level partitions up-front and passes reduced regimes to every downstream builder.InternalRegimegains apartitionsfield; downstream machinery never sees partition states as vmap'd axes.src/lcm/model.py.Model._partition_gridaggregates partitions across regimes (validates category agreement).Model.solvewraps backward induction in a partition loop, stacks sub-V-arrays along trailing axes.Model.simulategroups subjects by their partition-value tuple, slices V-arrays per group, and concatenatesPeriodRegimeSimulationDatawith globally-correctsubject_ids.src/lcm/simulation/{simulate,result}.py.PeriodRegimeSimulationDatagains an explicitsubject_idsfield;simulate()accepts it as a kwarg so per-partition dispatch groups stay aligned after concatenation.to_dataframeemits partition names as "state" columns with correct per-regime categorical mapping.src/lcm/simulation/initial_conditions.py. Partition names are required ininitial_conditions— they supply each subject's fixed value.Qualification rules for partition lifting
A discrete state qualifies as a partition iff:
DiscreteGrid.None.Terminal regimes (which must have empty
state_transitionsby validation) count the absence of an entry as the equivalent ofNone.Continuous fixed states (
Noneon a continuous grid) fall through to an identity transition — partitioning a continuous value is meaningless.Mahler-Yum impact
Four fixed states (
productivity,education,discount_type,health_type) auto-lift on this branch, giving a 32-point partition product.Test plan
pixi run --environment tests-cpu tests -n 7— 856 pass, 5 skipped (843 existing + 13 new partition tests).pixi run ty— clean.prek run --all-files— clean.test_dag_output_feeds_default_h_monotone_in_discount_factorupdated: skip the last working-life period where V = U (deterministic transition todead, V_dead = 0) is pref_type-independent regardless of discount factor.Not in scope
_IdentityTransitionentirely — continuous fixed states still need it. Narrower cleanup is a follow-up.jax.lax.scaninto a pre-allocated output buffer (current implementation holds all sub-V-arrays during stack). Peak could drop another ~30 %.🤖 Generated with Claude Code