Accept pd.Series in params and pd.DataFrame in initial conditions#289
Accept pd.Series in params and pd.DataFrame in initial conditions#289hmgaudecker merged 146 commits intomainfrom
Conversation
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…_states_for_subjects Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Move function-level imports to top-level, add noqa for print statements, fix line length in docstring, use keyword arg for jax.config.update. Exclude tests/data/ from name-tests-test hook. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Move jax.config.update("jax_enable_x64") back to module level (before
lcm_examples imports) to ensure 64-bit precision during model construction.
Regenerate all benchmark pickle files with correct precision.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…al_states Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…e_state_space_info Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace the fragile `expected_levels` tuple and inferred-mode fallbacks with a `param_path` parameter that resolves the parameter's position in the model template. The function now inspects the owning function's signature to determine indexing dimensions (states, actions, period) automatically. Key changes: - `array_from_series`: new API with `sr`, `model` (required), `param_path` (required 1-3 tuple). No more `data`, `ages`, `expected_levels`, or inferred mode. Always strict validation. - `_LevelMapping` dataclass + `_scatter_series`: replace `_OutcomeMapping`, `_compute_shape`, `_map_labels_to_codes` with a clean abstraction. Both `array_from_series` and `_build_probs_array` (transition probs) use it. - `_resolve_param_indexing`: resolves 1/2/3-part param paths by scanning model regimes, verifying consistent indexing across matching functions. - NaN-fill in `_build_probs_array`: `np.full(shape, np.nan)` instead of `np.zeros`, making data gaps visible. - `_resolve_categoricals`: now raises on conflicts instead of silently overriding model grids. Removed: `_OutcomeMapping`, `_multiindex_series_to_array`, `_age_series_to_array`, `_categorical_series_to_array`, `_build_level_mappings_from_grids`, `_compute_shape`, `_map_labels_to_codes`. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…idspecs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…idspecs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ctions Introduce two explicit, phase-specific containers that replace the fragile override pattern. Each consumer now knows which phase it operates in. Eliminates simulate_overrides, with_simulate_overrides(), and internal use of PhaseVariant for regime_transition_probs. PhaseVariant is retained in the user-facing API (Regime.functions) for PR 6. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- _get_func_indexing_params: drop regime parameter, use AST inspection to determine indexing params from subscript patterns in the function source. Raises TypeError for lambdas, ValueError for computed indices (with recipe to extract into a DAG function). - _validate_probs_array_indexing → _validate_array_param_indexing: accept array_param_name and indexing_params instead of hardcoding "probs_array". - _collect_probs_array_subscripts → _collect_subscripts: parameterized. - Remove _get_indexing_params (replaced by AST-based _get_func_indexing_params). - Remove regime param from _build_probs_array (no longer needed). - Make all private functions keyword-only throughout pandas_utils.py and error_handling.py. - Rename test file: test_validate_probs_indexing → test_validate_array_indexing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ot comparable + precision diffs.
…nctions Enforce consistent ordering: flat_param_names, age, period, functions, constraints, transitions, stochastic_transition_names, then remaining args. Also fix notebook cell 9 source format (string → list of strings). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
timmens
left a comment
There was a problem hiding this comment.
Very nice! I prefer this UI 😄
Only one comment:
The pandas_interop docs explain the format and the "why labels" well, but I think there's a gap for users trying to build a mental model of what's happening. The function code uses JAX integer indexing (probs_array[age, health]), but the user provides string-labeled Series, and nothing explains how those two worlds connect.
A short paragraph somewhere before or after the "Series format" section could help. Something along the lines of: your model functions work with plain JAX arrays and integer indexing as usual; the Series is purely an input convenience. Before any model code runs, the conversion inspects your function to figure out which dimensions the array is indexed over, maps labels to integer positions using the model's grids, and hands the function a normal JAX array. Your function never sees pandas.
Not blocking, but I think it'd save users some head-scratching.
…us-grid-initial-conditions
…r/remove-label-translator
808c7aa to
5e18a9c
Compare
8dd49b6 to
ca5601c
Compare
Addresses review feedback on PR #289: users needed a bridge between the labeled Series they provide and the integer-indexed JAX arrays their functions receive. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…us-grid-initial-conditions
…r/remove-label-translator
…us-grid-initial-conditions
…r/remove-label-translator
…d-initial-conditions
…r/remove-label-translator
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…dataframe - Add solve+simulate tests for models with different discrete categories per regime (e.g. health with/without "disabled") - Fix to_dataframe: remap per-regime codes to labels using each regime's own category ordering before building the merged Categorical - Fix initial_conditions_from_dataframe: cast discrete states to int32 (was float32, rejected by JAX as indexer) - Fix validate_initial_conditions: union valid codes across regimes instead of overwriting Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…r/remove-label-translator
…series-converter # Conflicts: # src/lcm/pandas_utils.py # tests/test_pandas_utils.py
Summary
@timmens was right in his comments on #272
— requiring a separate preprocessor call before
solve()/simulate()was too clunky. Took me a while to come around, but here we are: pandas
objects are now accepted directly.
model.solve()andmodel.simulate()acceptpd.Seriesvalues inparams — labeled Series with a named MultiIndex are converted to
correctly shaped JAX arrays automatically.
model.simulate()accepts apd.DataFrameasinitial_conditions—regime name mapping and discrete label encoding happen transparently.
derived_categoricalsparameter tosolve()/simulate()forDAG function outputs not in the model's state/action grids.
broadcast_to_templateas the single params tree traversalprimitive —
process_paramsand_resolve_fixed_paramsbecomeone-line wrappers.
_get_func_indexing_params— inspects function source forsubscript patterns to determine array dimensions, with required
array_param_nameto avoid false positives.Test plan
pixi run -e tests-cpu tests)prek run --all-files)pixi run -e type-checking ty)model.py(_maybe_convert_series,_maybe_convert_dataframe)convert_series_in_paramsandarray_from_seriesinpandas_utils.pybroadcast_to_templateinparams_processing.py🤖 Generated with Claude Code