In [None]:
using Flux, Flux.Data.MNIST, Images
using Flux: onehotbatch, argmax, mse, throttle
using Base.Iterators: partition
using Images
#using CuArrays

For this model we'll learn a compression scheme. The idea is to encode MNIST digits as small vectors that can then be decoded back into the original image.

In [None]:
imgs = MNIST.images()
vecs = float.(vec.(imgs))

# Partition into batches of size 1000
data = [gpu(hcat(vecs[i]...)) for i in partition(1:60_000, 1000)];
vecs = gpu.(vecs);

In [None]:
imgs[1:10]

In [None]:
N = 32 # Size of the encoding

encoder = Dense(28^2, N, relu)
decoder = Dense(N, 28^2, relu)

m = gpu(Chain(encoder, decoder))

loss(x) = mse(m(x), x)

It's useful to be able to visualise what's happening in the network.

In [None]:
gray(x) = Gray(x)
img(x::AbstractVector) = collect(gray.(reshape(clamp.(x, 0, 1), 28, 28)))

First image:

In [None]:
imgs[1]

Round-trip through the model:

In [None]:
img(m(vecs[1]).data)

Random selection of 20 images:

In [None]:
function sample()
  # 20 random digits
  xs = [vecs[i] for i in rand(1:length(imgs), 20)]
  # Before and after images
  before, after = img.(xs), img.(map(x -> m(x).data, xs))
  # Stack them all together
  hcat(vcat.(before, after)...)
end
sample()

In [None]:
evalcb = function ()
    print_with_color(:blue, "Loss is $(loss(data[1][1]))")
    display(sample())
end
opt = ADAM(params(m))

for i = 1:10
    Flux.train!(loss, zip(data), opt)
    evalcb()
end