Skip to content

ZetangForward/Bridge_Gap_Diffusion

Repository files navigation

Bridging the Gap between Training and Inference for Diffusion Model

This is the official code for Can Diffusion Model Achieve Better Performance in Text Generation? Bridging the Gap between Training and Inference!

Highlight

  1. One can post-train your own diffusion model with two methods below to accelerate the inference speed and achieve better performance !

  2. Extensive experiments show our method can generate a full sequence with 128 tokens in only 4 denoising steps !

Model Architecture

Logo

Down-Sampling Strategy

Logo

Dataset & Model Prepartion

Dataset

We provide the download link for all the data used in our paper:

Task Dataset Samples Used in our paper
Text Simplification WIKI AUTO 677k download
Paraphrase Quora Question Pairs 114k download
Story Generation ROC Story 88k download
Question Generation Quasar-T 117k download
E2E (Semantic / Syntax) E2E 88k download

Please download the data and place under the ./datasets folder

Backbone Model

Please refer to the following repos for more details:

DiffuSeq: Sequence to Sequence Text Generation with Diffusion Models

Diffusion-LM Improves Controllable Text Generation

Note We also provide the two post-trained models link for quick check

Quick Start

We provide the code for post-training on QQP (Paraphrase) dataset

Environment

conda create -n diffsuion python=3.9
conda activate diffusion
pip install -r requirement.txt

Training

We conduct experiment with 4 NVIDIA-A100(40GB)

cd scripts
export CUDA_VISIBLE_DEVICES=0,1,2,3;

DISTRIBUTE_ARGS="
    --nproc_per_node=4 \
    --use_env
"

TRAIN_ARGS="
    --diff_steps 2000 \
    --microbatch 100 \
    --lr 0.0001 \
    --learning_steps 320000 \
    --save_interval 2500 \
    --seed 109 \
    --noise_schedule sqrt \
    --hidden_dim 128 \
    --bsz 100 \
    --dataset qqp \
    --data_dir datasets/QQP \
    --vocab bert \
    --seq_len 128 \
    --simi_penalty l2_noise_random \
    --simi_lambda -2 \
    --simi_step 10 \
    --simi_noise 0.05 \
    --resume_checkpoint /path/to/checkpoint \
    --schedule_sampler lossaware \
    --notes qqp
"

python -m torch.distributed.launch $DISTRIBUTE_ARGS run_train.py $TRAIN_ARGS

Inference

python sample_seq2seq.py \
    --model_path /path/to/checkpoint \
    --step 2000 \
    --batch_size 16 \
    --seed2 10 \
    --split test \
    --out_dir generation_outputs  \
    --decode_respacing "adp_20"

Acknowledgement

We appreciate the open source of the following projects:

DiffuSeqDiffusion-LM

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages