-
Notifications
You must be signed in to change notification settings - Fork 17
/
stepper.jl
81 lines (67 loc) · 2.19 KB
/
stepper.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
struct Stepper{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K}
rng::A
model::M
sampler::S
kwargs::K
end
# Initial sample.
function Base.iterate(stp::Stepper)
# Unpack iterator.
rng = stp.rng
model = stp.model
sampler = stp.sampler
kwargs = stp.kwargs
discard_initial = get(kwargs, :discard_initial, 0)::Int
# Start sampling algorithm and discard initial samples if desired.
sample, state = step(rng, model, sampler; kwargs...)
for _ in 1:discard_initial
sample, state = step(rng, model, sampler, state; kwargs...)
end
return sample, state
end
# Subsequent samples.
function Base.iterate(stp::Stepper, state)
# Unpack iterator.
rng = stp.rng
model = stp.model
sampler = stp.sampler
kwargs = stp.kwargs
thinning = get(kwargs, :thinning, 1)::Int
# Return next sample, possibly after thinning the chain if desired.
for _ in 1:(thinning - 1)
_, state = step(rng, model, sampler, state; kwargs...)
end
return step(rng, model, sampler, state; kwargs...)
end
Base.IteratorSize(::Type{<:Stepper}) = Base.IsInfinite()
Base.IteratorEltype(::Type{<:Stepper}) = Base.EltypeUnknown()
function steps(model_or_logdensity, sampler::AbstractSampler; kwargs...)
return steps(Random.default_rng(), model_or_logdensity, sampler; kwargs...)
end
"""
steps(
rng::Random.AbstractRNG=Random.default_rng(),
model::AbstractModel,
sampler::AbstractSampler;
kwargs...,
)
Create an iterator that returns samples from the `model` with the Markov chain Monte Carlo
`sampler`.
# Examples
```jldoctest; setup=:(using AbstractMCMC: steps)
julia> struct MyModel <: AbstractMCMC.AbstractModel end
julia> struct MySampler <: AbstractMCMC.AbstractSampler end
julia> function AbstractMCMC.step(rng, ::MyModel, ::MySampler, state=nothing; kwargs...)
# all samples are zero
return 0.0, state
end
julia> iterator = steps(MyModel(), MySampler());
julia> collect(Iterators.take(iterator, 10)) == zeros(10)
true
```
"""
function steps(
rng::Random.AbstractRNG, model::AbstractModel, sampler::AbstractSampler; kwargs...
)
return Stepper(rng, model, sampler, kwargs)
end