diff --git a/Project.toml b/Project.toml index c763a8824..1b597b7c9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.55" +version = "0.10.56" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/utils.jl b/src/utils.jl index 0942fa5f7..11372cfe6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -94,6 +94,8 @@ Base.setindex!(D::ColVecs, v::AbstractVector, i) = setindex!(D.X, v, :, i) Base.vcat(a::ColVecs, b::ColVecs) = ColVecs(hcat(a.X, b.X)) Base.zero(x::ColVecs) = ColVecs(zero(x.X)) +Base.reduce(::typeof(hcat), a::ColVecs) = copy(a.X) +Base.reduce(::typeof(vcat), a::ColVecs) = copy(vec(a.X)) dim(x::ColVecs) = size(x.X, 1) @@ -167,6 +169,8 @@ Base.setindex!(D::RowVecs, v::AbstractVector, i) = setindex!(D.X, v, i, :) Base.vcat(a::RowVecs, b::RowVecs) = RowVecs(vcat(a.X, b.X)) Base.zero(x::RowVecs) = RowVecs(zero(x.X)) +Base.reduce(::typeof(hcat), a::RowVecs) = permutedims(a.X) +Base.reduce(::typeof(vcat), a::RowVecs) = vec(permutedims(a.X)) dim(x::RowVecs) = size(x.X, 2) diff --git a/test/utils.jl b/test/utils.jl index 42784548a..6b7869b26 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -40,6 +40,9 @@ Y = randn(rng, D, N + 1) DY = ColVecs(Y) + + @test reduce(vcat, DY) == vcat(DY...) + @test reduce(hcat, DY) == hcat(DY...) @test KernelFunctions.pairwise(SqEuclidean(), DX) ≈ pairwise(SqEuclidean(), X; dims=2) @test KernelFunctions.pairwise(SqEuclidean(), DX, DY) ≈ @@ -98,6 +101,9 @@ Y = randn(rng, D + 1, N) DY = RowVecs(Y) + + @test reduce(vcat, DY) == vcat(DY...) + @test reduce(hcat, DY) == hcat(DY...) @test KernelFunctions.pairwise(SqEuclidean(), DX) ≈ pairwise(SqEuclidean(), X; dims=1) @test KernelFunctions.pairwise(SqEuclidean(), DX, DY) ≈