Yushi Huang1, 2,*†, Xiangxin Zhou2,*, Ruoyu Wang2, 3,*†, Chi Zhang3, Jun Zhang1,Tianyu Pang2,‡
1The Hong Kong University of Science and Technology 2Tencent Hunyuan 3Westlake University
* Equal contribution · † Work done during internship at Tencent Hunyuan · ‡ Corresponding author
- 📖 Abstract
- 🍭 Method Overview
- 📊 Main Results
- ✅ TODO
- 📁 Repository Layout
- 🛠️ Installation
- 🚀 Quick Start
- ⚙️ Configuration
- 🎁 Reward Scorers
- 🙌 Acknowledgements
- 📄 Citation
- ⚖️ License
We propose Reward-Tilted Distribution Matching Distillation (RTDMD), a two-stage framework that unifies distribution-matching distillation with reward-guided RL for few-step flow generators. Minimizing the KL divergence to a reward-tilted teacher distribution decomposes naturally into a distribution-matching term and a reward-maximization term — instantiated as Ambient-Consistent DMD (AC-DMD) for the cold start and a hybrid policy gradient (SubGRPO + final-step reward back-propagation) for the RL stage. With 4 NFE RTDMD reaches new SOTA on SD3-M / SD3.5-M / FLUX.2 4B; the distilled FLUX.2 4B even beats the full FLUX.2 9B teacher (50 NFE) on most rewards.
|
4-step samples from RTDMD-distilled FLUX.2 4B (no classifier-free guidance). |
Qualitative comparison for few-step diffusion models (4 NFE). |
RTDMD overview. Det. = deterministic final step, Stoc. = stochastic intermediate steps. Trajectories: teacher (blue), few-step generator (green), fake score (yellow).
For the generator
The two terms map directly to the two trainers exposed by the CLI:
| Stage | Trainer | Key knobs |
|---|---|---|
| 1. AC-DMD cold start | ACDMDTrainer (--trainer ac_dmd) |
sub-interval renoising, consistency weight γ, CPS sampler η = 0.9 |
| 2. RTDMD RL fine-tune | RTDMDTrainer (--trainer rtdmd) |
SubGRPO + final-step BP + AC-DMD |
All numbers are on 4 NFE (4 inference steps); the teacher uses its standard multi-step setting. Bold = best; underline = second-best.
| Method | NFE | CLIPScore ↑ | Aesthetic ↑ | PickScore ↑ | HPSv2 ↑ | ImageReward ↑ |
|---|---|---|---|---|---|---|
| SD3-M teacher (w/ CFG) | 100 | 0.2936 | 5.5711 | 22.3236 | 0.2810 | 1.0759 |
| GDMD | 4 | 0.2930 | 5.8728 | 22.4614 | 0.3076 | 1.2702 |
| Rdm | 4 | 0.2936 | 5.8769 | 22.5783 | 0.2957 | 1.2897 |
| RTDMD (Ours) | 4 | 0.3161 | 5.9642 | 22.8593 | 0.3211 | 1.3024 |
RTDMD is the only 4-NFE model that surpasses the 100-NFE SD3-M teacher with CFG across all five metrics — see the paper for the full baseline table.
| Method | NFE | ImageReward ↑ | CLIPScore ↑ | Aesthetic ↑ | PickScore ↑ | HPSv2 ↑ | HPSv3 ↑ | GenEval ↑ | GenEval2 ↑ | OCR ↑ |
|---|---|---|---|---|---|---|---|---|---|---|
| FLUX.2 4B teacher | 50 | 0.8538 | 0.2834 | 5.3333 | 22.3938 | 0.2771 | 11.7025 | 0.7631 | 0.2207 | 0.6133 |
| FLUX.2 9B teacher | 50 | 1.0021 | 0.2962 | 5.2030 | 22.6382 | 0.2800 | 11.6883 | 0.7568 | 0.3557 | 0.7432 |
| Z-Image 6B | 50 | 0.7841 | 0.2841 | 5.2488 | 22.2118 | 0.2714 | 10.0857 | 0.6563 | 0.3012 | 0.7373 |
| Z-Image-Turbo 6B | 4 | 0.9696 | 0.2764 | 5.2894 | 22.7994 | 0.2954 | 12.9136 | 0.7562 | 0.3530 | 0.7539 |
| FLUX.2 4B | 4 | 1.0506 | 0.2864 | 5.2658 | 22.7370 | 0.2890 | 12.9295 | 0.7722 | 0.2403 | 0.6375 |
| FLUX.2 9B | 4 | 1.1998 | 0.2919 | 5.3730 | 23.0178 | 0.2991 | 13.2955 | 0.7814 | 0.3570 | 0.7566 |
| Z-Image 6B w/ TDM-R1 | 4 | 1.1543 | 0.2836 | 5.2450 | 22.8202 | 0.3064 | 13.4349 | 0.7737 | 0.4073 | 0.7665 |
| FLUX.2 4B w/ RTDMD (Ours) | 4 | 1.3712 | 0.3219 | 5.7746 | 23.9642 | 0.3516 | 15.5772 | 0.9046 | 0.2755 | 0.6858 |
RTDMD on FLUX.2 4B is the best 4-NFE model on 7 of 9 rewards (ImageReward / CLIPScore / Aesthetic / PickScore / HPSv2 / HPSv3 / GenEval) and beats the FLUX.2 9B teacher at 50 NFE on every one of those seven — including +0.37 ImageReward, +0.57 Aesthetic, +1.33 PickScore, +3.89 HPSv3, and +0.15 GenEval.
- Release more RTDMD checkpoints (FLUX.2 9B and FLUX.1 dev) on the RTDMD HF collection
RTDMD/
├── main.py # Training entry point
├── inference.py # Inference entry point
├── configs/
│ ├── cold_start/ # AC-DMD distillation YAMLs (5 backbones)
│ ├── rtdmd/ # RTDMD RL fine-tune YAMLs (5 backbones)
│ └── inference/ # Inference YAMLs (5 backbones)
├── rtdmd/ # Source package: trainers/, models/, schedulers/,
│ # rewards/, data/, parallel/, utils/, diffusers_patch/
└── scripts/
├── cold_start.sh # AC-DMD launcher (single / multi-node)
├── rtdmd.sh # RTDMD launcher (single / multi-node)
├── inference.sh # Inference launcher
└── merge_lora_transformer.py
Reference environment (what the paper numbers were produced with):
| Component | Version |
|---|---|
| Python | 3.10 |
| CUDA | 12.4 |
| PyTorch | 2.6.0 |
| GPU | NVIDIA H20 / H100 / H800 / A100-80GB |
| NCCL / IB | RoCE or InfiniBand for multi-node |
git clone https://github.com/Harahan/RTDMD.git
cd RTDMD
conda create -n rtdmd python=3.10 -y
conda activate rtdmd
pip install -r requirements.txt
pip install -e .requirements.txt is a pinned snapshot of the paper environment
(flash-attn, peft, the exact diffusers git commit, and mmcv /
mmdet for the GenEval scorer). If flash-attn fails to build, drop the
line — the model loaders fall back to PyTorch SDPA automatically.
pretrained_path and *_init_path accept either a local directory or a
HuggingFace Hub repo id; diffusers.from_pretrained() downloads and caches
the weights on first use. Gated repos (e.g. black-forest-labs/FLUX.1-dev)
require huggingface-cli login with an authorized token first.
Point RTDMD_REWARD_CKPT_PATH (or each config's reward_ckpt_path) at a
local directory for the reward-model weights. Most scorers auto-download
on first use: PickScore, HPSv3, ImageReward, CLIPScore, GenEval2
(Qwen3-VL Soft-TIFA), OCR (PaddleOCR), and the GenEval Mask2Former
backbone (pulled from the OpenMMLab CDN into reward_ckpts/).
Only two scorers need a one-time wget:
mkdir -p reward_ckpts && cd reward_ckpts
# Aesthetic predictor (LAION)
wget https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/refs/heads/main/sac+logos+ava1-l14-linearMSE.pth
# HPSv2.1 (OpenCLIP backbone + HPS classifier head)
wget https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin
wget https://huggingface.co/xswu/HPSv2/resolve/main/HPS_v2.1_compressed.pt
cd ..
export RTDMD_REWARD_CKPT_PATH=$(pwd)/reward_ckptsGenEval evaluates against the COCO-80 object categories (the
Mask2Former detector we use is trained on COCO) — the class-name lookup
ships at rtdmd/rewards/assets/object_names.txt, so no extra setup is
needed beyond pip install -r requirements.txt.
Optional pre-warm so the first training step doesn't stall on HuggingFace downloads:
python - <<'PY'
from transformers import AutoModel, AutoProcessor, CLIPModel, CLIPProcessor
AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
AutoModel.from_pretrained("yuvalkirstain/PickScore_v1")
CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
PYAll examples below use FLUX.2-klein 4B. The other four supported
backbones (SD3-M, SD3.5-M, FLUX.1-dev, FLUX.2-klein 9B) use the exact same
commands — only the YAML basename changes under each
configs/{cold_start,rtdmd,inference}/ directory.
All five models run cold-start on 1 node × 8 GPUs:
bash scripts/cold_start.sh 8 configs/cold_start/flux2_4b.yamlRecommended scale per model:
| Model | Nodes × GPUs/node | Total GPUs |
|---|---|---|
| SD3.5-M | 1 × 8 | 8 |
| SD3-M | 2 × 8 | 16 |
| FLUX.2-klein 4B | 2 × 8 | 16 |
| FLUX.1-dev | 4 × 8 | 32 |
| FLUX.2-klein 9B | 4 × 8 | 32 |
Single-node (e.g., SD3.5-M):
bash scripts/rtdmd.sh 8 configs/rtdmd/sd35m.yamlMulti-node — FLUX.2-klein 4B on 2 × 8 GPUs (set the env vars on each node):
# Node 0
NNODES=2 NODE_RANK=0 MASTER_ADDR=<chief-ip> \
bash scripts/rtdmd.sh 8 configs/rtdmd/flux2_4b.yaml
# Node 1
NNODES=2 NODE_RANK=1 MASTER_ADDR=<chief-ip> \
bash scripts/rtdmd.sh 8 configs/rtdmd/flux2_4b.yamlFor 4-node jobs (FLUX.1-dev / FLUX.2-klein 9B) set NNODES=4 and launch on
ranks 0..3 the same way. When the scheduler exports CHIEF_IP / INDEX / HOST_NUM / HOST_GPU_NUM these are picked up automatically.
One YAML per model under configs/inference/. Each ships with the
distilled + RL LoRA stack enabled by default. The three LoRA regimes are
selected by the YAML's lora_paths:
lora_paths: []→ plain pretrained model, no LoRAlora_paths: [distilled]→ distilled-only LoRAlora_paths: [distilled, rl]→ distilled + RL LoRAs merged in order (YAML default)
Distilled few-step generation (FLUX.2-klein 4B), 8 GPUs, no reward scoring:
bash scripts/inference.sh 8 configs/inference/flux2_4b.yaml \
--override eval_reward=false --prompt "a cute cat sitting on a windowsill"No LoRA (plain pretrained) or distilled-only LoRA via CLI override:
# No LoRA
bash scripts/inference.sh 8 configs/inference/flux2_4b.yaml --override lora_paths=
# Distilled-only LoRA
bash scripts/inference.sh 8 configs/inference/flux2_4b.yaml \
--override lora_paths=/path/to/flux2_4b_cold_start_ckpt/checkpoint-15000/generator_ema.ptSame launcher with eval_reward=true (already the YAML default). Generates
images for the datasets baked into the YAML and writes per-reward + weighted
mean scores to inference_outputs/<model>/metadata.json:
bash scripts/inference.sh 8 configs/inference/flux2_4b.yamlThe default eval block mirrors training: drawbench for most rewards plus
hpsv3 / geneval / geneval2 / ocr on their own sub-datasets, capped at
num_media_images: 64 prompts per dataset. See the reward_fn and
reward_dataset_map sections of each inference YAML for per-reward weights
and dataset routing.
Configuration is pure-Python dataclass + YAML with dot-notation CLI overrides:
bash scripts/rtdmd.sh 8 configs/rtdmd/flux2_4b.yaml \
--override train.seed=123 dmd.fake_update_ratio=10Top-level sections of RTDMDConfig (see rtdmd/config.py):
| Section | Purpose |
|---|---|
model |
Pretrained path (HF Hub repo id or local dir), dtype, LoRA settings (generator / fake-score / teacher). |
dmd |
DMD hyperparameters: CPS sampler η, denoising step list, fake-score TTUR ratio. |
ac_dmd |
AC-DMD sub-interval renoising bounds and consistency-loss knobs. |
grpo |
GRPO sampling / PPO settings + last_step_loss (AC-DMD / BP aux on the deterministic last step). |
solver |
Per-role AdamW configs (generator / fake_score / teacher). |
train |
Steps, batch size, autocast dtype, EMA, resume. |
distributed |
fsdp or ddp; FSDP sharding strategy (full_shard / hybrid / shard_grad_op); CPU offload for frozen aux models. |
eval |
Periodic reward-evaluation knobs. |
logging |
wandb project / run name / tags. |
The dataclass loader silently drops unknown keys, so old configs remain loadable across refactors.
MultiScorer (in rtdmd/rewards/) wraps nine backends
that can be combined as {name: weight} inside any reward_fn block:
pickscore, hpsv2, hpsv3, clipscore, aesthetic, imagereward,
ocr, geneval, geneval2.
The differentiable subset (pickscore, hpsv2, clipscore, imagereward)
can be plugged into reward back-propagation on the deterministic final step
by setting last_step_loss.bp_enabled: true in the RTDMD YAML — the rest are
scored offline as part of GRPO advantages.
- diffusers, transformers, and peft — base generative-model stack and LoRA.
- Flow-GRPO — the
SDE-step-with-logprob routine in
rtdmd/diffusers_patch/sde_with_logprob.pyis ported from this project. - Teacher backbones: Stable Diffusion 3 / 3.5, FLUX.1, and FLUX.2.
@misc{huang2026reinforcingfewstepgeneratorsrewardtilted,
title={Reinforcing Few-step Generators via Reward-Tilted Distribution Matching},
author={Yushi Huang and Xiangxin Zhou and Ruoyu Wang and Chi Zhang and Jun Zhang and Tianyu Pang},
year={2026},
eprint={2605.26108},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2605.26108},
}This project is licensed under the Apache License 2.0 — see LICENSE. The supported teacher checkpoints (SD3 / SD3.5 / FLUX.1 / FLUX.2) are released under their original licenses; please comply with each upstream license when using them.