<a href="https://colab.research.google.com/github/a-mhamdi/jlai/blob/main/Codes/Julia/Part-3/vae/vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VARIATIONAL AUTOENCODER (VAE)
---

**VAE** implemented in `Julia` using the `Flux.jl` library

In [None]:
versioninfo()

Manage project dependencies.

In [None]:
pkgs = """[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
"""

open("Project.toml", "w") do file
    write(file, pkgs)
end

Activate the project environment and instantiate the listed packages.

In [None]:
_ = begin
  import Pkg;
  Pkg.activate(".");
  Pkg.instantiate();
end

Display the status of the packages in the active project environment.

In [None]:
Pkg.status()

Import the Flux library and some of its submodules.

In [None]:
using Flux # v"0.16.0"
using Flux: @functor
using Flux: DataLoader
using Flux: onecold, onehotbatch

Import the `CUDA` library for **GPU** acceleration.

In [None]:
using CUDA

Import the `ProgressMeter` library for displaying progress during training.

In [None]:
using ProgressMeter: Progress, next!

Load the **MNIST** dataset.

In [None]:
using MLDatasets
d = MNIST()

Let's define a mutable struct `HyperParams` to hold hyperparameters for the **VAE** model.

In [None]:
Base.@kwdef mutable struct HyperParams
    η = 3f-3                        # Learning rate
    λ = 1f-2                        # Regularization parameter
    batchsize = 64                  # Batch size
    epochs = 16                     # Number of epochs
    split = :train                  # Split data into `train` and `test`
    input_dim = 28*28               # Input dimension
    hidden_dim = 512                # Hidden dimension
    latent_dim = 2                  # Latent dimension
    # save_path = "Output"          # Results folder
end

Load and prepare the **MNIST** data for training and testing.

In [None]:
function get_data(; kws...)
    args = HyperParams(; kws...);
    # Split data
    data = MNIST(split=args.split);
    X = reshape(data.features, (args.input_dim, :));
    loader = DataLoader(X; batchsize=args.batchsize, shuffle=true) |> gpu;
    return loader
end

Load the training and testing data using the `get_data` function.

In [None]:
train_loader = get_data();
test_loader = get_data(split=:test);

This struct defines the encoder network for the **VAE**.

In [None]:
struct Encoder
    linear
    μ
    log_σ
end

This macro makes the `Encoder` struct compatible with Flux's automatic differentiation.

In [None]:
@functor Encoder

Construct an `Encoder` model with specified input, hidden, and latent dimensions.

In [None]:
encoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Encoder(
    Dense(input_dim, hidden_dim, tanh),   # linear
    Dense(hidden_dim, latent_dim),        # μ
    Dense(hidden_dim, latent_dim),        # log_σ
    ) |> gpu

Define the forward pass for the encoder, returning the mean and log variance of the latent distribution.

In [None]:
function (encoder::Encoder)(x)
    h = encoder.linear(x)
    encoder.μ(h), encoder.log_σ(h)
end

Construct a decoder model with specified input, hidden, and output dimensions.

In [None]:
decoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Chain(
    Dense(latent_dim, hidden_dim, tanh),
    Dense(hidden_dim, input_dim)
    ) |> gpu

Define the forward pass for the **VAE**, encoding the input, sampling from the latent space, and decoding to reconstruct the input.

In [None]:
function vae(x, enc, dec)
    # Encode `x` into the latent space
    μ, log_σ = enc(x)
    # `z` si a sample from the latent distribution
    z = μ + randn(Float32, size(log_σ)) .* exp.(log_σ)
    # Decode the latent representation into a reconstruction of `x`
    x̂ = dec(z)
    # Return μ, log_σ and x̂
    μ, log_σ, x̂
end

Calculate the loss for the **VAE**, including the reconstruction loss, **KL** divergence loss, and **L2** regularization.

In [None]:
function l(x, enc, dec, λ)
    μ, log_σ, x̂ = vae(x, enc, dec)
    len = size(x)[end]
    # The reconstruction loss measures how well the VAE was able to reconstruct the input data
    logp_x_z = -Flux.Losses.logitbinarycrossentropy(x̂, x, agg=sum) / len
    # The KL divergence loss measures how close the latent distribution is to the normal distribution
    kl_q_p = 5f-1 * sum(@. (-2f0 * log_σ - 1f0 + exp(2f0 * log_σ) + μ^2)) / len
    # L2 Regularization
    reg = λ * sum( θ -> sum(θ.^2), Flux.params(dec) )
    # Sum of the reconstruction loss and the KL divergence loss
    -logp_x_z + kl_q_p + reg
end

Train the **VAE** model for a specified number of epochs using the `Adam` optimizer.

In [None]:
function train(; kws...)
    args = HyperParams(; kws...)

    # Initialize `encoder` and `decoder`
    enc_mdl = encoder(args.input_dim, args.hidden_dim, args.latent_dim)
    dec_mdl = decoder(args.input_dim, args.hidden_dim, args.latent_dim)

    # ADAM optimizers
    opt_enc = Flux.setup(Adam(args.η), enc_mdl)
    opt_dec = Flux.setup(Adam(args.η), dec_mdl)

    for epoch in 1:args.epochs
        printstyled("\t***\t === EPOCH $(epoch) === \t*** \n", color=:magenta, bold=true)
        progress = Progress(length(train_loader))
        for X in train_loader
                loss, back = Flux.pullback(enc_mdl, dec_mdl) do enc, dec
                    l(X, enc, dec, args.λ)
                end
                grad_enc, grad_dec = back(1f0)
                Flux.update!(opt_enc, enc_mdl, grad_enc) # Upd `encoder` params
                Flux.update!(opt_dec, dec_mdl, grad_dec) # Upd `decoder` params
                next!(progress; showvalues=[(:loss, loss)])
        end
    end

    # Save the model
    #=
    using DrWatson: struct2dict
    using BSON

    mdl_path = joinpath(args.save_path, "vae.bson")
    let args=struct2dict(args)
        BSON.@save mdl_path encoder decoder args
        @info "Model saved to $(mdl_path)"
    end
    =#

    enc_mdl, dec_mdl
end

Initiate the training process for the **VAE** model.

In [None]:
enc_model, dec_model = train()