Skip to content

artsobolev/IWHVI

master
Switch branches/tags
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Importance Weighted Hierarchical Variational Inference

This repo contains source code for the Importance Weighted Hierarchical Variational Inference paper (NeurIPS 2019).

To train a 32-D MNIST VAE use the following command:

python ./iwhvae_run.py                                              \
    --annealing_factor                   0.95                       \
    --annealing_speed                    100                        \
    --dataset                            dynamic_mnist              \
    --z_dim                              32                         \
    --noise_dim                          32                         \
    --decoder_arch                       300 300                    \
    --encoder_arch                       300 300                    \
    --tau_arch                           300 300                    \
    --tau_use_dreg                                                  \
    --train_batch_size                   256                        \
    --learning_rate                      0.001                      \
    --epochs                             10000                      \
    --evaluate_every                     25                         \
    --q_kl_warm_up_cycles                1                          \
    --q_kl_warm_up_fraction              0.03                       \
    --save_every                         100000                     \
    --train_iwhvi_samples_step_schedule  250,5 500,25 1000,50       \
    --val_batch_size                     50000                      \
    --val_iwae_samples                   1000                       \
    --val_iwhvi_samples                  100                        \
    --save_path                          ./exp_dynamic_mnist_d32

To evaluate a trained model, use the following command:

python ./iwhvae_eval.py ./exp_dynamic_mnist_d32/model-weights-final \
    --n_repeats                          10                         \
    --dataset                            dynamic_mnist              \
    --z_dim                              32                         \
    --noise_dim                          32                         \
    --decoder_arch                       300 300                    \
    --encoder_arch                       300 300                    \
    --tau_arch                           300 300

If you're getting "out of memory" exceptions, add --test_iwae_batch_size N with N smaller than 5000.

Related Links

[ Paper | Blogpost | Talk | Toy Example in Colab ]

About

Code accompanying the Importance Weighted Hierarchical Variational Inference (NeurIPS 2019) paper

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages