Skip to content
64 changes: 35 additions & 29 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,36 +230,42 @@ end
end

# LinearAlgebra
@reactant_overlay @noinline function LinearAlgebra.lu(x::AbstractArray; kwargs...)
if use_overlayed_version(x)
return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...)
else
return Base.inferencebarrier(LinearAlgebra.lu)(x; kwargs...)
end
end
@reactant_overlay @noinline function LinearAlgebra.lu(
x::AbstractArray, pivot::RowMaximum; kwargs...
)
if use_overlayed_version(x)
return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...)
else
return Base.inferencebarrier(LinearAlgebra.lu)(x, pivot; kwargs...)
end
end
@reactant_overlay @noinline function LinearAlgebra.lu!(x::AbstractArray; kwargs...)
if use_overlayed_version(x)
return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...)
else
return Base.inferencebarrier(LinearAlgebra.lu!)(x; kwargs...)
end
end
@reactant_overlay @noinline function LinearAlgebra.lu!(
x::AbstractArray, pivot::RowMaximum; kwargs...
## Various factorizations
## TODO: specialize for `cholesky!` --> cholcopy
factorization_copy(f::F, x, pivot) where {F} = x
factorization_copy(f::F, x) where {F} = x

for (jlop, rop, default_pivot) in (
(:lu, :overloaded_lu, RowMaximum),
(:lu!, :overloaded_lu, RowMaximum),
(:cholesky, :overloaded_cholesky, NoPivot),
(:cholesky!, :overloaded_cholesky, NoPivot),
)
if use_overlayed_version(x)
return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...)
else
return Base.inferencebarrier(LinearAlgebra.lu!)(x, pivot; kwargs...)
@eval begin
@reactant_overlay @noinline function LinearAlgebra.$(jlop)(
x::AbstractArray; kwargs...
)
if use_overlayed_version(x)
pivot = $(default_pivot)()
return TracedLinearAlgebra.$(rop)(
factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs...
)
else
return Base.inferencebarrier(LinearAlgebra.$(jlop))(x; kwargs...)
end
end

@reactant_overlay @noinline function LinearAlgebra.$(jlop)(
x::AbstractArray, pivot::$(default_pivot); kwargs...
)
if use_overlayed_version(x)
return TracedLinearAlgebra.$(rop)(
factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs...
)
else
return Base.inferencebarrier(LinearAlgebra.$(jlop))(x, pivot; kwargs...)
end
end
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Reactant
using ReactantCore:
ReactantCore, @trace, within_compile, MissingTracedValue, materialize_traced_array

using LinearAlgebra: LinearAlgebra, RowMaximum
using LinearAlgebra: LinearAlgebra, RowMaximum, NoPivot
using Random: Random, AbstractRNG
using EnumX: @enumx
using Functors: Functors, @leaf
Expand Down
17 changes: 13 additions & 4 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -730,11 +730,20 @@ end

# stack
function overloaded_stack(dims::Union{Integer,Colon}, xs)
@assert allequal([ndims(x) for x in xs]) "All arrays must have the same number of \
dimensions..."
dims = dims isa Colon ? ndims(first(xs)) + 1 : dims
dims = dims isa Colon ? nothing : dims
res = []
for x in xs
prev_dims = nothing
for x in unwrapped_broadcast(identity, xs)
cur_dims = ndims(x)
if prev_dims === nothing
prev_dims = cur_dims
else
@assert prev_dims == cur_dims "All arrays must have the same number of \
dimensions..."
end

dims === nothing && (dims = cur_dims + 1)

new_shape = ntuple(
i -> i == dims ? 1 : (i < dims ? size(x, i) : size(x, i - 1)), ndims(x) + 1
)
Expand Down
12 changes: 12 additions & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,25 @@ Base.copy(x::TracedRNumber{T}) where {T} = TracedRNumber{T}((), x.mlir_data)
function Base.eps(::Type{TracedRNumber{T}}) where {T}
return Reactant.promote_to(TracedRNumber{T}, eps(T))
end
Base.eps(x::TracedRNumber{T}) where {T} = eps(typeof(x))

function Base.typemin(::Type{TracedRNumber{T}}) where {T}
return Reactant.promote_to(TracedRNumber{T}, typemin(T))
end
Base.typemin(x::TracedRNumber{T}) where {T} = typemin(typeof(x))

function Base.typemax(::Type{TracedRNumber{T}}) where {T}
return Reactant.promote_to(TracedRNumber{T}, typemax(T))
end
Base.typemax(x::TracedRNumber{T}) where {T} = typemax(typeof(x))

function Base.nextfloat(x::TracedRNumber{T}) where {T<:AbstractFloat}
return @opcall next_after(x, typemax(x))
end

function Base.prevfloat(x::TracedRNumber{T}) where {T<:AbstractFloat}
return @opcall next_after(x, typemin(x))
end

function Base.rtoldefault(T::Type{<:TracedRNumber})
return T(Base.rtoldefault(unwrapped_eltype(T)))
Expand Down
Loading
Loading