-
Notifications
You must be signed in to change notification settings - Fork 231
Allow more AdvancedHMC options in HMC types #1818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Some things I don't like about the current interface:
|
|
@sethaxen apologies for the slow response, I'll take a look later this week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks, @sethaxen - very nice improvements. I like your design in general. Only some (very) minor comments below. I am happy to merge as-is if you like.
EDIT: seems that we need to fix some tests before merging.
| Δ_max::Float64 | ||
| ϵ::Float64 # (initial) step size | ||
| metric::metricT | ||
| integrator::I |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A minor comment: integrator seems a bit redundant with ϵ here. Maybe consider removing the ϵ field in this type, and storing the information in DefaultIntegrator when a user specifies ϵ in the constructor?
| NUTS{AD}(-1, 0.65; kwargs...) | ||
| end | ||
|
|
||
| function NUTS{AD}( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for cleaning up the NUTS constructor.
| ### | ||
| ### Default options | ||
| ### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it make sense to add the default options to AdvancedHMC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think that would make for a cleaner design in the end.
| TS<:AHMC.AbstractTrajectorySampler, | ||
| TC<:AHMC.AbstractTerminationCriterion, | ||
| I<:AHMC.AbstractIntegrator, | ||
| A<:AHMC.AbstractAdaptor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it sufficient to make these type parameters without corresponding fields (apart from I)? I.e., do the types contain all information or do we need values?
| Δ_max::Float64, | ||
| ϵ::Float64, | ||
| ::Type{metricT}, | ||
| metricT::Type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you change this?
| end | ||
|
|
||
| function NUTS{AD}(kwargs...) where AD | ||
| function NUTS{AD}(; kwargs...) where AD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the current code did not work? Otherwise this would be a breaking change it seems (positional arguments replaced with keyword arguments).
| function NUTS{AD}(δ::Float64; kwargs...) where AD | ||
| NUTS{AD}(-1, δ; kwargs...) | ||
| end | ||
|
|
||
| function NUTS{AD}(kwargs...) where AD | ||
| function NUTS{AD}(; kwargs...) where AD | ||
| NUTS{AD}(-1, 0.65; kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this can be simplified to (can't make a suggestion)
function NUTS{AD}(δ::Float64=0.65; kwargs...) where AD
NUTS{AD}(-1, δ; kwargs...)| if spl.alg.integrator isa DefaultIntegrator | ||
| if iszero(spl.alg.ϵ) | ||
| ϵ = AHMC.find_good_stepsize(h, θ_init) | ||
| @info "Found initial step size" ϵ | ||
| else | ||
| ϵ = spl.alg.ϵ | ||
| end | ||
| else | ||
| ϵ = spl.alg.ϵ | ||
| ϵ = AHMC.step_size(spl.alg.integrator) | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this could be refactored into a separate function since it's the same as above?
|
|
||
| return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z) | ||
| adaptor = as_concrete(spl.alg.adaptor, metric; n_adapts=spl.alg.n_adapts, ϵ=ϵ, δ=spl.alg.δ) | ||
| return HMCState(vi, 0, 0, kernel.τ, h, spl.alg.adaptor, t.z) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this should be
| return HMCState(vi, 0, 0, kernel.τ, h, spl.alg.adaptor, t.z) | |
| return HMCState(vi, 0, 0, kernel.τ, h, adaptor, t.z) |
|
Closed in favour of #1997 |
As discussed on Slack, it would be very useful to expose more AdvancedHMC configuration options to the
NUTS,HMC, andHMCDAconvenience constructors here. In particular, it would be nice to be able to customize the integrator, metric (values), adaptor, or termination criterion.As far as I can tell, the main challenge here is that the initial step size and metric are model-dependent, while
NUTSetc should be model-independent. The approach this PR takes is to introduce default metric types that, if not replaced by a specific user-provided object, are internally constructed using the model.Currently this is only set up for
NUTS. If this seems like a useful general approach to other devs, I'll continue with theHMCandHMCDAsamplers.example
cc @cpfiffer @yebai