In [2]:
import Pkg; Pkg.add("Parameters")
using Base.Iterators: partition
using BSON
using DrWatson: struct2dict
using Flux
using Flux: logitbinarycrossentropy, chunk, binarycrossentropy
using Flux.Data: DataLoader
using Images
using Logging: with_logger
using MLDatasets
using Parameters: @with_kw
using ProgressMeter: Progress, next!
using TensorBoardLogger: TBLogger, tb_overwrite
using Random
import Pkg; Pkg.add("ImageMagick")
import Pkg; Pkg.add("ImageIO")

[32m[1m   Updating[22m[39m registry at `~/.julia/registries/General`


[?25l    

[32m[1m   Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`








[32m[1m  Resolving[22m[39m package versions...
[32m[1m   Updating[22m[39m `~/.julia/environments/v1.4/Project.toml`
[90m [no changes][39m
[32m[1m   Updating[22m[39m `~/.julia/environments/v1.4/Manifest.toml`
[90m [no changes][39m
[32m[1m  Resolving[22m[39m package versions...
[32m[1m   Updating[22m[39m `~/.julia/environments/v1.4/Project.toml`
[90m [no changes][39m
[32m[1m   Updating[22m[39m `~/.julia/environments/v1.4/Manifest.toml`
[90m [no changes][39m
[32m[1m  Resolving[22m[39m package versions...
[32m[1m   Updating[22m[39m `~/.julia/environments/v1.4/Project.toml`
[90m [no changes][39m
[32m[1m   Updating[22m[39m `~/.julia/environments/v1.4/Manifest.toml`
[90m [no changes][39m


In [3]:
obs_dim = 784
latent_dim = 32
hidden_dims = [600, 400]

2-element Array{Int64,1}:
 600
 400

In [4]:
# define layers
struct Encoder
    linear
    mu
    logsigma
    Encoder(obs_dim = 784, latent_dim= 32, hidden_dim=[600, 400], device= "cpu") = new(
        Chain(Dense(obs_dim, hidden_dims[1], relu), Dense(hidden_dims[1], hidden_dims[2], relu)),   # linear
        Dense(hidden_dims[2], latent_dim),        # μ
        Dense(hidden_dims[2], latent_dim),        # logσ
    )
end

function (encoder::Encoder)(x)
    h = encoder.linear(x)
    encoder.mu(h), encoder.logsigma(h)
end


Decoder(obs_dim= 784, latent_dim= 32, hidden_dim=[600, 400], device= "cpu") = Chain(
    Dense(latent_dim, hidden_dim[2], relu),
    Dense(hidden_dim[2], hidden_dim[1], relu), 
    Dense(hidden_dims[1], obs_dim, sigmoid)
)

Decoder (generic function with 5 methods)

In [5]:
# define loss function

function kl_divergence(mu, logsigma, len)

    return 0.5f0 * sum(@. (exp(2f0 * logsigma) + mu^2 -1f0 - 2f0 * logsigma)) / len
end

function sample_with_reparam(mu, logsigma, device = "cpu")
    
        sample =  mu + randn(Float32, size(logsigma)) .* exp.(logsigma)
        return sample 
end

function sample(num_samples, decode)

        z = randn((num_samples, latent_dim))
        theta = decode(z)
        sample = torch.bernoulli(theta)
        return sample
end

function elbo(encode, decode, input, device = "cpu")

    mu, logsigma = encode(input)
    len = size(input)[end]
    z = sample_with_reparam(mu, logsigma)
    theta = decode(z)
    log_obs_prob = -sum(logitbinarycrossentropy(theta, input))/len
    kl = kl_divergence(mu, logsigma, len)
    elbo = log_obs_prob - kl
    return elbo
end

function convert_to_image(x)
    Gray.(permutedims(vcat(reshape.(chunk(sigmoid.(x), 12), 28, :)...), (2, 1)))
end

convert_to_image (generic function with 1 method)

In [6]:
function model_loss(encoder, decoder, x)
    mu, logsigma, decoder_z = reconstuct(encoder, decoder, x)
    len = size(x)[end]
    # KL-divergence
    kl_q_p = 0.5f0 * sum(@. (exp(2f0 * logsigma) + mu^2 -1f0 - 2f0 * logsigma)) / len
    logp_x_z = -sum(Flux.Losses.binarycrossentropy.(decoder_z, x)) / len
    -logp_x_z + kl_q_p 
end

model_loss (generic function with 1 method)

In [7]:
function reconstuct(encoder, decoder, x, device="cpu")
    mu, logsigma = encoder(x)
    z = mu + randn(Float32, size(logsigma)) .* exp.(logsigma)
    mu, logsigma, decoder(z)
end

reconstuct (generic function with 2 methods)

In [9]:
batch_size = 12
xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
xtrain = reshape(xtrain, 28^2, :)
train_loader = Flux.Data.DataLoader(xtrain, ytrain, batchsize=batch_size, shuffle=true)
learning_rate = 1e-4
max_epochs = 10
display_step = 200

5000

200

In [8]:
encoder = Encoder()
decoder = Decoder()
opt = ADAM(learning_rate)
device = "cpu"
ps = Flux.params(encoder.linear, encoder.mu, encoder.logsigma, decoder)
tblogger = TBLogger("/home/cyrine/VAE_Fluxjl/", tb_overwrite)
original, _ = first(train_loader)

# training
train_steps = 0
@info "Start Training, total $(max_epochs) epochs"
for epoch = 1:max_epochs
    @info "Epoch $(epoch)"
    progress = Progress(length(train_loader))

    for (x, _) in train_loader 
        loss, back = Flux.pullback(ps) do
            model_loss(encoder, decoder, x)
        end
        grad = back(1f0)
        Flux.Optimise.update!(opt, ps, grad)
        # progress meter
        next!(progress; showvalues=[(:loss, loss)]) 

        # logging with TensorBoard
        if train_steps % display_step == 0
            with_logger(tblogger) do
                    @info "train" loss=loss
            end
        end

        train_steps += 1
    end
    # save image
    _, _, rec_original = reconstuct(encoder, decoder, original, device)
    image = convert_to_image(rec_original)
    image_path = joinpath("/home/cyrine/VAE_Fluxjl/", "epoch_$(epoch).png")
    save(image_path, image)
    @info "Image saved: $(image_path)"
end



            

│   caller = ip:0x0
└ @ Core :-1


UndefVarError: UndefVarError: len not defined

In [8]:
# save model
model_path = joinpath("/home/cyrine/VAE_Fluxjl/", "model.bson") 
let encoder = cpu(encoder), decoder = cpu(decoder)
    BSON.@save model_path encoder decoder 
    @info "Model saved: $(model_path)"
end

┌ Info: Model saved: /home/cyrine/VAE_Fluxjl/model.bson
└ @ Main In[8]:5
