Skip to content

Activate scoped beartype claw on lcm.grids/shocks/params + drift fixes#356

Merged
hmgaudecker merged 15 commits into
feat/beartype-perimeterfrom
feat/beartype-claw-cleanup
May 14, 2026
Merged

Activate scoped beartype claw on lcm.grids/shocks/params + drift fixes#356
hmgaudecker merged 15 commits into
feat/beartype-perimeterfrom
feat/beartype-claw-cleanup

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

@hmgaudecker hmgaudecker commented May 13, 2026

Stacked on #355.

Summary

Activates beartype's AST-rewriting claw on lcm.grids, lcm.shocks,
and lcm.params and fixes every annotation lie those three subpackages
were carrying. Perimeter @beartype / @beartype_init decorators on
user-facing constructors (from #355) are unchanged.

What's live

The claw registrations sit at the top of lcm/__init__.py, before
the bulk-import of submodules, so the AST-rewriting import hook gets
installed before any matching submodule loads. Per-subpackage
BeartypeConf maps type violations to the project exception most
natural to that scope:

  • lcm.grids / lcm.shocksGridInitializationError
  • lcm.paramsInvalidParamsError

A probe call into a deliberately-bad-typed helper (e.g.
get_linspace_coordinate(value="not a float", ...)) raises
GridInitializationError end-to-end.

The earlier tests/conftest.py registration ran after from lcm._beartype_conf import ... triggered lcm/__init__.py, which
eagerly imports every targeted submodule — the claw was a no-op on
every prior run. Moving the registration into lcm/__init__.py
itself fixed that.

Drift fixed alongside activation

  • Piece cumulative n_points: explicit dtype=jnp.int32 on the
    _piece_n_points array and sum(dtype=jnp.int32) so the property
    honours its ScalarInt annotation under jax_enable_x64=True
    (jnp.sum otherwise promotes to int64).
  • _ShockGrid.compute_gridpoints / compute_transition_probs /
    Tauchen._innovation_variance
    : kwargs widened from float | FloatND to float | FloatND | IntND. Runtime params like n_std=2
    arrive as 0-d int arrays after pylcm's canonical-dtype cast.
  • get_coordinate on every continuous grid + get_irreg_coordinate:
    accept Python float / int alongside ScalarFloat | FloatND.
  • process_params(params_template=...) widened to
    Mapping[RegimeName, Mapping[str, object]] since real callers and
    test fixtures both use plain dicts at the inner levels.
  • _ParamsLeaf adds np.ndarray so numpy-array params satisfy
    the perimeter check before pylcm casts them.
  • create_regime_params_template(regime: ...) and helpers,
    create_params_template(internal_regimes: ...) typed as Any
    so duck-typed test mocks (RegimeMock, MockRegime) satisfy the
    signature — the functions themselves only touch a small structural
    interface.
  • PRNG keys: every draw_shock's key: parameter uses
    KeyArray (new alias in lcm.typing); JAX PRNG keys have
    dtype=key<fry>, not float, so they don't match FloatND.
  • MappingProxyType → Mapping at internal entry points where
    callers pass plain dict (e.g. solve_brute.solve, perimeter of
    process_params).
  • JAX scalar/array dtype drift in grids/coordinates.py:
    get_linspace_coordinate / get_logspace_coordinate widen
    start: ScalarFloat / stop: ScalarFloat to ScalarFloat | FloatND (recursive log → linear path passes 0-d arrays).
  • NeverNoReturn in simulation/initial_conditions.py
    (typing.Never is unsupported by beartype 0.22).
  • persistence.py defines _ModelOrNone / _SimulationResultOrNone
    as TYPE_CHECKING-conditional aliases that resolve to Any at
    runtime — the bare-union string forward reference 'Model | None'
    was unparseable by beartype's forward-ref machinery, and importing
    Model at the top causes a circular import.
  • Two corrections found in passing:
    • dtypes.py:canonical_float_dtype return type was jnp.dtype but
      actually returns _ScalarMeta (a class). Changed to type.
    • utils/containers.py:get_field_names_and_values parameter was
      dc: type but the function passes through dataclasses.fields(),
      which accepts both classes and instances. Changed to dc: object.

Naming

v_arrayV_arr rename on _RegimeSharding.V_arr_sharding and
_build_zero_V_arr to match the rest of the solve path (which
already uses V_arr / next_regime_to_V_arr throughout). Came in
through the merge from distributed (#346).

Test plan

  • pixi run -e tests-cpu pytest -n 4 — 979 passed, 10 skipped
  • pixi run ty — clean
  • prek run --all-files — clean
  • Probe call confirms the claw is genuinely live: a wrong-typed
    get_linspace_coordinate(value="not a float", ...) raises
    GridInitializationError from the claw-installed wrapper, not
    TypeError from the bare Python operator inside the function.
  • aca-dev workspace tests (414 passed) green against this branch's
    submodule pointer.

Performance

Full pylcm test suite: ~88-90 s with claw live vs ~86-87 s without —
~3% overhead at suite scale. Inside jax.jit, beartype's checks run
during tracing only, so steady-state execution is unaffected.

Out of scope

Extending the claw to lcm.regime_building / lcm.solution /
lcm.simulation is blocked on a structural conflict:
lcm.regime_building.processing._wrap_regime_transition_probs uses
functools.wraps(user_func) + dags' with_signature to propagate the
user function's annotations onto an internal wrapper. The claw then
decorates that wrapper with @beartype, and beartype enforces user
annotations like age: float / period: int against JIT tracers
inside jax.jit. Either annotation stripping at the wrap site, or a
beartype hint-override mapping float/int to Array | float | int,
is needed before that scope can go live.

🤖 Generated with Claude Code

Targeted widenings + corrections that align type annotations with the
values actually flowing through pylcm at runtime. None of these change
behaviour today; together they cover what beartype's `On` strategy
would flag if `beartype.claw.beartype_package("lcm", ...)` were
enabled in `tests/conftest.py`.

Categories addressed (one site per bullet unless noted):

- `MappingProxyType` annotations on internal entry points widened to
  `Mapping`: `solve_brute.py:solve` (8 sites) and
  `params/processing.py:process_params`. Callers pass plain `dict` in
  several spots; the immutability discipline isn't load-bearing on
  these boundaries.
- JAX scalar/array dtype drift on `grids/coordinates.py`:
  `get_linspace_coordinate` / `get_logspace_coordinate` widen
  `start: ScalarFloat` and `stop: ScalarFloat` to `ScalarFloat | Array`
  (the recursive log → linear path passes 0-d JAX arrays).
- `shocks/_base.py:_gauss_hermite_normal`: add `ScalarFloat` to the
  `mu` / `sigma` unions to admit 0-d JAX scalars.
- PRNG keys: `draw_shock` methods on every `Tauchen`, `Rouwenhorst`,
  `TauchenNormalMixture`, `Uniform`, `Normal`, `LogNormal` retype
  `key: FloatND` to `key: Array` — PRNG keys have `dtype=key<fry>`,
  not float.
- `utils/dispatchers.py:batched_vmap`: retype `kwargs: FloatND` and
  `Float1D` to `Array`; the dispatcher routes int arrays as well as
  float arrays through the vmap chain.
- `params/{mapping_leaf,sequence_leaf}.py:_unflatten`:
  `values: list[Any]` → `Sequence[Any]`; pytree unflatten receives the
  tuple from `_flatten`.
- Beartype-unsupported PEP forms:
  - `simulation/initial_conditions.py:_raise_feasibility_type_error`:
    `Never` → `NoReturn` (beartype 0.22 rejects `typing.Never`).
  - `persistence.py`: define `_ModelOrNone` and
    `_SimulationResultOrNone` as `TYPE_CHECKING`-conditional aliases
    that resolve to `Any` at runtime. The bare-union string forward
    reference `'Model | None'` was unparseable by beartype's
    forward-ref machinery, but importing `Model` at the top causes a
    circular import.
- Corrections discovered in passing:
  - `dtypes.py:canonical_float_dtype` return annotation `jnp.dtype` →
    `type` (the function returns `_ScalarMeta`, a class, not a numpy
    `dtype` instance).
  - `utils/containers.py:get_field_names_and_values` parameter
    `dc: type` → `dc: object` (the function accepts both dataclass
    classes and instances).
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 13, 2026

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 13, 2026

Benchmark comparison (main → HEAD)

Comparing 49b1408d (main) → bfc7c35a (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 28.009 s 25.345 s 0.90
peak GPU mem 579 MB 2.90 GB 5.00
compilation time 300.18 s 304.03 s 1.01
peak CPU mem 7.57 GB 7.57 GB 1.00
Mahler-Yum execution time 4.678 s 4.432 s 0.95
peak GPU mem 529 MB 529 MB 1.00
compilation time 14.08 s 14.14 s 1.00
peak CPU mem 1.68 GB 1.70 GB 1.01
Precautionary Savings - Solve execution time 52.0 ms 41.0 ms 0.79
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.64 s 2.78 s 1.05
peak CPU mem 1.13 GB 1.13 GB 1.00
Precautionary Savings - Simulate execution time 122.6 ms 98.9 ms 0.81
peak GPU mem 344 MB 344 MB 1.00
compilation time 4.68 s 4.84 s 1.03
peak CPU mem 1.30 GB 1.31 GB 1.01
Precautionary Savings - Solve & Simulate execution time 148.3 ms 127.3 ms 0.86
peak GPU mem 578 MB 578 MB 1.00
compilation time 7.41 s 7.16 s 0.97
peak CPU mem 1.29 GB 1.28 GB 1.00
Precautionary Savings - Solve & Simulate (irreg) execution time 283.2 ms 260.4 ms 0.92
peak GPU mem 2.19 GB 2.19 GB 1.00
compilation time 7.15 s 7.64 s 1.07
peak CPU mem 1.34 GB 1.34 GB 1.00

hmgaudecker and others added 5 commits May 13, 2026 18:25
Activate beartype's import-time claw via tests/conftest.py for the
`lcm.grids`, `lcm.shocks`, and `lcm.params` packages, configured to
raise the existing project exceptions (`GridInitializationError` /
`InvalidParamsError`) on parameter-type violations. The claw catches
real annotation drift that the previous decorator-only perimeter did
not police — every shock-grid `draw_shock`, `compute_gridpoints`, and
`compute_transition_probs` plus every grid `get_coordinate` now has
runtime type checks.

The drift fixes:

- iid `draw_shock` returns `ScalarFloat` (not `Float1D`) — each call
  produces a 0-d sample.
- ar1 `draw_shock` takes / returns `ScalarFloat` per element (matches
  what `next_state.py` feeds it under vmap).
- shock-grid `compute_gridpoints` / `compute_transition_probs` accept
  `**kwargs: float | FloatND` — runtime kwargs from the params wrapper
  arrive as JAX tracers, not Python floats.
- `_mixture_cdf` and `TauchenNormalMixture._innovation_variance` widen
  scalar params to `float | FloatND` so JAX-array shock params flow
  through.
- `get_linspace_coordinate` / `get_logspace_coordinate` accept
  `n_points: ScalarInt | IntND` for the piecewise dispatch where
  `_piece_n_points[piece_idx]` is an `IntND` under vectorized lookup.
- Piecewise + linspace/logspace coordinate methods accept
  `value: float | ScalarFloat | FloatND` (Python float was already
  used in tests; the contract just lacked the annotation).
- Continuous-grid + shock-grid `get_coordinate` overloads use
  `FloatND` instead of bare `Array`.
`KeyArray = Array` makes the `key:` parameter on every `draw_shock`
explicit about expecting a PRNG key rather than a float array. PRNGKey
arrays carry dtype `key<fry>` and don't match `FloatND`, so using
`KeyArray` is the natural type slot for those signatures.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Moves the `beartype_package(...)` registrations from `tests/conftest.py`
into `lcm/__init__.py` so the claw actually instruments `lcm.grids`,
`lcm.shocks`, and `lcm.params` at first import. The conftest version
ran AFTER `from lcm._beartype_conf import ...` triggered
`lcm/__init__.py`, which eagerly imports every targeted submodule —
the claw was a no-op on every prior run.

Live-claw drift fixed:

- `Piece` `n_points` cumulative sum: explicit `dtype=jnp.int32` on the
  `_piece_n_points` array + `sum(dtype=jnp.int32)` so the property
  honours its `ScalarInt` annotation under `jax_enable_x64=True`.
- `_ShockGrid.compute_gridpoints` / `compute_transition_probs` /
  `Tauchen._innovation_variance`: kwargs accept `float | FloatND |
  IntND`; runtime params (e.g. `n_std=2`) flow through as 0-d int
  arrays after pylcm's canonical-dtype cast.
- `UniformContinuousGrid` / `LinSpacedGrid` / `LogSpacedGrid` /
  `_ShockGrid` `get_coordinate`: accept Python `float` / `int`
  alongside `ScalarFloat | FloatND`. Same widening on
  `get_irreg_coordinate`.
- `process_params(params_template=...)`: widen to
  `Mapping[RegimeName, Mapping[str, object]]` since real callers and
  test mocks both use plain dicts at the inner levels (the strict
  `MappingProxyType[...]`-only alias rejected them).
- `_ParamsLeaf`: add `np.ndarray` so numpy-array params (used in
  dtype-invariant tests) pass the perimeter check before pylcm casts
  them.
- `create_regime_params_template(regime: ...)` and helpers,
  `create_params_template(internal_regimes: ...)`: typed as `Any` so
  duck-typed test mocks satisfy the signature (the functions
  themselves only touch a small structural interface — `states`,
  `actions`, `functions`, `regime_params_template`).
- `test_as_leaf_rejects_int` now expects `InvalidParamsError` (the
  configured `PARAMS_CONF` exception), matching the claw-routed
  violation; the in-function `TypeError` raise no longer fires
  because beartype rejects the argument first.

Drops the now-redundant claw setup from `tests/conftest.py` and
strips the `# ty: ignore[invalid-argument-type]` / `[no-matching-
overload]` markers the widenings make unnecessary.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@hmgaudecker hmgaudecker changed the title Annotation drift sweep — prep for whole-package beartype claw Activate scoped beartype claw on lcm.grids/shocks/params + drift fixes May 13, 2026
… boundary

Reverts the widenings introduced when activating the claw on
`lcm.grids` / `lcm.shocks` / `lcm.params`. The principle: user-facing
constructors keep `float | ScalarFloat` etc. as a convenience, but
pylcm-internal helpers take only the canonical JAX-side types and
callers cast at the boundary.

Internal-API signatures restored to JAX-only:

- `get_coordinate` on every continuous + shock grid (and the abstract
  base in `_ShockGrid` / `ContinuousGrid`) drops Python `float` from
  the `value:` parameter; `get_linspace_coordinate` /
  `get_logspace_coordinate` / `get_irreg_coordinate` in
  `lcm.grids.coordinates` do the same.
- `_gauss_hermite_normal(mu, sigma)` drops `float`; the one Python
  literal call site in `Tauchen.compute_transition_probs` now wraps
  with `jnp.asarray(0.0)`.
- `_mixture_cdf(p1, mu1, sigma1, mu2, sigma2)` drops `float` and uses
  `ScalarFloat | FloatND`.
- `Tauchen._innovation_variance(...)` drops `float`.
- `_ShockGrid.compute_gridpoints` / `compute_transition_probs` kwargs
  drop `float`; the per-grid concrete overrides match.
- `draw_shock(params: MappingProxyType[str, float | FloatND])` →
  `MappingProxyType[str, FloatND | IntND]`.
- `lcm.params.processing.create_params_template(internal_regimes: ...)`
  restored to `Mapping[RegimeName, InternalRegime]` (was `Any` for
  duck-typed mocks); `solve_brute.solve` and the other internal
  helpers in `solve_brute` use `MappingProxyType[RegimeName,
  InternalRegime]` (was `Mapping[...]`).
- `create_regime_params_template(regime: Regime)` and its helpers
  restored to the strict `Regime` type.

Boundary casts via a new private `_params_to_jax` helper in
`lcm.shocks._base`:

- `_ShockGrid.get_gridpoints` / `get_transition_probs` cast
  `self.params` (mix of Python literals + runtime arrays).
- `interfaces.StateActionSpace.state_action_space` casts `shock_kw`
  before passing to `spec.compute_gridpoints`.
- `regime_building.processing.weights_func_runtime` casts before
  invoking the shock grid's compute helpers.
- `regime_building.next_state._create_*_next_func` casts before
  calling `_draw_shock`.

`Piece._piece_n_points` cumulative `.sum()` explicit `dtype=jnp.int32`
so the property returns `ScalarInt` (int32) under
`jax_enable_x64=True` instead of the int64 JAX produces by default.

Test mocks inherit the production class so `isinstance(x, Regime)` /
`isinstance(x, InternalRegime)` holds at the beartype-checked
perimeter:

- `tests/regime_mock.py:RegimeMock` inherits `Regime`, overrides
  `__init__` to bypass the parent's `@beartype`-decorated constructor
  validation via `object.__setattr__`. Mirrors `Regime.__post_init__`'s
  default-H injection.
- `tests/regime_building/test_process_params.py:MockRegime` inherits
  `InternalRegime`, sets only `regime_params_template` directly via
  `object.__setattr__`.

Tests updated to pass JAX scalars at the perimeter:

- `tests/test_grids.py`, `tests/test_function_representation.py`,
  `tests/test_fgp_discretization.py` wrap Python literals with
  `jnp.asarray(...)` at sites that flow into `get_coordinate` /
  `_mixture_cdf`.
- `test_function_evaluator_performs_linear_extrapolation`: the int
  `10` in `wealth_outside_of_grid` becomes `10.0`.

Drops the `# ty: ignore[invalid-argument-type]` markers in test files
where the inheritance fix makes them unnecessary.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
hmgaudecker and others added 8 commits May 14, 2026 09:45
…loads, bare Array

ty erases jaxtyping shapes, so `ScalarX | XND` unions and scalar/array
`@overload` pairs add zero static precision; beartype treats a 0-d array
as satisfying both, so the union is redundant at runtime too. Collapse
both patterns, eliminate bare `Array` annotations in favour of the
narrowest `lcm.typing` alias, and tighten `process_params` /
`_params_to_jax` boundaries to their strict canonical types. Rewrite the
AGENTS.md guidance that mandated the overload pattern.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Its sole producer, process_regimes, returns MappingProxyType; no caller
passes a plain dict, so Mapping only loses precision.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Replace every bare `Array` annotation with the narrowest jaxtyping alias
or union (`FloatND`, `IntND`, `BoolND`, `KeyArray`). Add a
`UserInitialConditions` alias whose int slot is dtype-generic
`Int[Array, "..."]` for the `Model.simulate` boundary, where users pass
int64 arrays that `build_initial_states` downcasts — parallel to
`_ParamsLeaf`. Remaining `Array` references are runtime `isinstance`
checks and jaxtyping bracket parameters.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Revert the int64-accepting widening of `_ParamsLeaf` and the
`UserInitialConditions` alias: pylcm pins integers to int32 everywhere,
so the `process_params` / `Model.simulate` boundaries take `IntND` and
reject raw int64 fail-fast. Add `RealND` (jaxtyping `Real`, any-width
float or int) for the genuinely dtype-polymorphic `ndimage` / `argmax`
primitives, which are exercised across dtypes below the canonical-dtype
invariant. Pin internal `jnp.arange` index arrays to int32 so pylcm
never produces int64 under x64. Fix tests that passed raw int64 initial
conditions / params to use the correct dtype.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
RealND only existed to let the vendored ndimage/argmax primitives accept
int64 from `test_ndimage` / `test_argmax` parametrizations — but pylcm
never produces int64, so those int64 cases tested a path that cannot
occur. Narrow the primitives to `FloatND` / `IntND` / `FloatND | IntND`
and switch the test dtype parametrizations to int32.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…b keys

`KeyArray` -> `PRNGKeyND` to fit the `*ND` alias family. The stochastic
next-state wrappers annotated their `key_<qname>` parameter as
`dict[str, PRNGKeyND]`, but dags hands the function a single (batched)
PRNG key — `generate_simulation_keys` is the thing that returns the dict,
keyed by `key_<qname>`. Correct those three sites to `PRNGKeyND`. Key
`regime_transition_probs` mappings by `RegimeName` rather than bare `str`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ping Key

`RegimeTransitionFunction` / `VmappedRegimeTransitionFunction` are only
ever the processed, dict-returning wrappers — never the user's raw
`next_regime`. Annotate their `__call__` as
`MappingProxyType[RegimeName, FloatND]`, which drops the
`# ty: ignore[invalid-assignment]` at every call site. Type `PRNGKeyND`
as `Key[Array, "..."]` — jaxtyping's `Key` matches the `key<fry>` dtype,
so it no longer falls back to bare `Array`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ds/shocks

The `beartype_package` claws on `lcm.grids` and `lcm.shocks` apply
`@beartype` to every class — every method, not just `__init__` — with
`GRID_CONF`. `@beartype_init` only ever wrapped `__init__` with the same
conf, so all 13 decorators were a strict subset of what the claw does.
Its docstring rationale ("don't police every method") is moot: the claw
polices every method and the suite is green. Remove the decorators and
the `beartype_init` helper.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@hmgaudecker hmgaudecker merged commit 273ef76 into feat/beartype-perimeter May 14, 2026
10 checks passed
@hmgaudecker hmgaudecker deleted the feat/beartype-claw-cleanup branch May 14, 2026 10:58
@hmgaudecker
Copy link
Copy Markdown
Member Author

Collapsed into #355 — the #355/#356/#357 stack is now a single PR. feat/beartype-perimeter was fast-forwarded to this branch's tip, so GitHub flags this as merged; all commits are preserved intact. Review continues on #355.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant