Skip to content

Conversation

FredericWantiez
Copy link
Member

Still very early, just opening it to track changes/ideas.

@devmotion
Copy link
Member

Maybe it would be good to merge #20 first?

@FredericWantiez
Copy link
Member Author

Agree, we need to build on #20 for this fix.

@yebai yebai mentioned this pull request Apr 26, 2021
@yebai
Copy link
Member

yebai commented Apr 26, 2021

Agree, we need to build on #20 for this fix.

I've merged #20.

@codecov
Copy link

codecov bot commented May 17, 2021

Codecov Report

Merging #23 (7f1b12f) into master (350e2c1) will increase coverage by 4.14%.
The diff coverage is 96.77%.

❗ Current head 7f1b12f differs from pull request most recent head 705d1dc. Consider uploading reports for the commit 705d1dc to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##           master      #23      +/-   ##
==========================================
+ Coverage   55.70%   59.85%   +4.14%     
==========================================
  Files           5        6       +1     
  Lines         368      411      +43     
==========================================
+ Hits          205      246      +41     
- Misses        163      165       +2     
Impacted Files Coverage Δ
src/container.jl 92.36% <94.59%> (-0.27%) ⬇️
src/rng.jl 100.00% <100.00%> (ø)
src/smc.jl 97.22% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 350e2c1...705d1dc. Read the comment docs.

@coveralls
Copy link

coveralls commented May 17, 2021

Pull Request Test Coverage Report for Build 1234943898

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 60 of 62 (96.77%) changed or added relevant lines in 3 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage increased (+3.9%) to 59.611%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/container.jl 35 37 94.59%
Files with Coverage Reduction New Missed Lines %
src/container.jl 1 91.67%
Totals Coverage Status
Change from base Build 946932953: 3.9%
Covered Lines: 245
Relevant Lines: 411

💛 - Coveralls

@FredericWantiez
Copy link
Member Author

@yebai I think there's an issue with using fork and passing the RNG to Trace. When we fork the particles during the resampling step, the RNG is set to the RNG we used in the original Trace definition, all children particles will use it and we won't be able to reconstruct the stream.

Also, on the side, when do we populate the reference particle when we resample in PG ?

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need a mechanism to split the random stream during the resampling step. The splitting operation

  1. allows the children particles to share the random history before the splitting point,
  2. make the random streams produce independent random numbers after the splitting point.

More concretely, we need to split the random stream when copying a trace, i.e. fork(trace):

Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask))

We need to replay the random stream when replaying a trace, i.e. forkr(trace):

function forkr(trace::Trace)

Also, on the side, when do we populate the reference particle when we resample in PG ?

I'm not sure I understand the question. The reference particle is incrementally replayed during the forward propagation of a conditional SMC step. If the reference particle has multiple children at a resampling step, one child will inherit the random stream and keep replaying, the other children will split the random stream and start producing new random numbers.

Also, happy to discuss more when we meet.

src/rng.jl Outdated


