Skip to content

Commit

Permalink
hvncat: Stronger argument checks (#41196)
Browse files Browse the repository at this point in the history
fixes #41047
  • Loading branch information
BioTurboNick committed Jul 15, 2021
1 parent 41ee0fa commit e6aca89
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 33 deletions.
111 changes: 78 additions & 33 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,7 @@ _hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::Number...) = _typed_h
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::AbstractArray...) = _typed_hvncat(promote_eltype(xs...), dimsshape, row_first, xs...)
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::AbstractArray{T}...) where T = _typed_hvncat(T, dimsshape, row_first, xs...)


typed_hvncat(T::Type, dimsshape::Tuple, row_first::Bool, xs...) = _typed_hvncat(T, dimsshape, row_first, xs...)
typed_hvncat(T::Type, dim::Int, xs...) = _typed_hvncat(T, Val(dim), xs...)

Expand All @@ -2152,9 +2153,9 @@ _typed_hvncat(::Type, ::Val{0}, ::AbstractArray...) = _typed_hvncat_0d_only_one(
_typed_hvncat_0d_only_one() =
throw(ArgumentError("a 0-dimensional array may only contain exactly one element"))

_typed_hvncat(::Type{T}, ::Val{N}) where {T, N} = Array{T, N}(undef, ntuple(x -> 0, Val(N)))

function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, xs::Number...) where {T, N}
function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, xs::Number...) where {T, N}
all(>(0), dims) ||
throw(ArgumentError("`dims` argument must contain positive integers"))
A = Array{T, N}(undef, dims...)
lengtha = length(A) # Necessary to store result because throw blocks are being deoptimized right now, which leads to excessive allocations
lengthx = length(xs) # Cuts from 3 allocations to 1.
Expand Down Expand Up @@ -2191,9 +2192,28 @@ function hvncat_fill!(A::Array, row_first::Bool, xs::Tuple)
end

_typed_hvncat(T::Type, dim::Int, ::Bool, xs...) = _typed_hvncat(T, Val(dim), xs...) # catches from _hvncat type promoters

function _typed_hvncat(::Type{T}, ::Val{N}) where {T, N}
N < 0 &&
throw(ArgumentError("concatenation dimension must be nonnegative"))
return Array{T, N}(undef, ntuple(x -> 0, Val(N)))
end

function _typed_hvncat(T::Type, ::Val{N}, xs::Number...) where N
N < 0 &&
throw(ArgumentError("concatenation dimension must be nonnegative"))
A = cat_similar(xs[1], T, (ntuple(x -> 1, Val(N - 1))..., length(xs)))
hvncat_fill!(A, false, xs)
return A
end

function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
# optimization for arrays that can be concatenated by copying them linearly into the destination
# conditions: the elements must all have 1- or 0-length dimensions above N
# conditions: the elements must all have 1-length dimensions above N
length(as) > 0 ||
throw(ArgumentError("must have at least one element"))
N < 0 &&
throw(ArgumentError("concatenation dimension must be nonnegative"))
for a as
ndims(a) <= N || all(x -> size(a, x) == 1, (N + 1):ndims(a)) ||
return _typed_hvncat(T, (ntuple(x -> 1, N - 1)..., length(as), 1), false, as...)
Expand All @@ -2203,10 +2223,13 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
nd = max(N, ndims(as[1]))

Ndim = 0
for i 1:lastindex(as)
Ndim += cat_size(as[i], N)
for d 1:N - 1
cat_size(as[1], d) == cat_size(as[i], d) || throw(ArgumentError("mismatched size along axis $d in element $i"))
for i eachindex(as)
a = as[i]
Ndim += size(a, N)
nd = max(nd, ndims(a))
for d 1:N-1
size(a, d) == size(as[1], d) ||
throw(ArgumentError("all dimensions of element $i other than $N must be of length 1"))
end
end

Expand All @@ -2222,17 +2245,20 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
end

function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N}
# optimization for scalars and 1-length arrays that can be concatenated by copying them linearly
# into the destination
length(as) > 0 ||
throw(ArgumentError("must have at least one element"))
N < 0 &&
throw(ArgumentError("concatenation dimension must be nonnegative"))
nd = N
Ndim = 0
for a as
if a isa AbstractArray
cat_size(a, N) == length(a) ||
throw(ArgumentError("all dimensions of elements other than $N must be of length 1"))
nd = max(nd, cat_ndims(a))
end
for i eachindex(as)
a = as[i]
Ndim += cat_size(a, N)
nd = max(nd, cat_ndims(a))
for d 1:N-1
cat_size(a, d) == 1 ||
throw(ArgumentError("all dimensions of element $i other than $N must be of length 1"))
end
end

A = Array{T, nd}(undef, ntuple(x -> 1, N - 1)..., Ndim, ntuple(x -> 1, nd - N)...)
Expand Down Expand Up @@ -2276,7 +2302,12 @@ function _typed_hvncat_1d(::Type{T}, ds::Int, ::Val{row_first}, as...) where {T,
end
end

function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, as...) where {T, N}
function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as...) where {T, N}
length(as) > 0 ||
throw(ArgumentError("must have at least one element"))
all(>(0), dims) ||
throw(ArgumentError("`dims` argument must contain positive integers"))

d1 = row_first ? 2 : 1
d2 = row_first ? 1 : 2

Expand All @@ -2291,7 +2322,9 @@ function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool,

currentdims = zeros(Int, nd)
blockcount = 0
elementcount = 0
for i eachindex(as)
elementcount += cat_length(as[i])
currentdims[d1] += cat_size(as[i], d1)
if currentdims[d1] == outdims[d1]
currentdims[d1] = 0
Expand Down Expand Up @@ -2321,14 +2354,9 @@ function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool,
end
end

# calling sum() leads to 3 extra allocations
len = 0
for a as
len += cat_length(a)
end
outlen = prod(outdims)
outlen == 0 && throw(ArgumentError("too few elements in arguments, unable to infer dimensions"))
len == outlen || throw(ArgumentError("too many elements in arguments; expected $(outlen), got $(len)"))
elementcount == outlen ||
throw(ArgumentError("mismatched number of elements; expected $(outlen), got $(elementcount)"))

# copy into final array
A = cat_similar(as[1], T, outdims)
Expand All @@ -2347,22 +2375,32 @@ function _typed_hvncat(T::Type, shape::Tuple{Tuple}, row_first::Bool, xs...)
return _typed_hvncat_1d(T, shape[1][1], Val(row_first), xs...)
end

