## Variational Autoencoder for MNIST data

### Loading the data

First we load the required packages:

In [1]:
using Flux, Flux.Data.MNIST, Statistics
using Flux: throttle, params

In [2]:
# Extend distributions slightly to have a numerically stable logpdf for `p` close to 1 or 0.
using Distributions
import Distributions: logpdf
logpdf(b::Bernoulli, y::Bool) = y * log(b.p + eps(Float32)) + (1f0 - y) * log(1 - b.p + eps(Float32))



logpdf (generic function with 62 methods)

Now we read in the data:

In [3]:
# Load data, binarise it, and partition into mini-batches of M.
X = float.(hcat(vec.(MNIST.images())...)) .> 0.5
N, M = size(X, 2), 100
data = [X[:,i] for i in Iterators.partition(1:N,M)];
(N,M)

(60000, 100)

In [6]:
X[:,1]

784-element BitArray{1}:
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
     ⋮
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false

In [5]:
################################# Define Model #################################

# Latent dimensionality, # hidden units.
Dz, Dh = 5, 500

# Components of recognition model / "encoder" MLP.
A, μ, logσ = Dense(28^2, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz)
g(X) = (h = A(X); (μ(h), logσ(h)))
z(μ, logσ) = μ + exp(logσ) * randn(Float32)

# Generative model / "decoder" MLP.
f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, 28^2, σ))

Chain(Dense(5, 500, tanh), Dense(500, 784, NNlib.σ))

In [6]:
####################### Define ways of doing things with the model. #######################

# KL-divergence between approximation posterior and N(0, 1) prior.
kl_q_p(μ, logσ) = 0.5f0 * sum(exp.(2f0 .* logσ) + μ.^2 .- 1f0 .+ logσ.^2)

# logp(x|z) - conditional probability of data given latents.
logp_x_z(x, z) = sum(logpdf.(Bernoulli.(f(z)), x))

# Monte Carlo estimator of mean ELBO using M samples.
L̄(X) = ((μ̂, logσ̂) = g(X); (logp_x_z(X, z.(μ̂, logσ̂)) - kl_q_p(μ̂, logσ̂)) * 1 // M)

loss(X) = -L̄(X) + 0.01f0 * sum(x->sum(x.^2), params(f))

# Sample from the learned model.
modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz)))))

modelsample (generic function with 1 method)

In [7]:
################################# Learn Parameters ##############################

evalcb = throttle(() -> @show(-L̄(X[:, rand(1:N, M)])), 30)
opt = ADAM()
ps = params(A, μ, logσ, f)

for i = 1:20
  @info "Epoch $i"
  Flux.train!(loss, ps, zip(data), opt, cb=evalcb)
end

┌ Info: Epoch 1
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 544.9813f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 175.18198f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 164.75838f0 (tracked)


┌ Info: Epoch 2
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 166.55515f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 169.11412f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.44893f0 (tracked)


┌ Info: Epoch 3
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 153.92361f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.722f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 152.04886f0 (tracked)


┌ Info: Epoch 4
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 142.13672f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 159.439f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 152.61314f0 (tracked)


┌ Info: Epoch 5
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 158.70023f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 151.49704f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 142.91176f0 (tracked)


┌ Info: Epoch 6
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 143.63165f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 153.68814f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.31458f0 (tracked)


┌ Info: Epoch 7
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 138.28984f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 144.12935f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.97943f0 (tracked)


┌ Info: Epoch 8
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 137.43669f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.21892f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 136.34074f0 (tracked)


┌ Info: Epoch 9
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 140.1035f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.19398f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 136.70186f0 (tracked)


┌ Info: Epoch 10
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 138.35522f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 141.93678f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 135.47774f0 (tracked)


┌ Info: Epoch 11
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 137.62286f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 138.34444f0 (tracked)


┌ Info: Epoch 12
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 135.87799f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 147.59515f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.06065f0 (tracked)


┌ Info: Epoch 13
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 138.60023f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 146.99283f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 135.73558f0 (tracked)


┌ Info: Epoch 14
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 148.5738f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 139.14633f0 (tracked)


┌ Info: Epoch 15
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 141.02647f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.58876f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 136.5529f0 (tracked)


┌ Info: Epoch 16
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 138.88948f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 142.06839f0 (tracked)


┌ Info: Epoch 17
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 133.72638f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 142.40042f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 141.37666f0 (tracked)


┌ Info: Epoch 18
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 151.34944f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 136.0661f0 (tracked)


┌ Info: Epoch 19
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 141.76482f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.24782f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 139.72148f0 (tracked)


┌ Info: Epoch 20
└ @ Main In[7]:8


-(L̄(X[:, rand(1:N, M)])) = 136.00517f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 133.91176f0 (tracked)


In [8]:
################################# Sample Output ##############################

using Images

img(x) = Gray.(reshape(x, 28, 28))

sample = hcat(img.([modelsample() for i = 1:10])...)
save("sample.png", sample)

In [9]:
for i = 1:20
  @info "Epoch $i"
  Flux.train!(loss, ps, zip(data), opt, cb=evalcb)
end

┌ Info: Epoch 1
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 137.48178f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 140.76006f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 142.49525f0 (tracked)


┌ Info: Epoch 2
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 140.76175f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 136.34175f0 (tracked)


┌ Info: Epoch 3
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 135.53745f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.63853f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 141.9314f0 (tracked)


┌ Info: Epoch 4
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 139.90031f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 134.64522f0 (tracked)


┌ Info: Epoch 5
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 143.83553f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 133.73615f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 137.242f0 (tracked)


┌ Info: Epoch 6
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 137.95589f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.19841f0 (tracked)


┌ Info: Epoch 7
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 133.44397f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 136.67905f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 141.81439f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 142.8838f0 (tracked)


┌ Info: Epoch 8
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 146.24826f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 142.12029f0 (tracked)


┌ Info: Epoch 9
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 142.2862f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 137.81175f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 140.80772f0 (tracked)


┌ Info: Epoch 10
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 144.76172f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 137.30658f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 134.41515f0 (tracked)


┌ Info: Epoch 11
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 142.57329f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.72849f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 141.55705f0 (tracked)


┌ Info: Epoch 12
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 142.71115f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 138.45088f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 132.97826f0 (tracked)


┌ Info: Epoch 13
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 136.85297f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 141.00037f0 (tracked)


┌ Info: Epoch 14
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 142.27643f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 142.3596f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 143.75272f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.95453f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.97285f0 (tracked)


┌ Info: Epoch 15
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 136.47821f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 141.70793f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 131.40163f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 134.59027f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 138.84686f0 (tracked)


┌ Info: Epoch 16
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 142.29745f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 140.22194f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 134.98273f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 144.98996f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 141.38715f0 (tracked)


┌ Info: Epoch 17
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 132.53378f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 148.12596f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 141.03625f0 (tracked)


┌ Info: Epoch 18
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 143.42809f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.40973f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 150.96144f0 (tracked)


┌ Info: Epoch 19
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 135.95183f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 139.17155f0 (tracked)


┌ Info: Epoch 20
└ @ Main In[9]:2


-(L̄(X[:, rand(1:N, M)])) = 140.36661f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 145.36574f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 139.43881f0 (tracked)


In [10]:
sample = hcat(img.([modelsample() for i = 1:10])...)
save("sample2.png", sample)