First, download and set up the repo:
git clone https://github.com/Juanerx/Q-DiT.git
cd Q-DiTThen create the environment and install required packages:
conda create -n qdit python=3.8
conda activate qdit
pip install -r requirements.txt
pip install .If you want to use gptq or static quantization, calibration data should be generated by:
cd scripts
python collect_cali_data.pyWe can quantize the model:
bash quant_main.sh --image-size 256 --num-sampling-steps 50 --cfg-scale 1.5 --use_gptqSC mode replaces floating-point matrix multiplications with stochastic computing, trading precision for hardware efficiency. Run via scripts/quant_sc_main.py.
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 0.5 --qklayerwise 1.0 --avlayerwise 0.2 \
--sc_prec 8 --sc_enable \
--image-size 256 --num-sampling-steps 100 --batch-size 16 \
--ckpt pretrained_models/DiT-XL-2-256x256.pt| Argument | Type | Default | Description |
|---|---|---|---|
--timewise |
float | 0.0 | Fraction of timesteps to use SC (0-1). SC is applied to the first timewise * total_timesteps diffusion steps (the noisiest ones). |
--qklayerwise |
float | 0.0 | Fraction of blocks to use SC for Q@K^T attention (0-1). |
--avlayerwise |
float | 0.0 | Fraction of blocks to use SC for Attn@V attention (0-1). |
--projlayerwise |
float | 0.0 | Fraction of blocks to use SC for output projection (0-1). |
--mlplayerwise |
float | 0.0 | Fraction of blocks to use SC for MLP fc1 and fc2 (0-1). |
--inputprojlayerwise |
float | 0.0 | Fraction of blocks to use SC for QKV input projection (0-1). |
--sc_prec |
int | 8 | SC precision in bits. Sets stoc_len = 2^sc_prec (e.g., 8 → stoc_len=256). |
--sc_enable |
flag | false | Use enable-signal SC multiplication (compact kernel) instead of standard XNOR/AND. |
--sc_noise_model |
flag | false | Replace real SC kernels with a fast analytical noise surrogate. See SC Noise Model below. |
--sc_noise_local_correction |
float | 0.15 | Variance correction for per-row/per-batch scaled matmuls (MLP, AV, QK, input proj). Recommended: 0.10 - 0.15. |
--sc_noise_global_correction |
float | 0.60 | Variance correction for per-tensor scaled matmuls (output proj). Recommended: 0.10 - 0.20. |
--reverse_layerwise |
flag | false | Apply SC to the last N blocks instead of the first N. |
--sc_skip_blocks |
str | "" | Comma-separated block indices to skip SC. E.g., "0,27". |
--debug_sc |
flag | false | Debug mode: run both FP and SC, log per-operator error, use FP result. |
--sc_config |
str | None | Path to JSON config for per-block/per-operator/per-group precision (overrides layerwise args). |
--save_sc_config |
str | None | Save the effective SC config to a JSON file after setup. |
--mp |
flag | false | Enable per-token-row mixed precision for QK/AV. Assigns different stoc_len levels to different token rows based on runtime importance. |
--mp_levels |
str | 256,128,64,32 |
Comma-separated stoc_len levels for MP, sorted descending. |
--mp_fractions |
str | None | Comma-separated fractions of rows per MP level (must sum to 1). Default: equal fractions (1/N each). |
SC on Q@K^T only, 50% of timesteps, all blocks:
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 0.5 --qklayerwise 1.0 \
--sc_prec 8 --sc_enable \
--image-size 256 --num-sampling-steps 100 --batch-size 16 \
--ckpt pretrained_models/DiT-XL-2-256x256.ptSC on Q@K^T + Attn@V + MLP, all timesteps, 80% of blocks:
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 1.0 --qklayerwise 0.8 --avlayerwise 0.8 --mlplayerwise 0.8 \
--sc_prec 8 --sc_enable \
--image-size 256 --num-sampling-steps 100 --batch-size 16 \
--ckpt pretrained_models/DiT-XL-2-256x256.ptSave config for reuse:
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 0.5 --qklayerwise 1.0 --avlayerwise 0.5 \
--sc_prec 8 --sc_enable \
--save_sc_config my_config.json \
--image-size 256 --num-sampling-steps 100 --batch-size 16 \
--ckpt pretrained_models/DiT-XL-2-256x256.ptThe --mp flag enables adaptive precision assignment for QK and AV matmuls, integrated directly into the existing _sc_qk/_sc_av paths. Granularity matches the quantization granularity of each operator:
- QK (per-head): Heads are ranked by
||Q||_infacross tokens. Heads with larger Q magnitude get higherstoc_len. Uses the existing batched kernel path — no speed regression. - AV (per-token-row): Within each (batch, head), rows are ranked by
max(attn_row). Rows with more concentrated attention get higherstoc_len.
Rows/heads are bucketed into N discrete stoc_len levels using quantile boundaries. Assignment is recomputed every forward pass.
Quick test (4 levels, equal fractions):
python scripts/quant_sc_main.py --wbits 8 --abits 8 --w_sym --a_sym --timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 --sc_prec 8 --sc_enable --mp --mp_levels 256,128,64,32Custom fractions (aggressive — only 10% at full precision):
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 \
--sc_prec 8 --sc_enable \
--mp --mp_levels 256,128,64,32 --mp_fractions 0.1,0.2,0.3,0.42 levels (simple high/low split):
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 \
--sc_prec 8 --sc_enable \
--mp --mp_levels 256,64 --mp_fractions 0.3,0.7Note:
--mponly activates for blocks/timesteps where QK/AV SC is already enabled via--qklayerwise/--avlayerwise/--timewise. It is handled inside the existing_sc_qk/_sc_avmethods — no separate code path.
The --sc_noise_model flag replaces the real SC Triton kernels with an analytical Gaussian noise surrogate. This is useful for large-scale evaluation runs (e.g. 50k FID samples) where the full bitstream SC simulation is too slow.
Instead of generating bitstreams, AND/XNOR-gating them, and counting bits, the surrogate computes:
out[i,j] = (a @ b^T)[i,j] + noise[i,j]
where noise[i,j] ~ N(0, sigma^2) with:
sigma = sqrt(D / L^2 * correction) * a_scale[i] * b_scale[j]
D= reduction dimension (e.g. 1152 for DiT-XL hidden, 72 for head_dim)L= stoc_len (from--sc_precor per-row MP assignment)correction=--sc_noise_local_correctionor--sc_noise_global_correctiona_scale,b_scale= per-row or per-tensor max of inputs (matches real SC kernel normalization)
The 1/L^2 scaling (instead of 1/L) reflects the Sobol low-discrepancy RNG used by the real SC kernels, which achieves O(1/L) error instead of O(1/sqrt(L)) for iid Bernoulli.
# Uniform SC noise on all operators, using defaults (local=0.15, global=0.60)
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 1 --qklayerwise 1.0 --avlayerwise 1.0 \
--projlayerwise 1.0 --mlplayerwise 1.0 --inputprojlayerwise 1.0 \
--sc_prec 8 --sc_enable --sc_noise_model \
--image-size 256 --num-sampling-steps 100 --batch-size 16# With adaptive MP (same as real SC, but fast)
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 1 --qklayerwise 1.0 --avlayerwise 1.0 \
--projlayerwise 1.0 --mlplayerwise 1.0 --inputprojlayerwise 1.0 \
--sc_prec 8 --sc_enable --sc_noise_model \
--adaptive_mp --mp_levels 256,128,64,32,16 \
--mp_alpha 0.3 --mp_beta 0.1 \
--image-size 256 --num-sampling-steps 100 --batch-size 16# Custom noise correction (tune visually against real SC output)
python scripts/quant_sc_main.py \
... --sc_noise_model \
--sc_noise_local_correction 0.12 \
--sc_noise_global_correction 0.15| Parameter | Affects | Recommended |
|---|---|---|
--sc_noise_local_correction |
MLP fc1/fc2, AV, QK, input proj (~85% of matmuls) | 0.10 - 0.15 |
--sc_noise_global_correction |
Output proj (~15% of matmuls) | 0.10 - 0.20 |
- Larger values = more noise = images look more like real SC with short stoc_len
- Smaller values = less noise = cleaner images, closer to FP baseline
- 0.0 = no noise at all (pure quantized matmul, useful for debugging)
- To match real SC visually, start at
local=0.15, global=0.15and compare side-by-side
- Orthogonal to
--sc_enable(both can be on simultaneously) - All SC dispatch paths are supported: uniform, adaptive MP, range MP, per-head mixed
- The real SC path is completely untouched when
--sc_noise_modelis not set - Uses
torch.compilefor kernel fusion; first ~5 iterations are slow (JIT warmup), then steady state ~12 it/s at batch=8 on RTX PRO 6000
The --sc_config flag enables fine-grained control over SC precision per block, per operator, and per group. This allows mixed-precision SC where different operators use different stochastic stream lengths (stoc_len).
Key idea: early termination. A shorter stoc_len means fewer loop iterations in the SC kernel, giving proportional speedup. For example, stoc_len=128 (int7) is ~2x faster than stoc_len=256 (int8).
{
"total_blocks": 28,
"default_stoc_len": 256,
"default_timewise": 0.5,
"blocks": [
{
"qk": {"enabled": true, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
"av": {"enabled": true, "stoc_len": 128, "timewise": 1.0, "group_stoc_lens": null},
"proj": {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
"mlp_fc1": {"enabled": true, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
"mlp_fc2": {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
"input_proj": {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null}
}
]
}The blocks array must have exactly total_blocks entries (28 for DiT-XL/2). Each block contains config for the six operators:
| Operator | Description |
|---|---|
qk |
Q @ K^T attention BMM |
av |
Attn @ V attention BMM |
proj |
Output projection linear layer |
mlp_fc1 |
MLP first linear layer |
mlp_fc2 |
MLP second linear layer |
input_proj |
QKV input projection linear layer |
Each operator has four fields:
| Field | Type | Description |
|---|---|---|
enabled |
bool | Whether SC is used for this operator in this block. |
stoc_len |
int | Stochastic stream length. Controls precision/speed tradeoff. Must be ≤ 2^sc_prec. |
timewise |
float | Per-operator timewise fraction (0-1). SC is only used during the first timewise * total_timesteps diffusion steps. |
group_stoc_lens |
list or null | Per-group stoc_lens for mixed precision within a single operator. For linear ops: one per weight quantization group. For BMM ops (qk, av): one per attention head. null means uniform (all groups use stoc_len). |
| Precision | stoc_len | Relative Speed |
|---|---|---|
| int8 | 256 | 1x (baseline) |
| int7.5 | 181 | ~1.4x |
| int7 | 128 | ~2x |
| int6 | 64 | ~4x |
| int5 | 32 | ~8x |
| int4 | 16 | ~16x |
Non-power-of-2 values (e.g., 181 for int7.5) are supported. The quantization grid is determined by sc_prec = ceil(log2(stoc_len)).
QK at int8, AV at int7 (128), all blocks, all timesteps:
{
"total_blocks": 28,
"default_stoc_len": 256,
"default_timewise": 1.0,
"blocks": [
{
"qk": {"enabled": true, "stoc_len": 256, "timewise": 1.0, "group_stoc_lens": null},
"av": {"enabled": true, "stoc_len": 128, "timewise": 1.0, "group_stoc_lens": null},
"proj": {"enabled": false, "stoc_len": 256, "timewise": 1.0, "group_stoc_lens": null},
"mlp_fc1": {"enabled": false, "stoc_len": 256, "timewise": 1.0, "group_stoc_lens": null},
"mlp_fc2": {"enabled": false, "stoc_len": 256, "timewise": 1.0, "group_stoc_lens": null},
"input_proj": {"enabled": false, "stoc_len": 256, "timewise": 1.0, "group_stoc_lens": null}
}
]
}(Repeat the block entry 28 times, or vary per block.)
Independent fc1/fc2 — SC on fc1 only, fc2 uses FP:
{
"total_blocks": 28,
"default_stoc_len": 256,
"default_timewise": 0.5,
"blocks": [
{
"qk": {"enabled": true, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
"av": {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
"proj": {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
"mlp_fc1": {"enabled": true, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
"mlp_fc2": {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
"input_proj": {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null}
}
]
}Per-head mixed precision for AV (heads 0-7 at int8, heads 8-15 at int6):
{
"av": {
"enabled": true,
"stoc_len": 256,
"timewise": 1.0,
"group_stoc_lens": [256, 256, 256, 256, 256, 256, 256, 256, 64, 64, 64, 64, 64, 64, 64, 64]
}
}python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 1.0 --qklayerwise 1.0 --avlayerwise 1.0 \
--sc_prec 8 --sc_enable \
--sc_config path/to/config.json \
--image-size 256 --num-sampling-steps 100 --batch-size 16 \
--ckpt pretrained_models/DiT-XL-2-256x256.ptWhen --sc_config is provided, the JSON config overrides the layerwise CLI args for block/operator selection and precision. The --timewise and --*layerwise args are still required for backward compatibility but are ignored in favor of the JSON config's per-operator timewise and enabled fields.
from qdit.sc_integration import SCPrecisionMap
# Create a config: QK@int8 in all blocks, AV@int7 in first 14 blocks
pm = SCPrecisionMap(total_blocks=28, default_stoc_len=256, default_timewise=1.0)
pm.enable_operator_fraction("qk", fraction=1.0, stoc_len=256, timewise=1.0)
pm.enable_operator_fraction("av", fraction=0.5, stoc_len=128, timewise=1.0)
# Set per-head precision for AV in block 0
pm.set_group_stoc_lens(0, "av", [256]*8 + [64]*8)
# Save to JSON
pm.to_json("my_config.json")
# Load and inspect
pm2 = SCPrecisionMap.from_json("my_config.json")
print(pm2.summary())
# SCPrecisionMap(total_blocks=28, default_stoc_len=256)
# qk: enabled in 28/28 blocks, stoc_lens=[256], mixed_groups=no
# av: enabled in 14/28 blocks, stoc_lens=[256, 64, 128], mixed_groups=yes
# ...
# Compare total compute budget vs uniform int8
print(f"Total stoc budget: {pm2.total_stoc_budget()}")Results are saved to ../results/<NNN>-qdit_sc_<params>/:
sample_sc.png— generated image gridlog.txt— run logdebug_sc_mlp.csv— MLP SC debug stats (if debug mode enabled)
- Do NOT combine
--staticwith--a_sym— the static quantizer only supports asymmetric activation quantization. Use--a_symwithout--static, or--staticwithout--a_sym. - DiT-XL/2 has 28 blocks and 16 attention heads with head_dim=72.
- The
--sc_enableflag selects the compact enable-signal kernel path, which is required for mixed-precision early termination.
@misc{chen2024QDiT,
title={Q-DiT: Accurate Post-Training Quantization for Diffusion Transformers},
author={Lei Chen and Yuan Meng and Chen Tang and Xinzhu Ma and Jingyan Jiang and Xin Wang and Zhi Wang and Wenwu Zhu},
year={2024},
eprint={2406.17343},
archivePrefix={arXiv},
primaryClass={cs.CV}
url={https://arxiv.org/abs/2406.17343},
}This codebase borrows from GPTQ, Atom and ADM. Thanks to the authors for releasing their codebases!