From e44fbecbd6278f417caaf1c295a585b4885de7c0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Oct 2025 21:08:01 -0400 Subject: [PATCH 1/6] feat: convert Random.AbstractRNG to ReactantRNG in make_tracer --- src/Tracing.jl | 61 ++++++++++++++++++++++++++++++++++++++ test/integration/random.jl | 15 ++++++++++ 2 files changed, 76 insertions(+) diff --git a/src/Tracing.jl b/src/Tracing.jl index 04fdc97255..70457f5f1b 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -610,6 +610,41 @@ Base.@nospecializeinfer function traced_type_inner( } end +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(PT::Type{ReactantRNG{S}}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) where {S} + return ReactantRNG{traced_type_inner(S, seen, mode, track_numbers, sharding, runtime)} +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(PT::Type{<:Random.AbstractRNG}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) + mode == ConcreteToTraced && throw("$(typeof(prev)) is not a concrete type") + if mode in ( + TracedTrack, + TracedToConcrete, + TracedSetPath, + TracedToTypes, + NoStopTracedTrack, + TracedToJAX, + ) + throw("$(typeof(prev)) is not a traced type") + end + return ReactantRNG{ + traced_type_inner(Array{UInt64,1}, seen, mode, track_numbers, sharding, runtime) + } +end + Base.@nospecializeinfer function traced_type_inner( @nospecialize(T::Type), seen, @@ -1855,6 +1890,32 @@ Base.@nospecializeinfer function make_tracer( return prev end +Base.@nospecializeinfer function make_tracer( + seen, @nospecialize(prev::ReactantRNG), @nospecialize(path), mode; kwargs... +) + return ReactantRNG(make_tracer(seen, prev.seed, path, mode; kwargs...), prev.algorithm) +end + +Base.@nospecializeinfer function make_tracer( + seen, @nospecialize(prev::Random.AbstractRNG), @nospecialize(path), mode; kwargs... +) + mode == ConcreteToTraced && throw("$(typeof(prev)) is not a concrete type") + if mode in ( + TracedTrack, + TracedToConcrete, + TracedSetPath, + TracedToTypes, + NoStopTracedTrack, + TracedToJAX, + ) + throw("$(typeof(prev)) is not a traced type") + end + return ReactantRNG( + make_tracer(seen, TracedRandom.make_seed(prev), path, mode; kwargs...), + TracedRandom.rng_algorithm(prev), + ) +end + @inline function to_rarray( @nospecialize(x); runtime::Union{Nothing,Val{:IFRT},Val{:PJRT}}=nothing, diff --git a/test/integration/random.jl b/test/integration/random.jl index c0860d1b24..08d1e595fb 100644 --- a/test/integration/random.jl +++ b/test/integration/random.jl @@ -201,3 +201,18 @@ end ConcreteRArray{Float64,2} @test @jit(rand_on_device()) isa ConcreteRArray{Float32,3} end + +@testset "Tracing of Random" begin + struct RandomContainer{RNG} + rng::RNG + end + + rng_st = RandomContainer(MersenneTwister(0)) + rng_st_ra = Reactant.to_rarray(rng_st) + @test rng_st_ra.rng isa Reactant.ReactantRNG + + fn(st) = rand(st.rng, 10000) + + hlo = @code_hlo fn(rng_st_ra) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") +end From 1d0b98f3f743753c44fff0e602835145bc5a5624 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Oct 2025 22:11:23 -0400 Subject: [PATCH 2/6] fix: path --- src/Tracing.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index 70457f5f1b..e81396ff94 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1893,7 +1893,9 @@ end Base.@nospecializeinfer function make_tracer( seen, @nospecialize(prev::ReactantRNG), @nospecialize(path), mode; kwargs... ) - return ReactantRNG(make_tracer(seen, prev.seed, path, mode; kwargs...), prev.algorithm) + return ReactantRNG( + make_tracer(seen, prev.seed, (path..., :seed), mode; kwargs...), prev.algorithm + ) end Base.@nospecializeinfer function make_tracer( @@ -1911,7 +1913,7 @@ Base.@nospecializeinfer function make_tracer( throw("$(typeof(prev)) is not a traced type") end return ReactantRNG( - make_tracer(seen, TracedRandom.make_seed(prev), path, mode; kwargs...), + make_tracer(seen, TracedRandom.make_seed(prev), (path..., :seed), mode; kwargs...), TracedRandom.rng_algorithm(prev), ) end From d6386d3f4fc0a61bd05e30e3d11798b5f857caa2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Oct 2025 22:51:08 -0400 Subject: [PATCH 3/6] fix: more cases --- src/Tracing.jl | 38 +++++++++++--------------------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index e81396ff94..74c9d678ab 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -629,20 +629,12 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(sharding), @nospecialize(runtime) ) - mode == ConcreteToTraced && throw("$(typeof(prev)) is not a concrete type") - if mode in ( - TracedTrack, - TracedToConcrete, - TracedSetPath, - TracedToTypes, - NoStopTracedTrack, - TracedToJAX, - ) - throw("$(typeof(prev)) is not a traced type") + if mode == ArrayToConcrete + return ReactantRNG{ + traced_type_inner(Array{UInt64,1}, seen, mode, track_numbers, sharding, runtime) + } end - return ReactantRNG{ - traced_type_inner(Array{UInt64,1}, seen, mode, track_numbers, sharding, runtime) - } + return PT end Base.@nospecializeinfer function traced_type_inner( @@ -1901,21 +1893,13 @@ end Base.@nospecializeinfer function make_tracer( seen, @nospecialize(prev::Random.AbstractRNG), @nospecialize(path), mode; kwargs... ) - mode == ConcreteToTraced && throw("$(typeof(prev)) is not a concrete type") - if mode in ( - TracedTrack, - TracedToConcrete, - TracedSetPath, - TracedToTypes, - NoStopTracedTrack, - TracedToJAX, - ) - throw("$(typeof(prev)) is not a traced type") + if mode == ArrayToConcrete + return ReactantRNG( + make_tracer(seen, TracedRandom.make_seed(prev), (path..., :seed), mode; kwargs...), + TracedRandom.rng_algorithm(prev), + ) end - return ReactantRNG( - make_tracer(seen, TracedRandom.make_seed(prev), (path..., :seed), mode; kwargs...), - TracedRandom.rng_algorithm(prev), - ) + return prev end @inline function to_rarray( From 31ba1d2cc16858da2ebd427a3f2730d42482280b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Oct 2025 23:10:49 -0400 Subject: [PATCH 4/6] Update src/Tracing.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Tracing.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index 74c9d678ab..9baf8bb286 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1895,7 +1895,9 @@ Base.@nospecializeinfer function make_tracer( ) if mode == ArrayToConcrete return ReactantRNG( - make_tracer(seen, TracedRandom.make_seed(prev), (path..., :seed), mode; kwargs...), + make_tracer( + seen, TracedRandom.make_seed(prev), (path..., :seed), mode; kwargs... + ), TracedRandom.rng_algorithm(prev), ) end From ccb9af40ac44b1023b57c5f5f96fbebaeeb77107 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 18 Oct 2025 09:55:03 -0400 Subject: [PATCH 5/6] fix: tuple tracing --- src/Tracing.jl | 15 +++++++++++++++ test/nn/lux.jl | 2 -- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index 9baf8bb286..31ed05495b 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -155,6 +155,17 @@ Base.@nospecializeinfer function traced_tuple_type_inner( return Tuple{TT...} end +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:Tuple}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) + return traced_tuple_type_inner(T, seen, mode, track_numbers, sharding, runtime) +end + Base.@nospecializeinfer function traced_type_inner( @nospecialize(T::Core.TypeofVararg), seen, @@ -1885,6 +1896,10 @@ end Base.@nospecializeinfer function make_tracer( seen, @nospecialize(prev::ReactantRNG), @nospecialize(path), mode; kwargs... ) + if mode == TracedToTypes + push!(path, Core.Typeof(prev)) + return make_tracer(seen, prev.seed, path, mode; kwargs...) + end return ReactantRNG( make_tracer(seen, prev.seed, (path..., :seed), mode; kwargs...), prev.algorithm ) diff --git a/test/nn/lux.jl b/test/nn/lux.jl index f8e9817ae3..3c847ca75c 100644 --- a/test/nn/lux.jl +++ b/test/nn/lux.jl @@ -93,8 +93,6 @@ end end @testset "RNN Integration" begin - using Reactant, Lux, Enzyme, Random - model = Recurrence(RNNCell(4 => 4); ordering=BatchLastIndex()) ps, st = Reactant.to_rarray(Lux.setup(Random.default_rng(), model)) From b66b621a3945451a7a1fd5535f7c2baa72144a10 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 18 Oct 2025 12:10:30 -0400 Subject: [PATCH 6/6] fix: warn if making a semantically different conversion --- ext/ReactantRandom123Ext.jl | 8 +++++++- src/Tracing.jl | 1 + src/stdlibs/Random.jl | 12 ++++++++++-- test/integration/random.jl | 14 ++++++++++++++ 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/ext/ReactantRandom123Ext.jl b/ext/ReactantRandom123Ext.jl index d701fdc7e4..c6e1da89e4 100644 --- a/ext/ReactantRandom123Ext.jl +++ b/ext/ReactantRandom123Ext.jl @@ -4,8 +4,14 @@ using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x using Reactant: TracedRandom TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY" -TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY" TracedRandom.rng_algorithm(::Philox4x) = "PHILOX" + +TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY" +TracedRandom.should_warn_if_not_natively_supported(::Threefry2x) = nothing +TracedRandom.make_seed(rng::Threefry2x) = UInt64[rng.key1, rng.key2] + TracedRandom.rng_algorithm(::Philox2x) = "PHILOX" +TracedRandom.should_warn_if_not_natively_supported(::Philox2x) = nothing +TracedRandom.make_seed(rng::Philox2x) = UInt64[rng.ctr1, rng.ctr2, rng.key] end diff --git a/src/Tracing.jl b/src/Tracing.jl index 31ed05495b..ed9f075a75 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1909,6 +1909,7 @@ Base.@nospecializeinfer function make_tracer( seen, @nospecialize(prev::Random.AbstractRNG), @nospecialize(path), mode; kwargs... ) if mode == ArrayToConcrete + TracedRandom.should_warn_if_not_natively_supported(prev) return ReactantRNG( make_tracer( seen, TracedRandom.make_seed(prev), (path..., :seed), mode; kwargs... diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 2bda495973..de9d24b0b7 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -10,8 +10,16 @@ import ..Reactant: ReactantRNG using Random: Random, AbstractRNG -@noinline make_seed(rng::AbstractRNG=Random.RandomDevice()) = - Random.rand!(rng, Vector{UInt64}(undef, 2)) +@noinline function should_warn_if_not_natively_supported(rng::AbstractRNG) + @warn "The RNG $(typeof(rng)) is not natively supported by Reactant. We will convert \ + this to `ReactantRNG` which will have different seed and distribution \ + characteristics." maxlog = 1 + return nothing +end + +@noinline function make_seed(rng::AbstractRNG=Random.RandomDevice()) + return Random.rand!(rng, Vector{UInt64}(undef, 2)) +end @noinline function Random.seed!(rng::ReactantRNG, seed::Number) if seed isa TracedRNumber diff --git a/test/integration/random.jl b/test/integration/random.jl index 08d1e595fb..ef5eb7a698 100644 --- a/test/integration/random.jl +++ b/test/integration/random.jl @@ -215,4 +215,18 @@ end hlo = @code_hlo fn(rng_st_ra) @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + @testset "natively supported RNGs" begin + rng = Threefry2x() + rng_ra = Reactant.to_rarray(rng) + @test rng_ra isa Reactant.ReactantRNG + @test rng_ra.seed ≈ UInt64[rng.key1, rng.key2] + @test rng_ra.algorithm == "THREE_FRY" + + rng = Philox2x() + rng_ra = Reactant.to_rarray(rng) + @test rng_ra isa Reactant.ReactantRNG + @test rng_ra.seed ≈ UInt64[rng.ctr1, rng.ctr2, rng.key] + @test rng_ra.algorithm == "PHILOX" + end end