From ff2fcc1ba9e5690c0e393fe7e5003fcfabcd2d19 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 29 Dec 2022 09:15:37 +0100 Subject: [PATCH] chunk by partition indexes (#134) * chunk by partitions * references --- src/utils.jl | 47 ++++++++++++++++++++++++++++++++++++++++++++--- test/utils.jl | 36 +++++++++++++++++++++++++++++------- 2 files changed, 73 insertions(+), 10 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index b4553ba..2198f92 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -62,7 +62,8 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, Unroll the given `xs` into an array of arrays along the given dimension `dims`. -See also [`stack`](@ref) and [`unbatch`](@ref). +See also [`stack`](@ref), [`unbatch`](@ref), +and [`chunk`](@ref). # Examples @@ -156,6 +157,46 @@ function chunk(x::AbstractArray; size, dims::Int=ndims(x)) return [_selectdim(x, dims, i) for i in idxs] end + +""" + chunk(x, partition_idxs; [npartitions, dims]) + +Partition the array `x` along the dimension `dims` according to the indexes +in `partition_idxs`. + +`partition_idxs` must be sorted and contain only positive integers +between 1 and the number of partitions. + +If the number of partition `npartitions` is not provided, +it is inferred from `partition_idxs`. + +If `dims` is not provided, it defaults to the last dimension. + +See also [`unbatch`](@ref). + +# Examples + +```jldoctest +julia> x = reshape([1:10;], 2, 5) +2×5 Matrix{Int64}: + 1 3 5 7 9 + 2 4 6 8 10 + +julia> chunk(x, [1, 2, 2, 3, 3]) +3-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}}: + [1; 2;;] + [3 5; 4 6] + [7 9; 8 10] +``` +""" +function chunk(x::AbstractArray{T,N}, partition_idxs::AbstractVector; + npartitions=nothing, dims=ndims(x)) where {T, N} + @assert issorted(partition_idxs) "partition_idxs must be sorted" + m = npartitions === nothing ? maximum(partition_idxs) : npartitions + degrees = NNlib.scatter(+, ones_like(partition_idxs), partition_idxs, dstsize=(m,)) + return chunk(x; size=degrees, dims) +end + # work around https://github.com/JuliaML/MLUtils.jl/issues/103 _selectdim(x::AbstractArray, dims::Int, i) = selectdim(x, dims, i) _selectdim(x::AbstractArray, dims::Int, i::UnitRange) = _selectdim(x, Val(dims), i) @@ -349,13 +390,13 @@ end Reverse of the [`batch`](@ref) operation, unstacking the last dimension of the array `x`. -See also [`unstack`](@ref). +See also [`unstack`](@ref) and [`chunk`](@ref). # Examples ```jldoctest julia> unbatch([1 3 5 7; - 2 4 6 8]) + 2 4 6 8]) 4-element Vector{Vector{Int64}}: [1, 2] [3, 4] diff --git a/test/utils.jl b/test/utils.jl index cb901e4..10edfcd 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -134,6 +134,16 @@ end idxs = MLUtils._partition_idxs(x, cld(size(x, dims), n), dims) test_zygote(MLUtils.∇chunk, dl, x, idxs, Val(dims), check_inferred=false) + + if CUDA.functional() + # https://github.com/JuliaML/MLUtils.jl/issues/103 + x = rand(2, 10) |> cu + cs = chunk(x, 2) + @test length(cs) == 2 + @test cs[1] isa CuArray + @test cs[1] == x[:, 1:5] + end + @testset "size collection" begin a = reshape(collect(1:10), (5, 2)) y = chunk(a; dims = 1, size = (1, 4)) @@ -144,13 +154,25 @@ end test_zygote(x -> chunk(x; dims = 1, size = (1, 4)), a) end - if CUDA.functional() - # https://github.com/JuliaML/MLUtils.jl/issues/103 - x = rand(2, 10) |> cu - cs = chunk(x, 2) - @test length(cs) == 2 - @test cs[1] isa CuArray - @test cs[1] == x[:, 1:5] + @testset "chunk by partition_idxs" begin + x = reshape(collect(1:15), (3, 5)) + partition_idxs = [1,1,3,3,4] + + y = chunk(x, partition_idxs) + @test length(y) == 4 + @test y[1] == [1 4; 2 5; 3 6] + @test size(y[2]) == (3, 0) + @test y[3] == [7 10; 8 11; 9 12] + @test y[4] == reshape([13, 14, 15], 3, 1) + + y = chunk(x, partition_idxs; npartitions=5) + @test length(y) == 5 + @test size(y[5]) == (3, 0) + + y = chunk(x, [1,1,2]; dims=1) + @test length(y) == 2 + @test y[1] == [1 4 7 10 13; 2 5 8 11 14] + @test y[2] == [3 6 9 12 15] end end