Skip to content

Lift discrete fixed states as partition dimensions#326

Closed
hmgaudecker wants to merge 18 commits into
feature/bench-aca-baselinefrom
feature/partition-fixed-states
Closed

Lift discrete fixed states as partition dimensions#326
hmgaudecker wants to merge 18 commits into
feature/bench-aca-baselinefrom
feature/partition-fixed-states

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

Summary

state_transitions[name] = None on a DiscreteGrid state has always meant "this value never changes." Today pylcm implements that by vmap'ing an auto-generated identity transition alongside every other state, paying memory linearly in the fixed state's cardinality. Because the Bellman backward induction never couples values across such a dimension, each fixed state partitions the state space into independent sub-problems: pylcm can compile one reduced sub-model and run it once per point in the product of all partition grids.

This PR adds that machinery end-to-end. User-facing API is unchanged — None transitions still declare a fixed state; internally pylcm lifts them out of variable_info / state-action space / Q_and_F, iterates the partition product in Model.solve / Model.simulate, and restores partition axes at the API boundary.

What changed

  • src/lcm/regime_building/partitions.py (new). Shared helpers: detect_model_partitions, lift_partitions_from_regime, iterate_partition_points, inject_partition_scalars, stack_partition_V_arrays, group_subjects_by_partition, slice_V_at_partition_point, slice_initial_conditions. All handle the empty-grid case transparently (one iteration, no stacking, no grouping).
  • src/lcm/regime_building/processing.py. process_regimes detects model-level partitions up-front and passes reduced regimes to every downstream builder. InternalRegime gains a partitions field; downstream machinery never sees partition states as vmap'd axes.
  • src/lcm/model.py. Model._partition_grid aggregates partitions across regimes (validates category agreement). Model.solve wraps backward induction in a partition loop, stacks sub-V-arrays along trailing axes. Model.simulate groups subjects by their partition-value tuple, slices V-arrays per group, and concatenates PeriodRegimeSimulationData with globally-correct subject_ids.
  • src/lcm/simulation/{simulate,result}.py. PeriodRegimeSimulationData gains an explicit subject_ids field; simulate() accepts it as a kwarg so per-partition dispatch groups stay aligned after concatenation. to_dataframe emits partition names as "state" columns with correct per-regime categorical mapping.
  • src/lcm/simulation/initial_conditions.py. Partition names are required in initial_conditions — they supply each subject's fixed value.

Qualification rules for partition lifting

A discrete state qualifies as a partition iff:

  1. The grid is a DiscreteGrid.
  2. Every non-terminal regime where the state appears declares None.
  3. The state appears in at least one non-terminal regime (target-only states populated by per-target transitions at the boundary are excluded).
  4. Categories agree across regimes.

Terminal regimes (which must have empty state_transitions by validation) count the absence of an entry as the equivalent of None.

Continuous fixed states (None on a continuous grid) fall through to an identity transition — partitioning a continuous value is meaningless.

Mahler-Yum impact

Four fixed states (productivity, education, discount_type, health_type) auto-lift on this branch, giving a 32-point partition product.

  • GPU peak memory (cuda13, 100 subjects): 501 MB. Comparable to or slightly below the pre-heterogeneity baseline (~400 MB on older main, 522 MB on the h-consumes-dag-outputs branch). The win is larger for Replication_MY2024's 10K-subject / 37-period runs where memory pressure is real.
  • Execution time dominated by compile once + run N times: only one JIT compile across all partition points (shape invariant across points).

Test plan

  • pixi run --environment tests-cpu tests -n 7856 pass, 5 skipped (843 existing + 13 new partition tests).
  • pixi run ty — clean.
  • prek run --all-files — clean.
  • Mahler-Yum GPU benchmark runs end-to-end on cuda13.
  • test_dag_output_feeds_default_h_monotone_in_discount_factor updated: skip the last working-life period where V = U (deterministic transition to dead, V_dead = 0) is pref_type-independent regardless of discount factor.

Not in scope

  • Removing _IdentityTransition entirely — continuous fixed states still need it. Narrower cleanup is a follow-up.
  • Memory-optimal stacking via jax.lax.scan into a pre-allocated output buffer (current implementation holds all sub-V-arrays during stack). Peak could drop another ~30 %.

