Skip to content

Commit

Permalink
chunk by partition indexes (#134)
Browse files Browse the repository at this point in the history
* chunk by partitions

* references
  • Loading branch information
CarloLucibello committed Dec 29, 2022
1 parent 6a86d23 commit ff2fcc1
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 10 deletions.
47 changes: 44 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io,
Unroll the given `xs` into an array of arrays along the given dimension `dims`.
See also [`stack`](@ref) and [`unbatch`](@ref).
See also [`stack`](@ref), [`unbatch`](@ref),
and [`chunk`](@ref).
# Examples
Expand Down Expand Up @@ -156,6 +157,46 @@ function chunk(x::AbstractArray; size, dims::Int=ndims(x))
return [_selectdim(x, dims, i) for i in idxs]
end


"""
chunk(x, partition_idxs; [npartitions, dims])
Partition the array `x` along the dimension `dims` according to the indexes
in `partition_idxs`.
`partition_idxs` must be sorted and contain only positive integers
between 1 and the number of partitions.
If the number of partition `npartitions` is not provided,
it is inferred from `partition_idxs`.
If `dims` is not provided, it defaults to the last dimension.
See also [`unbatch`](@ref).
# Examples
```jldoctest
julia> x = reshape([1:10;], 2, 5)
2×5 Matrix{Int64}:
1 3 5 7 9
2 4 6 8 10
julia> chunk(x, [1, 2, 2, 3, 3])
3-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}}:
[1; 2;;]
[3 5; 4 6]
[7 9; 8 10]
```
"""
function chunk(x::AbstractArray{T,N}, partition_idxs::AbstractVector;
npartitions=nothing, dims=ndims(x)) where {T, N}
@assert issorted(partition_idxs) "partition_idxs must be sorted"
m = npartitions === nothing ? maximum(partition_idxs) : npartitions
degrees = NNlib.scatter(+, ones_like(partition_idxs), partition_idxs, dstsize=(m,))
return chunk(x; size=degrees, dims)
end

# work around https://github.com/JuliaML/MLUtils.jl/issues/103
_selectdim(x::AbstractArray, dims::Int, i) = selectdim(x, dims, i)
_selectdim(x::AbstractArray, dims::Int, i::UnitRange) = _selectdim(x, Val(dims), i)
Expand Down Expand Up @@ -349,13 +390,13 @@ end
Reverse of the [`batch`](@ref) operation,
unstacking the last dimension of the array `x`.
See also [`unstack`](@ref).
See also [`unstack`](@ref) and [`chunk`](@ref).
# Examples
```jldoctest
julia> unbatch([1 3 5 7;
2 4 6 8])
2 4 6 8])
4-element Vector{Vector{Int64}}:
[1, 2]
[3, 4]
Expand Down
36 changes: 29 additions & 7 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ end
idxs = MLUtils._partition_idxs(x, cld(size(x, dims), n), dims)
test_zygote(MLUtils.∇chunk, dl, x, idxs, Val(dims), check_inferred=false)


if CUDA.functional()
# https://github.com/JuliaML/MLUtils.jl/issues/103
x = rand(2, 10) |> cu
cs = chunk(x, 2)
@test length(cs) == 2
@test cs[1] isa CuArray
@test cs[1] == x[:, 1:5]
end

@testset "size collection" begin
a = reshape(collect(1:10), (5, 2))
y = chunk(a; dims = 1, size = (1, 4))
Expand All @@ -144,13 +154,25 @@ end
test_zygote(x -> chunk(x; dims = 1, size = (1, 4)), a)
end

if CUDA.functional()
# https://github.com/JuliaML/MLUtils.jl/issues/103
x = rand(2, 10) |> cu
cs = chunk(x, 2)
@test length(cs) == 2
@test cs[1] isa CuArray
@test cs[1] == x[:, 1:5]
@testset "chunk by partition_idxs" begin
x = reshape(collect(1:15), (3, 5))
partition_idxs = [1,1,3,3,4]

y = chunk(x, partition_idxs)
@test length(y) == 4
@test y[1] == [1 4; 2 5; 3 6]
@test size(y[2]) == (3, 0)
@test y[3] == [7 10; 8 11; 9 12]
@test y[4] == reshape([13, 14, 15], 3, 1)

y = chunk(x, partition_idxs; npartitions=5)
@test length(y) == 5
@test size(y[5]) == (3, 0)

y = chunk(x, [1,1,2]; dims=1)
@test length(y) == 2
@test y[1] == [1 4 7 10 13; 2 5 8 11 14]
@test y[2] == [3 6 9 12 15]
end
end

Expand Down

2 comments on commit ff2fcc1

@CarloLucibello
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/74759

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" ff2fcc1ba9e5690c0e393fe7e5003fcfabcd2d19
git push origin v0.4.0

Please sign in to comment.