Skip to content

quant: add SmoothQuant pre-quantization transform#15

Merged
heroarmor merged 2 commits into
mainfrom
feat/smoothquant
May 20, 2026
Merged

quant: add SmoothQuant pre-quantization transform#15
heroarmor merged 2 commits into
mainfrom
feat/smoothquant

Conversation

@Allenjin123
Copy link
Copy Markdown
Contributor

Summary

Adds the SmoothQuant (Xiao et al., ICML 2023) pre-quantization transform as a new module under scmp_kernels/quant/. Mathematically equivalent diagonal rescale along the shared inner dim D that migrates per-channel activation outliers into the weight so both operands quantize better:

Y = A @ B.T = (A / s) @ (B * s).T,    s_j = act_max[j]^α / w_max[j]^(1-α)

The actual int-quant kernels are unchanged — they just see better-conditioned operands.

API

New module scmp_kernels/quant/smoothquant.py:

Function Purpose
accumulate_act_scales(x, running=None) Per-channel max-abs aggregator for calibration pass
compute_smooth_scales(act_scales, weight, alpha=0.5) Build (D,) smoothing vector from calibrated stats
apply_smoothing(a, b, smooth_scales) Diagonal rescale (2D + 3D inputs)
apply_smoothing_offline(weight, smooth_scales) Bake s into a weight once

Wired into sc_matmul as one new kwarg:

sc_matmul(a, b, ..., smooth_scales=s)   # equivalent to sc_matmul(a/s, b*s)

Default None preserves byte-for-byte legacy behavior.

Test results

tests/test_smoothquant.py — 10 tests, all pass on gl1810 (RTX PRO 6000 Blackwell):

  • Calibration aggregator matches torch.maximum(...).amax(0) across batches.

  • compute_smooth_scales matches closed forms at α=0 and α=1.

  • (a/s) @ (b*s).T equals a @ b.T to <1e-10 in fp64 (2D and 3D).

  • Outlier-MSE improvement under simulated int8 quant (60× outlier channels):

    Granularity MSE (plain) MSE (smooth) Improvement
    per_tensor 3.58e+00 2.43e-01 14.7×
    per_row 8.87e-01 5.35e-02 16.6×
    per_head 2.96e+00 2.25e-01 13.1×
  • sc_matmul(a, b, smooth_scales=s) is torch.equal to sc_matmul(a/s, b*s) on CUDA.

  • Invalid args (α∉[0,1], D mismatch, non-1D smooth_scales) raise ValueError.

Base branch

This PR is stacked on refactor/extract-quant-module because the new module lives at scmp_kernels/quant/smoothquant.py — that sub-package only exists on that branch. Please merge the refactor PR first; this one will then auto-retarget to main.

Test plan

  • CPU pytest: python -m pytest tests/test_smoothquant.py → 9 pass, 1 skipped (CUDA test self-skips without GPU).
  • GPU pytest on gl1810: 10/10 pass, including test_sc_matmul_smooth_scales_kwarg_equivalence.
  • Reviewer to confirm calibration aggregator API matches their existing observer pattern (none today, this is the first one).

Out of scope

  • No automatic α search — caller picks (upstream recommends 0.5 for OPT, 0.85 for Llama).
  • No model-walker like upstream's smooth_lmscmp_kernels is model-agnostic.
  • smooth_scales is shape (D,) only; per-head smoothing (BH, D) could be added later.

🤖 Generated with Claude Code

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a SmoothQuant-style pre-quantization diagonal rescaling to improve int-quant conditioning for SC matmul inputs, while preserving the mathematical result of a @ b.T when applied consistently.

Changes:

  • Introduces scmp_kernels/quant/smoothquant.py with calibration, scale computation, and smoothing application helpers.
  • Adds an optional smooth_scales: (D,) kwarg to sc_matmul to apply (a / s, b * s) before dispatching to existing kernels.
  • Adds a dedicated test suite validating helper identities, closed forms, MSE improvement under simulated int8 quant, and CUDA wiring equivalence.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
scmp_kernels/quant/smoothquant.py New SmoothQuant helper functions (calibration, scale computation, online/offline application).
scmp_kernels/sc/matmul.py Adds smooth_scales kwarg and applies the pre-transform before kernel dispatch.
scmp_kernels/quant/__init__.py Re-exports SmoothQuant helpers from the quant subpackage.
tests/test_smoothquant.py Adds unit tests for helpers plus a CUDA-only sc_matmul(..., smooth_scales=...) equivalence test.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

flat = x.detach().abs().reshape(-1, x.shape[-1])
cur = flat.amax(dim=0)
if running is None:
return cur.clone()
Comment on lines +115 to +126
if smooth_scales.dim() != 1:
raise ValueError(
f"smooth_scales must be 1D (D,), got shape "
f"{tuple(smooth_scales.shape)}")
if a.shape[-1] != smooth_scales.shape[0] or b.shape[-1] != smooth_scales.shape[0]:
raise ValueError(
f"D mismatch: a.shape[-1]={a.shape[-1]}, b.shape[-1]={b.shape[-1]}, "
f"smooth_scales.shape[0]={smooth_scales.shape[0]}")

sa = smooth_scales.to(device=a.device, dtype=a.dtype)
sb = smooth_scales.to(device=b.device, dtype=b.dtype)
return a / sa, b * sb
if weight.shape[-1] != smooth_scales.shape[0]:
raise ValueError(
f"D mismatch: weight.shape[-1]={weight.shape[-1]} vs "
f"smooth_scales.shape[0]={smooth_scales.shape[0]}")
Comment thread tests/test_smoothquant.py
Comment on lines +135 to +136


