In [1]:
using Pkg; Pkg.activate("."); Pkg.instantiate();

using Flux, Flux.Data.MNIST, Statistics
using Flux: throttle, params
using Juno: @progress

  Updating registry at `~/.julia/registries/General`
  Updating git-repo `https://github.com/JuliaRegistries/General.git`
[?25l    Fetching: [>                                        ]  0.0 %[2K[?25h

Extend distributions slightly to have a numerically stable logpdf for `p` close to 1 or 0.

In [2]:
using Distributions
import Distributions: logpdf
logpdf(b::Bernoulli, y::Bool) = y * log(b.p + eps()) + (1 - y) * log(1 - b.p + eps())



logpdf (generic function with 62 methods)

Load data, binarise it, and partition into mini-batches of M.

In [3]:
X = float.(hcat(vec.(MNIST.images())...)) .> 0.5
N, M = size(X, 2), 100
data = [X[:,i] for i in Iterators.partition(1:N,M)]


################################# Define Model #################################

600-element Array{BitArray{2},1}:
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false

Latent dimensionality, # hidden units.

In [4]:
Dz, Dh = 5, 500

(5, 500)

Components of recognition model / "encoder" MLP.

In [5]:
A, μ, logσ = Dense(28^2, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz)
g(X) = (h = A(X); (μ(h), logσ(h)))
z(μ, logσ) = μ + exp(logσ) * randn()

z (generic function with 1 method)

Generative model / "decoder" MLP.

In [6]:
f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, 28^2, σ))


####################### Define ways of doing things with the model. #######################

Chain(Dense(5, 500, tanh), Dense(500, 784, NNlib.σ))

KL-divergence between approximation posterior and N(0, 1) prior.

In [7]:
kl_q_p(μ, logσ) = 0.5 * sum(exp.(2 .* logσ) + μ.^2 .- 1 .+ logσ.^2)

kl_q_p (generic function with 1 method)

logp(x|z) - conditional probability of data given latents.

In [8]:
logp_x_z(x, z) = sum(logpdf.(Bernoulli.(f(z)), x))

logp_x_z (generic function with 1 method)

Monte Carlo estimator of mean ELBO using M samples.

In [9]:
L̄(X) = ((μ̂, logσ̂) = g(X); (logp_x_z(X, z.(μ̂, logσ̂)) - kl_q_p(μ̂, logσ̂)) / M)

loss(X) = -L̄(X) + 0.01 * sum(x->sum(x.^2), params(f))

loss (generic function with 1 method)

Sample from the learned model.

In [10]:
modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz)))))


################################# Learn Parameters ##############################

evalcb = throttle(() -> @show(-L̄(X[:, rand(1:N, M)])), 30)
opt = ADAM(params(A, μ, logσ, f))
@progress for i = 1:10
  @info "Epoch $i"
  Flux.train!(loss, zip(data), opt, cb=evalcb)
end


################################# Sample Output ##############################

#using Images

#img(x) = Gray.(reshape(x, 28, 28))

#cd(@__DIR__)
#sample = hcat(img.([modelsample() for i = 1:10])...)
#save("sample.png", sample)

┌ Info: Epoch 1
└ @ Main.##358 string:9
-(L̄(X[:, rand(1:N, M)])) = 543.1589636480705 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 208.7558564799027 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 188.01054312972693 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 177.620175332416 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 181.41586825699696 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 177.64052726653924 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 162.65382475999098 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 170.40689285327417 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 165.3394833974157 (tracked)
┌ Info: Epoch 2
└ @ Main.##358 string:9
-(L̄(X[:, rand(1:N, M)])) = 169.30139679989122 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 170.56885601617753 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 165.88201025958278 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 161.77510983855007 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 159.3024751808574 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 163.6701361629169 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 160.26310338500895 (tracked)
-(L̄(X[:, rand(