Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unifying trajectories #214

Closed
wants to merge 27 commits into from
Closed

Unifying trajectories #214

wants to merge 27 commits into from

Conversation

xukai92
Copy link
Member

@xukai92 xukai92 commented Aug 5, 2020

This PR replaces #213.

This PR solves #103 by explicitly separating Trajectory from HMCKernel,
without introducing any functionality change.
I left some comments on how to unify the notation as discussed in #48,
in which I use τ to refer a (numerical) Hamiltonian trajectory and κ to refer a MCMC kernel.
The remaining would be done in a separate PR to make the review easier.

As a showcase, below is how to construct NUTS now:

τ = Trajectory(Leapfrog(1e-3), NoUTurn())
κ = HMCKernel(τ, MultinomialTS)

Old syntax is still supported e.g. as below

NUTS{TS, TC}(int::AbstractIntegrator) where {TS, TC} = HMCKernel(Trajectory(int, TC()), TS)

As suggested by #103 (comment),
fixed simulation steps or length are unified as termination criteria.
A nice benefit of this is that the parameters regarding trajectory simulation are clear:
e.g. FixedNSteps(10), NoUTurn(5, 1_000) or NoUTurn(max_depth=5).

As a result, the main transition function looks like this.

function transition(rng, h, κ::HMCKernel, z)
    @unpack τ, TS = κ
    τ = reconstruct(τ, integrator=jitter(rng, τ.integrator))
    z = refresh(rng, z, h)
    return transition(rng, h, τ, TS, z)
end

As a side product of this new interface,
it also unifies how integrators are jittered - all before simulating the trajectory in the same place where momentum variables are refreshed.
With this design, this is the only place where jitter is called.

A list of changes is

  1. The new interface
  2. GeneralisedNoUTurn -> NoUTurn and StrictGeneralisedNoUTurn -> StrictNoUTurn
  3. Store the termination statistics directly in the binary tree
    • Why? Because termination criteria like NoUTurn are in the same role as FixedNSteps now. To unify the design, they are only intended to store the algorithm parameters, e.g. number of steps, maximum depth, etc.
  4. All tests are updated accordingly
  5. A regression test is added to make sure there is no functionality change compared to ef6de39
    • It can only run locally for now using Julia 1.5 on macOS. This is because each Julia version or hardware would generate different chains and I only pre-generated those for my own machine.
    • Making this fully supported on CI is left as a future work.

@xukai92
Copy link
Member Author

xukai92 commented Aug 5, 2020

As mentioned in #213 (comment):

I added regression tests to compare the exact simulated chains and statistics in daade0a for the four variants.

  • Static HMC with MH
  • Static HMC with multinomial
  • NUTS with slice
  • NUTS with multinomial

These should be a good enough indicator that the functionality doesn't got changed.

This ensures that this PR doesn't change the functionality at all.

The regression test in this PR is added in 31c6517.

@codecov
Copy link

codecov bot commented Aug 5, 2020

Codecov Report

Merging #214 (242dc4e) into master (692e646) will decrease coverage by 1.26%.
The diff coverage is 86.11%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #214      +/-   ##
==========================================
- Coverage   89.13%   87.86%   -1.27%     
==========================================
  Files          16       16              
  Lines         672      676       +4     
