Skip to content

CrucibleComputingGroup/scmp_diffusion

Repository files navigation

Q-DiT: Accurate Post-Training Quantization for Diffusion Transformers

Setup

First, download and set up the repo:

git clone https://github.com/Juanerx/Q-DiT.git
cd Q-DiT

Then create the environment and install required packages:

conda create -n qdit python=3.8
conda activate qdit
pip install -r requirements.txt
pip install .

Usage

Standard Quantization

If you want to use gptq or static quantization, calibration data should be generated by:

cd scripts
python collect_cali_data.py

We can quantize the model:

bash quant_main.sh --image-size 256 --num-sampling-steps 50 --cfg-scale 1.5 --use_gptq

Stochastic Computing (SC) Mode

SC mode replaces floating-point matrix multiplications with stochastic computing, trading precision for hardware efficiency. Run via scripts/quant_sc_main.py.

Basic Usage (Legacy CLI Args)

python scripts/quant_sc_main.py \
    --wbits 8 --abits 8 --w_sym --a_sym \
    --timewise 0.5 --qklayerwise 1.0 --avlayerwise 0.2 \
    --sc_prec 8 --sc_enable \
    --image-size 256 --num-sampling-steps 100 --batch-size 16 \
    --ckpt pretrained_models/DiT-XL-2-256x256.pt

SC CLI Arguments

Argument Type Default Description
--timewise float 0.0 Fraction of timesteps to use SC (0-1). SC is applied to the first timewise * total_timesteps diffusion steps (the noisiest ones).
--qklayerwise float 0.0 Fraction of blocks to use SC for Q@K^T attention (0-1).
--avlayerwise float 0.0 Fraction of blocks to use SC for Attn@V attention (0-1).
--projlayerwise float 0.0 Fraction of blocks to use SC for output projection (0-1).
--mlplayerwise float 0.0 Fraction of blocks to use SC for MLP fc1 and fc2 (0-1).
--inputprojlayerwise float 0.0 Fraction of blocks to use SC for QKV input projection (0-1).
--sc_prec int 8 SC precision in bits. Sets stoc_len = 2^sc_prec (e.g., 8 → stoc_len=256).
--sc_enable flag false Use enable-signal SC multiplication (compact kernel) instead of standard XNOR/AND.
--sc_noise_model flag false Replace real SC kernels with a fast analytical noise surrogate. See SC Noise Model below.
--sc_noise_local_correction float 0.15 Variance correction for per-row/per-batch scaled matmuls (MLP, AV, QK, input proj). Recommended: 0.10 - 0.15.
--sc_noise_global_correction float 0.60 Variance correction for per-tensor scaled matmuls (output proj). Recommended: 0.10 - 0.20.
--reverse_layerwise flag false Apply SC to the last N blocks instead of the first N.
--sc_skip_blocks str "" Comma-separated block indices to skip SC. E.g., "0,27".
--debug_sc flag false Debug mode: run both FP and SC, log per-operator error, use FP result.
--sc_config str None Path to JSON config for per-block/per-operator/per-group precision (overrides layerwise args).
--save_sc_config str None Save the effective SC config to a JSON file after setup.
--mp flag false Enable per-token-row mixed precision for QK/AV. Assigns different stoc_len levels to different token rows based on runtime importance.
--mp_levels str 256,128,64,32 Comma-separated stoc_len levels for MP, sorted descending.
--mp_fractions str None Comma-separated fractions of rows per MP level (must sum to 1). Default: equal fractions (1/N each).

Examples

SC on Q@K^T only, 50% of timesteps, all blocks:

python scripts/quant_sc_main.py \
    --wbits 8 --abits 8 --w_sym --a_sym \
    --timewise 0.5 --qklayerwise 1.0 \
    --sc_prec 8 --sc_enable \
    --image-size 256 --num-sampling-steps 100 --batch-size 16 \
    --ckpt pretrained_models/DiT-XL-2-256x256.pt

SC on Q@K^T + Attn@V + MLP, all timesteps, 80% of blocks:

python scripts/quant_sc_main.py \
    --wbits 8 --abits 8 --w_sym --a_sym \
    --timewise 1.0 --qklayerwise 0.8 --avlayerwise 0.8 --mlplayerwise 0.8 \
    --sc_prec 8 --sc_enable \
    --image-size 256 --num-sampling-steps 100 --batch-size 16 \
    --ckpt pretrained_models/DiT-XL-2-256x256.pt

Save config for reuse:

python scripts/quant_sc_main.py \
    --wbits 8 --abits 8 --w_sym --a_sym \
    --timewise 0.5 --qklayerwise 1.0 --avlayerwise 0.5 \
    --sc_prec 8 --sc_enable \
    --save_sc_config my_config.json \
    --image-size 256 --num-sampling-steps 100 --batch-size 16 \
    --ckpt pretrained_models/DiT-XL-2-256x256.pt

Dynamic Mixed Precision

The --mp flag enables adaptive precision assignment for QK and AV matmuls, integrated directly into the existing _sc_qk/_sc_av paths. Granularity matches the quantization granularity of each operator:

  • QK (per-head): Heads are ranked by ||Q||_inf across tokens. Heads with larger Q magnitude get higher stoc_len. Uses the existing batched kernel path — no speed regression.
  • AV (per-token-row): Within each (batch, head), rows are ranked by max(attn_row). Rows with more concentrated attention get higher stoc_len.

Rows/heads are bucketed into N discrete stoc_len levels using quantile boundaries. Assignment is recomputed every forward pass.

Quick test (4 levels, equal fractions):

python scripts/quant_sc_main.py --wbits 8 --abits 8 --w_sym --a_sym --timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 --sc_prec 8 --sc_enable --mp --mp_levels 256,128,64,32

Custom fractions (aggressive — only 10% at full precision):

python scripts/quant_sc_main.py \
    --wbits 8 --abits 8 --w_sym --a_sym \
    --timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 \
    --sc_prec 8 --sc_enable \
    --mp --mp_levels 256,128,64,32 --mp_fractions 0.1,0.2,0.3,0.4

2 levels (simple high/low split):

python scripts/quant_sc_main.py \
    --wbits 8 --abits 8 --w_sym --a_sym \
    --timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 \
    --sc_prec 8 --sc_enable \
    --mp --mp_levels 256,64 --mp_fractions 0.3,0.7

Note: --mp only activates for blocks/timesteps where QK/AV SC is already enabled via --qklayerwise/--avlayerwise/--timewise. It is handled inside the existing _sc_qk/_sc_av methods — no separate code path.

SC Noise Model (Fast Surrogate)

The --sc_noise_model flag replaces the real SC Triton kernels with an analytical Gaussian noise surrogate. This is useful for large-scale evaluation runs (e.g. 50k FID samples) where the full bitstream SC simulation is too slow.

How it works

Instead of generating bitstreams, AND/XNOR-gating them, and counting bits, the surrogate computes:

out[i,j] = (a @ b^T)[i,j]  +  noise[i,j]

where noise[i,j] ~ N(0, sigma^2) with:

sigma = sqrt(D / L^2 * correction) * a_scale[i] * b_scale[j]
  • D = reduction dimension (e.g. 1152 for DiT-XL hidden, 72 for head_dim)
  • L = stoc_len (from --sc_prec or per-row MP assignment)
  • correction = --sc_noise_local_correction or --sc_noise_global_correction
  • a_scale, b_scale = per-row or per-tensor max of inputs (matches real SC kernel normalization)

The 1/L^2 scaling (instead of 1/L) reflects the Sobol low-discrepancy RNG used by the real SC kernels, which achieves O(1/L) error instead of O(1/sqrt(L)) for iid Bernoulli.

Basic usage

# Uniform SC noise on all operators, using defaults (local=0.15, global=0.60)
python scripts/quant_sc_main.py \
    --wbits 8 --abits 8 --w_sym --a_sym \
    --timewise 1 --qklayerwise 1.0 --avlayerwise 1.0 \
    --projlayerwise 1.0 --mlplayerwise 1.0 --inputprojlayerwise 1.0 \
    --sc_prec 8 --sc_enable --sc_noise_model \
    --image-size 256 --num-sampling-steps 100 --batch-size 16
# With adaptive MP (same as real SC, but fast)
python scripts/quant_sc_main.py \
    --wbits 8 --abits 8 --w_sym --a_sym \
    --timewise 1 --qklayerwise 1.0 --avlayerwise 1.0 \
    --projlayerwise 1.0 --mlplayerwise 1.0 --inputprojlayerwise 1.0 \
    --sc_prec 8 --sc_enable --sc_noise_model \
    --adaptive_mp --mp_levels 256,128,64,32,16 \
    --mp_alpha 0.3 --mp_beta 0.1 \
    --image-size 256 --num-sampling-steps 100 --batch-size 16
# Custom noise correction (tune visually against real SC output)
python scripts/quant_sc_main.py \
    ... --sc_noise_model \
    --sc_noise_local_correction 0.12 \
    --sc_noise_global_correction 0.15

Tuning the correction factors

Parameter Affects Recommended
--sc_noise_local_correction MLP fc1/fc2, AV, QK, input proj (~85% of matmuls) 0.10 - 0.15
--sc_noise_global_correction Output proj (~15% of matmuls) 0.10 - 0.20
  • Larger values = more noise = images look more like real SC with short stoc_len
  • Smaller values = less noise = cleaner images, closer to FP baseline
  • 0.0 = no noise at all (pure quantized matmul, useful for debugging)
  • To match real SC visually, start at local=0.15, global=0.15 and compare side-by-side

Behavior

  • Orthogonal to --sc_enable (both can be on simultaneously)
  • All SC dispatch paths are supported: uniform, adaptive MP, range MP, per-head mixed
  • The real SC path is completely untouched when --sc_noise_model is not set
  • Uses torch.compile for kernel fusion; first ~5 iterations are slow (JIT warmup), then steady state ~12 it/s at batch=8 on RTX PRO 6000

Mixed-Precision SC via JSON Config

The --sc_config flag enables fine-grained control over SC precision per block, per operator, and per group. This allows mixed-precision SC where different operators use different stochastic stream lengths (stoc_len).

Key idea: early termination. A shorter stoc_len means fewer loop iterations in the SC kernel, giving proportional speedup. For example, stoc_len=128 (int7) is ~2x faster than stoc_len=256 (int8).

JSON Config Format

{
  "total_blocks": 28,
  "default_stoc_len": 256,
  "default_timewise": 0.5,
  "blocks": [
    {
      "qk":        {"enabled": true,  "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
      "av":        {"enabled": true,  "stoc_len": 128, "timewise": 1.0, "group_stoc_lens": null},
      "proj":      {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
      "mlp_fc1":   {"enabled": true,  "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
      "mlp_fc2":   {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
      "input_proj": {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null}
    }
  ]
}

The blocks array must have exactly total_blocks entries (28 for DiT-XL/2). Each block contains config for the six operators:

Operator Description
qk Q @ K^T attention BMM
av Attn @ V attention BMM
proj Output projection linear layer
mlp_fc1 MLP first linear layer
mlp_fc2 MLP second linear layer
input_proj QKV input projection linear layer

Each operator has four fields:

Field Type Description
enabled bool Whether SC is used for this operator in this block.
stoc_len int Stochastic stream length. Controls precision/speed tradeoff. Must be ≤ 2^sc_prec.
timewise float Per-operator timewise fraction (0-1). SC is only used during the first timewise * total_timesteps diffusion steps.
group_stoc_lens list or null Per-group stoc_lens for mixed precision within a single operator. For linear ops: one per weight quantization group. For BMM ops (qk, av): one per attention head. null means uniform (all groups use stoc_len).

stoc_len Reference

Precision stoc_len Relative Speed
int8 256 1x (baseline)
int7.5 181 ~1.4x
int7 128 ~2x
int6 64 ~4x
int5 32 ~8x
int4 16 ~16x

Non-power-of-2 values (e.g., 181 for int7.5) are supported. The quantization grid is determined by sc_prec = ceil(log2(stoc_len)).

Mixed-Precision Examples

QK at int8, AV at int7 (128), all blocks, all timesteps:

{
  "total_blocks": 28,
  "default_stoc_len": 256,
  "default_timewise": 1.0,
  "blocks": [
    {
      "qk":        {"enabled": true,  "stoc_len": 256, "timewise": 1.0, "group_stoc_lens": null},
      "av":        {"enabled": true,  "stoc_len": 128, "timewise": 1.0, "group_stoc_lens": null},
      "proj":      {"enabled": false, "stoc_len": 256, "timewise": 1.0, "group_stoc_lens": null},
      "mlp_fc1":   {"enabled": false, "stoc_len": 256, "timewise": 1.0, "group_stoc_lens": null},
      "mlp_fc2":   {"enabled": false, "stoc_len": 256, "timewise": 1.0, "group_stoc_lens": null},
      "input_proj": {"enabled": false, "stoc_len": 256, "timewise": 1.0, "group_stoc_lens": null}
    }
  ]
}

(Repeat the block entry 28 times, or vary per block.)

Independent fc1/fc2 — SC on fc1 only, fc2 uses FP:

{
  "total_blocks": 28,
  "default_stoc_len": 256,
  "default_timewise": 0.5,
  "blocks": [
    {
      "qk":        {"enabled": true,  "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
      "av":        {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
      "proj":      {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
      "mlp_fc1":   {"enabled": true,  "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
      "mlp_fc2":   {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null},
      "input_proj": {"enabled": false, "stoc_len": 256, "timewise": 0.5, "group_stoc_lens": null}
    }
  ]
}

Per-head mixed precision for AV (heads 0-7 at int8, heads 8-15 at int6):

{
  "av": {
    "enabled": true,
    "stoc_len": 256,
    "timewise": 1.0,
    "group_stoc_lens": [256, 256, 256, 256, 256, 256, 256, 256, 64, 64, 64, 64, 64, 64, 64, 64]
  }
}

Running with a JSON Config

python scripts/quant_sc_main.py \
    --wbits 8 --abits 8 --w_sym --a_sym \
    --timewise 1.0 --qklayerwise 1.0 --avlayerwise 1.0 \
    --sc_prec 8 --sc_enable \
    --sc_config path/to/config.json \
    --image-size 256 --num-sampling-steps 100 --batch-size 16 \
    --ckpt pretrained_models/DiT-XL-2-256x256.pt

When --sc_config is provided, the JSON config overrides the layerwise CLI args for block/operator selection and precision. The --timewise and --*layerwise args are still required for backward compatibility but are ignored in favor of the JSON config's per-operator timewise and enabled fields.

Generating Configs Programmatically

from qdit.sc_integration import SCPrecisionMap

# Create a config: QK@int8 in all blocks, AV@int7 in first 14 blocks
pm = SCPrecisionMap(total_blocks=28, default_stoc_len=256, default_timewise=1.0)
pm.enable_operator_fraction("qk", fraction=1.0, stoc_len=256, timewise=1.0)
pm.enable_operator_fraction("av", fraction=0.5, stoc_len=128, timewise=1.0)

# Set per-head precision for AV in block 0
pm.set_group_stoc_lens(0, "av", [256]*8 + [64]*8)

# Save to JSON
pm.to_json("my_config.json")

# Load and inspect
pm2 = SCPrecisionMap.from_json("my_config.json")
print(pm2.summary())
# SCPrecisionMap(total_blocks=28, default_stoc_len=256)
#   qk: enabled in 28/28 blocks, stoc_lens=[256], mixed_groups=no
#   av: enabled in 14/28 blocks, stoc_lens=[256, 64, 128], mixed_groups=yes
#   ...

# Compare total compute budget vs uniform int8
print(f"Total stoc budget: {pm2.total_stoc_budget()}")

Output

Results are saved to ../results/<NNN>-qdit_sc_<params>/:

  • sample_sc.png — generated image grid
  • log.txt — run log
  • debug_sc_mlp.csv — MLP SC debug stats (if debug mode enabled)

Notes

  • Do NOT combine --static with --a_sym — the static quantizer only supports asymmetric activation quantization. Use --a_sym without --static, or --static without --a_sym.
  • DiT-XL/2 has 28 blocks and 16 attention heads with head_dim=72.
  • The --sc_enable flag selects the compact enable-signal kernel path, which is required for mixed-precision early termination.

BibTeX

@misc{chen2024QDiT,
      title={Q-DiT: Accurate Post-Training Quantization for Diffusion Transformers},
      author={Lei Chen and Yuan Meng and Chen Tang and Xinzhu Ma and Jingyan Jiang and Xin Wang and Zhi Wang and Wenwu Zhu},
      year={2024},
      eprint={2406.17343},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
      url={https://arxiv.org/abs/2406.17343},
}

Acknowledgments

This codebase borrows from GPTQ, Atom and ADM. Thanks to the authors for releasing their codebases!

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors