diff --git a/Project.toml b/Project.toml index ea76060..8accba0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLUtils" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" authors = ["Carlo Lucibello and contributors"] -version = "0.2.9" +version = "0.2.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/batchview.jl b/src/batchview.jl index 4b5f252..d5a1467 100644 --- a/src/batchview.jl +++ b/src/batchview.jl @@ -100,7 +100,7 @@ function BatchView(data::T; batchsize::Int=1, partial::Bool=true, collate=Val(no throw(ArgumentError("`collate` must be one of `nothing`, `true` or `false`.")) end E = _batchviewelemtype(data, collate) - count = partial ? ceil(Int, n / batchsize) : floor(Int, n / batchsize) + count = partial ? cld(n, batchsize) : fld(n, batchsize) BatchView{E,T,typeof(collate)}(data, batchsize, count, partial) end diff --git a/src/utils.jl b/src/utils.jl index f64cbdd..b534292 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -122,9 +122,10 @@ unstack(xs; dims::Int) = [copy(selectdim(xs, dims, i)) for i in 1:size(xs, dims) """ chunk(x, n; [dims]) + chunk(x; [size, dims]) -Split `x` into `n` parts. The parts contain the same number of elements -except possibly for the last one that can be smaller. +Split `x` into `n` parts or alternatively, into equal chunks of size `size`. The parts contain +the same number of elements except possibly for the last one that can be smaller. If `x` is an array, `dims` can be used to specify along which dimension to split (defaults to the last dimension). @@ -138,6 +139,14 @@ julia> chunk(1:10, 3) 5:8 9:10 +julia> chunk(1:10; size = 2) +5-element Vector{UnitRange{Int64}}: + 1:2 + 3:4 + 5:6 + 7:8 + 9:10 + julia> x = reshape(collect(1:20), (5, 4)) 5×4 Matrix{Int64}: 1 6 11 16 @@ -156,30 +165,42 @@ julia> xs[1] 1 6 11 16 2 7 12 17 3 8 13 18 + +julia> xes = chunk(x; size = 2, dims = 2) +2-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}}: + [1 6; 2 7; … ; 4 9; 5 10] + [11 16; 12 17; … ; 14 19; 15 20] + +julia> xes[2] +5×2 view(::Matrix{Int64}, :, 3:4) with eltype Int64: + 11 16 + 12 17 + 13 18 + 14 19 + 15 20 ``` """ -chunk(x, n::Int) = collect(Iterators.partition(x, ceil(Int, length(x) / n))) +chunk(x; size::Int) = collect(Iterators.partition(x, size)) +chunk(x, n::Int) = chunk(x; size = cld(length(x), n)) -function chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) - idxs = _partition_idxs(x, n, dims) +function chunk(x::AbstractArray; size::Int, dims::Int=ndims(x)) + idxs = _partition_idxs(x, size, dims) [selectdim(x, dims, i) for i in idxs] end +chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) = chunk(x; size = cld(size(x, dims), n), dims) -function _partition_idxs(x, n, dims) - bs = ceil(Int, size(x, dims) / n) - Iterators.partition(axes(x, dims), bs) -end - -function rrule(::typeof(chunk), x::AbstractArray, n::Int; dims::Int=ndims(x)) +function rrule(::typeof(chunk), x::AbstractArray; size::Int, dims::Int=ndims(x)) # this is the implementation of chunk - idxs = _partition_idxs(x, n, dims) + idxs = _partition_idxs(x, size, dims) y = [selectdim(x, dims, i) for i in idxs] valdims = Val(dims) - chunk_pullback(dy) = (NoTangent(), ∇chunk(unthunk(dy), x, idxs, valdims), NoTangent()) - + chunk_pullback(dy) = (NoTangent(), ∇chunk(unthunk(dy), x, idxs, valdims)) + return y, chunk_pullback end +_partition_idxs(x, size, dims) = Iterators.partition(axes(x, dims), size) + # Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77 function ∇chunk(dys, x::AbstractArray, idxs, vd::Val{dim}) where {dim} i1 = findfirst(dy -> !(dy isa AbstractZero), dys) diff --git a/test/utils.jl b/test/utils.jl index cbc3888..85ec41c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -101,9 +101,16 @@ end x = reshape(collect(1:20), (5, 4)) cs = chunk(x, 2) @test length(cs) == 2 - cs[1] == [1 6; 2 7; 3 8; 4 9; 5 10] - cs[2] == [11 16; 12 17; 13 18; 14 19; 15 20] - + @test cs[1] == [1 6; 2 7; 3 8; 4 9; 5 10] + @test cs[2] == [11 16; 12 17; 13 18; 14 19; 15 20] + + x = permutedims(reshape(collect(1:10), (2, 5))) + cs = chunk(x; size = 2, dims = 1) + @test length(cs) == 3 + @test cs[1] == [1 2; 3 4] + @test cs[2] == [5 6; 7 8] + @test cs[3] == [9 10] + # test gradient test_zygote(chunk, rand(10), 3, check_inferred=false) @@ -111,10 +118,10 @@ end n = 2 dims = 2 x = rand(4, 5) - y = chunk(x, 2) - dy = randn!.(collect.(y)) - idxs = MLUtils._partition_idxs(x, n, dims) - test_zygote(MLUtils.∇chunk, dy, x, idxs, Val(dims), check_inferred=false) + l = chunk(x, 2) + dl = randn!.(collect.(l)) + idxs = MLUtils._partition_idxs(x, cld(size(x, dims), n), dims) + test_zygote(MLUtils.∇chunk, dl, x, idxs, Val(dims), check_inferred=false) end @testset "group_counts" begin