a tensorflow implementation of a vanilla Variational AutoEncoder (VAE) [1,2] with support for several standard datasets
- mnist
- fashion-mnist
- cifar10 (convolutional network only)
- celeb-a (convolutional network only)
- Bernoulli
- Gaussian
- fully connected
- convolutional
- The VAE can be trained on subsets with a specific label by specifying the 'class_label' flag.
- The VAE can be trained without the 'V' that is as an AutoEncoder by setting the 'AE' flag.
- there's a number of regularization options available for the fully connected network (dropout, L2 regularization on network weights)
- data augmentation is posisble by adding noise or image rotations
To install, either download the repo and run
pip install -e .
or run
pip install git+https://github.com/VMBoehm/vae
To see all available settings, simply run
python main.py --helpfull
python main.py --data_set='mnist' --latent_size=8 --network_type='conv'
will run a VAE on the mnist dataset with a latent space dimensionality of 8 and (de)convolutional networks as generator/encoder.
The code automatically saves checkpoints and exports the trained model. During the training, several summaries can be visualized with tensorboard:
tensorboard --logdir='./model/'
e.g. mnist reconstructions with a Gaussian likelihood, a latent space dimensionality of 8 and a fully connected network
and corresponding samples
A tensorflow 2.0 compatible version is available in the tf2 branch

