diff --git a/Project.toml b/Project.toml index 547112a4b6..af28e093d0 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.263" +Reactant_jll = "0.0.264" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/src/Overlay.jl b/src/Overlay.jl index f539ef8199..ffe4a4cad3 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -269,6 +269,22 @@ for (jlop, rop, default_pivot) in ( end end +for (jlop, rop) in ((:svd, :overloaded_svd),) + @eval begin + @reactant_overlay @noinline function LinearAlgebra.$(jlop)( + x::AbstractArray; kwargs... + ) + if use_overlayed_version(x) + return TracedLinearAlgebra.$(rop)( + factorization_copy(LinearAlgebra.$(jlop), x); kwargs... + ) + else + return Base.inferencebarrier(LinearAlgebra.$(jlop))(x; kwargs...) + end + end + end +end + @reactant_overlay @noinline function LinearAlgebra.dot(x::AbstractArray, y::AbstractArray) if use_overlayed_version(x) || use_overlayed_version(y) return TracedLinearAlgebra.overloaded_dot(x, y) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 9bfc2ba2ae..26dff76146 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -4,7 +4,7 @@ using ..MLIR: MLIR using ..Reactant: Reactant, Ops using ..Reactant: TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector -using ..Reactant: call_with_reactant +using ..Reactant: call_with_reactant, unwrapped_eltype, promote_to using ReactantCore: ReactantCore, materialize_traced_array, @trace using Reactant_jll: Reactant_jll @@ -15,8 +15,9 @@ using LinearAlgebra: LinearAlgebra, BLAS using LinearAlgebra: Adjoint, Transpose, Factorization, RowMaximum, NoPivot using LinearAlgebra: SymTridiagonal, Symmetric, Bidiagonal, Diagonal, Tridiagonal using LinearAlgebra: LowerTriangular, UnitLowerTriangular, UpperTriangular -using LinearAlgebra: - diag, diagm, ldiv!, det, logabsdet, lu, istriu, istril, triu!, tril!, inv!, rmul! +using LinearAlgebra: I, diag, diagm, ldiv!, det, logabsdet, istriu, istril, triu!, tril! +using LinearAlgebra: inv!, rmul!, normalize +using LinearAlgebra: svd, lu using Libdl: Libdl using GPUArraysCore: @allowscalar diff --git a/src/stdlibs/factorization/Cholesky.jl b/src/stdlibs/factorization/Cholesky.jl index e0f30ad586..80bda2a5dc 100644 --- a/src/stdlibs/factorization/Cholesky.jl +++ b/src/stdlibs/factorization/Cholesky.jl @@ -1,17 +1,18 @@ -struct GeneralizedCholesky{T,S<:AbstractArray,I<:Union{AbstractArray,Number}} <: - GeneralizedFactorization{T} +struct BatchedCholesky{T,S<:AbstractArray,I<:Union{AbstractArray,Number}} <: + BatchedFactorization{T} factors::S uplo::Char info::I end -function GeneralizedCholesky(factors::S, uplo::Char, info::I) where {S,I} +function BatchedCholesky(factors::S, uplo::Char, info::I) where {S,I} @assert ndims(info) == ndims(factors) - 2 - return GeneralizedCholesky{eltype(factors),S,I}(factors, uplo, info) + return BatchedCholesky{eltype(factors),S,I}(factors, uplo, info) end -Base.size(c::GeneralizedCholesky) = size(c.factors) -Base.ndims(c::GeneralizedCholesky) = ndims(c.factors) +Base.size(c::BatchedCholesky) = size(c.factors) +Base.size(c::BatchedCholesky, i::Integer) = size(c.factors, i) +Base.ndims(c::BatchedCholesky) = ndims(c.factors) function overloaded_cholesky(A::AbstractArray, ::NoPivot; check::Bool=false) return overloaded_cholesky(Reactant.promote_to(TracedRArray, A), NoPivot(); check) @@ -41,11 +42,11 @@ function overloaded_cholesky( info = TracedRNumber{Bool}((), info.mlir_data) end - return GeneralizedCholesky(factors, 'U', info) + return BatchedCholesky(factors, 'U', info) end function LinearAlgebra.ldiv!( - F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,M} + F::BatchedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,M} ) where {T,N,M} @assert N == M + 1 ldiv!(F, reshape(B, size(B, 1), 1, size(B)[2:end]...)) @@ -53,14 +54,14 @@ function LinearAlgebra.ldiv!( end function LinearAlgebra.ldiv!( - F::GeneralizedCholesky{T,<:AbstractArray{T,2}}, B::AbstractArray{T,2} + F::BatchedCholesky{T,<:AbstractArray{T,2}}, B::AbstractArray{T,2} ) where {T} B .= _cholesky_solve_core(F.factors, B, F.uplo) return B end function LinearAlgebra.ldiv!( - F::GeneralizedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,N} + F::BatchedCholesky{T,<:AbstractArray{T,N}}, B::AbstractArray{T,N} ) where {T,N} batch_shape = size(F.factors)[3:end] @assert batch_shape == size(B)[3:end] diff --git a/src/stdlibs/factorization/Factorization.jl b/src/stdlibs/factorization/Factorization.jl index 5114a47d7b..5782f35cf5 100644 --- a/src/stdlibs/factorization/Factorization.jl +++ b/src/stdlibs/factorization/Factorization.jl @@ -1,28 +1,25 @@ -# Supports batched factorization -abstract type GeneralizedFactorization{T} <: Factorization{T} end +abstract type BatchedFactorization{T} <: Factorization{T} end -function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization) +function LinearAlgebra.TransposeFactorization(f::BatchedFactorization) return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f) end -function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization) +function LinearAlgebra.AdjointFactorization(f::BatchedFactorization) return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f) end -const GeneralizedTransposeFactorization{T} = - LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T} -const GeneralizedAdjointFactorization{T} = - LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T} +const BatchedTransposeFactorization{T} = + LinearAlgebra.TransposeFactorization{T,<:BatchedFactorization{T}} where {T} +const BatchedAdjointFactorization{T} = + LinearAlgebra.AdjointFactorization{T,<:BatchedFactorization{T}} where {T} include("Cholesky.jl") include("LU.jl") +include("SVD.jl") # Overload \ to support batched factorization -for FT in ( - :GeneralizedFactorization, - :GeneralizedTransposeFactorization, - :GeneralizedAdjointFactorization, -) +for FT in + (:BatchedFactorization, :BatchedTransposeFactorization, :BatchedAdjointFactorization) for aType in (:AbstractVecOrMat, :AbstractArray) @eval Base.:(\)(F::$FT, B::$aType) = _overloaded_backslash(F, B) end @@ -32,18 +29,36 @@ for FT in ( ) where {T<:Union{Float32,Float64}} = _overloaded_backslash(F, B) end -function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray) - return ldiv!( - F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) - ) +function __get_B(F::Factorization, B::AbstractArray) + m, n = size(F, 1), size(F, 2) + if m != size(B, 1) + throw(DimensionMismatch("arguments must have the same number of rows")) + end + + TFB = typeof(oneunit(eltype(F)) \ oneunit(eltype(B))) + + BB = similar(B, TFB, max(size(B, 1), n), size(B)[2:end]...) + if n > size(B, 1) + BB[1:m, ntuple(Returns(Colon()), ndims(B) - 1)...] = B + else + copyto!(BB, B) + end + + return BB +end + +function _overloaded_backslash(F::BatchedFactorization, B::AbstractArray) + BB = __get_B(F, B) + ldiv!(F, BB) + return BB[1:size(F, 2), ntuple(Returns(Colon()), ndims(B) - 1)...] end -function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray) +function _overloaded_backslash(F::BatchedTransposeFactorization, B::AbstractArray) return conj!(adjoint(F.parent) \ conj.(B)) end -function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray) - return ldiv!( - F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B)))) - ) +function _overloaded_backslash(F::BatchedAdjointFactorization, B::AbstractArray) + BB = __get_B(F, B) + ldiv!(F, BB) + return BB[1:size(F)[2], ntuple(Returns(Colon()), ndims(B) - 1)...] end diff --git a/src/stdlibs/factorization/LU.jl b/src/stdlibs/factorization/LU.jl index 144e7b9dcd..79eb206e4d 100644 --- a/src/stdlibs/factorization/LU.jl +++ b/src/stdlibs/factorization/LU.jl @@ -1,22 +1,22 @@ -struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <: - GeneralizedFactorization{T} +struct BatchedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <: + BatchedFactorization{T} factors::S ipiv::P perm::P info::I end -Base.size(lu::GeneralizedLU) = size(lu.factors) -Base.size(lu::GeneralizedLU, i) = size(lu.factors, i) -Base.ndims(lu::GeneralizedLU) = ndims(lu.factors) -function Base.copy(lu::GeneralizedLU) - return GeneralizedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info)) +Base.size(lu::BatchedLU) = size(lu.factors) +Base.size(lu::BatchedLU, i::Integer) = size(lu.factors, i) +Base.ndims(lu::BatchedLU) = ndims(lu.factors) +function Base.copy(lu::BatchedLU) + return BatchedLU(copy(lu.factors), copy(lu.ipiv), copy(lu.perm), copy(lu.info)) end -function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} +function BatchedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I} @assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1 @assert ndims(info) == ndims(factors) - 2 - return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info) + return BatchedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info) end function overloaded_lu(x::AbstractArray, args...; kwargs...) @@ -37,11 +37,11 @@ function overloaded_lu( factors = @opcall transpose(factors, invperm(permdims)) ipiv = @opcall transpose(ipiv, perm_perm) perm = @opcall transpose(perm, perm_perm) - return GeneralizedLU(factors, ipiv, perm, info) + return BatchedLU(factors, ipiv, perm, info) end function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M} + lu::BatchedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M} ) where {T,P,I,N,M} @assert N == M + 1 ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...)) @@ -49,14 +49,14 @@ function LinearAlgebra.ldiv!( end function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2} + lu::BatchedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2} ) where {T,P,I} B .= _lu_solve_core(lu.factors, B, lu.perm) return B end function LinearAlgebra.ldiv!( - lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N} + lu::BatchedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N} ) where {T,P,I,N} batch_shape = size(lu.factors)[3:end] @assert batch_shape == size(B)[3:end] @@ -83,7 +83,7 @@ function LinearAlgebra.ldiv!( return B end -function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} +function LinearAlgebra.det(lu::BatchedLU{T,<:AbstractMatrix}) where {T} n = LinearAlgebra.checksquare(lu) # TODO: check for non-singular matrices @@ -91,7 +91,7 @@ function LinearAlgebra.det(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} return ifelse(isodd(sum(lu.ipiv[1:n] .!= (1:n))), -one(T), one(T)) * P end -function LinearAlgebra.logabsdet(lu::GeneralizedLU{T,<:AbstractMatrix}) where {T} +function LinearAlgebra.logabsdet(lu::BatchedLU{T,<:AbstractMatrix}) where {T} n = LinearAlgebra.checksquare(lu) Treal = real(T) # TODO: check for non-singular matrices @@ -106,7 +106,7 @@ end for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization), aType in (:AbstractVecOrMat, :AbstractArray) - @eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType) + @eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:BatchedLU}, B::$aType) # TODO: implement this error("`$(f_wrapper)` is not supported yet for LU.") return nothing @@ -116,7 +116,7 @@ end # currently we lower inverse to lu decomposition + triangular solve. we should # instead emit getri and lower that to a fallback if the backend doesn't support # it. -function LinearAlgebra.inv!(lu::GeneralizedLU) +function LinearAlgebra.inv!(lu::BatchedLU) @assert ndims(lu) == 2 "Only implemented for 2D tensors" rhs = Reactant.promote_to( TracedRArray{Reactant.unwrapped_eltype(eltype(lu)),2}, LinearAlgebra.I(size(lu, 1)) diff --git a/src/stdlibs/factorization/SVD.jl b/src/stdlibs/factorization/SVD.jl new file mode 100644 index 0000000000..d507a3feeb --- /dev/null +++ b/src/stdlibs/factorization/SVD.jl @@ -0,0 +1,168 @@ +struct BatchedSVD{T,Tr,M<:AbstractArray,C<:AbstractArray} <: BatchedFactorization{T} + U::M + S::C + Vt::M + + function BatchedSVD{T,Tr,M,C}(U::M, S::C, Vt::M) where {T,Tr,M,C} + @assert ndims(S) == ndims(U) - 1 + return new{T,Tr,M,C}(U, S, Vt) + end +end + +function Base.size(svd::BatchedSVD) + return (size(svd.U, 1), size(svd.Vt, 2), size(svd.U)[3:end]...) +end +Base.size(svd::BatchedSVD, i::Integer) = i == 2 ? size(svd.Vt, 2) : size(svd.U, i) + +function BatchedSVD(U::M, S::C, Vt::M) where {M,C} + @assert ndims(S) == ndims(U) - 1 + return BatchedSVD{eltype(U),eltype(S),M,C}(U, S, Vt) +end + +struct DefaultEnzymeXLASVDAlgorithm <: LinearAlgebra.Algorithm end +struct JacobiAlgorithm <: LinearAlgebra.Algorithm end + +_jlalg_to_enzymexla_alg(::Nothing) = "DEFAULT" +_jlalg_to_enzymexla_alg(alg::DefaultEnzymeXLASVDAlgorithm) = "DEFAULT" +_jlalg_to_enzymexla_alg(alg::LinearAlgebra.DivideAndConquer) = "DivideAndConquer" +_jlalg_to_enzymexla_alg(alg::LinearAlgebra.QRIteration) = "QRIteration" +_jlalg_to_enzymexla_alg(alg::JacobiAlgorithm) = "Jacobi" +_jlalg_to_enzymexla_alg(alg::String) = alg +_jlalg_to_enzymexla_alg(alg::Symbol) = _jlalg_to_enzymexla_alg(string(alg)) +_jlalg_to_enzymexla_alg(alg) = error("Unsupported SVD algorithm: $alg") + +# default relies on the backend to select the best algorithm +LinearAlgebra.default_svd_alg(::AnyTracedRArray) = DefaultEnzymeXLASVDAlgorithm() + +function overloaded_svd(A::AbstractArray; kwargs...) + return overloaded_svd(Reactant.promote_to(TracedRArray, A); kwargs...) +end + +function overloaded_svd( + A::AnyTracedRArray{T,N}; full::Bool=false, alg=LinearAlgebra.default_svd_alg(A) +) where {T,N} + # Batching here is in the last dimensions. `Ops.svd` expects the last dimensions + permdims = vcat(collect(Int64, 3:N), 1, 2) + A = @opcall transpose(materialize_traced_array(A), permdims) + + U, S, Vt = @opcall svd(A; full, algorithm=_jlalg_to_enzymexla_alg(alg)) + + # Permute back to the original dimensions + S_perm = vcat(N - 1, collect(Int64, 1:(N - 2))) + + U = @opcall transpose(U, invperm(permdims)) + S = @opcall transpose(S, S_perm) + Vt = @opcall transpose(Vt, invperm(permdims)) + + return BatchedSVD(U, S, Vt) +end + +struct __InnerVectorSVDDispatch{A} <: Function + full::Bool + algorithm::A +end + +struct __ZeroNormVectorSVDDispatch{A} <: Function + full::Bool + algorithm::A +end + +function overloaded_svd( + A::AnyTracedRVector; full::Bool=false, alg=LinearAlgebra.default_svd_alg(A) +) + normA = Reactant.call_with_reactant(LinearAlgebra.norm, A) + U, S, Vt = ReactantCore.traced_if( + iszero(normA), + __ZeroNormVectorSVDDispatch(full, alg), + __InnerVectorSVDDispatch(full, alg), + (A, normA), + ) + return BatchedSVD(U, S, Vt) +end + +function (fn::__ZeroNormVectorSVDDispatch)(A::AbstractVector{T}, normA) where {T} + U = promote_to( + TracedRArray, Matrix{unwrapped_eltype(A)}(I, length(A), fn.full ? length(A) : 1) + ) + return U, fill(normA, 1), ones(T, 1, 1) +end + +function (fn::__InnerVectorSVDDispatch)(A::AbstractVector{T}, normA) where {T} + if !fn.full + normalizedA = normalize(A) + U = materialize_traced_array(reshape(normalizedA, length(A), 1)) + return U, fill(normA, 1), ones(T, 1, 1) + end + (; U, S, Vt) = overloaded_svd(reshape(A, :, 1); full=true, alg=fn.algorithm) + return U, S, Vt +end + +function LinearAlgebra.svdvals(x::AnyTracedRArray{T,N}; kwargs...) where {T,N} + return overloaded_svd(x; kwargs..., full=false).S +end +function LinearAlgebra.svdvals!(x::AnyTracedRArray{T,N}; kwargs...) where {T,N} + return overloaded_svd(x; kwargs..., full=false).S +end +function LinearAlgebra.svdvals(x::AnyTracedRVector{T}; kwargs...) where {T} + return overloaded_svd(x; kwargs..., full=false).S +end +function LinearAlgebra.svdvals!(x::AnyTracedRVector{T}; kwargs...) where {T} + return overloaded_svd(x; kwargs..., full=false).S +end + +# Ideally we want to slice based on near zero singular values, but this will +# produce dynamically sized slices. Instead we zero out slices and proceed +function _svd_solve_core( + U::AbstractMatrix, S::AbstractVector{Tr}, Vt::AbstractMatrix, B::AbstractMatrix +) where {Tr} + mask = S .> eps(real(Tr)) * @allowscalar(S[1]) + m, n = size(U, 1), size(Vt, 2) + rhs = S .\ (U' * LinearAlgebra._cut_B(B, 1:m)) + rhs = ifelse.(mask, rhs, zero(eltype(rhs))) + return (Vt[1:length(S), :])' * rhs +end + +function LinearAlgebra.ldiv!( + svd::BatchedSVD{T,Tr,<:AbstractArray{T,N}}, B::AbstractArray{T,M} +) where {T,Tr,N,M} + @assert N == M + 1 + ldiv!(svd, reshape(B, size(B, 1), 1, size(B)[2:end]...)) + return B +end + +function LinearAlgebra.ldiv!( + svd::BatchedSVD{T,Tr,<:AbstractArray{T,2}}, B::AbstractArray{T,2} +) where {T,Tr} + n = size(svd, 2) + sol = _svd_solve_core(svd.U, svd.S, svd.Vt, B) + B[1:n, :] .= sol + return B +end + +function LinearAlgebra.ldiv!( + svd::BatchedSVD{T,Tr,<:AbstractArray{T,N}}, B::AbstractArray{T,N} +) where {T,Tr,N} + batch_shape = size(svd.U)[3:end] + @assert batch_shape == size(B)[3:end] + + n = size(svd, 2) + permutation = vcat(collect(Int64, 3:N), 1, 2) + S_perm = vcat(collect(Int64, 2:(N - 1)), 1) + + U = @opcall transpose(materialize_traced_array(svd.U), permutation) + S = @opcall transpose(materialize_traced_array(svd.S), S_perm) + Vt = @opcall transpose(materialize_traced_array(svd.Vt), permutation) + + B_permuted = @opcall transpose(materialize_traced_array(B), permutation) + + res = @opcall transpose( + only( + @opcall( + batch(_svd_solve_core, [U, S, Vt, B_permuted], collect(Int64, batch_shape)) + ), + ), + invperm(permutation), + ) + B[1:n, :, ntuple(Returns(Colon()), length(batch_shape))...] .= res + return B +end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index cbf2cb02af..d9db88932a 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -511,6 +511,107 @@ end end end +using LinearAlgebra, Reactant +Reactant.set_default_backend("cpu") + +function get_svd_algorithms(backend::String, size=nothing) + backend = lowercase(backend) + algorithms = ["DEFAULT"] + if occursin("cpu", backend) + append!(algorithms, ["QRIteration", "DivideAndConquer"]) + elseif occursin("cuda", backend) + if size === nothing || size[1] ≥ size[2] + append!(algorithms, ["QRIteration", "Jacobi"]) + else + append!(algorithms, ["Jacobi"]) + end + elseif occursin("tpu", backend) + append!(algorithms, ["Jacobi"]) + end + return algorithms +end + +least_squares_with_svd(A, b, full, alg) = svd(A; full, alg) \ b + +function compute_ls_solution_error(A, sol, b, bsize) + b = reshape(b, size(b, 1), bsize, :) + A = reshape(A, size(A, 1), size(A, 2), :) + sol = reshape(sol, size(sol, 1), bsize, :) + mul = stack((A[:, :, i] * sol[:, :, i] for i in axes(A, 3)); dims=3) + return maximum(abs, mul .- b) +end + +@testset "svd factorization" begin + A = Reactant.TestUtils.construct_test_array(Float32, 4, 8) + + algs = get_svd_algorithms(string(Reactant.devices()[1]), size(A)) + + tmp = rand(Float32, 8, 5) + B = A * tmp + b = B[:, 1] + + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + b_ra = Reactant.to_rarray(b) + + # test least squares error + @testset "least squares error: $(alg) | full=$(full)" for alg in algs, + full in (true, false) + + sol1 = @jit least_squares_with_svd(A_ra, b_ra, full, alg) + err1 = maximum(abs, A * Array(sol1) .- b) + @test err1 < 1e-3 + + sol2 = @jit least_squares_with_svd(A_ra, B_ra, full, alg) + err2 = maximum(abs, A * Array(sol2) .- B) + @test err2 < 1e-3 + end + + A = Reactant.TestUtils.construct_test_array(Float32, 4, 8, 3, 2) + A_ra = Reactant.to_rarray(A) + + tmp = rand(Float32, 8, 5, 3, 2) + B = similar(A, Float32, 4, 5, 3, 2) + for i in 1:3, j in 1:2 + B[:, :, i, j] = A[:, :, i, j] * tmp[:, :, i, j] + end + B_ra = Reactant.to_rarray(B) + b = B[:, 1, :, :] + b_ra = Reactant.to_rarray(b) + + @testset "[batched] least squares error: $(alg) | full=$(full)" for alg in algs, + full in (true, false) + + sol1 = @jit least_squares_with_svd(A_ra, b_ra, full, alg) + err1 = compute_ls_solution_error(A, Array(sol1), b, 1) + @test err1 < 1e-3 + + sol2 = @jit least_squares_with_svd(A_ra, B_ra, full, alg) + err2 = compute_ls_solution_error(A, Array(sol2), B, 5) + @test err2 < 1e-3 + end +end + +@testset "svdvals" begin + algs = get_svd_algorithms(string(Reactant.devices()[1])) + + @testset "Un-batched: $(alg)" for alg in algs + A = Reactant.TestUtils.construct_test_array(Float32, 4, 4) + _svdvals = svdvals(A) + A_ra = Reactant.to_rarray(A) + _svdvals_ra = @jit svdvals(A_ra; alg=alg) + @test _svdvals_ra ≈ _svdvals + end + + @testset "Batched: $(alg)" for alg in algs + A = Reactant.TestUtils.construct_test_array(Float32, 4, 4, 3, 2) + _svdvals = reshape(mapslices(svdvals, A; dims=(1, 2)), 4, 3, 2) + A_ra = Reactant.to_rarray(A) + _svdvals_ra = @jit svdvals(A_ra; alg=alg) + @test _svdvals_ra ≈ _svdvals + end +end + @testset "structure check" begin @testset "istriu" begin x = Reactant.TestUtils.construct_test_array(Float32, 8, 8)