From c402d09cf05492179fad2def5632e354a81f5b30 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Tue, 1 Aug 2023 14:34:18 -0400 Subject: [PATCH] cat: ensure vararg is more inferrable Ensures that Union{} is not part of the output possibilities after type-piracy of Base.cat methods. Refs https://github.com/JuliaLang/julia/issues/50550 --- src/sparsevector.jl | 48 +++++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/src/sparsevector.jl b/src/sparsevector.jl index 22b1badb..cd75e5b1 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -1194,13 +1194,14 @@ anysparse() = false anysparse(X) = X isa AbstractArray && issparse(X) anysparse(X, Xs...) = anysparse(X) || anysparse(Xs...) -function hcat(X::Union{Vector, AbstractSparseVector}...) +const _SparseVecConcatGroup = Union{Vector, AbstractSparseVector} +function hcat(X::_SparseVecConcatGroup...) if anysparse(X...) X = map(sparse, X) end return cat(X...; dims=Val(2)) end -function vcat(X::Union{Vector, AbstractSparseVector}...) +function vcat(X::_SparseVecConcatGroup...) if anysparse(X...) X = map(sparse, X) end @@ -1213,30 +1214,30 @@ end const _SparseConcatGroup = Union{AbstractVecOrMat{<:Number},Number} # `@constprop :aggressive` allows `dims` to be propagated as constant improving return type inference -Base.@constprop :aggressive function Base._cat(dims, X::_SparseConcatGroup...) - T = promote_eltype(X...) - if anysparse(X...) - X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...) +Base.@constprop :aggressive function Base._cat(dims, X1::_SparseConcatGroup, X::_SparseConcatGroup...) + T = promote_eltype(X1, X...) + if anysparse(X1) || anysparse(X...) + X1, X = _sparse(X1), map(_makesparse, X) end - return Base._cat_t(dims, T, X...) + return Base._cat_t(dims, T, X1, X...) end -function hcat(X::_SparseConcatGroup...) - if anysparse(X...) - X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...) +function hcat(X1::_SparseConcatGroup, X::_SparseConcatGroup...) + if anysparse(X1) || anysparse(X...) + X1, X = _sparse(X1), map(_makesparse, X) end - return cat(X..., dims=Val(2)) + return cat(X1, X..., dims=Val(2)) end -function vcat(X::_SparseConcatGroup...) - if anysparse(X...) - X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...) +function vcat(X1::_SparseConcatGroup, X::_SparseConcatGroup...) + if anysparse(X1) || anysparse(X...) + X1, X = _sparse(X1), map(_makesparse, X) end - return cat(X..., dims=Val(1)) + return cat(X1, X..., dims=Val(1)) end -function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) - if anysparse(X...) - vcat(_hvcat_rows(rows, X...)...) +function hvcat(rows::Tuple{Vararg{Int}}, X1::_SparseConcatGroup, X::_SparseConcatGroup...) + if anysparse(X1) || anysparse(X...) + vcat(_hvcat_rows(rows, X1, X...)...) else - Base.typed_hvcat(Base.promote_eltypeof(X...), rows, X...) + Base.typed_hvcat(Base.promote_eltypeof(X1, X...), rows, X1, X...) end end function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) @@ -1254,6 +1255,15 @@ function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup. end _hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = () +# disambiguation for type-piracy problems created above +hcat(n1::Number, ns::Vararg{Number}) = invoke(hcat, Tuple{Vararg{Number}}, n1, ns...) +vcat(n1::Number, ns::Vararg{Number}) = invoke(vcat, Tuple{Vararg{Number}}, n1, ns...) +hcat(n1::Type{N}, ns::Vararg{N}) where {N<:Number} = invoke(hcat, Tuple{Vararg{Number}}, n1, ns...) +vcat(n1::Type{N}, ns::Vararg{N}) where {N<:Number} = invoke(vcat, Tuple{Vararg{Number}}, n1, ns...) +hvcat(rows::Tuple{Vararg{Int}}, n1::Number, ns::Vararg{Number}) = invoke(hvcat, Tuple{typeof(rows), Vararg{Number}}, rows, n1, ns...) +hvcat(rows::Tuple{Vararg{Int}}, n1::N, ns::Vararg{N}) where {N<:Number} = invoke(hvcat, Tuple{typeof(rows), Vararg{N}}, rows, n1, ns...) + + # make sure UniformScaling objects are converted to sparse matrices for concatenation promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = anysparse(A...) ? SparseMatrixCSC : Matrix promote_to_arrays_(n::Int, ::Type{SparseMatrixCSC}, J::UniformScaling) = sparse(J, n, n)