-
Notifications
You must be signed in to change notification settings - Fork 12
Fix forkr - Handle rng in Trace #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Maybe it would be good to merge #20 first? |
Agree, we need to build on #20 for this fix. |
5366257
to
da48dab
Compare
8d50100
to
8725b79
Compare
Codecov Report
@@ Coverage Diff @@
## master #23 +/- ##
==========================================
+ Coverage 55.70% 59.85% +4.14%
==========================================
Files 5 6 +1
Lines 368 411 +43
==========================================
+ Hits 205 246 +41
- Misses 163 165 +2
Continue to review full report at Codecov.
|
Pull Request Test Coverage Report for Build 1234943898Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
@yebai I think there's an issue with using fork and passing the RNG to Trace. When we fork the particles during the resampling step, the RNG is set to the RNG we used in the original Trace definition, all children particles will use it and we won't be able to reconstruct the stream. Also, on the side, when do we populate the reference particle when we resample in PG ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need a mechanism to split the random stream during the resampling step. The splitting operation
- allows the children particles to share the random history before the splitting point,
- make the random streams produce independent random numbers after the splitting point.
More concretely, we need to split the random stream when copying a trace, i.e. fork(trace)
:
AdvancedPS.jl/src/container.jl
Line 33 in 156afff
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask)) |
We need to replay the random stream when replaying a trace, i.e. forkr(trace)
:
AdvancedPS.jl/src/container.jl
Line 58 in 156afff
function forkr(trace::Trace) |
Also, on the side, when do we populate the reference particle when we resample in PG ?
I'm not sure I understand the question. The reference particle is incrementally replayed during the forward propagation of a conditional SMC step. If the reference particle has multiple children at a resampling step, one child will inherit the random stream and keep replaying, the other children will split the random stream and start producing new random numbers.
Also, happy to discuss more when we meet.
src/rng.jl
Outdated
|
||
|
||
# Set seed manually, for init ? | ||
Random.seed!(rng::TracedRNG, seed) = Random.seed!(rng.rng, seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we store the seed
somewhere in TracedRNG
when calling this function?
Isn't this exactly what splittable RNGs are designed for (e.g. https://github.com/google/jax/blob/master/design_notes/prng.md#design, https://dl.acm.org/doi/abs/10.1145/2660193.2660195, and https://publications.lib.chalmers.se/records/fulltext/183348/local_183348.pdf)? The last reference also discusses why naive approaches for splitting linear PRNGs fail or make it hard to assess the quality of the returned random numbers (I am not sufficiently familiar with this area to determine if these all are valid concerns). |
yes, we are planning to use splittable RNGs. |
Looked at the random module in tensorflow recently, they also use counter-based rng and the split api: https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/ops/stateful_random_ops.py#L933-L981 We could do something similar: mutable struct TracedRNG{R,T<:Random123.AbstractR123{R}}
count::Int
rng::T
keys::Array{R}
end
TracedRNG(r::Random123.AbstractR123) = TracedRNG(0, r, [r.key])
function split(r::TracedRNG, n::Integer)
rngs = map(1:n) do i
new_seed = hash(r.rng.key, r.rng.ctr1+i)
new_rng = TracedRNG(r.count, Philox2x(new_seed), [r.keys..., new_seed])
return new_rng
end
end
rng = TracedRNG(Philox2x())
a, b = split(rng, 2) To replay the random numbers we should be able to cycle through the keys and reset the counters |
@yebai I simplified the API slightly, we don't need to track both the model counter and the rng counters. We can use the model counter if every particle has an unique key/id. Fixed the naming and the resampling step that was wrong. Also added a few tests. I can add some more, don't think we catch all the edge cases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @FredericWantiez - excellent work. I have gone through the rng.jl
file and made some comments along the process. I'll take a look at other parts slightly later.
In particular, I find the function pair split
and update_rng!
is slightly confusing while reading the code. Given that the split
and update_rng!
functions are always used in pairs, I wonder whether we can consider merging these functions into split!(rng::TracedRNG)::TracedRNG
for better clarity. This split!
function should always split a given RNG once, mutate it in place and return the mutated RNG. If we need to split N-times, we simply call split
N-times repetitively. Below is an example implementation of the split!
function:
function split!(rng::TracedRNG{T}) where {T}
key = hash(r.rng.key, convert(UInt, r.rng.ctr1)
seed!(rng, key)
return set_counter!(rng.rng, rng.count + 1) # rng.count + 1 is used instead of rng.count
end
In order to ensure that calling split!
repetitively with the same RNG will return new RNGs with different keys, we increase the internal counter by one every time a split is performed. This way, the next time when split!
is called on the same RNG, it will get a different key because r.rng.ctr1
will be different.
A side question: do we actually need to set rng.rng.ctr1
to rng.count
when performing a split? If so, is there a reference on this?
@devmotion @yebai Fixed the last few thing we discussed last week, let me know what you think. |
src/rng.jl
Outdated
import Random123: set_counter! | ||
|
||
# Default RNG type for when nothing is specified | ||
_BASE_RNG = Philox2x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_BASE_RNG = Philox2x | |
const _BASE_RNG = Philox2x |
src/rng.jl
Outdated
Set key and counter of inner RNG to key and the running model step | ||
""" | ||
function seed!(rng::TracedRNG{T}, key) where {T} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function seed!(rng::TracedRNG{T}, key) where {T} | |
function seed!(rng::TracedRNG, key) |
src/rng.jl
Outdated
Track current key of the inner RNG | ||
""" | ||
function save_state!(r::TracedRNG{T}) where {T} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function save_state!(r::TracedRNG{T}) where {T} | |
function save_state!(r::TracedRNG) |
src/rng.jl
Outdated
return push!(r.keys, r.rng.key) | ||
end | ||
|
||
Base.copy(r::TracedRNG{T}) where {T} = TracedRNG(r.count, copy(r.rng), copy(r.keys)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Base.copy(r::TracedRNG{T}) where {T} = TracedRNG(r.count, copy(r.rng), copy(r.keys)) | |
Base.copy(r::TracedRNG) = TracedRNG(r.count, copy(r.rng), copy(r.keys)) |
I am also worried that copy(r.rng)
is not sufficient to ensure that the copy is decoupled from the original RNG. Maybe it would be safer (albeit more expensive) to use deepcopy
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think David's already added most of the things I would add -- in general this seems to be fine with a few minor changes needed.
Fix types, import and PR review
Excellent work - many thanks @FredericWantiez @devmotion and @cpfiffer! Perhaps we can wait for a few days to see whether we miss anything, then make a new release. |
@@ -1,31 +1,43 @@ | |||
struct Trace{F} | |||
struct Trace{F,U,N,V<:Random123.AbstractR123{U}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we dispatch on U
, N
or V
somewhere? Otherwise (and probably even in this case) this could be simplified to
struct Trace{F,R<:TracedRNG}
f::F
ctask::Libtask.CTask
rng::R
end
- consume(pc::ParticleContainer): return incremental likelihood | ||
""" | ||
mutable struct ParticleContainer{T<:Particle} | ||
mutable struct ParticleContainer{T<:Particle,U,N,V<:Random123.AbstractR123{U}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, I guess we could use
mutable struct ParticleContainer{T<:Particle,R<:TracedRNG}
vals::Vector{T}
logWs::Vector{Float64}
rng::R
end
?
k = split(pi.rng.rng.key) | ||
Random.seed!(pi.rng, k[1]) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we could add an empty return
or nothing
at the end to avoid that the last seed or rng is returned.
function fpc(logp) | ||
f = let logp = logp | ||
() -> begin | ||
(rng) -> begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The brackets could be omitted.
Still very early, just opening it to track changes/ideas.