-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring HMC API #436
Refactoring HMC API #436
Conversation
2a4dfb3 solves #434 (comment) |
@yebai I've refactored core HMC code. Basically I added three functions |
The code looks good to me. For unit tests, perhaps we can write 2-3 models without constrained parameters (e.g. gdemo; Bayesian logistic regression, stochastic volatility from the NUTS paper),
This would allow us to check whether there are errors in the compiler, transformation or gradients. After this step is done, we can focus on debugging the HMC sampler itself by running it on the plain Julia version of models. |
Just a small update: I'm working on this PR today. It turns out that rather than only having Turing-free HMC core functions, we should actually have Turing-free complete HMC step first. Then handling with |
|
@yebai So now we have a simple Gaussian and a simple Bayesian linear regression. I'm going to implement a stochastic volatility today or tomorrow. But for the existing two without constrained variables, how should we really debug them? |
We actually want to implement a NUTS in the same way like |
@xukai92 It's fine to first test HMC and NUTS on models without constrained variables, just to verify dual averaging and leap-frog integrator is correct. We can change the prior to truncated Gaussian when the initial tests are passed. |
Yes, we would like to be able to unit test NUTS using the same set of models as HMC. |
Agreed work plan
The followings are based on these two models. "Check" means doing unit tests.
|
Thanks @willtebbutt and @wesselb - address them soon. |
Previous comments are resolved. |
@willtebbutt could you take a look before I merge this PR? Thanks! |
Thanks, Kai! |
Thanks for addressing our concerns @xukai92, just one remaining point. Are the function gen_grad_func(vi, spl, model)
return function (θ::AbstractVector)
if ADBACKEND == :forward_diff
vi[spl] = θ
grad = gradient(vi, model, spl)
elseif ADBACKEND == :reverse_diff
grad = gradient_r(θ, vi, model, spl)
else
error("An appropriate error.")
end
return getlogp(vi), grad
end
end It's really a matter of taste, but my impression is that a good general rule is to only use metaprogramming where it's really necessary, essentially just because it's tricky to read. |
I've replaced macros with closures. Thanks for pointing this - I agree we'd better use closures here. |
@willtebbutt Any other comments? |
Apologies for the delay. Will review the technical details of the NUTS implementation tomorrow morning and, assuming that everything looks fine, I'm happy for this to be merged. (Wessel and I had to refresh our memories regarding NUTS, hence the delay) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good. We couldn't find any obvious technical errors in the NUTS implementation, other than concerns about the capping of j
to 5 (see comment for details). There are a few stylistic things that it would be good to have addressed.
src/samplers/nuts.jl
Outdated
- ϵ : leapfrog step size | ||
- H0 : initial H | ||
- lj_func : function for log-joint | ||
- grad_func : function for the gradient of log-joint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This documentation should be outside of the function. See here for examples: https://docs.julialang.org/en/stable/manual/documentation/
In this particular case, we need something like:
"""
build_tree(θ::T, r::Vector, logu::Float64, v::Int, j::Int, ϵ::Float64, H0::Float64,
lj_func::Function, grad_func::Function, stds::Vector) where {T<:Union{Vector,SubArray}}
Recursively build balanced tree.
Ref: Algorithm 6 on http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf
# Arguments:
- θ: an argument
- r: another argument
"""
to be consistent with the standard Julia conventions. If we then call ?_build_tree
we will get the appropriate information displayed correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know the format. That should be a mistake during copying and pasting. Will fix it
src/samplers/nuts.jl
Outdated
Ref: Algorithm 6 on http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf | ||
""" | ||
function _build_tree(θ::T, r::Vector, logu::Float64, v::Int, j::Int, ϵ::Float64, H0::Float64, | ||
lj_func::Function, grad_func::Function, stds::Vector; Δ_max=1000) where {T<:Union{Vector,SubArray}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a particular reason that we're using Union{Vector, SubArray}
rather than just AbstractArray
here?
src/samplers/nuts.jl
Outdated
Recursively build balanced tree. | ||
|
||
Ref: Algorithm 6 on http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add documentation for stds
and Δ_max
src/samplers/nuts.jl
Outdated
|
||
Ref: Algorithm 6 on http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf | ||
""" | ||
function _build_tree(θ::T, r::Vector, logu::Float64, v::Int, j::Int, ϵ::Float64, H0::Float64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of arguments' types appear to be overly constrained. Is this necessary? For example, it's not clear to me that logu
can't be constrained to be an AbstractFloat
as opposed to Float64
-- similarly for ϵ
andH0
. Conversely, the type of Δ_max
is not constrained at all. Should this be constrained to be an AbstractFloat
also?
The same comment applies to r
: could it be an AbstractVector
instead of just a Vector
. Same for stds
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense! Thanks for pointing out this!
src/samplers/nuts.jl
Outdated
end | ||
end | ||
|
||
function _nuts_step(θ, ϵ, lj_func, grad_func, stds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to document this function properly, despite that fact it's not exposed publicly.
src/samplers/nuts.jl
Outdated
θm = θ; θp = θ; rm = r0; rp = r0; j = 0; θ_new = θ; n = 1; s = 1 | ||
local da_stat | ||
|
||
while s == 1 && j <= 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is j
arbitrarily capped at 5
? Surely this could potentially re-introduce the random-walk behaviour that we are trying to avoid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practise we have to set such a condition for j
otherwise the sampler can go super slow if the step-size happens to be very small. Stan uses 10 (maximum 2^10 evaluations) as default and 5 (maximum 2^5 evaluations) was originally chosen by Hong because we were bit slow back then.
Let me make this number an optional argument in the function. We probably need another interface change to support a user specified maximum j
.
src/samplers/nuts.jl
Outdated
v = rand([-1, 1]) | ||
|
||
if v == -1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent addition of blank lines here. Maybe remove these?
@willtebbutt @xukai92 @wesselb many thanks for the hard work! |
* move gmm overwrite out core source code * refactor find_H * refactor sample momentum * refactor lf step * hmc step abstraction v1.0 done * make hmcda using _hmc_step * add a note * add bayes lr * add things to runtests * add sv turing * add Turing-free nuts * bug free Turing-free nuts * update reference * add grad check unit test * add gdemo * add gdemo nuts * restructure hmc_core tests * NUTS works * Change rand seed for test * fix adapation condition bug * add test REQUIRE * Remove all benchmarks * change test nuts file name * rearrange hmc codes * add unit test for leapfrog * clean nuts test * Remove obsolete dependence on deps.jl * fix typo * add new lines to the end of files * rename file * add new lines to the end of files * use macros to gen functions with the same pattern * resolve indentation * add explict return * resolve indentation * more stable mh accept * Remove unrelated notebook * Unify the use of mh_accept * replace macro by closure to gen local funcs * improve doc
Addressing #431
Still in working.