Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,36 +230,42 @@ end
end

# LinearAlgebra
@reactant_overlay @noinline function LinearAlgebra.lu(x::AbstractArray; kwargs...)
if use_overlayed_version(x)
return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...)
else
return Base.inferencebarrier(LinearAlgebra.lu)(x; kwargs...)
end
end
@reactant_overlay @noinline function LinearAlgebra.lu(
x::AbstractArray, pivot::RowMaximum; kwargs...
)
if use_overlayed_version(x)
return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...)
else
return Base.inferencebarrier(LinearAlgebra.lu)(x, pivot; kwargs...)
end
end
@reactant_overlay @noinline function LinearAlgebra.lu!(x::AbstractArray; kwargs...)
if use_overlayed_version(x)
return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...)
else
return Base.inferencebarrier(LinearAlgebra.lu!)(x; kwargs...)
end
end
@reactant_overlay @noinline function LinearAlgebra.lu!(
x::AbstractArray, pivot::RowMaximum; kwargs...
## Various factorizations
## TODO: specialize for `cholesky!` --> cholcopy
factorization_copy(f::F, x, pivot) where {F} = x
factorization_copy(f::F, x) where {F} = x

for (jlop, rop, default_pivot) in (
(:lu, :overloaded_lu, RowMaximum),
(:lu!, :overloaded_lu, RowMaximum),
(:cholesky, :overloaded_cholesky, NoPivot),
(:cholesky!, :overloaded_cholesky, NoPivot),
)
if use_overlayed_version(x)
return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...)
else
return Base.inferencebarrier(LinearAlgebra.lu!)(x, pivot; kwargs...)
@eval begin
@reactant_overlay @noinline function LinearAlgebra.$(jlop)(
x::AbstractArray; kwargs...
)
if use_overlayed_version(x)
pivot = $(default_pivot)()
return TracedLinearAlgebra.$(rop)(
factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs...
)
else
return Base.inferencebarrier(LinearAlgebra.$(jlop))(x; kwargs...)
end
end

@reactant_overlay @noinline function LinearAlgebra.$(jlop)(
x::AbstractArray, pivot::$(default_pivot); kwargs...
)
if use_overlayed_version(x)
return TracedLinearAlgebra.$(rop)(
factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs...
)
else
return Base.inferencebarrier(LinearAlgebra.$(jlop))(x, pivot; kwargs...)
end
end
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Reactant
using ReactantCore:
ReactantCore, @trace, within_compile, MissingTracedValue, materialize_traced_array

using LinearAlgebra: LinearAlgebra, RowMaximum
using LinearAlgebra: LinearAlgebra, RowMaximum, NoPivot
using Random: Random, AbstractRNG
using EnumX: @enumx
using Functors: Functors, @leaf
Expand Down
17 changes: 13 additions & 4 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -730,11 +730,20 @@ end

# stack
function overloaded_stack(dims::Union{Integer,Colon}, xs)
@assert allequal([ndims(x) for x in xs]) "All arrays must have the same number of \
dimensions..."
dims = dims isa Colon ? ndims(first(xs)) + 1 : dims
dims = dims isa Colon ? nothing : dims
res = []
for x in xs
prev_dims = nothing
for x in unwrapped_broadcast(identity, xs)
cur_dims = ndims(x)
if prev_dims === nothing
prev_dims = cur_dims
else
@assert prev_dims == cur_dims "All arrays must have the same number of \
dimensions..."
end

dims === nothing && (dims = cur_dims + 1)

new_shape = ntuple(
i -> i == dims ? 1 : (i < dims ? size(x, i) : size(x, i - 1)), ndims(x) + 1
)
Expand Down
184 changes: 3 additions & 181 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!
using ..Ops: @opcall

using LinearAlgebra: LinearAlgebra, BLAS
using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum
using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum, NoPivot
using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal
using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular
using LinearAlgebra:
Expand Down Expand Up @@ -40,6 +40,8 @@ function __init__()
return nothing
end

include("factorization/Factorization.jl")

# Various Wrapper Arrays defined in LinearAlgebra
function ReactantCore.materialize_traced_array(
x::Transpose{TracedRNumber{T},<:AnyTracedRArray}
Expand Down Expand Up @@ -633,186 +635,6 @@ LinearAlgebra.transpose!(B::AnyTracedRMatrix, A::AnyTracedRMatrix) = copy!(B, tr

LinearAlgebra.adjoint!(B::AnyTracedRMatrix, A::AnyTracedRMatrix) = copy!(B, adjoint(A))

# Supports batched factorization
abstract type GeneralizedFactorization{T} <: Factorization{T} end

function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization)
return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f)
end

function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization)
return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f)
end

const GeneralizedTransposeFactorization{T} =
LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T}
const GeneralizedAdjointFactorization{T} =
LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T}

