Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8725b79
Init
FredericWantiez Apr 2, 2021
81603bd
mcmc chains
FredericWantiez Apr 30, 2021
2382874
rng
FredericWantiez May 4, 2021
79d9857
Trace
FredericWantiez May 4, 2021
ddbbc3c
Rng, test
FredericWantiez May 12, 2021
87559ef
Fix tests
FredericWantiez May 17, 2021
460f26c
Global rng
FredericWantiez May 20, 2021
b5240eb
Format, duplicate rng
FredericWantiez May 31, 2021
f125369
Tracedrng
FredericWantiez May 31, 2021
156afff
Merge branch 'master' of https://github.com/FredericWantiez/AdvancedP…
FredericWantiez Jun 4, 2021
cd12c8c
Format
FredericWantiez Jun 13, 2021
5cbbe40
Track Particle container
FredericWantiez Jul 11, 2021
2fb7290
Merge branch 'master' of https://github.com/FredericWantiez/AdvancedP…
FredericWantiez Jul 11, 2021
22fa456
Test, format
FredericWantiez Jul 11, 2021
8f05881
Reset RNG, save state, random123
FredericWantiez Aug 17, 2021
5b01908
Downcast
FredericWantiez Aug 17, 2021
8134751
Merge pull request #2 from FredericWantiez/feature/split
FredericWantiez Aug 17, 2021
0893eac
Correct replaying mechanism
FredericWantiez Aug 28, 2021
0ecdcce
Format
FredericWantiez Aug 28, 2021
f6df82d
Fix resampling step
FredericWantiez Aug 29, 2021
e76747e
Clean
FredericWantiez Aug 29, 2021
d77876f
Merge pull request #3 from FredericWantiez/feature/split
FredericWantiez Aug 29, 2021
4cd74b0
Fix naming
FredericWantiez Sep 4, 2021
3acd49a
Format
FredericWantiez Sep 4, 2021
fc31d02
Test
FredericWantiez Sep 4, 2021
d709319
Doc string
FredericWantiez Sep 8, 2021
6979bbe
Merge pull request #4 from FredericWantiez/feature/split
FredericWantiez Sep 8, 2021
8ed2100
Fix types, import and PR review
FredericWantiez Sep 11, 2021
4f8ec74
Merge pull request #5 from FredericWantiez/feature/split
FredericWantiez Sep 11, 2021
7f1b12f
Update src/rng.jl
yebai Sep 14, 2021
705d1dc
Update Project.toml
yebai Sep 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.2.4"
version = "0.3.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
AbstractMCMC = "2, 3"
Distributions = "0.23, 0.24, 0.25"
Libtask = "0.5.3"
Random123 = "1.3"
StatsFuns = "0.9"
julia = "1.3"
2 changes: 2 additions & 0 deletions src/AdvancedPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ using Distributions: Distributions
using Libtask: Libtask
using Random: Random
using StatsFuns: StatsFuns
using Random123: Random123

include("resampling.jl")
include("rng.jl")
include("container.jl")
include("smc.jl")
include("model.jl")
Expand Down
85 changes: 66 additions & 19 deletions src/container.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,43 @@
struct Trace{F}
struct Trace{F,U,N,V<:Random123.AbstractR123{U}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we dispatch on U, N or V somewhere? Otherwise (and probably even in this case) this could be simplified to

struct Trace{F,R<:TracedRNG}
    f::F
    ctask::Libtask.CTask
    rng::R
end

f::F
ctask::Libtask.CTask
rng::TracedRNG{U,N,V}
end

const Particle = Trace

function Trace(f)
function Trace(f, rng::TracedRNG)
ctask = let f = f
Libtask.CTask() do
res = f()
res = f(rng)
Libtask.produce(nothing)
return res
end
end

# add backward reference
newtrace = Trace(f, ctask)
newtrace = Trace(f, ctask, rng)
addreference!(ctask.task, newtrace)

return newtrace
end

Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask))
function Trace(f, ctask::Libtask.CTask)
return Trace(f, ctask, TracedRNG())
end

# Copy task
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask), deepcopy(trace.rng))

# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
advance!(t::Trace) = Libtask.consume(t.ctask)
function advance!(t::Trace, isref::Bool)
isref ? load_state!(t.rng) : save_state!(t.rng)
inc_counter!(t.rng)

# Move to next step
return Libtask.consume(t.ctask)
end

