In [1]:
using Pkg; Pkg.activate("/home/dhairyagandhi96/temp/model-zoo/script/.."); Pkg.status();

using Flux, Flux.Data.MNIST, Statistics
using Flux: throttle, params
using Juno: @progress

    Status `~/temp/model-zoo/Project.toml`
  [1520ce14]   AbstractTrees v0.2.1
  [fbb218c0] ↑ BSON v0.2.3 ⇒ v0.2.4
  [54eefc05]   Cascadia v0.4.0
  [8f4d0f93]   Conda v1.3.0
  [864edb3b] ↑ DataStructures v0.17.0 ⇒ v0.17.5
  [31c24e10] ↑ Distributions v0.21.3 ⇒ v0.21.5
  [587475ba]   Flux v0.9.0
  [708ec375]   Gumbo v0.5.1
  [b0807396]   Gym v1.1.3
  [cd3eb016] ↑ HTTP v0.8.6 ⇒ v0.8.7
  [6218d12a]   ImageMagick v0.7.5
  [916415d5]   Images v0.18.0
  [e5e0dc1b]   Juno v0.7.2
  [ca7b5df7]   MFCC v0.3.1
  [dbeba491] + Metalhead v0.4.0 #c4d1eba (https://github.com/FluxML/Metalhead.jl.git)
  [91a5bcdd] ↑ Plots v0.26.3 ⇒ v0.27.0
  [2913bbd2]   StatsBase v0.32.0
  [98b73d46]   Trebuchet v0.1.0
  [8149f6b0] ↑ WAV v1.0.2 ⇒ v1.0.3
  [10745b16]   Statistics 
  [4ec0a83e]   Unicode 


Extend distributions slightly to have a numerically stable logpdf for `p` close to 1 or 0.

In [2]:
using Distributions
import Distributions: logpdf
logpdf(b::Bernoulli, y::Bool) = y * log(b.p + eps(Float32)) + (1f0 - y) * log(1 - b.p + eps(Float32))



logpdf (generic function with 66 methods)

Load data, binarise it, and partition into mini-batches of M.

In [3]:
X = float.(hcat(vec.(MNIST.images())...)) .> 0.5
N, M = size(X, 2), 100
data = [X[:,i] for i in Iterators.partition(1:N,M)]


################################# Define Model #################################

600-element Array{BitArray{2},1}:
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false; … ; false false … false false; false false … false false]
 [false false … false false; false false … false false

Latent dimensionality, # hidden units.

In [4]:
Dz, Dh = 5, 500

(5, 500)

Components of recognition model / "encoder" MLP.

In [5]:
A, μ, logσ = Dense(28^2, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz)
g(X) = (h = A(X); (μ(h), logσ(h)))
z(μ, logσ) = μ + exp(logσ) * randn(Float32)

z (generic function with 1 method)

Generative model / "decoder" MLP.

In [6]:
f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, 28^2, σ))


####################### Define ways of doing things with the model. #######################

Chain(Dense(5, 500, tanh), Dense(500, 784, NNlib.σ))

KL-divergence between approximation posterior and N(0, 1) prior.

In [7]:
kl_q_p(μ, logσ) = 0.5f0 * sum(exp.(2f0 .* logσ) + μ.^2 .- 1f0 .- (2 .* logσ))

kl_q_p (generic function with 1 method)

logp(x|z) - conditional probability of data given latents.

In [8]:
logp_x_z(x, z) = sum(logpdf.(Bernoulli.(f(z)), x))

logp_x_z (generic function with 1 method)

Monte Carlo estimator of mean ELBO using M samples.

In [9]:
L̄(X) = ((μ̂, logσ̂) = g(X); (logp_x_z(X, z.(μ̂, logσ̂)) - kl_q_p(μ̂, logσ̂)) * 1 // M)

loss(X) = -L̄(X) + 0.01f0 * sum(x->sum(x.^2), params(f))

loss (generic function with 1 method)

Sample from the learned model.

In [10]:
modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz)))))


################################# Learn Parameters ##############################

evalcb = throttle(() -> @show(-L̄(X[:, rand(1:N, M)])), 30)
opt = ADAM()
ps = params(A, μ, logσ, f)

@progress for i = 1:20
  @info "Epoch $i"
  Flux.train!(loss, ps, zip(data), opt, cb=evalcb)
end


################################# Sample Output ##############################

using Images

img(x) = Gray.(reshape(x, 28, 28))

cd(@__DIR__)
sample = hcat(img.([modelsample() for i = 1:10])...)
save("sample.png", sample)

┌ Info: Epoch 1
└ @ Main.##1175 string:11
-(L̄(X[:, rand(1:N, M)])) = 557.39294f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 204.6174f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 202.80965f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 178.3148f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 192.90036f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 170.90718f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 176.84448f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 179.97359f0 (tracked)
┌ Info: Epoch 2
└ @ Main.##1175 string:11
-(L̄(X[:, rand(1:N, M)])) = 176.46544f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 159.39726f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 174.87422f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 173.93828f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 164.67564f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 163.99402f0 (tracked)
┌ Info: Epoch 3
└ @ Main.##1175 string:11
-(L̄(X[:, rand(1:N, M)])) = 161.5583f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 164.52103f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) = 164.51729f0 (tracked)
-(L̄(X[:, rand(1:N, M)])) =