Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring HMC API #436

Merged
merged 46 commits into from
Jul 3, 2018
Merged

Refactoring HMC API #436

merged 46 commits into from
Jul 3, 2018

Conversation

xukai92
Copy link
Member

@xukai92 xukai92 commented Apr 16, 2018

Addressing #431

Still in working.

@xukai92
Copy link
Member Author

xukai92 commented Apr 16, 2018

2a4dfb3 solves #434 (comment)

@xukai92
Copy link
Member Author

xukai92 commented Apr 17, 2018

@yebai I've refactored core HMC code. Basically I added three functions _leapfrog, _find_H and _sample_momentum and make the original corresponding functions calling them by wrapping Turing.jl internal things. Can you have a look at current change? and let's discuss what's the best way to write the unit tests, e.g. do we simply write another non-Turing.jl based HMC and test them? or something else.

@yebai
Copy link
Member

yebai commented Apr 17, 2018

The code looks good to me. For unit tests, perhaps we can write 2-3 models without constrained parameters (e.g. gdemo; Bayesian logistic regression, stochastic volatility from the NUTS paper),

  • implement them in both Turing and plain Julia (using AD to compute gradients for the plain Julia version of the models)
  • run Turing's HMC sampler on both the Turing version and Julia version

This would allow us to check whether there are errors in the compiler, transformation or gradients. After this step is done, we can focus on debugging the HMC sampler itself by running it on the plain Julia version of models.

@xukai92
Copy link
Member Author

xukai92 commented Apr 26, 2018

Just a small update:

I'm working on this PR today. It turns out that rather than only having Turing-free HMC core functions, we should actually have Turing-free complete HMC step first. Then handling with vi, spl and model outside this level of abstraction (at the moment I think it's hard to do that on the whole HMC level). I'm working on this now.

@xukai92
Copy link
Member Author

xukai92 commented Apr 28, 2018

  • Simple Gaussian (with prior on mean only)
  • Bayesian logistic regression
  • stochastic volatility

@xukai92
Copy link
Member Author

xukai92 commented Apr 30, 2018

@yebai So now we have a simple Gaussian and a simple Bayesian linear regression. I'm going to implement a stochastic volatility today or tomorrow. But for the existing two without constrained variables, how should we really debug them?

@xukai92
Copy link
Member Author

xukai92 commented Apr 30, 2018

We actually want to implement a NUTS in the same way like _hmc_step() do we?

@yebai
Copy link
Member

yebai commented May 8, 2018

But for the existing two without constrained variables, how should we really debug them?

@xukai92 It's fine to first test HMC and NUTS on models without constrained variables, just to verify dual averaging and leap-frog integrator is correct. We can change the prior to truncated Gaussian when the initial tests are passed.

@yebai
Copy link
Member

yebai commented May 8, 2018

We actually want to implement a NUTS in the same way like _hmc_step() do we?

Yes, we would like to be able to unit test NUTS using the same set of models as HMC.

@xukai92
Copy link
Member Author

xukai92 commented May 22, 2018

Agreed work plan

  • implement Turing.jl free gdemo
  • implement Turing.jl free LDA

The followings are based on these two models. "Check" means doing unit tests.

  • check gradient
  • check leaf-frog
  • check dual-averaging
    • Didn't use Turing-free DA in the end
  • NUTS Turing.jl free implementation
  • check NUTS w/out DA
  • check NUTS w/ DA
  • check pre-cond. adaptation
    • Didn't use Turing-free pre-cond. adapt. in the end
    • Online update of pre-cond. already has its unit test
  • check initialization

@xukai92
Copy link
Member Author

xukai92 commented Jun 21, 2018

Thanks @willtebbutt and @wesselb - address them soon.

@xukai92
Copy link
Member Author

xukai92 commented Jun 25, 2018

Previous comments are resolved.

@yebai
Copy link
Member

yebai commented Jun 27, 2018

@willtebbutt could you take a look before I merge this PR? Thanks!

@yebai
Copy link
Member

yebai commented Jun 27, 2018

Previous comments are resolved.

Thanks, Kai!

@willtebbutt
Copy link
Member

willtebbutt commented Jun 27, 2018

Thanks for addressing our concerns @xukai92, just one remaining point. Are the @gen_local_grad_func (and related) macros really necessary? Would closures not suffice? i.e.

function gen_grad_func(vi, spl, model)
    return function::AbstractVector)
        if ADBACKEND == :forward_diff
            vi[spl] = θ
            grad = gradient(vi, model, spl)
        elseif ADBACKEND == :reverse_diff
            grad = gradient_r(θ, vi, model, spl)
        else
            error("An appropriate error.")
        end
        return getlogp(vi), grad
    end
end

It's really a matter of taste, but my impression is that a good general rule is to only use metaprogramming where it's really necessary, essentially just because it's tricky to read.

@xukai92
Copy link
Member Author