# reset log probability
reset_logprob!(t::Trace) = nothing
Expand All @@ -48,16 +60,18 @@ end
# Create new task and copy randomness
function forkr(trace::Trace)
newf = reset_model(trace.f)
Random123.set_counter!(trace.rng, 1)

ctask = let f = trace.ctask.task.code
Libtask.CTask() do
res = f()
res = f()(trace.rng)
Libtask.produce(nothing)
return res
end
end

# add backward reference
newtrace = Trace(newf, ctask)
newtrace = Trace(newf, ctask, trace.rng)
addreference!(ctask.task, newtrace)

return newtrace
Expand All @@ -81,15 +95,21 @@ Data structure for particle filters
- normalise!(pc::ParticleContainer)
- consume(pc::ParticleContainer): return incremental likelihood
"""
mutable struct ParticleContainer{T<:Particle}
mutable struct ParticleContainer{T<:Particle,U,N,V<:Random123.AbstractR123{U}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I guess we could use

mutable struct ParticleContainer{T<:Particle,R<:TracedRNG}
    vals::Vector{T}
    logWs::Vector{Float64}
    rng::R
end

?

"Particles."
vals::Vector{T}
"Unnormalized logarithmic weights."
logWs::Vector{Float64}
"Traced RNG to replay the resampling step"
rng::TracedRNG{U,N,V}
end

function ParticleContainer(particles::Vector{<:Particle})
return ParticleContainer(particles, zeros(length(particles)))
return ParticleContainer(particles, zeros(length(particles)), TracedRNG())
end

function ParticleContainer(particles::Vector{<:Particle}, r::TracedRNG)
return ParticleContainer(particles, zeros(length(particles)), r)
end

Base.collect(pc::ParticleContainer) = pc.vals
Expand All @@ -116,7 +136,10 @@ function Base.copy(pc::ParticleContainer)
# copy weights
logWs = copy(pc.logWs)

return ParticleContainer(vals, logWs)
# Copy rng and states
rng = copy(pc.rng)

return ParticleContainer(vals, logWs, rng)
end

"""
Expand Down Expand Up @@ -170,6 +193,22 @@ function effectiveSampleSize(pc::ParticleContainer)
return inv(sum(abs2, Ws))
end

"""
update_keys!(pc::ParticleContainer)

Create new unique keys for the particles in the ParticleContainer
"""
function update_keys!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing)
# Update keys to new particle ids
nparticles = length(pc)
n = ref === nothing ? nparticles : nparticles - 1
for i in 1:n
pi = pc.vals[i]
k = split(pi.rng.rng.key)
Random.seed!(pi.rng, k[1])
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could add an empty return or nothing at the end to avoid that the last seed or rng is returned.


"""
resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic,
ref = nothing; weights = getweights(pc)])
Expand Down Expand Up @@ -213,11 +252,17 @@ function resample_propagate!(
pi = particles[i]
isref = pi === ref
p = isref ? fork(pi, isref) : pi
children[j += 1] = p
nseeds = isref ? ni - 1 : ni

seeds = split(p.rng.rng.key, nseeds)
!isref && Random.seed!(p.rng, seeds[1])

children[j += 1] = p
# fork additional children
for _ in 2:ni
children[j += 1] = fork(p, isref)
for k in 2:ni
part = fork(p, isref)
Random.seed!(part.rng, seeds[k])
children[j += 1] = part
end
end
end
Expand Down Expand Up @@ -247,6 +292,8 @@ function resample_propagate!(

if ess ≤ resampler.threshold * length(pc)
resample_propagate!(rng, pc, resampler.resampler, ref; weights=weights)
else
update_keys!(pc, ref)
end

return pc
Expand All @@ -258,7 +305,7 @@ end
Check if the final time step is reached, and otherwise reweight the particles by
considering the next observation.
"""
function reweight!(pc::ParticleContainer)
function reweight!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing)
n = length(pc)

particles = collect(pc)
Expand All @@ -270,7 +317,8 @@ function reweight!(pc::ParticleContainer)
# the execution of the model is finished.
# Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and
# ``θᵢ`` are variables of other samplers.
score = advance!(p)
isref = p === ref
score = advance!(p, isref)

if score === nothing
numdone += 1
Expand Down Expand Up @@ -321,7 +369,6 @@ function sweep!(
ref::Union{Particle,Nothing}=nothing,
)
# Initial step:

# Resample and propagate particles.
resample_propagate!(rng, pc, resampler, ref)

Expand All @@ -333,7 +380,7 @@ function sweep!(
logZ0 = logZ(pc)

# Reweight the particles by including the first observation ``y₁``.
isdone = reweight!(pc)
isdone = reweight!(pc, ref)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand All @@ -351,7 +398,7 @@ function sweep!(
logZ0 = logZ(pc)

# Reweight the particles by including the next observation ``yₜ``.
isdone = reweight!(pc)
isdone = reweight!(pc, ref)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand Down
85 changes: 85 additions & 0 deletions src/rng.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Default RNG type for when nothing is specified
const _BASE_RNG = Random123.Philox2x

"""
TracedRNG{R,N,T}

Wrapped random number generator from Random123 to keep track of random streams during model evaluation
"""
mutable struct TracedRNG{R,N,T<:Random123.AbstractR123{R}} <: Random.AbstractRNG
"Model step counter"
count::Int
"Inner RNG"
rng::T
"Array of keys"
keys::Array{R,N}
end

"""
TracedRNG(r::Random123.AbstractR123=AdvancedPS._BASE_RNG())
Create a `TracedRNG` with `r` as the inner RNG.
"""
function TracedRNG(r::Random123.AbstractR123=_BASE_RNG())
Random123.set_counter!(r, 0)
return TracedRNG(1, r, typeof(r.key)[])
end

# Connect to the Random API
Random.rng_native_52(rng::TracedRNG) = Random.rng_native_52(rng.rng)
Base.rand(rng::TracedRNG, ::Type{T}) where {T} = Base.rand(rng.rng, T)

"""
split(key::Integer, n::Integer=1)

Split `key` into `n` new keys
"""
function split(key::Integer, n::Integer=1)
T = typeof(key) # Make sure the type of `key` is consistent on W32 and W64 systems.
return T[hash(key, i) for i in UInt(1):UInt(n)]
end

"""
load_state!(r::TracedRNG)

Load state from current model iteration. Random streams are now replayed
"""
function load_state!(rng::TracedRNG)
key = rng.keys[rng.count]
Random.seed!(rng.rng, key)
return Random123.set_counter!(rng.rng, 0)
end

"""
update_rng!(rng::TracedRNG)

Set key and counter of inner rng in `rng` to `key` and the running model step to 0
"""
function Random.seed!(rng::TracedRNG, key)
Random.seed!(rng.rng, key)
return Random123.set_counter!(rng.rng, 0)
end

"""
save_state!(r::TracedRNG)

Add current key of the inner rng in `r` to `keys`.
"""
function save_state!(r::TracedRNG)
return push!(r.keys, r.rng.key)
end

Base.copy(r::TracedRNG) = TracedRNG(r.count, copy(r.rng), deepcopy(r.keys))

"""
set_counter!(r::TracedRNG, n::Integer)

Set the counter of the inner rng in `r`, used to keep track of the current model step
"""
Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n

"""
inc_counter!(r::TracedRNG, n::Integer=1)

Increase the model step counter by `n`
"""
inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n
12 changes: 8 additions & 4 deletions src/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ function AbstractMCMC.sample(
end

# Create a set of particles.
particles = ParticleContainer([Trace(model) for _ in 1:(sampler.nparticles)])
particles = ParticleContainer(
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG()
)

# Perform particle sweep.
logevidence = sweep!(rng, particles, sampler.resampler)
Expand Down Expand Up @@ -83,7 +85,9 @@ function AbstractMCMC.step(
rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel, sampler::PG; kwargs...
)
# Create a new set of particles.
particles = ParticleContainer([Trace(model) for _ in 1:(sampler.nparticles)])
particles = ParticleContainer(
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG()
)

# Perform a particle sweep.
logevidence = sweep!(rng, particles, sampler.resampler)
Expand All @@ -108,10 +112,10 @@ function AbstractMCMC.step(
# Create reference trajectory.
forkr(state.trajectory)
else
Trace(model)
Trace(model, TracedRNG())
end
end
particles = ParticleContainer(x)
particles = ParticleContainer(x, TracedRNG())

# Perform a particle sweep.
logevidence = sweep!(rng, particles, sampler.resampler, particles.vals[nparticles])
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
AbstractMCMC = "2, 3"
Distributions = "0.24, 0.25"
Libtask = "0.5"
julia = "1.3"
Random123 = "1.3"
Loading