Skip to content

Commit 0052a42

Browse files
committed
Address review: docstrings, error-message fix, boundary tests
- `_transition_checks.py` module docstring no longer references a non-existent `regime_building/static_checks.py`; it describes the actual distinction (runtime numerical checks vs. construction-time regime-spec validators). - `regime_building/transitions.py`: the `collect_state_transitions` docstring said ShockGrid states get a `lambda: None` stub; the code skips them entirely. Docstring corrected. - `solution/validate_V.py`: `Dict` → `dict` in a Returns section to match the `dict[str, Any]` annotation. - `_transition_checks.py` `_validate_no_reachable_incomplete_targets`: drop the override that, when a target regime was absent from the source's `state_transitions`, listed every state of the target as "missing" — including non-stochastic states that need no explicit entry. The preceding line already computes the correct stochastic- only set for that case. - `user_regime.py`: add a scope-boundary note to `_validate_function_output_grid_indexing` — the AST check is deliberately best-effort and should be deleted rather than hardened if it ever produces false positives. - Replace the two contentless happy-path validator tests with boundary inputs: values at the inclusive [0, 1] bounds and row sums just inside the sum-to-1 tolerance, so "does not raise" pins the tolerance/bound logic instead of being a bare smoke check.
1 parent 2b7f838 commit 0052a42

6 files changed

Lines changed: 36 additions & 13 deletions

File tree

src/lcm/_transition_checks.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
regimes, and no positive probability to a target with incomplete stochastic
88
transitions.
99
10-
Distinct from `regime_building/static_checks.py`, which fires at canonical-
11-
regime construction time on Python source (AST). The checks here need a
12-
fully-built `Regime` plus user `flat_params` to run.
10+
These are runtime checks: they need a fully-built `Regime` plus user
11+
`flat_params` and evaluate the transition functions numerically. The
12+
construction-time regime-spec validators (`Regime.__post_init__`, which
13+
inspect grids, signatures, and Python source) are a separate concern.
1314
1415
"""
1516

@@ -332,8 +333,6 @@ def _validate_no_reachable_incomplete_targets(
332333
if not jnp.any(regime_transition_probs[target_regime_name] > 0):
333334
continue
334335
missing = sorted(needs - set(transitions.get(target_regime_name, {})))
335-
if target_regime_name not in transitions:
336-
missing = sorted(f"next_{s}" for s in target_regime.variables.state_names)
337336
raise InvalidRegimeTransitionProbabilitiesError(
338337
f"Regime '{regime_name}' at age {age} has positive transition "
339338
f"probability to '{target_regime_name}', but '{regime_name}' "

src/lcm/regime_building/transitions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def collect_state_transitions(
2929
"""Collect state transition functions from `state_transitions`.
3030
3131
For each state, produces entries keyed as `f"next_{name}"`:
32-
- ShockGrid -> stub `lambda: None`
32+
- ShockGrid -> skipped (shock transitions are built directly in
33+
`_process_regime_core`)
3334
- `None` -> auto-generated identity transition
3435
- Callable -> used directly
3536
- `MarkovTransition` -> used directly (callable via `__call__`)

src/lcm/solution/validate_V.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _summarize_diagnostics(
202202
age: Age for the summary header.
203203
204204
Returns:
205-
Dict with per-metric `"overall"` and `"by_dim"` entries plus a
205+
dict with per-metric `"overall"` and `"by_dim"` entries plus a
206206
`"regime_probs"` mapping, suitable for `_format_diagnostic_summary`.
207207
208208
"""

src/lcm/user_regime.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,12 @@ def _validate_function_output_grid_indexing(regime: Regime) -> list[str]:
399399
consumer is indexing a 0-d array by a scalar integer, which raises
400400
`IndexError` at trace time. The fix is to drop the redundant `[g]`
401401
in the consumer (or refactor `f` not to take `g`).
402+
403+
This check is deliberately best-effort: it catches the common
404+
`func_output[discrete_grid]` subscript form and nothing else. It is not
405+
meant to grow into a general correctness checker for user functions —
406+
if it ever produces false positives, prefer deleting it over hardening
407+
it to chase every way the pattern can hide.
402408
"""
403409
function_output_names = set(regime.functions)
404410
discrete_grid_names = (

tests/test_regime.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,21 @@ def _make_partner_probs_array():
439439
)
440440

441441

442-
def test_validate_transition_probs_valid():
442+
def test_validate_transition_probs_accepts_boundary_inputs():
443+
"""Inclusive [0, 1] bounds and row sums within the 1e-6 tolerance pass.
444+
445+
The first row is exactly `[0.0, 1.0]` — values at the inclusive bounds.
446+
The last row sums to `1 - 5e-7`, just inside the `atol=1e-6` row-sum
447+
tolerance. The validator must accept both without raising.
448+
"""
443449
model = get_stochastic_model(3)
444-
arr = _make_partner_probs_array()
450+
arr = jnp.array(
451+
[
452+
[[[0.0, 1.0], [1.0, 0.0]], [[0.3, 0.7], [0.6, 0.4]]],
453+
[[[0.4, 0.6], [0.8, 0.2]], [[0.2, 0.8], [0.7, 0.3]]],
454+
[[[0.5, 0.5], [0.9, 0.1]], [[0.3, 0.7], [0.5, 0.4999995]]],
455+
]
456+
)
445457
validate_transition_probs(
446458
probs=arr, model=model, regime_name="working_life", state_name="partner"
447459
)

tests/test_validate_regime_transition_probs.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,17 @@
2222
from lcm_examples.mortality import get_model, get_params
2323

2424

25-
def test_valid_probs_all_active():
26-
"""Valid probabilities with all regimes active pass validation."""
25+
def test_valid_probs_accept_boundary_inputs():
26+
"""Inclusive [0, 1] bounds and a sum within tolerance pass validation.
27+
28+
Subject 0 splits probability exactly `[1.0, 0.0]` — values at the
29+
inclusive bounds. Subject 1 sums to `1 - 2.5e-6`, just inside the
30+
`jnp.allclose` default tolerance. The validator must accept both.
31+
"""
2732
probs = MappingProxyType(
2833
{
29-
"working_life": jnp.array([0.7, 0.6]),
30-
"retirement": jnp.array([0.3, 0.4]),
34+
"working_life": jnp.array([1.0, 0.4999975]),
35+
"retirement": jnp.array([0.0, 0.5]),
3136
}
3237
)
3338
_validate_regime_transition_probs(

0 commit comments

Comments
 (0)