Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dtype errors with adaptive integrator #1480

Open
jwnys opened this issue May 23, 2023 · 2 comments
Open

dtype errors with adaptive integrator #1480

jwnys opened this issue May 23, 2023 · 2 comments

Comments

@jwnys
Copy link
Collaborator

jwnys commented May 23, 2023

When using an adaptive integrator and a dtype=jnp.float32 for parameters, samples, etc, the adaptive integrator raises errors. The errors arise due to the following:

  • the two dt's in the jax.lax.cond in the accepted case can be different (next_dt vs rk_state.dt)
  • the error norm inherits the dtype from the variational state, and so the replaced last_norm and last_scaled_error can be different in the accepted case
  • the last happens because we initialize the last_norm e.g. with 0. in the adaptive case, which can be float32 after.

I'm not sure what the best solution is (I wasn't even expecting errors of this kind tbh), but some possibilities are:

  • make sure all replaced dt's have the same dtype as rk_state.dt
  • initialize last_norm and last_scaled_error fields with a jnp.array with a predefined fixed dtype (not clear how to determine this in general, but float64 would make sense)
  • initialize and convert everything to float64 in the RKState.
@PhilipVinc
Copy link
Member

Thanks, it all makes sense.
Getting things to work with non standard dtypes is always a mess...

In particular

make sure all replaced dt's have the same dtype as rk_state.dt

I agree

initialize last_norm and last_scaled_error fields with a jnp.array with a predefined fixed dtype (not clear how to determine this in general, but float64 would make sense)

I would initialise it with the dtype of the output of the error_norm.
You can get it from abstract interpretation, aka jax.eval_shape

Those two should be enough, I think? I don't think that there are other fields that might change... We already forcibly enforce that the dtype of the parameters does not change (for the same reason).

@PhilipVinc
Copy link
Member

Care to make a PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants