Skip to content

PixArt-alpha/PixArt-sigma

Repository files navigation

👉 PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation


This repo contains PyTorch model definitions, pre-trained weights and inference/sampling code for our paper exploring Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation. You can find more visualizations on our project page.

PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation
Junsong Chen*, Chongjian Ge*, Enze Xie*†, Yue Wu*, Lewei Yao, Xiaozhe Ren, Zhongdao Wang, Ping Luo, Huchuan Lu, Zhenguo Li
Huawei Noah’s Ark Lab, DLUT, HKU, HKUST


Welcome everyone to contribute🔥🔥!!

Learning from the previous PixArt-α project, we will try to keep this repo as simple as possible so that everyone in the PixArt community can use it.


Breaking News 🔥🔥!!


Contents

-Main

-Guidance

-Others


🆚 Compare with PixArt-α

Model T5 token length VAE 2K/4K
PixArt-Σ 300 SDXL
PixArt-α 120 SD1.5
Model Sample-1 Sample-2 Sample-3
PixArt-Σ
PixArt-α
Prompt Close-up, gray-haired, bearded man in 60s, observing passersby, in wool coat and brown beret, glasses, cinematic. Body shot, a French woman, Photography, French Streets background, backlight, rim light, Fujifilm. Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee.
Prompt DetailsSample-1 full prompt: An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a **brown beret** and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film.

🔧 Dependencies and Installation

conda create -n pixart python==3.9.0
conda activate pixart
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia

git clone https://github.com/PixArt-alpha/PixArt-sigma.git
cd PixArt-sigma
pip install -r requirements.txt

🔥 How to Train

1. PixArt Training

First of all.

We start a new repo to build a more user friendly and more compatible codebase. The main model structure is the same as PixArt-α, you can still develop your function base on the original repo. lso, This repo will support PixArt-alpha in the future.

Tip

Now you can train your model without prior feature extraction. We reform the data structure in PixArt-α code base, so that everyone can start to train & inference & visualize at the very beginning without any pain.

1.1 Downloading the toy dataset

Download the toy dataset first. The dataset structure for training is:

cd ./pixart-sigma-toy-dataset

Dataset Structure
├──InternImgs/  (images are saved here)
│  ├──000000000000.png
│  ├──000000000001.png
│  ├──......
├──InternData/
│  ├──data_info.json    (meta data)
Optional(👇)
│  ├──img_sdxl_vae_features_1024resolution_ms_new    (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│  │  ├──000000000000.npy
│  │  ├──000000000001.npy
│  │  ├──......
│  ├──caption_features_new
│  │  ├──000000000000.npz
│  │  ├──000000000001.npz
│  │  ├──......
│  ├──sharegpt4v_caption_features_new    (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│  │  ├──000000000000.npz
│  │  ├──000000000001.npz
│  │  ├──......

1.2 Download pretrained checkpoint

# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers

# PixArt-Sigma checkpoints
python tools/download.py # environment eg. HF_ENDPOINT=https://hf-mirror.com can use for HuggingFace mirror

1.3 You are ready to train!

Selecting your desired config file from config files dir.

python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 \
          train_scripts/train.py \
          configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py \
          --load-from output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth \
          --work-dir output/your_first_pixart-exp \
          --debug

💻 How to Test

1. Quick start with Gradio

To get started, first install the required dependencies. Make sure you've downloaded the checkpoint files from models(coming soon) to the output/pretrained_models folder, and then run on your local machine:

# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pixart_sigma_sdxlvae_T5_diffusers

# PixArt-Sigma checkpoints
python tools/download.py

# demo launch
python scripts/interface.py --model_path output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth --image_size 512 --port 11223

2. Integration in diffusers

Important

Upgrade your diffusers to make the PixArtSigmaPipeline available!

pip install git+https://github.com/huggingface/diffusers

For diffusers<0.28.0, check this script for help.

import torch
from diffusers import Transformer2DModel, PixArtSigmaPipeline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
weight_dtype = torch.float16

transformer = Transformer2DModel.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", 
    subfolder='transformer', 
    torch_dtype=weight_dtype,
    use_safetensors=True,
)
pipe = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
    transformer=transformer,
    torch_dtype=weight_dtype,
    use_safetensors=True,
)
pipe.to(device)

# Enable memory optimizations.
# pipe.enable_model_cpu_offload()

prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]
image.save("./catcus.png")

3. PixArt Demo

pip install git+https://github.com/huggingface/diffusers

# PixArt-Sigma 1024px
DEMO_PORT=12345 python app/app_pixart_sigma.py

# PixArt-Sigma One step Sampler(DMD)
DEMO_PORT=12345 python app/app_pixart_dmd.py

Let's have a look at a simple example using the http://your-server-ip:12345.

4. Convert .pth checkpoint into diffusers version

Directly download from Hugging Face

or run with:

pip install git+https://github.com/huggingface/diffusers

python tools/convert_pixart_to_diffusers.py --orig_ckpt_path output/pretrained_models/PixArt-Sigma-XL-2-1024-MS.pth --dump_path output/pretrained_models/PixArt-Sigma-XL-2-1024-MS --only_transformer=True --image_size=1024 --version sigma

⏬ Available Models

All models will be automatically downloaded here. You can also choose to download manually from this url.

Model #Params Checkpoint path Download in OpenXLab
T5 & SDXL-VAE 4.5B Diffusers: pixart_sigma_sdxlvae_T5_diffusers coming soon
PixArt-Σ-256 0.6B pth: PixArt-Sigma-XL-2-256x256.pth
Diffusers: PixArt-Sigma-XL-2-256x256
coming soon
PixArt-Σ-512 0.6B pth: PixArt-Sigma-XL-2-512-MS.pth
Diffusers: PixArt-Sigma-XL-2-512-MS
coming soon
PixArt-α-512-DMD 0.6B Diffusers: PixArt-Alpha-DMD-XL-2-512x512 coming soon
PixArt-Σ-1024 0.6B pth: PixArt-Sigma-XL-2-1024-MS.pth
Diffusers: PixArt-Sigma-XL-2-1024-MS
coming soon
PixArt-Σ-2K 0.6B pth: PixArt-Sigma-XL-2-2K-MS.pth
Diffusers: PixArt-Sigma-XL-2-2K-MS
coming soon

💪To-Do List

We will try our best to release

  • Training code
  • Inference code
  • Inference code of One Step Sampling with DMD
  • Model zoo (256/512/1024/2K)
  • Diffusers (for fast experience)
  • Training code of One Step Sampling with DMD
  • Diffusers (stable official version: huggingface/diffusers#7654)
  • LoRA training & inference code
  • Model zoo (KV Compress...)
  • ControlNet training & inference code

🤗Acknowledgements

  • Thanks to PixArt-α, DiT and OpenDMD for their wonderful work and codebase!
  • Thanks to Diffusers for their wonderful technical support and awesome collaboration!
  • Thanks to Hugging Face for sponsoring the nicely demo!

📖BibTeX

@misc{chen2024pixartsigma,
  title={PixArt-\Sigma: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation},
  author={Junsong Chen and Chongjian Ge and Enze Xie and Yue Wu and Lewei Yao and Xiaozhe Ren and Zhongdao Wang and Ping Luo and Huchuan Lu and Zhenguo Li},
  year={2024},
  eprint={2403.04692},
  archivePrefix={arXiv},
  primaryClass={cs.CV}

Star History

Star History Chart

About

PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published