Skip to content

Harahan/RTDMD

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

logo

Reinforcing Few-step Generators via Reward-Tilted Distribution Matching

Reward-Tilted DMD  ·  Ambient-Consistent Distillation  ·  Hybrid Policy Gradient

Paper Github Hugging Face Collection

License: Apache 2.0 Python

Yushi Huang1, 2,*, Xiangxin Zhou2,*, Ruoyu Wang2, 3,*, Chi Zhang3, Jun Zhang1,Tianyu Pang2,

1The Hong Kong University of Science and Technology    2Tencent Hunyuan    3Westlake University

* Equal contribution  ·  † Work done during internship at Tencent Hunyuan  ·  ‡ Corresponding author


📑 Table of Contents


📖 Abstract

We propose Reward-Tilted Distribution Matching Distillation (RTDMD), a two-stage framework that unifies distribution-matching distillation with reward-guided RL for few-step flow generators. Minimizing the KL divergence to a reward-tilted teacher distribution decomposes naturally into a distribution-matching term and a reward-maximization term — instantiated as Ambient-Consistent DMD (AC-DMD) for the cold start and a hybrid policy gradient (SubGRPO + final-step reward back-propagation) for the RL stage. With 4 NFE RTDMD reaches new SOTA on SD3-M / SD3.5-M / FLUX.2 4B; the distilled FLUX.2 4B even beats the full FLUX.2 9B teacher (50 NFE) on most rewards.

RTDMD teaser
4-step samples from RTDMD-distilled FLUX.2 4B (no classifier-free guidance).
RTDMD comparison
Qualitative comparison for few-step diffusion models (4 NFE).

🍭 Method Overview

RTDMD method overview
RTDMD overview. Det. = deterministic final step, Stoc. = stochastic intermediate steps. Trajectories: teacher (blue), few-step generator (green), fake score (yellow).

For the generator $G_\theta$, the reward-tilted KL objective decomposes as

$$ \nabla_\theta D_{\text{KL}}(p_\theta | \tilde{p}_\psi) = \underbrace{\nabla_\theta D_{\text{KL}}(p_\theta | p_\psi)}_{\text{distribution matching}} - \beta\underbrace{\nabla_\theta \mathbb{E}_{\hat{\mathbf{x}}_0 \sim p_\theta}[r(\hat{\mathbf{x}}_0)]}_{\text{reward maximization}}. $$

The two terms map directly to the two trainers exposed by the CLI:

Stage Trainer Key knobs
1. AC-DMD cold start ACDMDTrainer (--trainer ac_dmd) sub-interval renoising, consistency weight γ, CPS sampler η = 0.9
2. RTDMD RL fine-tune RTDMDTrainer (--trainer rtdmd) SubGRPO + final-step BP + AC-DMD

📊 Main Results

All numbers are on 4 NFE (4 inference steps); the teacher uses its standard multi-step setting. Bold = best; underline = second-best.

SD3-M (paper Table 1)

Method NFE CLIPScore ↑ Aesthetic ↑ PickScore ↑ HPSv2 ↑ ImageReward ↑
SD3-M teacher (w/ CFG) 100 0.2936 5.5711 22.3236 0.2810 1.0759
GDMD 4 0.2930 5.8728 22.4614 0.3076 1.2702
Rdm 4 0.2936 5.8769 22.5783 0.2957 1.2897
RTDMD (Ours) 4 0.3161 5.9642 22.8593 0.3211 1.3024

RTDMD is the only 4-NFE model that surpasses the 100-NFE SD3-M teacher with CFG across all five metrics — see the paper for the full baseline table.

FLUX.2 4B (paper Table 2)

Method NFE ImageReward ↑ CLIPScore ↑ Aesthetic ↑ PickScore ↑ HPSv2 ↑ HPSv3 ↑ GenEval ↑ GenEval2 ↑ OCR ↑
FLUX.2 4B teacher 50 0.8538 0.2834 5.3333 22.3938 0.2771 11.7025 0.7631 0.2207 0.6133
FLUX.2 9B teacher 50 1.0021 0.2962 5.2030 22.6382 0.2800 11.6883 0.7568 0.3557 0.7432
Z-Image 6B 50 0.7841 0.2841 5.2488 22.2118 0.2714 10.0857 0.6563 0.3012 0.7373
Z-Image-Turbo 6B 4 0.9696 0.2764 5.2894 22.7994 0.2954 12.9136 0.7562 0.3530 0.7539
FLUX.2 4B 4 1.0506 0.2864 5.2658 22.7370 0.2890 12.9295 0.7722 0.2403 0.6375
FLUX.2 9B 4 1.1998 0.2919 5.3730 23.0178 0.2991 13.2955 0.7814 0.3570 0.7566
Z-Image 6B w/ TDM-R1 4 1.1543 0.2836 5.2450 22.8202 0.3064 13.4349 0.7737 0.4073 0.7665
FLUX.2 4B w/ RTDMD (Ours) 4 1.3712 0.3219 5.7746 23.9642 0.3516 15.5772 0.9046 0.2755 0.6858

RTDMD on FLUX.2 4B is the best 4-NFE model on 7 of 9 rewards (ImageReward / CLIPScore / Aesthetic / PickScore / HPSv2 / HPSv3 / GenEval) and beats the FLUX.2 9B teacher at 50 NFE on every one of those seven — including +0.37 ImageReward, +0.57 Aesthetic, +1.33 PickScore, +3.89 HPSv3, and +0.15 GenEval.


✅ TODO

  • Release more RTDMD checkpoints (FLUX.2 9B and FLUX.1 dev) on the RTDMD HF collection

📁 Repository Layout

RTDMD/
├── main.py                # Training entry point
├── inference.py           # Inference entry point
├── configs/
│   ├── cold_start/        # AC-DMD distillation YAMLs (5 backbones)
│   ├── rtdmd/             # RTDMD RL fine-tune YAMLs (5 backbones)
│   └── inference/         # Inference YAMLs (5 backbones)
├── rtdmd/                 # Source package: trainers/, models/, schedulers/,
│                          # rewards/, data/, parallel/, utils/, diffusers_patch/
└── scripts/
    ├── cold_start.sh      # AC-DMD launcher (single / multi-node)
    ├── rtdmd.sh           # RTDMD launcher  (single / multi-node)
    ├── inference.sh       # Inference launcher
    └── merge_lora_transformer.py

🛠️ Installation

Reference environment (what the paper numbers were produced with):

Component Version
Python 3.10
CUDA 12.4
PyTorch 2.6.0
GPU NVIDIA H20 / H100 / H800 / A100-80GB
NCCL / IB RoCE or InfiniBand for multi-node
git clone https://github.com/Harahan/RTDMD.git
cd RTDMD

conda create -n rtdmd python=3.10 -y
conda activate rtdmd

pip install -r requirements.txt
pip install -e .

requirements.txt is a pinned snapshot of the paper environment (flash-attn, peft, the exact diffusers git commit, and mmcv / mmdet for the GenEval scorer). If flash-attn fails to build, drop the line — the model loaders fall back to PyTorch SDPA automatically.

Pretrained models

pretrained_path and *_init_path accept either a local directory or a HuggingFace Hub repo id; diffusers.from_pretrained() downloads and caches the weights on first use. Gated repos (e.g. black-forest-labs/FLUX.1-dev) require huggingface-cli login with an authorized token first.

Reward checkpoints

Point RTDMD_REWARD_CKPT_PATH (or each config's reward_ckpt_path) at a local directory for the reward-model weights. Most scorers auto-download on first use: PickScore, HPSv3, ImageReward, CLIPScore, GenEval2 (Qwen3-VL Soft-TIFA), OCR (PaddleOCR), and the GenEval Mask2Former backbone (pulled from the OpenMMLab CDN into reward_ckpts/).

Only two scorers need a one-time wget:

mkdir -p reward_ckpts && cd reward_ckpts
# Aesthetic predictor (LAION)
wget https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/refs/heads/main/sac+logos+ava1-l14-linearMSE.pth
# HPSv2.1 (OpenCLIP backbone + HPS classifier head)
wget https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin
wget https://huggingface.co/xswu/HPSv2/resolve/main/HPS_v2.1_compressed.pt
cd ..
export RTDMD_REWARD_CKPT_PATH=$(pwd)/reward_ckpts

GenEval evaluates against the COCO-80 object categories (the Mask2Former detector we use is trained on COCO) — the class-name lookup ships at rtdmd/rewards/assets/object_names.txt, so no extra setup is needed beyond pip install -r requirements.txt.

Optional pre-warm so the first training step doesn't stall on HuggingFace downloads:

python - <<'PY'
from transformers import AutoModel, AutoProcessor, CLIPModel, CLIPProcessor
AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
AutoModel.from_pretrained("yuvalkirstain/PickScore_v1")
CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
PY

🚀 Quick Start

All examples below use FLUX.2-klein 4B. The other four supported backbones (SD3-M, SD3.5-M, FLUX.1-dev, FLUX.2-klein 9B) use the exact same commands — only the YAML basename changes under each configs/{cold_start,rtdmd,inference}/ directory.

1. Cold-start distillation (AC-DMD)

All five models run cold-start on 1 node × 8 GPUs:

bash scripts/cold_start.sh 8 configs/cold_start/flux2_4b.yaml

2. RL fine-tune (RTDMD = GRPO + AC-DMD / BP aux)

Recommended scale per model:

Model Nodes × GPUs/node Total GPUs
SD3.5-M 1 × 8 8
SD3-M 2 × 8 16
FLUX.2-klein 4B 2 × 8 16
FLUX.1-dev 4 × 8 32
FLUX.2-klein 9B 4 × 8 32

Single-node (e.g., SD3.5-M):

bash scripts/rtdmd.sh 8 configs/rtdmd/sd35m.yaml

Multi-node — FLUX.2-klein 4B on 2 × 8 GPUs (set the env vars on each node):

# Node 0
NNODES=2 NODE_RANK=0 MASTER_ADDR=<chief-ip> \
    bash scripts/rtdmd.sh 8 configs/rtdmd/flux2_4b.yaml

# Node 1
NNODES=2 NODE_RANK=1 MASTER_ADDR=<chief-ip> \
    bash scripts/rtdmd.sh 8 configs/rtdmd/flux2_4b.yaml

For 4-node jobs (FLUX.1-dev / FLUX.2-klein 9B) set NNODES=4 and launch on ranks 0..3 the same way. When the scheduler exports CHIEF_IP / INDEX / HOST_NUM / HOST_GPU_NUM these are picked up automatically.

3. Inference

One YAML per model under configs/inference/. Each ships with the distilled + RL LoRA stack enabled by default. The three LoRA regimes are selected by the YAML's lora_paths:

  • lora_paths: [] → plain pretrained model, no LoRA
  • lora_paths: [distilled] → distilled-only LoRA
  • lora_paths: [distilled, rl] → distilled + RL LoRAs merged in order (YAML default)

Distilled few-step generation (FLUX.2-klein 4B), 8 GPUs, no reward scoring:

bash scripts/inference.sh 8 configs/inference/flux2_4b.yaml \
    --override eval_reward=false --prompt "a cute cat sitting on a windowsill"

No LoRA (plain pretrained) or distilled-only LoRA via CLI override:

# No LoRA
bash scripts/inference.sh 8 configs/inference/flux2_4b.yaml --override lora_paths=

# Distilled-only LoRA
bash scripts/inference.sh 8 configs/inference/flux2_4b.yaml \
    --override lora_paths=/path/to/flux2_4b_cold_start_ckpt/checkpoint-15000/generator_ema.pt

4. Reward evaluation

Same launcher with eval_reward=true (already the YAML default). Generates images for the datasets baked into the YAML and writes per-reward + weighted mean scores to inference_outputs/<model>/metadata.json:

bash scripts/inference.sh 8 configs/inference/flux2_4b.yaml

The default eval block mirrors training: drawbench for most rewards plus hpsv3 / geneval / geneval2 / ocr on their own sub-datasets, capped at num_media_images: 64 prompts per dataset. See the reward_fn and reward_dataset_map sections of each inference YAML for per-reward weights and dataset routing.


⚙️ Configuration

Configuration is pure-Python dataclass + YAML with dot-notation CLI overrides:

bash scripts/rtdmd.sh 8 configs/rtdmd/flux2_4b.yaml \
    --override train.seed=123 dmd.fake_update_ratio=10

Top-level sections of RTDMDConfig (see rtdmd/config.py):

Section Purpose
model Pretrained path (HF Hub repo id or local dir), dtype, LoRA settings (generator / fake-score / teacher).
dmd DMD hyperparameters: CPS sampler η, denoising step list, fake-score TTUR ratio.
ac_dmd AC-DMD sub-interval renoising bounds and consistency-loss knobs.
grpo GRPO sampling / PPO settings + last_step_loss (AC-DMD / BP aux on the deterministic last step).
solver Per-role AdamW configs (generator / fake_score / teacher).
train Steps, batch size, autocast dtype, EMA, resume.
distributed fsdp or ddp; FSDP sharding strategy (full_shard / hybrid / shard_grad_op); CPU offload for frozen aux models.
eval Periodic reward-evaluation knobs.
logging wandb project / run name / tags.

The dataclass loader silently drops unknown keys, so old configs remain loadable across refactors.


🎁 Reward Scorers

MultiScorer (in rtdmd/rewards/) wraps nine backends that can be combined as {name: weight} inside any reward_fn block: pickscore, hpsv2, hpsv3, clipscore, aesthetic, imagereward, ocr, geneval, geneval2.

The differentiable subset (pickscore, hpsv2, clipscore, imagereward) can be plugged into reward back-propagation on the deterministic final step by setting last_step_loss.bp_enabled: true in the RTDMD YAML — the rest are scored offline as part of GRPO advantages.


🙌 Acknowledgements


📄 Citation

@misc{huang2026reinforcingfewstepgeneratorsrewardtilted,
      title={Reinforcing Few-step Generators via Reward-Tilted Distribution Matching}, 
      author={Yushi Huang and Xiangxin Zhou and Ruoyu Wang and Chi Zhang and Jun Zhang and Tianyu Pang},
      year={2026},
      eprint={2605.26108},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2605.26108}, 
}

⚖️ License

This project is licensed under the Apache License 2.0 — see LICENSE. The supported teacher checkpoints (SD3 / SD3.5 / FLUX.1 / FLUX.2) are released under their original licenses; please comply with each upstream license when using them.

About

[Arxiv 2026] This is the official PyTorch implementation of "RTDMD: Reinforcing Few-step Generators via Reward-Tilted Distribution Matching"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors