diff --git a/ext/ReactantRandom123Ext.jl b/ext/ReactantRandom123Ext.jl index d701fdc7e4..c6e1da89e4 100644 --- a/ext/ReactantRandom123Ext.jl +++ b/ext/ReactantRandom123Ext.jl @@ -4,8 +4,14 @@ using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x using Reactant: TracedRandom TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY" -TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY" TracedRandom.rng_algorithm(::Philox4x) = "PHILOX" + +TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY" +TracedRandom.should_warn_if_not_natively_supported(::Threefry2x) = nothing +TracedRandom.make_seed(rng::Threefry2x) = UInt64[rng.key1, rng.key2] + TracedRandom.rng_algorithm(::Philox2x) = "PHILOX" +TracedRandom.should_warn_if_not_natively_supported(::Philox2x) = nothing +TracedRandom.make_seed(rng::Philox2x) = UInt64[rng.ctr1, rng.ctr2, rng.key] end diff --git a/src/Tracing.jl b/src/Tracing.jl index 04fdc97255..ed9f075a75 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -155,6 +155,17 @@ Base.@nospecializeinfer function traced_tuple_type_inner( return Tuple{TT...} end +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:Tuple}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) + return traced_tuple_type_inner(T, seen, mode, track_numbers, sharding, runtime) +end + Base.@nospecializeinfer function traced_type_inner( @nospecialize(T::Core.TypeofVararg), seen, @@ -610,6 +621,33 @@ Base.@nospecializeinfer function traced_type_inner( } end +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(PT::Type{ReactantRNG{S}}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) where {S} + return ReactantRNG{traced_type_inner(S, seen, mode, track_numbers, sharding, runtime)} +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(PT::Type{<:Random.AbstractRNG}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) + if mode == ArrayToConcrete + return ReactantRNG{ + traced_type_inner(Array{UInt64,1}, seen, mode, track_numbers, sharding, runtime) + } + end + return PT +end + Base.@nospecializeinfer function traced_type_inner( @nospecialize(T::Type), seen, @@ -1855,6 +1893,33 @@ Base.@nospecializeinfer function make_tracer( return prev end +Base.@nospecializeinfer function make_tracer( + seen, @nospecialize(prev::ReactantRNG), @nospecialize(path), mode; kwargs... +) + if mode == TracedToTypes + push!(path, Core.Typeof(prev)) + return make_tracer(seen, prev.seed, path, mode; kwargs...) + end + return ReactantRNG( + make_tracer(seen, prev.seed, (path..., :seed), mode; kwargs...), prev.algorithm + ) +end + +Base.@nospecializeinfer function make_tracer( + seen, @nospecialize(prev::Random.AbstractRNG), @nospecialize(path), mode; kwargs... +) + if mode == ArrayToConcrete + TracedRandom.should_warn_if_not_natively_supported(prev) + return ReactantRNG( + make_tracer( + seen, TracedRandom.make_seed(prev), (path..., :seed), mode; kwargs... + ), + TracedRandom.rng_algorithm(prev), + ) + end + return prev +end + @inline function to_rarray( @nospecialize(x); runtime::Union{Nothing,Val{:IFRT},Val{:PJRT}}=nothing, diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 2bda495973..de9d24b0b7 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -10,8 +10,16 @@ import ..Reactant: ReactantRNG using Random: Random, AbstractRNG -@noinline make_seed(rng::AbstractRNG=Random.RandomDevice()) = - Random.rand!(rng, Vector{UInt64}(undef, 2)) +@noinline function should_warn_if_not_natively_supported(rng::AbstractRNG) + @warn "The RNG $(typeof(rng)) is not natively supported by Reactant. We will convert \ + this to `ReactantRNG` which will have different seed and distribution \ + characteristics." maxlog = 1 + return nothing +end + +@noinline function make_seed(rng::AbstractRNG=Random.RandomDevice()) + return Random.rand!(rng, Vector{UInt64}(undef, 2)) +end @noinline function Random.seed!(rng::ReactantRNG, seed::Number) if seed isa TracedRNumber diff --git a/test/integration/random.jl b/test/integration/random.jl index c0860d1b24..ef5eb7a698 100644 --- a/test/integration/random.jl +++ b/test/integration/random.jl @@ -201,3 +201,32 @@ end ConcreteRArray{Float64,2} @test @jit(rand_on_device()) isa ConcreteRArray{Float32,3} end + +@testset "Tracing of Random" begin + struct RandomContainer{RNG} + rng::RNG + end + + rng_st = RandomContainer(MersenneTwister(0)) + rng_st_ra = Reactant.to_rarray(rng_st) + @test rng_st_ra.rng isa Reactant.ReactantRNG + + fn(st) = rand(st.rng, 10000) + + hlo = @code_hlo fn(rng_st_ra) + @test contains(repr(hlo), "stablehlo.rng_bit_generator") + + @testset "natively supported RNGs" begin + rng = Threefry2x() + rng_ra = Reactant.to_rarray(rng) + @test rng_ra isa Reactant.ReactantRNG + @test rng_ra.seed ≈ UInt64[rng.key1, rng.key2] + @test rng_ra.algorithm == "THREE_FRY" + + rng = Philox2x() + rng_ra = Reactant.to_rarray(rng) + @test rng_ra isa Reactant.ReactantRNG + @test rng_ra.seed ≈ UInt64[rng.ctr1, rng.ctr2, rng.key] + @test rng_ra.algorithm == "PHILOX" + end +end diff --git a/test/nn/lux.jl b/test/nn/lux.jl index f8e9817ae3..3c847ca75c 100644 --- a/test/nn/lux.jl +++ b/test/nn/lux.jl @@ -93,8 +93,6 @@ end end @testset "RNN Integration" begin - using Reactant, Lux, Enzyme, Random - model = Recurrence(RNNCell(4 => 4); ordering=BatchLastIndex()) ps, st = Reactant.to_rarray(Lux.setup(Random.default_rng(), model))