From 5006025033cf5559de75e2b69100f724a44566dc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Nov 2025 10:06:00 -0500 Subject: [PATCH 1/7] refactor: rename to batched to avoid ambiguity --- src/Overlay.jl | 16 ++++++++++ src/stdlibs/factorization/Cholesky.jl | 20 ++++++------- src/stdlibs/factorization/Factorization.jl | 16 +++++----- src/stdlibs/factorization/LU.jl | 34 +++++++++++----------- src/stdlibs/factorization/SVD.jl | 10 +++++++ 5 files changed, 61 insertions(+), 35 deletions(-) create mode 100644 src/stdlibs/factorization/SVD.jl diff --git a/src/Overlay.jl b/src/Overlay.jl index f539ef8199..ffe4a4cad3 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -269,6 +269,22 @@ for (jlop, rop, default_pivot) in ( end end +for (jlop, rop) in ((:svd, :overloaded_svd),) + @eval begin + @reactant_overlay @noinline function LinearAlgebra.$(jlop)( + x::AbstractArray; kwargs... + ) + if use_overlayed_version(x) + return TracedLinearAlgebra.$(rop)( + factorization_copy(LinearAlgebra.$(jlop), x); kwargs... + ) + else + return Base.inferencebarrier(LinearAlgebra.$(jlop))(x; kwargs...) + end + end + end +end + @reactant_overlay @noinline function LinearAlgebra.dot(x::AbstractArray, y::AbstractArray) if use_overlayed_version(x) || use_overlayed_version(y) return TracedLinearAlgebra.overloaded_dot(x, y) diff --git a/src/stdlibs/factorization/Cholesky.jl b/src/stdlibs/factorization/Cholesky.jl index e0f30ad586..efe9188903 100644 --- a/src/stdlibs/factorization/Cholesky.jl +++ b/src/stdlibs/factorization/Cholesky.jl @@ -1,17 +1,17 @@ -struct GeneralizedCholesky{T,S<:AbstractArray,I<:Union{AbstractArray,Number}} <: - GeneralizedFactorization{T} +struct BatchedCholesky{T,S<:AbstractArray,I<:Union{AbstractArray,Number}} <: + BatchedFactorization{T} factors::S uplo::Char info::I end -function GeneralizedCholesky(factors::S, uplo::Char, info::I) where {S,I} +function BatchedCholesky(factors::S, uplo::Char, info::I) where {S,I} @assert ndims(info) == ndims(factors) - 2 - return GeneralizedCholesky{eltype(factors),S,I}(factors, uplo, info) + return BatchedCholesky{eltype(factors),S,I}(factors, uplo, info) end -Base.size(c::GeneralizedCholesky) = size(c.factors) -Base.ndims(c::GeneralizedCholesky) = ndims(c.factors) +Base.size(c::BatchedCholesky) = size(c.factors) +Base.ndims(c::BatchedCholesky) = ndims(c.factors) function overloaded_cholesky(A::AbstractArray, ::NoPivot; check::Bool=false) return overloaded_cholesky(Reactant.promote_to(TracedRArray, A), NoPivot(); check) @@ -41,11 +41,11 @@ function overloaded_cholesky( info = TracedRNumber{Bool}((), info.mlir_data) end - return GeneralizedCholesky(factors, 'U', info) + return BatchedCholesky(factors, 'U', info) end function LinearAlgebra.ldiv!( - F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,M} + F::BatchedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,M} ) where {T,N,M} @assert N == M + 1 ldiv!(F, reshape(B, size(B, 1), 1, size(B)[2:end]...)) @@ -53,14 +53,14 @@ function LinearAlgebra.ldiv!( end function LinearAlgebra.ldiv!( - F::GeneralizedCholesky{T,<:AbstractArray{T,2}}, B::AbstractArray{T,2} + F::BatchedCholesky{T,<:AbstractArray{T,2}}, B::AbstractArray{T,2} ) where {T} B .= _cholesky_solve_core(F.factors, B, F.uplo) return B end function LinearAlgebra.ldiv!( - F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,N} + F::BatchedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,N} ) where {T,N} batch_shape = size(F.factors)[3:end] @assert batch_shape == size(B)[3:end] diff --git a/src/stdlibs/factorization/Factorization.jl b/src/stdlibs/factorization/Factorization.jl index 5114a47d7b..bd51b65828 100644 --- a/src/stdlibs/factorization/Factorization.jl +++ b/src/stdlibs/factorization/Factorization.jl @@ -1,25 +1,25 @@ -# Supports batched factorization -abstract type GeneralizedFactorization{T} <: Factorization{T} end +abstract type BatchedFactorization{T} <: Factorization{T} end -function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization) +function LinearAlgebra.TransposeFactorization(f::BatchedFactorization) return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f) end -function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization) +function LinearAlgebra.AdjointFactorization(f::BatchedFactorization) return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f) end const GeneralizedTransposeFactorization{T} = - LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T} + LinearAlgebra.TransposeFactorization{T,<:BatchedFactorization{T}} where {T} const GeneralizedAdjointFactorization{T} = - LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T} + LinearAlgebra.AdjointFactorization{T,<:BatchedFactorization{T}} where {T} include("Cholesky.jl") include("LU.jl") +include("SVD.jl") # Overload \ to support batched factorization for FT in ( - :GeneralizedFactorization, + :BatchedFactorization, :GeneralizedTransposeFactorization, :GeneralizedAdjointFactorization, ) @@ -32,7 +32,7 @@ for FT in ( ) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B) end -function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray) +function _overloaded_backslash(F::BatchedFactorization, B::AbstractArray) return ldiv!( F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) ) diff --git a/src/stdlibs/factorization/LU.jl b/src/stdlibs/factorization/LU.jl index 144e7b9dcd..439ed8ad33 100644 --- a/src/stdlibs/factorization/LU.jl +++ b/src/stdlibs/factorization/LU.jl @@ -1,22 +1,22 @@ -struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <: - GeneralizedFactorization{T} +struct BatchedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <: + BatchedFactorization{T} factors::S ipiv::P perm::P info::I end -Base.size(lu::GeneralizedLU) = size(lu.factors) -Base.size(lu::GeneralizedLU, i) = size(lu.factors, i) -Base.ndims(lu::GeneralizedLU) = ndims(lu.factors) -function Base.copy(lu::GeneralizedLU) - return GeneralizedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info)) +Base.size(lu::BatchedLU) = size(lu.factors) +Base.size(lu::BatchedLU, i) = size(lu.factors, i) +Base.ndims(lu::BatchedLU) = ndims(lu.factors) +function Base.copy(lu::BatchedLU) + return BatchedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info)) end -function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} +function BatchedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} @assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1 @assert ndims(info) == ndims(factors) - 2 - return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info) + return BatchedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info) end function overloaded_lu(x::AbstractArray, args...; kwargs...) @@ -37,11 +37,11 @@ function overloaded_lu( factors = @opcall transpose(factors, invperm(permdims)) ipiv = @opcall transpose(ipiv, perm_perm) perm = @opcall transpose(perm, perm_perm) - return GeneralizedLU(factors, ipiv, perm, info) + return BatchedLU(factors, ipiv, perm, info) end function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M} + lu::BatchedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M} ) where {T,P,I,N,M} @assert N == M + 1 ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...)) @@ -49,14 +49,14 @@ function LinearAlgebra.ldiv!( end function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2} + lu::BatchedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2} ) where {T,P,I} B .= _lu_solve_core(lu.factors, B, lu.perm) return B end function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N} + lu::BatchedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N} ) where {T,P,I,N} batch_shape = size(lu.factors)[3:end] @assert batch_shape == size(B)[3:end] @@ -83,7 +83,7 @@ function LinearAlgebra.ldiv!( return B end -function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} +function LinearAlgebra.det(lu::BatchedLU{T,<:AbstractMatrix}) where {T} n = LinearAlgebra.checksquare(lu) # TODO: check for non-singular matrices @@ -91,7 +91,7 @@ function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P end -function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} +function LinearAlgebra.logabsdet(lu::BatchedLU{T,<:AbstractMatrix}) where {T} n = LinearAlgebra.checksquare(lu) Treal = real(T) # TODO: check for non-singular matrices @@ -106,7 +106,7 @@ end for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization), aType in (:AbstractVecOrMat, :AbstractArray) - @eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType) + @eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:BatchedLU}, B::$aType) # TODO: implement this error("`$(f_wrapper)` is not supported yet for LU.") return nothing @@ -116,7 +116,7 @@ end # currently we lower inverse to lu decomposition + triangular solve. we should # instead emit getri and lower that to a fallback if the backend doesn't support # it. -function LinearAlgebra.inv!(lu::GeneralizedLU) +function LinearAlgebra.inv!(lu::BatchedLU) @assert ndims(lu) == 2 "Only implemented for 2D tensors" rhs = Reactant.promote_to( TracedRArray{Reactant.unwrapped_eltype(eltype(lu)),2}, LinearAlgebra.I(size(lu, 1)) diff --git a/src/stdlibs/factorization/SVD.jl b/src/stdlibs/factorization/SVD.jl new file mode 100644 index 0000000000..cf5e57cbef --- /dev/null +++ b/src/stdlibs/factorization/SVD.jl @@ -0,0 +1,10 @@ +struct BatchedSVD{T,Tr,M<:AbstractArray,C<:AbstractArray} <: Factorization{T} + U::M + S::C + Vt::M + + function BatchedSVD{T,Tr,M,C}(U::M, S::C, Vt::M) where {T,Tr,M,C} + @assert ndims(S) == ndims(U) - 1 + return new{T,Tr,M,C}(U, S, Vt) + end +end From 3d1caadb52cf8579d68df8cc6fcc2dfb466bb7ed Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Nov 2025 11:24:46 -0500 Subject: [PATCH 2/7] feat: svd/svdvals lowering --- src/Ops.jl | 5 +- src/stdlibs/LinearAlgebra.jl | 7 +-- src/stdlibs/factorization/SVD.jl | 85 ++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 5 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index ed162501ba..3c6518895f 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3325,9 +3325,10 @@ end m, n = size(x)[(end - 1):end] r = min(m, n) - U_size = (batch_sizes..., m, full ? m : r) + # note that sizes are transposed here + U_size = (batch_sizes..., full ? m : r, m) S_size = (batch_sizes..., r) - Vt_size = (batch_sizes..., full ? n : r, n) + Vt_size = (batch_sizes..., n, full ? n : r) info_size = batch_sizes if algorithm == "DEFAULT" diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 9bfc2ba2ae..26dff76146 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -4,7 +4,7 @@ using ..MLIR: MLIR using ..Reactant: Reactant, Ops using ..Reactant: TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector -using ..Reactant: call_with_reactant +using ..Reactant: call_with_reactant, unwrapped_eltype, promote_to using ReactantCore: ReactantCore, materialize_traced_array, @trace using Reactant_jll: Reactant_jll @@ -15,8 +15,9 @@ using LinearAlgebra: LinearAlgebra, BLAS using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum, NoPivot using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular -using LinearAlgebra: - diag, diagm, ldiv!, det, logabsdet, lu, istriu, istril, triu!, tril!, inv!, rmul! +using LinearAlgebra: I, diag, diagm, ldiv!, det, logabsdet, istriu, istril, triu!, tril! +using LinearAlgebra: inv!, rmul!, normalize +using LinearAlgebra: svd, lu using Libdl: Libdl using GPUArraysCore: @allowscalar diff --git a/src/stdlibs/factorization/SVD.jl b/src/stdlibs/factorization/SVD.jl index cf5e57cbef..8875271004 100644 --- a/src/stdlibs/factorization/SVD.jl +++ b/src/stdlibs/factorization/SVD.jl @@ -8,3 +8,88 @@ struct BatchedSVD{T,Tr,M<:AbstractArray,C<:AbstractArray} <: Factorization{T} return new{T,Tr,M,C}(U, S, Vt) end end + +function BatchedSVD(U::M, S::C, Vt::M) where {M,C} + @assert ndims(S) == ndims(U) - 1 + return BatchedSVD{eltype(U),eltype(S),M,C}(U, S, Vt) +end + +struct DefaultEnzymeXLASVDAlgorithm <: LinearAlgebra.Algorithm end +struct JacobiAlgorithm <: LinearAlgebra.Algorithm end + +_jlalg_to_enzymexla_alg(::Nothing) = "DEFAULT" +_jlalg_to_enzymexla_alg(alg::DefaultEnzymeXLASVDAlgorithm) = "DEFAULT" +_jlalg_to_enzymexla_alg(alg::LinearAlgebra.DivideAndConquer) = "DivideAndConquer" +_jlalg_to_enzymexla_alg(alg::LinearAlgebra.QRIteration) = "QRIteration" +_jlalg_to_enzymexla_alg(alg::JacobiAlgorithm) = "Jacobi" +_jlalg_to_enzymexla_alg(alg::String) = alg +_jlalg_to_enzymexla_alg(alg::Symbol) = _jlalg_to_enzymexla_alg(string(alg)) +_jlalg_to_enzymexla_alg(alg) = error("Unsupported SVD algorithm: $alg") + +# default relies on the backend to select the best algorithm +LinearAlgebra.default_svd_alg(::AnyTracedRArray) = DefaultEnzymeXLASVDAlgorithm() + +function overloaded_svd(A::AbstractArray; kwargs...) + return overloaded_svd(Reactant.promote_to(TracedRArray, A); kwargs...) +end + +function overloaded_svd( + A::AnyTracedRArray{T,N}; full::Bool=false, algorithm=LinearAlgebra.default_svd_alg(A) +) where {T,N} + U, S, Vt = @opcall svd(A; full, algorithm=_jlalg_to_enzymexla_alg(algorithm)) + return BatchedSVD(U, S, Vt) +end + +struct __InnerVectorSVDDispatch{A} <: Function + full::Bool + algorithm::A +end + +struct __ZeroNormVectorSVDDispatch{A} <: Function + full::Bool + algorithm::A +end + +function overloaded_svd( + A::AnyTracedRVector; full::Bool=false, algorithm=LinearAlgebra.default_svd_alg(A) +) + normA = Reactant.call_with_reactant(LinearAlgebra.norm, A) + U, S, Vt = ReactantCore.traced_if( + iszero(normA), + __ZeroNormVectorSVDDispatch(full, algorithm), + __InnerVectorSVDDispatch(full, algorithm), + (A, normA), + ) + return BatchedSVD(U, S, Vt) +end + +function (fn::__ZeroNormVectorSVDDispatch)(A::AbstractVector{T}, normA) where {T} + U = promote_to( + TracedRArray, Matrix{unwrapped_eltype(A)}(I, length(A), fn.full ? length(A) : 1) + ) + return U, fill(normA, 1), ones(T, 1, 1) +end + +function (fn::__InnerVectorSVDDispatch)(A::AbstractVector{T}, normA) where {T} + if !fn.full + normalizedA = normalize(A) + U = materialize_traced_array(reshape(normalizedA, length(A), 1)) + return U, fill(normA, 1), ones(T, 1, 1) + end + (; U, S, Vt) = overloaded_svd(reshape(A, :, 1); full=true, algorithm=fn.algorithm) + return U, S, Vt +end + +# TODO: not yet performant. See https://github.com/EnzymeAD/Enzyme-JAX/issues/1623 +function LinearAlgebra.svdvals(x::AnyTracedRArray{T,N}; kwargs...) where {T,N} + return overloaded_svd(x; kwargs..., full=false).S +end +function LinearAlgebra.svdvals!(x::AnyTracedRArray{T,N}; kwargs...) where {T,N} + return overloaded_svd(x; kwargs..., full=false).S +end +function LinearAlgebra.svdvals(x::AnyTracedRVector{T}; kwargs...) where {T} + return overloaded_svd(x; kwargs..., full=false).S +end +function LinearAlgebra.svdvals!(x::AnyTracedRVector{T}; kwargs...) where {T} + return overloaded_svd(x; kwargs..., full=false).S +end From dc2cfd01089f94b35df4afd5a534b3278b9cdb76 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Nov 2025 13:28:43 -0500 Subject: [PATCH 3/7] test: svdvals + batching fix --- src/stdlibs/factorization/SVD.jl | 17 +++++++++++++-- test/integration/linear_algebra.jl | 34 ++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/stdlibs/factorization/SVD.jl b/src/stdlibs/factorization/SVD.jl index 8875271004..e30b7eb933 100644 --- a/src/stdlibs/factorization/SVD.jl +++ b/src/stdlibs/factorization/SVD.jl @@ -36,7 +36,21 @@ end function overloaded_svd( A::AnyTracedRArray{T,N}; full::Bool=false, algorithm=LinearAlgebra.default_svd_alg(A) ) where {T,N} - U, S, Vt = @opcall svd(A; full, algorithm=_jlalg_to_enzymexla_alg(algorithm)) + # Batching here is in the last dimensions. `Ops.svd` expects the last dimensions + permdims = vcat(collect(Int64, 3:N), 1, 2) + A = @opcall transpose(materialize_traced_array(A), permdims) + + U, S, Vt = @opcall svd( + A; full, algorithm=_jlalg_to_enzymexla_alg(algorithm) + ) + + # Permute back to the original dimensions + S_perm = vcat(N - 1, collect(Int64, 1:(N - 2))) + + U = @opcall transpose(U, invperm(permdims)) + S = @opcall transpose(S, S_perm) + Vt = @opcall transpose(Vt, invperm(permdims)) + return BatchedSVD(U, S, Vt) end @@ -80,7 +94,6 @@ function (fn::__InnerVectorSVDDispatch)(A::AbstractVector{T}, normA) where {T} return U, S, Vt end -# TODO: not yet performant. See https://github.com/EnzymeAD/Enzyme-JAX/issues/1623 function LinearAlgebra.svdvals(x::AnyTracedRArray{T,N}; kwargs...) where {T,N} return overloaded_svd(x; kwargs..., full=false).S end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index cbf2cb02af..cd37a0d70c 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -511,6 +511,40 @@ end end end +function get_svd_algorithms(backend::String) + algorithms = ["DEFAULT"] + if occursin("cpu", backend) + append!(algorithms, ["QRIteration", "DivideAndConquer"]) + elseif occursin("cuda", backend) + append!(algorithms, ["QRIteration", "Jacobi"]) + elseif occursin("tpu", backend) + append!(algorithms, ["Jacobi"]) + end + return algorithms +end + +@testset "svd factorization" begin end + +@testset "svdvals" begin + algs = get_svd_algorithms(string(Reactant.devices()[1])) + + @testset "Un-batched: $(alg)" for alg in algs + A = Reactant.TestUtils.construct_test_array(Float32, 4, 4) + _svdvals = svdvals(A) + A_ra = Reactant.to_rarray(A) + _svdvals_ra = @jit svdvals(A_ra; algorithm=alg) + @test _svdvals_ra ≈ _svdvals + end + + @testset "Batched: $(alg)" for alg in algs + A = Reactant.TestUtils.construct_test_array(Float32, 4, 4, 3, 2) + _svdvals = reshape(mapslices(svdvals, A; dims=(1, 2)), 4, 3, 2) + A_ra = Reactant.to_rarray(A) + _svdvals_ra = @jit svdvals(A_ra; algorithm=alg) + @test _svdvals_ra ≈ _svdvals + end +end + @testset "structure check" begin @testset "istriu" begin x = Reactant.TestUtils.construct_test_array(Float32, 8, 8) From 1cc9644cf13f0e0a22d27e3d74e3473689cf1354 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Nov 2025 13:46:23 -0500 Subject: [PATCH 4/7] chore: run fmt --- src/stdlibs/factorization/SVD.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/stdlibs/factorization/SVD.jl b/src/stdlibs/factorization/SVD.jl index e30b7eb933..04fa8e3709 100644 --- a/src/stdlibs/factorization/SVD.jl +++ b/src/stdlibs/factorization/SVD.jl @@ -40,9 +40,7 @@ function overloaded_svd( permdims = vcat(collect(Int64, 3:N), 1, 2) A = @opcall transpose(materialize_traced_array(A), permdims) - U, S, Vt = @opcall svd( - A; full, algorithm=_jlalg_to_enzymexla_alg(algorithm) - ) + U, S, Vt = @opcall svd(A; full, algorithm=_jlalg_to_enzymexla_alg(algorithm)) # Permute back to the original dimensions S_perm = vcat(N - 1, collect(Int64, 1:(N - 2))) From 5262596dcbc3a6dc95cb59b1f08ca4cde4e2d8ff Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Nov 2025 16:32:39 -0500 Subject: [PATCH 5/7] feat: more coverage --- src/Ops.jl | 5 +- src/stdlibs/factorization/Cholesky.jl | 1 + src/stdlibs/factorization/Factorization.jl | 45 ++++++++----- src/stdlibs/factorization/LU.jl | 2 +- src/stdlibs/factorization/SVD.jl | 76 ++++++++++++++++++++-- test/integration/linear_algebra.jl | 42 +++++++++++- 6 files changed, 142 insertions(+), 29 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 3c6518895f..ed162501ba 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3325,10 +3325,9 @@ end m, n = size(x)[(end - 1):end] r = min(m, n) - # note that sizes are transposed here - U_size = (batch_sizes..., full ? m : r, m) + U_size = (batch_sizes..., m, full ? m : r) S_size = (batch_sizes..., r) - Vt_size = (batch_sizes..., n, full ? n : r) + Vt_size = (batch_sizes..., full ? n : r, n) info_size = batch_sizes if algorithm == "DEFAULT" diff --git a/src/stdlibs/factorization/Cholesky.jl b/src/stdlibs/factorization/Cholesky.jl index efe9188903..80bda2a5dc 100644 --- a/src/stdlibs/factorization/Cholesky.jl +++ b/src/stdlibs/factorization/Cholesky.jl @@ -11,6 +11,7 @@ function BatchedCholesky(factors::S, uplo::Char, info::I) where {S,I} end Base.size(c::BatchedCholesky) = size(c.factors) +Base.size(c::BatchedCholesky, i::Integer) = size(c.factors, i) Base.ndims(c::BatchedCholesky) = ndims(c.factors) function overloaded_cholesky(A::AbstractArray, ::NoPivot; check::Bool=false) diff --git a/src/stdlibs/factorization/Factorization.jl b/src/stdlibs/factorization/Factorization.jl index bd51b65828..5782f35cf5 100644 --- a/src/stdlibs/factorization/Factorization.jl +++ b/src/stdlibs/factorization/Factorization.jl @@ -8,9 +8,9 @@ function LinearAlgebra.AdjointFactorization(f::BatchedFactorization) return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f) end -const GeneralizedTransposeFactorization{T} = +const BatchedTransposeFactorization{T} = LinearAlgebra.TransposeFactorization{T,<:BatchedFactorization{T}} where {T} -const GeneralizedAdjointFactorization{T} = +const BatchedAdjointFactorization{T} = LinearAlgebra.AdjointFactorization{T,<:BatchedFactorization{T}} where {T} include("Cholesky.jl") @@ -18,11 +18,8 @@ include("LU.jl") include("SVD.jl") # Overload \ to support batched factorization -for FT in ( - :BatchedFactorization, - :GeneralizedTransposeFactorization, - :GeneralizedAdjointFactorization, -) +for FT in + (:BatchedFactorization, :BatchedTransposeFactorization, :BatchedAdjointFactorization) for aType in (:AbstractVecOrMat, :AbstractArray) @eval Base.:(\)(F::$FT, B::$aType) = _overloaded_backslash(F, B) end @@ -32,18 +29,36 @@ for FT in ( ) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B) end +function __get_B(F::Factorization, B::AbstractArray) + m, n = size(F, 1), size(F, 2) + if m != size(B, 1) + throw(DimensionMismatch("arguments must have the same number of rows")) + end + + TFB = typeof(oneunit(eltype(F)) \ oneunit(eltype(B))) + + BB = similar(B, TFB, max(size(B, 1), n), size(B)[2:end]...) + if n > size(B, 1) + BB[1:m, ntuple(Returns(Colon()), ndims(B) - 1)...] = B + else + copyto!(BB, B) + end + + return BB +end + function _overloaded_backslash(F::BatchedFactorization, B::AbstractArray) - return ldiv!( - F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) - ) + BB = __get_B(F, B) + ldiv!(F, BB) + return BB[1:size(F, 2), ntuple(Returns(Colon()), ndims(B) - 1)...] end -function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray) +function _overloaded_backslash(F::BatchedTransposeFactorization, B::AbstractArray) return conj!(adjoint(F.parent) \ conj.(B)) end -function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray) - return ldiv!( - F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) - ) +function _overloaded_backslash(F::BatchedAdjointFactorization, B::AbstractArray) + BB = __get_B(F, B) + ldiv!(F, BB) + return BB[1:size(F)[2], ntuple(Returns(Colon()), ndims(B) - 1)...] end diff --git a/src/stdlibs/factorization/LU.jl b/src/stdlibs/factorization/LU.jl index 439ed8ad33..79eb206e4d 100644 --- a/src/stdlibs/factorization/LU.jl +++ b/src/stdlibs/factorization/LU.jl @@ -7,7 +7,7 @@ struct BatchedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Numb end Base.size(lu::BatchedLU) = size(lu.factors) -Base.size(lu::BatchedLU, i) = size(lu.factors, i) +Base.size(lu::BatchedLU, i::Integer) = size(lu.factors, i) Base.ndims(lu::BatchedLU) = ndims(lu.factors) function Base.copy(lu::BatchedLU) return BatchedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info)) diff --git a/src/stdlibs/factorization/SVD.jl b/src/stdlibs/factorization/SVD.jl index 04fa8e3709..d507a3feeb 100644 --- a/src/stdlibs/factorization/SVD.jl +++ b/src/stdlibs/factorization/SVD.jl @@ -1,4 +1,4 @@ -struct BatchedSVD{T,Tr,M<:AbstractArray,C<:AbstractArray} <: Factorization{T} +struct BatchedSVD{T,Tr,M<:AbstractArray,C<:AbstractArray} <: BatchedFactorization{T} U::M S::C Vt::M @@ -9,6 +9,11 @@ struct BatchedSVD{T,Tr,M<:AbstractArray,C<:AbstractArray} <: Factorization{T} end end +function Base.size(svd::BatchedSVD) + return (size(svd.U, 1), size(svd.Vt, 2), size(svd.U)[3:end]...) +end +Base.size(svd::BatchedSVD, i::Integer) = i == 2 ? size(svd.Vt, 2) : size(svd.U, i) + function BatchedSVD(U::M, S::C, Vt::M) where {M,C} @assert ndims(S) == ndims(U) - 1 return BatchedSVD{eltype(U),eltype(S),M,C}(U, S, Vt) @@ -34,13 +39,13 @@ function overloaded_svd(A::AbstractArray; kwargs...) end function overloaded_svd( - A::AnyTracedRArray{T,N}; full::Bool=false, algorithm=LinearAlgebra.default_svd_alg(A) + A::AnyTracedRArray{T,N}; full::Bool=false, alg=LinearAlgebra.default_svd_alg(A) ) where {T,N} # Batching here is in the last dimensions. `Ops.svd` expects the last dimensions permdims = vcat(collect(Int64, 3:N), 1, 2) A = @opcall transpose(materialize_traced_array(A), permdims) - U, S, Vt = @opcall svd(A; full, algorithm=_jlalg_to_enzymexla_alg(algorithm)) + U, S, Vt = @opcall svd(A; full, algorithm=_jlalg_to_enzymexla_alg(alg)) # Permute back to the original dimensions S_perm = vcat(N - 1, collect(Int64, 1:(N - 2))) @@ -63,13 +68,13 @@ struct __ZeroNormVectorSVDDispatch{A} <: Function end function overloaded_svd( - A::AnyTracedRVector; full::Bool=false, algorithm=LinearAlgebra.default_svd_alg(A) + A::AnyTracedRVector; full::Bool=false, alg=LinearAlgebra.default_svd_alg(A) ) normA = Reactant.call_with_reactant(LinearAlgebra.norm, A) U, S, Vt = ReactantCore.traced_if( iszero(normA), - __ZeroNormVectorSVDDispatch(full, algorithm), - __InnerVectorSVDDispatch(full, algorithm), + __ZeroNormVectorSVDDispatch(full, alg), + __InnerVectorSVDDispatch(full, alg), (A, normA), ) return BatchedSVD(U, S, Vt) @@ -88,7 +93,7 @@ function (fn::__InnerVectorSVDDispatch)(A::AbstractVector{T}, normA) where {T} U = materialize_traced_array(reshape(normalizedA, length(A), 1)) return U, fill(normA, 1), ones(T, 1, 1) end - (; U, S, Vt) = overloaded_svd(reshape(A, :, 1); full=true, algorithm=fn.algorithm) + (; U, S, Vt) = overloaded_svd(reshape(A, :, 1); full=true, alg=fn.algorithm) return U, S, Vt end @@ -104,3 +109,60 @@ end function LinearAlgebra.svdvals!(x::AnyTracedRVector{T}; kwargs...) where {T} return overloaded_svd(x; kwargs..., full=false).S end + +# Ideally we want to slice based on near zero singular values, but this will +# produce dynamically sized slices. Instead we zero out slices and proceed +function _svd_solve_core( + U::AbstractMatrix, S::AbstractVector{Tr}, Vt::AbstractMatrix, B::AbstractMatrix +) where {Tr} + mask = S .> eps(real(Tr)) * @allowscalar(S[1]) + m, n = size(U, 1), size(Vt, 2) + rhs = S .\ (U' * LinearAlgebra._cut_B(B, 1:m)) + rhs = ifelse.(mask, rhs, zero(eltype(rhs))) + return (Vt[1:length(S), :])' * rhs +end + +function LinearAlgebra.ldiv!( + svd::BatchedSVD{T,Tr,<:AbstractArray{T,N}}, B::AbstractArray{T,M} +) where {T,Tr,N,M} + @assert N == M + 1 + ldiv!(svd, reshape(B, size(B, 1), 1, size(B)[2:end]...)) + return B +end + +function LinearAlgebra.ldiv!( + svd::BatchedSVD{T,Tr,<:AbstractArray{T,2}}, B::AbstractArray{T,2} +) where {T,Tr} + n = size(svd, 2) + sol = _svd_solve_core(svd.U, svd.S, svd.Vt, B) + B[1:n, :] .= sol + return B +end + +function LinearAlgebra.ldiv!( + svd::BatchedSVD{T,Tr,<:AbstractArray{T,N}}, B::AbstractArray{T,N} +) where {T,Tr,N} + batch_shape = size(svd.U)[3:end] + @assert batch_shape == size(B)[3:end] + + n = size(svd, 2) + permutation = vcat(collect(Int64, 3:N), 1, 2) + S_perm = vcat(collect(Int64, 2:(N - 1)), 1) + + U = @opcall transpose(materialize_traced_array(svd.U), permutation) + S = @opcall transpose(materialize_traced_array(svd.S), S_perm) + Vt = @opcall transpose(materialize_traced_array(svd.Vt), permutation) + + B_permuted = @opcall transpose(materialize_traced_array(B), permutation) + + res = @opcall transpose( + only( + @opcall( + batch(_svd_solve_core, [U, S, Vt, B_permuted], collect(Int64, batch_shape)) + ), + ), + invperm(permutation), + ) + B[1:n, :, ntuple(Returns(Colon()), length(batch_shape))...] .= res + return B +end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index cd37a0d70c..e5d29722b4 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -523,7 +523,43 @@ function get_svd_algorithms(backend::String) return algorithms end -@testset "svd factorization" begin end +least_squares_with_svd(A, b, full, alg) = svd(A; full, alg) \ b + +@testset "svd factorization" begin + algs = get_svd_algorithms(string(Reactant.devices()[1])) + + A = Reactant.TestUtils.construct_test_array(Float32, 4, 8) + tmp = rand(Float32, 8, 5) + B = A * tmp + b = B[:, 1] + + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + b_ra = Reactant.to_rarray(b) + + # test least squares error + @testset "least squares error: $(alg) | full=$(full)" for alg in algs, + full in (true, false) + # FIXME: svd is mutating the input buffers + + sol1 = @jit least_squares_with_svd(A_ra, b_ra, full, alg) + err1 = maximum(abs, A * Array(sol1) .- b) + @show err1 + + sol2 = @jit least_squares_with_svd(A_ra, B_ra, full, alg) + err2 = maximum(abs, A * Array(sol2) .- B) + @show err2 + end + + @testset "[batched] least squares error: $(alg) | full=$(full)" for alg in algs, + full in (true, false) + + + # TODO + end + + # TODO: when A is a vector +end @testset "svdvals" begin algs = get_svd_algorithms(string(Reactant.devices()[1])) @@ -532,7 +568,7 @@ end A = Reactant.TestUtils.construct_test_array(Float32, 4, 4) _svdvals = svdvals(A) A_ra = Reactant.to_rarray(A) - _svdvals_ra = @jit svdvals(A_ra; algorithm=alg) + _svdvals_ra = @jit svdvals(A_ra; alg=alg) @test _svdvals_ra ≈ _svdvals end @@ -540,7 +576,7 @@ end A = Reactant.TestUtils.construct_test_array(Float32, 4, 4, 3, 2) _svdvals = reshape(mapslices(svdvals, A; dims=(1, 2)), 4, 3, 2) A_ra = Reactant.to_rarray(A) - _svdvals_ra = @jit svdvals(A_ra; algorithm=alg) + _svdvals_ra = @jit svdvals(A_ra; alg=alg) @test _svdvals_ra ≈ _svdvals end end From d795085aa889a1c6d8e0a662e26e7892eb8a318c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Nov 2025 18:13:41 -0600 Subject: [PATCH 6/7] test: use updated commit --- test/integration/linear_algebra.jl | 51 ++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index e5d29722b4..d9db88932a 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -511,12 +511,20 @@ end end end -function get_svd_algorithms(backend::String) +using LinearAlgebra, Reactant +Reactant.set_default_backend("cpu") + +function get_svd_algorithms(backend::String, size=nothing) + backend = lowercase(backend) algorithms = ["DEFAULT"] if occursin("cpu", backend) append!(algorithms, ["QRIteration", "DivideAndConquer"]) elseif occursin("cuda", backend) - append!(algorithms, ["QRIteration", "Jacobi"]) + if size === nothing || size[1] ≥ size[2] + append!(algorithms, ["QRIteration", "Jacobi"]) + else + append!(algorithms, ["Jacobi"]) + end elseif occursin("tpu", backend) append!(algorithms, ["Jacobi"]) end @@ -525,10 +533,19 @@ end least_squares_with_svd(A, b, full, alg) = svd(A; full, alg) \ b -@testset "svd factorization" begin - algs = get_svd_algorithms(string(Reactant.devices()[1])) +function compute_ls_solution_error(A, sol, b, bsize) + b = reshape(b, size(b, 1), bsize, :) + A = reshape(A, size(A, 1), size(A, 2), :) + sol = reshape(sol, size(sol, 1), bsize, :) + mul = stack((A[:, :, i] * sol[:, :, i] for i in axes(A, 3)); dims=3) + return maximum(abs, mul .- b) +end +@testset "svd factorization" begin A = Reactant.TestUtils.construct_test_array(Float32, 4, 8) + + algs = get_svd_algorithms(string(Reactant.devices()[1]), size(A)) + tmp = rand(Float32, 8, 5) B = A * tmp b = B[:, 1] @@ -540,25 +557,39 @@ least_squares_with_svd(A, b, full, alg) = svd(A; full, alg) \ b # test least squares error @testset "least squares error: $(alg) | full=$(full)" for alg in algs, full in (true, false) - # FIXME: svd is mutating the input buffers sol1 = @jit least_squares_with_svd(A_ra, b_ra, full, alg) err1 = maximum(abs, A * Array(sol1) .- b) - @show err1 + @test err1 < 1e-3 sol2 = @jit least_squares_with_svd(A_ra, B_ra, full, alg) err2 = maximum(abs, A * Array(sol2) .- B) - @show err2 + @test err2 < 1e-3 end + A = Reactant.TestUtils.construct_test_array(Float32, 4, 8, 3, 2) + A_ra = Reactant.to_rarray(A) + + tmp = rand(Float32, 8, 5, 3, 2) + B = similar(A, Float32, 4, 5, 3, 2) + for i in 1:3, j in 1:2 + B[:, :, i, j] = A[:, :, i, j] * tmp[:, :, i, j] + end + B_ra = Reactant.to_rarray(B) + b = B[:, 1, :, :] + b_ra = Reactant.to_rarray(b) + @testset "[batched] least squares error: $(alg) | full=$(full)" for alg in algs, full in (true, false) + sol1 = @jit least_squares_with_svd(A_ra, b_ra, full, alg) + err1 = compute_ls_solution_error(A, Array(sol1), b, 1) + @test err1 < 1e-3 - # TODO + sol2 = @jit least_squares_with_svd(A_ra, B_ra, full, alg) + err2 = compute_ls_solution_error(A, Array(sol2), B, 5) + @test err2 < 1e-3 end - - # TODO: when A is a vector end @testset "svdvals" begin From 937641cebdc3965c8079d6cf0e074c0079110627 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 21 Nov 2025 09:19:25 -0500 Subject: [PATCH 7/7] chore: bump reactant_jll version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 547112a4b6..af28e093d0 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.263" +Reactant_jll = "0.0.264" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10"