quant: add SmoothQuant pre-quantization transform#15
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a SmoothQuant-style pre-quantization diagonal rescaling to improve int-quant conditioning for SC matmul inputs, while preserving the mathematical result of a @ b.T when applied consistently.
Changes:
- Introduces
scmp_kernels/quant/smoothquant.pywith calibration, scale computation, and smoothing application helpers. - Adds an optional
smooth_scales: (D,)kwarg tosc_matmulto apply(a / s, b * s)before dispatching to existing kernels. - Adds a dedicated test suite validating helper identities, closed forms, MSE improvement under simulated int8 quant, and CUDA wiring equivalence.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
scmp_kernels/quant/smoothquant.py |
New SmoothQuant helper functions (calibration, scale computation, online/offline application). |
scmp_kernels/sc/matmul.py |
Adds smooth_scales kwarg and applies the pre-transform before kernel dispatch. |
scmp_kernels/quant/__init__.py |
Re-exports SmoothQuant helpers from the quant subpackage. |
tests/test_smoothquant.py |
Adds unit tests for helpers plus a CUDA-only sc_matmul(..., smooth_scales=...) equivalence test. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| flat = x.detach().abs().reshape(-1, x.shape[-1]) | ||
| cur = flat.amax(dim=0) | ||
| if running is None: | ||
| return cur.clone() |
| if smooth_scales.dim() != 1: | ||
| raise ValueError( | ||
| f"smooth_scales must be 1D (D,), got shape " | ||
| f"{tuple(smooth_scales.shape)}") | ||
| if a.shape[-1] != smooth_scales.shape[0] or b.shape[-1] != smooth_scales.shape[0]: | ||
| raise ValueError( | ||
| f"D mismatch: a.shape[-1]={a.shape[-1]}, b.shape[-1]={b.shape[-1]}, " | ||
| f"smooth_scales.shape[0]={smooth_scales.shape[0]}") | ||
|
|
||
| sa = smooth_scales.to(device=a.device, dtype=a.dtype) | ||
| sb = smooth_scales.to(device=b.device, dtype=b.dtype) | ||
| return a / sa, b * sb |
| if weight.shape[-1] != smooth_scales.shape[0]: | ||
| raise ValueError( | ||
| f"D mismatch: weight.shape[-1]={weight.shape[-1]} vs " | ||
| f"smooth_scales.shape[0]={smooth_scales.shape[0]}") |
|
|
||
|
|
heroarmor
left a comment
There was a problem hiding this comment.
Reviewed — good. Note: stacked on #13, merge that first. ✅
Base is refactor/extract-quant-module, not main (the module lives there), so #13 needs to land first; this will then auto-retarget to main.
Math checks out: (a/s) @ (b·s)ᵀ cancels along the contracted D dim, and the fp64 identity tests confirm it to <1e-10 (2D + 3D). α=0/α=1 closed forms are correct. Opt-in smooth_scales=None preserves byte-for-byte legacy behavior, and the ValueError validation is thorough.
Two things worth considering:
compute_smooth_scalescasts toweight.dtypebeforepow(α)/pow(1-α). For fp16/bf16 weights this loses precision — upstream computes scales in fp32. Cheap fix: do the scale math in fp32, return in weight dtype.test_sc_matmul_smooth_scales_kwarg_equivalenceis essentially tautological (both sides callapply_smoothing), so it validates wiring but not the quantization benefit on the SC path — that's only shown via CPU fake-quant. Fine as a smoke test, just scope-aware.
|
@Allenjin123 #13 has merged to Could you rebase Once it's clean I'll approve. The earlier review notes still stand (fp32 for the |
Mathematically equivalent diagonal rescale along D that migrates per-channel
activation outliers into the weight:
Y = A @ B.T = (A / s) @ (B * s).T, s_j = act_max[j]^a / w_max[j]^(1-a)
New helpers in scmp_kernels/quant/smoothquant.py:
accumulate_act_scales - per-channel max-abs aggregator for calibration
compute_smooth_scales - build s from calibrated stats + weight
apply_smoothing - apply diagonal rescale (2D and 3D)
apply_smoothing_offline- bake s into the weight once
Wired into sc_matmul as an optional smooth_scales kwarg; default None
preserves byte-for-byte legacy behavior. Works for all three granularities
(per_tensor, per_row, per_head).
Tests in tests/test_smoothquant.py cover the math identity, calibration
aggregator, alpha=0/1 closed forms, MSE improvement under simulated int8
quant for all three granularities (13-17x on synthetic outliers), and the
sc_matmul kwarg-vs-manual equivalence (CUDA-only).
Reference: Xiao et al., "SmoothQuant: Accurate and Efficient PTQ for LLMs",
ICML 2023.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
0ae3c47 to
03786c7
Compare
The quant sub-package extraction described by this plan landed in #13. The plan is now historical context only — remove from the working tree to avoid drift between the doc and the actual layout. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| flat = x.detach().abs().reshape(-1, x.shape[-1]) | ||
| cur = flat.amax(dim=0) | ||
| if running is None: | ||
| return cur.clone() |
| if smooth_scales.dim() != 1: | ||
| raise ValueError( | ||
| f"smooth_scales must be 1D (D,), got shape " | ||
| f"{tuple(smooth_scales.shape)}") | ||
| if a.shape[-1] != smooth_scales.shape[0] or b.shape[-1] != smooth_scales.shape[0]: | ||
| raise ValueError( | ||
| f"D mismatch: a.shape[-1]={a.shape[-1]}, b.shape[-1]={b.shape[-1]}, " | ||
| f"smooth_scales.shape[0]={smooth_scales.shape[0]}") | ||
|
|
||
| sa = smooth_scales.to(device=a.device, dtype=a.dtype) | ||
| sb = smooth_scales.to(device=b.device, dtype=b.dtype) | ||
| return a / sa, b * sb |
| if smooth_scales.dim() != 1: | ||
| raise ValueError( | ||
| f"smooth_scales must be 1D (D,), got shape " | ||
| f"{tuple(smooth_scales.shape)}") | ||
| if weight.shape[-1] != smooth_scales.shape[0]: | ||
| raise ValueError( | ||
| f"D mismatch: weight.shape[-1]={weight.shape[-1]} vs " | ||
| f"smooth_scales.shape[0]={smooth_scales.shape[0]}") | ||
| return weight * smooth_scales.to(device=weight.device, dtype=weight.dtype) |
heroarmor
left a comment
There was a problem hiding this comment.
Rebase looks clean — diff is now scoped to the four SmoothQuant files (plus the stale MIGRATION_PLAN.md removal, which is fine to fold in here). Math identity holds, opt-in kwarg preserves legacy behavior, validation is solid. Approving. The two earlier notes (fp32 for the pow scale math; the kwarg test being a wiring smoke test) remain optional/non-blocking.
Summary
Adds the SmoothQuant (Xiao et al., ICML 2023) pre-quantization transform as a new module under
scmp_kernels/quant/. Mathematically equivalent diagonal rescale along the shared inner dimDthat migrates per-channel activation outliers into the weight so both operands quantize better:The actual int-quant kernels are unchanged — they just see better-conditioned operands.
API
New module
scmp_kernels/quant/smoothquant.py:accumulate_act_scales(x, running=None)compute_smooth_scales(act_scales, weight, alpha=0.5)(D,)smoothing vector from calibrated statsapply_smoothing(a, b, smooth_scales)apply_smoothing_offline(weight, smooth_scales)sinto a weight onceWired into
sc_matmulas one new kwarg:Default
Nonepreserves byte-for-byte legacy behavior.Test results
tests/test_smoothquant.py— 10 tests, all pass ongl1810(RTX PRO 6000 Blackwell):Calibration aggregator matches
torch.maximum(...).amax(0)across batches.compute_smooth_scalesmatches closed forms at α=0 and α=1.(a/s) @ (b*s).Tequalsa @ b.Tto <1e-10 in fp64 (2D and 3D).Outlier-MSE improvement under simulated int8 quant (60× outlier channels):
per_tensorper_rowper_headsc_matmul(a, b, smooth_scales=s)istorch.equaltosc_matmul(a/s, b*s)on CUDA.Invalid args (α∉[0,1], D mismatch, non-1D
smooth_scales) raiseValueError.Base branch
This PR is stacked on
refactor/extract-quant-modulebecause the new module lives atscmp_kernels/quant/smoothquant.py— that sub-package only exists on that branch. Please merge the refactor PR first; this one will then auto-retarget tomain.Test plan
python -m pytest tests/test_smoothquant.py→ 9 pass, 1 skipped (CUDA test self-skips without GPU).gl1810: 10/10 pass, includingtest_sc_matmul_smooth_scales_kwarg_equivalence.Out of scope
smooth_lm—scmp_kernelsis model-agnostic.smooth_scalesis shape(D,)only; per-head smoothing(BH, D)could be added later.🤖 Generated with Claude Code