Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow usage of AbstractSampler #2008

Merged
merged 22 commits into from
Jul 5, 2023
Merged

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Jun 13, 2023

This is a very rough idea of how we could go about supporting AbstractMCMC.AbstractSampler to reduce the necessary glue code.

With this, the following works:

julia> using Revise, Turing

julia> @model demo() = x ~ Normal()
demo (generic function with 2 methods)

julia> model = demo()
DynamicPPL.Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}(demo, NamedTuple(), NamedTuple(), DynamicPPL.DefaultContext())

julia> sampler = Turing.Inference.initialize_nuts(model);  # NOTE: This has been moved to the tests for now.

julia> chain = sample(model, externalsampler(sampler), 1000; nadapts=500, discard_initial=500)
Sampling 100%|██████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (1000×14×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 0.16 seconds
Compute duration  = 0.16 seconds
parameters        = x
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, is_adapt

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

           x    0.0692    0.9664    0.0528   339.0005   512.5947    0.9991     2105.5929

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           x   -1.7927   -0.5420    0.0395    0.7099    1.9460

@torfjelde torfjelde marked this pull request as draft June 13, 2023 11:24
src/inference/Inference.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member Author

As demonstrated in #2011 it seems supporting this properly is somewhat non-trivial, and so for the moment I've decided to just introduce a simple SamplerWrapper <: InferenceAlgorithm which wraps a AbstractMCMC.AbstractSampler but inherits then inherits everything else that comes with being an InferenceAlgorithm.

The annoying thing (other than it being ugly) is that the user has to wrap the AbstractMCMC.AbstractSampler themselves, since otherwise we run into ambuigity issues as before. The positives: any sampler using LogDensityFunction is now usable and integrates "nicely" with Turing.

@coveralls
Copy link

coveralls commented Jun 18, 2023

Pull Request Test Coverage Report for Build 5457548017

  • 0 of 32 (0.0%) changed or added relevant lines in 2 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage remained the same at 0.0%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/inference/Inference.jl 0 1 0.0%
src/contrib/inference/abstractmcmc.jl 0 31 0.0%
Files with Coverage Reduction New Missed Lines %
src/inference/Inference.jl 1 0%
Totals Coverage Status
Change from base Build 5455388390: 0.0%
Covered Lines: 0
Relevant Lines: 1458

💛 - Coveralls

@codecov
Copy link

codecov bot commented Jun 18, 2023

Codecov Report

Patch and project coverage have no change.

Comparison is base (0dee0f8) 0.00% compared to head (b0503e5) 0.00%.

Additional details and impacted files
@@          Coverage Diff           @@
##           master   #2008   +/-   ##
======================================
  Coverage    0.00%   0.00%           
======================================
  Files          21      22    +1     
  Lines        1426    1458   +32     
======================================
- Misses       1426    1458   +32     
Impacted Files Coverage Δ
src/Turing.jl 0.00% <ø> (ø)
src/contrib/inference/abstractmcmc.jl 0.00% <0.00%> (ø)
src/inference/Inference.jl 0.00% <0.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@torfjelde torfjelde marked this pull request as ready for review June 20, 2023 15:10
@torfjelde torfjelde changed the title [WIP] Allow usage of AbstractSampler Allow usage of AbstractSampler Jun 20, 2023
@torfjelde
Copy link
Member Author

I believe the tests are sometimes failing simply because of numerical issues.

Some of the test models, e.g. demo_dot_assume_observe_index_literal, have one variance parameter for every observation. This means that the MLE estimate will have the variance parameter s be 0. Most of the time this seems to work just fine, i.e. LBFGS finds something close to 0, but sometimes it fails.

@yebai yebai merged commit d89dae3 into master Jul 5, 2023
13 checks passed
@yebai yebai deleted the torfjelde/allow-abstractsampler-draft branch July 5, 2023 08:25
Comment on lines +6 to +37
struct TuringTransition{T,NT<:NamedTuple,F<:AbstractFloat}
θ::T
lp::F
stat::NT
end

function TuringTransition(vi::AbstractVarInfo, t)
theta = tonamedtuple(vi)
lp = getlogp(vi)
return TuringTransition(theta, lp, getstats(t))
end

metadata(t::TuringTransition) = merge((lp = t.lp,), t.stat)
DynamicPPL.getlogp(t::TuringTransition) = t.lp

state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
θ = getparams(transition)
varinfo = DynamicPPL.unflatten(f.varinfo, θ)
# TODO: `deepcopy` is overkill; make more efficient.
varinfo = DynamicPPL.invlink!!(deepcopy(varinfo), f.model)
return TuringTransition(varinfo, transition)
end

# NOTE: Only thing that depends on the underlying sampler.
# Something similar should be part of AbstractMCMC at some point:
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
getparams(transition::AdvancedHMC.Transition) = transition.z.θ
getstats(transition::AdvancedHMC.Transition) = transition.stat

getparams(transition::AdvancedMH.Transition) = transition.params
getstats(transition) = NamedTuple()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@torfjelde Isn't this obsolete now that #2026 was merged?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. @JaimeRZP, can you do a follow-up PR to unify these Transition types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is, which I was aware of; me and Jaime wanted to wait until that had been merged before merging this PR.
I was planning to incorporate those changes into this PR before merging. We're also missing a version-bump.

I'd appreciate it if we left merging of a PR to the person who opened it, unless otherwise explicitly stated. In particular now when it's just a matter of days before I'll be back in full development capacity again. This has happened quite few times now :/

Also, it's not like this PR needs to be merged to be able to develop other functionality; it's easy enough to just depend on the branch directly.

Copy link
Member

@yebai yebai Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was under the impression that @JaimeRZP needs this to be released. Sorry for the rush -- hopefully, it didn't break anything! We can always add more changes in a follow-up PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha. @JaimeRZP you can always just do ]add Turing#torfjelde/allow-abstractsampler-draft if you want to try out recent developments. And if you want to develop features based on this branch, just create a branch based on this PR and then continue from there, as you did with #2028 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants