Skip to content

1zeryu/SpeeDiT

Repository files navigation

SpeeDiT: Accelerating DiTs and General Diffusion Models via Principle Timestep Adjustment Training

If you like SpeeDiT, please give us a star ⭐ on GitHub for the latest update.

Authors

Elevator roadshow of SpeeDiT

We propose a general diffusion training acceleration algorithm that employs asymmetric sampling of time steps, named SpeeDiT. It can speed up DiT by 3.3 times without a decrease in FID. Ongoing experiments demonstrate that SpeeDiT can be applied to multiple diffusion-based visual generation tasks and has good compatibility with other acceleration methods. Therefore, we believe SpeeDiT can significantly reduce the cost of diffusion training, allowing more people to benefit from this exciting technological advancement!

TODO list sorted by priority

If you encounter any inconvenience with the code or have suggestions for improvements, please feel free to contact us via email at ykzhou8981389@gmail.com and kai.wang@comp.nus.edu.sg.

  • Releasing SpeeDiT-XL/2 400K, 1000K, ..., 7000K checkpoints and publish the technical report.

  • Upgrading the components of SpeeDiT

  • Applying SpeeDiT to text2image

    ​ [Stable diffusion]

    ​ [Latent Diffusion]

    ​ [Imagen]

  • Applying SpeeDiT to text2video

    ​ [Open-Sora]

    ​ [Latte]

  • SpeeDiT + MDT

  • More tasks (Image inpainting, 3D Generation)

😮 Highlights

Our method, which is easily compatible, can accelerate the training of diffusion model.

comparision

✒️ Motivation

Inspired by the uphill and downhill diffusion processes in physics. The following GIF illustrates the commonalities between image diffusion and electron diffusion. The left figure of electric diffusion is simulated from PhET/diffusion. The right figure is downloaded from OpenAI website.comparision

Visualization of different phases of reverse process and uphill diffusion. For easy understanding, we assume that the direction of electronic velocity only has two cases: ⬅️ and ➡️.

motivation

🔆 Method

We use the sampling and weighting strategy which are simple and easily compatible to achieve the acceleration. The following is the core code SpeeDiT/speedit/diffusion/iddpm/speed.py ,

class SpeeDiffusion(SpacedDiffusion):
    def __init__(self, faster, **kwargs):
        super().__init__(**kwargs)
        self.faster = faster
        if faster:
            grad = np.gradient(self.sqrt_one_minus_alphas_cumprod)

            # set the meaningful steps in diffusion, which is more important in inference
            self.meaningful_steps = np.argmax(grad < 1e-4) + 1

            # p2 weighting from: Perception Prioritized Training of Diffusion Models
            self.p2_gamma = 1
            self.p2_k = 1
            self.snr = 1.0 / (1 - self.alphas_cumprod) - 1
            sqrt_one_minus_alphas_bar = torch.from_numpy(self.sqrt_one_minus_alphas_cumprod)
            # sample more meaningful step
            p = torch.tanh(1e6 * (torch.gradient(sqrt_one_minus_alphas_bar)[0] - 1e-4)) + 1.5
            self.p = F.normalize(p, p=1, dim=0)
            self.weights = self._weights()
        else:
            self.meaningful_steps = self.num_timesteps

    def _weights(self):
        # process where all noise to noisy image with content has more weighting in training
        # the weights act on the mse loss
        weights =  1 / (self.p2_k + self.snr) ** self.p2_gamma
        weights = weights
        return weights

    # get the weights and sampling t in training diffusion
    def t_sample(self, n, device):
        if self.faster:
            t = torch.multinomial(self.p, n // 2 + 1, replacement=True).to(device)
            # dual sampling, which can balance the step multiple task training
            dual_t = torch.where(t < self.meaningful_steps, self.meaningful_steps - t, t - self.meaningful_steps)
            t = torch.cat([t, dual_t], dim=0)[:n]
            weights = self.weights
        else:
            # if
            t = torch.randint(0, self.num_timesteps, (n,), device=device)
            weights = None

        return t, weights

You can enable our acceleration module with diffusion.faster=True.

# config file
diffusion:
    timestep_respacing: '250'
    faster: true  #enabl module for training acceleration

🛠️ Requirements and Installation

This code base does not use hardware acceleration technology, experimental environment is not complicated.

You can create a new conda environment:

conda env create -f environment.yml
conda activate speedit

or install the necessary package by:

pip install -r requirements.txt

If necessary, we will provide more methods (e.g., docker) to facilitate the configuration of the experimental environment.

🗝️ Implementation

We provide a complete process for generating tasks including training, inference and test. The current code is only compatible with class-conditional image generation tasks. We will be compatible with more generation tasks about diffusion in the future.

We refactor the facebookresearch/DiT code and loaded the configs using OmegaConf . The configuration file loading rule is recursive for easier argument modification. Simply put, the file in the latter path will override the previous setting of base.yaml.

You can modify the experiment setting by modifying the config file and the command line. More details about the reading of config are written in configs/README.md.

For each experiment, you must provide two arguments by command,

-c: config path;
-p: phase including ['train', 'inference', 'sample'].

Train & inference

For example, class-conditional image generation task with 256x256 ImageNet dataset and DiT-XL/2 models.

# Training: training diffusion and saving checkpoints
torchrun --nproc_per_node=8 main.py -c configs/image/imagenet_256/base.yaml -p train
# inference: generating samples for testing
torchrun --nproc_per_node=8 main.py -c configs/image/imagenet_256/base.yaml -p inference
# sample: sample some images for visualization
python main.py -c configs/image/imagenet_256/base.yaml -p sample

How to do ablation?

You can modify the experiment setting by modifying the config file and the command line. More details about the reading of config are written in configs/README.md.

For example, change the classifier-free guidance scale in sampling by command line:

python main.py -c configs/image/imagenet_256/base.yaml -p sample guidance_scale=1.5

Test

Test the generation tasks require the results of inference. The more details about testing in evaluations.

👍 Acknowledgement

We are grateful for the following exceptional work and generous contribution to open source.

  • DiT: Scalable Diffusion Models with Transformers.
  • Open-Sora : Open-Sora: Democratizing Efficient Video Production for All
  • OpenDiT: An acceleration for DiT training. We adopt valuable acceleration strategies for training progress from OpenDiT.

🔒 License

The majority of this project is released under the Apache 2.0 license as found in the LICENSE file.

✏️Citation

If you find our code useful in your research, please consider giving a star ⭐ and citation 📝.

@software{speedit,
  author = {Yukun Zhou, Kai Wang, Hanwang Zhang, Yang You and Xiaojiang Peng},
  title = {SpeeDiT: Accelerating DiTs and General Diffusion Models via Principle Timestep Adjustment Training},
  month = {March},
  year = {2024},
  url = {https://github.com/1zeryu/SpeeDiT}
}

About

SpeeDiT: Accelerating DiTs and General Diffusion Models via Principle Timestep Adjustment Training.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages