diff --git a/src/dataloader.jl b/src/dataloader.jl index 2c7ab42..ebb67e8 100644 --- a/src/dataloader.jl +++ b/src/dataloader.jl @@ -65,7 +65,7 @@ The original data is preserved in the `data` field of the DataLoader. (depending on the `collate` and `batchsize` options, could be `getobs!(buffer, data, idxs)` or `getobs!(buffer[i], data, idx)`). Default `false`. - **`collate`**: Defines the batching behavior. Default `nothing`. - - If `nothing` , a batch is `getobs(data, indices)`. + - If `nothing`, a batch is `getobs(data, indices)`. - If `false`, each batch is `[getobs(data, i) for i in indices]`. - If `true`, applies `MLUtils.batch` to the vector of observations in a batch, recursively collating arrays in the last dimensions. See [`MLUtils.batch`](@ref) for more information @@ -235,7 +235,7 @@ _create_buffer(x) = getobs(x, 1) function _create_buffer(x::BatchView) obsindices = _batchrange(x, 1) - return [getobs(A.data, idx) for idx in enumerate(obsindices)] + return [getobs(x.data, i) for i in obsindices] end function _create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData} @@ -322,18 +322,24 @@ end # Base uses this function for composable array printing, e.g. adjoint(view(::Matrix))) function Base.showarg(io::IO, d::DataLoader, toplevel) - print(io, "DataLoader(") + print(io, "DataLoader(data") Base.showarg(io, d.data, false) - d.buffer == false || print(io, ", buffer=", d.buffer) + if d.buffer != false + print(io, ", buffer") + Base.showarg(io, d.buffer, false) + end d.parallel == false || print(io, ", parallel=", d.parallel) d.shuffle == false || print(io, ", shuffle=", d.shuffle) d.batchsize == 1 || print(io, ", batchsize=", d.batchsize) d.partial == true || print(io, ", partial=", d.partial) - d.collate === Val(nothing) || print(io, ", collate=", d.collate) + d.collate === Val(nothing) || print(io, ", collate=", _valstr(d.collate)) d.rng == Random.default_rng() || print(io, ", rng=", d.rng) print(io, ")") end +_valstr(::Val{T}) where T = string(T) +_valstr(x) = string(x) + Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false) function Base.show(io::IO, m::MIME"text/plain", d::DataLoader) diff --git a/test/dataloader.jl b/test/dataloader.jl index 26107a8..d087c68 100644 --- a/test/dataloader.jl +++ b/test/dataloader.jl @@ -271,7 +271,7 @@ d = DataLoader((X2, Y2), batchsize=3) - @test contains(repr(d), "DataLoader(::Tuple{Matrix") + @test contains(repr(d), "DataLoader(data::Tuple{Matrix") @test contains(repr(d), "batchsize=3") @test contains(repr(MIME"text/plain"(), d), "2-element DataLoader") @@ -279,13 +279,42 @@ d2 = DataLoader((x = X2, y = Y2), batchsize=2, partial=false) - @test contains(repr(d2), "DataLoader(::@NamedTuple") + @test contains(repr(d2), "DataLoader(data::@NamedTuple") @test contains(repr(d2), "partial=false") - @test contains(repr(MIME"text/plain"(), d2), "2-element DataLoader(::@NamedTuple") + @test contains(repr(MIME"text/plain"(), d2), "2-element DataLoader(data::@NamedTuple") @test contains(repr(MIME"text/plain"(), d2), "x = 2×2 Matrix{Float32}, y = 2-element Vector") end end + + @testset "buffer issue 205" begin + + function shift_pair(X) + inputs = map(X) do x + T = size(x, 4) + return selectdim(x, 4, 1:(T-1)) + end + targets = map(X) do x + T = size(x, 4) + return selectdim(x, 4, 2:T) + end + return (stack(inputs), stack(targets)) + end + + trajectory = randn(Float32, 32, 32, 4, 3, 5); + + loader = DataLoader( + trajectory; + batchsize=2, + partial=false, + buffer=true, + collate = shift_pair, + shuffle = false, + ) + + @test first(loader)[1] == trajectory[:, :, :, 1:2, 1:2] + @test first(loader)[2] == trajectory[:, :, :, 2:3, 1:2] + end end @testset "eachobs" begin