🤖 Generated with Claude Code

@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented Apr 18, 2026

hmgaudecker and others added 4 commits April 18, 2026 19:48
`_extract_period_data` previously computed `subject_id = jnp.arange(len(result.in_regime))`
per period. That's fine today (one monolithic dispatch), but it prevents per-partition
simulation dispatch: multiple groups would emit colliding local `subject_id`s.

Add a `subject_ids` field to `PeriodRegimeSimulationData`, compute
`jnp.arange(n_subjects)` once at `simulate()` entry, and thread it through
`_simulate_regime_in_period`. `_extract_period_data` now reads from
`result.subject_ids` verbatim. Pure refactor — behaviour and test output
unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Introduce `src/lcm/regime_building/partitions.py` with four helpers
(`iterate_partition_points`, `inject_partition_scalars`,
`stack_partition_V_arrays`, `group_subjects_by_partition`) that handle
the empty-grid case transparently: single iteration with no scalar
injection, no axis stacking, no subject grouping.

Add `InternalRegime.partitions` (defaults to empty) and a derived
`Model._partition_grid` that is the validated union across regimes.
Wrap `Model.solve` in a partition loop; `Model.simulate` stays single-
call for now (will be rewritten in the next commit when the grid
actually becomes non-empty).

Zero behavioural change — every internal regime still has empty
partitions. All 843 tests pass unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`state_transitions[name] = None` on a `DiscreteGrid` state has always
been the moral "this variable never changes" declaration; until now
the implementation auto-generated an identity transition that was
vmap'd along with every other state — doubling memory per fixed
dim with no mathematical benefit. Because the Bellman backward
induction never couples values across such a dimension, each fixed
state partitions the state space into independent sub-problems; solve
and simulate can compile one reduced sub-model and run it once per
point in the product of all partition grids.

`process_regimes` now detects model-level partition states (discrete
`None` transition present in *every* regime that declares the state,
with consistent categories) via `detect_model_partitions` and strips
them from each regime's state-action space via
`lift_partitions_from_regime` before any downstream machinery runs.
The partition scalars are injected into `internal_params[regime]` at
solve / simulate time via `inject_partition_scalars`, filtered per
regime so regimes that do not reference a partition never receive the
scalar as an unexpected kwarg.

`Model.solve` iterates the partition product, compiles once (shape
invariant across points), stacks the sub-V-arrays along trailing
partition axes. `Model.simulate` groups subjects by their
partition-value tuple from `initial_conditions`, slices V-arrays per
group, simulates each group separately, and concatenates
`PeriodRegimeSimulationData` with globally-correct `subject_ids`.

Continuous fixed states are NOT lifted — partitioning a continuous
value is meaningless; they fall through to an identity transition
(same as today). `_IdentityTransition` still exists; removal is a
follow-up.

All 843 tests pass unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Thirteen tests in `tests/regime_building/test_partitions.py` cover:

- `detect_model_partitions` qualification rules (discrete + uniform
  `None` transition across all regimes; non-None transitions,
  continuous grids, and mismatched categories are rejected).
- `lift_partitions_from_regime` removes partition states.
- `InternalRegime.partitions` exposes the lifted grid.
- `Model._partition_grid` aggregates across regimes.
- `iterate_partition_points` enumerates the full product.
- `Model.solve` appends partition axes at the tail of each V-array
  and produces values monotone in the partition scalar.
- `Model.simulate` routes subjects to the correct sub-solution, keeps
  partition values constant per subject across periods, preserves
  global subject_ids after per-partition dispatch, and is invariant
  to the input ordering of subjects.

Surface partition names in `SimulationResult` metadata: treat them as
"state" columns in `regime_to_states` and include their categories in
`regime_discrete_categories` so `to_dataframe` emits them and maps
integer codes to labels correctly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker force-pushed the feature/partition-fixed-states branch from 3626a8f to 54a7ca2 Compare April 18, 2026 17:50
A state that is `None`-transitioned in every non-terminal regime but
appears in a terminal regime (which validation requires to have
empty `state_transitions`) is still a fixed state. Without this,
models like Mahler-Yum — whose `discount_type` / `health_type` /
`education` / `productivity` appear on both alive and dead regimes —
were silently held back from partitioning because the terminal
regime looked like it had a non-None transition.

