In [1]:
using MLDatasets: MNIST
using Flux.Data: DataLoader
using Flux: onehotbatch

train_x, train_y = MNIST.traindata(Float32)
test_x, test_y = MNIST.testdata(Float32)


println("Training data shape: ", size(train_x))
println("Training labels shape: ", size(train_y))
println("Test data shape: ", size(test_x))
println("Test labels shape: ", size(test_y))

Training data shape: (28, 28, 60000)
Training labels shape: (60000,)
Test data shape: (28, 28, 10000)
Test labels shape: (10000,)


In [2]:
# Reshaping the data to be compatible with Flux
train_x = reshape(train_x, 28, 28, 1, :)
test_x = reshape(test_x, 28, 28, 1, :)

println("Training data shape: ", size(train_x))
println("Training labels shape: ", size(train_y))

Training data shape: (28, 28, 1, 60000)
Training labels shape: (60000,)


In [3]:
train_y, test_y = onehotbatch(train_y, 0:9), onehotbatch(test_y, 0:9)

println("Training label shape: ", size(train_y))
println("Test label shape: ", size(test_y))

Training label shape: (10, 60000)
Test label shape: (10, 10000)


In [12]:
train_y[:, 1]

10-element OneHotVector(::UInt32) with eltype Bool:
 ⋅
 ⋅
 ⋅
 ⋅
 ⋅
 1
 ⋅
 ⋅
 ⋅
 ⋅

In [5]:
data_loader = DataLoader((train_x, train_y); batchsize = 128, shuffle = true)

DataLoader{Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}, Random._GLOBAL_RNG}(([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; … ;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Bool[0 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 0]), 128, 60000, true, true, Random._GLOBAL_RNG())

In [7]:
for (x, y) in data_loader
    @assert size(x) == (28, 28, 1, 128) || size(x) == (28, 28, 1, 96)
    @assert size(y) == (10, 128) || size(y) == (10, 96)
end