function _typed_hvncat(T::Type, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {N}
function _typed_hvncat(::Type{T}, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {T, N}
length(as) > 0 ||
throw(ArgumentError("must have at least one element"))
all(>(0), tuple((shape...)...)) ||
throw(ArgumentError("`shape` argument must consist of positive integers"))

d1 = row_first ? 2 : 1
d2 = row_first ? 1 : 2
shape = collect(shape) # saves allocations later
shapelength = shape[end][1]
shapev = collect(shape) # saves allocations later
all(!isempty, shapev) ||
throw(ArgumentError("each level of `shape` argument must have at least one value"))
length(shapev[end]) == 1 ||
throw(ArgumentError("last level of shape must contain only one integer"))
shapelength = shapev[end][1]
lengthas = length(as)
shapelength == lengthas || throw(ArgumentError("number of elements does not match shape; expected $(shapelength), got $lengthas)"))

# discover dimensions
nd = max(N, cat_ndims(as[1]))
outdims = zeros(Int, nd)
currentdims = zeros(Int, nd)
blockcounts = zeros(Int, nd)
shapepos = ones(Int, nd)

elementcount = 0
for i eachindex(as)
elementcount += cat_length(as[i])
wasstartblock = false
for d 1:N
ad = (d < 3 && row_first) ? (d == 1 ? 2 : 1) : d
Expand All @@ -2372,27 +2410,34 @@ function _typed_hvncat(T::Type, shape::NTuple{N, Tuple}, row_first::Bool, as...)
if d == 1 || i == 1 || wasstartblock
currentdims[d] += dsize
elseif dsize != cat_size(as[i - 1], ad)
throw(ArgumentError("""argument $i has a mismatched number of elements along axis $ad; \
expected $(cat_size(as[i - 1], ad)), got $dsize"""))
throw(ArgumentError("argument $i has a mismatched number of elements along axis $ad; \
expected $(cat_size(as[i - 1], ad)), got $dsize"))
end

wasstartblock = blockcounts[d] == 1 # remember for next dimension

isendblock = blockcounts[d] == shape[d][shapepos[d]]
isendblock = blockcounts[d] == shapev[d][shapepos[d]]
if isendblock
if outdims[d] == 0
outdims[d] = currentdims[d]
elseif outdims[d] != currentdims[d]
throw(ArgumentError("""argument $i has a mismatched number of elements along axis $ad; \
expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize"""))
throw(ArgumentError("argument $i has a mismatched number of elements along axis $ad; \
expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize"))
end
currentdims[d] = 0
blockcounts[d] = 0
shapepos[d] += 1
d > 1 && (blockcounts[d - 1] == 0 ||
throw(ArgumentError("shape in level $d is inconsistent; level counts must nest \
evenly into each other")))
end
end
end

outlen = prod(outdims)
elementcount == outlen ||
throw(ArgumentError("mismatched number of elements; expected $(outlen), got $(elementcount)"))

if row_first
outdims[1], outdims[2] = outdims[2], outdims[1]
end
Expand Down
63 changes: 63 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,69 @@ using Base: typed_hvncat
@test [v v;;; fill(v, 1, 2)] == fill(v, 1, 2, 2)
end

# dims form
for v ((), (1,), ([1],), (1, [1]), ([1], 1), ([1], [1]))
# reject dimension < 0
@test_throws ArgumentError hvncat(-1, v...)

# reject shape tuple with no elements
@test_throws ArgumentError hvncat(((),), true, v...)
end

# reject dims or shape with negative or zero values
for v1 (-1, 0, 1)
for v2 (-1, 0, 1)
v1 == v2 == 1 && continue
for v3 ((), (1,), ([1],), (1, [1]), ([1], 1), ([1], [1]))
@test_throws ArgumentError hvncat((v1, v2), true, v3...)
@test_throws ArgumentError hvncat(((v1,), (v2,)), true, v3...)
end
end
end

for v ((1, [1]), ([1], 1), ([1], [1]))
# reject shape with more than one end value
@test_throws ArgumentError hvncat(((1, 1),), true, v...)
end

for v ((1, 2, 3), (1, 2, [3]), ([1], [2], [3]))
# reject shape with more values in later level
@test_throws ArgumentError hvncat(((2, 1), (1, 1, 1)), true, v...)
end

# reject shapes that don't nest evenly between levels (e.g. 1 + 2 does not fit into 2)
@test_throws ArgumentError hvncat(((1, 2, 1), (2, 2), (4,)), true, [1 2], [3], [4], [1 2; 3 4])

# zero-length arrays are handled appropriately
@test [zeros(Int, 1, 2, 0) ;;; 1 3] == [1 3;;;]
@test [[] ;;; [] ;;; []] == Array{Any}(undef, 0, 1, 3)
@test [[] ; 1 ;;; 2 ; []] == [1 ;;; 2]
@test [[] ; [] ;;; [] ; []] == Array{Any}(undef, 0, 1, 2)
@test [[] ; 1 ;;; 2] == [1 ;;; 2]
@test [[] ; [] ;;; [] ;;; []] == Array{Any}(undef, 0, 1, 3)
z = zeros(Int, 0, 0, 0)
[z z ; z ;;; z ;;; z] == Array{Int}(undef, 0, 0, 0)

for v1 (zeros(Int, 0, 0), zeros(Int, 0, 0, 0, 0), zeros(Int, 0, 0, 0, 0, 0, 0, 0))
for v2 (1, [1])
for v3 (2, [2])
@test_throws ArgumentError [v1 ;;; v2]
@test_throws ArgumentError [v1 ;;; v2 v3]
@test_throws ArgumentError [v1 v1 ;;; v2 v3]
end
end
end
v1 = zeros(Int, 0, 0, 0)
for v2 (1, [1])
for v3 (2, [2])
# current behavior, not potentially dangerous.
# should throw error like above loop
@test [v1 ;;; v2 v3] == [v2 v3;;;]
@test_throws ArgumentError [v1 ;;; v2]
@test_throws ArgumentError [v1 v1 ;;; v2 v3]
end
end

# 0-dimension behaviors
# exactly one argument, placed in an array
# if already an array, copy, with type conversion as necessary
Expand Down

0 comments on commit e6aca89

Please sign in to comment.