Skip to content

LingxiaoShawn/USD3

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 

Repository files navigation

USD3

Official pytorch source code for

Improving and Unifying Discrete-time and Continuous-time Discrete Denoising Diffusion
Lingxiao Zhao*, Xueying Ding*, Lijun Yu, Leman Akoglu

About

Discrete diffusion models are less explored comparing with continuous-state diffusion models, yet there are many dicrete data like language and graph. We have unified and simplified the diffusion framework (forward diffusion process, backward denoising process, and loss computation) for discrete-time and continuous-time denoising diffusion. Notice that the framework currently works on nominal data where categorical classes are NOT ordered.

Supported features:

  • Unified code: you only need to switch the loss function to choose between continuous time and discrete time. Forward and backward process are shared with the same code.
  • Fast and memory efficient: forward and backward process does NOT store the costly $C\times C$ transition matrices, thanks to the nominal data assumption. We provide both efficient exact VLB loss and simplified VLB loss. Backward process easily supports jump steps.
  • Any dimension input: our code easily support any multi-element object $X$ with dimensionality $(B, N_1,...,N_k, C)$ without any modification, where $k$ can be any positive integer. $B$ is the batch size. If samples have different number of elements, you can provide the mask of paddings to the loss function, which will ignore these padding elements.
  • Conditional diffusion: one can provide the mask for conditional part, and these elements won't change during the conditional diffusion process.
  • Element-dependent noise: we support two types of noise. 1) all elements share the same noise, with the categorical noise distribution having shape $(C)$. 2) element-dependent noise with noise distribution shape $(B, N_1,...,N_k, C)$. This is particularly useful in conditional diffusion process, where one can define element-dependent noise.

Installation

If you only use discrete_diffusion.py for your own project

As long as you have pytorch (>= 1.13) installed, you are free to use directly :)

Run experiment in the paper

We follow the experimental setup and code base from TauLDR.

Usage

If you only use discrete_diffusion.py for your own project

  • Train
  1. Create UnifiedDiscreteDiffusion object

    from discrete_diffusion import UnifiedDiscreteDiffusion
    diffusion = UnifiedDiscreteDiffusion(num_steps, # 0 means use continuous time
                                         num_classes, 
                                         noise_schedule_type, 
                                         noise_schedule_args)
  2. Sampling x_t from t and x_0 (every batch)

    • For continuous-time case (num_steps=0), t should in the range 0 ~ 1.0
    • For discrete-time, t should in integer in the range 0 ~ num_steps
    • m is the noise distribution, see code for doc
    • conditional_mask is used for keeping certain part unchanged or conditioned
    x_t = diffusion.qt_0_sample(x_0, t, m, conditional_mask)
  3. Compute loss with input the noisy x_t and original x_0 (every batch)

    • Assume you have a model (network): (B, N1, ..., Nk), t -> (B, N1, ..., Nk, C), where C is num_classes
    • model takes x_t and t as input, and output prediction of x_0 distribution
    logits_t = model(x_t, t)
    
    # loss = coeff_ce * ce + coeff_vlb * vlb
    loss = diffusion.compute_loss(logits_t,
                                  x_t, 
                                  x_0, 
                                  t, 
                                  m, 
                                  coeff_ce=0.1,
                                  coeff_vlb=1.0, 
                                  conditional_mask=conditional_mask,
                                  denoising_fn=model,
                                  simplified_vlb=False)
    • There are three parameters to play (coeff_ce, coeff_vlb, simplified_vlb), see paper for detail.
  4. Update model with loss['loss'].backward() (every batch)

  • Generation

    • After training, you can use the trained model to generate samples.
    • In discrete-time case, one would want num_backward_steps to be smaller than the training steps num_steps for good performance.
    diffusion.sample(model,
                     num_backward_steps, 
                     m, 
                     conditional_mask=None,
                     conditional_input=None)
    • One can also use mcmc refinement in sampling, see code doc for parameters.

Run experiment in the paper

TODO

  • Add "Run experiment in the paper"
    • Installation
    • Usage
    • Code

Citation

If you use this codebase, or otherwise found our work valuable, please cite:

@article{zhao2024improving,
  title={Improving and Unifying Discrete\&Continuous-time Discrete Denoising Diffusion},
  author={Zhao, Lingxiao and Ding, Xueying and Yu, Lijun and Akoglu, Leman},
  journal={arXiv preprint arXiv:2402.03701},
  year={2024}
}

About

Unify and Simplify Discrete-time and Continuous-time Discrete Denoising Diffusion

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages