In [1]:
import Pkg; Pkg.add("Parameters")
using Base.Iterators: partition
using BSON
using DrWatson: struct2dict
using Flux
using Flux: logitbinarycrossentropy, chunk, binarycrossentropy
using Flux.Data: DataLoader
using Images
using Logging: with_logger
using MLDatasets
using Parameters: @with_kw
using ProgressMeter: Progress, next!
using TensorBoardLogger: TBLogger, tb_overwrite
using Random
import Pkg; Pkg.add("ImageMagick")
import Pkg; Pkg.add("ImageIO")

[32m[1m   Updating[22m[39m registry at `~/.julia/registries/General`


[?25l    

[32m[1m   Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`




[32m[1m  Resolving[22m[39m package versions...
[32m[1m  Installed[22m[39m Bzip2_jll ───── v1.0.6+4
[32m[1m  Installed[22m[39m FFMPEG_jll ──── v4.3.1+2
[32m[1m  Installed[22m[39m Zlib_jll ────── v1.2.11+16
[32m[1m  Installed[22m[39m FriBidi_jll ─── v1.0.5+5
[32m[1m  Installed[22m[39m JpegTurbo_jll ─ v2.0.1+2
[32m[1m  Installed[22m[39m libpng_jll ──── v1.6.37+5
[32m[1m  Installed[22m[39m Zstd_jll ────── v1.4.5+1
[32m[1m  Installed[22m[39m FreeType2_jll ─ v2.10.1+4
[32m[1m  Installed[22m[39m Libtiff_jll ─── v4.1.0+1
######################################################################### 100,0%##O=#  #                                                                       
######################################################################### 100,0%##O=#  #                                                                                                     33,3%
######################################################################### 100,0%#=#=-#

In [2]:
obs_dim = 784
latent_dim = 32
hidden_dims = [600, 400]

2-element Array{Int64,1}:
 600
 400

In [3]:
trainImg = Flux.Data.MNIST.images()
testImg = Flux.Data.MNIST.images(:test)
trainLabel = Flux.Data.MNIST.labels()
testLabel = Flux.Data.MNIST.labels(:test)


for i in 1:60000
    gray = Gray.(trainImg[i])
    gray = reshape(gray, (obs_dim, 1))
    trainImg[i] = Float64.(gray)
end
 
train_loader = Flux.Data.DataLoader((trainImg, trainLabel), batchsize=12, shuffle = true);
for (ix, iy) in train_loader
    for i in 1:12
        print(size(ix[i]))
    end
    break
end

(784, 1)(784, 1)(784, 1)(784, 1)(784, 1)(784, 1)(784, 1)(784, 1)(784, 1)(784, 1)(784, 1)(784, 1)

In [None]:
# Batched Data:
batch_size = 128
train_batched = []

for x in partition(trainImg, batch_size)
    if size(x)[1] > size(trainImg)[1] % batch_size 
        x = reshape(cat(x..., dims = 3), 28, 28, batch_size)
    else
        x = reshape(cat(x..., dims = 3), 28, 28, size(trainImg)[1] % batch_size)
    end
    push!(train_batched, x)
end

In [None]:
println(size(train_batched[2]))
train_batched = Float64.(train_batched)
#train_batched = convert(Array{Float64}, train_batched)

In [None]:
# Batched Labels:
batch_size = 128
label_train_batched = []

for x in partition(trainLabel, batch_size)
    if size(x)[1] > size(trainLabel)[1] % batch_size 
        x = reshape(cat(x..., dims = 2), 1, batch_size)
    else
        x = reshape(cat(x..., dims = 2), 1, size(trainLabel)[1] % batch_size)
    end
    push!(label_train_batched, x)
end

In [None]:
println(size(label_train_batched[1]))
label_train_batched = convert(Array{Float64}, label_train_batched)

In [3]:
# define layers
struct Encoder
    linear
    mu
    logsigma
    Encoder(obs_dim = 784, latent_dim= 32, hidden_dim=[600, 400], device= "cpu") = new(
        Chain(Dense(obs_dim, hidden_dims[1], relu), Dense(hidden_dims[1], hidden_dims[2], relu)),   # linear
        Dense(hidden_dims[2], latent_dim),        # μ
        Dense(hidden_dims[2], latent_dim),        # logσ
    )
end

function (encoder::Encoder)(x)
    h = encoder.linear(x)
    encoder.mu(h), encoder.logsigma(h)
end

#function encode(input, obs_dim=784, latent_dim=32, hidden_dims=[600, 400])
 #   encoder = Chain(Dense(obs_dim, hidden_dims[1], relu), Dense(hidden_dims[1], hidden_dims[2], relu))
  #  output = encoder(input)
   # mu = Dense(size(output)[1], latent_dim)(output)
    #logsigma = Dense(size(output)[1], latent_dim)(output)
    #return mu, logsigma
#end

#function decode(z)
#    decoder = Chain(Dense(latent_dim, hidden_dims[2], relu), Dense(hidden_dims[2], hidden_dims[1], relu), Dense(hidden_dims[1], obs_dim, sigmoid))
#    return decoder(z)
#end

Decoder(obs_dim= 784, latent_dim= 32, hidden_dim=[600, 400], device= "cpu") = Chain(
    Dense(latent_dim, hidden_dim[2], relu),
    Dense(hidden_dim[2], hidden_dim[1], relu), 
    Dense(hidden_dims[1], obs_dim, sigmoid)
)

Decoder (generic function with 5 methods)

In [4]:
# define loss function

function kl_divergence(mu, logsigma, len)
    #power2 = mu.^2
    #db = 2 * logsigma
    #expo =  exp.(db)
    #add = -1 * ones(size(db))
    #kl_div_array = 0.5 * (power2 + expo - db + add)
    #kl_div = 0
    #for i in 1:size(kl_div_array)[1]
    #    kl_div += kl_div_array[i]
    #end
    #return kl_div
    return 0.5f0 * sum(@. (exp(2f0 * logsigma) + mu^2 -1f0 - 2f0 * logsigma)) / len
end

function sample_with_reparam(mu, logsigma, device = "cpu")
    
        sample =  mu + randn(Float32, size(logsigma)) .* exp.(logsigma)
        return sample 
end

function sample(num_samples, decode)

        z = randn((num_samples, latent_dim))
        theta = decode(z)
        sample = torch.bernoulli(theta)
        return sample
end

function elbo(encode, decode, input, device = "cpu")

    mu, logsigma = encode(input)
    len = size(input)[end]
    z = sample_with_reparam(mu, logsigma)
    theta = decode(z)
    log_obs_prob = -sum(logitbinarycrossentropy(theta, input))/len
    kl = kl_divergence(mu, logsigma, len)
    elbo = log_obs_prob - kl
    return elbo
end

function convert_to_image(x)
    Gray.(permutedims(vcat(reshape.(chunk(sigmoid.(x), 12), 28, :)...), (2, 1)))
end

convert_to_image (generic function with 1 method)

In [5]:
function model_loss(encoder, decoder, x)
    mu, logsigma, decoder_z = reconstuct(encoder, decoder, x)
    len = size(x)[end]
    # KL-divergence
    kl_q_p = 0.5f0 * sum(@. (exp(2f0 * logsigma) + mu^2 -1f0 - 2f0 * logsigma)) / len

    logp_x_z = -sum(Flux.Losses.binarycrossentropy.(decoder_z, x)) / len
    # regularization
    #reg = λ * sum(x->sum(x.^2), Flux.params(decoder))
    
    -logp_x_z + kl_q_p 
end

model_loss (generic function with 1 method)

In [6]:
function reconstuct(encoder, decoder, x, device="cpu")
    mu, logsigma = encoder(x)
    z = mu + randn(Float32, size(logsigma)) .* exp.(logsigma)
    mu, logsigma, decoder(z)
end

reconstuct (generic function with 2 methods)

In [None]:
batch_size = 12
xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
xtrain = reshape(xtrain, 28^2, :)
train_loader = Flux.Data.DataLoader(xtrain, ytrain, batchsize=batch_size, shuffle=true)
learning_rate = 1e-4
max_epochs = 10
display_step = 200
encoder = Encoder()
decoder = Decoder()
opt = ADAM(learning_rate)
device = "cpu"
ps = Flux.params(encoder.linear, encoder.mu, encoder.logsigma, decoder)
tblogger = TBLogger("/home/cyrine/VAE_Fluxjl/", tb_overwrite)
original, _ = first(train_loader)

# training
train_steps = 0
@info "Start Training, total $(max_epochs) epochs"
for epoch = 1:max_epochs
    @info "Epoch $(epoch)"
    progress = Progress(length(train_loader))

    for (x, _) in train_loader 
        loss, back = Flux.pullback(ps) do
            model_loss(encoder, decoder, x)
        end
        print("here")
        grad = back(1f0)
        print("HERE")
        Flux.Optimise.update!(opt, ps, grad)
        # progress meter
        next!(progress; showvalues=[(:loss, loss)]) 

        # logging with TensorBoard
        if train_steps % display_step == 0
            with_logger(tblogger) do
                    @info "train" loss=loss
            end
        end

        train_steps += 1
    end
    # save image
    _, _, rec_original = reconstuct(encoder, decoder, original, device)
    image = convert_to_image(rec_original)
    image_path = joinpath("/home/cyrine/VAE_Fluxjl/", "epoch_$(epoch).png")
    save(image_path, image)
    @info "Image saved: $(image_path)"
end



            

│  - To prevent this behaviour, do `ProgressMeter.ijulia_behavior(:append)`. 
└ @ ProgressMeter /home/cyrine/.julia/packages/ProgressMeter/OUQkp/src/ProgressMeter.jl:441
[32mProgress:   3%|█▎                                       |  ETA: 0:52:33[39m
[34m  loss:  222.77444[39m

In [None]:
# save model
model_path = joinpath("/home/cyrine/VAE_Fluxjl/", "model.bson") 
let encoder = cpu(encoder), decoder = cpu(decoder)
    BSON.@save model_path encoder decoder 
    @info "Model saved: $(model_path)"
end

In [None]:
## compute Loss after fitting
x_sample = next_batch(vae.loader, 100)[1]
x_sample = broadcast(/,x_sample,maximum(x_sample,2))
cur_loss, = run(vae.sess, vae.Loss, Dict(vae.x => x_sample))

In [None]:
## plot some reconstructed samples
x_sample = next_batch(vae.loader, 100)[1]
x_sample = broadcast(/,x_sample,maximum(x_sample,2));
x_reconstruct = reconstruct(vae, x_sample)

figure(figsize=(8, 12))

for i in 1:5

    subplot(5, 2, 2*i-1)
    imshow(reshape(x_sample[i,:], 28, 28), vmin=0, vmax=1, cmap="gray")
    title("Test input")
    colorbar()
    
    subplot(5, 2, 2*i)
    imshow(reshape(x_reconstruct[i,:], 28, 28), vmin=0, vmax=1, cmap="gray")
    title("Reconstruction")
    colorbar()
end

tight_layout()