<a href="https://colab.research.google.com/github/a-mhamdi/jlai/blob/main/Codes/Julia/Part-3/vae/vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VARIATIONAL AUTOENCODER (VAE)
---

VAE implemented in `Julia` using the `Flux.jl` library

In [1]:
versioninfo()

Julia Version 1.10.8
Commit 4c16ff44be8 (2025-01-22 10:06 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 2 × Intel(R) Xeon(R) CPU @ 2.00GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake-avx512)
Threads: 2 default, 0 interactive, 1 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/lib64-nvidia
  JULIA_NUM_THREADS = 2


In [2]:
using Pkg; Pkg.add([ "Flux", "MLDatasets", "ProgressMeter", "LuxCUDA" ])

[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m   Installed[22m[39m cuDNN ────────────────── v1.4.1
[32m[1m   Installed[22m[39m ZipFile ──────────────── v0.10.1
[32m[1m   Installed[22m[39m LLVMLoopInfo ─────────── v1.0.0
[32m[1m   Installed[22m[39m ContextVariablesX ────── v0.1.3
[32m[1m   Installed[22m[39m TimerOutputs ─────────── v0.5.28
[32m[1m   Installed[22m[39m Accessors ────────────── v0.1.42
[32m[1m   Installed[22m[39m ShowCases ────────────── v0.1.0
[32m[1m   Installed[22m[39m NNlib ────────────────── v0.9.28
[32m[1m   Installed[22m[39m NVTX_jll ─────────────── v3.1.1+0
[32m[1m   Installed[22m[39m ScopedValues ─────────── v1.3.0
[32m[1m   Installed[22m[39m BFloat16s ────────────── v0.5.0
[32m[1m   Installed[22m[39m Optimisers ───────────── v0.4.5
[32m[1m   Installed[22m[39m InitialValues ────────── v0.3.1
[32m[1m   Installed[22m[39m

In [3]:
Pkg.status()

[32m[1mStatus[22m[39m `~/.julia/environments/v1.10/Project.toml`
  [90m[336ed68f] [39mCSV v0.10.15
  [90m[a93c6f00] [39mDataFrames v1.7.0
  [90m[587475ba] [39mFlux v0.16.3
  [90m[7073ff75] [39mIJulia v1.26.0
  [90m[d0bbae9a] [39mLuxCUDA v0.3.3
  [90m[eb30cadb] [39mMLDatasets v0.7.18
  [90m[ee78f7c6] [39mMakie v0.22.2
  [90m[91a5bcdd] [39mPlots v1.40.9
  [90m[92933f4c] [39mProgressMeter v1.10.2


Import the machine learning library `Flux`

In [4]:
using Flux # v"0.16.0"
using Flux: @functor
using Flux: DataLoader
using Flux: onecold, onehotbatch

In [5]:
using ProgressMeter: Progress, next!

In [6]:
using MLDatasets
d = MNIST()

This program has requested access to the data dependency MNIST.
which is not currently installed. It can be installed automatically, and you will not see this message again.

Dataset: THE MNIST DATABASE of handwritten digits
Authors: Yann LeCun, Corinna Cortes, Christopher J.C. Burges
Website: http://yann.lecun.com/exdb/mnist/

[LeCun et al., 1998a]
    Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner.
    "Gradient-based learning applied to document recognition."
    Proceedings of the IEEE, 86(11):2278-2324, November 1998

The files are available for download at the offical
website linked above. Note that using the data
responsibly and respecting copyright remains your
responsibility. The authors of MNIST aren't really
explicit about any terms of use, so please read the
website to make sure you want to download the
dataset.



Do you want to download the dataset from ["https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", "https://ossci-datasets.s3.amazonaws.com/mn

dataset MNIST:
  metadata  =>    Dict{String, Any} with 3 entries
  split     =>    :train
  features  =>    28×28×60000 Array{Float32, 3}
  targets   =>    60000-element Vector{Int64}

In [7]:
Base.@kwdef mutable struct HyperParams
    η = 3f-3                        # Learning rate
    λ = 1f-2                        # Regularization parameter
    batchsize = 64                  # Batch size
    epochs = 16                     # Number of epochs
    split = :train                  # Split data into `train` and `test`
    input_dim = 28*28               # Input dimension
    hidden_dim = 512                # Hidden dimension
    latent_dim = 2                  # Latent dimension
    # save_path = "Output"          # Results folder
end

HyperParams

Load the **MNIST** dataset

In [8]:
function get_data(; kws...)
    args = HyperParams(; kws...);
    # Split data
    data = MNIST(split=args.split);
    X = reshape(data.features, (args.input_dim, :));
    loader = DataLoader(X; batchsize=args.batchsize, shuffle=true) |> gpu;
    return loader
end

get_data (generic function with 1 method)

In [9]:
train_loader = get_data();
test_loader = get_data(split=:test);

Define the `encoder` network: The encoder network should return the parameters of the _latent distribution_ (μ and σ).

In [10]:
struct Encoder
    linear
    μ
    log_σ
end

In [11]:
@functor Encoder

[33m[1m│ [22m[39mMost likely, you should write `Flux.@layer MyLayer`which will add various convenience methods for your type,such as pretty-printing and use with Adapt.jl.
[33m[1m│ [22m[39mHowever, this is not required. Flux.jl v0.15 uses Functors.jl v0.5,which makes exploration of most nested `struct`s opt-out instead of opt-in...so Flux will automatically see inside any custom struct definitions.
[33m[1m│ [22m[39mIf you really want to apply the `@functor` macro to a custom struct, use `Functors.@functor` instead.
[33m[1m└ [22m[39m[90m@ Flux ~/.julia/packages/Flux/3711C/src/deprecations.jl:101[39m


In [12]:
encoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Encoder(
    Dense(input_dim, hidden_dim, tanh),   # linear
    Dense(hidden_dim, latent_dim),        # μ
    Dense(hidden_dim, latent_dim),        # log_σ
    ) |> gpu

encoder (generic function with 1 method)

In [13]:
function (encoder::Encoder)(x)
    h = encoder.linear(x)
    encoder.μ(h), encoder.log_σ(h)
end

Define the `decoder` network: The decoder network should return the reconstruction of the input data

In [14]:
decoder(input_dim::Int, hidden_dim::Int, latent_dim::Int) = Chain(
    Dense(latent_dim, hidden_dim, tanh),
    Dense(hidden_dim, input_dim)
    ) |> gpu

decoder (generic function with 1 method)

Reconstruction of the input data

In [15]:
function vae(x, enc, dec)
    # Encode `x` into the latent space
    μ, log_σ = enc(x)
    # `z` si a sample from the latent distribution
    z = μ + randn(Float32, size(log_σ)) .* exp.(log_σ)
    # Decode the latent representation into a reconstruction of `x`
    x̂ = dec(z)
    # Return μ, log_σ and x̂
    μ, log_σ, x̂
end

vae (generic function with 1 method)

In [16]:
function l(x, enc, dec, λ)
    μ, log_σ, x̂ = vae(x, enc, dec)
    len = size(x)[end]
    # The reconstruction loss measures how well the VAE was able to reconstruct the input data
    logp_x_z = -Flux.Losses.logitbinarycrossentropy(x̂, x, agg=sum) / len
    # The KL divergence loss measures how close the latent distribution is to the normal distribution
    kl_q_p = 5f-1 * sum(@. (-2f0 * log_σ - 1f0 + exp(2f0 * log_σ) + μ^2)) / len
    # L2 Regularization
    reg = λ * sum( θ -> sum(θ.^2), Flux.params(dec) )
    # Sum of the reconstruction loss and the KL divergence loss
    -logp_x_z + kl_q_p + reg
end

l (generic function with 1 method)

In [17]:
function train(; kws...)
    args = HyperParams(; kws...)

    # Initialize `encoder` and `decoder`
    enc_mdl = encoder(args.input_dim, args.hidden_dim, args.latent_dim)
    dec_mdl = decoder(args.input_dim, args.hidden_dim, args.latent_dim)

    # ADAM optimizers
    opt_enc = Flux.setup(Adam(args.η), enc_mdl)
    opt_dec = Flux.setup(Adam(args.η), dec_mdl)

    for epoch in 1:args.epochs
        printstyled("\t***\t === EPOCH $(epoch) === \t*** \n", color=:magenta, bold=true)
        progress = Progress(length(train_loader))
        for X in train_loader
                loss, back = Flux.pullback(enc_mdl, dec_mdl) do enc, dec
                    l(X, enc, dec, args.λ)
                end
                grad_enc, grad_dec = back(1f0)
                Flux.update!(opt_enc, enc_mdl, grad_enc) # Upd `encoder` params
                Flux.update!(opt_dec, dec_mdl, grad_dec) # Upd `decoder` params
                next!(progress; showvalues=[(:loss, loss)])
        end
    end

    # Save the model
    #=
    using DrWatson: struct2dict
    using BSON

    mdl_path = joinpath(args.save_path, "vae.bson")
    let args=struct2dict(args)
    	BSON.@save mdl_path encoder decoder args
    	@info "Model saved to $(mdl_path)"
    end
    =#

    enc_mdl, dec_mdl
end

train (generic function with 1 method)

In [None]:
enc_model, dec_model = train()

[33m[1m│ [22m[39m - To prevent this behaviour, do `ProgressMeter.ijulia_behavior(:append)`. 
[33m[1m└ [22m[39m[90m@ ProgressMeter ~/.julia/packages/ProgressMeter/kVZZH/src/ProgressMeter.jl:594[39m
[32mProgress:  60%|████████████████████████████████████▉                        |  ETA: 0:04:50[39m
[34m  loss:  181.97878[39m