Copy link
Copy Markdown
Collaborator

@heroarmor heroarmor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed — good. Note: stacked on #13, merge that first.

Base is refactor/extract-quant-module, not main (the module lives there), so #13 needs to land first; this will then auto-retarget to main.

Math checks out: (a/s) @ (b·s)ᵀ cancels along the contracted D dim, and the fp64 identity tests confirm it to <1e-10 (2D + 3D). α=0/α=1 closed forms are correct. Opt-in smooth_scales=None preserves byte-for-byte legacy behavior, and the ValueError validation is thorough.

Two things worth considering:

  1. compute_smooth_scales casts to weight.dtype before pow(α)/pow(1-α). For fp16/bf16 weights this loses precision — upstream computes scales in fp32. Cheap fix: do the scale math in fp32, return in weight dtype.
  2. test_sc_matmul_smooth_scales_kwarg_equivalence is essentially tautological (both sides call apply_smoothing), so it validates wiring but not the quantization benefit on the SC path — that's only shown via CPU fake-quant. Fine as a smoke test, just scope-aware.

@heroarmor heroarmor changed the base branch from refactor/extract-quant-module to main May 20, 2026 06:36
@heroarmor
Copy link
Copy Markdown
Collaborator

@Allenjin123 #13 has merged to main, and I've retargeted this PR to main. Heads up: #13 was squash-merged, so its refactor commits landed on main under a new SHA (fa7387d). This branch still carries the original refactor commits, so git now sees them as conflicting re-additions of quant/fused.py, grouped.py, and the kernels.py edits — the PR shows as CONFLICTING and the diff includes all 7 files instead of just the SmoothQuant ones.

Could you rebase feat/smoothquant onto the latest main? That'll drop the redundant refactor commits and leave only the four SmoothQuant-specific changes (quant/__init__.py exports, quant/smoothquant.py, the sc_matmul kwarg in matmul.py, and tests/test_smoothquant.py).

git fetch origin
git rebase origin/main
# resolve by keeping main's versions of fused.py/grouped.py/kernels.py; your smoothquant files stay
git push --force-with-lease

Once it's clean I'll approve. The earlier review notes still stand (fp32 for the pow scale math; the kwarg test being a wiring smoke test) — optional, non-blocking.

Mathematically equivalent diagonal rescale along D that migrates per-channel
activation outliers into the weight:

    Y = A @ B.T = (A / s) @ (B * s).T,   s_j = act_max[j]^a / w_max[j]^(1-a)

New helpers in scmp_kernels/quant/smoothquant.py:

  accumulate_act_scales  - per-channel max-abs aggregator for calibration
  compute_smooth_scales  - build s from calibrated stats + weight
  apply_smoothing        - apply diagonal rescale (2D and 3D)
  apply_smoothing_offline- bake s into the weight once

Wired into sc_matmul as an optional smooth_scales kwarg; default None
preserves byte-for-byte legacy behavior. Works for all three granularities
(per_tensor, per_row, per_head).

Tests in tests/test_smoothquant.py cover the math identity, calibration
aggregator, alpha=0/1 closed forms, MSE improvement under simulated int8
quant for all three granularities (13-17x on synthetic outliers), and the
sc_matmul kwarg-vs-manual equivalence (CUDA-only).

Reference: Xiao et al., "SmoothQuant: Accurate and Efficient PTQ for LLMs",
ICML 2023.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The quant sub-package extraction described by this plan landed in #13.
The plan is now historical context only — remove from the working tree to
avoid drift between the doc and the actual layout.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

flat = x.detach().abs().reshape(-1, x.shape[-1])
cur = flat.amax(dim=0)
if running is None:
return cur.clone()
Comment on lines +115 to +126
if smooth_scales.dim() != 1:
raise ValueError(
f"smooth_scales must be 1D (D,), got shape "
f"{tuple(smooth_scales.shape)}")
if a.shape[-1] != smooth_scales.shape[0] or b.shape[-1] != smooth_scales.shape[0]:
raise ValueError(
f"D mismatch: a.shape[-1]={a.shape[-1]}, b.shape[-1]={b.shape[-1]}, "
f"smooth_scales.shape[0]={smooth_scales.shape[0]}")

sa = smooth_scales.to(device=a.device, dtype=a.dtype)
sb = smooth_scales.to(device=b.device, dtype=b.dtype)
return a / sa, b * sb
Comment on lines +139 to +147
if smooth_scales.dim() != 1:
raise ValueError(
f"smooth_scales must be 1D (D,), got shape "
f"{tuple(smooth_scales.shape)}")
if weight.shape[-1] != smooth_scales.shape[0]:
raise ValueError(
f"D mismatch: weight.shape[-1]={weight.shape[-1]} vs "
f"smooth_scales.shape[0]={smooth_scales.shape[0]}")
return weight * smooth_scales.to(device=weight.device, dtype=weight.dtype)
Copy link
Copy Markdown
Collaborator

@heroarmor heroarmor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebase looks clean — diff is now scoped to the four SmoothQuant files (plus the stale MIGRATION_PLAN.md removal, which is fine to fold in here). Math identity holds, opt-in kwarg preserves legacy behavior, validation is solid. Approving. The two earlier notes (fp32 for the pow scale math; the kwarg test being a wiring smoke test) remain optional/non-blocking.

@heroarmor heroarmor merged commit c422fce into main May 20, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants