Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 10 additions & 32 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,9 @@ for randfun in (:rand, :randn, :randexp)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dims)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dims)
@warn "Reactant doesn't support sampling of $(T) with the current \
interpreter. Falling back to native interpreter." maxlog = 1
return Base.inferencebarrier(Random.$(randfun))(rng, T, dims)
end

@reactant_overlay @noinline function Random.$(randfun)(
Expand All @@ -69,13 +65,9 @@ for randfun in (:rand, :randn, :randexp)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T, dim1, dims...)
@warn "Reactant doesn't support sampling of $(T) with the current \
interpreter. Falling back to native interpreter." maxlog = 1
return Base.inferencebarrier(Random.$(randfun))(rng, T, dim1, dims...)
end

# scalars
Expand All @@ -85,13 +77,9 @@ for randfun in (:rand, :randn, :randexp)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T)
end
return error(
"Reactant doesn't support sampling of $(T) with the current interpreter."
)
# XXX: The following will lead to illegal instruction
# @warn "Reactant doesn't support sampling of $(T) with the current \
# interpreter. Falling back to native interpreter." maxlog = 1
# return Random.$(randfun)(rng, T)
@warn "Reactant doesn't support sampling of $(T) with the current \
interpreter. Falling back to native interpreter." maxlog = 1
return Base.inferencebarrier(Random.$(randfun))(rng, T)
end

# inplace
Expand All @@ -100,21 +88,11 @@ for randfun in (:rand, :randn, :randexp)
)
return TracedRandom.$(overload_randfun!)(rng, A)
end

# XXX: Uncomment once AbsInt issues with recursive calls are resolved
# @reactant_overlay @noinline function Random.$(randfun!)(
# rng::AbstractRNG, A::AbstractArray
# )
# @warn "Directly writing to an array using Random.jl functions inside \
# ReactantInterpreter will generate a constant array in the IR. Use with \
# caution." maxlog = 1
# return Random.$(randfun!)(rng, A)
# end
end
end

# LinearAlgebra.jl overloads
## `_mul!` goes through too many layers of abstractions and we aren't able to overload
## `mul!` goes through too many layers of abstractions and we aren't able to overload
## without specializing on every possible combination of types
for (cT, aT, bT) in (
(:AbstractVector, :AbstractMatrix, :AbstractVector),
Expand Down
10 changes: 2 additions & 8 deletions src/stdlibs/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,8 @@ using ..Reactant:
unwrapped_eltype
using Random: Random, AbstractRNG

@noinline function make_seed(rng::AbstractRNG=Random.RandomDevice())
# XXX: We should really be able to call this here. But with our AbsInt it leads to a
# segfault. So we'll just call it in the rand! method.
# return rand(rng, UInt64, 2)
seed = Array{UInt64}(undef, 2)
Random.rand!(rng, seed)
return seed
end
@noinline make_seed(rng::AbstractRNG=Random.RandomDevice()) =
Random.rand!(rng, Vector{UInt64}(undef, 2))

@noinline function Random.seed!(rng::TracedRNG, seed::Number)
if seed isa TracedRNumber
Expand Down
Loading