==========================================
- Hits          599      594       -5     
- Misses         73       82       +9     
Impacted Files Coverage Δ
src/AdvancedHMC.jl 75.00% <44.44%> (-25.00%) ⬇️
src/sampler.jl 84.48% <91.66%> (-4.81%) ⬇️
src/trajectory.jl 95.41% <92.15%> (-0.56%) ⬇️
src/integrator.jl 93.61% <0.00%> (-0.27%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e2ce003...3c06f2d. Read the comment docs.

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed roughly half of this PR tonight. I'll review the rest tomorrow.

README.md Outdated Show resolved Hide resolved
src/AdvancedHMC.jl Outdated Show resolved Hide resolved
src/AdvancedHMC.jl Outdated Show resolved Hide resolved
src/AdvancedHMC.jl Outdated Show resolved Hide resolved
src/sampler.jl Outdated Show resolved Hide resolved
src/trajectory.jl Outdated Show resolved Hide resolved
src/trajectory.jl Outdated Show resolved Hide resolved
src/trajectory.jl Show resolved Hide resolved
src/trajectory.jl Outdated Show resolved Hide resolved
src/trajectory.jl Outdated Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
src/sampler.jl Outdated Show resolved Hide resolved
src/sampler.jl Outdated Show resolved Hide resolved
src/sampler.jl Show resolved Hide resolved
src/sampler.jl Show resolved Hide resolved
src/trajectory.jl Outdated Show resolved Hide resolved
src/trajectory.jl Outdated Show resolved Hide resolved
src/trajectory.jl Outdated Show resolved Hide resolved
src/trajectory.jl Outdated Show resolved Hide resolved
src/trajectory.jl Outdated Show resolved Hide resolved
@@ -0,0 +1,63 @@
using Test, Random, BSON, AdvancedHMC, Distributions, ForwardDiff

is_ef6de39 = isdefined(AdvancedHMC, :EndPointTS)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better way of performing this regression test is to check out a pinned version of AHMC, run the tests on CI machines with the pinned version and the current version. Then we can compare results without worrying machine-specific details.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree. I don't know how to do that though. Can you point me some example CI codes on doing so?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BenchmarkTools allows users to write tests, and run regression tests automatically. Our use here is a bit special, but the principles are the same I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you point me an example? I only found example where they are looking at regression in time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my view, no single version of the code should be considered the ground truth. Rather, we can only be confident that features covered by the test suite work correctly. Hence instead of checking out a previous version to test against, it's better to add any missing tests, ensuring they pass on both the old version and the new version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For performance regressions, you can use PkgBenchmark (see e.g. https://github.com/JuliaFolds/FLoops.jl/blob/master/.github/workflows/benchmark.yml), but I don't think that checks the outputs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we can keep PRs small, i.e. one PR for a change/feature. For big changes like this PR, we should try to do it in multiple steps. It would make reviewing code much easier, which leads to much quicker merging too.

test/sampler-vec.jl Outdated Show resolved Hide resolved
test/trajectory.jl Show resolved Hide resolved
test/trajectory.jl Outdated Show resolved Hide resolved
test/trajectory.jl Outdated Show resolved Hide resolved
@yebai
Copy link
Member

yebai commented Aug 8, 2020

Thanks, @xukai92. I've done one full pass of this PR; it looks overall good to me. I left many comments above.

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xukai92 Pls ping me when this is ready for another review.

@xukai92
Copy link
Member Author

xukai92 commented Aug 31, 2020

@yebai This is ready for another look.

Two things I didn't address are:

  1. Replacing rho by TurnStatistics
    • I feel it's better to do this in another PR as it involves changes in many places.
  2. Finding a better way to handle TS
    • I did rename it to trajectory_sampler_type which is more descriptive
    • Ideally I feel we need that we can avoid storing types in HMCKernel once we introduce TurnStatistics
      • As you suggested, we can parameterize TurnStatistics by termination criteria so does trajectory sampling methods
      • If we do this, we don't need to store turn statistics (for trajectory sampling) inside XXXTS anymore.
      • Then MultinomialTS() (instance) can be used over MultinomialTS (type) during construction, which I think is better.
    • Open for discussion though. What's your opinion?

Also, do you think we should introduce a dev branch to AHMC?

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an excellent refactor! I had a few minor questions/suggestions. Some of them are not necessarily on changes in this PR and don't need to be handled here, but they came up while I was reading.

src/trajectory.jl Outdated Show resolved Hide resolved
end

const NoUTurn = GeneralisedNoUTurn
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this alias?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to make the transition of calling the old ones ClassificNoUTurn and new ones just NoUTurn and StrictNoUTurn, i.e. by default we refer to the generalised ones if not specified.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it would make sense to go the other way around, then? Define NoUTurn and then make GeneralisedNoUTurn the alias so that users' legacy code works?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were some debates between me and Hong on if we want to make this transition, and we ended up with this compromise. So with the PR, the recommended way is still the previous version. And we want to use the shorter version privately for a while before making the transition.


Detect U turn for two phase points (`zleft` and `zright`) under given Hamiltonian `h`
using the (original) no-U-turn cirterion.

Ref: https://arxiv.org/abs/1111.4246, https://arxiv.org/abs/1701.02434
"""
function isterminated(h::Hamiltonian, t::BinaryTree{<:ClassicNoUTurn})
function isterminated(::ClassicNoUTurn, h::Hamiltonian, t::BinaryTree)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this should probably be in another PR.

@@ -587,79 +531,67 @@ function isterminated(h::Hamiltonian, t::BinaryTree{<:ClassicNoUTurn})
end

"""
isterminated(h::Hamiltonian, t::BinaryTree{<:GeneralisedNoUTurn})
$(SIGNATURES)

Detect U turn for two phase points (`zleft` and `zright`) under given Hamiltonian `h`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this docstring is outdated, since zleft and zright appear nowhere in the signature or body of either method.

"""
function isterminated(tc::StrictGeneralisedNoUTurn, h, t, tleft, tright)
# Step 0: original generalised U-turn check
s1 = isterminated(tc, h, t)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to return s1 if terminated, else s2 if terminated, else s3? Seems if s1 indicates is terminated, then we potentially do 3 times the work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea! Can you do that in your NUTS PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

) where {T<:BinaryTree{<:StrictGeneralisedNoUTurn}}
rho = tleft.c.rho + tright.zleft.r
function check_left_subtree(h::Hamiltonian, t::T, tleft::T, tright::T) where {T<:BinaryTree}
rho = tleft.rho + tright.zleft.r
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice if we could avoid this allocation. Performing two extra dots in generalised_uturn_criterion should be cheaper than performing this allocation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to make the proposed change if some benchmark shows it improves the performance.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a benchmark of one with the same time complexity:

julia> using LinearAlgebra, BenchmarkTools, Plots

julia> function foo(a, b, c, d)
           e = a + b
           f = c + d
           dot(e, f)
       end;

julia> function foo2(a, b, c, d)
           dot(a, c) + dot(a, d) + dot(b, c) + dot(b, d)
       end;

julia> ns = 10 .^ (0:6);

julia> times = map(ns) do n
           a, b, c, d = ntuple(_ -> randn(n), 4)
           (@belapsed($foo($a,$b,$c,$d)), @belapsed($foo2($a,$b,$c,$d)))
       end;

julia> plot(ns, [first.(times) last.(times)]; xscale=:log10, yscale=:log10, labels=["allocating" "non-allocating"])

tmp

I can include it in a more general allocation-reducing PR though.

) where {T<:BinaryTree{<:StrictGeneralisedNoUTurn}}
rho = tleft.zright.r + tright.c.rho
function check_right_subtree(h::Hamiltonian, t::T, tleft::T, tright::T) where {T<:BinaryTree}
rho = tleft.zright.r + tright.rho
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same point here as above.

src/trajectory.jl Outdated Show resolved Hide resolved
src/trajectory.jl Show resolved Hide resolved
xukai92 and others added 2 commits October 27, 2020 12:08
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xukai92, I've done another careful pass of this PR and left some comments. There is now some merge conflicts against the master branch. Maybe consider fixing that too.

Overall, I find reviewing this PR a bit painful ; ) I hope we can keep PRs small and incremental in the future given that the stake of introducing bugs is high. For a concrete example, the refactoring of Trajectories, TerminationCriteria, and the introduction of HMCKernel each would deserve its own PR...


struct HMC{TS} end
HMC{TS}(int::AbstractIntegrator, L) where {TS} = HMCKernel(Trajectory(int, FixedNSteps(L)), TS)
HMC(int::AbstractIntegrator, L) = HMC{MetropolisTS}(int, L)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider the following for clarity and performance?

Suggested change
HMC(int::AbstractIntegrator, L) = HMC{MetropolisTS}(int, L)
HMC(int::AbstractIntegrator, L) = HMCKernel(Trajectory(int, FixedNSteps(L)), MetropolisTS)

struct HMC{TS} end
HMC{TS}(int::AbstractIntegrator, L) where {TS} = HMCKernel(Trajectory(int, FixedNSteps(L)), TS)
HMC(int::AbstractIntegrator, L) = HMC{MetropolisTS}(int, L)
HMC(ϵ::AbstractScalarOrVec{<:Real}, L) = HMC{MetropolisTS}(Leapfrog(ϵ), L)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
HMC::AbstractScalarOrVec{<:Real}, L) = HMC{MetropolisTS}(Leapfrog(ϵ), L)
HMC::AbstractScalarOrVec{<:Real}, L) = HMCKernel(Trajectory(Leapfrog(ϵ), FixedNSteps(L)), MetropolisTS)


struct StaticTrajectory{TS} end
@deprecate StaticTrajectory{TS}(args...) where {TS} HMC{TS}(args...)
@deprecate StaticTrajectory(args...) HMC(args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here, consider calling HMCKernel for clarity.

src/AdvancedHMC.jl Outdated Show resolved Hide resolved
@@ -30,12 +30,48 @@ stat(t::Transition) = t.stat
"""
Abstract Markov chain Monte Carlo proposal.
"""
abstract type AbstractProposal end
abstract type AbstractKernel end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider the following for clarity?

Suggested change
abstract type AbstractKernel end
abstract type AbstractMCMCKernel end

src/trajectory.jl Outdated Show resolved Hide resolved
H′ = energy(z′)
ΔH = H′ - H0
α′ = exp(min(0, -ΔH))
sampler′ = S(sampler, H0, z′)
return BinaryTree(z′, z′, C(z′), α′, 1, ΔH), sampler′, Termination(sampler′, nt, H0, H′)
sampler′ = TS(sampler, H0, z′)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using reconstruct here for consistency and clarity.

h::Hamiltonian,
τ::Trajectory{I, C},
::Type{TS},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should avoid passing this ::Type{TS} around.

src/trajectory.jl Outdated Show resolved Hide resolved
@@ -0,0 +1,63 @@
using Test, Random, BSON, AdvancedHMC, Distributions, ForwardDiff

is_ef6de39 = isdefined(AdvancedHMC, :EndPointTS)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we can keep PRs small, i.e. one PR for a change/feature. For big changes like this PR, we should try to do it in multiple steps. It would make reviewing code much easier, which leads to much quicker merging too.

@yebai
Copy link
Member

yebai commented Jan 6, 2021

Ps. I now feel the introduction of HMCKernel may not be necessary and potentially confusing. It is no longer necessary if we remove the TS field. It is inaccurate because, in HMC variants based on dynamic trajectories, the chosen trajectory sampler and the trajectory termination criteria (e.g. no-U-turn) is often coupled. That is, the step of simulating a trajectory and the step of picking a candidate phase point is coupled. We often iterate these two steps to get a final proposal. It is not easy to disentangle these two steps without introducing significant performance regression and more memory usage.

@xukai92
Copy link
Member Author

xukai92 commented Jan 6, 2021

Hi @xukai92, I've done another careful pass of this PR and left some comments. There is now some merge conflicts against the master branch. Maybe consider fixing that too.

Thanks for taking another pass. I will address the comments soon.

Overall, I find reviewing this PR a bit painful ; ) I hope we can keep PRs small and incremental in the future given that the stake of introducing bugs is high. For a concrete example, the refactoring of Trajectories, TerminationCriteria, and the introduction of HMCKernel each would deserve its own PR...

Sorry about this. I'm actually learning towards separating it. I will work on extracting out the first PR after cooperating your suggestions.

Ps. I now feel the introduction of HMCKernel may not be necessary and potentially confusing. It is no longer necessary if we remove the TS field.

I completely agree that if we remove TS filed, we can probably merge HMCKernel and Trajectory . But just as a reminder, it has a few benefits in the future

  1. Compatibility: I don't think we will be happy if we were to compose a HMC Trajectory with a MH kernel.
  2. Modularity of momentum behaviour: partial momentum, coupling can be made modular; and I believe it's better to consider this step as part of the MCMC kernel
  3. This last point is questionable: I kind of feel we should try to make the interface working on instances rather than types, it possible. For example, NUTS{MultinomialTS, GeneralisedNoUTurn}(Leapfrog(1e-3)) -> HMCKernel(Trajectory(Leapfrog(1e-3), NoUTurn()), MultinomialTS()). But anyway this is just my very initial thoughts on this.

But maybe I could just introduce it later when introducing the modularity of momentum, as we are agreeing on making separate PRs now.

It is inaccurate because, in HMC variants based on dynamic trajectories, the chosen trajectory sampler and the trajectory termination criteria (e.g. no-U-turn) is often coupled. That is, the step of simulating a trajectory and the step of picking a candidate phase point is coupled. We often iterate these two steps to get a final proposal. It is not easy to disentangle these two steps without introducing significant performance regression and more memory usage.

Yes I see that these two steps are coupled for dynamics trajectories; it is exactly the reason we are dispatching on termination and trajectory sampler together, right? Not sure what do you mean by "inaccurate" here. I still think they (trajectory and its sampler) are conceptually two concepts; it's just

  • For static trajectories, even the implementation can be (but not necessarily) separated.
  • For dynamic trajectories, it turns out to be the case that we need to dispatch on them together as the implementation is coupled.

xukai92 and others added 2 commits January 8, 2021 14:00
Co-authored-by: Hong Ge <hg344@cam.ac.uk>
@xukai92 xukai92 closed this Jan 12, 2021
@yebai yebai deleted the kx/unify-trajectory-1 branch March 11, 2022 14:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants