diff --git a/src/Overlay.jl b/src/Overlay.jl index e837966cb7..f539ef8199 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -230,36 +230,42 @@ end end # LinearAlgebra -@reactant_overlay @noinline function LinearAlgebra.lu(x::AbstractArray; kwargs...) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu)(x; kwargs...) - end -end -@reactant_overlay @noinline function LinearAlgebra.lu( - x::AbstractArray, pivot::RowMaximum; kwargs... -) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu)(x, pivot; kwargs...) - end -end -@reactant_overlay @noinline function LinearAlgebra.lu!(x::AbstractArray; kwargs...) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu!)(x; kwargs...) - end -end -@reactant_overlay @noinline function LinearAlgebra.lu!( - x::AbstractArray, pivot::RowMaximum; kwargs... +## Various factorizations +## TODO: specialize for `cholesky!` --> cholcopy +factorization_copy(f::F, x, pivot) where {F} = x +factorization_copy(f::F, x) where {F} = x + +for (jlop, rop, default_pivot) in ( + (:lu, :overloaded_lu, RowMaximum), + (:lu!, :overloaded_lu, RowMaximum), + (:cholesky, :overloaded_cholesky, NoPivot), + (:cholesky!, :overloaded_cholesky, NoPivot), ) - if use_overlayed_version(x) - return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...) - else - return Base.inferencebarrier(LinearAlgebra.lu!)(x, pivot; kwargs...) + @eval begin + @reactant_overlay @noinline function LinearAlgebra.$(jlop)( + x::AbstractArray; kwargs... + ) + if use_overlayed_version(x) + pivot = $(default_pivot)() + return TracedLinearAlgebra.$(rop)( + factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs... + ) + else + return Base.inferencebarrier(LinearAlgebra.$(jlop))(x; kwargs...) + end + end + + @reactant_overlay @noinline function LinearAlgebra.$(jlop)( + x::AbstractArray, pivot::$(default_pivot); kwargs... + ) + if use_overlayed_version(x) + return TracedLinearAlgebra.$(rop)( + factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs... + ) + else + return Base.inferencebarrier(LinearAlgebra.$(jlop))(x, pivot; kwargs...) + end + end end end diff --git a/src/Reactant.jl b/src/Reactant.jl index 69cee2b3b0..e164886fb9 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -3,7 +3,7 @@ module Reactant using ReactantCore: ReactantCore, @trace, within_compile, MissingTracedValue, materialize_traced_array -using LinearAlgebra: LinearAlgebra, RowMaximum +using LinearAlgebra: LinearAlgebra, RowMaximum, NoPivot using Random: Random, AbstractRNG using EnumX: @enumx using Functors: Functors, @leaf diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 73dacadfb2..f01c57231b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -730,11 +730,20 @@ end # stack function overloaded_stack(dims::Union{Integer,Colon}, xs) - @assert allequal([ndims(x) for x in xs]) "All arrays must have the same number of \ - dimensions..." - dims = dims isa Colon ? ndims(first(xs)) + 1 : dims + dims = dims isa Colon ? nothing : dims res = [] - for x in xs + prev_dims = nothing + for x in unwrapped_broadcast(identity, xs) + cur_dims = ndims(x) + if prev_dims === nothing + prev_dims = cur_dims + else + @assert prev_dims == cur_dims "All arrays must have the same number of \ + dimensions..." + end + + dims === nothing && (dims = cur_dims + 1) + new_shape = ntuple( i -> i == dims ? 1 : (i < dims ? size(x, i) : size(x, i - 1)), ndims(x) + 1 ) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 575c53c7eb..8b53c5237e 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -26,13 +26,25 @@ Base.copy(x::TracedRNumber{T}) where {T} = TracedRNumber{T}((), x.mlir_data) function Base.eps(::Type{TracedRNumber{T}}) where {T} return Reactant.promote_to(TracedRNumber{T}, eps(T)) end +Base.eps(x::TracedRNumber{T}) where {T} = eps(typeof(x)) function Base.typemin(::Type{TracedRNumber{T}}) where {T} return Reactant.promote_to(TracedRNumber{T}, typemin(T)) end +Base.typemin(x::TracedRNumber{T}) where {T} = typemin(typeof(x)) + function Base.typemax(::Type{TracedRNumber{T}}) where {T} return Reactant.promote_to(TracedRNumber{T}, typemax(T)) end +Base.typemax(x::TracedRNumber{T}) where {T} = typemax(typeof(x)) + +function Base.nextfloat(x::TracedRNumber{T}) where {T<:AbstractFloat} + return @opcall next_after(x, typemax(x)) +end + +function Base.prevfloat(x::TracedRNumber{T}) where {T<:AbstractFloat} + return @opcall next_after(x, typemin(x)) +end function Base.rtoldefault(T::Type{<:TracedRNumber}) return T(Base.rtoldefault(unwrapped_eltype(T))) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index b92a5d1177..ff67859c5e 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -5,18 +5,20 @@ using ..Reactant: Reactant, Ops using ..Reactant: TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector using ..Reactant: call_with_reactant -using ReactantCore: ReactantCore, materialize_traced_array +using ReactantCore: ReactantCore, materialize_traced_array, @trace using Reactant_jll: Reactant_jll using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data! using ..Ops: @opcall using LinearAlgebra: LinearAlgebra, BLAS -using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum +using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum, NoPivot using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular -using LinearAlgebra: diag, diagm, ldiv! +using LinearAlgebra: + diag, diagm, ldiv!, det, logabsdet, lu, istriu, istril, triu!, tril!, inv!, rmul! using Libdl: Libdl +using GPUArraysCore: @allowscalar function __init__() if Reactant_jll.is_available() @@ -38,6 +40,8 @@ function __init__() return nothing end +include("factorization/Factorization.jl") + # Various Wrapper Arrays defined in LinearAlgebra function ReactantCore.materialize_traced_array( x::Transpose{TracedRNumber{T},<:AnyTracedRArray} @@ -327,7 +331,7 @@ end # LinearAlgebra defines norm with some conditionals which cannot be traced directly function LinearAlgebra.norm(x::TracedRArray{T,N}, p::Real=2) where {T,N} isinf(p) && return maximum(abs, x) - return mapreduce(Base.Fix2(^, p), +, x)^(1 / p) + return mapreduce(Base.Fix2(^, p), +, x)^(T(1 / p)) end function LinearAlgebra._diagm(shape, kv::Pair{<:Integer,<:AnyTracedRVector}...) @@ -631,203 +635,163 @@ LinearAlgebra.transpose!(B::AnyTracedRMatrix, A::AnyTracedRMatrix) = copy!(B, tr LinearAlgebra.adjoint!(B::AnyTracedRMatrix, A::AnyTracedRMatrix) = copy!(B, adjoint(A)) -# Supports batched factorization -abstract type GeneralizedFactorization{T} <: Factorization{T} end - -function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization) - return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f) -end - -function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization) - return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f) -end - -const GeneralizedTransposeFactorization{T} = - LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T} -const GeneralizedAdjointFactorization{T} = - LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T} - -# LU Factorization -struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <: - GeneralizedFactorization{T} - factors::S - ipiv::P - perm::P - info::I +# indexing into specific wrapepd array types +# TODO: specialize these ones. We don't need to make the arrays dense (though our passes +# should be able to optimize them out) +for AT in ( + Bidiagonal{<:TracedRNumber}, + LowerTriangular{<:TracedRNumber}, + UpperTriangular{<:TracedRNumber}, + LinearAlgebra.Hermitian{<:TracedRNumber}, + SymTridiagonal{<:TracedRNumber}, + Tridiagonal{<:TracedRNumber}, + Symmetric{<:TracedRNumber}, + LinearAlgebra.UnitUpperTriangular{<:TracedRNumber}, + LinearAlgebra.UnitLowerTriangular{<:TracedRNumber}, + LinearAlgebra.UpperHessenberg{<:TracedRNumber}, +) + @eval function Base.getindex(A::$AT, i::Int, j::Int) + return getindex(materialize_traced_array(A), i, j) + end end -Base.size(lu::GeneralizedLU) = size(lu.factors) - -Base.ndims(lu::GeneralizedLU) = ndims(lu.factors) +LinearAlgebra._istriu(A::AnyTracedRMatrix, k) = all(iszero, overloaded_tril(A, k - 1)) +LinearAlgebra._istril(A::AnyTracedRMatrix, k) = all(iszero, overloaded_triu(A, k + 1)) -function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} - @assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1 - @assert ndims(info) == ndims(factors) - 2 - return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info) +# Only needed because we lack automatic if tracing +function LinearAlgebra.det(A::AnyTracedRMatrix) + @trace if istriu(A) || istril(A) + _det = det(UpperTriangular(A)) + else + _det = det(lu(A; check=false)) + end + return _det end -function overloaded_lu(x::AbstractArray, args...; kwargs...) - return overloaded_lu(Reactant.promote_to(TracedRArray, x), args...; kwargs...) +function LinearAlgebra.logabsdet(A::AnyTracedRMatrix) + @trace if istriu(A) || istril(A) + _logabsdet = logabsdet(UpperTriangular(A)) + else + _logabsdet = logabsdet(lu(A; check=false)) + end + return _logabsdet end -function overloaded_lu( - A::AnyTracedRArray{T,N}, ::RowMaximum; check::Bool=false, allowsingular::Bool=false -) where {T,N} - # TODO: don't ignore the check and allowsingular flags - # Batching here is in the last dimensions. `Ops.lu` expects the last dimensions - permdims = vcat(collect(Int64, 3:N), 1, 2) - A = @opcall transpose(materialize_traced_array(A), permdims) - factors, ipiv, perm, info = @opcall lu(A) - - # Permute back to the original dimensions - perm_perm = vcat(N - 1, collect(Int64, 1:(N - 2))) - factors = @opcall transpose(factors, invperm(permdims)) - ipiv = @opcall transpose(ipiv, perm_perm) - perm = @opcall transpose(perm, perm_perm) - return GeneralizedLU(factors, ipiv, perm, info) +function LinearAlgebra.logabsdet( + A::Union{UpperTriangular{T,<:AnyTracedRMatrix},LowerTriangular{T,<:AnyTracedRMatrix}} +) where {T} + d = LinearAlgebra.diag(A) + sgn = prod(sign, d) + abs_det = sum(log ∘ abs, d) + return abs_det, sgn +end + +function Base.inv(A::TracedRArray{T,2}) where {T} # don't overload Any* here + LinearAlgebra.checksquare(A) + @trace if istriu(A) + Ai = triu!(parent(inv(UpperTriangular(A)))) + elseif istril(A) + Ai = tril!(parent(inv(LowerTriangular(A)))) + else + Ai = inv!(lu(A; check=false)) + end + return Ai end -function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M} -) where {T,P,I,N,M} - @assert N == M + 1 - ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...)) - return B +for (wT, lower, ud) in ( + (:UpperTriangular, false, false), + (:LowerTriangular, true, false), + (:UnitUpperTriangular, false, true), + (:UnitLowerTriangular, true, true), +) + @eval function Base.inv(A::LinearAlgebra.$(wT){T,<:AnyTracedRMatrix}) where {T} + S = typeof(inv(oneunit(Reactant.unwrapped_eltype(T)))) + rhs = Reactant.promote_to(TracedRArray{S,2}, LinearAlgebra.I(size(A, 1))) + return @opcall triangular_solve( + parent(A), + rhs; + left_side=false, + lower=$(lower), + transpose_a='N', + unit_diagonal=$(ud), + ) + end end -function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2} -) where {T,P,I} - B .= _lu_solve_core(lu.factors, B, lu.perm) - return B +function LinearAlgebra.cross(x::AnyTracedRVector, y::AbstractVector) + return LinearAlgebra.cross(x, Reactant.promote_to(TracedRArray{eltype(y),1}, y)) end -function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N} -) where {T,P,I,N} - batch_shape = size(lu.factors)[3:end] - @assert batch_shape == size(B)[3:end] - - permutation = vcat(collect(Int64, 3:N), 1, 2) - - factors = @opcall transpose(materialize_traced_array(lu.factors), permutation) - B_permuted = @opcall transpose(materialize_traced_array(B), permutation) - perm = @opcall transpose( - materialize_traced_array(lu.perm), vcat(collect(Int64, 2:(N - 1)), 1) - ) - - res = @opcall transpose( - only( - @opcall( - batch( - _lu_solve_core, [factors, B_permuted, perm], collect(Int64, batch_shape) - ) - ), - ), - invperm(permutation), - ) - B .= res - return B +function LinearAlgebra.cross(x::AbstractVector, y::AnyTracedRVector) + return LinearAlgebra.cross(Reactant.promote_to(TracedRArray{eltype(x),1}, x), y) end -function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} - n = LinearAlgebra.checksquare(lu) - # TODO: check for non-singular matrices - - P = prod(LinearAlgebra.diag(lu.factors)) - return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P +function LinearAlgebra.cross(x::AnyTracedRVector, y::AnyTracedRVector) + x_ = materialize_traced_array(x) + y_ = materialize_traced_array(y) + @allowscalar a1, a2, a3 = x_ + @allowscalar b1, b2, b3 = y_ + return Reactant.aos_to_soa([a2 * b3 - a3 * b2, a3 * b1 - a1 * b3, a1 * b2 - a2 * b1]) end -function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} - n = LinearAlgebra.checksquare(lu) - Treal = real(T) - # TODO: check for non-singular matrices - - d = LinearAlgebra.diag(lu.factors) - absdet = sum(log ∘ abs, d) - P = prod(sign, d) - s = ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(Treal), one(Treal)) * P - return absdet, s +function LinearAlgebra.issymmetric(A::AnyTracedRMatrix) + axes(A, 1) == axes(A, 2) || return false + return all(A .== transpose(A)) end -for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization), - aType in (:AbstractVecOrMat, :AbstractArray) - - @eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType) - # TODO: implement this - error("`$(f_wrapper)` is not supported yet for LU.") - return nothing - end +function LinearAlgebra.ishermitian(A::AnyTracedRMatrix) + axes(A, 1) == axes(A, 2) || return false + return all(A .== adjoint(A)) end -function _lu_solve_core(factors::AbstractMatrix, B::AbstractMatrix, perm::AbstractVector) - permuted_B = B[Int64.(perm), :] - return UpperTriangular(factors) \ (UnitLowerTriangular(factors) \ permuted_B) +function LinearAlgebra.isbanded(A::AnyTracedRMatrix, kl::Integer, ku::Integer) + return istriu(A, kl) & istril(A, ku) end -# Overload \ to support batched factorization -for FT in ( - :GeneralizedFactorization, - :GeneralizedTransposeFactorization, - :GeneralizedAdjointFactorization, -) - for aType in (:AbstractVecOrMat, :AbstractArray) - @eval Base.:(\)(F::$FT, B::$aType) = _overloaded_backslash(F, B) +function LinearAlgebra.normalize(a::AnyTracedRArray{T}, p::Real=2) where {T} + nrm = LinearAlgebra.norm(a, p) + if !isempty(a) + aa = LinearAlgebra.copymutable_oftype(a, typeof(zero(T) / nrm)) + return LinearAlgebra.__normalize!(aa, nrm) + else + return typeof(zero(T) / nrm)[] end - - @eval Base.:(\)( - F::$FT{T}, B::Union{Array{Complex{T},1},Array{Complex{T},2}} - ) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B) end -function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray) - return ldiv!( - F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) - ) -end - -function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray) - return conj!(adjoint(F.parent) \ conj.(B)) +@static if isdefined(LinearAlgebra, :__normalize!) + function LinearAlgebra.__normalize!(a::AnyTracedRArray, nrm) + # The largest positive floating point number whose inverse is less than infinity + δ = inv(prevfloat(typemax(nrm))) + @trace if nrm ≥ δ # Safe to multiply with inverse + invnrm = inv(nrm) + rmul!(a, invnrm) + else # scale elements to avoid overflow + εδ = eps(one(nrm)) / δ + rmul!(a, εδ) + rmul!(a, inv(nrm * εδ)) + end + return a + end end -function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray) - return ldiv!( - F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) - ) +function LinearAlgebra.rmul!(A::AnyTracedRArray, b::Number) + @. A *= b + return A end -# indexing into specific wrapepd array types -# TODO: specialize these ones. We don't need to make the arrays dense (though our passes -# should be able to optimize them out) -for AT in ( - Bidiagonal{<:TracedRNumber}, - LowerTriangular{<:TracedRNumber}, - UpperTriangular{<:TracedRNumber}, - LinearAlgebra.Hermitian{<:TracedRNumber}, - SymTridiagonal{<:TracedRNumber}, - Tridiagonal{<:TracedRNumber}, - Symmetric{<:TracedRNumber}, - LinearAlgebra.UnitUpperTriangular{<:TracedRNumber}, - LinearAlgebra.UnitLowerTriangular{<:TracedRNumber}, - LinearAlgebra.UpperHessenberg{<:TracedRNumber}, -) - @eval function Base.getindex(A::$AT, i::Int, j::Int) - return getindex(materialize_traced_array(A), i, j) - end +function LinearAlgebra.lmul!(b::Number, A::AnyTracedRArray) + @. A = b * A + return A end -LinearAlgebra._istriu(A::AnyTracedRMatrix, k) = all(iszero, overloaded_tril(A, k - 1)) -LinearAlgebra._istril(A::AnyTracedRMatrix, k) = all(iszero, overloaded_triu(A, k + 1)) - -# Only needed because we lack automatic if tracing -function LinearAlgebra.det(A::AnyTracedRMatrix) - # FIXME: using @trace here produces the cryptic UndefVarError - return LinearAlgebra.det(LinearAlgebra.lu(A; check=false)) +function LinearAlgebra.rdiv!(A::AnyTracedRArray, b::Number) + @. A /= b + return A end -function LinearAlgebra.logabsdet(A::AnyTracedRMatrix) - # FIXME: using @trace here produces the cryptic UndefVarError - return LinearAlgebra.logabsdet(LinearAlgebra.lu(A; check=false)) +function LinearAlgebra.ldiv!(b::Number, A::AnyTracedRArray) + @. A = b \ A + return A end end diff --git a/src/stdlibs/factorization/Cholesky.jl b/src/stdlibs/factorization/Cholesky.jl new file mode 100644 index 0000000000..e0f30ad586 --- /dev/null +++ b/src/stdlibs/factorization/Cholesky.jl @@ -0,0 +1,96 @@ +struct GeneralizedCholesky{T,S<:AbstractArray,I<:Union{AbstractArray,Number}} <: + GeneralizedFactorization{T} + factors::S + uplo::Char + info::I +end + +function GeneralizedCholesky(factors::S, uplo::Char, info::I) where {S,I} + @assert ndims(info) == ndims(factors) - 2 + return GeneralizedCholesky{eltype(factors),S,I}(factors, uplo, info) +end + +Base.size(c::GeneralizedCholesky) = size(c.factors) +Base.ndims(c::GeneralizedCholesky) = ndims(c.factors) + +function overloaded_cholesky(A::AbstractArray, ::NoPivot; check::Bool=false) + return overloaded_cholesky(Reactant.promote_to(TracedRArray, A), NoPivot(); check) +end + +function overloaded_cholesky( + A::AnyTracedRArray{T,N}, ::NoPivot; check::Bool=false +) where {T,N} + # TODO: dont ignore check + # move the batching dims to the front + permdims = vcat(collect(Int64, 3:N), 1, 2) + A = @opcall transpose(materialize_traced_array(A), permdims) + + factors = @opcall cholesky(A; lower=false) + factors = @opcall transpose(factors, invperm(permdims)) + + # stablehlo doesn't return the info + info = materialize_traced_array( + dropdims( + Reactant.CallWithReactant(mapreduce)( + isfinite, &, mapslices(LinearAlgebra.triu, factors; dims=(1, 2)); dims=1:2 + ); + dims=(1, 2), + ), + ) + if N == 2 + info = TracedRNumber{Bool}((), info.mlir_data) + end + + return GeneralizedCholesky(factors, 'U', info) +end + +function LinearAlgebra.ldiv!( + F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,M} +) where {T,N,M} + @assert N == M + 1 + ldiv!(F, reshape(B, size(B, 1), 1, size(B)[2:end]...)) + return B +end + +function LinearAlgebra.ldiv!( + F::GeneralizedCholesky{T,<:AbstractArray{T,2}}, B::AbstractArray{T,2} +) where {T} + B .= _cholesky_solve_core(F.factors, B, F.uplo) + return B +end + +function LinearAlgebra.ldiv!( + F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,N} +) where {T,N} + batch_shape = size(F.factors)[3:end] + @assert batch_shape == size(B)[3:end] + + base_fn = F.uplo == 'U' ? _cholesky_solve_core_upper : _cholesky_solve_core_lower + + permutation = vcat(collect(Int64, 3:N), 1, 2) + + factors = @opcall transpose(materialize_traced_array(F.factors), permutation) + B_permuted = @opcall transpose(materialize_traced_array(B), permutation) + + res = @opcall transpose( + only(@opcall(batch(base_fn, [factors, B_permuted], collect(Int64, batch_shape)))), + invperm(permutation), + ) + B .= res + return B +end + +function _cholesky_solve_core(factors::AbstractMatrix, B::AbstractMatrix, uplo::Char) + if uplo == 'U' + return _cholesky_solve_core_upper(factors, B) + else + return _cholesky_solve_core_lower(factors, B) + end +end + +function _cholesky_solve_core_lower(factors::AbstractMatrix, B::AbstractMatrix) + return adjoint(LowerTriangular(factors)) \ (LowerTriangular(factors) \ B) +end +function _cholesky_solve_core_upper(factors::AbstractMatrix, B::AbstractMatrix) + return UpperTriangular(factors) \ (adjoint(UpperTriangular(factors)) \ B) +end diff --git a/src/stdlibs/factorization/Factorization.jl b/src/stdlibs/factorization/Factorization.jl new file mode 100644 index 0000000000..5114a47d7b --- /dev/null +++ b/src/stdlibs/factorization/Factorization.jl @@ -0,0 +1,49 @@ +# Supports batched factorization +abstract type GeneralizedFactorization{T} <: Factorization{T} end + +function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization) + return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f) +end + +function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization) + return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f) +end + +const GeneralizedTransposeFactorization{T} = + LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T} +const GeneralizedAdjointFactorization{T} = + LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T} + +include("Cholesky.jl") +include("LU.jl") + +# Overload \ to support batched factorization +for FT in ( + :GeneralizedFactorization, + :GeneralizedTransposeFactorization, + :GeneralizedAdjointFactorization, +) + for aType in (:AbstractVecOrMat, :AbstractArray) + @eval Base.:(\)(F::$FT, B::$aType) = _overloaded_backslash(F, B) + end + + @eval Base.:(\)( + F::$FT{T}, B::Union{Array{Complex{T},1},Array{Complex{T},2}} + ) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B) +end + +function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray) + return ldiv!( + F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) + ) +end + +function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray) + return conj!(adjoint(F.parent) \ conj.(B)) +end + +function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray) + return ldiv!( + F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) + ) +end diff --git a/src/stdlibs/factorization/LU.jl b/src/stdlibs/factorization/LU.jl new file mode 100644 index 0000000000..144e7b9dcd --- /dev/null +++ b/src/stdlibs/factorization/LU.jl @@ -0,0 +1,131 @@ +struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <: + GeneralizedFactorization{T} + factors::S + ipiv::P + perm::P + info::I +end + +Base.size(lu::GeneralizedLU) = size(lu.factors) +Base.size(lu::GeneralizedLU, i) = size(lu.factors, i) +Base.ndims(lu::GeneralizedLU) = ndims(lu.factors) +function Base.copy(lu::GeneralizedLU) + return GeneralizedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info)) +end + +function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} + @assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1 + @assert ndims(info) == ndims(factors) - 2 + return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info) +end + +function overloaded_lu(x::AbstractArray, args...; kwargs...) + return overloaded_lu(Reactant.promote_to(TracedRArray, x), args...; kwargs...) +end + +function overloaded_lu( + A::AnyTracedRArray{T,N}, ::RowMaximum; check::Bool=false, allowsingular::Bool=false +) where {T,N} + # TODO: don't ignore the check and allowsingular flags + # Batching here is in the last dimensions. `Ops.lu` expects the last dimensions + permdims = vcat(collect(Int64, 3:N), 1, 2) + A = @opcall transpose(materialize_traced_array(A), permdims) + factors, ipiv, perm, info = @opcall lu(A) + + # Permute back to the original dimensions + perm_perm = vcat(N - 1, collect(Int64, 1:(N - 2))) + factors = @opcall transpose(factors, invperm(permdims)) + ipiv = @opcall transpose(ipiv, perm_perm) + perm = @opcall transpose(perm, perm_perm) + return GeneralizedLU(factors, ipiv, perm, info) +end + +function LinearAlgebra.ldiv!( + lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M} +) where {T,P,I,N,M} + @assert N == M + 1 + ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...)) + return B +end + +function LinearAlgebra.ldiv!( + lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2} +) where {T,P,I} + B .= _lu_solve_core(lu.factors, B, lu.perm) + return B +end + +function LinearAlgebra.ldiv!( + lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N} +) where {T,P,I,N} + batch_shape = size(lu.factors)[3:end] + @assert batch_shape == size(B)[3:end] + + permutation = vcat(collect(Int64, 3:N), 1, 2) + + factors = @opcall transpose(materialize_traced_array(lu.factors), permutation) + B_permuted = @opcall transpose(materialize_traced_array(B), permutation) + perm = @opcall transpose( + materialize_traced_array(lu.perm), vcat(collect(Int64, 2:(N - 1)), 1) + ) + + res = @opcall transpose( + only( + @opcall( + batch( + _lu_solve_core, [factors, B_permuted, perm], collect(Int64, batch_shape) + ) + ), + ), + invperm(permutation), + ) + B .= res + return B +end + +function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} + n = LinearAlgebra.checksquare(lu) + # TODO: check for non-singular matrices + + P = prod(LinearAlgebra.diag(lu.factors)) + return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P +end + +function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} + n = LinearAlgebra.checksquare(lu) + Treal = real(T) + # TODO: check for non-singular matrices + + d = LinearAlgebra.diag(lu.factors) + absdet = sum(log ∘ abs, d) + P = prod(sign, d) + s = ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(Treal), one(Treal)) * P + return absdet, s +end + +for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization), + aType in (:AbstractVecOrMat, :AbstractArray) + + @eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType) + # TODO: implement this + error("`$(f_wrapper)` is not supported yet for LU.") + return nothing + end +end + +# currently we lower inverse to lu decomposition + triangular solve. we should +# instead emit getri and lower that to a fallback if the backend doesn't support +# it. +function LinearAlgebra.inv!(lu::GeneralizedLU) + @assert ndims(lu) == 2 "Only implemented for 2D tensors" + rhs = Reactant.promote_to( + TracedRArray{Reactant.unwrapped_eltype(eltype(lu)),2}, LinearAlgebra.I(size(lu, 1)) + ) + ldiv!(lu, rhs) + return rhs +end + +function _lu_solve_core(factors::AbstractMatrix, B::AbstractMatrix, perm::AbstractVector) + permuted_B = B[Int64.(perm), :] + return UpperTriangular(factors) \ (UnitLowerTriangular(factors) \ permuted_B) +end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index ea05cfbf0f..cbf2cb02af 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -351,24 +351,31 @@ end end end -solve_with_lu(A, b) = lu(A) \ b -function solve_with_lu_batched(A::AbstractArray{T,N}, B::AbstractArray{T,N}) where {T,N} +solve_with_fact(f::F, A, b) where {F} = f(A) \ b +function solve_with_fact_batched( + f::F, A::AbstractArray{T,N}, B::AbstractArray{T,N} +) where {F,T,N} A2 = reshape(A, size(A, 1), size(A, 2), prod(size(A)[3:end])) B2 = reshape(B, size(B, 1), size(B, 2), prod(size(B)[3:end])) @assert size(A2, 3) == size(B2, 3) return reshape( - stack(lu(view(A2, :, :, i)) \ view(B2, :, :, i) for i in axes(A2, 3)), + stack(f(view(A2, :, :, i)) \ view(B2, :, :, i) for i in axes(A2, 3)), size(A2, 1), size(B2, 2), size(A)[3:end]..., ) end -function solve_with_lu_batched(A::AbstractArray{T,N}, b::AbstractArray{T,M}) where {T,N,M} +function solve_with_fact_batched( + f::F, A::AbstractArray{T,N}, b::AbstractArray{T,M} +) where {F,T,N,M} @assert N == M + 1 B = reshape(b, size(b, 1), 1, size(b)[2:end]...) - return dropdims(solve_with_lu_batched(A, B); dims=2) + return dropdims(solve_with_fact_batched(f, A, B); dims=2) end +solve_with_lu(A, b) = solve_with_fact(lu, A, b) +solve_with_lu_batched(A, b) = solve_with_fact_batched(lu, A, b) + @testset "LU Factorization" begin @testset "Un-batched" begin @testset for T in (Float32, Float64, ComplexF32, ComplexF64) @@ -433,34 +440,185 @@ end end end -@testset "istriu" begin - x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) - x_triu = triu(x, 4) - x_triu_ra = Reactant.to_rarray(x_triu) - @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 4))) - @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 3))) - @test !Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 5))) +solve_with_cholesky(A, b) = solve_with_fact(cholesky, A, b) +solve_with_cholesky_batched(A, b) = solve_with_fact_batched(cholesky, A, b) + +function random_matrix_with_cond( + ::Type{T}, rows::Int, cols::Int, cond_number::Float64 +) where {T} + # Generate random orthogonal matrices U and V + U = ( + LinearAlgebra.qr(randn(rows, rows)).Q * + Diagonal(sign.(diag(LinearAlgebra.qr(randn(rows, rows)).R))) + ) + V = ( + LinearAlgebra.qr(randn(cols, cols)).Q * + Diagonal(sign.(diag(LinearAlgebra.qr(randn(cols, cols)).R))) + ) + + min_dim = min(rows, cols) + singular_values = exp.(range(log(1.0), log(1.0 / cond_number); length=min_dim)) + + S = zeros(Float64, rows, cols) + @inbounds for i in 1:min_dim + S[i, i] = singular_values[i] + end + + return T.(U * S * V') +end + +@testset "Cholesky Factorization" begin + @testset "Un-batched" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + (T == ComplexF64 || T == Float64) && RunningOnTPU && continue + + A = random_matrix_with_cond(T, 4, 4, 1.001) # avoid ill conditioned + A = A * A' + A_ra = Reactant.to_rarray(A) + + b = rand(T, 4) + b_ra = Reactant.to_rarray(b) + + B = rand(T, 4, 3) + B_ra = Reactant.to_rarray(B) + + @test @jit(solve_with_cholesky(A_ra, b_ra)) ≈ solve_with_cholesky(A, b) atol = + 1e-4 rtol = 1e-2 + @test @jit(solve_with_cholesky(A_ra, B_ra)) ≈ solve_with_cholesky(A, B) atol = + 1e-4 rtol = 1e-2 + end + end + + @testset "Batched" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + (T == ComplexF64 || T == Float64) && RunningOnTPU && continue + + A = stack(random_matrix_with_cond(T, 4, 4, 1.001) for _ in 1:6) + A = reshape(stack((r * r' for r in eachslice(A; dims=3))), 4, 4, 3, 2) + A_ra = Reactant.to_rarray(A) + + b = rand(T, 4, 3, 2) + b_ra = Reactant.to_rarray(b) + + B = rand(T, 4, 5, 3, 2) + B_ra = Reactant.to_rarray(B) + + @test @jit(solve_with_cholesky(A_ra, b_ra)) ≈ solve_with_cholesky_batched(A, b) atol = + 1e-4 rtol = 1e-2 + @test @jit(solve_with_cholesky(A_ra, B_ra)) ≈ solve_with_cholesky_batched(A, B) atol = + 1e-4 rtol = 1e-2 + end + end end -@testset "istril" begin - x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) - x_tril = tril(x, -4) - x_tril_ra = Reactant.to_rarray(x_tril) - @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -4))) - @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -3))) - @test !Bool(@jit(LinearAlgebra.istril(x_tril_ra, -5))) +@testset "structure check" begin + @testset "istriu" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x_triu = triu(x, 4) + x_triu_ra = Reactant.to_rarray(x_triu) + @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 4))) + @test Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 3))) + @test !Bool(@jit(LinearAlgebra.istriu(x_triu_ra, 5))) + end + + @testset "istril" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x_tril = tril(x, -4) + x_tril_ra = Reactant.to_rarray(x_tril) + @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -4))) + @test Bool(@jit(LinearAlgebra.istril(x_tril_ra, -3))) + @test !Bool(@jit(LinearAlgebra.istril(x_tril_ra, -5))) + end + + @testset "banded" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x = tril(triu(x, -3), 4) + x_ra = Reactant.to_rarray(x) + + @test Bool(@jit(LinearAlgebra.isbanded(x_ra, -3, 4))) + @test Bool(@jit(LinearAlgebra.isbanded(x_ra, -3, 5))) + @test Bool(@jit(LinearAlgebra.isbanded(x_ra, -4, 4))) + @test !Bool(@jit(LinearAlgebra.isbanded(x_ra, -2, 4))) + @test !Bool(@jit(LinearAlgebra.isbanded(x_ra, -3, 3))) + end + + @testset "issymmetric/ishermitian" begin + x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x2 = x .+ x' + x_ra = Reactant.to_rarray(x) + x2_ra = Reactant.to_rarray(x2) + @test Bool(@jit(LinearAlgebra.issymmetric(x2_ra))) + @test Bool(@jit(LinearAlgebra.ishermitian(x2_ra))) + @test !Bool(@jit(LinearAlgebra.issymmetric(x_ra))) + @test !Bool(@jit(LinearAlgebra.ishermitian(x_ra))) + + x = Reactant.TestUtils.construct_test_array(ComplexF32, 8, 8) + x2 = x .+ x' + x_ra = Reactant.to_rarray(x) + x2_ra = Reactant.to_rarray(x2) + @test !Bool(@jit(LinearAlgebra.issymmetric(x2_ra))) + @test Bool(@jit(LinearAlgebra.ishermitian(x2_ra))) + @test !Bool(@jit(LinearAlgebra.issymmetric(x_ra))) + @test !Bool(@jit(LinearAlgebra.ishermitian(x_ra))) + end end @testset "det" begin - x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) + x_lowtri = Float32[1 0; 2 2] + x_reg = Float32[1 -1; 2 2] + x_uptri = Float32[1 2; 0 2] + + for x in (x_lowtri, x_reg, x_uptri) + x_ra = Reactant.to_rarray(x) + + res_ra = @jit LinearAlgebra.logabsdet(x_ra) + res = LinearAlgebra.logabsdet(x) + @test res_ra[1] ≈ res[1] + @test res_ra[2] ≈ res[2] + + res_ra = @jit LinearAlgebra.det(x_ra) + res = LinearAlgebra.det(x) + @test res_ra ≈ res + end +end + +@testset "inv" begin + x_lowtri = Float32[1 0; 2 2] + x_reg = Float32[1 -1; 2 2] + x_uptri = Float32[1 2; 0 2] + + for x in (x_lowtri, x_reg, x_uptri) + x_ra = Reactant.to_rarray(x) + + res_ra = @jit inv(x_ra) + res = inv(x) + @test res_ra ≈ res + end +end + +@testset "norm accidental promotion" begin + x_ra = Reactant.to_rarray(rand(Float32, 4, 4)) + @test @jit(norm(x_ra)) isa ConcreteRNumber{Float32} +end + +@testset "cross" begin + x = Float32[0; 1; 0] + x_ra = Reactant.to_rarray(x) + y = Float32[0; 0; 1] + y_ra = Reactant.to_rarray(y) + + @test @jit(LinearAlgebra.cross(x_ra, y_ra)) ≈ LinearAlgebra.cross(x, y) + @test @jit(LinearAlgebra.cross(x_ra, y)) ≈ LinearAlgebra.cross(x, y) + @test @jit(LinearAlgebra.cross(x, y_ra)) ≈ LinearAlgebra.cross(x, y) +end + +@testset "normalize/normalize!" begin + x = Reactant.TestUtils.construct_test_array(Float32, 4, 4) x_ra = Reactant.to_rarray(x) - res_ra = @jit LinearAlgebra.logabsdet(x_ra) - res = LinearAlgebra.logabsdet(x) - @test res_ra[1] ≈ res[1] - @test res_ra[2] ≈ res[2] + @test @jit(LinearAlgebra.normalize(x_ra)) ≈ LinearAlgebra.normalize(x) - res_ra = @jit LinearAlgebra.det(x_ra) - res = LinearAlgebra.det(x) - @test res_ra ≈ res + LinearAlgebra.normalize!(x) + @jit LinearAlgebra.normalize!(x_ra) + @test x_ra ≈ x end