Skip to content

Predict Times#135

Open
DanWaxman wants to merge 33 commits intomainfrom
dw-ml-predict-times
Open

Predict Times#135
DanWaxman wants to merge 33 commits intomainfrom
dw-ml-predict-times

Conversation

@DanWaxman
Copy link
Collaborator

Addresses #50 #116

@mattlevine22 mattlevine22 marked this pull request as ready for review March 16, 2026 21:25
@mattlevine22
Copy link
Collaborator

@DanWaxman I THINK this PR is more-or-less ready to be looked at.

@DanWaxman
Copy link
Collaborator Author

This looks great to a first read! I'll read in more depth tomorrow. In the mean time, we should write up & update the PR description. This should be squash & merged, so the description should be fairly complete as to sufficiently describe the progress & design choices on its own.

@DanWaxman DanWaxman requested a review from mattlevine22 March 17, 2026 13:33
Comment on lines +90 to +96
if ctrl_values is None:
inputs = jnp.zeros((T1, control_dim))
elif ctrl_values.shape[0] > T1:
inds = jnp.searchsorted(ctrl_times, obs_times, side="left")
inputs = ctrl_values[inds]
else:
inputs = ctrl_values
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ctrl_values is None:
inputs = jnp.zeros((T1, control_dim))
elif ctrl_values.shape[0] > T1:
inds = jnp.searchsorted(ctrl_times, obs_times, side="left")
inputs = ctrl_values[inds]
else:
inputs = ctrl_values
if ctrl_values is None:
inputs = jnp.zeros((T1, control_dim))
elif ctrl_values.shape[0] > T1:
# Find controls aligned to obs_times
inds = jnp.searchsorted(ctrl_times, obs_times, side="left")
inputs = ctrl_values[inds]
else:
# Controls should align exactly with obs_times
inputs = ctrl_values

Comment on lines +132 to +139
if ctrl_values is None:
inputs = jnp.zeros((T1, control_dim))
elif ctrl_values.shape[0] > T1:
# ctrl spans union of obs_times and predict_times; filter needs ctrl at obs_times only
inds = jnp.searchsorted(ctrl_times, obs_times, side="left")
inputs = ctrl_values[inds]
else:
inputs = ctrl_values
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ctrl_values is None:
inputs = jnp.zeros((T1, control_dim))
elif ctrl_values.shape[0] > T1:
# ctrl spans union of obs_times and predict_times; filter needs ctrl at obs_times only
inds = jnp.searchsorted(ctrl_times, obs_times, side="left")
inputs = ctrl_values[inds]
else:
inputs = ctrl_values
if ctrl_values is None:
inputs = jnp.zeros((T1, control_dim))
elif ctrl_values.shape[0] > T1:
# Find controls aligned to obs_times
inds = jnp.searchsorted(ctrl_times, obs_times, side="left")
inputs = ctrl_values[inds]
else:
# Controls should align exactly with obs_times
inputs = ctrl_values

Comment on lines -250 to -257
Args:
name: Name of the factor.
dynamics: Dynamical model to filter.
filter_config: KFConfig, EKFConfig, or UKFConfig.
obs_times: Observation times.
obs_values: Observed values.
ctrl_times: Control times (optional).
ctrl_values: Control values (optional).
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why removed?

record_kwargs = _config_to_record_kwargs(filter_config)

ys = obs_values
t1 = int(ys.shape[0]) # this is T+1 in cuthbert's convention
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
t1 = int(ys.shape[0]) # this is T+1 in cuthbert's convention
T1 = int(ys.shape[0]) # this is T+1 in cuthbert's convention

Let's use T1 to mirror the cddynamax conventions

if ctrl_values is None:
control_dim = dynamics.control_dim
ctrl_values = jnp.zeros((t1, control_dim), dtype=ys.dtype)
elif ctrl_values.shape[0] > t1:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif ctrl_values.shape[0] > t1:
elif ctrl_values.shape[0] > T1:


return jax.vmap(_obs_step)(jnp.arange(T))
else:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Comment on lines +530 to +533
# TODO: Handle this case.
raise NotImplementedError(
"this is to-be-implemented. Should pass forward whatever is from previous operator in **kwargs."
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:)

"Construct dynamics via DynamicalModel before simulation."
t0 = dynamics.t0 if dynamics.t0 is not None else times[0]

def _run_one_from_x0(key: Array, x0: Array) -> tuple[Array, Array]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Run one what?

f"y_{t_idx + 1}",
dynamics.observation_model(x=x_t, u=u_next, t=t_next),
obs=_get_val_or_None(obs_values, t_idx + 1),
# n_simulations > 1: vmap over scan with dist.sample (no numpyro.sample in body)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should exaplain this more -- is it because numpyro doesn't place nice with the vmap?

Comment on lines 92 to 97
"""
Validate that ctrl_times/ctrl_values align with obs_times if provided.

Raises:
ValueError: If control times length doesn't match observation times length.
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain all the rules that are being tested

@DanWaxman
Copy link
Collaborator Author

@mattlevine22 left feedback

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support non-filtered predict_times for trajectory-based evaluation Forecast evaluations will need a posterior "rollout" mode

2 participants