Skip to content

Make Discretizer use GaussianStateEvolution for increased filter compatibility#199

Open
mattlevine22 wants to merge 14 commits intomainfrom
ml-issue-115
Open

Make Discretizer use GaussianStateEvolution for increased filter compatibility#199
mattlevine22 wants to merge 14 commits intomainfrom
ml-issue-115

Conversation

@mattlevine22
Copy link
Copy Markdown
Collaborator

@mattlevine22 mattlevine22 commented Apr 14, 2026

Addresses issue #115

  • Gives GaussianStateEvolution a time-varying covariance field (or jax.array)
  • CD_Dynamax/Dynamax integration with GaussianStateEvolution errors if covariance is a callable
  • Refactor discretizers.py to build a class EulerMaruyamaGaussianStateEvolution(GaussianStateEvolution)
  • Added tests_discretizers.py which runs Cuthbert EKF with a discretized SDE model.

N.B. CD_Dynamax/Dynamax discrete-time integrations don't support time-varying F(, .. t) anyway

N.B.This PR only adds support for time-varying GaussianStateEvolution to enable discretized + discrete-time-filter patterns. It does NOT modify: [LinearGaussianStateEvolution, GaussianObservation, LinearGaussianObservation],

mattlevine22 and others added 6 commits March 26, 2026 19:30
Introduce solver modules for ODE and SDE paths, add an Euler-Maruyama scan source for SDESimulator, and update filter+simulator tests and continuous-time tutorial outputs to cover the new backend behavior.

Made-with: Cursor
Copy link
Copy Markdown
Collaborator

@DanWaxman DanWaxman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good in principle, maybe with some style comments. This should've run the hierarchical discretized smoke tests, but it would also be nice to run the 08_hierarchical_inference.ipynb tutorial notebook and ensure there's no qualitative change there (I don't expect there is).

Comment thread dynestyx/inference/integrations/cd_dynamax/utils.py
Comment thread dynestyx/discretizers.py Outdated
Comment on lines +44 to +45
Supports batched time (and optional control) matching the previous
`EulerMaruyamaGaussianStateEvolution` behavior.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

de-AI-ify pls :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread dynestyx/discretizers.py Outdated
Comment on lines +49 to +52
"loc": ndarray
Mean of next state(s): shape (dim_state,) or (num_timepoints, dim_state)
"cov": ndarray
Covariance of next state(s): shape (dim_state, dim_state) or (num_timepoints, dim_state, dim_state)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this allows batched shapes, this should be reflected in the docstring array shapes

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread dynestyx/discretizers.py Outdated
Comment on lines +131 to +134
The parent `GaussianStateEvolution.__call__` would evaluate the
Euler–Maruyama drift and diffusion twice. Under `jax.vmap` (e.g.
plate-batched cuthbert EKF), that split can change tracing/shapes. This
override matches the original one-step implementation.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread dynestyx/discretizers.py Outdated
Comment on lines +136 to +139
em_result = _euler_maruyama_loc_cov(self.cte, x, u, t_now, t_next)
return dist.MultivariateNormal(
loc=em_result["loc"], covariance_matrix=em_result["cov"]
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not so clear on the benefit of making this a helper function, to be honest. It seems like it can be reasonably in-lined here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want it to construct F and cov in the __init__ method, as well as use it in the __call__ method

Comment thread dynestyx/discretizers.py Outdated
Comment on lines +99 to +102
Holds ``cte`` as an explicit field so `DynamicalModel` pytrees under
`numpyro.plate` still expose batched continuous-time parameters for
simulator slicing (closures over ``cte`` alone would hide those arrays).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit in-the-weeds. Would prefer to just write that we hold cte for pytree compatibility.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@mattlevine22 mattlevine22 changed the base branch from main to ml-fast-sde-solver2 April 15, 2026 19:11
@mattlevine22 mattlevine22 changed the base branch from ml-fast-sde-solver2 to main April 15, 2026 19:13
@mattlevine22 mattlevine22 changed the base branch from main to ml-fast-sde-solver2 April 15, 2026 19:14
@DanWaxman DanWaxman changed the base branch from ml-fast-sde-solver2 to main April 22, 2026 14:15
@mattlevine22 mattlevine22 requested a review from DanWaxman April 22, 2026 19:13
@mattlevine22
Copy link
Copy Markdown
Collaborator Author

This PR only adds support for time-varying GaussianStateEvolution to enable discretized + discrete-time-filter patterns. It does not modify things like: [LinearGaussianStateEvolution, GaussianObservation, LinearGaussianObservation],

Comment thread dynestyx/solvers/sde.py
Comment on lines +50 to +59
def _bm_dim_or_default(state_evolution: ContinuousTimeStateEvolution) -> int:
"""Return Brownian dimension, defaulting to 1 when unspecified.

Args:
state_evolution: Continuous-time state evolution.

Returns:
Brownian motion dimension used by EM sampling.
"""
return int(state_evolution.bm_dim) if state_evolution.bm_dim is not None else 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we just set the corresponding state_evolution.bm_dim? Seems like needless indirection.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I don't really see how 1 is a reasonable default. If anything, we should raise or set up a probe state pattern to get the right shape.

Comment thread dynestyx/solvers/sde.py
Comment on lines +150 to +157
def euler_maruyama_step_loc_cov(
state_evolution: ContinuousTimeStateEvolution,
x: Array,
u: Array | None,
t_now: Array,
dt: Array,
) -> tuple[Array, Array]:
"""Compute one EM moment step over a fixed `dt`.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like a particularly simple function.

Comment thread dynestyx/inference/integrations/cd_dynamax/utils.py
Comment thread dynestyx/solvers/sde.py
Batched mode maps across the time axis, pairing
`x[:, i], u[:, i], t_now[i], t_next[i]` for each `i`.
Scalar inputs are promoted to a batch of size 1 internally and squeezed
back to single-transition outputs.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this counter to other places?

Comment thread dynestyx/utils.py
f"got {type(dt)}."
)


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we got rid of this in #178 ?

Comment thread dynestyx/discretizers.py
Comment on lines +40 to +41
``F`` and ``cov`` are optional constructor args so Equinox/dataclass-style
but we don't use them.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants