# VARIATIONAL AUTOENCODER (VAE)
---

VAE implemented in `Julia` using the `Flux.jl` library

In [1]:
versioninfo() # -> v"1.11.2"

Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i7-8565U CPU @ 1.80GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, skylake)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)
Environment:
  LD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:
  DYLD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:
  JULIA_NUM_THREADS = 8


In [2]:
using Pkg; pkg"activate ."; pkg"status"

[32m[1m  Activating[22m[39m project at `~/Work/git-repos/AI-ML-DL/jlai/Codes/Julia/Part-3/vae`


[32m[1mStatus[22m[39m `~/Work/git-repos/AI-ML-DL/jlai/Codes/Julia/Part-3/vae/Project.toml`
  [90m[587475ba] [39mFlux v0.16.0
  [90m[eb30cadb] [39mMLDatasets v0.7.18
  [90m[91a5bcdd] [39mPlots v1.40.9
  [90m[c3e4b0f8] [39mPluto v0.20.4
  [90m[7f904dfe] [39mPlutoUI v0.7.60
  [90m[92933f4c] [39mProgressMeter v1.10.2
  [90m[d6f4376e] [39mMarkdown v1.11.0


Import the machine learning library `Flux`

In [3]:
using Flux # v"0.16.0"
using Flux: @functor
using Flux: DataLoader
using Flux: onecold, onehotbatch

In [4]:
using ProgressMeter: Progress, next!

In [5]:
using MLDatasets
d = MNIST()

dataset MNIST:
  metadata  =>    Dict{String, Any} with 3 entries
  split     =>    :train
  features  =>    28×28×60000 Array{Float32, 3}
  targets   =>    60000-element Vector{Int64}

In [6]:
Base.@kwdef mutable struct HyperParams
    η = 3f-3                        # Learning rate
    λ = 1f-2                        # Regularization parameter
    batchsize = 64                  # Batch size
    epochs = 16                     # Number of epochs
    split = :train                  # Split data into `train` and `test`
    input_dim = 28*28               # Input dimension
    hidden_dim = 512                # Hidden dimension
    latent_dim = 2                  # Latent dimension
    # save_path = "Output"          # Results folder
end

HyperParams

Load the **MNIST** dataset

In [7]:
function get_data(; kws...)
    args = HyperParams(; kws...);
    # Split data
    data = MNIST(split=args.split);
    X = reshape(data.features, (args.input_dim, :));
    loader = DataLoader(X; batchsize=args.batchsize, shuffle=true);
    return loader
end

get_data (generic function with 1 method)

In [8]:
train_loader = get_data();
test_loader = get_data(split=:test);

Define the `encoder` network: The encoder network should return the parameters of the _latent distribution_ (μ and σ).

In [9]:
struct Encoder
    linear
    μ
    log_σ
end

In [10]:
@functor Encoder

[33m[1m│ [22m[39mMost likely, you should write `Flux.@layer MyLayer`which will add various convenience methods for your type,such as pretty-printing and use with Adapt.jl.
[33m[1m│ [22m[39mHowever, this is not required. Flux.jl v0.15 uses Functors.jl v0.5,which makes exploration of most nested `struct`s opt-out instead of opt-in...so Flux will automatically see inside any custom struct definitions.
[33m[1m│ [22m[39mIf you really want to apply the `@functor` macro to a custom struct, use `Functors.@functor` instead.
[33m[1m└ [22m[39m[90m@ Flux ~/.julia/packages/Flux/Mhg1r/src/deprecations.jl:101[39m


In [11]:
encoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Encoder(
    Dense(input_dim, hidden_dim, tanh),   # linear
    Dense(hidden_dim, latent_dim),        # μ
    Dense(hidden_dim, latent_dim),        # log_σ
)

encoder (generic function with 1 method)

In [12]:
function (encoder::Encoder)(x)
    h = encoder.linear(x)
    encoder.μ(h), encoder.log_σ(h)
end

Define the `decoder` network: The decoder network should return the reconstruction of the input data

In [13]:
decoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Chain(
    Dense(latent_dim, hidden_dim, tanh),
    Dense(hidden_dim, input_dim)
)

decoder (generic function with 1 method)

Reconstruction of the input data

In [14]:
function vae(x, enc, dec)
    # Encode `x` into the latent space
    μ, log_σ = enc(x)
    # `z` si a sample from the latent distribution
    z = μ + randn(Float32, size(log_σ)) .* exp.(log_σ)
    # Decode the latent representation into a reconstruction of `x`
    x̂ = dec(z)
    # Return μ, log_σ and x̂
    μ, log_σ, x̂
end

vae (generic function with 1 method)

In [15]:
function l(x, enc, dec, λ)
    μ, log_σ, x̂ = vae(x, enc, dec)
    len = size(x)[end]
    # The reconstruction loss measures how well the VAE was able to reconstruct the input data
    logp_x_z = -Flux.Losses.logitbinarycrossentropy(x̂, x, agg=sum) / len
    # The KL divergence loss measures how close the latent distribution is to the normal distribution
    kl_q_p = 5f-1 * sum(@. (-2f0 * log_σ - 1f0 + exp(2f0 * log_σ) + μ^2)) / len
    # L2 Regularization
    reg = λ * sum( θ -> sum(θ.^2), Flux.params(dec) )
    # Sum of the reconstruction loss and the KL divergence loss
    -logp_x_z + kl_q_p + reg
end

l (generic function with 1 method)

In [16]:
function train(; kws...)
    args = HyperParams(; kws...)
    
    # Initialize `encoder` and `decoder`
    enc_mdl = encoder(args.input_dim, args.hidden_dim, args.latent_dim)
    dec_mdl = decoder(args.input_dim, args.hidden_dim, args.latent_dim)
    
    # ADAM optimizers
    opt_enc = Flux.setup(Adam(args.η), enc_mdl)
    opt_dec = Flux.setup(Adam(args.η), dec_mdl)

    for epoch in 1:args.epochs
        printstyled("\t***\t === EPOCH $(epoch) === \t*** \n", color=:magenta, bold=true)
        progress = Progress(length(train_loader))
        for X in train_loader
                loss, back = Flux.pullback(enc_mdl, dec_mdl) do enc, dec
                    l(X, enc, dec, args.λ)
                end
                grad_enc, grad_dec = back(1f0)
                Flux.update!(opt_enc, enc_mdl, grad_enc) # Upd `encoder` params
                Flux.update!(opt_dec, dec_mdl, grad_dec) # Upd `decoder` params
                next!(progress; showvalues=[(:loss, loss)]) 
        end
    end
    
    # Save the model
    #=
    using DrWatson: struct2dict
    using BSON

    mdl_path = joinpath(args.save_path, "vae.bson")
    let args=struct2dict(args)
    	BSON.@save mdl_path encoder decoder args
    	@info "Model saved to $(mdl_path)"
    end
    =#
    
    enc_mdl, dec_mdl
end

train (generic function with 1 method)

In [17]:
enc_model, dec_model = train()

[33m[1m│ [22m[39m - To prevent this behaviour, do `ProgressMeter.ijulia_behavior(:append)`. 
[33m[1m└ [22m[39m[90m@ ProgressMeter ~/.julia/packages/ProgressMeter/kVZZH/src/ProgressMeter.jl:594[39m
[32mProgress:  15%|██████▎                                  |  ETA: 0:12:19[39m
[34m  loss:  192.82417[39m

LoadError: InterruptException: