Skip to content

Commit

Permalink
improvements to chunk (#133)
Browse files Browse the repository at this point in the history
* make chunk accept a collection of sizes

* docstring

* use view

* tests

* cleanup

* using CUDA

* using CUDA
  • Loading branch information
CarloLucibello authored Dec 28, 2022
1 parent 08ad0b7 commit 6a86d23
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 12 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ julia = "1.6"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ChainRulesTestUtils", "DataFrames", "SparseArrays", "Test", "Zygote"]
test = ["ChainRulesTestUtils", "CUDA", "DataFrames", "SparseArrays", "Test", "Zygote"]
54 changes: 43 additions & 11 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ unstack(xs; dims::Int) = [copy(selectdim(xs, dims, i)) for i in 1:size(xs, dims)
chunk(x, n; [dims])
chunk(x; [size, dims])
Split `x` into `n` parts or alternatively, into equal chunks of size `size`. The parts contain
the same number of elements except possibly for the last one that can be smaller.
Split `x` into `n` parts or alternatively, if `size` is an integer, into equal chunks of size `size`.
The parts contain the same number of elements except possibly for the last one that can be smaller.
In case `size` is a collection of integers instead, the elements of `x` are split into chunks of
the given sizes.
If `x` is an array, `dims` can be used to specify along which dimension to
split (defaults to the last dimension).
Expand Down Expand Up @@ -135,31 +138,60 @@ julia> xes[2]
13 18
14 19
15 20
julia> chunk(1:6; size = [2, 4])
2-element Vector{UnitRange{Int64}}:
1:2
3:6
```
"""
chunk(x; size::Int) = collect(Iterators.partition(x, size))

chunk(x, n::Int) = chunk(x; size = cld(length(x), n))

function chunk(x::AbstractArray; size::Int, dims::Int=ndims(x))
chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) = chunk(x; size = cld(size(x, dims), n), dims)

function chunk(x::AbstractArray; size, dims::Int=ndims(x))
idxs = _partition_idxs(x, size, dims)
[selectdim(x, dims, i) for i in idxs]
return [_selectdim(x, dims, i) for i in idxs]
end
chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) = chunk(x; size = cld(size(x, dims), n), dims)

function rrule(::typeof(chunk), x::AbstractArray; size::Int, dims::Int=ndims(x))
# this is the implementation of chunk
# work around https://github.com/JuliaML/MLUtils.jl/issues/103
_selectdim(x::AbstractArray, dims::Int, i) = selectdim(x, dims, i)
_selectdim(x::AbstractArray, dims::Int, i::UnitRange) = _selectdim(x, Val(dims), i)

function _selectdim(x::AbstractArray{T,N}, ::Val{dims}, i::UnitRange) where {T,N,dims}
return view(x, ntuple(_ -> Colon(), dims-1)..., i, ntuple(_ -> Colon(), N-dims)...)
end

function rrule(::typeof(chunk), x::AbstractArray; size, dims::Int=ndims(x))
# This is the implementation of chunk
idxs = _partition_idxs(x, size, dims)
y = [selectdim(x, dims, i) for i in idxs]
y = [_selectdim(x, dims, i) for i in idxs]
valdims = Val(dims)
# TODO avoid capturing x in the pullback
chunk_pullback(dy) = (NoTangent(), ∇chunk(unthunk(dy), x, idxs, valdims))

return y, chunk_pullback
end

_partition_idxs(x, size, dims) = Iterators.partition(axes(x, dims), size)
_partition_idxs(x, size::Int, dims::Int) = Iterators.partition(axes(x, dims), size)

_partition_idxs(x, size, dims::Int) = _partition_idxs(x, collect(size), dims)

function _partition_idxs(x, size::AbstractVector{<:Integer}, dims::Int)
n = length(axes(x, dims))
cumsz = cumsum(size)
if cumsz[end] != n
throw(ArgumentError("The sum of the sizes must be equal to $n, the length of the dimension."))
end
return [(i==1 ? 1 : cumsz[i-1]+1):cumsz[i] for i=1:length(cumsz)]
end

@non_differentiable _partition_idxs(::Any...)

# Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77
function ∇chunk(dys, x::AbstractArray, idxs, vd::Val{dim}) where {dim}
function ∇chunk(dys, x, idxs, vd::Val{dim}) where {dim}
i1 = findfirst(dy -> !(dy isa AbstractZero), dys)
if i1 === nothing # all slices are Zero!
return _zero_fill!(similar(x, float(eltype(x))))
Expand All @@ -168,7 +200,7 @@ function ∇chunk(dys, x::AbstractArray, idxs, vd::Val{dim}) where {dim}
# The whole point of this gradient is that we can allocate one `dx` array:
dx = similar(x, T)
for (k, i) in enumerate(idxs)
slice = selectdim(dx, dim, i)
slice = _selectdim(dx, dim, i)
if dys[k] isa AbstractZero
_zero_fill!(slice) # Avoids this: copyto!([1,2,3], ZeroTangent()) == [0,2,3]
else
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using ChainRulesTestUtils: test_rrule
using Zygote: ZygoteRuleConfig
using ChainRulesCore: rrule_via_ad
using DataFrames
using CUDA

showcompact(io, x) = show(IOContext(io, :compact => true), x)

Expand Down
19 changes: 19 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,25 @@ end
dl = randn!.(collect.(l))
idxs = MLUtils._partition_idxs(x, cld(size(x, dims), n), dims)
test_zygote(MLUtils.∇chunk, dl, x, idxs, Val(dims), check_inferred=false)

@testset "size collection" begin
a = reshape(collect(1:10), (5, 2))
y = chunk(a; dims = 1, size = (1, 4))
@test length(y) == 2
@test y[1] == [1 6]
@test y[2] == [2 7; 3 8; 4 9; 5 10]

test_zygote(x -> chunk(x; dims = 1, size = (1, 4)), a)
end

if CUDA.functional()
# https://github.com/JuliaML/MLUtils.jl/issues/103
x = rand(2, 10) |> cu
cs = chunk(x, 2)
@test length(cs) == 2
@test cs[1] isa CuArray
@test cs[1] == x[:, 1:5]
end
end

@testset "group_counts" begin
Expand Down

0 comments on commit 6a86d23

Please sign in to comment.