Activate scoped beartype claw on lcm.grids/shocks/params + drift fixes#356
Merged
hmgaudecker merged 15 commits intoMay 14, 2026
Merged
Conversation
Targeted widenings + corrections that align type annotations with the
values actually flowing through pylcm at runtime. None of these change
behaviour today; together they cover what beartype's `On` strategy
would flag if `beartype.claw.beartype_package("lcm", ...)` were
enabled in `tests/conftest.py`.
Categories addressed (one site per bullet unless noted):
- `MappingProxyType` annotations on internal entry points widened to
`Mapping`: `solve_brute.py:solve` (8 sites) and
`params/processing.py:process_params`. Callers pass plain `dict` in
several spots; the immutability discipline isn't load-bearing on
these boundaries.
- JAX scalar/array dtype drift on `grids/coordinates.py`:
`get_linspace_coordinate` / `get_logspace_coordinate` widen
`start: ScalarFloat` and `stop: ScalarFloat` to `ScalarFloat | Array`
(the recursive log → linear path passes 0-d JAX arrays).
- `shocks/_base.py:_gauss_hermite_normal`: add `ScalarFloat` to the
`mu` / `sigma` unions to admit 0-d JAX scalars.
- PRNG keys: `draw_shock` methods on every `Tauchen`, `Rouwenhorst`,
`TauchenNormalMixture`, `Uniform`, `Normal`, `LogNormal` retype
`key: FloatND` to `key: Array` — PRNG keys have `dtype=key<fry>`,
not float.
- `utils/dispatchers.py:batched_vmap`: retype `kwargs: FloatND` and
`Float1D` to `Array`; the dispatcher routes int arrays as well as
float arrays through the vmap chain.
- `params/{mapping_leaf,sequence_leaf}.py:_unflatten`:
`values: list[Any]` → `Sequence[Any]`; pytree unflatten receives the
tuple from `_flatten`.
- Beartype-unsupported PEP forms:
- `simulation/initial_conditions.py:_raise_feasibility_type_error`:
`Never` → `NoReturn` (beartype 0.22 rejects `typing.Never`).
- `persistence.py`: define `_ModelOrNone` and
`_SimulationResultOrNone` as `TYPE_CHECKING`-conditional aliases
that resolve to `Any` at runtime. The bare-union string forward
reference `'Model | None'` was unparseable by beartype's
forward-ref machinery, but importing `Model` at the top causes a
circular import.
- Corrections discovered in passing:
- `dtypes.py:canonical_float_dtype` return annotation `jnp.dtype` →
`type` (the function returns `_ScalarMeta`, a class, not a numpy
`dtype` instance).
- `utils/containers.py:get_field_names_and_values` parameter
`dc: type` → `dc: object` (the function accepts both dataclass
classes and instances).
Benchmark comparison (main → HEAD)Comparing
|
Activate beartype's import-time claw via tests/conftest.py for the `lcm.grids`, `lcm.shocks`, and `lcm.params` packages, configured to raise the existing project exceptions (`GridInitializationError` / `InvalidParamsError`) on parameter-type violations. The claw catches real annotation drift that the previous decorator-only perimeter did not police — every shock-grid `draw_shock`, `compute_gridpoints`, and `compute_transition_probs` plus every grid `get_coordinate` now has runtime type checks. The drift fixes: - iid `draw_shock` returns `ScalarFloat` (not `Float1D`) — each call produces a 0-d sample. - ar1 `draw_shock` takes / returns `ScalarFloat` per element (matches what `next_state.py` feeds it under vmap). - shock-grid `compute_gridpoints` / `compute_transition_probs` accept `**kwargs: float | FloatND` — runtime kwargs from the params wrapper arrive as JAX tracers, not Python floats. - `_mixture_cdf` and `TauchenNormalMixture._innovation_variance` widen scalar params to `float | FloatND` so JAX-array shock params flow through. - `get_linspace_coordinate` / `get_logspace_coordinate` accept `n_points: ScalarInt | IntND` for the piecewise dispatch where `_piece_n_points[piece_idx]` is an `IntND` under vectorized lookup. - Piecewise + linspace/logspace coordinate methods accept `value: float | ScalarFloat | FloatND` (Python float was already used in tests; the contract just lacked the annotation). - Continuous-grid + shock-grid `get_coordinate` overloads use `FloatND` instead of bare `Array`.
`KeyArray = Array` makes the `key:` parameter on every `draw_shock` explicit about expecting a PRNG key rather than a float array. PRNGKey arrays carry dtype `key<fry>` and don't match `FloatND`, so using `KeyArray` is the natural type slot for those signatures. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Moves the `beartype_package(...)` registrations from `tests/conftest.py` into `lcm/__init__.py` so the claw actually instruments `lcm.grids`, `lcm.shocks`, and `lcm.params` at first import. The conftest version ran AFTER `from lcm._beartype_conf import ...` triggered `lcm/__init__.py`, which eagerly imports every targeted submodule — the claw was a no-op on every prior run. Live-claw drift fixed: - `Piece` `n_points` cumulative sum: explicit `dtype=jnp.int32` on the `_piece_n_points` array + `sum(dtype=jnp.int32)` so the property honours its `ScalarInt` annotation under `jax_enable_x64=True`. - `_ShockGrid.compute_gridpoints` / `compute_transition_probs` / `Tauchen._innovation_variance`: kwargs accept `float | FloatND | IntND`; runtime params (e.g. `n_std=2`) flow through as 0-d int arrays after pylcm's canonical-dtype cast. - `UniformContinuousGrid` / `LinSpacedGrid` / `LogSpacedGrid` / `_ShockGrid` `get_coordinate`: accept Python `float` / `int` alongside `ScalarFloat | FloatND`. Same widening on `get_irreg_coordinate`. - `process_params(params_template=...)`: widen to `Mapping[RegimeName, Mapping[str, object]]` since real callers and test mocks both use plain dicts at the inner levels (the strict `MappingProxyType[...]`-only alias rejected them). - `_ParamsLeaf`: add `np.ndarray` so numpy-array params (used in dtype-invariant tests) pass the perimeter check before pylcm casts them. - `create_regime_params_template(regime: ...)` and helpers, `create_params_template(internal_regimes: ...)`: typed as `Any` so duck-typed test mocks satisfy the signature (the functions themselves only touch a small structural interface — `states`, `actions`, `functions`, `regime_params_template`). - `test_as_leaf_rejects_int` now expects `InvalidParamsError` (the configured `PARAMS_CONF` exception), matching the claw-routed violation; the in-function `TypeError` raise no longer fires because beartype rejects the argument first. Drops the now-redundant claw setup from `tests/conftest.py` and strips the `# ty: ignore[invalid-argument-type]` / `[no-matching- overload]` markers the widenings make unnecessary. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
… boundary Reverts the widenings introduced when activating the claw on `lcm.grids` / `lcm.shocks` / `lcm.params`. The principle: user-facing constructors keep `float | ScalarFloat` etc. as a convenience, but pylcm-internal helpers take only the canonical JAX-side types and callers cast at the boundary. Internal-API signatures restored to JAX-only: - `get_coordinate` on every continuous + shock grid (and the abstract base in `_ShockGrid` / `ContinuousGrid`) drops Python `float` from the `value:` parameter; `get_linspace_coordinate` / `get_logspace_coordinate` / `get_irreg_coordinate` in `lcm.grids.coordinates` do the same. - `_gauss_hermite_normal(mu, sigma)` drops `float`; the one Python literal call site in `Tauchen.compute_transition_probs` now wraps with `jnp.asarray(0.0)`. - `_mixture_cdf(p1, mu1, sigma1, mu2, sigma2)` drops `float` and uses `ScalarFloat | FloatND`. - `Tauchen._innovation_variance(...)` drops `float`. - `_ShockGrid.compute_gridpoints` / `compute_transition_probs` kwargs drop `float`; the per-grid concrete overrides match. - `draw_shock(params: MappingProxyType[str, float | FloatND])` → `MappingProxyType[str, FloatND | IntND]`. - `lcm.params.processing.create_params_template(internal_regimes: ...)` restored to `Mapping[RegimeName, InternalRegime]` (was `Any` for duck-typed mocks); `solve_brute.solve` and the other internal helpers in `solve_brute` use `MappingProxyType[RegimeName, InternalRegime]` (was `Mapping[...]`). - `create_regime_params_template(regime: Regime)` and its helpers restored to the strict `Regime` type. Boundary casts via a new private `_params_to_jax` helper in `lcm.shocks._base`: - `_ShockGrid.get_gridpoints` / `get_transition_probs` cast `self.params` (mix of Python literals + runtime arrays). - `interfaces.StateActionSpace.state_action_space` casts `shock_kw` before passing to `spec.compute_gridpoints`. - `regime_building.processing.weights_func_runtime` casts before invoking the shock grid's compute helpers. - `regime_building.next_state._create_*_next_func` casts before calling `_draw_shock`. `Piece._piece_n_points` cumulative `.sum()` explicit `dtype=jnp.int32` so the property returns `ScalarInt` (int32) under `jax_enable_x64=True` instead of the int64 JAX produces by default. Test mocks inherit the production class so `isinstance(x, Regime)` / `isinstance(x, InternalRegime)` holds at the beartype-checked perimeter: - `tests/regime_mock.py:RegimeMock` inherits `Regime`, overrides `__init__` to bypass the parent's `@beartype`-decorated constructor validation via `object.__setattr__`. Mirrors `Regime.__post_init__`'s default-H injection. - `tests/regime_building/test_process_params.py:MockRegime` inherits `InternalRegime`, sets only `regime_params_template` directly via `object.__setattr__`. Tests updated to pass JAX scalars at the perimeter: - `tests/test_grids.py`, `tests/test_function_representation.py`, `tests/test_fgp_discretization.py` wrap Python literals with `jnp.asarray(...)` at sites that flow into `get_coordinate` / `_mixture_cdf`. - `test_function_evaluator_performs_linear_extrapolation`: the int `10` in `wealth_outside_of_grid` becomes `10.0`. Drops the `# ty: ignore[invalid-argument-type]` markers in test files where the inheritance fix makes them unnecessary. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
4 tasks
…loads, bare Array ty erases jaxtyping shapes, so `ScalarX | XND` unions and scalar/array `@overload` pairs add zero static precision; beartype treats a 0-d array as satisfying both, so the union is redundant at runtime too. Collapse both patterns, eliminate bare `Array` annotations in favour of the narrowest `lcm.typing` alias, and tighten `process_params` / `_params_to_jax` boundaries to their strict canonical types. Rewrite the AGENTS.md guidance that mandated the overload pattern. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Its sole producer, process_regimes, returns MappingProxyType; no caller passes a plain dict, so Mapping only loses precision. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Replace every bare `Array` annotation with the narrowest jaxtyping alias or union (`FloatND`, `IntND`, `BoolND`, `KeyArray`). Add a `UserInitialConditions` alias whose int slot is dtype-generic `Int[Array, "..."]` for the `Model.simulate` boundary, where users pass int64 arrays that `build_initial_states` downcasts — parallel to `_ParamsLeaf`. Remaining `Array` references are runtime `isinstance` checks and jaxtyping bracket parameters. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Revert the int64-accepting widening of `_ParamsLeaf` and the `UserInitialConditions` alias: pylcm pins integers to int32 everywhere, so the `process_params` / `Model.simulate` boundaries take `IntND` and reject raw int64 fail-fast. Add `RealND` (jaxtyping `Real`, any-width float or int) for the genuinely dtype-polymorphic `ndimage` / `argmax` primitives, which are exercised across dtypes below the canonical-dtype invariant. Pin internal `jnp.arange` index arrays to int32 so pylcm never produces int64 under x64. Fix tests that passed raw int64 initial conditions / params to use the correct dtype. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
RealND only existed to let the vendored ndimage/argmax primitives accept int64 from `test_ndimage` / `test_argmax` parametrizations — but pylcm never produces int64, so those int64 cases tested a path that cannot occur. Narrow the primitives to `FloatND` / `IntND` / `FloatND | IntND` and switch the test dtype parametrizations to int32. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…b keys `KeyArray` -> `PRNGKeyND` to fit the `*ND` alias family. The stochastic next-state wrappers annotated their `key_<qname>` parameter as `dict[str, PRNGKeyND]`, but dags hands the function a single (batched) PRNG key — `generate_simulation_keys` is the thing that returns the dict, keyed by `key_<qname>`. Correct those three sites to `PRNGKeyND`. Key `regime_transition_probs` mappings by `RegimeName` rather than bare `str`. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ping Key `RegimeTransitionFunction` / `VmappedRegimeTransitionFunction` are only ever the processed, dict-returning wrappers — never the user's raw `next_regime`. Annotate their `__call__` as `MappingProxyType[RegimeName, FloatND]`, which drops the `# ty: ignore[invalid-assignment]` at every call site. Type `PRNGKeyND` as `Key[Array, "..."]` — jaxtyping's `Key` matches the `key<fry>` dtype, so it no longer falls back to bare `Array`. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ds/shocks
The `beartype_package` claws on `lcm.grids` and `lcm.shocks` apply
`@beartype` to every class — every method, not just `__init__` — with
`GRID_CONF`. `@beartype_init` only ever wrapped `__init__` with the same
conf, so all 13 decorators were a strict subset of what the claw does.
Its docstring rationale ("don't police every method") is moot: the claw
polices every method and the suite is green. Remove the decorators and
the `beartype_init` helper.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Member
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked on #355.
Summary
Activates beartype's AST-rewriting claw on
lcm.grids,lcm.shocks,and
lcm.paramsand fixes every annotation lie those three subpackageswere carrying. Perimeter
@beartype/@beartype_initdecorators onuser-facing constructors (from #355) are unchanged.
What's live
The claw registrations sit at the top of
lcm/__init__.py, beforethe bulk-import of submodules, so the AST-rewriting import hook gets
installed before any matching submodule loads. Per-subpackage
BeartypeConfmaps type violations to the project exception mostnatural to that scope:
lcm.grids/lcm.shocks→GridInitializationErrorlcm.params→InvalidParamsErrorA probe call into a deliberately-bad-typed helper (e.g.
get_linspace_coordinate(value="not a float", ...)) raisesGridInitializationErrorend-to-end.The earlier
tests/conftest.pyregistration ran afterfrom lcm._beartype_conf import ...triggeredlcm/__init__.py, whicheagerly imports every targeted submodule — the claw was a no-op on
every prior run. Moving the registration into
lcm/__init__.pyitself fixed that.
Drift fixed alongside activation
Piececumulativen_points: explicitdtype=jnp.int32on the_piece_n_pointsarray andsum(dtype=jnp.int32)so the propertyhonours its
ScalarIntannotation underjax_enable_x64=True(
jnp.sumotherwise promotes to int64)._ShockGrid.compute_gridpoints/compute_transition_probs/Tauchen._innovation_variance: kwargs widened fromfloat | FloatNDtofloat | FloatND | IntND. Runtime params liken_std=2arrive as 0-d int arrays after pylcm's canonical-dtype cast.
get_coordinateon every continuous grid +get_irreg_coordinate:accept Python
float/intalongsideScalarFloat | FloatND.process_params(params_template=...)widened toMapping[RegimeName, Mapping[str, object]]since real callers andtest fixtures both use plain dicts at the inner levels.
_ParamsLeafaddsnp.ndarrayso numpy-array params satisfythe perimeter check before pylcm casts them.
create_regime_params_template(regime: ...)and helpers,create_params_template(internal_regimes: ...)typed asAnyso duck-typed test mocks (
RegimeMock,MockRegime) satisfy thesignature — the functions themselves only touch a small structural
interface.
draw_shock'skey:parameter usesKeyArray(new alias inlcm.typing); JAX PRNG keys havedtype=key<fry>, not float, so they don't matchFloatND.MappingProxyType → Mappingat internal entry points wherecallers pass plain
dict(e.g.solve_brute.solve, perimeter ofprocess_params).grids/coordinates.py:get_linspace_coordinate/get_logspace_coordinatewidenstart: ScalarFloat/stop: ScalarFloattoScalarFloat | FloatND(recursive log → linear path passes 0-d arrays).Never→NoReturninsimulation/initial_conditions.py(
typing.Neveris unsupported by beartype 0.22).persistence.pydefines_ModelOrNone/_SimulationResultOrNoneas
TYPE_CHECKING-conditional aliases that resolve toAnyatruntime — the bare-union string forward reference
'Model | None'was unparseable by beartype's forward-ref machinery, and importing
Modelat the top causes a circular import.dtypes.py:canonical_float_dtypereturn type wasjnp.dtypebutactually returns
_ScalarMeta(a class). Changed totype.utils/containers.py:get_field_names_and_valuesparameter wasdc: typebut the function passes throughdataclasses.fields(),which accepts both classes and instances. Changed to
dc: object.Naming
v_array→V_arrrename on_RegimeSharding.V_arr_shardingand_build_zero_V_arrto match the rest of the solve path (whichalready uses
V_arr/next_regime_to_V_arrthroughout). Came inthrough the merge from
distributed(#346).Test plan
pixi run -e tests-cpu pytest -n 4— 979 passed, 10 skippedpixi run ty— cleanprek run --all-files— cleanget_linspace_coordinate(value="not a float", ...)raisesGridInitializationErrorfrom the claw-installed wrapper, notTypeErrorfrom the bare Python operator inside the function.submodule pointer.
Performance
Full pylcm test suite: ~88-90 s with claw live vs ~86-87 s without —
~3% overhead at suite scale. Inside
jax.jit, beartype's checks runduring tracing only, so steady-state execution is unaffected.
Out of scope
Extending the claw to
lcm.regime_building/lcm.solution/lcm.simulationis blocked on a structural conflict:lcm.regime_building.processing._wrap_regime_transition_probsusesfunctools.wraps(user_func)+ dags'with_signatureto propagate theuser function's annotations onto an internal wrapper. The claw then
decorates that wrapper with
@beartype, and beartype enforces userannotations like
age: float/period: intagainst JIT tracersinside
jax.jit. Either annotation stripping at the wrap site, or abeartype hint-override mapping
float/inttoArray | float | int,is needed before that scope can go live.
🤖 Generated with Claude Code