-
Notifications
You must be signed in to change notification settings - Fork 12
Fix forkr - Handle rng in Trace #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
8725b79
Init
FredericWantiez 81603bd
mcmc chains
FredericWantiez 2382874
rng
FredericWantiez 79d9857
Trace
FredericWantiez ddbbc3c
Rng, test
FredericWantiez 87559ef
Fix tests
FredericWantiez 460f26c
Global rng
FredericWantiez b5240eb
Format, duplicate rng
FredericWantiez f125369
Tracedrng
FredericWantiez 156afff
Merge branch 'master' of https://github.com/FredericWantiez/AdvancedP…
FredericWantiez cd12c8c
Format
FredericWantiez 5cbbe40
Track Particle container
FredericWantiez 2fb7290
Merge branch 'master' of https://github.com/FredericWantiez/AdvancedP…
FredericWantiez 22fa456
Test, format
FredericWantiez 8f05881
Reset RNG, save state, random123
FredericWantiez 5b01908
Downcast
FredericWantiez 8134751
Merge pull request #2 from FredericWantiez/feature/split
FredericWantiez 0893eac
Correct replaying mechanism
FredericWantiez 0ecdcce
Format
FredericWantiez f6df82d
Fix resampling step
FredericWantiez e76747e
Clean
FredericWantiez d77876f
Merge pull request #3 from FredericWantiez/feature/split
FredericWantiez 4cd74b0
Fix naming
FredericWantiez 3acd49a
Format
FredericWantiez fc31d02
Test
FredericWantiez d709319
Doc string
FredericWantiez 6979bbe
Merge pull request #4 from FredericWantiez/feature/split
FredericWantiez 8ed2100
Fix types, import and PR review
FredericWantiez 4f8ec74
Merge pull request #5 from FredericWantiez/feature/split
FredericWantiez 7f1b12f
Update src/rng.jl
yebai 705d1dc
Update Project.toml
yebai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,20 @@ | ||
name = "AdvancedPS" | ||
uuid = "576499cb-2369-40b2-a588-c64705576edc" | ||
authors = ["TuringLang"] | ||
version = "0.2.4" | ||
version = "0.3.0" | ||
|
||
[deps] | ||
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Random123 = "74087812-796a-5b5d-8853-05524746bad3" | ||
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" | ||
|
||
[compat] | ||
AbstractMCMC = "2, 3" | ||
Distributions = "0.23, 0.24, 0.25" | ||
Libtask = "0.5.3" | ||
Random123 = "1.3" | ||
StatsFuns = "0.9" | ||
julia = "1.3" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,43 @@ | ||
struct Trace{F} | ||
struct Trace{F,U,N,V<:Random123.AbstractR123{U}} | ||
f::F | ||
ctask::Libtask.CTask | ||
rng::TracedRNG{U,N,V} | ||
end | ||
|
||
const Particle = Trace | ||
|
||
function Trace(f) | ||
function Trace(f, rng::TracedRNG) | ||
ctask = let f = f | ||
Libtask.CTask() do | ||
res = f() | ||
res = f(rng) | ||
Libtask.produce(nothing) | ||
return res | ||
end | ||
end | ||
|
||
# add backward reference | ||
newtrace = Trace(f, ctask) | ||
newtrace = Trace(f, ctask, rng) | ||
addreference!(ctask.task, newtrace) | ||
|
||
return newtrace | ||
end | ||
|
||
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask)) | ||
function Trace(f, ctask::Libtask.CTask) | ||
return Trace(f, ctask, TracedRNG()) | ||
end | ||
|
||
# Copy task | ||
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask), deepcopy(trace.rng)) | ||
|
||
# step to the next observe statement and | ||
# return the log probability of the transition (or nothing if done) | ||
advance!(t::Trace) = Libtask.consume(t.ctask) | ||
function advance!(t::Trace, isref::Bool) | ||
isref ? load_state!(t.rng) : save_state!(t.rng) | ||
inc_counter!(t.rng) | ||
|
||
# Move to next step | ||
return Libtask.consume(t.ctask) | ||
end | ||
|
||
# reset log probability | ||
reset_logprob!(t::Trace) = nothing | ||
|
@@ -48,16 +60,18 @@ end | |
# Create new task and copy randomness | ||
function forkr(trace::Trace) | ||
newf = reset_model(trace.f) | ||
Random123.set_counter!(trace.rng, 1) | ||
|
||
ctask = let f = trace.ctask.task.code | ||
Libtask.CTask() do | ||
res = f() | ||
res = f()(trace.rng) | ||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Libtask.produce(nothing) | ||
return res | ||
end | ||
end | ||
|
||
# add backward reference | ||
newtrace = Trace(newf, ctask) | ||
newtrace = Trace(newf, ctask, trace.rng) | ||
addreference!(ctask.task, newtrace) | ||
|
||
return newtrace | ||
|
@@ -81,15 +95,21 @@ Data structure for particle filters | |
- normalise!(pc::ParticleContainer) | ||
- consume(pc::ParticleContainer): return incremental likelihood | ||
""" | ||
mutable struct ParticleContainer{T<:Particle} | ||
mutable struct ParticleContainer{T<:Particle,U,N,V<:Random123.AbstractR123{U}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, I guess we could use mutable struct ParticleContainer{T<:Particle,R<:TracedRNG}
vals::Vector{T}
logWs::Vector{Float64}
rng::R
end ? |
||
"Particles." | ||
vals::Vector{T} | ||
"Unnormalized logarithmic weights." | ||
logWs::Vector{Float64} | ||
"Traced RNG to replay the resampling step" | ||
rng::TracedRNG{U,N,V} | ||
end | ||
|
||
function ParticleContainer(particles::Vector{<:Particle}) | ||
return ParticleContainer(particles, zeros(length(particles))) | ||
return ParticleContainer(particles, zeros(length(particles)), TracedRNG()) | ||
end | ||
|
||
function ParticleContainer(particles::Vector{<:Particle}, r::TracedRNG) | ||
return ParticleContainer(particles, zeros(length(particles)), r) | ||
end | ||
|
||
Base.collect(pc::ParticleContainer) = pc.vals | ||
|
@@ -116,7 +136,10 @@ function Base.copy(pc::ParticleContainer) | |
# copy weights | ||
logWs = copy(pc.logWs) | ||
|
||
return ParticleContainer(vals, logWs) | ||
# Copy rng and states | ||
rng = copy(pc.rng) | ||
|
||
return ParticleContainer(vals, logWs, rng) | ||
end | ||
|
||
""" | ||
|
@@ -170,6 +193,22 @@ function effectiveSampleSize(pc::ParticleContainer) | |
return inv(sum(abs2, Ws)) | ||
end | ||
|
||
""" | ||
update_keys!(pc::ParticleContainer) | ||
|
||
Create new unique keys for the particles in the ParticleContainer | ||
""" | ||
function update_keys!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing) | ||
# Update keys to new particle ids | ||
nparticles = length(pc) | ||
n = ref === nothing ? nparticles : nparticles - 1 | ||
for i in 1:n | ||
pi = pc.vals[i] | ||
k = split(pi.rng.rng.key) | ||
Random.seed!(pi.rng, k[1]) | ||
end | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we could add an empty |
||
|
||
""" | ||
resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic, | ||
ref = nothing; weights = getweights(pc)]) | ||
|
@@ -213,11 +252,17 @@ function resample_propagate!( | |
pi = particles[i] | ||
isref = pi === ref | ||
p = isref ? fork(pi, isref) : pi | ||
children[j += 1] = p | ||
nseeds = isref ? ni - 1 : ni | ||
|
||
seeds = split(p.rng.rng.key, nseeds) | ||
!isref && Random.seed!(p.rng, seeds[1]) | ||
|
||
children[j += 1] = p | ||
# fork additional children | ||
for _ in 2:ni | ||
children[j += 1] = fork(p, isref) | ||
for k in 2:ni | ||
part = fork(p, isref) | ||
Random.seed!(part.rng, seeds[k]) | ||
children[j += 1] = part | ||
end | ||
end | ||
end | ||
|
@@ -247,6 +292,8 @@ function resample_propagate!( | |
|
||
if ess ≤ resampler.threshold * length(pc) | ||
resample_propagate!(rng, pc, resampler.resampler, ref; weights=weights) | ||
else | ||
update_keys!(pc, ref) | ||
end | ||
|
||
return pc | ||
|
@@ -258,7 +305,7 @@ end | |
Check if the final time step is reached, and otherwise reweight the particles by | ||
considering the next observation. | ||
""" | ||
function reweight!(pc::ParticleContainer) | ||
function reweight!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing) | ||
n = length(pc) | ||
|
||
particles = collect(pc) | ||
|
@@ -270,7 +317,8 @@ function reweight!(pc::ParticleContainer) | |
# the execution of the model is finished. | ||
# Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and | ||
# ``θᵢ`` are variables of other samplers. | ||
score = advance!(p) | ||
isref = p === ref | ||
score = advance!(p, isref) | ||
|
||
if score === nothing | ||
numdone += 1 | ||
|
@@ -321,7 +369,6 @@ function sweep!( | |
ref::Union{Particle,Nothing}=nothing, | ||
) | ||
# Initial step: | ||
|
||
# Resample and propagate particles. | ||
resample_propagate!(rng, pc, resampler, ref) | ||
|
||
|
@@ -333,7 +380,7 @@ function sweep!( | |
logZ0 = logZ(pc) | ||
|
||
# Reweight the particles by including the first observation ``y₁``. | ||
isdone = reweight!(pc) | ||
isdone = reweight!(pc, ref) | ||
|
||
# Compute the normalizing constant ``Z₁`` after reweighting. | ||
logZ1 = logZ(pc) | ||
|
@@ -351,7 +398,7 @@ function sweep!( | |
logZ0 = logZ(pc) | ||
|
||
# Reweight the particles by including the next observation ``yₜ``. | ||
isdone = reweight!(pc) | ||
isdone = reweight!(pc, ref) | ||
|
||
# Compute the normalizing constant ``Z₁`` after reweighting. | ||
logZ1 = logZ(pc) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Default RNG type for when nothing is specified | ||
const _BASE_RNG = Random123.Philox2x | ||
|
||
""" | ||
TracedRNG{R,N,T} | ||
|
||
Wrapped random number generator from Random123 to keep track of random streams during model evaluation | ||
""" | ||
mutable struct TracedRNG{R,N,T<:Random123.AbstractR123{R}} <: Random.AbstractRNG | ||
"Model step counter" | ||
count::Int | ||
"Inner RNG" | ||
rng::T | ||
"Array of keys" | ||
keys::Array{R,N} | ||
end | ||
|
||
""" | ||
TracedRNG(r::Random123.AbstractR123=AdvancedPS._BASE_RNG()) | ||
Create a `TracedRNG` with `r` as the inner RNG. | ||
""" | ||
function TracedRNG(r::Random123.AbstractR123=_BASE_RNG()) | ||
Random123.set_counter!(r, 0) | ||
return TracedRNG(1, r, typeof(r.key)[]) | ||
end | ||
|
||
# Connect to the Random API | ||
Random.rng_native_52(rng::TracedRNG) = Random.rng_native_52(rng.rng) | ||
Base.rand(rng::TracedRNG, ::Type{T}) where {T} = Base.rand(rng.rng, T) | ||
|
||
""" | ||
split(key::Integer, n::Integer=1) | ||
|
||
Split `key` into `n` new keys | ||
""" | ||
function split(key::Integer, n::Integer=1) | ||
T = typeof(key) # Make sure the type of `key` is consistent on W32 and W64 systems. | ||
return T[hash(key, i) for i in UInt(1):UInt(n)] | ||
end | ||
|
||
""" | ||
load_state!(r::TracedRNG) | ||
|
||
Load state from current model iteration. Random streams are now replayed | ||
""" | ||
function load_state!(rng::TracedRNG) | ||
key = rng.keys[rng.count] | ||
Random.seed!(rng.rng, key) | ||
return Random123.set_counter!(rng.rng, 0) | ||
end | ||
|
||
""" | ||
update_rng!(rng::TracedRNG) | ||
|
||
Set key and counter of inner rng in `rng` to `key` and the running model step to 0 | ||
""" | ||
function Random.seed!(rng::TracedRNG, key) | ||
Random.seed!(rng.rng, key) | ||
return Random123.set_counter!(rng.rng, 0) | ||
end | ||
|
||
""" | ||
save_state!(r::TracedRNG) | ||
|
||
Add current key of the inner rng in `r` to `keys`. | ||
""" | ||
function save_state!(r::TracedRNG) | ||
return push!(r.keys, r.rng.key) | ||
end | ||
|
||
Base.copy(r::TracedRNG) = TracedRNG(r.count, copy(r.rng), deepcopy(r.keys)) | ||
|
||
""" | ||
set_counter!(r::TracedRNG, n::Integer) | ||
|
||
Set the counter of the inner rng in `r`, used to keep track of the current model step | ||
""" | ||
Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n | ||
|
||
""" | ||
inc_counter!(r::TracedRNG, n::Integer=1) | ||
|
||
Increase the model step counter by `n` | ||
""" | ||
inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we dispatch on
U
,N
orV
somewhere? Otherwise (and probably even in this case) this could be simplified to