Implements Variational Auto Encoder (VAE) [Kingma and Welling's "Auto-Encoding Variational Bayes"] and associated training code for CelebA, plus the inference code for latent space sampling. You need to get the CelebA data separately and unzip under ./data/
due to download restrictions imposed by Google.
Set up the environment (using miniconda)
$ conda create --name "vae" python=3.6.12
$ conda activate vae
$ pip install -r requirements.txt
If you don't specify any flags, it trains for batch=128
and epochs=50
. During training, generated reconstruction and random sampled results are saved under results/
folder. Also the model weights after each epoch is saved under results.model/
folder.
$ python train_vae.py
Reconstruction of input pictures (top row: original, bottom row: reconstructed) after 30 epochs. Smaller kld_weight makes reconstruction to be a more accurate representation of the original (see below). See Notes on The Loss function (Training related)
kld_weight = 0.00025 (small) | kld_weight = 0.025 (large) |
However, larger kld_weight (for regularization term) enables to learn a latent space that follows Normal distribution more strictly and reduce the sparsity. From the following table it is visible that as the latent space is better approximated by the Normal distribution (i.e. larger kld_weight), random sample becomes more accurate (i.e. less random colors). Some 64 pictures generated by randomly sampled latent codes from the distribution is shown below.
kld_weight = 0.00025 (small) | kld_weight = 0.025 (large) |
Interpolation results between two random CelebA pictures are saved as rndpics_interpolate.{gif,png}
python test_vae.py
Some examples of interpolation between two pictures by sampling from the latent space and generating images for each latent code
Another way to experiment with the latent space is, instead of interpolating between two pictures, interpolate for each latent entry and try to find one that controls a meaningful feature, such as "smile". For ex. by using the default checkpoint in test_vae.py
, when interpolated the latent vector entry 3, the network controls the smile on the person's face.
In "Kingma and Welling's paper", ELBO (Evidence Lower Bound) is the objective function to be maximized and given in Eqn (3) of the paper as follows:
The loss can be analytically calculated by assuming
-
Reconstruction Loss (First Term) : The distribution
$p_\theta(x|z)$ represents the generator using a noisy observation model where$G_\theta(z)$ is the mean and$\eta I$ the variance of the normal distribution$\mathcal{N}(z;G(z),\eta I)$ . Then, it is clear that the first term becomes the reconstruction error.
-
Regularization Loss,
$D_{KL}$ : The analytical expression of$D_{KL}$ between two Normal distributions are given as:$$D_ {KL} (\mathcal{N_1} || \mathcal{N_2})= \log \Big(\frac{\sigma_2}{\sigma_1}\Big) + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2}$$ We can obtain the closed form expression for$D_{KL}(q_\phi(z|x) || p_\theta(z))$ in ELBO by assuming Normal distributions where$p_\theta(z)$ is from$\mathcal{N}(z;0,I)$ and estimated posterior$q_\phi(z|x)$ is from$\mathcal{N}(z;\mu,\sigma^2)$ . Hence for$\sigma_2, \mu_2 =(1,0)$ and$\sigma_1, \mu_1 =(\sigma,\mu)$ the expression becomes:$$D_{KL}(q_\phi(z|x) || p_\theta(z)) = -\frac{1}{2} \big(1 + \log (\sigma^2) -\mu^2 - \sigma^2\big)$$
Note that in the above equations, regularization loss (KL Divergence) is summed accross the latent dimension of vectors
MSE = F.mse_loss(recon_x, x)
KLD = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
kld_weight = 0.00025 #0.025
loss = MSE + kld_weight * KLD
Note that smaller kld_weight
(0.00025) reduces the effect of regularization term (i.e. latent space is not strictly following a Normal distrubution and it is sparse) and causing sampling space to be more colorful (see Table: Random Sampling ). But this helps the reconstructions to be closer to the originals (see Table: Reconstruction ). For larger kld_weight
(0.025), sampled images looks better (less colorful), but reconstructions becomes much different than the input.
[1] Training code (train_vae.py) is mainly from Official Pytorch examples/vae
[2] AntixK github page is a nice resource for various VAE algorithms. I mainly borrowed the code in vae.py
from "class VanillaVAE" in vanilla_vae.py
. Code is written very organized and modular by using pytorch-lightning that automatically uses DDP for multi-GPU training. However the modularity made it too complex to debug/understand for someone who is new. Besides there is no inference script for quick testing.
[3] Got some inspiration from Moshe Sipper's Medium post and github repo, for the inference code (test_vae.py)
[4] Used some ideas/code about sampling from latent space, plotting them using matplotlib, generating GIF and using PCA. Tingsong Ou's Medium Post and [Alexander van de Kleut github repo