Lock integer dtype to int32 end-to-end#341
Merged
hmgaudecker merged 2 commits intofeat/simulate-aot-n-subjectsfrom May 3, 2026
Merged
Lock integer dtype to int32 end-to-end#341hmgaudecker merged 2 commits intofeat/simulate-aot-n-subjectsfrom
hmgaudecker merged 2 commits intofeat/simulate-aot-n-subjectsfrom
Conversation
Tighten Int1D/IntND/DiscreteState/DiscreteAction/ScalarInt to Int32 in typing.py, and cast searchsorted/argmax/unravel_index/where outputs to int32 at every site where their width depended on jax_enable_x64. This prevents the JIT cache from silently splitting into per-period int32/int64 variants and breaks the AOT-compiled simulate program that ships a single signature. Adds a regression test asserting discrete grids, build_initial_states discrete entries, and MISSING_CAT_CODE match int32. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Benchmark comparison (main → HEAD)Comparing
|
`jnp.searchsorted` already returns int32 even with `jax_enable_x64`, so the four `.astype(jnp.int32)` casts in `grids/coordinates.py`, `grids/piecewise.py` (×2), and `simulation/simulate.py:_compute_starting_periods` were no-ops at the dtype level — but they sat between an integer-producing op and its index-consumer inside vmap'd interpolation kernels, breaking XLA's fusion and forcing the intermediate to materialise as a top-level GPU buffer per (period, regime, state). Likewise, the `unravel_index` output in `_lookup_values_from_indices` is consumed immediately by `grid[index]`, which accepts int64 fine — the cast served no purpose. Keeps the argmax cast on the solve path (real int64→int32 narrowing), the boundary casts at error/validation paths, and the AOT-relevant casts in `pandas_utils` and the `subject_regime_ids` sentinel. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5 tasks
hmgaudecker
added a commit
that referenced
this pull request
May 4, 2026
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
hmgaudecker
added a commit
that referenced
this pull request
May 4, 2026
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stacks on #340. Finishes the int dtype lock-in started by the pinhole fixes
on
feat/simulate-aot-n-subjects(DiscreteGrid.to_jax → int32,build_initial_statescasts to grid dtype). Now every internal integerJAX array is
int32regardless ofjax_enable_x64, and the type aliasesin
src/lcm/typing.pyadvertise that contract sotyflags any futureregression at edit time.
Why this matters: the lazy JIT cache silently retraces per dtype, so
int32(no x64) vsint64(x64) variants of the same regime compiledinto different specializations. AOT compile via
jax.jit(...).lower(**args).compile()ships a single signature andbroke at runtime with
int32[N] vs int64[N]mismatches.Changes
src/lcm/typing.py— replaceIntwithInt32forInt1D,IntND,DiscreteState,DiscreteAction,ScalarInt. ~113 internal usagesinherit the dtype constraint without further edits.
jax_enable_x64:grids/coordinates.py:191—searchsortedgrids/piecewise.py:90, 161—searchsortedsimulation/simulate.py:386—searchsorted(starting periods)simulation/simulate.py:357—unravel_indexoutputsregime_building/argmax.py:46, 68— scalar fallback +argmaxsimulation/initial_conditions.py:365, 409—whereindex extractionsutils/error_handling.py:376—whereindex extractionpandas_utils.py:155— regime-id ingestion cast toint32to matchDiscreteGrid.to_jax().simulation/simulate.py:102—subject_regime_idssentinel bufferpinned to
int32regardless of input regime dtype.Test plan
pixi run -e tests-cpu pytest tests/ -n 7— 895 passed (3 newdtype-invariant tests added)
pixi run -e type-checking ty— cleanprek run --all-files— clean/tmp/smoke8.py) —simulate OK, n=20with all int initial-condition entries reporting
int32🤖 Generated with Claude Code