Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernels should have kernel(state, *parameters) signature #40

Closed
rlouf opened this issue Oct 14, 2021 · 0 comments · Fixed by #41
Closed

Kernels should have kernel(state, *parameters) signature #40

rlouf opened this issue Oct 14, 2021 · 0 comments · Fixed by #41

Comments

@rlouf
Copy link
Member

rlouf commented Oct 14, 2021

We currently specialize the HMC and NUTS kernels in the factory using closures. However this is unpractical, we are moving away from this design in blackjax, see the related discussion.

The HMC kernel factory has the following signature

new_kernel(srrng, logprob_fn, step_size, inverse_mass_matrix, num_integration_steps, divergence_threshold)

and the HMC kernel:

kernel(q, log_prob, log_prob_grad)

And we suggest to instead have:

new_kernel(srng, logprob_fn, inverse_mass_matrix, num_integration_steps, divergence_threshold)
kernel(q, log_prob, log_prob_grad, step_size)

I bumped into this design issue while implementing algorithms for step size adaptation where we have to "create" as many kernels as we change the values of the parameters.

I think this issue should be addressed before we move forward with the adaptation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant