Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
4e5f1ad
Import of existing Turing code.
adscib Apr 29, 2016
87a08d3
Re-organisation.
Apr 29, 2016
51857e5
More efficient implementation of ParticleContainer.
May 10, 2016
97a1886
Improved efficiency of `io.jl`.
May 10, 2016
9598d72
Added some comments.
May 11, 2016
5c04c67
Revert all Int64 to Int.
May 16, 2016
9f924e6
Merge pull request #10 from yebai/appveyor_test
yebai May 16, 2016
a551c22
Merged with master.
May 16, 2016
8dc5640
Merge branch 'master' into benchmarks
May 16, 2016
672cd4d
Merge pull request #8 from yebai/benchmarks
yebai May 22, 2016
fea283d
Add a switch to SMC to enable replay coroutines.
Oct 24, 2016
6fb9230
Fix a bug in container
xukai92 Nov 9, 2016
fd53695
Moved some untested functions to a seprate branch.
Nov 10, 2016
434661f
Merge pull request #36 from yebai/test-coverage
xukai92 Nov 10, 2016
082d503
Fix bug in TArray.
Nov 10, 2016
9e96a94
Merge pull request #38 from yebai/master-julia-0.5
xukai92 Nov 11, 2016
f3c5d88
Merge branch 'master' into development
Nov 14, 2016
1f2ebf2
Merge branch 'development'
Nov 14, 2016
ec137fd
Fix Gibbs sampler bug
xukai92 Jan 17, 2017
62ac590
Merge pull request #73 from yebai/new-interface
yebai Feb 3, 2017
0163d26
Unify Mamba.Chain and Turing.Chain into one type.
Mar 24, 2017
a2c15f0
Merge pull request #116 from yebai/unify-chain
xukai92 Mar 24, 2017
be46b72
Fixed PG.
Apr 2, 2017
4725d9c
Remove Task.storage and use flatten names internally
xukai92 Apr 2, 2017
3416493
Merge pull request #130 from yebai/unify-rand-interface
yebai Apr 3, 2017
17751a0
Merge branch 'master' into remove-sample-macro
Apr 3, 2017
d8eafad
Merge pull request #134 from yebai/remove-sample-macro
xukai92 Apr 3, 2017
928c73f
Add statistics
xukai92 Apr 13, 2017
4a9e06d
Make accumulating of log weight more roboust (#218)
xukai92 Apr 30, 2017
f2fdefd
Merge pull request #220 from yebai/refactor-vi
yebai May 3, 2017
e891e91
Commented out an obsolete function.
May 9, 2017
90c6497
Make the use of ForwardDiff package clear
xukai92 Jul 5, 2017
1057873
Work around #171
xukai92 Jul 24, 2017
c2a828c
Merge pull request #316 from yebai/fix-extension
yebai Jul 24, 2017
793593a
avoid useless computation since weights are normalised
emilemathieu Oct 15, 2017
ca8b5a9
Merge pull request #355 from emilemathieu/issue-350
yebai Oct 17, 2017
b15650e
avoid useless computation since weights are normalised
emilemathieu Oct 15, 2017
99096f1
Merge branch 'master' into feature-334
yebai Oct 19, 2017
4f39231
Merge pull request #363 from yebai/feature-334
yebai Oct 19, 2017
94938ba
Solve compatibility with Julia 0.6 (#341, #330, #293)
yebai Nov 16, 2017
522f70b
Bugfix for Particle Gibbs (#384)
emilemathieu Nov 20, 2017
7947ba3
only erase extra randomness ofr reference particle
emilemathieu Nov 27, 2017
eb6aadd
only use forkr for when reference particle is created, otherwise alwa…
emilemathieu Nov 27, 2017
1406479
merge TraceC and TraceR
emilemathieu Nov 28, 2017
36b3fca
remove useless function and add signature
emilemathieu Nov 28, 2017
3931800
temp fixes to load Turing
xukai92 Aug 11, 2018
96b2a1c
Fix more syntax error.
yebai Aug 11, 2018
fe47322
fix comment syntax
xukai92 Aug 11, 2018
0b55421
Further changes in favor of #462
trappmartin Aug 14, 2018
94cd7e9
Fix more mutable types errors.
yebai Aug 14, 2018
933472e
Fix merge conflicts.
yebai Aug 14, 2018
1186560
Fix merge conflict.
yebai Aug 14, 2018
fa42d69
Revert some changes related to Missing.
yebai Aug 14, 2018
b51874a
Fix incorrect `take!` call.
yebai Aug 20, 2018
f1cba78
added array initializers to src
trappmartin Aug 24, 2018
3726560
further fixes to get tests pass, issue #469
trappmartin Aug 31, 2018
ffa0636
Fix an initialization bug in particle container.
yebai Aug 31, 2018
f799dad
Fixes tests #490
yebai Sep 5, 2018
656c021
Make structs parametric and type stable and take constructors out
mohdibntarek Sep 7, 2018
a5be428
Merge pull request #497 from TuringLang/hg/libtask
hessammehr Sep 11, 2018
c3e0daf
Tidy up AD interface + remove cleandual! and realpart!. (#515)
willtebbutt Sep 14, 2018
885923f
Some style improvements (#598)
willtebbutt Nov 17, 2018
7a0062a
Compiler refactor 2.0 (#613)
mohdibntarek Dec 28, 2018
a56fd0d
Reorganization - no functional change (#649)
mohdibntarek Jan 19, 2019
319afc4
Upgrade Libtask, use CTask instead of Task to create task (#685)
KDr2 Feb 17, 2019
b4200d5
Remove `Nothing` from the Sampler union type (#706)
KDr2 Mar 8, 2019
0618adb
Test refactoring (#731)
cpfiffer Mar 29, 2019
4493678
Merge PMCMC samplers into one file - no functionality change. (#773)
yebai May 6, 2019
7ff0e90
TS 2: introduce TypedVarInfo and fix spl.info[:cache_updated] (#742)
mohdibntarek May 13, 2019
aa9ab05
make AllUtils.jl inclusion copy-paste-friendly (#804)
mohdibntarek Jun 3, 2019
0b7f5be
TS 3: Hook sample to TypedVarInfo (#803)
mohdibntarek Jun 8, 2019
e55d3d8
Fix #665, #802 and #829 (#826)
mohdibntarek Jul 8, 2019
ca09b4e
Interface Changeover (#793)
cpfiffer Sep 12, 2019
d36d387
Add bound on ZygoteRules because it's not done by Zygote (#946)
andreasnoack Nov 6, 2019
505e1c1
Remove redundant computations and unneeded allocations (#959)
devmotion Nov 18, 2019
4c11223
Reduce allocations (#971)
devmotion Nov 18, 2019
e0f06da
Update checks for nothing (#964)
devmotion Nov 18, 2019
9da909c
Simplify ParticleContainer (#966)
devmotion Nov 20, 2019
c545d0a
Remove num_particles and pre-allocate children (#979)
devmotion Nov 21, 2019
8c18652
Simplify systematic resampling
devmotion Nov 22, 2019
666337d
Simplify stratified resampling (#983)
devmotion Nov 22, 2019
e906cd9
Simplify the acceptance check in MH and update the documentation (#992)
devmotion Nov 24, 2019
7e57db6
Add fallbacks for `alg_str` and `transition_type` (#995)
devmotion Nov 24, 2019
ca4bae1
Create arrays of particles instead of using `push!` in SMC and PG (#988)
devmotion Nov 25, 2019
b151f11
Compiler 3.0 (#965)
mohdibntarek Dec 9, 2019
00119da
Add explicit resampler with ESS threshold (#990)
devmotion Dec 17, 2019
9b4a67e
Fix docstrings and remove static type parameters (#1037)
devmotion Dec 24, 2019
fbb35dc
prob and logprob macros for simple model queries (#997)
mohdibntarek Dec 25, 2019
ea59d23
Remove static type parameters (#1044)
devmotion Dec 29, 2019
f018000
Use AdvancedMH (#1083)
cpfiffer Feb 7, 2020
0e677e0
Use reset_num_produce! (#1117)
devmotion Feb 18, 2020
eeb6c2b
Update to new AbstractMCMC API (#1116)
devmotion Feb 24, 2020
f4a3236
Zygote AD backend (#783)
mohdibntarek Mar 15, 2020
e9fa8e7
Move default model evaluation code to DynamicPPL (#1151)
phipsgabler Mar 24, 2020
e12b9af
Add concretize before generating Chains
cpfiffer Mar 25, 2020
b477acb
Improved particle filter error message. (#900)
yebai Apr 15, 2020
5c254fc
Update to DynamicPPL 0.6
devmotion Apr 18, 2020
65d92c1
Merge pull request #1224 from devmotion/update
cpfiffer Apr 18, 2020
785381b
Add constructors and documentation for `SMC` (#1228)
devmotion Apr 22, 2020
ef6d0cc
Use `propagate!` instead of `Libtask.consume` (#1237)
devmotion Apr 27, 2020
03b2c8f
Simplify ParticleContainer and add better tests for propagate! (#1242)
devmotion Apr 28, 2020
3fdaadc
Make Turing compatible with Libtask 0.4 (#1248)
devmotion May 3, 2020
40afdd4
Make Turing compatible with DynamicPPL 0.7.1 (#1249)
devmotion May 4, 2020
95f2ab0
Allow sampling from the prior (#1243)
devmotion May 4, 2020
62c6e46
Fix log evidence computation (#1266)
devmotion May 6, 2020
12e187a
Use non-global RNGs
devmotion May 17, 2020
1848f32
Make GibbsComponent a simple trait (fix #1306) (#1307)
phipsgabler Jun 6, 2020
b5176d1
Fixes for MCMCChains 4 (#1324)
devmotion Jun 23, 2020
ffab788
residual sampling weird weighting (#1345)
francescoalemanno Jul 2, 2020
1eb1645
Remove DEBUG (#1342)
devmotion Aug 12, 2020
bfb4f3e
Fix random test errors (#1469)
devmotion Nov 24, 2020
16fae46
Update to AbstractMCMC 2 (#1428)
devmotion Nov 27, 2020
2c978d0
Merge branch 'master' of github.com:TuringLang/AdvancedPS.jl
devmotion Nov 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 326 additions & 0 deletions Turing/src/contrib/inference/AdvancedSMCExtensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@

####
#### Particle marginal Metropolis-Hastings sampler.
####

"""
PMMH(n_iters::Int, smc_alg:::SMC, parameters_algs::Tuple{MH})

Particle independant Metropolis–Hastings and
Particle marginal Metropolis–Hastings samplers.

Note that this method is particle-based, and arrays of variables
must be stored in a [`TArray`](@ref) object.

Usage:

```julia
alg = PMMH(100, SMC(20, :v1), MH(1,:v2))
alg = PMMH(100, SMC(20, :v1), MH(1,(:v2, (x) -> Normal(x, 1))))
```

Arguments:

- `n_iters::Int` : Number of iterations to run.
- `smc_alg:::SMC` : An [`SMC`](@ref) algorithm to use.
- `parameters_algs::Tuple{MH}` : An [`MH`](@ref) algorithm, which includes a
sample space specification.
"""
mutable struct PMMH{space, A<:Tuple} <: InferenceAlgorithm
n_iters::Int # number of iterations
algs::A # Proposals for state & parameters
end
function PMMH(n_iters::Int, algs::Tuple, space::Tuple)
return PMMH{space, typeof(algs)}(n_iters, algs)
end
function PMMH(n_iters::Int, smc_alg::SMC, parameter_algs...)
return PMMH(n_iters, tuple(parameter_algs..., smc_alg), ())
end

PIMH(n_iters::Int, smc_alg::SMC) = PMMH(n_iters, tuple(smc_alg), ())

function Sampler(alg::PMMH, model::Model, s::Selector)
info = Dict{Symbol, Any}()
spl = Sampler(alg, info, s)

alg_str = "PMMH"
n_samplers = length(alg.algs)
samplers = Array{Sampler}(undef, n_samplers)

space = Set{Symbol}()

for i in 1:n_samplers
sub_alg = alg.algs[i]
if isa(sub_alg, Union{SMC, MH})
samplers[i] = Sampler(sub_alg, model, Selector(Symbol(typeof(sub_alg))))
else
error("[$alg_str] unsupport base sampling algorithm $alg")
end
if typeof(sub_alg) == MH && sub_alg.n_iters != 1
warn(
"[$alg_str] number of iterations greater than 1" *
"is useless for MH since it is only used for its proposal"
)
end
space = union(space, sub_alg.space)
end

info[:old_likelihood_estimate] = -Inf # Force to accept first proposal
info[:old_prior_prob] = 0.0
info[:samplers] = samplers

return spl
end

function step(model, spl::Sampler{<:PMMH}, vi::VarInfo, is_first::Bool)
violating_support = false
proposal_ratio = 0.0
new_prior_prob = 0.0
new_likelihood_estimate = 0.0
old_θ = copy(vi[spl])

@debug "Propose new parameters from proposals..."
for local_spl in spl.info[:samplers][1:end-1]
@debug "$(typeof(local_spl)) proposing $(local_spl.alg.space)..."
propose(model, local_spl, vi)
if local_spl.info[:violating_support] violating_support=true; break end
new_prior_prob += local_spl.info[:prior_prob]
proposal_ratio += local_spl.info[:proposal_ratio]
end

if violating_support
# do not run SMC if going to refuse anyway
accepted = false
else
@debug "Propose new state with SMC..."
vi, _ = step(model, spl.info[:samplers][end], vi)
new_likelihood_estimate = spl.info[:samplers][end].info[:logevidence][end]

@debug "Decide whether to accept..."
accepted = mh_accept(
spl.info[:old_likelihood_estimate] + spl.info[:old_prior_prob],
new_likelihood_estimate + new_prior_prob,
proposal_ratio,
)
end

if accepted
spl.info[:old_likelihood_estimate] = new_likelihood_estimate
spl.info[:old_prior_prob] = new_prior_prob
else # rejected
vi[spl] = old_θ
end

return vi, accepted
end

function sample( model::Model,
alg::PMMH;
save_state=false, # flag for state saving
resume_from=nothing, # chain to continue
reuse_spl_n=0 # flag for spl re-using
)

spl = Sampler(alg, model)
if resume_from !== nothing
spl.selector = resume_from.info[:spl].selector
end
alg_str = "PMMH"

# Number of samples to store
sample_n = spl.alg.n_iters

# Init samples
time_total = zero(Float64)
samples = Array{Sample}(undef, sample_n)
weight = 1 / sample_n
for i = 1:sample_n
samples[i] = Sample(weight, Dict{Symbol, Any}())
end

# Init parameters
vi = if resume_from === nothing
vi_ = VarInfo(model)
else
resume_from.info[:vi]
end
n = spl.alg.n_iters

# PMMH steps
accept_his = Bool[]
PROGRESS[] && (spl.info[:progress] = ProgressMeter.Progress(n, 1, "[$alg_str] Sampling...", 0))
for i = 1:n
@debug "$alg_str stepping..."
time_elapsed = @elapsed vi, is_accept = step(model, spl, vi, i==1)

if is_accept # accepted => store the new predcits
samples[i].value = Sample(vi, spl).value
else # rejected => store the previous predcits
samples[i] = samples[i - 1]
end

time_total += time_elapsed
push!(accept_his, is_accept)
if PROGRESS[]
haskey(spl.info, :progress) && ProgressMeter.update!(spl.info[:progress], spl.info[:progress].counter + 1)
end
end

println("[$alg_str] Finished with")
println(" Running time = $time_total;")
accept_rate = sum(accept_his) / n # calculate the accept rate
println(" Accept rate = $accept_rate;")

if resume_from !== nothing # concat samples
pushfirst!(samples, resume_from.info[:samples]...)
end
c = Chain(-Inf, samples) # wrap the result by Chain

if save_state # save state
c = save(c, spl, model, vi, samples)
end

c
end


####
#### IMCMC Sampler.
####

"""
IPMCMC(n_particles::Int, n_iters::Int, n_nodes::Int, n_csmc_nodes::Int)

Particle Gibbs sampler.

Note that this method is particle-based, and arrays of variables
must be stored in a [`TArray`](@ref) object.

Usage:

```julia
IPMCMC(100, 100, 4, 2)
```

Arguments:

- `n_particles::Int` : Number of particles to use.
- `n_iters::Int` : Number of iterations to employ.
- `n_nodes::Int` : The number of nodes running SMC and CSMC.
- `n_csmc_nodes::Int` : The number of CSMC nodes.
```

A paper on this can be found [here](https://arxiv.org/abs/1602.05128).
"""
mutable struct IPMCMC{T, F} <: InferenceAlgorithm
n_particles::Int # number of particles used
n_iters::Int # number of iterations
n_nodes::Int # number of nodes running SMC and CSMC
n_csmc_nodes::Int # number of nodes CSMC
resampler::F # function to resample
space::Set{T} # sampling space, emtpy means all
end
IPMCMC(n1::Int, n2::Int) = IPMCMC(n1, n2, 32, 16, resample_systematic, Set())
IPMCMC(n1::Int, n2::Int, n3::Int) = IPMCMC(n1, n2, n3, Int(ceil(n3/2)), resample_systematic, Set())
IPMCMC(n1::Int, n2::Int, n3::Int, n4::Int) = IPMCMC(n1, n2, n3, n4, resample_systematic, Set())
function IPMCMC(n1::Int, n2::Int, n3::Int, n4::Int, space...)
_space = isa(space, Symbol) ? Set([space]) : Set(space)
IPMCMC(n1, n2, n3, n4, resample_systematic, _space)
end

function Sampler(alg::IPMCMC, s::Selector)
info = Dict{Symbol, Any}()
spl = Sampler(alg, info, s)
# Create SMC and CSMC nodes
samplers = Array{Sampler}(undef, alg.n_nodes)
# Use resampler_threshold=1.0 for SMC since adaptive resampling is invalid in this setting
default_CSMC = CSMC(alg.n_particles, 1, alg.resampler, alg.space)
default_SMC = SMC(alg.n_particles, alg.resampler, 1.0, false, alg.space)

for i in 1:alg.n_csmc_nodes
samplers[i] = Sampler(default_CSMC, Selector(Symbol(typeof(default_CSMC))))
end
for i in (alg.n_csmc_nodes+1):alg.n_nodes
samplers[i] = Sampler(default_SMC, Selector(Symbol(typeof(default_CSMC))))
end

info[:samplers] = samplers

return spl
end

function step(model, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first::Bool)
# Initialise array for marginal likelihood estimators
log_zs = zeros(spl.alg.n_nodes)

# Run SMC & CSMC nodes
for j in 1:spl.alg.n_nodes
reset_num_produce!(VarInfos[j])
VarInfos[j] = step(model, spl.info[:samplers][j], VarInfos[j])[1]
log_zs[j] = spl.info[:samplers][j].info[:logevidence][end]
end

# Resampling of CSMC nodes indices
conditonal_nodes_indices = collect(1:spl.alg.n_csmc_nodes)
unconditonal_nodes_indices = collect(spl.alg.n_csmc_nodes+1:spl.alg.n_nodes)
for j in 1:spl.alg.n_csmc_nodes
# Select a new conditional node by simulating cj
log_ksi = vcat(log_zs[unconditonal_nodes_indices], log_zs[j])
ksi = exp.(log_ksi .- maximum(log_ksi))
c_j = wsample(ksi) # sample from Categorical with unormalized weights

if c_j < length(log_ksi) # if CSMC node selects another index than itself
conditonal_nodes_indices[j] = unconditonal_nodes_indices[c_j]
unconditonal_nodes_indices[c_j] = j
end
end
nodes_permutation = vcat(conditonal_nodes_indices, unconditonal_nodes_indices)

VarInfos[nodes_permutation]
end

function sample(model::Model, alg::IPMCMC)

spl = Sampler(alg)

# Number of samples to store
sample_n = alg.n_iters * alg.n_csmc_nodes

# Init samples
time_total = zero(Float64)
samples = Array{Sample}(undef, sample_n)
weight = 1 / sample_n
for i = 1:sample_n
samples[i] = Sample(weight, Dict{Symbol, Any}())
end

# Init parameters
vi = empty!(VarInfo(model))
VarInfos = Array{VarInfo}(undef, spl.alg.n_nodes)
for j in 1:spl.alg.n_nodes
VarInfos[j] = deepcopy(vi)
end
n = spl.alg.n_iters

# IPMCMC steps
if PROGRESS[] spl.info[:progress] = ProgressMeter.Progress(n, 1, "[IPMCMC] Sampling...", 0) end
for i = 1:n
@debug "IPMCMC stepping..."
time_elapsed = @elapsed VarInfos = step(model, spl, VarInfos, i==1)

# Save each CSMS retained path as a sample
for j in 1:spl.alg.n_csmc_nodes
samples[(i-1)*alg.n_csmc_nodes+j].value = Sample(VarInfos[j], spl).value
end

time_total += time_elapsed
if PROGRESS[]
haskey(spl.info, :progress) && ProgressMeter.update!(spl.info[:progress], spl.info[:progress].counter + 1)
end
end

println("[IPMCMC] Finished with")
println(" Running time = $time_total;")

Chain(0.0, samples) # wrap the result by Chain
end
Loading