Conversation
…crete-time); new demos in tutorials
…ick likelihood profile to quickstart
|
@DanWaxman I THINK this PR is more-or-less ready to be looked at. |
|
This looks great to a first read! I'll read in more depth tomorrow. In the mean time, we should write up & update the PR description. This should be squash & merged, so the description should be fairly complete as to sufficiently describe the progress & design choices on its own. |
| if ctrl_values is None: | ||
| inputs = jnp.zeros((T1, control_dim)) | ||
| elif ctrl_values.shape[0] > T1: | ||
| inds = jnp.searchsorted(ctrl_times, obs_times, side="left") | ||
| inputs = ctrl_values[inds] | ||
| else: | ||
| inputs = ctrl_values |
There was a problem hiding this comment.
| if ctrl_values is None: | |
| inputs = jnp.zeros((T1, control_dim)) | |
| elif ctrl_values.shape[0] > T1: | |
| inds = jnp.searchsorted(ctrl_times, obs_times, side="left") | |
| inputs = ctrl_values[inds] | |
| else: | |
| inputs = ctrl_values | |
| if ctrl_values is None: | |
| inputs = jnp.zeros((T1, control_dim)) | |
| elif ctrl_values.shape[0] > T1: | |
| # Find controls aligned to obs_times | |
| inds = jnp.searchsorted(ctrl_times, obs_times, side="left") | |
| inputs = ctrl_values[inds] | |
| else: | |
| # Controls should align exactly with obs_times | |
| inputs = ctrl_values |
| if ctrl_values is None: | ||
| inputs = jnp.zeros((T1, control_dim)) | ||
| elif ctrl_values.shape[0] > T1: | ||
| # ctrl spans union of obs_times and predict_times; filter needs ctrl at obs_times only | ||
| inds = jnp.searchsorted(ctrl_times, obs_times, side="left") | ||
| inputs = ctrl_values[inds] | ||
| else: | ||
| inputs = ctrl_values |
There was a problem hiding this comment.
| if ctrl_values is None: | |
| inputs = jnp.zeros((T1, control_dim)) | |
| elif ctrl_values.shape[0] > T1: | |
| # ctrl spans union of obs_times and predict_times; filter needs ctrl at obs_times only | |
| inds = jnp.searchsorted(ctrl_times, obs_times, side="left") | |
| inputs = ctrl_values[inds] | |
| else: | |
| inputs = ctrl_values | |
| if ctrl_values is None: | |
| inputs = jnp.zeros((T1, control_dim)) | |
| elif ctrl_values.shape[0] > T1: | |
| # Find controls aligned to obs_times | |
| inds = jnp.searchsorted(ctrl_times, obs_times, side="left") | |
| inputs = ctrl_values[inds] | |
| else: | |
| # Controls should align exactly with obs_times | |
| inputs = ctrl_values |
| Args: | ||
| name: Name of the factor. | ||
| dynamics: Dynamical model to filter. | ||
| filter_config: KFConfig, EKFConfig, or UKFConfig. | ||
| obs_times: Observation times. | ||
| obs_values: Observed values. | ||
| ctrl_times: Control times (optional). | ||
| ctrl_values: Control values (optional). |
| record_kwargs = _config_to_record_kwargs(filter_config) | ||
|
|
||
| ys = obs_values | ||
| t1 = int(ys.shape[0]) # this is T+1 in cuthbert's convention |
There was a problem hiding this comment.
| t1 = int(ys.shape[0]) # this is T+1 in cuthbert's convention | |
| T1 = int(ys.shape[0]) # this is T+1 in cuthbert's convention |
Let's use T1 to mirror the cddynamax conventions
| if ctrl_values is None: | ||
| control_dim = dynamics.control_dim | ||
| ctrl_values = jnp.zeros((t1, control_dim), dtype=ys.dtype) | ||
| elif ctrl_values.shape[0] > t1: |
There was a problem hiding this comment.
| elif ctrl_values.shape[0] > t1: | |
| elif ctrl_values.shape[0] > T1: |
|
|
||
| return jax.vmap(_obs_step)(jnp.arange(T)) | ||
| else: | ||
|
|
There was a problem hiding this comment.
| # TODO: Handle this case. | ||
| raise NotImplementedError( | ||
| "this is to-be-implemented. Should pass forward whatever is from previous operator in **kwargs." | ||
| ) |
| "Construct dynamics via DynamicalModel before simulation." | ||
| t0 = dynamics.t0 if dynamics.t0 is not None else times[0] | ||
|
|
||
| def _run_one_from_x0(key: Array, x0: Array) -> tuple[Array, Array]: |
| f"y_{t_idx + 1}", | ||
| dynamics.observation_model(x=x_t, u=u_next, t=t_next), | ||
| obs=_get_val_or_None(obs_values, t_idx + 1), | ||
| # n_simulations > 1: vmap over scan with dist.sample (no numpyro.sample in body) |
There was a problem hiding this comment.
Should exaplain this more -- is it because numpyro doesn't place nice with the vmap?
| """ | ||
| Validate that ctrl_times/ctrl_values align with obs_times if provided. | ||
|
|
||
| Raises: | ||
| ValueError: If control times length doesn't match observation times length. | ||
| """ |
There was a problem hiding this comment.
Explain all the rules that are being tested
|
@mattlevine22 left feedback |
Addresses #50 #116