Target-only states (declared only in a terminal regime and populated
by per-target transitions at the boundary) are excluded from
partitioning: they have no value at initial_conditions.

Update `test_dag_output_feeds_default_h_monotone_in_discount_factor`
to only check periods before `FINAL_AGE_ALIVE`: at the last
working-life age the deterministic transition goes to dead (V=0),
so V = U is pref_type-independent regardless of the discount factor.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker force-pushed the feature/partition-fixed-states branch from 54a7ca2 to 85e4d06 Compare April 18, 2026 17:53
@hmgaudecker
Copy link
Copy Markdown
Member Author

Code review

User asked for ALL issues, not just those ≥80. Scoring is from a follow-up Haiku pass against the rubric in the code-review skill (0 = false positive, 100 = certain & frequent).

Confidence 80+

  1. Missing kw-only separator on detect_model_partitions (CLAUDE.md/MEMORY.md: "all function parameters should be keyword-only (*,) unless JAX tracing requires positional args") — score 100.
    https://github.com/OpenSourceEconomics/pylcm/blob/85e4d06568c5d10a4cd8d9fae58f44717ff2c2bf/src/lcm/regime_building/partitions.py#L36-L40

  2. Missing kw-only separator on iterate_partition_points (same rule) — score 100.
    https://github.com/OpenSourceEconomics/pylcm/blob/85e4d06568c5d10a4cd8d9fae58f44717ff2c2bf/src/lcm/regime_building/partitions.py#L139-L143

  3. Missing kw-only separator on _build_partition_grid (same rule) — score 100.
    https://github.com/OpenSourceEconomics/pylcm/blob/85e4d06568c5d10a4cd8d9fae58f44717ff2c2bf/src/lcm/model.py#L578-L581

Confidence 75 (highly likely real)

  1. Same RNG seed passed to every partition group in Model.simulate — each sub-simulate calls jax.random.key(seed=seed) with the same integer. Stochastic state transitions apply identical draws across groups. Silent correctness bug when seed is not None and there are ≥2 partition groups with stochastic transitions. Fix: split key per group (jax.random.split(jax.random.key(seed), n_groups)) or hash (seed, partition_point) before passing.
    https://github.com/OpenSourceEconomics/pylcm/blob/85e4d06568c5d10a4cd8d9fae58f44717ff2c2bf/src/lcm/model.py#L381-L394

  2. Test docstring non-imperative mood_utility starts with "Utility scales…" (descriptive). AGENTS.md: "Imperative mood for docstring summary lines."
    https://github.com/OpenSourceEconomics/pylcm/blob/85e4d06568c5d10a4cd8d9fae58f44717ff2c2bf/tests/regime_building/test_partitions.py#L53-L62

