Skip to content

Convert Random.TaskLocalRNG in to_rarray? #1583

@oschulz

Description

@oschulz

MLDataDevices adapts Random.TaskLocalRNG to Random.TaskLocalRNG

adapt(ReactantDevice(), Random.default_rng()) isa Reactant.ReactantRNG

but to_rarray

Reactant.to_rarray(Random.default_rng()) isa Random.TaskLocalRNG

currently doesn't. So when users have, a data structure that contains both an RNG and some data, for example

dummy_mcmc_chain = (current = zeros(Float32, 10), rng = Random.default_rng(), buffer = zeros(Float32, 10))
dummy_mcmc_step!(chain) = chain.current .+= randn!(chain.rng, chain.buffer)
dummy_mcmc_step!(chain)
dummy_mcmc_step!(chain)
...

And move it to Reactant

reactant_mcmc_chain = Reactant.to_rarray(dummy_mcmc_chain)
reactant_mcmc_step! = @compile dummy_mcmc_step!(reactant_mcmc_chain)
reactant_mcmc_step!(reactant_mcmc_chain)

the random numbers drawn are always the same, because the state of Random.default_rng() at captured at compilation time and stored in reactant_mcmc_step!.

If users don't know the details of the code they are compiling, this may be an easy trap to fall into. Should we change Reactant.to_rarray to convert RNGs, at least when it's a Random.TaskLocalRNG?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions