Skip to content

andylolu2/jax-vqvae-gpt

Repository files navigation

VQ-VAE + GPT on JAX (and Haiku 📜)

This is an implementation of VQ-VAE with a GPT-style sampler in the JAX and Haiku ecosystem.

Instead of using a PixelCNN to sample from the latent space like the original paper, this implementation uses a GPT-style, decoder-only transformer to generate samples.

🌟 Generated samples

generated_0 generated_1 generated_2 generated_3 generated_4 generated_5 generated_6 generated_7 generated_8 generated_9

Generated with

python -m generate -p runs/gpt/exp0 -o generated/exp0 -t 0.5 -S 5

🔩 Run it yourself!

Step 0: (Optional, recommended) Create virtual environment

Step 1: Install requirements

You will need to install JAX separately. This is because the installation procedure will be different depending on which platform / accelerator / CUDA version you are on.

Please follow these instructions to install JAX accordingly.

Then, install the project's dependencies

pip install -r requirements.txt

Step 2: Create & modify train_vqvae_config.json

Create a copy of the included sample JSON file.

cp static/configs/train_vqvae_config.json train_vqvae_config.json

Optionally, you can change the training parameters.

Since this file is likely to be changed while experimenting, /train_vqvae_config.json is included in the .gitignore.

Step 3: Train VQ-VAE!

python -m train_vqvae -f train_vqvae_config.json

In another other terminal, open tensorbaord to monitor the training progress.

tensorboard --logdir runs/vqvae

Your value to --logdir is the value of logdir in train_vqvae_config.json. By default, it is runs/vqvae.

Step 4: Encode dataset to prepare for training the GPT

python -m vqvae_encode -p runs/vqvae/exp0/ -o datasets/exp0-encoded

See python -m vqvae_encode -h for usage details.

This goes through the MNIST dataset and adds a column for the indices into the quantized codebook of each image.

Step 5: Create & modify train_gpt_config.json

Create a copy of the included sample JSON file.

cp static/configs/train_gpt_config.json train_gpt_config.json

Optionally, you can change the training parameters.

Since this file is likely to be changed while experimenting, /train_gpt_config.json is included in the .gitignore.

Step 6: Train the GPT!

python -m train_gpt -f train_gpt_config.json

The training script prepends the class label on the sequence of encoding indices which allows for conditional generation afterwards.

In another other terminal, open tensorbaord to monitor the training progress.

tensorboard --logdir runs/gpt

Your value to --logdir is the value of logdir in train_gpt_config.json. By default, it is runs/gpt.

Step 7: Generate samples!

python -m generate -p runs/gpt/exp0 -o generated/exp0 -t 0.4 

See python -m generate -h for usage details.

Voilà! View your generated samples in generated/exp0!

References

  1. Aaron van den Oord, Oriol Vinyals, Koray Kavukcuoglu. Neural Discrete Representation Learning. 2017. arXiv:1711.00937.
  2. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin. Attention Is All You Need. 2017. arXiv:1706.03762.
  3. Radford, Alec and Karthik Narasimhan. Improving Language Understanding by Generative Pre-Training. 2018.

Code references

  1. DeepMind Haiku examples

About

Implementation of VQ-VAE with a GPT-style sampler in the JAX and Haiku ecosystem.

Topics

Resources

License

Stars

Watchers

Forks

Languages