Skip to content

Commit

Permalink
generalize batchseq to sequence of generic arrays (#126)
Browse files Browse the repository at this point in the history
generalize batchseq to sequence of generic arrays
  • Loading branch information
CarloLucibello committed Oct 28, 2022
2 parents a85c098 + bd3bef8 commit 855f95b
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 37 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
Expand All @@ -23,8 +24,9 @@ DataAPI = "1.0"
DelimitedFiles = "1.0"
FLoops = "0.2"
FoldsThreads = "0.1"
SimpleTraits = "0.9"
NNlib = "0.8"
ShowCases = "0.1"
SimpleTraits = "0.9"
StatsBase = "0.33"
Tables = "1.10"
Transducers = "0.4"
Expand Down
3 changes: 1 addition & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ obsview
ObsView
ones_like
oversample
MLUtils.rpad
randobs
rpad(::AbstractVector, ::Integer, ::Any)
rpad_constant
shuffleobs
splitobs
stack
Expand Down
3 changes: 2 additions & 1 deletion src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using ChainRulesCore: @non_differentiable, unthunk, AbstractZero,
NoTangent, ZeroTangent, ProjectTo

using SimpleTraits
import NNlib

@traitdef IsTable{X}
@traitimpl IsTable{X} <- Tables.istable(X)
Expand Down Expand Up @@ -73,12 +74,12 @@ export batch,
ones_like,
rand_like,
randn_like,
rpad_constant,
stack,
unbatch,
unsqueeze,
unstack,
zeros_like
# rpad

include("Datasets/Datasets.jl")
using .Datasets
Expand Down
5 changes: 4 additions & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Deprecations v0.1
# Deprecated in v0.2
@deprecate stack(x, dims) stack(x; dims=dims)
@deprecate unstack(x, dims) unstack(x; dims=dims)
@deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims)
Expand All @@ -7,3 +7,6 @@
@deprecate frequencies(x) group_counts(x)
@deprecate eachbatch(data, batchsize; kws...) eachobs(data; batchsize, kws...)
@deprecate eachbatch(data; size=1, kws...) eachobs(data; batchsize=size, kws...)

# Deprecated in v0.3
@deprecate rpad(v::AbstractVector, n::Integer, p) rpad_constant(v, n, p)
80 changes: 53 additions & 27 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,52 +388,78 @@ unbatch(x::AbstractArray) = [getobs(x, i) for i in 1:numobs(x)]
unbatch(x::AbstractVector) = x

"""
rpad(v::AbstractVector, n::Integer, p)
batchseq(seqs, val = 0)
Return the given sequence padded with `p` up to a maximum length of `n`.
Take a list of `N` sequences, and turn them into a single sequence where each
item is a batch of `N`. Short sequences will be padded by `val`.
# Examples
```jldoctest
julia> rpad([1, 2], 4, 0)
julia> batchseq([[1, 2, 3], [4, 5]], 0)
3-element Vector{Vector{Int64}}:
[1, 4]
[2, 5]
[3, 0]
```
"""
function batchseq(xs, val = 0, n = nothing)
n = n === nothing ? maximum(x -> size(x, ndims(x)), xs) : n
xs_ = [rpad_constant(x, n, val; dims=ndims(x)) for x in xs]
[batch([obsview(xs_[j], i) for j = 1:length(xs_)]) for i = 1:n]
end

"""
rpad_constant(v::AbstractArray, n::Union{Integer, Tuple}, val = 0; dims=:)
Return the given sequence padded with `val` along the dimensions `dims`
up to a maximum length in each direction specified by `n`.
# Examples
```jldoctest
julia> rpad_constant([1, 2], 4, -1) # passing with -1 up to size 4
4-element Vector{Int64}:
1
2
0
0
-1
-1
julia> rpad([1, 2, 3], 2, 0)
julia> rpad_constant([1, 2, 3], 2) # no padding if length is already greater than n
3-element Vector{Int64}:
1
2
3
```
"""
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
# TODO Piracy

"""
batchseq(seqs, pad)
Take a list of `N` sequences, and turn them into a single sequence where each
item is a batch of `N`. Short sequences will be padded by `pad`.
# Examples
julia> rpad_constant([1 2; 3 4], 4; dims=1) # padding along the first dimension
4×2 Matrix{Int64}:
1 2
3 4
0 0
0 0
```jldoctest
julia> batchseq([[1, 2, 3], [4, 5]], 0)
3-element Vector{Vector{Int64}}:
[1, 4]
[2, 5]
[3, 0]
julia> rpad_constant([1 2; 3 4], 4) # padding along all dimensions by default
4×2 Matrix{Int64}:
1 2
3 4
0 0
0 0
```
"""
function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
xs_ = [rpad(x, n, pad) for x in xs]
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
function rpad_constant(x::AbstractArray, n::Union{Integer, Tuple}, val=0; dims=:)
ns = _rpad_pads(x, n, dims)
return NNlib.pad_constant(x, ns, val; dims)
end

function _rpad_pads(x, n, dims)
_dims = dims === Colon() ? (1:ndims(x)) : dims
_n = n isa Integer ? ntuple(i -> n, length(_dims)) : n
@assert length(_dims) == length(_n)
ns = ntuple(i -> isodd(i) ? 0 : max(_n[i÷2] - size(x, _dims[i÷2]), 0), 2*length(_n))
return ns
end

@non_differentiable _rpad_pads(::Any...)

"""
flatten(x::AbstractArray)
Expand Down
18 changes: 13 additions & 5 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ end
@test d == Dict('a' => 1, 'b' => 2)
end

@testset "rpad" begin
@test rpad([1, 2], 4, 0) == [1, 2, 0, 0]
@test rpad([1, 2, 3], 2, 0) == [1,2,3]
end

@testset "batchseq" begin
bs = batchseq([[1, 2, 3], [4, 5]], 0)
@test bs[1] == [1, 4]
Expand All @@ -144,6 +139,11 @@ end
@test bs[1] == [1, 4]
@test bs[2] == [2, 5]
@test bs[3] == [3, -1]

batchseq([ones(2,4), zeros(2, 3), ones(2,2)]) ==[[1.0 0.0 1.0; 1.0 0.0 1.0]
[1.0 0.0 1.0; 1.0 0.0 1.0]
[1.0 0.0 0.0; 1.0 0.0 0.0]
[1.0 0.0 0.0; 1.0 0.0 0.0]]
end

@testset "ones_like" begin
Expand Down Expand Up @@ -188,3 +188,11 @@ end

test_zygote(fill_like, rand(5), rand(), (2, 4, 2))
end

@testset "rpad_constant" begin
@test rpad_constant([1, 2], 4, -1) == [1, 2, -1, -1]
@test rpad_constant([1, 2, 3], 2) == [1, 2, 3]
@test rpad_constant([1 2; 3 4], 4; dims=1) == [1 2; 3 4; 0 0; 0 0]
@test rpad_constant([1 2; 3 4], 4) == [1 2 0 0; 3 4 0 0; 0 0 0 0; 0 0 0 0]
@test rpad_constant([1 2; 3 4], (3, 4)) == [1 2 0 0; 3 4 0 0; 0 0 0 0]
end

0 comments on commit 855f95b

Please sign in to comment.