This is an implementation of VQ-VAE with a GPT-style sampler in the JAX and Haiku ecosystem.
Instead of using a PixelCNN to sample from the latent space like the original paper, this implementation uses a GPT-style, decoder-only transformer to generate samples.
Generated with
python -m generate -p runs/gpt/exp0 -o generated/exp0 -t 0.5 -S 5
You will need to install JAX separately. This is because the installation procedure will be different depending on which platform / accelerator / CUDA version you are on.
Please follow these instructions to install JAX accordingly.
Then, install the project's dependencies
pip install -r requirements.txt
Create a copy of the included sample JSON file.
cp static/configs/train_vqvae_config.json train_vqvae_config.json
Optionally, you can change the training parameters.
Since this file is likely to be changed while experimenting,
/train_vqvae_config.json
is included in the.gitignore
.
python -m train_vqvae -f train_vqvae_config.json
In another other terminal, open tensorbaord
to monitor
the training progress.
tensorboard --logdir runs/vqvae
Your value to
--logdir
is the value oflogdir
intrain_vqvae_config.json
. By default, it isruns/vqvae
.
python -m vqvae_encode -p runs/vqvae/exp0/ -o datasets/exp0-encoded
See
python -m vqvae_encode -h
for usage details.
This goes through the MNIST dataset and adds a column for the indices into the quantized codebook of each image.
Create a copy of the included sample JSON file.
cp static/configs/train_gpt_config.json train_gpt_config.json
Optionally, you can change the training parameters.
Since this file is likely to be changed while experimenting,
/train_gpt_config.json
is included in the.gitignore
.
python -m train_gpt -f train_gpt_config.json
The training script prepends the class label on the sequence of encoding indices which allows for conditional generation afterwards.
In another other terminal, open tensorbaord
to monitor
the training progress.
tensorboard --logdir runs/gpt
Your value to
--logdir
is the value oflogdir
intrain_gpt_config.json
. By default, it isruns/gpt
.
python -m generate -p runs/gpt/exp0 -o generated/exp0 -t 0.4
See
python -m generate -h
for usage details.
Voilà! View your generated samples in generated/exp0
!
- Aaron van den Oord, Oriol Vinyals, Koray Kavukcuoglu. Neural Discrete Representation Learning. 2017. arXiv:1711.00937.
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin. Attention Is All You Need. 2017. arXiv:1706.03762.
- Radford, Alec and Karthik Narasimhan. Improving Language Understanding by Generative Pre-Training. 2018.