Note: This is a work in progress.
This repo purpose is to serve as a cleaner and feature-rich implementation of the VQGAN - Taming Transformers for High-Resolution Image Synthesis from the initial work of dome272's repo in PyTorch from scratch. There's also a great video on the explanation of VQGAN by dome272.
I created this repo to better understand VQGAN myself, and to provide scripts for faster training and experimentation with a toy dataset like MNIST etc. I also tried to make it as clean as possible, with comments, logging, testing & coverage, custom datasets & visualizations, etc.
VQGAN stands for Vector Quantised Generative Adversarial Networks. The main idea behind this paper is to use CNN to learn the visual part of the image and generate a codebook of context-rich visual parts and then use Transformers to learn the long-range/global interactions between the visual parts of the image embedded in the codebook. Combining these two, we can generate very high-resolution images.
Learning both of these short and long-term interactions to generate high-resolution images is done in two different stages.
- The first stage uses VQGAN to learn the codebook of context-rich visual representation of the images. In terms of architecture, it is very similar to VQVAE in that it consists of an encoder, decoder and the codebook. We will learn more about this in the next section.
- Using a transformer to learn the global interactions between the vectors in the codebook by predicting the next sequence from the previous sequences, to generate high-resolution images.
The architecture of VQGAN consists of majorly three parts, the encoder, decoder and the Codebook, similar to the VQVAE paper.
- The encoder
encoder.py
part in the VQGAN learns to represent the images into a much lower dimension called embeddings or latent and consists of Convolution, Downsample, Residual blocks and special attention blocks ( Non-Local blocks ), around 30 million parameters in default settings. - The embeddings are then quantized using CodeBook and the quantized embeddings are used as input to the decoder
decoder.py
part. - The decode takes the "quantized" embeddings and reconstructs the image. The architecture is similar to the encoder but reversed. Around 40 million parameters in default settings, slightly more compared to encoder due to more number of residual blocks.
The main idea behind codebook and quantization is to convert the continuous latent representation into a discrete representation. The codebook is simply a list of n
latent vectors ( which are learned while training ) which are then used to replace the latents generated from the encoder output with the closest vector ( in terms of distance ) from the codebook. The VQ part comes from here.
The training involves, sending the batch of images through the encoder, quantizing the embeddings and then sending the quantized embeddings through the decoder to reconstruct the image. The loss function is computed as follows:
The above equation represents the sum of reconstruction loss, alignment and commitment loss
-
Reconstruction loss
Appartely there is some confusion about is this reconstruction loss was replaced with perceptual loss or it was a combination of them, we will go with what was implemented in the official code CompVis/taming-transformers#40, which is l1 + perceptual loss
The reconstruction loss is a sum of the l1 loss and perceptual loss.
$\text { L1 Loss }=\sum_{i=1}^{n}\left|y_{\text {true }}-y_{\text {predicted }}\right|$ The perceptual is calculated the l2 distance between the last layer output of the generated vs original image from pre-trained model like VGG, etc.
-
The alignment and commitment loss is from the quantization which compares the distance between the latent vectors from encoder output and the closest vector from the codebook.
sg
here means stop gradient function.
The above loss is for the discriminator which takes in real and generated images and learns to classify which one's real or face. the GAN in VQGAN comes from here :)
The discrimination here is a bit different than conventional discriminators in that, instead of taking whole images as an input, they instead convert the images into patches using convolution and then predict which patch is real or fake.
We calculate lambda as the ratio between the reconstruction loss and the GAN loss, both with respect to the gradient of the last layer of the decoder. calculate_lambda
in vqgan.py
The final loss then becomes -
which is the combination of the reconstruction loss, alignment loss and commitment loss and discriminator loss multiplied with lambda
.
To generate the images from VQGAN, we generate the quantized vectors from Stage 2 and pass them through the decoder to reconstruct the image.
This stage contains Transformers 🤖 which are trained to predict the next latent vector from the sequence of previous latent vectors in the quantized encoder output. The paper uses mingpt.py
from Andrej Karpathy's karpathy/minGPT repo.
Due to computation constraints of generating high-resolution images, they also use a sliding attention window to predict the next latent vector from its neighbor vectors in the quantized encoder output.
- Clone the repo -
https://github.com/Shubhamai/pytorch-vqgan
- Create a new conda environment using
conda env create --prefix env python=3.7.13 --file=environment.yml
- Activate the conda environment using
conda activate ./env
-
You can start the training by running
python train.py
. It reads the default config file fromconfigs/default.yml
. To change the config path, run -python train.py --config_path configs/default.yaml
.Here's what mostly the script does -
- Downloads the MNIST dataset automatically and saved in the data directory ( specified in config ).
- Training the VQGAN and transformer model on the MNIST train set with parameters passed from the config file.
- The training metrics, visualizations and model are saved in the experiments/ directory with the corresponding path specified in the config file.
-
Run
aim up
to open the experiment tracker to see the metrics and reconstructed & generated images.
To generate the images, simply run python generate.py
, the models will be loaded from the experiments/checkpoints
and the output will be saved in experiments
.
I have also just started getting my feet wet with testing and automated testing with GitHub CI/CD, so the tests here might not be the best practices.
To run tests, run pytest --cov-config=.coveragerc --cov=. test
The hardware which I tried the model on default settings is -
- Ryzen 5 4600H
- NVIDIA GeForce GTX 1660Ti - 6 GB VRAM
- 12 GB ram
It took around 2-3 min to get good reconstruction results. Since, google colab has similar hardware in terms compute power from what I understand, it should run just fine on colab :)
The list here contains some helpful blogs or videos that helped me a bunch in understanding the VQGAN.
- The Illustrated VQGAN by Lj Miranda
- VQGAN: Taming Transformers for High-Resolution Image Synthesis [Paper Explained] by Gradient Dude
- VQ-GAN: Taming Transformers for High-Resolution Image Synthesis | Paper Explained by The AI Epiphany
- VQ-GAN | Paper Explanation and VQ-GAN | PyTorch Implementation by Outlier
- TL#006 Robin Rombach Taming Transformers for High Resolution Image Synthesis by one of the paper's author - Robin Rombach. Thanks for the talk :)
@misc{esser2020taming,
title={Taming Transformers for High-Resolution Image Synthesis},
author={Patrick Esser and Robin Rombach and Björn Ommer},
year={2020},
eprint={2012.09841},
archivePrefix={arXiv},
primaryClass={cs.CV}
}