diff --git a/src/Tracing.jl b/src/Tracing.jl index ed9f075a75..d614fc7c3b 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1901,7 +1901,7 @@ Base.@nospecializeinfer function make_tracer( return make_tracer(seen, prev.seed, path, mode; kwargs...) end return ReactantRNG( - make_tracer(seen, prev.seed, (path..., :seed), mode; kwargs...), prev.algorithm + make_tracer(seen, prev.seed, (path..., 1), mode; kwargs...), prev.algorithm ) end @@ -1911,9 +1911,7 @@ Base.@nospecializeinfer function make_tracer( if mode == ArrayToConcrete TracedRandom.should_warn_if_not_natively_supported(prev) return ReactantRNG( - make_tracer( - seen, TracedRandom.make_seed(prev), (path..., :seed), mode; kwargs... - ), + make_tracer(seen, TracedRandom.make_seed(prev), (path..., 1), mode; kwargs...), TracedRandom.rng_algorithm(prev), ) end diff --git a/test/nn/lux.jl b/test/nn/lux.jl index 3c847ca75c..62450521c9 100644 --- a/test/nn/lux.jl +++ b/test/nn/lux.jl @@ -102,3 +102,13 @@ end res, ∂ps = @jit gradient_loss_function(model, x, ps, st) @test res isa Reactant.ConcreteRNumber end + +@testset "RNG stored in state" begin + model = Dropout(0.5f0) + ps, st = Reactant.to_rarray(Lux.setup(Random.default_rng(), model)) + + x = Reactant.to_rarray(randn(Float32, 10, 10)) + + res, st_new = @jit model(x, ps, st) + @test st_new.rng isa Reactant.ReactantRNG +end