<a href="https://colab.research.google.com/github/a-mhamdi/jlai/blob/main/Codes/Julia/Part-3/vae/vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VARIATIONAL AUTOENCODER (VAE)
---

**VAE** implemented in `Julia` using the `Flux.jl` library

In [1]:
versioninfo()

Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 2 × Intel(R) Xeon(R) CPU @ 2.00GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, skylake-avx512)
Threads: 2 default, 0 interactive, 1 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/lib64-nvidia
  JULIA_NUM_THREADS = auto


Manage project dependencies.

In [2]:
pkgs = """[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
"""

open("Project.toml", "w") do file
    write(file, pkgs)
end

607

Activate the project environment and instantiate the listed packages.

In [3]:
_ = begin
  import Pkg;
  Pkg.activate(".");
  Pkg.instantiate();
end

[32m[1m  Activating[22m[39m project at `/content`
[92m[1mPrecompiling[22m[39m project...
  15086.0 ms[32m  ✓ [39mCUDA → SpecialFunctionsExt
  14899.5 ms[32m  ✓ [39m[90mAtomix → AtomixCUDAExt[39m
  15229.6 ms[32m  ✓ [39mCUDA → ChainRulesCoreExt
  14451.8 ms[32m  ✓ [39mCUDA → EnzymeCoreExt
  14928.3 ms[32m  ✓ [39m[90mMLDataDevices → MLDataDevicesCUDAExt[39m
  16803.4 ms[32m  ✓ [39m[90mStridedViews → StridedViewsCUDAExt[39m
  11294.7 ms[32m  ✓ [39m[90mNNlib → NNlibCUDAExt[39m
  14711.2 ms[32m  ✓ [39mcuDNN
  14986.5 ms[32m  ✓ [39m[90mMLDataDevices → MLDataDevicescuDNNExt[39m
  15400.3 ms[32m  ✓ [39m[90mNNlib → NNlibCUDACUDNNExt[39m
  16882.5 ms[32m  ✓ [39mFlux → FluxCUDAcuDNNExt
  11 dependencies successfully precompiled in 64 seconds. 446 already precompiled.


Display the status of the packages in the active project environment.

In [4]:
Pkg.status()

[32m[1mStatus[22m[39m `/content/Project.toml`
  [90m[fbb218c0] [39mBSON v0.3.9
[33m⌅[39m [90m[052768ef] [39mCUDA v5.8.5
  [90m[13f3f980] [39mCairoMakie v0.15.6
  [90m[587475ba] [39mFlux v0.16.5
  [90m[d8c32880] [39mImageInTerminal v0.5.4
  [90m[4e3cecfd] [39mImageShow v0.3.8
[33m⌅[39m [90m[033835bb] [39mJLD2 v0.5.15
  [90m[eb30cadb] [39mMLDatasets v0.7.18
  [90m[92933f4c] [39mProgressMeter v1.11.0
  [90m[10745b16] [39mStatistics v1.11.1
  [90m[02a925ec] [39mcuDNN v1.4.5
  [90m[d6f4376e] [39mMarkdown v1.11.0
[36m[1mInfo[22m[39m Packages marked with [33m⌅[39m have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`


Import the Flux library and some of its submodules.

In [5]:
using Flux # v"0.16.0"
using Flux: @functor
using Flux: DataLoader
using Flux: onecold, onehotbatch

Import the `CUDA` library for **GPU** acceleration.

In [6]:
using CUDA

Import the `ProgressMeter` library for displaying progress during training.

In [7]:
using ProgressMeter: Progress, next!

Load the **MNIST** dataset.

In [8]:
using MLDatasets
d = MNIST()

dataset MNIST:
  metadata  =>    Dict{String, Any} with 3 entries
  split     =>    :train
  features  =>    28×28×60000 Array{Float32, 3}
  targets   =>    60000-element Vector{Int64}

Let's define a mutable struct `HyperParams` to hold hyperparameters for the **VAE** model.

In [9]:
Base.@kwdef mutable struct HyperParams
    η = 3f-3                        # Learning rate
    λ = 1f-2                        # Regularization parameter
    batchsize = 64                  # Batch size
    epochs = 16                     # Number of epochs
    split = :train                  # Split data into `train` and `test`
    input_dim = 28*28               # Input dimension
    hidden_dim = 512                # Hidden dimension
    latent_dim = 2                  # Latent dimension
    # save_path = "Output"          # Results folder
end

HyperParams

Load and prepare the **MNIST** data for training and testing.

In [10]:
function get_data(; kws...)
    args = HyperParams(; kws...);
    # Split data
    data = MNIST(split=args.split);
    X = reshape(data.features, (args.input_dim, :));
    loader = DataLoader(X; batchsize=args.batchsize, shuffle=true) |> gpu;
    return loader
end

get_data (generic function with 1 method)

Load the training and testing data using the `get_data` function.

In [11]:
train_loader = get_data();
test_loader = get_data(split=:test);

This struct defines the encoder network for the **VAE**.

In [12]:
struct Encoder
    linear
    μ
    log_σ
end

This macro makes the `Encoder` struct compatible with Flux's automatic differentiation.

In [13]:
@functor Encoder

[33m[1m│ [22m[39mMost likely, you should write `Flux.@layer MyLayer`which will add various convenience methods for your type,such as pretty-printing and use with Adapt.jl.
[33m[1m│ [22m[39mHowever, this is not required. Flux.jl v0.15 uses Functors.jl v0.5,which makes exploration of most nested `struct`s opt-out instead of opt-in...so Flux will automatically see inside any custom struct definitions.
[33m[1m│ [22m[39mIf you really want to apply the `@functor` macro to a custom struct, use `Functors.@functor` instead.
[33m[1m└ [22m[39m[90m@ Flux ~/.julia/packages/Flux/uRn8o/src/deprecations.jl:101[39m


Construct an `Encoder` model with specified input, hidden, and latent dimensions.

In [14]:
encoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Encoder(
    Dense(input_dim, hidden_dim, tanh),   # linear
    Dense(hidden_dim, latent_dim),        # μ
    Dense(hidden_dim, latent_dim),        # log_σ
    ) |> gpu

encoder (generic function with 1 method)

Define the forward pass for the encoder, returning the mean and log variance of the latent distribution.

In [15]:
function (encoder::Encoder)(x)
    h = encoder.linear(x)
    encoder.μ(h), encoder.log_σ(h)
end

Construct a decoder model with specified input, hidden, and output dimensions.

In [16]:
decoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Chain(
    Dense(latent_dim, hidden_dim, tanh),
    Dense(hidden_dim, input_dim)
    ) |> gpu

decoder (generic function with 1 method)

Define the forward pass for the **VAE**, encoding the input, sampling from the latent space, and decoding to reconstruct the input.

In [17]:
function vae(x, enc, dec)
    # Encode `x` into the latent space
    μ, log_σ = enc(x)
    # `z` si a sample from the latent distribution
    z = μ + randn(Float32, size(log_σ)) .* exp.(log_σ)
    # Decode the latent representation into a reconstruction of `x`
    x̂ = dec(z)
    # Return μ, log_σ and x̂
    μ, log_σ, x̂
end

vae (generic function with 1 method)

Calculate the loss for the **VAE**, including the reconstruction loss, **KL** divergence loss, and **L2** regularization.

In [18]:
function l(x, enc, dec, λ)
    μ, log_σ, x̂ = vae(x, enc, dec)
    len = size(x)[end]
    # The reconstruction loss measures how well the VAE was able to reconstruct the input data
    logp_x_z = -Flux.Losses.logitbinarycrossentropy(x̂, x, agg=sum) / len
    # The KL divergence loss measures how close the latent distribution is to the normal distribution
    kl_q_p = 5f-1 * sum(@. (-2f0 * log_σ - 1f0 + exp(2f0 * log_σ) + μ^2)) / len
    # L2 Regularization
    reg = λ * sum( θ -> sum(θ.^2), Flux.params(dec) )
    # Sum of the reconstruction loss and the KL divergence loss
    -logp_x_z + kl_q_p + reg
end

l (generic function with 1 method)

Train the **VAE** model for a specified number of epochs using the `Adam` optimizer.

In [19]:
function train(; kws...)
    args = HyperParams(; kws...)

    # Initialize `encoder` and `decoder`
    enc_mdl = encoder(args.input_dim, args.hidden_dim, args.latent_dim)
    dec_mdl = decoder(args.input_dim, args.hidden_dim, args.latent_dim)

    # ADAM optimizers
    opt_enc = Flux.setup(Adam(args.η), enc_mdl)
    opt_dec = Flux.setup(Adam(args.η), dec_mdl)

    for epoch in 1:args.epochs
        printstyled("\t***\t === EPOCH $(epoch) === \t*** \n", color=:magenta, bold=true)
        progress = Progress(length(train_loader))
        for X in train_loader
                loss, back = Flux.pullback(enc_mdl, dec_mdl) do enc, dec
                    l(X, enc, dec, args.λ)
                end
                grad_enc, grad_dec = back(1f0)
                Flux.update!(opt_enc, enc_mdl, grad_enc) # Upd `encoder` params
                Flux.update!(opt_dec, dec_mdl, grad_dec) # Upd `decoder` params
                next!(progress; showvalues=[(:loss, loss)])
        end
    end

    # Save the model
    #=
    using DrWatson: struct2dict
    using BSON

    mdl_path = joinpath(args.save_path, "vae.bson")
    let args=struct2dict(args)
        BSON.@save mdl_path encoder decoder args
        @info "Model saved to $(mdl_path)"
    end
    =#

    enc_mdl, dec_mdl
end

train (generic function with 1 method)

Initiate the training process for the **VAE** model.

In [None]:
enc_model, dec_model = train()