Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.1.0"
version = "0.2.0"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Officially, this is not breaking since currently nothing is exported. However, I know that this will break Turing since we call some of the internal methods without RNG, so I thought it would be safer to declare it as breaking.


[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
Expand Down
1 change: 1 addition & 0 deletions src/AdvancedPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module AdvancedPS

import Distributions
import Libtask
import Random
import StatsFuns

include("resampling.jl")
Expand Down
24 changes: 13 additions & 11 deletions src/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ function effectiveSampleSize(pc::ParticleContainer)
end

"""
resample_propagate!(pc::ParticleContainer[, randcat = resample, ref = nothing;
weights = getweights(pc)])
resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic,
ref = nothing; weights = getweights(pc)])

Resample and propagate the particles in `pc`.

Expand All @@ -176,8 +176,9 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
`ref` that is ensured to survive the resampling step.
"""
function resample_propagate!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
randcat = resample,
randcat = resample_systematic,
ref::Union{Particle, Nothing} = nothing;
weights = getweights(pc)
)
Expand All @@ -187,7 +188,7 @@ function resample_propagate!(
# sample ancestor indices
n = length(pc)
nresamples = ref === nothing ? n : n - 1
indx = randcat(weights, nresamples)
indx = randcat(rng, weights, nresamples)

# count number of children for each particle
num_children = zeros(Int, n)
Expand Down Expand Up @@ -230,6 +231,7 @@ function resample_propagate!(
end

function resample_propagate!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing} = nothing;
Expand All @@ -239,7 +241,7 @@ function resample_propagate!(
ess = inv(sum(abs2, weights))

if ess ≤ resampler.threshold * length(pc)
resample_propagate!(pc, resampler.resampler, ref; weights = weights)
resample_propagate!(rng, pc, resampler.resampler, ref; weights = weights)
end

pc
Expand Down Expand Up @@ -292,7 +294,7 @@ function reweight!(pc::ParticleContainer)
end

"""
sweep!(pc::ParticleContainer, resampler)
sweep!(rng, pc::ParticleContainer, resampler)

Perform a particle sweep and return an unbiased estimate of the log evidence.

Expand All @@ -303,11 +305,11 @@ The resampling steps use the given `resampler`.
Del Moral, P., Doucet, A., & Jasra, A. (2006). Sequential monte carlo samplers.
Journal of the Royal Statistical Society: Series B (Statistical Methodology), 68(3), 411-436.
"""
function sweep!(pc::ParticleContainer, resampler)
function sweep!(rng::Random.AbstractRNG, pc::ParticleContainer, resampler)
# Initial step:

# Resample and propagate particles.
resample_propagate!(pc, resampler)
resample_propagate!(rng, pc, resampler)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
Expand All @@ -317,7 +319,7 @@ function sweep!(pc::ParticleContainer, resampler)
logZ0 = logZ(pc)

# Reweight the particles by including the first observation ``y₁``.
isdone = reweight!(pc)
isdone = reweight!(rng, pc)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand All @@ -328,14 +330,14 @@ function sweep!(pc::ParticleContainer, resampler)
# For observations ``y₂, …, yₜ``:
while !isdone
# Resample and propagate particles.
resample_propagate!(pc, resampler)
resample_propagate!(rng, pc, resampler)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
logZ0 = logZ(pc)

# Reweight the particles by including the next observation ``yₜ``.
isdone = reweight!(pc)
isdone = reweight!(rng, pc)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand Down
49 changes: 27 additions & 22 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,33 @@ struct ResampleWithESSThreshold{R, T<:Real}
threshold::T
end

function ResampleWithESSThreshold(resampler = resample)
function ResampleWithESSThreshold(resampler = resample_systematic)
ResampleWithESSThreshold(resampler, 0.5)
end

# More stable, faster version of rand(Categorical)
function randcat(p::AbstractVector{<:Real})
function randcat(rng::Random.AbstractRNG, p::AbstractVector{<:Real})
T = eltype(p)
r = rand(T)
r = rand(rng, T)
cp = p[1]
s = 1
for j in eachindex(p)
r -= p[j]
if r <= zero(T)
s = j
break
end
n = length(p)
while cp <= r && s < n
@inbounds cp += p[s += 1]
end
return s
end

function resample_multinomial(
rng::Random.AbstractRNG,
w::AbstractVector{<:Real},
num_particles::Integer = length(w),
)
return rand(Distributions.sampler(Distributions.Categorical(w)), num_particles)
return rand(rng, Distributions.sampler(Distributions.Categorical(w)), num_particles)
end

function resample_residual(
rng::Random.AbstractRNG,
w::AbstractVector{<:Real},
num_particles::Integer = length(weights),
)
Expand All @@ -57,19 +57,19 @@ function resample_residual(
end
residuals[j] = x - floor_x
end

# sampling from residuals
if i <= num_particles
residuals ./= sum(residuals)
rand!(Distributions.Categorical(residuals), view(indices, i:num_particles))
rand!(rng, Distributions.Categorical(residuals), view(indices, i:num_particles))
end

return indices
end


"""
resample_stratified(weights, n)
resample_stratified(rng, weights, n)

Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`,
generated by stratified resampling.
Expand All @@ -80,7 +80,11 @@ are selected according to the multinomial distribution defined by the normalized
i.e., `xᵢ = j` if and only if
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
"""
function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = length(weights))
function resample_stratified(
rng::Random.AbstractRNG,
weights::AbstractVector{<:Real},
n::Integer = length(weights),
)
# check input
m = length(weights)
m > 0 || error("weight vector is empty")
Expand All @@ -93,7 +97,7 @@ function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = lengt
sample = 1
@inbounds for i in 1:n
# sample next `u` (scaled by `n`)
u = oftype(v, i - 1 + rand())
u = oftype(v, i - 1 + rand(rng))

# as long as we have not found the next sample
while v < u
Expand All @@ -114,7 +118,7 @@ function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = lengt
end

"""
resample_systematic(weights, n)
resample_systematic(rng, weights, n)

Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`,
generated by systematic resampling.
Expand All @@ -125,14 +129,18 @@ numbers `u₁`, ..., `uₙ` where ``uₖ = (u + k − 1) / n``. Based on these n
normalized `weights`, i.e., `xᵢ = j` if and only if
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
"""
function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = length(weights))
function resample_systematic(
rng::Random.AbstractRNG,
weights::AbstractVector{<:Real},
n::Integer = length(weights),
)
# check input
m = length(weights)
m > 0 || error("weight vector is empty")

# pre-calculations
@inbounds v = n * weights[1]
u = oftype(v, rand())
u = oftype(v, rand(rng))

# find all samples
samples = Array{Int}(undef, n)
Expand All @@ -158,6 +166,3 @@ function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = lengt

return samples
end

# Default resampling scheme
const resample = resample_systematic
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
2 changes: 1 addition & 1 deletion test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
@test AdvancedPS.logZ(pc) ≈ log(sum(exp, 2 .* logps))

# Resample and propagate particles.
AdvancedPS.resample_propagate!(pc)
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc)
@test pc.logWs == zeros(3)
@test AdvancedPS.getweights(pc) == fill(1/3, 3)
@test all(AdvancedPS.getweight(pc, i) == 1/3 for i in 1:3)
Expand Down
15 changes: 7 additions & 8 deletions test/resampling.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
@testset "resampling.jl" begin
D = [0.3, 0.4, 0.3]
num_samples = Int(1e6)
rng = Random.GLOBAL_RNG

