diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 00000000..e3c5b3b3 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,311 @@ +# This file is machine-generated - editing it directly is not advised + +[[AbstractMCMC]] +deps = ["ConsoleProgressMonitor", "Distributed", "Logging", "LoggingExtras", "ProgressLogging", "Random", "StatsBase", "TerminalLoggers"] +git-tree-sha1 = "31a0a7b957525748e05599488ca6eef476fef12b" +uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +version = "1.0.1" + +[[AbstractTrees]] +deps = ["Markdown"] +git-tree-sha1 = "33e450545eaf7699da1a6e755f9ea65f14077a45" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.3.3" + +[[ArgCheck]] +git-tree-sha1 = "dedbbb2ddb876f899585c4ec4433265e3017215a" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.1.0" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[Bijectors]] +deps = ["ArgCheck", "Compat", "Distributions", "LinearAlgebra", "MappedArrays", "NNlib", "Random", "Reexport", "Requires", "Roots", "SparseArrays", "Statistics", "StatsFuns"] +git-tree-sha1 = "7049f8682dab97b87c30759057058cfec63e4fc6" +uuid = "76274a88-744f-5084-9051-94815aaf08c4" +version = "0.8.2" + +[[BinaryProvider]] +deps = ["Libdl", "Logging", "SHA"] +git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058" +uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +version = "0.5.10" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "083e7e5ec3ef443e9dcb6dd3fbcb815879823bfa" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.14.0" + +[[CompilerSupportLibraries_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612" +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.3.3+0" + +[[ConsoleProgressMonitor]] +deps = ["Logging", "ProgressMeter"] +git-tree-sha1 = "3ab7b2136722890b9af903859afcf457fa3059e8" +uuid = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" +version = "0.1.2" + +[[DataAPI]] +git-tree-sha1 = "176e23402d80e7743fc26c19c681bfb11246af32" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.3.0" + +[[DataStructures]] +deps = ["InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "88d48e133e6d3dd68183309877eac74393daa7eb" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.17.20" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[Distributions]] +deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"] +git-tree-sha1 = "dec9607adfa6a82399cce0bd9b8557f8cc3b7bcd" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.23.9" + +[[DynamicPPL]] +deps = ["AbstractMCMC", "Bijectors", "Distributions", "MacroTools", "Random", "ZygoteRules"] +git-tree-sha1 = "b46046c78801149cc9f2b80e806fbd5a2891b652" +uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" +version = "0.8.2" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays"] +git-tree-sha1 = "9a457808000939be5f052291cbed7de409c2839d" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.9.4" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[LeftChildRightSiblingTrees]] +deps = ["AbstractTrees"] +git-tree-sha1 = "71be1eb5ad19cb4f61fa8c73395c0338fd092ae0" +uuid = "1d6d02ad-be62-4b6b-8a6d-2f90e265016e" +version = "0.1.2" + +[[LibGit2]] +deps = ["Printf"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[Libtask]] +deps = ["BinaryProvider", "Libdl", "Pkg"] +git-tree-sha1 = "68a658db4792dfc468ea2aabcf06f3f74f153f23" +uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" +version = "0.4.1" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[LoggingExtras]] +deps = ["Dates"] +git-tree-sha1 = "03289aba73c0abc25ff0229bed60f2a4129cd15c" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "0.4.2" + +[[MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.5" + +[[MappedArrays]] +git-tree-sha1 = "e2a02fe7ee86a10c707ff1756ab1650b40b140bb" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.2.2" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "0.4.3" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[NNlib]] +deps = ["Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "8ec4693a5422f0b064ce324f59351f24aa474893" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.7.4" + +[[OpenSpecFun_jll]] +deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"] +git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.3+3" + +[[OrderedCollections]] +git-tree-sha1 = "293b70ac1780f9584c89268a6e2a560d938a7065" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.3.0" + +[[PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] +git-tree-sha1 = "b3405086eb6a974eba1958923d46bc0e1c2d2d63" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.10.0" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "59398022b661b6fd569f25de6b18fde39843196a" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.3" + +[[ProgressMeter]] +deps = ["Distributed", "Printf"] +git-tree-sha1 = "2de4cddc0ceeddafb6b143b5b6cd9c659b64507c" +uuid = "92933f4c-e287-5a05-a399-4b506db050ca" +version = "1.3.2" + +[[QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "0ab8a09d4478ebeb99a706ecbf8634a65077ccdc" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.4.0" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Reexport]] +deps = ["Pkg"] +git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "0.2.0" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.0.1" + +[[Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "86c5647b565873641538d8f812c04e4c9dbeb370" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.6.1" + +[[Rmath_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "d76185aa1f421306dec73c057aa384bad74188f0" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.2.2+1" + +[[Roots]] +deps = ["Printf"] +git-tree-sha1 = "1211c7c1928c1ed29cdcef65979b7a791e3b9fbe" +uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" +version = "1.0.5" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures", "Random", "Test"] +git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "0.3.1" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["OpenSpecFun_jll"] +git-tree-sha1 = "d8d8b8a9f4119829410ecd706da4cc8594a1e020" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "0.10.3" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] +git-tree-sha1 = "a6102b1f364befdb05746f386b67c6b7e3262c45" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.0" + +[[StatsFuns]] +deps = ["Rmath", "SpecialFunctions"] +git-tree-sha1 = "04a5a8e6ab87966b43f247920eab053fd5fdc925" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "0.9.5" + +[[SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[TerminalLoggers]] +deps = ["LeftChildRightSiblingTrees", "Logging", "Markdown", "Printf", "ProgressLogging", "UUIDs"] +git-tree-sha1 = "cbea752b5eef52a3e1188fb31580c3e4fa0cbc35" +uuid = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" +version = "0.1.2" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.0" diff --git a/Project.toml b/Project.toml new file mode 100644 index 00000000..d27a4212 --- /dev/null +++ b/Project.toml @@ -0,0 +1,18 @@ +name = "AdvancedPS" +uuid = "9c00393c-822d-4781-8df5-4c3c33f9866d" +authors = ["JS Denain "] +version = "0.1.0" + +[deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" + +[extras] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[target] +test = ["Test"] diff --git a/README.md b/README.md index 49974209..0b241dae 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,17 @@ -# AdvancedParticleSamplers.jl +# AdvancedPS.jl +## References +* Doucet, Arnaud, and Adam M. Johansen. "A tutorial on particle filtering and smoothing: Fifteen years later." Handbook of nonlinear filtering 12, no. 656-704 (2009): 3. -### Reference +* Andrieu, Christophe, Arnaud Doucet, and Roman Holenstein. "Particle Markov chain Monte Carlo methods." Journal of the Royal Statistical Society: Series B (Statistical Methodology) 72, no. 3 (2010): 269-342. -1. Doucet, Arnaud, and Adam M. Johansen. "A tutorial on particle filtering and smoothing: Fifteen years later." Handbook of nonlinear filtering 12, no. 656-704 (2009): 3. - -2. Andrieu, Christophe, Arnaud Doucet, and Roman Holenstein. "Particle Markov chain Monte Carlo methods." Journal of the Royal Statistical Society: Series B (Statistical Methodology) 72, no. 3 (2010): 269-342. - -3. Tripuraneni, Nilesh, Shixiang Shane Gu, Hong Ge, and Zoubin Ghahramani. "Particle gibbs for infinite hidden Markov models." In Advances in Neural Information Processing Systems, pp. 2395-2403. 2015. - -4. Lindsten, Fredrik, Michael I. Jordan, and Thomas B. Schön. "Particle Gibbs with ancestor sampling." The Journal of Machine Learning Research 15, no. 1 (2014): 2145-2184. - -5. Pitt, Michael K., and Neil Shephard. "Filtering via simulation: Auxiliary particle filters." Journal of the American statistical association 94, no. 446 (1999): 590-599. - -6. Doucet, Arnaud, Nando de Freitas, and Neil Gordon. "Sequential Monte Carlo Methods in Practice." - -7. Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. "Sequential Monte Carlo samplers." Journal of the Royal Statistical Society: Series B (Statistical Methodology) 68, no. 3 (2006): 411-436. +* Tripuraneni, Nilesh, Shixiang Shane Gu, Hong Ge, and Zoubin Ghahramani. "Particle gibbs for infinite hidden Markov models." In Advances in Neural Information Processing Systems, pp. 2395-2403. 2015. +* Lindsten, Fredrik, Michael I. Jordan, and Thomas B. Schön. "Particle Gibbs with ancestor sampling." The Journal of Machine Learning Research 15, no. 1 (2014): 2145-2184. +* Pitt, Michael K., and Neil Shephard. "Filtering via simulation: Auxiliary particle filters." Journal of the American statistical association 94, no. 446 (1999): 590-599. +* Doucet, Arnaud, Nando de Freitas, and Neil Gordon. "Sequential Monte Carlo Methods in Practice." +* Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. "Sequential Monte Carlo samplers." Journal of the Royal Statistical Society: Series B (Statistical Methodology) 68, no. 3 (2006): 411-436. diff --git a/src/AdvancedPS.jl b/src/AdvancedPS.jl new file mode 100644 index 00000000..25b412e3 --- /dev/null +++ b/src/AdvancedPS.jl @@ -0,0 +1,37 @@ +module AdvancedPS +using Libtask +using Random +using AbstractMCMC: AbstractSampler +using DynamicPPL: AbstractVarInfo, Model, SampleFromPrior, Sampler, reset_num_produce!, set_retained_vns_del_by_spl!, increment_num_produce!, getlogp, resetlogp! +using Distributions +using StatsFuns: softmax, logsumexp + +include("trace.jl") +export Trace, + fork, + forkr, + current_trace + +include("particlecontainer.jl") +export ParticleContainer, + reset_logweights!, + increase_logweight!, + getweights, + getweight, + logZ, + effectiveSampleSize + +include("resampling.jl") +export ResampleWithESSThreshold, + randcat, + resample_multinomial, + resample_residual, + resample_stratified, + resample_systematic + +include("sweep.jl") +export resample_propagate!, + reweight!, + sweep! + +end diff --git a/src/particlecontainer.jl b/src/particlecontainer.jl new file mode 100644 index 00000000..05577616 --- /dev/null +++ b/src/particlecontainer.jl @@ -0,0 +1,91 @@ +const Particle = Trace + +""" +Data structure for particle filters +- effectiveSampleSize(pc :: ParticleContainer) +- normalise!(pc::ParticleContainer) +- consume(pc::ParticleContainer): return incremental likelihood +""" +mutable struct ParticleContainer{T<:Particle} + "Particles." + vals::Vector{T} + "Unnormalized logarithmic weights." + logWs::Vector{Float64} +end + +function ParticleContainer(particles::Vector{<:Particle}) + return ParticleContainer(particles, zeros(length(particles))) +end + +Base.collect(pc::ParticleContainer) = pc.vals +Base.length(pc::ParticleContainer) = length(pc.vals) +Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Int) = pc.vals[i] + +# registers a new x-particle in the container +function Base.push!(pc::ParticleContainer, p::Particle) + push!(pc.vals, p) + push!(pc.logWs, 0.0) + pc +end + +# clones a theta-particle +function Base.copy(pc::ParticleContainer) + # fork particles + vals = eltype(pc.vals)[fork(p) for p in pc.vals] + + # copy weights + logWs = copy(pc.logWs) + + ParticleContainer(vals, logWs) +end + +""" + reset_logweights!(pc::ParticleContainer) + +Reset all unnormalized logarithmic weights to zero. +""" +function reset_logweights!(pc::ParticleContainer) + fill!(pc.logWs, 0.0) + return pc +end + +""" + increase_logweight!(pc::ParticleContainer, i::Int, x) + +Increase the unnormalized logarithmic weight of the `i`th particle with `x`. +""" +function increase_logweight!(pc::ParticleContainer, i, logw) + pc.logWs[i] += logw + return pc +end + +""" + getweights(pc::ParticleContainer) + +Compute the normalized weights of the particles. +""" +getweights(pc::ParticleContainer) = softmax(pc.logWs) + +""" + getweight(pc::ParticleContainer, i) + +Compute the normalized weight of the `i`th particle. +""" +getweight(pc::ParticleContainer, i) = exp(pc.logWs[i] - logZ(pc)) + +""" + logZ(pc::ParticleContainer) + +Return the logarithm of the normalizing constant of the unnormalized logarithmic weights. +""" +logZ(pc::ParticleContainer) = logsumexp(pc.logWs) + +""" + effectiveSampleSize(pc::ParticleContainer) + +Compute the effective sample size ``1 / ∑ wᵢ²``, where ``wᵢ```are the normalized weights. +""" +function effectiveSampleSize(pc::ParticleContainer) + Ws = getweights(pc) + return inv(sum(abs2, Ws)) +end diff --git a/src/resampling.jl b/src/resampling.jl new file mode 100644 index 00000000..d0a7d29c --- /dev/null +++ b/src/resampling.jl @@ -0,0 +1,174 @@ +# 2 dichotomies between resamplers: +# - whether or not to use ESS thresholds +# - whether to use systematic, residual, multinomial, or stratified resampling schemes + +############################################# +## Resample only when ESS ≤ a preset value ## +############################################# + + +# modifies dispatch in resample_propagate! +struct ResampleWithESSThreshold{R, T<:Real} + scheme::R + threshold::T +end + +function ResampleWithESSThreshold(scheme = resample_systematic) + ResampleWithESSThreshold(scheme, 0.5) +end + +############################################# +## Resampling schemes for particle filters ## +############################################# + +# Some references +# - http://arxiv.org/pdf/1301.4019.pdf +# - http://people.isy.liu.se/rt/schon/Publications/HolSG2006.pdf +# Code adapted from: http://uk.mathworks.com/matlabcentral/fileexchange/24968-resampling-methods-for-particle-filtering + +# More stable, faster version of rand(Categorical) +function randcat(p::AbstractVector{<:Real}) + T = eltype(p) + r = rand(T) + s = 1 + for j in eachindex(p) + r -= p[j] + if r <= zero(T) + s = j + break + end + end + return s +end + +""" + resample_multinomial(w, num_particles) + +Multinomial resampling scheme +""" +function resample_multinomial(w::AbstractVector{<:Real}, num_particles::Integer) + return rand(Distributions.sampler(Categorical(w)), num_particles) +end + +""" + resample_residual(w, num_particles) + +Residual resampling scheme +""" +function resample_residual(w::AbstractVector{<:Real}, num_particles::Integer) + + M = length(w) + + # "Repetition counts" (plus the random part, later on): + Ns = floor.(length(w) .* w) + + # The "remainder" or "residual" count: + R = Int(sum(Ns)) + + # The number of particles which will be drawn stocastically: + M_rdn = num_particles - R + + # The modified weights: + Ws = (M .* w - floor.(M .* w)) / M_rdn + + # Draw the deterministic part: + indx1, i = Array{Int}(undef, R), 1 + for j in 1:M + for k in 1:Ns[j] + indx1[i] = j + i += 1 + end + end + + # And now draw the stochastic (Multinomial) part: + return append!(indx1, rand(Distributions.sampler(Categorical(w)), M_rdn)) +end + +""" + resample_stratified(weights, n) + +Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`, +generated by stratified resampling. +In stratified resampling `n` ordered random numbers `u₁`, ..., `uₙ` are generated, where +``uₖ \\sim U[(k - 1) / n, k / n)``. Based on these numbers the samples `x₁`, ..., `xₙ` +are selected according to the multinomial distribution defined by the normalized `weights`, +i.e., `xᵢ = j` if and only if +``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``. +""" +function resample_stratified(weights::AbstractVector{<:Real}, n::Integer) + # check input + m = length(weights) + m > 0 || error("weight vector is empty") + + # pre-calculations + @inbounds v = n * weights[1] + + # generate all samples + samples = Array{Int}(undef, n) + sample = 1 + @inbounds for i in 1:n + # sample next `u` (scaled by `n`) + u = oftype(v, i - 1 + rand()) + + # as long as we have not found the next sample + while v < u + # increase and check the sample + sample += 1 + sample > m && + error("sample could not be selected (are the weights normalized?)") + + # update the cumulative sum of weights (scaled by `n`) + v += n * weights[sample] + end + + # save the next sample + samples[i] = sample + end + + return samples +end + +""" + resample_systematic(weights, n) + +Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`, +generated by systematic resampling. +In systematic resampling a random number ``u \\sim U[0, 1)`` is used to generate `n` ordered +numbers `u₁`, ..., `uₙ` where ``uₖ = (u + k − 1) / n``. Based on these numbers the samples +`x₁`, ..., `xₙ` are selected according to the multinomial distribution defined by the +normalized `weights`, i.e., `xᵢ = j` if and only if +``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``. +""" +function resample_systematic(weights::AbstractVector{<:Real}, n::Integer) + # check input + m = length(weights) + m > 0 || error("weight vector is empty") + + # pre-calculations + @inbounds v = n * weights[1] + u = oftype(v, rand()) + + # find all samples + samples = Array{Int}(undef, n) + sample = 1 + @inbounds for i in 1:n + # as long as we have not found the next sample + while v < u + # increase and check the sample + sample += 1 + sample > m && + error("sample could not be selected (are the weights normalized?)") + + # update the cumulative sum of weights (scaled by `n`) + v += n * weights[sample] + end + + # save the next sample + samples[i] = sample + + # update `u` + u += one(u) + end + + return samples +end diff --git a/src/sweep.jl b/src/sweep.jl new file mode 100644 index 00000000..e0a3890f --- /dev/null +++ b/src/sweep.jl @@ -0,0 +1,184 @@ +""" + resample_propagate!(pc::ParticleContainer[, scheme, ref = nothing; weights = getweights(pc)]) + +Resample and propagate the particles in `pc`, without ESS thresholding. +Function `scheme` is the scheme used to resample ancestor indices based on the particle weights. +For Particle Gibbs sampling, one can provide a reference particle `ref` that is ensured to survive the resampling step. +""" +function resample_propagate!( +pc::ParticleContainer, +scheme, +ref::Union{Particle, Nothing} = nothing; +weights = getweights(pc) +) + # check that weights are not NaN + @assert !any(isnan, weights) + + # sample ancestor indices + n = length(pc) + nresamples = ref === nothing ? n : n - 1 + indx = scheme(weights, nresamples) + + # count number of children for each particle + num_children = zeros(Int, n) + @inbounds for i in indx + num_children[i] += 1 + end + + # fork particles + particles = collect(pc) + children = similar(particles) + j = 0 + @inbounds for i in 1:n + ni = num_children[i] + + if ni > 0 + # fork first child + pi = particles[i] + isref = pi === ref + p = isref ? fork(pi, isref) : pi + children[j += 1] = p + + # fork additional children + for _ in 2:ni + children[j += 1] = fork(p, isref) + end + end + end + + if ref !== nothing + # Insert the retained particle. This is based on the replaying trick for efficiency + # reasons. If we implement PG using task copying, we need to store Nx * T particles! + @inbounds children[n] = ref + end + + # replace particles and log weights in the container with new particles and weights + pc.vals = children + reset_logweights!(pc) + + return pc +end + + +""" + resample_propagate!(pc::ParticleContainer, resampler::ResampleWithESSThreshold[, + ref = nothing; weights = getweights(pc)]) + +Resample and propagate the particles in `pc`, with ESS thresholding. +For Particle Gibbs sampling, one can provide a reference particle `ref` that is ensured to survive the resampling step. +""" +function resample_propagate!( +pc::ParticleContainer, +resampler::ResampleWithESSThreshold, +ref::Union{Particle,Nothing} = nothing; +weights = getweights(pc) +) + # Compute the effective sample size ``1 / ∑ wᵢ²`` with normalized weights ``wᵢ`` + ess = inv(sum(abs2, weights)) + + if ess ≤ resampler.threshold * length(pc) + resample_propagate!(pc, resampler.scheme, ref; weights = weights) + end + + return pc +end + + +""" + reweight!(pc::ParticleContainer) + +Check if the final time step is reached, and otherwise reweight the particles by +considering the next observation. +""" +function reweight!(pc::ParticleContainer) + n = length(pc) + + particles = collect(pc) + numdone = 0 + for i in 1:n + p = particles[i] + + # Obtain ``\\log p(yₜ | y₁, …, yₜ₋₁, x₁, …, xₜ, θ₁, …, θₜ)``, or `nothing` if the + # the execution of the model is finished. + # Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and + # ``θᵢ`` are variables of other samplers. + score = Libtask.consume(p) + + if score === nothing + numdone += 1 + else + # Increase the unnormalized logarithmic weights, accounting for the variables + # of other samplers. + increase_logweight!(pc, i, score + getlogp(p.vi)) + + # Reset the accumulator of the log probability in the model so that we can + # accumulate log probabilities of variables of other samplers until the next + # observation. + resetlogp!(p.vi) + end + end + + # Check if all particles are propagated to the final time point. + numdone == n && return true + + # The posterior for models with random number of observations is not well-defined. + if numdone != 0 + error("mis-aligned execution traces: # particles = ", n, + " # completed trajectories = ", numdone, + ". Please make sure the number of observations is NOT random.") + end + + return false +end + +""" + sweep!(pc::ParticleContainer, resampler) + +Perform a particle sweep and return an unbiased estimate of the log evidence. +The resampling steps use the given `resampler`. +# Reference: Del Moral, P., Doucet, A., & Jasra, A. (2006). Sequential monte carlo samplers. +Journal of the Royal Statistical Society: Series B (Statistical Methodology), 68(3), 411-436. +""" +function sweep!(pc::ParticleContainer, resampler) + # Initial step: + + # Resample and propagate particles. + resample_propagate!(pc, resampler) + + # Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic + # weights. + # Usually it is equal to the number of particles in the beginning but this + # implementation covers also the unlikely case of a particle container that is + # initialized with non-zero logarithmic weights. + logZ0 = logZ(pc) + + # Reweight the particles by including the first observation ``y₁``. + isdone = reweight!(pc) + + # Compute the normalizing constant ``Z₁`` after reweighting. + logZ1 = logZ(pc) + + # Compute the estimate of the log evidence ``\\log p(y₁)``. + logevidence = logZ1 - logZ0 + + # For observations ``y₂, …, yₜ``: + while !isdone + # Resample and propagate particles. + resample_propagate!(pc, resampler) + + # Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic + # weights. + logZ0 = logZ(pc) + + # Reweight the particles by including the next observation ``yₜ``. + isdone = reweight!(pc) + + # Compute the normalizing constant ``Z₁`` after reweighting. + logZ1 = logZ(pc) + + # Compute the estimate of the log evidence ``\\log p(y₁, …, yₜ)``. + logevidence += logZ1 - logZ0 + end + + return logevidence +end diff --git a/src/trace.jl b/src/trace.jl new file mode 100644 index 00000000..2b72fcf4 --- /dev/null +++ b/src/trace.jl @@ -0,0 +1,76 @@ +mutable struct Trace{Tspl<:AbstractSampler, Tvi<:AbstractVarInfo, Tmodel<:Model} + model::Tmodel + spl::Tspl + vi::Tvi + ctask::CTask + + function Trace{SampleFromPrior}(model::Model, spl::AbstractSampler, vi::AbstractVarInfo) + return new{SampleFromPrior,typeof(vi),typeof(model)}(model, SampleFromPrior(), vi) + end + function Trace{S}(model::Model, spl::S, vi::AbstractVarInfo) where S<:Sampler + return new{S,typeof(vi),typeof(model)}(model, spl, vi) + end +end + +function Base.copy(trace::Trace) + vi = deepcopy(trace.vi) + res = Trace{typeof(trace.spl)}(trace.model, trace.spl, vi) + res.ctask = copy(trace.ctask) + return res +end + +# NOTE: this function is called by `forkr` +function Trace(f, m::Model, spl::AbstractSampler, vi::AbstractVarInfo) + res = Trace{typeof(spl)}(m, spl, deepcopy(vi)) + ctask = CTask() do + res = f() + produce(nothing) + return res + end + task = ctask.task + if task.storage === nothing + task.storage = IdDict() + end + task.storage[:turing_trace] = res # create a backward reference in task_local_storage + res.ctask = ctask + return res +end + +function Trace(m::Model, spl::AbstractSampler, vi::AbstractVarInfo) + res = Trace{typeof(spl)}(m, spl, deepcopy(vi)) + reset_num_produce!(res.vi) + ctask = CTask() do + res = m(vi, spl) + produce(nothing) + return res + end + task = ctask.task + if task.storage === nothing + task.storage = IdDict() + end + task.storage[:turing_trace] = res # create a backward reference in task_local_storage + res.ctask = ctask + return res +end + +# step to the next observe statement, return log likelihood +Libtask.consume(t::Trace) = (increment_num_produce!(t.vi); consume(t.ctask)) + +# Task copying version of fork for Trace. +function fork(trace :: Trace, is_ref :: Bool = false) + newtrace = copy(trace) + is_ref && set_retained_vns_del_by_spl!(newtrace.vi, newtrace.spl) + newtrace.ctask.task.storage[:turing_trace] = newtrace + return newtrace +end + +# PG requires keeping all randomness for the reference particle +# Create new task and copy randomness +function forkr(trace::Trace) + newtrace = Trace(trace.ctask.task.code, trace.model, trace.spl, deepcopy(trace.vi)) + newtrace.spl = trace.spl + reset_num_produce!(newtrace.vi) + return newtrace +end + +current_trace() = current_task().storage[:turing_trace] diff --git a/test/particlecontainer.jl b/test/particlecontainer.jl new file mode 100644 index 00000000..54dc1ce6 --- /dev/null +++ b/test/particlecontainer.jl @@ -0,0 +1,12 @@ +using AdvancedPS +using Test + +@testset "particlecontainer.jl" begin + @testset "copy particle container" begin + pc = ParticleContainer(Trace[]) + newpc = copy(pc) + + @test newpc.logWs == pc.logWs + @test typeof(pc) === typeof(newpc) + end +end \ No newline at end of file diff --git a/test/resampling.jl b/test/resampling.jl new file mode 100644 index 00000000..61ea7b56 --- /dev/null +++ b/test/resampling.jl @@ -0,0 +1,16 @@ +using Test +using AdvancedPS + +@testset "resampling.jl" begin + D = [0.3, 0.4, 0.3] + num_samples = Int(1e6) + resSystematic = resample_systematic(D, num_samples ) + resStratified = resample_stratified(D, num_samples ) + resMultinomial= resample_multinomial(D, num_samples ) + resResidual = resample_residual(D, num_samples ) + + @test count(==(2), resSystematic) ≈ 0.4 * num_samples atol=1e-3*num_samples + @test count(==(2), resStratified) ≈ 0.4 * num_samples atol=1e-3*num_samples + @test count(==(2), resMultinomial) ≈ 0.4 * num_samples atol=1e-2*num_samples + @test count(==(2), resResidual) ≈ 0.4 * num_samples atol=1e-2*num_samples +end diff --git a/test/runtests.jl b/test/runtests.jl new file mode 100644 index 00000000..2dccc019 --- /dev/null +++ b/test/runtests.jl @@ -0,0 +1,6 @@ +using Test +using AdvancedPS + +include("resampling.jl") +include("particlecontainer.jl") +# more tests for structs and functions from particlecontainer.jl and sweep.jl require smc.jl and are tested with it \ No newline at end of file