diff --git a/src/Overlay.jl b/src/Overlay.jl index e837966cb7..f539ef8199 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -230,36 +230,42 @@ end end # LinearAlgebra -@reactant_overlay @noinline function LinearAlgebra.lu(x::AbstractArray; kwargs...) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu)(x; kwargs...) - end -end -@reactant_overlay @noinline function LinearAlgebra.lu( - x::AbstractArray, pivot::RowMaximum; kwargs... -) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu)(x, pivot; kwargs...) - end -end -@reactant_overlay @noinline function LinearAlgebra.lu!(x::AbstractArray; kwargs...) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu!)(x; kwargs...) - end -end -@reactant_overlay @noinline function LinearAlgebra.lu!( - x::AbstractArray, pivot::RowMaximum; kwargs... +## Various factorizations +## TODO: specialize for `cholesky!` --> cholcopy +factorization_copy(f::F, x, pivot) where {F} = x +factorization_copy(f::F, x) where {F} = x + +for (jlop, rop, default_pivot) in ( + (:lu, :overloaded_lu, RowMaximum), + (:lu!, :overloaded_lu, RowMaximum), + (:cholesky, :overloaded_cholesky, NoPivot), + (:cholesky!, :overloaded_cholesky, NoPivot), ) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu!)(x, pivot; kwargs...) + @eval begin + @reactant_overlay @noinline function LinearAlgebra.$(jlop)( + x::AbstractArray; kwargs... + ) + if use_overlayed_version(x) + pivot = $(default_pivot)() + return TracedLinearAlgebra.$(rop)( + factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs... + ) + else + return Base.inferencebarrier(LinearAlgebra.$(jlop))(x; kwargs...) + end + end + + @reactant_overlay @noinline function LinearAlgebra.$(jlop)( + x::AbstractArray, pivot::$(default_pivot); kwargs... + ) + if use_overlayed_version(x) + return TracedLinearAlgebra.$(rop)( + factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs... + ) + else + return Base.inferencebarrier(LinearAlgebra.$(jlop))(x, pivot; kwargs...) + end + end end end diff --git a/src/Reactant.jl b/src/Reactant.jl index 69cee2b3b0..e164886fb9 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -3,7 +3,7 @@ module Reactant using ReactantCore: ReactantCore, @trace, within_compile, MissingTracedValue, materialize_traced_array -using LinearAlgebra: LinearAlgebra, RowMaximum +using LinearAlgebra: LinearAlgebra, RowMaximum, NoPivot using Random: Random, AbstractRNG using EnumX: @enumx using Functors: Functors, @leaf diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 73dacadfb2..f01c57231b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -730,11 +730,20 @@ end # stack function overloaded_stack(dims::Union{Integer,Colon}, xs) - @assert allequal([ndims(x) for x in xs]) "All arrays must have the same number of \ - dimensions..." - dims = dims isa Colon ? ndims(first(xs)) + 1 : dims + dims = dims isa Colon ? nothing : dims res = [] - for x in xs + prev_dims = nothing + for x in unwrapped_broadcast(identity, xs) + cur_dims = ndims(x) + if prev_dims === nothing + prev_dims = cur_dims + else + @assert prev_dims == cur_dims "All arrays must have the same number of \ + dimensions..." + end + + dims === nothing && (dims = cur_dims + 1) + new_shape = ntuple( i -> i == dims ? 1 : (i < dims ? size(x, i) : size(x, i - 1)), ndims(x) + 1 ) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index a76d337ad8..ac23ae964f 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -12,7 +12,7 @@ using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data! using ..Ops: @opcall using LinearAlgebra: LinearAlgebra, BLAS -using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum +using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum, NoPivot using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular using LinearAlgebra: @@ -40,6 +40,8 @@ function __init__() return nothing end +include("factorization/Factorization.jl") + # Various Wrapper Arrays defined in LinearAlgebra function ReactantCore.materialize_traced_array( x::Transpose{TracedRNumber{T},<:AnyTracedRArray} @@ -633,186 +635,6 @@ LinearAlgebra.transpose!(B::AnyTracedRMatrix, A::AnyTracedRMatrix) = copy!(B, tr LinearAlgebra.adjoint!(B::AnyTracedRMatrix, A::AnyTracedRMatrix) = copy!(B, adjoint(A)) -# Supports batched factorization -abstract type GeneralizedFactorization{T} <: Factorization{T} end - -function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization) - return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f) -end - -function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization) - return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f) -end - -const GeneralizedTransposeFactorization{T} = - LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T} -const GeneralizedAdjointFactorization{T} = - LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T} - -# LU Factorization -struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <: - GeneralizedFactorization{T} - factors::S - ipiv::P - perm::P - info::I -end - -Base.size(lu::GeneralizedLU) = size(lu.factors) -Base.size(lu::GeneralizedLU, i) = size(lu.factors, i) -Base.ndims(lu::GeneralizedLU) = ndims(lu.factors) -function Base.copy(lu::GeneralizedLU) - return GeneralizedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info)) -end - -function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} - @assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1 - @assert ndims(info) == ndims(factors) - 2 - return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info) -end - -function overloaded_lu(x::AbstractArray, args...; kwargs...) - return overloaded_lu(Reactant.promote_to(TracedRArray, x), args...; kwargs...) -end - -function overloaded_lu( - A::AnyTracedRArray{T,N}, ::RowMaximum; check::Bool=false, allowsingular::Bool=false -) where {T,N} - # TODO: don't ignore the check and allowsingular flags - # Batching here is in the last dimensions. `Ops.lu` expects the last dimensions - permdims = vcat(collect(Int64, 3:N), 1, 2) - A = @opcall transpose(materialize_traced_array(A), permdims) - factors, ipiv, perm, info = @opcall lu(A) - - # Permute back to the original dimensions - perm_perm = vcat(N - 1, collect(Int64, 1:(N - 2))) - factors = @opcall transpose(factors, invperm(permdims)) - ipiv = @opcall transpose(ipiv, perm_perm) - perm = @opcall transpose(perm, perm_perm) - return GeneralizedLU(factors, ipiv, perm, info) -end - -function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M} -) where {T,P,I,N,M} - @assert N == M + 1 - ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...)) - return B -end - -function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2} -) where {T,P,I} - B .= _lu_solve_core(lu.factors, B, lu.perm) - return B -end - -function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N} -) where {T,P,I,N} - batch_shape = size(lu.factors)[3:end] - @assert batch_shape == size(B)[3:end] - - permutation = vcat(collect(Int64, 3:N), 1, 2) - - factors = @opcall transpose(materialize_traced_array(lu.factors), permutation) - B_permuted = @opcall transpose(materialize_traced_array(B), permutation) - perm = @opcall transpose( - materialize_traced_array(lu.perm), vcat(collect(Int64, 2:(N - 1)), 1) - ) - - res = @opcall transpose( - only( - @opcall( - batch( - _lu_solve_core, [factors, B_permuted, perm], collect(Int64, batch_shape) - ) - ), - ), - invperm(permutation), - ) - B .= res - return B -end - -function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} - n = LinearAlgebra.checksquare(lu) - # TODO: check for non-singular matrices - - P = prod(LinearAlgebra.diag(lu.factors)) - return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P -end - -function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} - n = LinearAlgebra.checksquare(lu) - Treal = real(T) - # TODO: check for non-singular matrices - - d = LinearAlgebra.diag(lu.factors) - absdet = sum(log ∘ abs, d) - P = prod(sign, d) - s = ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(Treal), one(Treal)) * P - return absdet, s -end - -for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization), - aType in (:AbstractVecOrMat, :AbstractArray) - - @eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType) - # TODO: implement this - error("`$(f_wrapper)` is not supported yet for LU.") - return nothing - end -end - -# currently we lower inverse to lu decomposition + triangular solve. we should -# instead emit getri and lower that to a fallback if the backend doesn't support -# it. -function LinearAlgebra.inv!(lu::GeneralizedLU) - @assert ndims(lu) == 2 "Only implemented for 2D tensors" - rhs = Reactant.promote_to( - TracedRArray{Reactant.unwrapped_eltype(eltype(lu)),2}, LinearAlgebra.I(size(lu, 1)) - ) - ldiv!(lu, rhs) - return rhs -end - -function _lu_solve_core(factors::AbstractMatrix, B::AbstractMatrix, perm::AbstractVector) - permuted_B = B[Int64.(perm), :] - return UpperTriangular(factors) \ (UnitLowerTriangular(factors) \ permuted_B) -end - -# Overload \ to support batched factorization -for FT in ( - :GeneralizedFactorization, - :GeneralizedTransposeFactorization, - :GeneralizedAdjointFactorization, -) - for aType in (:AbstractVecOrMat, :AbstractArray) - @eval Base.:(\)(F::$FT, B::$aType) = _overloaded_backslash(F, B) - end - - @eval Base.:(\)( - F::$FT{T}, B::Union{Array{Complex{T},1},Array{Complex{T},2}} - ) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B) -end - -function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray) - return ldiv!( - F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) - ) -end - -function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray) - return conj!(adjoint(F.parent) \ conj.(B)) -end - -function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray) - return ldiv!( - F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) - ) -end - # indexing into specific wrapepd array types # TODO: specialize these ones. We don't need to make the arrays dense (though our passes # should be able to optimize them out) diff --git a/src/stdlibs/factorization/Cholesky.jl b/src/stdlibs/factorization/Cholesky.jl new file mode 100644 index 0000000000..e0f30ad586 --- /dev/null +++ b/src/stdlibs/factorization/Cholesky.jl @@ -0,0 +1,96 @@ +struct GeneralizedCholesky{T,S<:AbstractArray,I<:Union{AbstractArray,Number}} <: + GeneralizedFactorization{T} + factors::S + uplo::Char + info::I +end + +function GeneralizedCholesky(factors::S, uplo::Char, info::I) where {S,I} + @assert ndims(info) == ndims(factors) - 2 + return GeneralizedCholesky{eltype(factors),S,I}(factors, uplo, info) +end + +Base.size(c::GeneralizedCholesky) = size(c.factors) +Base.ndims(c::GeneralizedCholesky) = ndims(c.factors) + +function overloaded_cholesky(A::AbstractArray, ::NoPivot; check::Bool=false) + return overloaded_cholesky(Reactant.promote_to(TracedRArray, A), NoPivot(); check) +end + +function overloaded_cholesky( + A::AnyTracedRArray{T,N}, ::NoPivot; check::Bool=false +) where {T,N} + # TODO: dont ignore check + # move the batching dims to the front + permdims = vcat(collect(Int64, 3:N), 1, 2) + A = @opcall transpose(materialize_traced_array(A), permdims) + + factors = @opcall cholesky(A; lower=false) + factors = @opcall transpose(factors, invperm(permdims)) + + # stablehlo doesn't return the info + info = materialize_traced_array( + dropdims( + Reactant.CallWithReactant(mapreduce)( + isfinite, &, mapslices(LinearAlgebra.triu, factors; dims=(1, 2)); dims=1:2 + ); + dims=(1, 2), + ), + ) + if N == 2 + info = TracedRNumber{Bool}((), info.mlir_data) + end + + return GeneralizedCholesky(factors, 'U', info) +end + +function LinearAlgebra.ldiv!( + F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,M} +) where {T,N,M} + @assert N == M + 1 + ldiv!(F, reshape(B, size(B, 1), 1, size(B)[2:end]...)) + return B +end + +function LinearAlgebra.ldiv!( + F::GeneralizedCholesky{T,<:AbstractArray{T,2}}, B::AbstractArray{T,2} +) where {T} + B .= _cholesky_solve_core(F.factors, B, F.uplo) + return B +end + +function LinearAlgebra.ldiv!( + F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,N} +) where {T,N} + batch_shape = size(F.factors)[3:end] + @assert batch_shape == size(B)[3:end] + + base_fn = F.uplo == 'U' ? _cholesky_solve_core_upper : _cholesky_solve_core_lower + + permutation = vcat(collect(Int64, 3:N), 1, 2) + + factors = @opcall transpose(materialize_traced_array(F.factors), permutation) + B_permuted = @opcall transpose(materialize_traced_array(B), permutation) + + res = @opcall transpose( + only(@opcall(batch(base_fn, [factors, B_permuted], collect(Int64, batch_shape)))), + invperm(permutation), + ) + B .= res + return B +end + +function _cholesky_solve_core(factors::AbstractMatrix, B::AbstractMatrix, uplo::Char) + if uplo == 'U' + return _cholesky_solve_core_upper(factors, B) + else + return _cholesky_solve_core_lower(factors, B) + end +end + +function _cholesky_solve_core_lower(factors::AbstractMatrix, B::AbstractMatrix) + return adjoint(LowerTriangular(factors)) \ (LowerTriangular(factors) \ B) +end +function _cholesky_solve_core_upper(factors::AbstractMatrix, B::AbstractMatrix) + return UpperTriangular(factors) \ (adjoint(UpperTriangular(factors)) \ B) +end diff --git a/src/stdlibs/factorization/Factorization.jl b/src/stdlibs/factorization/Factorization.jl new file mode 100644 index 0000000000..5114a47d7b --- /dev/null +++ b/src/stdlibs/factorization/Factorization.jl @@ -0,0 +1,49 @@ +# Supports batched factorization +abstract type GeneralizedFactorization{T} <: Factorization{T} end + +function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization) + return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f) +end + +function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization) + return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f) +end + +const GeneralizedTransposeFactorization{T} = + LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T} +const GeneralizedAdjointFactorization{T} = + LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T} + +include("Cholesky.jl") +include("LU.jl") + +# Overload \ to support batched factorization +for FT in ( + :GeneralizedFactorization, + :GeneralizedTransposeFactorization, + :GeneralizedAdjointFactorization, +) + for aType in (:AbstractVecOrMat, :AbstractArray) + @eval Base.:(\)(F::$FT, B::$aType) = _overloaded_backslash(F, B) + end + + @eval Base.:(\)( + F::$FT{T}, B::Union{Array{Complex{T},1},Array{Complex{T},2}} + ) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B) +end + +function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray) + return ldiv!( + F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) + ) +end + +function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray) + return conj!(adjoint(F.parent) \ conj.(B)) +end + +function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray) + return ldiv!( + F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) + ) +end diff --git a/src/stdlibs/factorization/LU.jl b/src/stdlibs/factorization/LU.jl new file mode 100644 index 0000000000..144e7b9dcd --- /dev/null +++ b/src/stdlibs/factorization/LU.jl @@ -0,0 +1,131 @@ +struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <: + GeneralizedFactorization{T} + factors::S + ipiv::P + perm::P + info::I +end + +Base.size(lu::GeneralizedLU) = size(lu.factors) +Base.size(lu::GeneralizedLU, i) = size(lu.factors, i) +Base.ndims(lu::GeneralizedLU) = ndims(lu.factors) +function Base.copy(lu::GeneralizedLU) + return GeneralizedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info)) +end + +function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} + @assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1 + @assert ndims(info) == ndims(factors) - 2 + return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info) +end + +function overloaded_lu(x::AbstractArray, args...; kwargs...) + return overloaded_lu(Reactant.promote_to(TracedRArray, x), args...; kwargs...) +end + +function overloaded_lu( + A::AnyTracedRArray{T,N}, ::RowMaximum; check::Bool=false, allowsingular::Bool=false +) where {T,N} + # TODO: don't ignore the check and allowsingular flags + # Batching here is in the last dimensions. `Ops.lu` expects the last dimensions + permdims = vcat(collect(Int64, 3:N), 1, 2) + A = @opcall transpose(materialize_traced_array(A), permdims) + factors, ipiv, perm, info = @opcall lu(A) + + # Permute back to the original dimensions + perm_perm = vcat(N - 1, collect(Int64, 1:(N - 2))) + factors = @opcall transpose(factors, invperm(permdims)) + ipiv = @opcall transpose(ipiv, perm_perm) + perm = @opcall transpose(perm, perm_perm) + return GeneralizedLU(factors, ipiv, perm, info) +end + +function LinearAlgebra.ldiv!( + lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M} +) where {T,P,I,N,M} + @assert N == M + 1 + ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...)) + return B +end + +function LinearAlgebra.ldiv!( + lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2} +) where {T,P,I} + B .= _lu_solve_core(lu.factors, B, lu.perm) + return B +end + +function LinearAlgebra.ldiv!( + lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N} +) where {T,P,I,N} + batch_shape = size(lu.factors)[3:end] + @assert batch_shape == size(B)[3:end] + + permutation = vcat(collect(Int64, 3:N), 1, 2) + + factors = @opcall transpose(materialize_traced_array(lu.factors), permutation) + B_permuted = @opcall transpose(materialize_traced_array(B), permutation) + perm = @opcall transpose( + materialize_traced_array(lu.perm), vcat(collect(Int64, 2:(N - 1)), 1) + ) + + res = @opcall transpose( + only( + @opcall( + batch( + _lu_solve_core, [factors, B_permuted, perm], collect(Int64, batch_shape) + ) + ), + ), + invperm(permutation), + ) + B .= res + return B +end + +function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} + n = LinearAlgebra.checksquare(lu) + # TODO: check for non-singular matrices + + P = prod(LinearAlgebra.diag(lu.factors)) + return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P +end + +function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} + n = LinearAlgebra.checksquare(lu) + Treal = real(T) + # TODO: check for non-singular matrices + + d = LinearAlgebra.diag(lu.factors) + absdet = sum(log ∘ abs, d) + P = prod(sign, d) + s = ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(Treal), one(Treal)) * P + return absdet, s +end + +for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization), + aType in (:AbstractVecOrMat, :AbstractArray) + + @eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType) + # TODO: implement this + error("`$(f_wrapper)` is not supported yet for LU.") + return nothing + end +end + +# currently we lower inverse to lu decomposition + triangular solve. we should +# instead emit getri and lower that to a fallback if the backend doesn't support +# it. +function LinearAlgebra.inv!(lu::GeneralizedLU) + @assert ndims(lu) == 2 "Only implemented for 2D tensors" + rhs = Reactant.promote_to( + TracedRArray{Reactant.unwrapped_eltype(eltype(lu)),2}, LinearAlgebra.I(size(lu, 1)) + ) + ldiv!(lu, rhs) + return rhs +end + +function _lu_solve_core(factors::AbstractMatrix, B::AbstractMatrix, perm::AbstractVector) + permuted_B = B[Int64.(perm), :] + return UpperTriangular(factors) \ (UnitLowerTriangular(factors) \ permuted_B) +end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 7fa390b3a3..08199af5a4 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -351,24 +351,31 @@ end end end -solve_with_lu(A, b) = lu(A) \ b -function solve_with_lu_batched(A::AbstractArray{T,N}, B::AbstractArray{T,N}) where {T,N} +solve_with_fact(f::F, A, b) where {F} = f(A) \ b +function solve_with_fact_batched( + f::F, A::AbstractArray{T,N}, B::AbstractArray{T,N} +) where {F,T,N} A2 = reshape(A, size(A, 1), size(A, 2), prod(size(A)[3:end])) B2 = reshape(B, size(B, 1), size(B, 2), prod(size(B)[3:end])) @assert size(A2, 3) == size(B2, 3) return reshape( - stack(lu(view(A2, :, :, i)) \ view(B2, :, :, i) for i in axes(A2, 3)), + stack(f(view(A2, :, :, i)) \ view(B2, :, :, i) for i in axes(A2, 3)), size(A2, 1), size(B2, 2), size(A)[3:end]..., ) end -function solve_with_lu_batched(A::AbstractArray{T,N}, b::AbstractArray{T,M}) where {T,N,M} +function solve_with_fact_batched( + f::F, A::AbstractArray{T,N}, b::AbstractArray{T,M} +) where {F,T,N,M} @assert N == M + 1 B = reshape(b, size(b, 1), 1, size(b)[2:end]...) - return dropdims(solve_with_lu_batched(A, B); dims=2) + return dropdims(solve_with_fact_batched(f, A, B); dims=2) end +solve_with_lu(A, b) = solve_with_fact(lu, A, b) +solve_with_lu_batched(A, b) = solve_with_fact_batched(lu, A, b) + @testset "LU Factorization" begin @testset "Un-batched" begin @testset for T in (Float32, Float64, ComplexF32, ComplexF64) @@ -433,6 +440,53 @@ end end end +solve_with_cholesky(A, b) = solve_with_fact(cholesky, A, b) +solve_with_cholesky_batched(A, b) = solve_with_fact_batched(cholesky, A, b) + +@testset "Cholesky Factorization" begin + @testset "Un-batched" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + (T == ComplexF64 || T == Float64) && RunningOnTPU && continue + + A = rand(T, 4, 4) + A = A * A' + A_ra = Reactant.to_rarray(A) + + b = rand(T, 4) + b_ra = Reactant.to_rarray(b) + + B = rand(T, 4, 3) + B_ra = Reactant.to_rarray(B) + + @test @jit(solve_with_cholesky(A_ra, b_ra)) ≈ solve_with_cholesky(A, b) atol = + 1e-4 rtol = 1e-2 + @test @jit(solve_with_cholesky(A_ra, B_ra)) ≈ solve_with_cholesky(A, B) atol = + 1e-4 rtol = 1e-2 + end + end + + @testset "Batched" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + (T == ComplexF64 || T == Float64) && RunningOnTPU && continue + + A = rand(T, 4, 4, 6) + A = reshape(stack((r * r' for r in eachslice(A; dims=3))), 4, 4, 3, 2) + A_ra = Reactant.to_rarray(A) + + b = rand(T, 4, 3, 2) + b_ra = Reactant.to_rarray(b) + + B = rand(T, 4, 5, 3, 2) + B_ra = Reactant.to_rarray(B) + + @test @jit(solve_with_cholesky(A_ra, b_ra)) ≈ solve_with_cholesky_batched(A, b) atol = + 1e-4 rtol = 1e-2 + @test @jit(solve_with_cholesky(A_ra, B_ra)) ≈ solve_with_cholesky_batched(A, B) atol = + 1e-4 rtol = 1e-2 + end + end +end + @testset "structure check" begin @testset "istriu" begin x = Reactant.TestUtils.construct_test_array(Float32, 8, 8)