resSystematic = AdvancedPS.resample_systematic(D, num_samples )
resStratified = AdvancedPS.resample_stratified(D, num_samples )
resMultinomial= AdvancedPS.resample_multinomial(D, num_samples )
resResidual = AdvancedPS.resample_residual(D, num_samples )
AdvancedPS.resample(D)
resSystematic2= AdvancedPS.resample(D, num_samples )
resSystematic = AdvancedPS.resample_systematic(rng, D, num_samples)
resStratified = AdvancedPS.resample_stratified(rng, D, num_samples)
resMultinomial= AdvancedPS.resample_multinomial(rng, D, num_samples)
resResidual = AdvancedPS.resample_residual(rng, D, num_samples)
AdvancedPS.resample_systematic(rng, D)

@test sum(resSystematic .== 2) ≈ (num_samples * 0.4) atol=1e-3*num_samples
@test sum(resSystematic2 .== 2) ≈ (num_samples * 0.4) atol=1e-3*num_samples
@test sum(resStratified .== 2) ≈ (num_samples * 0.4) atol=1e-3*num_samples
@test sum(resMultinomial .== 2) ≈ (num_samples * 0.4) atol=1e-2*num_samples
@test sum(resResidual .== 2) ≈ (num_samples * 0.4) atol=1e-2*num_samples
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using AdvancedPS
using Libtask
using Random
using Test

@testset "AdvancedPS.jl" begin
Expand Down