Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion ext/ReactantRandom123Ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@ using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x
using Reactant: TracedRandom

TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Philox4x) = "PHILOX"

TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY"
TracedRandom.should_warn_if_not_natively_supported(::Threefry2x) = nothing
TracedRandom.make_seed(rng::Threefry2x) = UInt64[rng.key1, rng.key2]

TracedRandom.rng_algorithm(::Philox2x) = "PHILOX"
TracedRandom.should_warn_if_not_natively_supported(::Philox2x) = nothing
TracedRandom.make_seed(rng::Philox2x) = UInt64[rng.ctr1, rng.ctr2, rng.key]

end
65 changes: 65 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ Base.@nospecializeinfer function traced_tuple_type_inner(
return Tuple{TT...}
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type{<:Tuple}),
seen,
mode::TraceMode,
@nospecialize(track_numbers::Type),
@nospecialize(sharding),
@nospecialize(runtime)
)
return traced_tuple_type_inner(T, seen, mode, track_numbers, sharding, runtime)
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Core.TypeofVararg),
seen,
Expand Down Expand Up @@ -610,6 +621,33 @@ Base.@nospecializeinfer function traced_type_inner(
}
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(PT::Type{ReactantRNG{S}}),
seen,
@nospecialize(mode::TraceMode),
@nospecialize(track_numbers::Type),
@nospecialize(sharding),
@nospecialize(runtime)
) where {S}
return ReactantRNG{traced_type_inner(S, seen, mode, track_numbers, sharding, runtime)}
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(PT::Type{<:Random.AbstractRNG}),
seen,
@nospecialize(mode::TraceMode),
@nospecialize(track_numbers::Type),
@nospecialize(sharding),
@nospecialize(runtime)
)
if mode == ArrayToConcrete
return ReactantRNG{
traced_type_inner(Array{UInt64,1}, seen, mode, track_numbers, sharding, runtime)
}
end
return PT
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type),
seen,
Expand Down Expand Up @@ -1855,6 +1893,33 @@ Base.@nospecializeinfer function make_tracer(
return prev
end

Base.@nospecializeinfer function make_tracer(
seen, @nospecialize(prev::ReactantRNG), @nospecialize(path), mode; kwargs...
)
if mode == TracedToTypes
push!(path, Core.Typeof(prev))
return make_tracer(seen, prev.seed, path, mode; kwargs...)
end
return ReactantRNG(
make_tracer(seen, prev.seed, (path..., :seed), mode; kwargs...), prev.algorithm
)
end

Base.@nospecializeinfer function make_tracer(
seen, @nospecialize(prev::Random.AbstractRNG), @nospecialize(path), mode; kwargs...
)
if mode == ArrayToConcrete
TracedRandom.should_warn_if_not_natively_supported(prev)
return ReactantRNG(
make_tracer(
seen, TracedRandom.make_seed(prev), (path..., :seed), mode; kwargs...
),
TracedRandom.rng_algorithm(prev),
)
end
return prev
end

@inline function to_rarray(
@nospecialize(x);
runtime::Union{Nothing,Val{:IFRT},Val{:PJRT}}=nothing,
Expand Down
12 changes: 10 additions & 2 deletions src/stdlibs/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,16 @@ import ..Reactant: ReactantRNG

using Random: Random, AbstractRNG

@noinline make_seed(rng::AbstractRNG=Random.RandomDevice()) =
Random.rand!(rng, Vector{UInt64}(undef, 2))
@noinline function should_warn_if_not_natively_supported(rng::AbstractRNG)
@warn "The RNG $(typeof(rng)) is not natively supported by Reactant. We will convert \
this to `ReactantRNG` which will have different seed and distribution \
characteristics." maxlog = 1
return nothing
end

@noinline function make_seed(rng::AbstractRNG=Random.RandomDevice())
return Random.rand!(rng, Vector{UInt64}(undef, 2))
end

@noinline function Random.seed!(rng::ReactantRNG, seed::Number)
if seed isa TracedRNumber
Expand Down
29 changes: 29 additions & 0 deletions test/integration/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,32 @@ end
ConcreteRArray{Float64,2}
@test @jit(rand_on_device()) isa ConcreteRArray{Float32,3}
end

@testset "Tracing of Random" begin
struct RandomContainer{RNG}
rng::RNG
end

rng_st = RandomContainer(MersenneTwister(0))
rng_st_ra = Reactant.to_rarray(rng_st)
@test rng_st_ra.rng isa Reactant.ReactantRNG

fn(st) = rand(st.rng, 10000)

hlo = @code_hlo fn(rng_st_ra)
@test contains(repr(hlo), "stablehlo.rng_bit_generator")

@testset "natively supported RNGs" begin
rng = Threefry2x()
rng_ra = Reactant.to_rarray(rng)
@test rng_ra isa Reactant.ReactantRNG
@test rng_ra.seed ≈ UInt64[rng.key1, rng.key2]
@test rng_ra.algorithm == "THREE_FRY"

rng = Philox2x()
rng_ra = Reactant.to_rarray(rng)
@test rng_ra isa Reactant.ReactantRNG
@test rng_ra.seed ≈ UInt64[rng.ctr1, rng.ctr2, rng.key]
@test rng_ra.algorithm == "PHILOX"
end
end
2 changes: 0 additions & 2 deletions test/nn/lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ end
end

@testset "RNN Integration" begin
using Reactant, Lux, Enzyme, Random

model = Recurrence(RNNCell(4 => 4); ordering=BatchLastIndex())
ps, st = Reactant.to_rarray(Lux.setup(Random.default_rng(), model))

Expand Down
Loading