-
Notifications
You must be signed in to change notification settings - Fork 38
Closed
Description
MLDataDevices adapts Random.TaskLocalRNG to Random.TaskLocalRNG
adapt(ReactantDevice(), Random.default_rng()) isa Reactant.ReactantRNGbut to_rarray
Reactant.to_rarray(Random.default_rng()) isa Random.TaskLocalRNGcurrently doesn't. So when users have, a data structure that contains both an RNG and some data, for example
dummy_mcmc_chain = (current = zeros(Float32, 10), rng = Random.default_rng(), buffer = zeros(Float32, 10))
dummy_mcmc_step!(chain) = chain.current .+= randn!(chain.rng, chain.buffer)
dummy_mcmc_step!(chain)
dummy_mcmc_step!(chain)
...And move it to Reactant
reactant_mcmc_chain = Reactant.to_rarray(dummy_mcmc_chain)
reactant_mcmc_step! = @compile dummy_mcmc_step!(reactant_mcmc_chain)
reactant_mcmc_step!(reactant_mcmc_chain)
the random numbers drawn are always the same, because the state of Random.default_rng() at captured at compilation time and stored in reactant_mcmc_step!.
If users don't know the details of the code they are compiling, this may be an easy trap to fall into. Should we change Reactant.to_rarray to convert RNGs, at least when it's a Random.TaskLocalRNG?
Metadata
Metadata
Assignees
Labels
No labels