xukai92 commented Jun 28, 2018

I've replaced macros with closures. Thanks for pointing this - I agree we'd better use closures here.

@xukai92
Copy link
Member Author

xukai92 commented Jun 28, 2018

@willtebbutt Any other comments?

@willtebbutt
Copy link
Member

willtebbutt commented Jun 29, 2018

Apologies for the delay. Will review the technical details of the NUTS implementation tomorrow morning and, assuming that everything looks fine, I'm happy for this to be merged. (Wessel and I had to refresh our memories regarding NUTS, hence the delay)

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. We couldn't find any obvious technical errors in the NUTS implementation, other than concerns about the capping of j to 5 (see comment for details). There are a few stylistic things that it would be good to have addressed.

- ϵ : leapfrog step size
- H0 : initial H
- lj_func : function for log-joint
- grad_func : function for the gradient of log-joint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This documentation should be outside of the function. See here for examples: https://docs.julialang.org/en/stable/manual/documentation/

In this particular case, we need something like:

"""
    build_tree(θ::T, r::Vector, logu::Float64, v::Int, j::Int, ϵ::Float64, H0::Float64,
               lj_func::Function, grad_func::Function, stds::Vector) where {T<:Union{Vector,SubArray}}

Recursively build balanced tree.

Ref: Algorithm 6 on http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf

# Arguments:
- θ: an argument
- r: another argument
"""

to be consistent with the standard Julia conventions. If we then call ?_build_tree we will get the appropriate information displayed correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know the format. That should be a mistake during copying and pasting. Will fix it

Ref: Algorithm 6 on http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf
"""
function _build_tree(θ::T, r::Vector, logu::Float64, v::Int, j::Int, ϵ::Float64, H0::Float64,
lj_func::Function, grad_func::Function, stds::Vector; Δ_max=1000) where {T<:Union{Vector,SubArray}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason that we're using Union{Vector, SubArray} rather than just AbstractArray here?

Recursively build balanced tree.

Ref: Algorithm 6 on http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add documentation for stds and Δ_max


Ref: Algorithm 6 on http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf
"""
function _build_tree(θ::T, r::Vector, logu::Float64, v::Int, j::Int, ϵ::Float64, H0::Float64,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of arguments' types appear to be overly constrained. Is this necessary? For example, it's not clear to me that logu can't be constrained to be an AbstractFloat as opposed to Float64 -- similarly for ϵ andH0. Conversely, the type of Δ_max is not constrained at all. Should this be constrained to be an AbstractFloat also?

The same comment applies to r: could it be an AbstractVector instead of just a Vector. Same for stds.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense! Thanks for pointing out this!

end
end

function _nuts_step(θ, ϵ, lj_func, grad_func, stds)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to document this function properly, despite that fact it's not exposed publicly.

θm = θ; θp = θ; rm = r0; rp = r0; j = 0; θ_new = θ; n = 1; s = 1
local da_stat

while s == 1 && j <= 5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is j arbitrarily capped at 5? Surely this could potentially re-introduce the random-walk behaviour that we are trying to avoid?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practise we have to set such a condition for j otherwise the sampler can go super slow if the step-size happens to be very small. Stan uses 10 (maximum 2^10 evaluations) as default and 5 (maximum 2^5 evaluations) was originally chosen by Hong because we were bit slow back then.

Let me make this number an optional argument in the function. We probably need another interface change to support a user specified maximum j.

v = rand([-1, 1])

if v == -1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent addition of blank lines here. Maybe remove these?

@yebai
Copy link
Member

yebai commented Jul 3, 2018

@willtebbutt @xukai92 @wesselb many thanks for the hard work!

@yebai yebai merged commit b67e9e1 into master Jul 3, 2018
@yebai yebai deleted the refactor-hmc-api branch July 9, 2018 12:56
yebai pushed a commit that referenced this pull request Sep 18, 2018
* move gmm overwrite out core source code

* refactor find_H

* refactor sample momentum

* refactor lf step

* hmc step abstraction v1.0 done

* make hmcda using _hmc_step

* add a note

* add bayes lr

* add things to runtests

* add sv turing

* add Turing-free nuts

* bug free Turing-free nuts

* update reference

* add grad check unit test

* add gdemo

* add gdemo nuts

* restructure hmc_core tests

* NUTS works

* Change rand seed for test

* fix adapation condition bug

* add test REQUIRE

* Remove all benchmarks

* change test nuts file name

* rearrange hmc codes

* add unit test for leapfrog

* clean nuts test

* Remove obsolete dependence on deps.jl

* fix typo

* add new lines to the end of files

* rename file

* add new lines to the end of files

* use macros to gen functions with the same pattern

* resolve indentation

* add explict return

* resolve indentation

* more stable mh accept

* Remove unrelated notebook

* Unify the use of mh_accept

* replace macro by closure to gen local funcs

* improve doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants