From 61e8df32091f737de512e96f3e18baf73a72b306 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Wed, 17 Sep 2025 15:19:54 -0400 Subject: [PATCH] ldiv for vector-of-vectors RHS --- src/LinearAlgebra.jl | 11 +++++++++-- src/factorization.jl | 4 ++-- src/hessenberg.jl | 4 ++-- src/special.jl | 4 ++-- test/lu.jl | 2 +- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/LinearAlgebra.jl b/src/LinearAlgebra.jl index d97c308b..f4b8f972 100644 --- a/src/LinearAlgebra.jl +++ b/src/LinearAlgebra.jl @@ -700,6 +700,12 @@ const LAPACKFactorizations{T,S} = Union{ (\)(F::AdjointFactorization{<:Any,<:LAPACKFactorizations}, B::AbstractVecOrMat) = ldiv(F, B) (\)(F::TransposeFactorization{<:Any,<:LU}, B::AbstractVecOrMat) = ldiv(F, B) +# return the "scalar" type for vector fields, if possible +_scalartype(::Type{T}) where {T<:Number} = T +_scalartype(::Type{T}) where T = _scalartype(T, Base.IteratorEltype(T)) +_scalartype(::Type{T}, ::Base.HasEltype) where T = _scalartype(eltype(T)) +_scalartype(::Type{T}, ::Base.EltypeUnknown) where T = T + function ldiv(F::Factorization, B::AbstractVecOrMat) require_one_based_indexing(B) m, n = size(F) @@ -707,12 +713,13 @@ function ldiv(F::Factorization, B::AbstractVecOrMat) throw(DimensionMismatch("arguments must have the same number of rows")) end - TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F))) + TFB = typeof(zero(_scalartype(eltype(B))) / oneunit(eltype(F))) FF = Factorization{TFB}(F) # For wide problem we (often) compute a minimum norm solution. The solution # is larger than the right hand side so we use size(F, 2). - BB = _zeros(TFB, B, n) + TBB = typeof(zero(eltype(B)) / oneunit(eltype(F))) + BB = _zeros(TBB, B, n) if n > size(B, 1) # Underdetermined diff --git a/src/factorization.jl b/src/factorization.jl index b035f642..372d4640 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -135,7 +135,7 @@ end function (\)(F::Factorization, B::AbstractVecOrMat) require_one_based_indexing(B) - TFB = typeof(oneunit(eltype(F)) \ oneunit(eltype(B))) + TFB = typeof(oneunit(eltype(F)) \ zero(eltype(B))) ldiv!(F, copy_similar(B, TFB)) end (\)(F::TransposeFactorization, B::AbstractVecOrMat) = conj!(adjoint(F.parent) \ conj.(B)) @@ -179,7 +179,7 @@ end function (/)(B::AbstractMatrix, F::Factorization) require_one_based_indexing(B) - TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F))) + TFB = typeof(zero(eltype(B)) / oneunit(eltype(F))) rdiv!(copy_similar(B, TFB), F) end # reinterpretation trick for complex lhs and real factorization diff --git a/src/hessenberg.jl b/src/hessenberg.jl index 36450160..4b0a709d 100644 --- a/src/hessenberg.jl +++ b/src/hessenberg.jl @@ -196,7 +196,7 @@ TransUpperHessenberg{T,S<:UpperHessenberg{T}} = Transpose{T, S} AdjOrTransUpperHessenberg{T,S<:UpperHessenberg{T}} = AdjOrTrans{T, S} function (\)(H::Union{UpperHessenberg,AdjOrTransUpperHessenberg}, B::AbstractVecOrMat) - TFB = typeof(oneunit(eltype(H)) \ oneunit(eltype(B))) + TFB = typeof(oneunit(eltype(H)) \ zero(eltype(B))) return ldiv!(H, copy_similar(B, TFB)) end @@ -204,7 +204,7 @@ end (/)(B::AbstractMatrix, H::AdjUpperHessenberg) = _rdiv(B, H) (/)(B::AbstractMatrix, H::TransUpperHessenberg) = _rdiv(B, H) function _rdiv(B, H) - TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(H))) + TFB = typeof(zero(eltype(B)) / oneunit(eltype(H))) return rdiv!(copy_similar(B, TFB), H) end diff --git a/src/special.jl b/src/special.jl index 83880ca5..9186d833 100644 --- a/src/special.jl +++ b/src/special.jl @@ -132,13 +132,13 @@ function mul(B::Bidiagonal, H::UpperHessenberg) end function /(H::UpperHessenberg, B::Bidiagonal) - T = typeof(oneunit(eltype(H))/oneunit(eltype(B))) + T = typeof(oneunit(eltype(H))/zero(eltype(B))) A = _rdiv!(similar(H, T, size(H)), H, B) return B.uplo == 'U' ? UpperHessenberg(A) : A end function \(B::Bidiagonal, H::UpperHessenberg) - T = typeof(oneunit(eltype(B))\oneunit(eltype(H))) + T = typeof(zero(eltype(B))\oneunit(eltype(H))) A = ldiv!(similar(H, T, size(H)), B, H) return B.uplo == 'U' ? UpperHessenberg(A) : A end diff --git a/test/lu.jl b/test/lu.jl index ba869032..abc0aa3a 100644 --- a/test/lu.jl +++ b/test/lu.jl @@ -127,7 +127,7 @@ dimg = randn(n)/2 end # Test whether Ax_ldiv_B!(y, LU, x) indeed overwrites y - resultT = typeof(oneunit(eltyb) / oneunit(eltya)) + resultT = typeof(zero(eltyb) / oneunit(eltya)) b_dest = similar(b, resultT) c_dest = similar(c, resultT)