# Set seed manually, for init ?
Random.seed!(rng::TracedRNG, seed) = Random.seed!(rng.rng, seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we store the seed somewhere in TracedRNG when calling this function?

@devmotion
Copy link
Member

We'll need a mechanism to split the random stream during the resampling step. The splitting operation

1. allows the children particles to share the random history before the splitting point,

2. make the random streams produce independent random numbers after the splitting point.

Isn't this exactly what splittable RNGs are designed for (e.g. https://github.com/google/jax/blob/master/design_notes/prng.md#design, https://dl.acm.org/doi/abs/10.1145/2660193.2660195, and https://publications.lib.chalmers.se/records/fulltext/183348/local_183348.pdf)? The last reference also discusses why naive approaches for splitting linear PRNGs fail or make it hard to assess the quality of the returned random numbers (I am not sufficiently familiar with this area to determine if these all are valid concerns).

@yebai
Copy link
Member

yebai commented Jun 8, 2021

yes, we are planning to use splittable RNGs.

@FredericWantiez
Copy link
Member Author

Looked at the random module in tensorflow recently, they also use counter-based rng and the split api: https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/ops/stateful_random_ops.py#L933-L981

We could do something similar:

mutable struct TracedRNG{R,T<:Random123.AbstractR123{R}}
    count::Int
    rng::T
    keys::Array{R}
end


TracedRNG(r::Random123.AbstractR123) = TracedRNG(0, r, [r.key])


function split(r::TracedRNG, n::Integer)
    rngs = map(1:n) do i
        new_seed = hash(r.rng.key, r.rng.ctr1+i)
        new_rng = TracedRNG(r.count, Philox2x(new_seed), [r.keys..., new_seed])
        return new_rng
    end
end

rng = TracedRNG(Philox2x())
a, b = split(rng, 2)  

To replay the random numbers we should be able to cycle through the keys and reset the counters

@FredericWantiez
Copy link
Member Author

@yebai I simplified the API slightly, we don't need to track both the model counter and the rng counters. We can use the model counter if every particle has an unique key/id. Fixed the naming and the resampling step that was wrong. Also added a few tests. I can add some more, don't think we catch all the edge cases.

@yebai yebai changed the title [WIP] Fix forkr - Handle rng in Trace Fix forkr - Handle rng in Trace Aug 30, 2021
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @FredericWantiez - excellent work. I have gone through the rng.jl file and made some comments along the process. I'll take a look at other parts slightly later.

In particular, I find the function pair split and update_rng! is slightly confusing while reading the code. Given that the split and update_rng! functions are always used in pairs, I wonder whether we can consider merging these functions into split!(rng::TracedRNG)::TracedRNG for better clarity. This split! function should always split a given RNG once, mutate it in place and return the mutated RNG. If we need to split N-times, we simply call split N-times repetitively. Below is an example implementation of the split! function:

function split!(rng::TracedRNG{T}) where {T}
    key = hash(r.rng.key, convert(UInt, r.rng.ctr1)
    seed!(rng, key)
    return set_counter!(rng.rng, rng.count + 1) # rng.count + 1 is used instead of rng.count
end

In order to ensure that calling split! repetitively with the same RNG will return new RNGs with different keys, we increase the internal counter by one every time a split is performed. This way, the next time when split! is called on the same RNG, it will get a different key because r.rng.ctr1 will be different.

A side question: do we actually need to set rng.rng.ctr1 to rng.count when performing a split? If so, is there a reference on this?

@FredericWantiez
Copy link
Member Author

@devmotion @yebai Fixed the last few thing we discussed last week, let me know what you think.

src/rng.jl Outdated
import Random123: set_counter!

# Default RNG type for when nothing is specified
_BASE_RNG = Philox2x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_BASE_RNG = Philox2x
const _BASE_RNG = Philox2x

src/rng.jl Outdated
Set key and counter of inner RNG to key and the running model step
"""
function seed!(rng::TracedRNG{T}, key) where {T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function seed!(rng::TracedRNG{T}, key) where {T}
function seed!(rng::TracedRNG, key)

src/rng.jl Outdated
Track current key of the inner RNG
"""
function save_state!(r::TracedRNG{T}) where {T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function save_state!(r::TracedRNG{T}) where {T}
function save_state!(r::TracedRNG)

src/rng.jl Outdated
return push!(r.keys, r.rng.key)
end

Base.copy(r::TracedRNG{T}) where {T} = TracedRNG(r.count, copy(r.rng), copy(r.keys))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Base.copy(r::TracedRNG{T}) where {T} = TracedRNG(r.count, copy(r.rng), copy(r.keys))
Base.copy(r::TracedRNG) = TracedRNG(r.count, copy(r.rng), copy(r.keys))

I am also worried that copy(r.rng) is not sufficient to ensure that the copy is decoupled from the original RNG. Maybe it would be safer (albeit more expensive) to use deepcopy?

@yebai yebai requested a review from cpfiffer September 10, 2021 09:09
Copy link
Member

@cpfiffer cpfiffer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think David's already added most of the things I would add -- in general this seems to be fine with a few minor changes needed.

@yebai yebai merged commit 51c18d4 into TuringLang:master Sep 14, 2021
@yebai
Copy link
Member

yebai commented Sep 14, 2021

Excellent work - many thanks @FredericWantiez @devmotion and @cpfiffer!

Perhaps we can wait for a few days to see whether we miss anything, then make a new release.

@@ -1,31 +1,43 @@
struct Trace{F}
struct Trace{F,U,N,V<:Random123.AbstractR123{U}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we dispatch on U, N or V somewhere? Otherwise (and probably even in this case) this could be simplified to

struct Trace{F,R<:TracedRNG}
    f::F
    ctask::Libtask.CTask
    rng::R
end

- consume(pc::ParticleContainer): return incremental likelihood
"""
mutable struct ParticleContainer{T<:Particle}
mutable struct ParticleContainer{T<:Particle,U,N,V<:Random123.AbstractR123{U}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I guess we could use

mutable struct ParticleContainer{T<:Particle,R<:TracedRNG}
    vals::Vector{T}
    logWs::Vector{Float64}
    rng::R
end

?

k = split(pi.rng.rng.key)
Random.seed!(pi.rng, k[1])
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could add an empty return or nothing at the end to avoid that the last seed or rng is returned.

function fpc(logp)
f = let logp = logp
() -> begin
(rng) -> begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The brackets could be omitted.

@FredericWantiez FredericWantiez mentioned this pull request Sep 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants