diff --git a/Project.toml b/Project.toml index 9666236..82ec98f 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" @@ -23,8 +24,9 @@ DataAPI = "1.0" DelimitedFiles = "1.0" FLoops = "0.2" FoldsThreads = "0.1" -SimpleTraits = "0.9" +NNlib = "0.8" ShowCases = "0.1" +SimpleTraits = "0.9" StatsBase = "0.33" Tables = "1.10" Transducers = "0.4" diff --git a/docs/src/api.md b/docs/src/api.md index b20d9bb..36d6a9b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -36,9 +36,8 @@ obsview ObsView ones_like oversample -MLUtils.rpad randobs -rpad(::AbstractVector, ::Integer, ::Any) +rpad_constant shuffleobs splitobs stack diff --git a/src/MLUtils.jl b/src/MLUtils.jl index 40c6aa5..c4d0dec 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -17,6 +17,7 @@ using ChainRulesCore: @non_differentiable, unthunk, AbstractZero, NoTangent, ZeroTangent, ProjectTo using SimpleTraits +import NNlib @traitdef IsTable{X} @traitimpl IsTable{X} <- Tables.istable(X) @@ -73,12 +74,12 @@ export batch, ones_like, rand_like, randn_like, + rpad_constant, stack, unbatch, unsqueeze, unstack, zeros_like - # rpad include("Datasets/Datasets.jl") using .Datasets diff --git a/src/deprecations.jl b/src/deprecations.jl index f9456e9..7545002 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,4 +1,4 @@ -# Deprecations v0.1 +# Deprecated in v0.2 @deprecate stack(x, dims) stack(x; dims=dims) @deprecate unstack(x, dims) unstack(x; dims=dims) @deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims) @@ -7,3 +7,6 @@ @deprecate frequencies(x) group_counts(x) @deprecate eachbatch(data, batchsize; kws...) eachobs(data; batchsize, kws...) @deprecate eachbatch(data; size=1, kws...) eachobs(data; batchsize=size, kws...) + +# Deprecated in v0.3 +@deprecate rpad(v::AbstractVector, n::Integer, p) rpad_constant(v, n, p) diff --git a/src/utils.jl b/src/utils.jl index b534292..fd254b7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -388,52 +388,78 @@ unbatch(x::AbstractArray) = [getobs(x, i) for i in 1:numobs(x)] unbatch(x::AbstractVector) = x """ - rpad(v::AbstractVector, n::Integer, p) + batchseq(seqs, val = 0) -Return the given sequence padded with `p` up to a maximum length of `n`. +Take a list of `N` sequences, and turn them into a single sequence where each +item is a batch of `N`. Short sequences will be padded by `val`. # Examples ```jldoctest -julia> rpad([1, 2], 4, 0) +julia> batchseq([[1, 2, 3], [4, 5]], 0) +3-element Vector{Vector{Int64}}: + [1, 4] + [2, 5] + [3, 0] +``` +""" +function batchseq(xs, val = 0, n = nothing) + n = n === nothing ? maximum(x -> size(x, ndims(x)), xs) : n + xs_ = [rpad_constant(x, n, val; dims=ndims(x)) for x in xs] + [batch([obsview(xs_[j], i) for j = 1:length(xs_)]) for i = 1:n] +end + +""" + rpad_constant(v::AbstractArray, n::Union{Integer, Tuple}, val = 0; dims=:) + +Return the given sequence padded with `val` along the dimensions `dims` +up to a maximum length in each direction specified by `n`. + +# Examples +```jldoctest +julia> rpad_constant([1, 2], 4, -1) # passing with -1 up to size 4 4-element Vector{Int64}: 1 2 - 0 - 0 + -1 + -1 -julia> rpad([1, 2, 3], 2, 0) +julia> rpad_constant([1, 2, 3], 2) # no padding if length is already greater than n 3-element Vector{Int64}: 1 2 3 -``` -""" -Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))] -# TODO Piracy - -""" - batchseq(seqs, pad) - -Take a list of `N` sequences, and turn them into a single sequence where each -item is a batch of `N`. Short sequences will be padded by `pad`. - -# Examples +julia> rpad_constant([1 2; 3 4], 4; dims=1) # padding along the first dimension +4×2 Matrix{Int64}: + 1 2 + 3 4 + 0 0 + 0 0 -```jldoctest -julia> batchseq([[1, 2, 3], [4, 5]], 0) -3-element Vector{Vector{Int64}}: - [1, 4] - [2, 5] - [3, 0] +julia> rpad_constant([1 2; 3 4], 4) # padding along all dimensions by default +4×2 Matrix{Int64}: + 1 2 + 3 4 + 0 0 + 0 0 ``` """ -function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs)) - xs_ = [rpad(x, n, pad) for x in xs] - [batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n] +function rpad_constant(x::AbstractArray, n::Union{Integer, Tuple}, val=0; dims=:) + ns = _rpad_pads(x, n, dims) + return NNlib.pad_constant(x, ns, val; dims) +end + +function _rpad_pads(x, n, dims) + _dims = dims === Colon() ? (1:ndims(x)) : dims + _n = n isa Integer ? ntuple(i -> n, length(_dims)) : n + @assert length(_dims) == length(_n) + ns = ntuple(i -> isodd(i) ? 0 : max(_n[i÷2] - size(x, _dims[i÷2]), 0), 2*length(_n)) + return ns end +@non_differentiable _rpad_pads(::Any...) + """ flatten(x::AbstractArray) diff --git a/test/utils.jl b/test/utils.jl index 85ec41c..d5b80bb 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -129,11 +129,6 @@ end @test d == Dict('a' => 1, 'b' => 2) end -@testset "rpad" begin - @test rpad([1, 2], 4, 0) == [1, 2, 0, 0] - @test rpad([1, 2, 3], 2, 0) == [1,2,3] -end - @testset "batchseq" begin bs = batchseq([[1, 2, 3], [4, 5]], 0) @test bs[1] == [1, 4] @@ -144,6 +139,11 @@ end @test bs[1] == [1, 4] @test bs[2] == [2, 5] @test bs[3] == [3, -1] + + batchseq([ones(2,4), zeros(2, 3), ones(2,2)]) ==[[1.0 0.0 1.0; 1.0 0.0 1.0] + [1.0 0.0 1.0; 1.0 0.0 1.0] + [1.0 0.0 0.0; 1.0 0.0 0.0] + [1.0 0.0 0.0; 1.0 0.0 0.0]] end @testset "ones_like" begin @@ -188,3 +188,11 @@ end test_zygote(fill_like, rand(5), rand(), (2, 4, 2)) end + +@testset "rpad_constant" begin + @test rpad_constant([1, 2], 4, -1) == [1, 2, -1, -1] + @test rpad_constant([1, 2, 3], 2) == [1, 2, 3] + @test rpad_constant([1 2; 3 4], 4; dims=1) == [1 2; 3 4; 0 0; 0 0] + @test rpad_constant([1 2; 3 4], 4) == [1 2 0 0; 3 4 0 0; 0 0 0 0; 0 0 0 0] + @test rpad_constant([1 2; 3 4], (3, 4)) == [1 2 0 0; 3 4 0 0; 0 0 0 0] +end \ No newline at end of file