backend/native: TD-T6 — real AVX2 kernels for scal/nrm2/asum (f32+f64)#186
Merged
Conversation
Closes TD-T6 (critical audit finding from the per-CPU matrix doc).
Before this commit, the AVX2 native BLAS-1 module had:
pub fn scal_f32(alpha: f32, x: &mut [f32]) {
super::scalar::scal_f32(alpha, x); // ← scalar shim, no AVX2
}
pub fn nrm2_f32(x: &[f32]) -> f32 {
super::scalar::nrm2_f32(x) // ← scalar shim
}
pub fn asum_f32(x: &[f32]) -> f32 {
super::scalar::asum_f32(x) // ← scalar shim
}
// ... and f64 siblings, same shape
These were the documented "// No AVX2 specialization — fall through
to scalar" path. Three operations on every Haswell+ host fell to
scalar even though `dot_f32_avx2` and `axpy_f32_avx2` shipped real
AVX2 in the same module since day one. PR #180's audit flagged this
as TD-T6 (critical: blocks BLAS-1 throughput on Haswell / Arrow
Lake / Zen 1-3).
New AVX2 kernels (6 total — f32 + f64 for each of scal / nrm2 / asum):
scal: broadcast α to ymm via `_mm256_set1_ps`, multiply 8/4 lanes
at a time via `_mm256_mul_ps`/`_mm256_mul_pd`, scalar tail.
Stores result back to the same buffer in-place.
nrm2: two-accumulator unroll with `_mm256_fmadd_ps`/`_pd` (x²
accumulated via FMA, single-rounded per IEEE), horizontal
reduce + scalar sqrt. Same shape as `dot_f32_avx2` (which
also unrolls 2 accumulators + uses FMA), just operates on
one input vector instead of two.
asum: abs via `_mm256_and_ps`/`_pd` with a sign-bit-cleared mask
(0x7FFFFFFF for f32, 0x7FFFFFFFFFFFFFFF for f64) — one
AVX instruction (VANDPS) is faster than calling f32::abs()
lane-by-lane. Two-accumulator unroll + horizontal reduce.
All three follow the existing `dot_f32_avx2` template:
- `#[target_feature(enable = "avx2[,fma]")]` on the inner unsafe fn.
- Public wrapper does `cfg(target_arch = "x86_64")` and dispatches
to the unsafe fn (tier detection in caller-of-caller verified
AVX2 before reaching this module).
- Non-x86_64 builds: pass through to `super::scalar::*`.
- Scalar tail handles `n % chunk_size` lanes via the same fold the
scalar reference uses.
Numerical contract:
scal: byte-equal to scalar (`x[i] *= α` is the same op).
asum: small ULP drift on long vectors because the SIMD horizontal
reduce orders the sum differently from strict left-fold.
Test tolerance: `|got - expected| <= |expected|*1e-5 + 1e-6`.
nrm2: same — drifts ~1-2 ULP on long vectors via reduce-order +
sqrt rounding. Same tolerance.
3 new parity tests (`td_t6_scal_f32_parity`,
`td_t6_nrm2_f32_parity`, `td_t6_asum_f32_parity`) sweep
n ∈ {0, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100} — covers the
chunk-of-16 unroll path, the chunk-of-8 cleanup path, and the
scalar tail for every kernel.
Verification:
* 2090 lib tests pass (was 2087 — +3 new parity tests; the
existing test_scal_f32 / test_nrm2_f64 / test_asum_f32 that
used to hit the scalar shims now exercise the AVX2 kernels
and continue to pass).
* cargo clippy --lib --tests --features rayon,native -- -D warnings
clean.
* cargo clippy --lib --tests --features rayon,native,runtime-dispatch
-- -D warnings clean.
* cargo fmt --all --check clean.
Throughput impact (back-of-envelope on Sapphire Rapids, n=4096):
scal_f32: scalar 4096 cycles (1 mul/lane) → AVX2 ~520 cycles
(8 lanes/instr + 1-cycle issue) = ~8× faster.
asum_f32: scalar 4096 cycles → AVX2 ~520 cycles = ~8× faster.
nrm2_f32: scalar 4096 cycles (1 FMA/lane) → AVX2 ~260 cycles
(16 lanes via 2-acc unroll, 1-cycle issue) = ~16×.
Out of scope (separate PRs):
* AVX-512 versions of the same three ops — `kernels_avx512.rs`
has them already (lines 137-209), wired through the
cfg(target_feature = "avx512f") path. This commit fixes the
AVX2 tier, which serves Haswell through Arrow Lake / Zen 1-3.
* Runtime-dispatch trampolines for these ops (would go in
`simd_runtime/blas_l1.rs` mirroring the matmul.rs pattern from
the runtime-dispatch PR).
https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes TD-T6 (critical audit finding from the per-CPU matrix doc). The AVX2 native BLAS-1 module had documented
// No AVX2 specialization — fall through to scalarshims forscal_f32/scal_f64/nrm2_f32/nrm2_f64/asum_f32/asum_f64— six ops on every Haswell+ host fell to scalar even thoughdot_f32_avx2andaxpy_f32_avx2shipped real AVX2 in the same module since day one.This wires the six missing kernels.
Kernels
scal_*_mm256_set1_ps/pd, mul 8/4 lanes, scalar tailnrm2_*_mm256_fmadd_*), horiz reduce + sqrtasum_*_mm256_and_*with sign-bit-cleared mask, sum-reduceAll three follow the existing
dot_f32_avx2template —#[target_feature(enable = "avx2[,fma]")]on the innerunsafe fn, public wrapper doescfg(target_arch = "x86_64"), non-x86 builds keep their scalar fallback, scalar tail handlesn % chunk_size.Numerical contract
scalis byte-equal to scalar (x[i] *= αis the same op).asumdrifts ~1-2 ULP on long vectors because SIMD horizontal reduce orders the sum differently from strict left-fold.nrm2same as asum + final sqrt rounding.Test tolerance:
|got - expected| <= |expected| * 1e-5 + 1e-6(same precedent as the existingdot_f32_avx2and the BLAS reference implementations broadly).Test plan
td_t6_*_paritytests sweepn ∈ {0, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100}— covers the chunk-of-16 unroll path, the chunk-of-8 cleanup, and the scalar tail for every kernel.test_scal_f32/test_nrm2_f64/test_asum_f32(which used to exercise the scalar shims) now hit the AVX2 kernels and continue to pass.cargo clippy --lib --tests --features rayon,native -- -D warningsclean.cargo clippy --lib --tests --features rayon,native,runtime-dispatch -- -D warningsclean.cargo fmt --all --checkclean.Out of scope (separate PRs)
kernels_avx512.rshas them already (lines 137-209), wired through thecfg(target_feature = "avx512f")path. This PR fixes the AVX2 tier, which serves Haswell through Arrow Lake / Zen 1-3.simd_runtime/blas_l1.rsmirroring thematmul.rspattern from PR simd_int_ops, hpc: AMX TDPBUSD arm for gemm_u8_i8 slice surface #185's runtime-dispatch landing).https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
Generated by Claude Code