Extend runtime type checking to the regime-building pipeline#357
Merged
Conversation
Adds `lcm.regime_building` to the import-claw registrations in `lcm/__init__.py`, mapping type violations to `ModelInitializationError` (regime compilation is part of model construction). Wraps `dags.with_signature` / `dags.rename_arguments` via a thin helper in `lcm.utils._dags_forwarders` that defaults `forwarder=True`. Each direct caller of those dags helpers inside pylcm produces a generic `*args, **kwargs` forwarder whose annotations describe the inner function's contract, not the wrapper's own call protocol. `forwarder=True` advertises the wrapper as a permissive forwarder on its `__annotations__`, so beartype's claw treats it as universally permissive and skips per-parameter enforcement — matching the wrapper's actual runtime behaviour. dags' own `get_annotations` recovers the user-described view via its existing args/kwargs-mismatch fallback. Pins `dags` to the `feat/no-type-check-flag` branch (PR OpenSourceEconomics/dags#82) which adds the `forwarder` flag. Will be replaced with a released version once that PR lands. Annotation drift fixed alongside activation: - `collect_state_transitions(states: ...)` widened to `Mapping[StateName, Grid | None]`; test mocks pass `None` for placeholder states. - `map_coordinates(coordinates: ...)` widened to `Sequence[Array] | Array`; callers pass a 2D `jnp.array` (a single Array, not a sequence) and JAX produces a single tracer under vmap. - `_get_weights_func_for_shock.weights_func_runtime.shock_kw` typed as `dict[str, float | FloatND]`; under JIT the runtime shock params arrive as tracers. - `solve_brute.solve.running_any_nan` / `running_any_inf` typed as `BoolND` to match the underlying `jnp.zeros((), dtype=bool)`. - `diagnostics._wrap_with_reduction.reduced.**kwargs` typed as `Array | Mapping[str, Array]` since `next_regime_to_V_arr` flows through it as a mapping alongside Array-valued state/action inputs. - One test fixture (`tests/test_next_state.py::test_create_stochastic_next_func`) updated to pass an `int32`-typed `labels` array. Out of scope: extending the claw to `lcm.solution` / `lcm.simulation`, which surfaces further annotation drift the claw correctly catches but that needs its own pass. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
… feat/beartype-claw-extend
This was referenced May 14, 2026
dags#82 made the `*args, **kwargs` forwarder shape the only behaviour for `with_signature` / `rename_arguments` — there is no `forwarder` flag left to default, so the shim has nothing to do. regime_building imports the dags wrappers directly again; the dags pin moves to the current branch head. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The subprocess imports lcm, whose beartype claw can emit diagnostics to stdout, so `int(result.stdout.strip())` blew up on a polluted stream. Mark the peak-bytes line and locate it instead of parsing stdout whole. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
# Conflicts: # src/lcm/regime_building/diagnostics.py # src/lcm/regime_building/ndimage.py # src/lcm/regime_building/processing.py
…t tests Activating the claw on `lcm.regime_building` makes jaxtyping shape checks run on cloudpickled annotation types. jaxtyping marks a `"..."` axis with a plain `object()` sentinel that loses identity across a pickle round-trip, tripping `assert type(variadic_dim) is _NamedVariadicDim`. Replace it with a `__reduce__`-backed singleton, patched in before any jaxtyping-subscripted type is created. Fix `test_ndimage_unit` cases that passed raw int64 arrays the int32-pinned primitives now reject. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…feat/beartype-claw-extend
2 tasks
Base automatically changed from
feat/beartype-claw-cleanup
to
feat/beartype-perimeter
May 14, 2026 10:58
Member
Author
5 tasks
Benchmark comparison (main → HEAD)Comparing
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked on #356.
Coordinated change
Part of a three-repo change adopting runtime type checking in dags-based projects:
Merge order: ttsim#99 (anytime) · dags#82 → release 0.6 · pylcm#357 (after 0.6).
Why runtime type checking
A dags-based project has type information that only fully exists at runtime. dags makes function composition data —
concatenate_functionsassembles the composed callable from regime/policy environments at execution time, so its real signature is(*args, **kwargs). That's the feature, not a quirk. The corollary: no static checker can resolve a composition determined by data — it sees(*args, **kwargs)and can say nothing about the wiring. Static and runtime checking are complementary, each native to its own domain: static analysis owns what's static (every leaf function body, all non-dags code —tyruns in CI for exactly that), runtime checking owns what's runtime (the composition).beartype checks the leaf node functions, not the composed callable — the dags wrapper is deliberately transparent to it (see dags#82). When the DAG executes and feeds node A's output into node B's parameter, beartype — having decorated node B — checks B's inputs at that moment. That is runtime verification of the wiring, the thing a static checker structurally cannot see.
What this PR does
Extends pylcm's scoped beartype claw to
lcm.regime_building, mapping type violations there toModelInitializationError(regime compilation is part of model construction). Together with #356, the claw now covers the construction-time subpackages:lcm.grids,lcm.shocks,lcm.params,lcm.regime_building.Honest annotations — the drift the claw surfaced
With the claw live, every node's annotations get enforced against real runtime values. Where an annotation was lying, it gets fixed — each of these makes the annotation truthful about what the function legitimately receives:
collect_state_transitions(states: ...)→Mapping[StateName, Grid | None]— placeholder states are genuinelyNone.map_coordinates(coordinates: ...)→Sequence[Array] | Array— callers pass a single 2-DArray, and JAX produces a single tracer under vmap._get_weights_func_for_shock.weights_func_runtime'sshock_kw→dict[str, float | FloatND]— Python floats fromfixed_params, JAX tracers from runtime params under JIT (drops a now-unnecessary# ty: ignore)._wrap_with_reduction.reduced's**kwargs→Array | Mapping[str, Array]—next_regime_to_V_arrflows through as a mapping alongside the Array-valued inputs.solve_brute.solve'srunning_any_nan/running_any_inf→BoolND(were mis-typedFloatND) — caught even with the claw offlcm.solution, viasolve()'s perimeter@beartypedecoration.tests/test_next_state.py::test_create_stochastic_next_func) passes anint32-typedlabelsarray.The complementary case — where an annotation was too wide, with Python
floatleaking into JAX-internal helpers — is handled in #356, which tightens those helpers to canonical JAX types and casts at named Python→JAX boundaries (_params_to_jax). Same principle, opposite direction: make the surface tell the truth about what actually flows through it.Status — pending dags 0.6
This branch is functionally complete and green.
lcm.regime_building.*imports the dags wrappers (with_signature,rename_arguments) directly — dags#82 made the forwarder shape the only behaviour, so no pylcm-side shim is needed. The one remaining transitional bit:pyproject.tomlpinsdagsto a pre-release rev offeat/no-type-check-flag(dags#82). Once dags 0.6 ships, that pin moves todags>=0.6— a one-line change, no code or annotation work.Out of scope — left for follow-ups
Extending the claw to
lcm.solution/lcm.simulationsurfaces ~100 further annotation-drift failures in the internal solve/simulate paths (test mocks passingint64whereDiscreteStateis expected; wrappers annotateddict[str, float]whose values are JAX tracers under JIT; vmap'drandom_idreturningInt1D-annotated values that are actually 0-D tracers; etc.). Each is a legitimate catch — the claw doing its job — but they need a dedicated cleanup pass with the same honest-surface discipline.Test plan
pixi run --environment tests-cpu tests -n 4— 979 passed, 10 skippedpixi run -e type-checking ty— cleanprek run --all-files— cleanlcm.regime_building: a deliberately bad-typed call into a regime_building helper raisesModelInitializationErrorfrom the claw-installed wrapper.