Conversation
Introduce strongly-typed dataclasses for model configuration: - Dimensions, Labels, Anchoring, EstimationOptions, TransitionInfo - FactorEndogenousInfo, EndogenousFactorsInfo This improves type safety and enables IDE autocompletion while keeping user-facing model_dict as a plain dictionary. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Replace dict fields with frozendict in frozen dataclasses to ensure true immutability: - Labels.aug_periods_to_periods - Labels.aug_stages_to_stages - Anchoring.outcomes - TransitionInfo.param_names, individual_functions, function_names - EndogenousFactorsInfo.aug_periods_to_aug_period_meas_types, factor_info 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Update process_model() to return a ProcessedModel frozen dataclass and update all consumers to use attribute access instead of dict access. This provides: - Better type safety with explicit typed fields - Immutability via frozen dataclass - IDE autocomplete support - Clear documentation of the model structure 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…c so that config.TEST_DATA_DIR is valid also for skillmodels the package (as opposed to the project).
The filtered_states DataFrame and params index both use aug_period as the period identifier, not period. This fixes KeyError when calling decompose_measurement_variance. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ting values on FixedConstraints.
Remove list from loc type union, convert callers to tuple(). Update anchoring test expectations from list to tuple. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The viz code assumed states DataFrames always have `aug_period` as a column, but pre-computed states (e.g. from health-cognition) may carry `period` in the index instead. Add `_normalize_states_columns` to promote index levels and rename `period` → `aug_period` when needed. Also document the period vs aug_period convention in CLAUDE.md. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #87 +/- ##
==========================================
+ Coverage 96.86% 96.91% +0.05%
==========================================
Files 57 57
Lines 4809 4952 +143
==========================================
+ Hits 4658 4799 +141
- Misses 151 153 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…as in-place operations.
janosg
left a comment
There was a problem hiding this comment.
Maybe we should compare the speed of just the update step to see if the linear one is implemented efficiently. Without a detailed analysis I would expect the linear update step to be at least twice as fast as the unscented one. Of course, in a model with few factors or many measurements per factor, the unscented predict might not have been the bottleneck anyways.
src/skillmodels/kalman_filters.py
Outdated
| for i, factor in enumerate(latent_factors): | ||
| if i in constant_factor_indices: | ||
| row = jnp.zeros(n_all_factors).at[i].set(1.0) | ||
| f_rows.append(row) | ||
| c_vals.append(0.0) | ||
| else: | ||
| coeffs = trans_coeffs[factor] | ||
| f_rows.append(coeffs[:-1]) | ||
| c_vals.append(coeffs[-1]) | ||
|
|
||
| f_mat = jnp.stack(f_rows) # (n_latent, n_all) | ||
| c_vec = jnp.array(c_vals) # (n_latent,) |
There was a problem hiding this comment.
Looks suboptimal but maybe Jax is smart enough at compiling the small array creation away. Have you tried different implementations?
There was a problem hiding this comment.
Confirmed that Jax is smart enough. But I kept a more idiomatic version from the experiments and added a note.
|
Re: the question about a linear update step — The measurement model in skillmodels is always linear ( The QR decomposition in the update operates on an (this is Claude, obviously, but it does appear plausible without checking deeply) |
a5c2381 to
d7ba962
Compare
Add linear Kalman predict fast path
Fixes #36.
Summary
linear_kalman_predictthat uses direct matrix algebra (F @ x + c) instead of the unscented sigma-point transform, for models where all factors uselinearorconstanttransition functionsmaximization_inputs.pyauto-selects the fast path viais_all_linear()— no API changes neededpredict_funccallable instead of hardcodedkalman_predict+transition_funcBenchmark results
Tested on
health-cognition(no_feedback_to_investments_linear, 4 latent factors, GPU 8 GiB):om.Constraints(unscented)linear-predictThe main benefit is reduced GPU memory usage — the unscented transform generates 2n+1 sigma points which are expensive to differentiate through, while the linear path uses a single matrix multiply. On a small model (4 factors), the speed gain is modest (~6% GPU), but the memory reduction is the difference between fitting on GPU vs OOMing when memory is constrained.
Test plan
linear_kalman_predictandis_all_linearintest_kalman_filters.pyom.Constraintson real estimation task