Skip to content

Commit

Permalink
added training commands, checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
arash-vahdat committed Nov 19, 2020
1 parent 9a40297 commit 18350c7
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 13 deletions.
148 changes: 137 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
likelihood-based generative models on several image datasets.

<p align="center">
<img src="scripts/celebahq.png" width="800">
<img src="img/celebahq.png" width="800">
</p>

## Requirements
Expand Down Expand Up @@ -94,11 +94,15 @@ python create_ffhq_lmdb.py --ffhq_img_path=$DATA_DIR/ffhq/images1024x1024/ --ffh
We use the following commands on each dataset for training NVAEs on each dataset for
Table 1 in the [paper](https://arxiv.org/pdf/2007.03898.pdf). In all the datasets but MNIST
normalizing flows are enabled. Check Table 6 in the paper for more information on training
details:
details. Note that for the multinode training (more than 8-GPU experiments), we use the `mpirun`
command to run the training scripts on multiple nodes. Please adjust the commands below according to your setup.
Below `IP_ADDR` is the IP address of the machine that will host the process with rank 0
(see [here](https://pytorch.org/tutorials/intermediate/dist_tuto.html#initialization-methods).
`NODE_RANK` is the index of each node among all the nodes that are running the job.

<details><summary>MNIST</summary>

Two V100 GPUs are used for training NVAE on dynamically binarized MNIST. Training takes about 21 hours.
Two 16-GB V100 GPUs are used for training NVAE on dynamically binarized MNIST. Training takes about 21 hours.

```shell script
export EXPR_ID=UNIQUE_EXPR_ID
Expand All @@ -116,7 +120,7 @@ python train.py --data $DATA_DIR/mnist --root $CHECKPOINT_DIR --save $EXPR_ID --

<details><summary>CIFAR-10</summary>

Eight V100 GPUs are used for training NVAE on CIFAR-10. Training takes about 55 hours.
Eight 16-GB V100 GPUs are used for training NVAE on CIFAR-10. Training takes about 55 hours.

```shell script
export EXPR_ID=UNIQUE_EXPR_ID
Expand All @@ -134,7 +138,7 @@ python train.py --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR --save $EXPR_ID

<details><summary>CelebA 64</summary>

Eight V100 GPUs are used for training NVAE on CelebA 64. Training takes about 92 hours.
Eight 16-GB V100 GPUs are used for training NVAE on CelebA 64. Training takes about 92 hours.

```shell script
export EXPR_ID=UNIQUE_EXPR_ID
Expand All @@ -152,24 +156,110 @@ python train.py --data $DATA_DIR/celeba64_lmdb --root $CHECKPOINT_DIR --save $EX

<details><summary>ImageNet 32x32</summary>

Coming Soon.
24 16-GB V100 GPUs are used for training NVAE on ImageNet 32x32. Training takes about 70 hours.

```shell script
export EXPR_ID=UNIQUE_EXPR_ID
export DATA_DIR=PATH_TO_DATA_DIR
export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR
export CODE_DIR=PATH_TO_CODE_DIR
export IP_ADDR=IP_ADDRESS
export NODE_RANK=NODE_RANK_BETWEEN_0_TO_2
cd $CODE_DIR
mpirun --allow-run-as-root -np 3 -npernode 1 bash -c \
'python train.py --data $DATA_DIR/imagenet-oord/imagenet-oord-lmdb_32 --root $CHECKPOINT_DIR --save $EXPR_ID --dataset imagenet_32 \
--num_channels_enc 192 --num_channels_dec 192 --epochs 45 --num_postprocess_cells 2 --num_preprocess_cells 2 \
--num_latent_scales 1 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
--num_preprocess_blocks 1 --num_postprocess_blocks 1 --num_groups_per_scale 28 \
--batch_size 24 --num_nf 1 --warmup_epochs 1 \
--weight_decay_norm 1e-2 --weight_decay_norm_anneal --weight_decay_norm_init 1e0 \
--num_process_per_node 8 --use_se --res_dist \
--fast_adamax --node_rank $NODE_RANK --num_proc_node 3 --master_address $IP_ADDR '
```
</details>

<details><summary>CelebA HQ 256</summary>

Coming Soon.
24 32-GB V100 GPUs are used for training NVAE on CelebA HQ 256. Training takes about 94 hours.

```shell script
export EXPR_ID=UNIQUE_EXPR_ID
export DATA_DIR=PATH_TO_DATA_DIR
export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR
export CODE_DIR=PATH_TO_CODE_DIR
export IP_ADDR=IP_ADDRESS
export NODE_RANK=NODE_RANK_BETWEEN_0_TO_2
cd $CODE_DIR
mpirun --allow-run-as-root -np 3 -npernode 1 bash -c \
'python train.py --data $DATA_DIR/celeba/celeba-lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset celeba_256 \
--num_channels_enc 30 --num_channels_dec 30 --epochs 300 --num_postprocess_cells 2 --num_preprocess_cells 2 \
--num_latent_scales 5 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
--num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-2 --num_groups_per_scale 16 \
--batch_size 4 --num_nf 2 --ada_groups --min_groups_per_scale 4 \
--weight_decay_norm_anneal --weight_decay_norm_init 1. --num_process_per_node 8 --use_se --res_dist \
--fast_adamax --num_x_bits 5 --node_rank $NODE_RANK --num_proc_node 3 --master_address $IP_ADDR '
```

In our early experiments, a smaller model with 24 channels instead of 30, could be trained on only 8 GPUs in
the same time (with the batch size of 6). The smaller models obtain only 0.01 bpd higher
negative log-likelihood.
</details>

<details><summary>FFHQ 256</summary>

Coming Soon.
24 32-GB V100 GPUs are used for training NVAE on FFHQ 256. Training takes about 160 hours.

```shell script
export EXPR_ID=UNIQUE_EXPR_ID
export DATA_DIR=PATH_TO_DATA_DIR
export CHECKPOINT_DIR=PATH_TO_CHECKPOINT_DIR
export CODE_DIR=PATH_TO_CODE_DIR
export IP_ADDR=IP_ADDRESS
export NODE_RANK=NODE_RANK_BETWEEN_0_TO_2
cd $CODE_DIR
mpirun --allow-run-as-root -np 3 -npernode 1 bash -c \
'python train.py --data $DATA_DIR/ffhq/ffhq-lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset ffhq \
--num_channels_enc 30 --num_channels_dec 30 --epochs 200 --num_postprocess_cells 2 --num_preprocess_cells 2 \
--num_latent_scales 5 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
--num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-1 --num_groups_per_scale 16 \
--batch_size 4 --num_nf 2 --ada_groups --min_groups_per_scale 4 \
--weight_decay_norm_anneal --weight_decay_norm_init 1. --num_process_per_node 8 --use_se --res_dist \
--fast_adamax --num_x_bits 5 --learning_rate 8e-3 --node_rank $NODE_RANK --num_proc_node 3 --master_address $IP_ADDR '
```

In our early experiments, a smaller model with 24 channels instead of 30, could be trained on only 8 GPUs in
the same time (with the batch size of 6). The smaller models obtain only 0.01 bpd higher
negative log-likelihood.
</details>

**If for any reason your training is stopped, use the exactly same commend with the addition of `--cont_training`
to continue training from the last saved checkpoint. If you observe NaN, continuing the training using this flag
usually will not fix the NaN issue.**

## Known Issues
<details><summary>Cannot build CelebA 64 or training gives NaN right at the beginning on this dataset </summary>

Several users have reported issues building CelebA 64 or have encountered NaN at the beginning of training on this dataset.
If you face similar issues on this dataset, you can download this dataset manually and build LMDBs using instructions
on this issue https://github.com/NVlabs/NVAE/issues/2 .
</details>

<details><summary>Getting NaN after a few epochs of training </summary>

One of the main challenges in training very deep hierarchical VAEs is training instability that we discussed in the paper.
We have verified that the settings in the commands above can be trained in a stable way. If you modify the settings
above and you encounter NaN after a few epochs of training, you can use these tricks to stabilize your training:
i) increase the spectral regularization coefficient, `--weight_decay_norm`. ii) Use exponential decay on
`--weight_decay_norm` using `--weight_decay_norm_anneal` and `--weight_decay_norm_init`. iii) Decrease learning rate.
</details>

**If for any reason the training is stopped, use the exactly same commend with the addition of `--cont_training`
to continue training from the last saved checkpoint.**
<details><summary>Training freezes with no NaN </summary>

In some very rare cases, we observed that training freezes after 2-3 days of training. We believe the root cause
of this is because of a racing condition that is happening in one of the low-level libraries. If for any reason the training
is stopped, kill your current run, and use the exactly same commend with the addition of `--cont_training`
to continue training from the last saved checkpoint.
</details>

## Monitoring the training progress
While running any of the commands above, you can monitor the training progress using Tensorboard:
Expand All @@ -183,7 +273,7 @@ Above, `$CHECKPOINT_DIR` and `$EXPR_ID` are the same variables used for running

</details>

## Post-training sampling and evaluation
## Post-training sampling, evaluation, and checkpoints

<details><summary>Evaluation</summary>

Expand Down Expand Up @@ -211,6 +301,19 @@ where `--temp` sets the temperature used for sampling and `--readjust_bn` enable
as described in the paper. If you remove `--readjust_bn`, the sampling will proceed with BN layer in the eval mode
(i.e., BN layers will use running mean and variances extracted during training).

</details>

<details><summary>Checkpoints</summary>

We provide checkpoints on MNIST, CIFAR-10, CelebA 64, CelebA HQ 256, FFHQ in
[this Google drive directory](https://drive.google.com/drive/folders/1KVpw12AzdVjvbfEYM_6_3sxTy93wWkbe?usp=sharing).
For CIFAR10, we provide two checkpoints as we observed that a multiscale NVAE provides better qualitative
results than a single scale model on this dataset. The multiscale model is only slightly worse in terms
of log-likelihood (0.01 bpd). We also observe that one of our early models on CelebA HQ 256 with 0.01 bpd
worse likelihood generates much better images in low temperature on this dataset.

You can use the commands above to evaluate or sample from these checkpoints.

</details>

## How to construct smaller NVAE models
Expand Down Expand Up @@ -243,11 +346,34 @@ We use two schemes for setting the number of groups:
the total number of groups by reducing `--num_groups_per_scale` and `--min_groups_per_scale`
when `--ada_groups` is enabled.

## Understanding the implementation
If you are modifying the code, you can use the following figure to map the code to the paper.

<p align="center">
<img src="img/model_diagram.png" width="900">
</p>


## Traversing the latent space
We can generate images by traversing in the latent space of NVAE. This sequence is generated using our model
trained on CelebA HQ, by interpolating between samples generated with temperature 0.6.
Some artifacts are due to color quantization in GIFs.

<p align="center">
<img src="https://drive.google.com/uc?id=1k_s_TCdblNRI6MG_X1tji9VoOPumCzz9" width="512">
</p>

## License
Please check the LICENSE file. NVAE may be used non-commercially, meaning for research or
evaluation purposes only. For business inquiries, please contact
[researchinquiries@nvidia.com](mailto:researchinquiries@nvidia.com).

You should take into consideration that VAEs are trained to mimic the training data distribution, and, any
bias introduced in data collection will make VAEs generate samples with a similar bias. Additional bias could be
introduced during model design, training, or when VAEs are sampled using small temperatures. Bias correction in
generative learning is an active area of research, and we recommend interested readers to check this area before
building applications using NVAE.

## Bibtex:
Please cite our paper, if you happen to use this codebase:

Expand Down
19 changes: 17 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,25 @@ def main(eval_args):
checkpoint = torch.load(eval_args.checkpoint, map_location='cpu')
args = checkpoint['args']

if not hasattr(args, 'ada_groups'):
logging.info('old model, no ada groups was found.')
args.ada_groups = False

if not hasattr(args, 'min_groups_per_scale'):
logging.info('old model, no min_groups_per_scale was found.')
args.min_groups_per_scale = 1

if not hasattr(args, 'num_mixture_dec'):
logging.info('old model, no num_mixture_dec was found.')
args.num_mixture_dec = 10

logging.info('loaded the model at epoch %d', checkpoint['epoch'])
arch_instance = utils.get_arch_cells(args.arch_instance)
model = AutoEncoder(args, None, arch_instance)
model.load_state_dict(checkpoint['state_dict'])
# Loading is not strict because of self.weight_normalized in Conv2D class in neural_operations. This variable
# is only used for computing the spectral normalization and it is safe not to load it. Some of our earlier models
# did not have this variable.
model.load_state_dict(checkpoint['state_dict'], strict=False)
model = model.cuda()

logging.info('args = %s', args)
Expand Down Expand Up @@ -78,7 +93,7 @@ def main(eval_args):
with torch.no_grad():
n = int(np.floor(np.sqrt(num_samples)))
set_bn(model, bn_eval_mode, num_samples=36, t=eval_args.temp, iter=500)
for ind in range(5): # sampling is repeated.
for ind in range(10): # sampling is repeated.
torch.cuda.synchronize()
start = time()
with autocast():
Expand Down
File renamed without changes
Binary file added img/model_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 18350c7

Please sign in to comment.