Confidence 50 (real but edge-case or stylistic)

  1. Shadowed partition_names in process_regimes — outer model-level frozenset is re-bound inside the per-regime loop. Functionally correct today (outer binding isn't read after the loop) but a refactoring trap.
    https://github.com/OpenSourceEconomics/pylcm/blob/85e4d06568c5d10a4cd8d9fae58f44717ff2c2bf/src/lcm/regime_building/processing.py#L100-L152

  2. _merge_sub_simulation_results fast-path shape mismatch for single-point partitions — when len(sub_results) == 1 AND partition_grid is non-empty, the returned period_to_regime_to_V_arr has the partition axis appended but raw_results V_arrays do not. Mismatch visible if downstream jointly reads both.
    https://github.com/OpenSourceEconomics/pylcm/blob/85e4d06568c5d10a4cd8d9fae58f44717ff2c2bf/src/lcm/model.py#L497-L520

  3. Merge does not assert matching state_names across slicesstate_names = tuple(slices[0].states) silently drops keys present in later slices but missing from the first.
    https://github.com/OpenSourceEconomics/pylcm/blob/85e4d06568c5d10a4cd8d9fae58f44717ff2c2bf/src/lcm/model.py#L540-L563

Confidence 25 (plausible, unverified)

  1. Module layout: detect_model_partitions / lift_partitions_from_regime precede the primary-API helpers named in the module docstring. .ai-instructions/AGENTS.md: "Write 'deep' modules: important public function(s) at the top, private helpers below."
    https://github.com/OpenSourceEconomics/pylcm/blob/85e4d06568c5d10a4cd8d9fae58f44717ff2c2bf/src/lcm/regime_building/partitions.py#L36-L165

  2. subject_id semantic change — previously jnp.arange(len(result.in_regime)) (dense local row index), now result.subject_ids (global ID). Design is cleaner but a silent break for downstream consumers assuming 0..n-1 density. No in-repo consumer relies on density.
    https://github.com/OpenSourceEconomics/pylcm/blob/85e4d06568c5d10a4cd8d9fae58f44717ff2c2bf/src/lcm/simulation/result.py#L526-L539

  3. Partition-state reachability gap in _validate_all_variables_used — a partition state consumed only as an index into a params-dict/Series (not as a direct function arg) may not be reachable through the get_ancestors walk. No test covers this.
    https://github.com/OpenSourceEconomics/pylcm/blob/85e4d06568c5d10a4cd8d9fae58f44717ff2c2bf/src/lcm/model_processing.py#L195-L260

Confidence 0 (false positive / pre-existing, documented for completeness)

  1. Axis-order invariant concern (stacked V has continuous axes non-last): Model.simulate always calls slice_V_at_partition_point before handing V to the interpolator, so the invariant is preserved at the consumption point. (Score 0, but worth a comment at slice_V_at_partition_point documenting the invariant.)
  2. PR Support heterogeneous state sets in initial conditions #315 NaN→int32 cast: not reintroduced.
  3. PR Support heterogeneous state sets in initial conditions #315 wrong exception type in _raise_feasibility_type_error: not touched.
  4. PR Skip unreachable target regimes in continuation value loop #316 io_callback per scalar: pre-existing, Q_and_F not modified.
  5. PR Auto-convert pd.Series in fixed_params #308 helper-inlining readability: partition logic is in dedicated partitions.py; only a thin wrapper stays in model.py.

🤖 Generated with Claude Code

- If this code review was useful, please react with 👍. Otherwise, react with 👎.

hmgaudecker and others added 13 commits April 18, 2026 20:05
Issues fixed:

- Kw-only separators on helper functions (`detect_model_partitions`,
  `iterate_partition_points`, `_build_partition_grid`) per the project's
  helper-function convention.
- Derive a distinct seed per partition group in `Model.simulate`.
  Previously every sub-call hit `jax.random.key(seed)` with the same
  integer, producing identical stochastic draws across groups.
- Rename outer `partition_names` in `process_regimes` to
  `regime_partition_names` so the per-regime loop no longer shadows
  the model-level frozenset.
- Drop the single-result fast-path in `_merge_sub_simulation_results`
  (shape mismatch when `partition_grid` had one point) and add an
  assertion that state/action keys agree across sub-result slices.
- Fix test `_utility` docstring to imperative mood.
- Document the V-array axis-order invariant on `slice_V_at_partition_point`.
- Clarify module docstring in `partitions.py` to list both the
  model-building helpers and the solve/simulate-runtime helpers.
- Document the new global-subject_id semantics on `_extract_period_data`.

No behavioural change for models without stochastic transitions or with
empty partition grids. 856 tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The Mahler-Yum GPU regression failed with
`InvalidFunctionArgumentsError: missing required arguments: education,
productivity` because `_check_regime_feasibility` drew its state set
only from `variable_info` — partition states were lifted out of there
by the partition refactor. The user's `initial_conditions` carries the
partition values just like any other state, so merging the partition
names alongside the non-partition state names fixes the path.

Also drop the `dead_states[pref_type]` workaround in
`tests/solution/test_custom_aggregator.py`: with the partition lift,
pref_type no longer flows through dead's state-action space, so dead
no longer needs to declare it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Include partition states in `_validate_discrete_state_values` so a
  user supplying an invalid partition code (e.g. 99 for a 3-category
  grid) fails validation instead of silently dispatching into a wrong
  sub-solution.
- Include partitions in `_raise_feasibility_type_error`'s
  `discrete_names` set so the dtype-hint correctly identifies
  float-typed partition codes.
- Replace inline `variable_info | partitions` walk in
  `_check_regime_feasibility` with a call to the existing
  `_get_regime_state_names` helper.
- Strengthen `test_simulate_feasibility_validation_sees_partition_states`
  with a DataFrame-level assertion (each subject keeps their
  pref_type across periods, all three types appear).
- Add `test_invalid_partition_code_raises` guarding the new
  partition-code validation path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The Phase-3 merge of main (post-#317 squash) into the 318→326 stack
brought in a duplicate call site that uses the pre-refactor signature
(`regime=` kwarg). The 318+ refactor moved the call inside the
non-terminal branch and dropped `regime` from the API; the duplicate
was appended after `max_Q_over_a` without a conflict marker. Delete
it here so 326 tests pass. Same fix will cascade down-stack.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Mahler-Yum benchmark on #326 was 42s (11x main) with 1.99m compile —
because `Model.solve` called `solve()` (hence `_compile_all_functions`)
inside the partition loop. For a 32-point product that is ~32 full AOT
compile passes.

Split `solve()` into `compile_solve` (AOT + V template) and
`run_compiled_solve` (backward induction). `Model.solve` now calls
`compile_solve` once outside the loop, then `run_compiled_solve` per
point. Partition scalars flow in via `internal_params` at dispatch
time; JAX reuses the cached kernel across points.

Lowering needs partition scalars in the pytree, so we inject scalars
from the first partition point before compiling. Every other point
has identical shape/dtype, so the same compiled kernel serves all of
them.

Regression test: `_compile_all_functions` called exactly once
regardless of partition cardinality.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker
Copy link
Copy Markdown
Member Author

Closing to reopen as a clean single-commit stack; replaced by the new PR in the same slot. No content change.

@hmgaudecker
Copy link
Copy Markdown
Member Author

Reopened as #329 (clean single-commit stack).

hmgaudecker added a commit that referenced this pull request Apr 19, 2026
The fixture-based test broke once partition lifting landed on #326: subjects
grouped by partition-value tuple now draw per-group key streams, so per-subject
mortality draws no longer align with the pre-partition fixture recorded in
dff30e8. The failure was masked on 32-bit because the test was explicitly
skipped there (XLA fusion variance); on 64-bit it surfaced as a ~20 % mismatch
on the regime column even at N=4.

Fix:

- Compare per-period averages with meaningful but loose tolerance
  (atol=0.05, rtol=0.005) so real model regressions (utility sign flip, wrong
  Bellman accumulation, mis-indexed transition — all O(10 %)+) still trip the
  test, while stochastic survival flips at the mortality boundary do not.
- Bump n_subjects from 4 to 128 so fraction-alive noise per period stays
  below 1/128 ≈ 0.008, well inside the tolerance.
- Generate the fixture from pytest itself with `LCM_UPDATE_FIXTURES=1`, so
  fixture generation and the compare use the exact same process and import
  order. Regeneration via a standalone generator script drifts `value` by
  ~3 % between processes.
- Pin PYTHONHASHSEED=0 on the pixi tests/tests-32bit/tests-with-cov tasks.
  Without the pin, Python hash-based ordering propagates into the simulation
  random draws and makes per-process output nondeterministic (~70 %
  productivity-shock drift between runs). With the pin, f64 is
  byte-reproducible and f32 drifts only by ~1e-4.
- Re-enable the test at f32 — the skipif-not-X64_ENABLED decorator is gone
  now that tolerance and hash pinning together cover both precisions.
- Drop the Mahler-Yum generator from generate_benchmark_data.py; its
  docstring now points callers at the pytest-driven regen path. The small
  benchmark models (precautionary_savings, mortality) remain reproducible
  at per-subject granularity via the script.
@hmgaudecker hmgaudecker deleted the feature/partition-fixed-states branch April 20, 2026 08:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant