Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
summmeer committed Nov 28, 2022
1 parent 8bfafcb commit bfd0ffd
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Arguments explanation:
- ```--resume_checkpoint```: if not none, restore this checkpoint and continue training
- ```--vocab```: the tokenizer is initialized using bert or load your own preprocessed vocab dictionary (e.g. using BPE)

It will take around 2 days to train a __*DiffuSeq*__ model on 2 NVIDIA A100 GPUs for QG and QQP, and the training steps should be increased accordingly along with the size of the training set. To reproduce the results of Table 1 in our paper, we suggest the following configuration for each dataset when training.
It will take 2 more days to train a __*DiffuSeq*__ model on 4 NVIDIA A100 80G GPUs for QG and QQP, and the training steps should be increased accordingly along with the size of the training set. To reproduce the results of Table 1 in our paper, we suggest the following configuration for each dataset when training.

```
python -m torch.distributed.launch --nproc_per_node=4 --master_port=12233 --use_env run_train.py --diff_steps 2000 --lr 0.0001 --learning_steps 50000 --save_interval 10000 --seed 102 --noise_schedule sqrt --hidden_dim 128 --bsz 2048 --dataset qqp --data_dir {datasets/QQP} --vocab bert --seq_len 128 --schedule_sampler lossaware --notes qqp
Expand Down Expand Up @@ -86,6 +86,7 @@ python eval_seq2seq.py --folder ../{your-path-to-outputs} --mbr
```
Note: if you want to use this evaluation script for output files from other models, please make sure the same line from these output files refers to the same piece of data. Otherwise the diversity score could be incorrect.

> Update 28 Nov 2022: We prepare the checkpoint and sampling results of 10 seeds for QQP dataset in this [link](https://drive.google.com/drive/folders/1vnhJIUqPQva_x_sH2h5a0moCc1NYmEpr?usp=sharing).
Welcome to discuss if you have any questions.

Expand Down
2 changes: 1 addition & 1 deletion sample_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def main():
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
)

sample_shape = (batch.shape[0], args.seq_len, args.hidden_dim)
sample_shape = (x_start.shape[0], args.seq_len, args.hidden_dim)

samples = sample_fn(
model,
Expand Down

0 comments on commit bfd0ffd

Please sign in to comment.