Skip to content

Commit

Permalink
cat: ensure vararg is more inferrable
Browse files Browse the repository at this point in the history
Ensures that Union{} is not part of the output possibilities after type-piracy of Base.cat methods.

Refs JuliaLang/julia#50550
  • Loading branch information
vtjnash committed Aug 1, 2023
1 parent 2c4f870 commit b5c8fda
Showing 1 changed file with 32 additions and 23 deletions.
55 changes: 32 additions & 23 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1194,13 +1194,14 @@ anysparse() = false
anysparse(X) = X isa AbstractArray && issparse(X)
anysparse(X, Xs...) = anysparse(X) || anysparse(Xs...)

function hcat(X::Union{Vector, AbstractSparseVector}...)
const _SparseVecConcatGroup = Union{Vector, AbstractSparseVector}
function hcat(X::_SparseVecConcatGroup...)
if anysparse(X...)
X = map(sparse, X)
end
return cat(X...; dims=Val(2))
end
function vcat(X::Union{Vector, AbstractSparseVector}...)
function vcat(X::_SparseVecConcatGroup...)
if anysparse(X...)
X = map(sparse, X)
end
Expand All @@ -1213,47 +1214,55 @@ end
const _SparseConcatGroup = Union{AbstractVecOrMat{<:Number},Number}

# `@constprop :aggressive` allows `dims` to be propagated as constant improving return type inference
Base.@constprop :aggressive function Base._cat(dims, X::_SparseConcatGroup...)
T = promote_eltype(X...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
Base.@constprop :aggressive function Base._cat(dims, X1::_SparseConcatGroup, X::_SparseConcatGroup...)
T = promote_eltype(X1, X...)
if anysparse(X1) || anysparse(X...)
X1, X = _sparse(X1), map(_makesparse, X)
end
return Base._cat_t(dims, T, X...)
return Base._cat_t(dims, T, X1, X...)
end
function hcat(X::_SparseConcatGroup...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
function hcat(X1::_SparseConcatGroup, X::_SparseConcatGroup...)
if anysparse(X1) || anysparse(X...)
X1, X = _sparse(X1), map(_makesparse, X)
end
return cat(X..., dims=Val(2))
return cat(X1, X..., dims=Val(2))
end
function vcat(X::_SparseConcatGroup...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
function vcat(X1::_SparseConcatGroup, X::_SparseConcatGroup...)
if anysparse(X1) || anysparse(X...)
X1, X = _sparse(X1), map(_makesparse, X)
end
return cat(X..., dims=Val(1))
return cat(X1, X..., dims=Val(1))
end
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
if anysparse(X...)
vcat(_hvcat_rows(rows, X...)...)
function hvcat(rows::Tuple{Vararg{Int}}, X1::_SparseConcatGroup, X::_SparseConcatGroup...)
if anysparse(X1) || anysparse(X...)
vcat(_hvcat_rows(rows, X1, X...)...)
else
Base.typed_hvcat(Base.promote_eltypeof(X...), rows, X...)
Base.typed_hvcat(Base.promote_eltypeof(X1, X...), rows, X1, X...)
end
end
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X1::_SparseConcatGroup, X::_SparseConcatGroup...)
if row1 0
throw(ArgumentError("length of block row must be positive, got $row1"))
end
# assert `X` is non-empty so that inference of `eltype` won't include `Type{Union{}}`
T = eltype(X::Tuple{Any,Vararg{Any}})
T = eltype(X)
# inference of `getindex` may be imprecise in case `row1` is not const-propagated up
# to here, so help inference with the following type-assertions
return (
hcat(X[1 : row1]::Tuple{typeof(X[1]),Vararg{T}}...),
hcat(X1, X[1 : row1]::Tuple{Vararg{T}}...),
_hvcat_rows(rows, X[row1+1:end]::Tuple{Vararg{T}}...)...
)
end
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()

# disambiguation for type-piracy problems created above
hcat(n1::Number, ns::Vararg{Number}) = invoke(hcat, Tuple{Vararg{Number}}, n1, ns...)
vcat(n1::Number, ns::Vararg{Number}) = invoke(vcat, Tuple{Vararg{Number}}, n1, ns...)
hcat(n1::Type{N}, ns::Vararg{N}) where {N<:Number} = invoke(hcat, Tuple{Vararg{Number}}, n1, ns...)
vcat(n1::Type{N}, ns::Vararg{N}) where {N<:Number} = invoke(vcat, Tuple{Vararg{Number}}, n1, ns...)
hvcat(rows::Tuple{Vararg{Int}}, n1::Number, ns::Vararg{Number}) = invoke(hvcat, Tuple{Vararg{Number}}, n1, ns...)
hvcat(rows::Tuple{Vararg{Int}}, n1::N, ns::Vararg{N}) where {N<:Number} = invoke(hvcat, Tuple{Vararg{N}}, n1, ns...)


# make sure UniformScaling objects are converted to sparse matrices for concatenation
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = anysparse(A...) ? SparseMatrixCSC : Matrix
promote_to_arrays_(n::Int, ::Type{SparseMatrixCSC}, J::UniformScaling) = sparse(J, n, n)
Expand Down

0 comments on commit b5c8fda

Please sign in to comment.