Skip to content

Extend runtime type checking to the regime-building pipeline#357

Merged
hmgaudecker merged 11 commits into
feat/beartype-perimeterfrom
feat/beartype-claw-extend
May 14, 2026
Merged

Extend runtime type checking to the regime-building pipeline#357
hmgaudecker merged 11 commits into
feat/beartype-perimeterfrom
feat/beartype-claw-extend

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

@hmgaudecker hmgaudecker commented May 13, 2026

Stacked on #356.

Coordinated change

Part of a three-repo change adopting runtime type checking in dags-based projects:

Merge order: ttsim#99 (anytime) · dags#82 → release 0.6 · pylcm#357 (after 0.6).

Why runtime type checking

A dags-based project has type information that only fully exists at runtime. dags makes function composition dataconcatenate_functions assembles the composed callable from regime/policy environments at execution time, so its real signature is (*args, **kwargs). That's the feature, not a quirk. The corollary: no static checker can resolve a composition determined by data — it sees (*args, **kwargs) and can say nothing about the wiring. Static and runtime checking are complementary, each native to its own domain: static analysis owns what's static (every leaf function body, all non-dags code — ty runs in CI for exactly that), runtime checking owns what's runtime (the composition).

beartype checks the leaf node functions, not the composed callable — the dags wrapper is deliberately transparent to it (see dags#82). When the DAG executes and feeds node A's output into node B's parameter, beartype — having decorated node B — checks B's inputs at that moment. That is runtime verification of the wiring, the thing a static checker structurally cannot see.

What this PR does

Extends pylcm's scoped beartype claw to lcm.regime_building, mapping type violations there to ModelInitializationError (regime compilation is part of model construction). Together with #356, the claw now covers the construction-time subpackages: lcm.grids, lcm.shocks, lcm.params, lcm.regime_building.

Honest annotations — the drift the claw surfaced

With the claw live, every node's annotations get enforced against real runtime values. Where an annotation was lying, it gets fixed — each of these makes the annotation truthful about what the function legitimately receives:

  • collect_state_transitions(states: ...)Mapping[StateName, Grid | None] — placeholder states are genuinely None.
  • map_coordinates(coordinates: ...)Sequence[Array] | Array — callers pass a single 2-D Array, and JAX produces a single tracer under vmap.
  • _get_weights_func_for_shock.weights_func_runtime's shock_kwdict[str, float | FloatND] — Python floats from fixed_params, JAX tracers from runtime params under JIT (drops a now-unnecessary # ty: ignore).
  • _wrap_with_reduction.reduced's **kwargsArray | Mapping[str, Array]next_regime_to_V_arr flows through as a mapping alongside the Array-valued inputs.
  • solve_brute.solve's running_any_nan / running_any_infBoolND (were mis-typed FloatND) — caught even with the claw off lcm.solution, via solve()'s perimeter @beartype decoration.
  • One test fixture (tests/test_next_state.py::test_create_stochastic_next_func) passes an int32-typed labels array.

The complementary case — where an annotation was too wide, with Python float leaking into JAX-internal helpers — is handled in #356, which tightens those helpers to canonical JAX types and casts at named Python→JAX boundaries (_params_to_jax). Same principle, opposite direction: make the surface tell the truth about what actually flows through it.

Status — pending dags 0.6

This branch is functionally complete and green. lcm.regime_building.* imports the dags wrappers (with_signature, rename_arguments) directly — dags#82 made the forwarder shape the only behaviour, so no pylcm-side shim is needed. The one remaining transitional bit: pyproject.toml pins dags to a pre-release rev of feat/no-type-check-flag (dags#82). Once dags 0.6 ships, that pin moves to dags>=0.6 — a one-line change, no code or annotation work.

Out of scope — left for follow-ups

Extending the claw to lcm.solution / lcm.simulation surfaces ~100 further annotation-drift failures in the internal solve/simulate paths (test mocks passing int64 where DiscreteState is expected; wrappers annotated dict[str, float] whose values are JAX tracers under JIT; vmap'd random_id returning Int1D-annotated values that are actually 0-D tracers; etc.). Each is a legitimate catch — the claw doing its job — but they need a dedicated cleanup pass with the same honest-surface discipline.

Test plan

  • pixi run --environment tests-cpu tests -n 4 — 979 passed, 10 skipped
  • pixi run -e type-checking ty — clean
  • prek run --all-files — clean
  • Verified the claw is genuinely live on lcm.regime_building: a deliberately bad-typed call into a regime_building helper raises ModelInitializationError from the claw-installed wrapper.

Adds `lcm.regime_building` to the import-claw registrations in
`lcm/__init__.py`, mapping type violations to `ModelInitializationError`
(regime compilation is part of model construction).

Wraps `dags.with_signature` / `dags.rename_arguments` via a thin
helper in `lcm.utils._dags_forwarders` that defaults
`forwarder=True`. Each direct caller of those dags helpers inside
pylcm produces a generic `*args, **kwargs` forwarder whose annotations
describe the inner function's contract, not the wrapper's own call
protocol. `forwarder=True` advertises the wrapper as a permissive
forwarder on its `__annotations__`, so beartype's claw treats it as
universally permissive and skips per-parameter enforcement — matching
the wrapper's actual runtime behaviour. dags' own `get_annotations`
recovers the user-described view via its existing args/kwargs-mismatch
fallback.

Pins `dags` to the `feat/no-type-check-flag` branch (PR
OpenSourceEconomics/dags#82) which adds the `forwarder` flag. Will be
replaced with a released version once that PR lands.

Annotation drift fixed alongside activation:

- `collect_state_transitions(states: ...)` widened to
  `Mapping[StateName, Grid | None]`; test mocks pass `None` for
  placeholder states.
- `map_coordinates(coordinates: ...)` widened to
  `Sequence[Array] | Array`; callers pass a 2D `jnp.array` (a single
  Array, not a sequence) and JAX produces a single tracer under vmap.
- `_get_weights_func_for_shock.weights_func_runtime.shock_kw` typed as
  `dict[str, float | FloatND]`; under JIT the runtime shock params
  arrive as tracers.
- `solve_brute.solve.running_any_nan` / `running_any_inf` typed as
  `BoolND` to match the underlying `jnp.zeros((), dtype=bool)`.
- `diagnostics._wrap_with_reduction.reduced.**kwargs` typed as
  `Array | Mapping[str, Array]` since `next_regime_to_V_arr` flows
  through it as a mapping alongside Array-valued state/action inputs.
- One test fixture (`tests/test_next_state.py::test_create_stochastic_next_func`)
  updated to pass an `int32`-typed `labels` array.

Out of scope: extending the claw to `lcm.solution` / `lcm.simulation`,
which surfaces further annotation drift the claw correctly catches but
that needs its own pass.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 13, 2026

@hmgaudecker hmgaudecker changed the title Activate beartype claw on lcm.regime_building Extend runtime type checking to the regime-building pipeline May 14, 2026
hmgaudecker and others added 9 commits May 14, 2026 09:24
dags#82 made the `*args, **kwargs` forwarder shape the only behaviour
for `with_signature` / `rename_arguments` — there is no `forwarder`
flag left to default, so the shim has nothing to do. regime_building
imports the dags wrappers directly again; the dags pin moves to the
current branch head.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The subprocess imports lcm, whose beartype claw can emit diagnostics to
stdout, so `int(result.stdout.strip())` blew up on a polluted stream.
Mark the peak-bytes line and locate it instead of parsing stdout whole.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
# Conflicts:
#	src/lcm/regime_building/diagnostics.py
#	src/lcm/regime_building/ndimage.py
#	src/lcm/regime_building/processing.py
…t tests

Activating the claw on `lcm.regime_building` makes jaxtyping shape checks
run on cloudpickled annotation types. jaxtyping marks a `"..."` axis with
a plain `object()` sentinel that loses identity across a pickle
round-trip, tripping `assert type(variadic_dim) is _NamedVariadicDim`.
Replace it with a `__reduce__`-backed singleton, patched in before any
jaxtyping-subscripted type is created. Fix `test_ndimage_unit` cases that
passed raw int64 arrays the int32-pinned primitives now reject.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@hmgaudecker hmgaudecker linked an issue May 14, 2026 that may be closed by this pull request
2 tasks
Base automatically changed from feat/beartype-claw-cleanup to feat/beartype-perimeter May 14, 2026 10:58
@hmgaudecker hmgaudecker merged commit 273ef76 into feat/beartype-perimeter May 14, 2026
10 checks passed
@hmgaudecker hmgaudecker deleted the feat/beartype-claw-extend branch May 14, 2026 10:58
@hmgaudecker
Copy link
Copy Markdown
Member Author

Collapsed into #355 — the #355/#356/#357 stack is now a single PR. feat/beartype-perimeter was fast-forwarded to this branch's tip, so GitHub flags this as merged; all commits are preserved intact. Review continues on #355.

@github-actions
Copy link
Copy Markdown

Benchmark comparison (main → HEAD)

Comparing 49b1408d (main) → 273ef76e (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 28.009 s 25.138 s 0.90
peak GPU mem 579 MB 823 MB 1.42
compilation time 300.18 s 296.31 s 0.99
peak CPU mem 7.57 GB 7.38 GB 0.97
Mahler-Yum execution time 4.678 s 4.413 s 0.94
peak GPU mem 529 MB 529 MB 1.00
compilation time 14.08 s 14.00 s 0.99
peak CPU mem 1.68 GB 1.69 GB 1.01
Precautionary Savings - Solve execution time 52.0 ms 42.2 ms 0.81
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.64 s 2.75 s 1.04
peak CPU mem 1.13 GB 1.14 GB 1.00
Precautionary Savings - Simulate execution time 122.6 ms 114.5 ms 0.93
peak GPU mem 344 MB 344 MB 1.00
compilation time 4.68 s 4.95 s 1.06
peak CPU mem 1.30 GB 1.30 GB 1.00
Precautionary Savings - Solve & Simulate execution time 148.3 ms 131.4 ms 0.89
peak GPU mem 578 MB 578 MB 1.00
compilation time 7.41 s 7.35 s 0.99
peak CPU mem 1.29 GB 1.29 GB 1.00
Precautionary Savings - Solve & Simulate (irreg) execution time 283.2 ms 271.4 ms 0.96
peak GPU mem 2.19 GB 2.19 GB 1.00
compilation time 7.15 s 7.38 s 1.03
peak CPU mem 1.34 GB 1.35 GB 1.01

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ENH: Improve type annotations and checking

1 participant