Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The original data is preserved in the `data` field of the DataLoader.
(depending on the `collate` and `batchsize` options, could be `getobs!(buffer, data, idxs)` or `getobs!(buffer[i], data, idx)`).
Default `false`.
- **`collate`**: Defines the batching behavior. Default `nothing`.
- If `nothing` , a batch is `getobs(data, indices)`.
- If `nothing`, a batch is `getobs(data, indices)`.
- If `false`, each batch is `[getobs(data, i) for i in indices]`.
- If `true`, applies `MLUtils.batch` to the vector of observations in a batch,
recursively collating arrays in the last dimensions. See [`MLUtils.batch`](@ref) for more information
Expand Down Expand Up @@ -235,7 +235,7 @@ _create_buffer(x) = getobs(x, 1)

function _create_buffer(x::BatchView)
obsindices = _batchrange(x, 1)
return [getobs(A.data, idx) for idx in enumerate(obsindices)]
return [getobs(x.data, i) for i in obsindices]
end

function _create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData}
Expand Down Expand Up @@ -322,18 +322,24 @@ end

# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix)))
function Base.showarg(io::IO, d::DataLoader, toplevel)
print(io, "DataLoader(")
print(io, "DataLoader(data")
Base.showarg(io, d.data, false)
d.buffer == false || print(io, ", buffer=", d.buffer)
if d.buffer != false
print(io, ", buffer")
Base.showarg(io, d.buffer, false)
end
d.parallel == false || print(io, ", parallel=", d.parallel)
d.shuffle == false || print(io, ", shuffle=", d.shuffle)
d.batchsize == 1 || print(io, ", batchsize=", d.batchsize)
d.partial == true || print(io, ", partial=", d.partial)
d.collate === Val(nothing) || print(io, ", collate=", d.collate)
d.collate === Val(nothing) || print(io, ", collate=", _valstr(d.collate))
d.rng == Random.default_rng() || print(io, ", rng=", d.rng)
print(io, ")")
end

_valstr(::Val{T}) where T = string(T)
_valstr(x) = string(x)

Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false)

function Base.show(io::IO, m::MIME"text/plain", d::DataLoader)
Expand Down
35 changes: 32 additions & 3 deletions test/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,21 +271,50 @@

d = DataLoader((X2, Y2), batchsize=3)

@test contains(repr(d), "DataLoader(::Tuple{Matrix")
@test contains(repr(d), "DataLoader(data::Tuple{Matrix")
@test contains(repr(d), "batchsize=3")

@test contains(repr(MIME"text/plain"(), d), "2-element DataLoader")
@test contains(repr(MIME"text/plain"(), d), "2×3 Matrix{Float32}, 3-element Vector")

d2 = DataLoader((x = X2, y = Y2), batchsize=2, partial=false)

@test contains(repr(d2), "DataLoader(::@NamedTuple")
@test contains(repr(d2), "DataLoader(data::@NamedTuple")
@test contains(repr(d2), "partial=false")

@test contains(repr(MIME"text/plain"(), d2), "2-element DataLoader(::@NamedTuple")
@test contains(repr(MIME"text/plain"(), d2), "2-element DataLoader(data::@NamedTuple")
@test contains(repr(MIME"text/plain"(), d2), "x = 2×2 Matrix{Float32}, y = 2-element Vector")
end
end

@testset "buffer issue 205" begin

function shift_pair(X)
inputs = map(X) do x
T = size(x, 4)
return selectdim(x, 4, 1:(T-1))
end
targets = map(X) do x
T = size(x, 4)
return selectdim(x, 4, 2:T)
end
return (stack(inputs), stack(targets))
end

trajectory = randn(Float32, 32, 32, 4, 3, 5);

loader = DataLoader(
trajectory;
batchsize=2,
partial=false,
buffer=true,
collate = shift_pair,
shuffle = false,
)

@test first(loader)[1] == trajectory[:, :, :, 1:2, 1:2]
@test first(loader)[2] == trajectory[:, :, :, 2:3, 1:2]
end
end

@testset "eachobs" begin
Expand Down
Loading