From 7d40704a8e9ccf07c3254e3b71820250d84557ac Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 May 2026 07:58:22 +0000 Subject: [PATCH] =?UTF-8?q?feat(backend/native):=20TD-T6=20=E2=80=94=20rea?= =?UTF-8?q?l=20AVX2=20kernels=20for=20scal/nrm2/asum?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes TD-T6 (critical audit finding from the per-CPU matrix doc). Before this commit, the AVX2 native BLAS-1 module had: pub fn scal_f32(alpha: f32, x: &mut [f32]) { super::scalar::scal_f32(alpha, x); // ← scalar shim, no AVX2 } pub fn nrm2_f32(x: &[f32]) -> f32 { super::scalar::nrm2_f32(x) // ← scalar shim } pub fn asum_f32(x: &[f32]) -> f32 { super::scalar::asum_f32(x) // ← scalar shim } // ... and f64 siblings, same shape These were the documented "// No AVX2 specialization — fall through to scalar" path. Three operations on every Haswell+ host fell to scalar even though `dot_f32_avx2` and `axpy_f32_avx2` shipped real AVX2 in the same module since day one. PR #180's audit flagged this as TD-T6 (critical: blocks BLAS-1 throughput on Haswell / Arrow Lake / Zen 1-3). New AVX2 kernels (6 total — f32 + f64 for each of scal / nrm2 / asum): scal: broadcast α to ymm via `_mm256_set1_ps`, multiply 8/4 lanes at a time via `_mm256_mul_ps`/`_mm256_mul_pd`, scalar tail. Stores result back to the same buffer in-place. nrm2: two-accumulator unroll with `_mm256_fmadd_ps`/`_pd` (x² accumulated via FMA, single-rounded per IEEE), horizontal reduce + scalar sqrt. Same shape as `dot_f32_avx2` (which also unrolls 2 accumulators + uses FMA), just operates on one input vector instead of two. asum: abs via `_mm256_and_ps`/`_pd` with a sign-bit-cleared mask (0x7FFFFFFF for f32, 0x7FFFFFFFFFFFFFFF for f64) — one AVX instruction (VANDPS) is faster than calling f32::abs() lane-by-lane. Two-accumulator unroll + horizontal reduce. All three follow the existing `dot_f32_avx2` template: - `#[target_feature(enable = "avx2[,fma]")]` on the inner unsafe fn. - Public wrapper does `cfg(target_arch = "x86_64")` and dispatches to the unsafe fn (tier detection in caller-of-caller verified AVX2 before reaching this module). - Non-x86_64 builds: pass through to `super::scalar::*`. - Scalar tail handles `n % chunk_size` lanes via the same fold the scalar reference uses. Numerical contract: scal: byte-equal to scalar (`x[i] *= α` is the same op). asum: small ULP drift on long vectors because the SIMD horizontal reduce orders the sum differently from strict left-fold. Test tolerance: `|got - expected| <= |expected|*1e-5 + 1e-6`. nrm2: same — drifts ~1-2 ULP on long vectors via reduce-order + sqrt rounding. Same tolerance. 3 new parity tests (`td_t6_scal_f32_parity`, `td_t6_nrm2_f32_parity`, `td_t6_asum_f32_parity`) sweep n ∈ {0, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100} — covers the chunk-of-16 unroll path, the chunk-of-8 cleanup path, and the scalar tail for every kernel. Verification: * 2090 lib tests pass (was 2087 — +3 new parity tests; the existing test_scal_f32 / test_nrm2_f64 / test_asum_f32 that used to hit the scalar shims now exercise the AVX2 kernels and continue to pass). * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo clippy --lib --tests --features rayon,native,runtime-dispatch -- -D warnings clean. * cargo fmt --all --check clean. Throughput impact (back-of-envelope on Sapphire Rapids, n=4096): scal_f32: scalar 4096 cycles (1 mul/lane) → AVX2 ~520 cycles (8 lanes/instr + 1-cycle issue) = ~8× faster. asum_f32: scalar 4096 cycles → AVX2 ~520 cycles = ~8× faster. nrm2_f32: scalar 4096 cycles (1 FMA/lane) → AVX2 ~260 cycles (16 lanes via 2-acc unroll, 1-cycle issue) = ~16×. Out of scope (separate PRs): * AVX-512 versions of the same three ops — `kernels_avx512.rs` has them already (lines 137-209), wired through the cfg(target_feature = "avx512f") path. This commit fixes the AVX2 tier, which serves Haswell through Arrow Lake / Zen 1-3. * Runtime-dispatch trampolines for these ops (would go in `simd_runtime/blas_l1.rs` mirroring the matmul.rs pattern from the runtime-dispatch PR). https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u --- src/backend/native.rs | 315 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 308 insertions(+), 7 deletions(-) diff --git a/src/backend/native.rs b/src/backend/native.rs index 56123970..ee14bbb7 100644 --- a/src/backend/native.rs +++ b/src/backend/native.rs @@ -540,24 +540,71 @@ mod avx2 { } } - // No AVX2 specialization — fall through to scalar pub fn scal_f32(alpha: f32, x: &mut [f32]) { - super::scalar::scal_f32(alpha, x); + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() already verified AVX2 support before calling. + unsafe { scal_f32_avx2(alpha, x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::scal_f32(alpha, x); + } } pub fn scal_f64(alpha: f64, x: &mut [f64]) { - super::scalar::scal_f64(alpha, x); + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() already verified AVX2 support before calling. + unsafe { scal_f64_avx2(alpha, x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::scal_f64(alpha, x); + } } pub fn nrm2_f32(x: &[f32]) -> f32 { - super::scalar::nrm2_f32(x) + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() verified AVX2+FMA. + unsafe { nrm2_f32_avx2(x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::nrm2_f32(x) + } } pub fn nrm2_f64(x: &[f64]) -> f64 { - super::scalar::nrm2_f64(x) + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() verified AVX2+FMA. + unsafe { nrm2_f64_avx2(x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::nrm2_f64(x) + } } pub fn asum_f32(x: &[f32]) -> f32 { - super::scalar::asum_f32(x) + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() verified AVX2. + unsafe { asum_f32_avx2(x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::asum_f32(x) + } } pub fn asum_f64(x: &[f64]) -> f64 { - super::scalar::asum_f64(x) + #[cfg(target_arch = "x86_64")] + { + // SAFETY: tier() verified AVX2. + unsafe { asum_f64_avx2(x) } + } + #[cfg(not(target_arch = "x86_64"))] + { + super::scalar::asum_f64(x) + } } // ── AVX2 intrinsic implementations ───────────────────────────── @@ -677,6 +724,201 @@ mod avx2 { i += 1; } } + + // ── scal: x[i] *= alpha ──────────────────────────────────────── + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2")] + unsafe fn scal_f32_avx2(alpha: f32, x: &mut [f32]) { + use core::arch::x86_64::*; + let n = x.len(); + let valpha = _mm256_set1_ps(alpha); + let mut i = 0; + while i + 8 <= n { + let v = _mm256_loadu_ps(x.as_ptr().add(i)); + _mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v, valpha)); + i += 8; + } + while i < n { + x[i] *= alpha; + i += 1; + } + } + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2")] + unsafe fn scal_f64_avx2(alpha: f64, x: &mut [f64]) { + use core::arch::x86_64::*; + let n = x.len(); + let valpha = _mm256_set1_pd(alpha); + let mut i = 0; + while i + 4 <= n { + let v = _mm256_loadu_pd(x.as_ptr().add(i)); + _mm256_storeu_pd(x.as_mut_ptr().add(i), _mm256_mul_pd(v, valpha)); + i += 4; + } + while i < n { + x[i] *= alpha; + i += 1; + } + } + + // ── nrm2: sqrt(Σ x[i]²) ──────────────────────────────────────── + // + // Two-accumulator unroll + FMA for the squared sum, scalar sqrt at + // the end. SIMD horizontal reduce ordering differs from the strict + // left-fold the scalar reference uses, so the ULP error can drift + // by 1-2 ULP on long vectors — same tolerance the existing + // `dot_f32_avx2` carries, accepted in BLAS-1. + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2,fma")] + unsafe fn nrm2_f32_avx2(x: &[f32]) -> f32 { + use core::arch::x86_64::*; + let n = x.len(); + let mut acc0 = _mm256_setzero_ps(); + let mut acc1 = _mm256_setzero_ps(); + let mut i = 0; + while i + 16 <= n { + let v0 = _mm256_loadu_ps(x.as_ptr().add(i)); + let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8)); + acc0 = _mm256_fmadd_ps(v0, v0, acc0); + acc1 = _mm256_fmadd_ps(v1, v1, acc1); + i += 16; + } + while i + 8 <= n { + let v = _mm256_loadu_ps(x.as_ptr().add(i)); + acc0 = _mm256_fmadd_ps(v, v, acc0); + i += 8; + } + acc0 = _mm256_add_ps(acc0, acc1); + let hi = _mm256_extractf128_ps(acc0, 1); + let lo = _mm256_castps256_ps128(acc0); + let sum128 = _mm_add_ps(lo, hi); + let shuf = _mm_movehdup_ps(sum128); + let sums = _mm_add_ps(sum128, shuf); + let shuf2 = _mm_movehl_ps(sums, sums); + let result = _mm_add_ss(sums, shuf2); + let mut total = _mm_cvtss_f32(result); + while i < n { + total += x[i] * x[i]; + i += 1; + } + total.sqrt() + } + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2,fma")] + unsafe fn nrm2_f64_avx2(x: &[f64]) -> f64 { + use core::arch::x86_64::*; + let n = x.len(); + let mut acc0 = _mm256_setzero_pd(); + let mut acc1 = _mm256_setzero_pd(); + let mut i = 0; + while i + 8 <= n { + let v0 = _mm256_loadu_pd(x.as_ptr().add(i)); + let v1 = _mm256_loadu_pd(x.as_ptr().add(i + 4)); + acc0 = _mm256_fmadd_pd(v0, v0, acc0); + acc1 = _mm256_fmadd_pd(v1, v1, acc1); + i += 8; + } + while i + 4 <= n { + let v = _mm256_loadu_pd(x.as_ptr().add(i)); + acc0 = _mm256_fmadd_pd(v, v, acc0); + i += 4; + } + acc0 = _mm256_add_pd(acc0, acc1); + let hi = _mm256_extractf128_pd(acc0, 1); + let lo = _mm256_castpd256_pd128(acc0); + let sum128 = _mm_add_pd(lo, hi); + let shuf = _mm_unpackhi_pd(sum128, sum128); + let result = _mm_add_sd(sum128, shuf); + let mut total = _mm_cvtsd_f64(result); + while i < n { + total += x[i] * x[i]; + i += 1; + } + total.sqrt() + } + + // ── asum: Σ |x[i]| ───────────────────────────────────────────── + // + // Abs via AND with sign-bit-cleared mask (one AVX instruction — + // VANDPS), horizontal sum at the end. Same ordering caveat as + // nrm2. + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2")] + unsafe fn asum_f32_avx2(x: &[f32]) -> f32 { + use core::arch::x86_64::*; + let n = x.len(); + // Sign-bit-cleared mask: 0x7FFFFFFF in every lane. + let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFF_FFFFi32)); + let mut acc0 = _mm256_setzero_ps(); + let mut acc1 = _mm256_setzero_ps(); + let mut i = 0; + while i + 16 <= n { + let v0 = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i)), abs_mask); + let v1 = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i + 8)), abs_mask); + acc0 = _mm256_add_ps(acc0, v0); + acc1 = _mm256_add_ps(acc1, v1); + i += 16; + } + while i + 8 <= n { + let v = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i)), abs_mask); + acc0 = _mm256_add_ps(acc0, v); + i += 8; + } + acc0 = _mm256_add_ps(acc0, acc1); + let hi = _mm256_extractf128_ps(acc0, 1); + let lo = _mm256_castps256_ps128(acc0); + let sum128 = _mm_add_ps(lo, hi); + let shuf = _mm_movehdup_ps(sum128); + let sums = _mm_add_ps(sum128, shuf); + let shuf2 = _mm_movehl_ps(sums, sums); + let result = _mm_add_ss(sums, shuf2); + let mut total = _mm_cvtss_f32(result); + while i < n { + total += x[i].abs(); + i += 1; + } + total + } + + #[cfg(target_arch = "x86_64")] + #[target_feature(enable = "avx2")] + unsafe fn asum_f64_avx2(x: &[f64]) -> f64 { + use core::arch::x86_64::*; + let n = x.len(); + let abs_mask = _mm256_castsi256_pd(_mm256_set1_epi64x(0x7FFF_FFFF_FFFF_FFFFi64)); + let mut acc0 = _mm256_setzero_pd(); + let mut acc1 = _mm256_setzero_pd(); + let mut i = 0; + while i + 8 <= n { + let v0 = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i)), abs_mask); + let v1 = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i + 4)), abs_mask); + acc0 = _mm256_add_pd(acc0, v0); + acc1 = _mm256_add_pd(acc1, v1); + i += 8; + } + while i + 4 <= n { + let v = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i)), abs_mask); + acc0 = _mm256_add_pd(acc0, v); + i += 4; + } + acc0 = _mm256_add_pd(acc0, acc1); + let hi = _mm256_extractf128_pd(acc0, 1); + let lo = _mm256_castpd256_pd128(acc0); + let sum128 = _mm_add_pd(lo, hi); + let shuf = _mm_unpackhi_pd(sum128, sum128); + let result = _mm_add_sd(sum128, shuf); + let mut total = _mm_cvtsd_f64(result); + while i < n { + total += x[i].abs(); + i += 1; + } + total + } } // ═══════════════════════════════════════════════════════════════════ @@ -760,4 +1002,63 @@ mod tests { // Should be one of the valid tier values assert!(nr == 4 || nr == 8 || nr == 16); } + + // ── TD-T6: parity sweep for the new AVX2 BLAS-1 kernels ──────── + // + // The shim → real-intrinsic switch flipped scal/nrm2/asum from + // scalar-fallthrough to AVX2 chunked + scalar-tail kernels. Each + // new kernel: verify byte-equal (or ULP-tight for nrm2 which + // includes a sqrt and a different sum order) against the scalar + // reference across shapes that exercise the chunk-of-16, chunk- + // of-8, and scalar-tail code paths. + + fn ref_scal(alpha: f32, x: &[f32]) -> Vec { + x.iter().map(|&v| v * alpha).collect() + } + fn ref_nrm2(x: &[f32]) -> f32 { + x.iter().map(|&v| v * v).sum::().sqrt() + } + fn ref_asum(x: &[f32]) -> f32 { + x.iter().map(|&v| v.abs()).sum() + } + + #[test] + fn td_t6_scal_f32_parity() { + for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] { + let alpha = 1.5f32; + let init: Vec = (0..n).map(|i| (i as f32 * 0.5) - 1.0).collect(); + let expected = ref_scal(alpha, &init); + let mut got = init.clone(); + scal_f32(alpha, &mut got); + for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() { + assert_eq!(g.to_bits(), e.to_bits(), "scal_f32 n={n} i={i}: got {g} want {e}"); + } + } + } + + #[test] + fn td_t6_nrm2_f32_parity() { + for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] { + let x: Vec = (0..n).map(|i| (i as f32 * 0.3) - 0.5).collect(); + let expected = ref_nrm2(&x); + let got = nrm2_f32(&x); + // ULP tolerance because SIMD reduce order differs from + // strict left-fold; nrm2 also includes the final sqrt. + let abs_err = (got - expected).abs(); + let rel_tol = expected.abs() * 1e-5 + 1e-6; + assert!(abs_err <= rel_tol, "nrm2_f32 n={n}: got {got} want {expected} (err {abs_err})"); + } + } + + #[test] + fn td_t6_asum_f32_parity() { + for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] { + let x: Vec = (0..n).map(|i| (i as f32 * 0.3) - 0.5).collect(); + let expected = ref_asum(&x); + let got = asum_f32(&x); + let abs_err = (got - expected).abs(); + let rel_tol = expected.abs() * 1e-5 + 1e-6; + assert!(abs_err <= rel_tol, "asum_f32 n={n}: got {got} want {expected} (err {abs_err})"); + } + } }