/
abstractmcmc.jl
71 lines (58 loc) · 2.52 KB
/
abstractmcmc.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
struct TuringState{S,F}
state::S
logdensity::F
end
state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
θ = getparams(transition)
varinfo = DynamicPPL.unflatten(f.varinfo, θ)
# TODO: `deepcopy` is overkill; make more efficient.
varinfo = DynamicPPL.invlink!!(deepcopy(varinfo), f.model)
return Transition(varinfo, transition)
end
# NOTE: Only thing that depends on the underlying sampler.
# Something similar should be part of AbstractMCMC at some point:
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
getparams(transition::AdvancedHMC.Transition) = transition.z.θ
getstats(transition::AdvancedHMC.Transition) = transition.stat
getparams(transition::AdvancedMH.Transition) = transition.params
getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f))
setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Setfield.@set f.varinfo = varinfo
setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo) = setvarinfo(parent(f), varinfo)
# TODO: Do we also support `resume`, etc?
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler};
kwargs...
)
sampler = sampler_wrapper.alg.sampler
# Create a log-density function with an implementation of the
# gradient so we ensure that we're using the same AD backend as in Turing.
f = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(model))
# Link the varinfo.
f = setvarinfo(f, DynamicPPL.link!!(getvarinfo(f), model))
# Then just call `AdvancedHMC.step` with the right arguments.
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler; kwargs...
)
# Update the `state`
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler},
state::TuringState;
kwargs...
)
sampler = sampler_wrapper.alg.sampler
f = state.logdensity
# Then just call `AdvancedHMC.step` with the right arguments.
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
)
# Update the `state`
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
end