# LU Factorization
struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <:
GeneralizedFactorization{T}
factors::S
ipiv::P
perm::P
info::I
end

Base.size(lu::GeneralizedLU) = size(lu.factors)
Base.size(lu::GeneralizedLU, i) = size(lu.factors, i)
Base.ndims(lu::GeneralizedLU) = ndims(lu.factors)
function Base.copy(lu::GeneralizedLU)
return GeneralizedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info))
end

function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I}
@assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1
@assert ndims(info) == ndims(factors) - 2
return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info)
end

function overloaded_lu(x::AbstractArray, args...; kwargs...)
return overloaded_lu(Reactant.promote_to(TracedRArray, x), args...; kwargs...)
end

function overloaded_lu(
A::AnyTracedRArray{T,N}, ::RowMaximum; check::Bool=false, allowsingular::Bool=false
) where {T,N}
# TODO: don't ignore the check and allowsingular flags
# Batching here is in the last dimensions. `Ops.lu` expects the last dimensions
permdims = vcat(collect(Int64, 3:N), 1, 2)
A = @opcall transpose(materialize_traced_array(A), permdims)
factors, ipiv, perm, info = @opcall lu(A)

# Permute back to the original dimensions
perm_perm = vcat(N - 1, collect(Int64, 1:(N - 2)))
factors = @opcall transpose(factors, invperm(permdims))
ipiv = @opcall transpose(ipiv, perm_perm)
perm = @opcall transpose(perm, perm_perm)
return GeneralizedLU(factors, ipiv, perm, info)
end

function LinearAlgebra.ldiv!(
lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M}
) where {T,P,I,N,M}
@assert N == M + 1
ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...))
return B
end

function LinearAlgebra.ldiv!(
lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2}
) where {T,P,I}
B .= _lu_solve_core(lu.factors, B, lu.perm)
return B
end

function LinearAlgebra.ldiv!(
lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N}
) where {T,P,I,N}
batch_shape = size(lu.factors)[3:end]
@assert batch_shape == size(B)[3:end]

permutation = vcat(collect(Int64, 3:N), 1, 2)

factors = @opcall transpose(materialize_traced_array(lu.factors), permutation)
B_permuted = @opcall transpose(materialize_traced_array(B), permutation)
perm = @opcall transpose(
materialize_traced_array(lu.perm), vcat(collect(Int64, 2:(N - 1)), 1)
)

res = @opcall transpose(
only(
@opcall(
batch(
_lu_solve_core, [factors, B_permuted, perm], collect(Int64, batch_shape)
)
),
),
invperm(permutation),
)
B .= res
return B
end

function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T}
n = LinearAlgebra.checksquare(lu)
# TODO: check for non-singular matrices

P = prod(LinearAlgebra.diag(lu.factors))
return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P
end

function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T}
n = LinearAlgebra.checksquare(lu)
Treal = real(T)
# TODO: check for non-singular matrices

d = LinearAlgebra.diag(lu.factors)
absdet = sum(log ∘ abs, d)
P = prod(sign, d)
s = ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(Treal), one(Treal)) * P
return absdet, s
end

for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization),
aType in (:AbstractVecOrMat, :AbstractArray)

@eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType)
# TODO: implement this
error("`$(f_wrapper)` is not supported yet for LU.")
return nothing
end
end

# currently we lower inverse to lu decomposition + triangular solve. we should
# instead emit getri and lower that to a fallback if the backend doesn't support
# it.
function LinearAlgebra.inv!(lu::GeneralizedLU)
@assert ndims(lu) == 2 "Only implemented for 2D tensors"
rhs = Reactant.promote_to(
TracedRArray{Reactant.unwrapped_eltype(eltype(lu)),2}, LinearAlgebra.I(size(lu, 1))
)
ldiv!(lu, rhs)
return rhs
end

function _lu_solve_core(factors::AbstractMatrix, B::AbstractMatrix, perm::AbstractVector)
permuted_B = B[Int64.(perm), :]
return UpperTriangular(factors) \ (UnitLowerTriangular(factors) \ permuted_B)
end

# Overload \ to support batched factorization
for FT in (
:GeneralizedFactorization,
:GeneralizedTransposeFactorization,
:GeneralizedAdjointFactorization,
)
for aType in (:AbstractVecOrMat, :AbstractArray)
@eval Base.:(\)(F::$FT, B::$aType) = _overloaded_backslash(F, B)
end

@eval Base.:(\)(
F::$FT{T}, B::Union{Array{Complex{T},1},Array{Complex{T},2}}
) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B)
end

function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray)
return ldiv!(
F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B))))
)
end

function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray)
return conj!(adjoint(F.parent) \ conj.(B))
end

function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray)
return ldiv!(
F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B))))
)
end

# indexing into specific wrapepd array types
# TODO: specialize these ones. We don't need to make the arrays dense (though our passes
# should be able to optimize them out)
Expand Down
Loading
Loading