Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 308 additions & 7 deletions src/backend/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,24 +540,71 @@ mod avx2 {
}
}

// No AVX2 specialization — fall through to scalar
pub fn scal_f32(alpha: f32, x: &mut [f32]) {
super::scalar::scal_f32(alpha, x);
#[cfg(target_arch = "x86_64")]
{
// SAFETY: tier() already verified AVX2 support before calling.
unsafe { scal_f32_avx2(alpha, x) }
}
#[cfg(not(target_arch = "x86_64"))]
{
super::scalar::scal_f32(alpha, x);
}
}
pub fn scal_f64(alpha: f64, x: &mut [f64]) {
super::scalar::scal_f64(alpha, x);
#[cfg(target_arch = "x86_64")]
{
// SAFETY: tier() already verified AVX2 support before calling.
unsafe { scal_f64_avx2(alpha, x) }
}
#[cfg(not(target_arch = "x86_64"))]
{
super::scalar::scal_f64(alpha, x);
}
}
pub fn nrm2_f32(x: &[f32]) -> f32 {
super::scalar::nrm2_f32(x)
#[cfg(target_arch = "x86_64")]
{
// SAFETY: tier() verified AVX2+FMA.
unsafe { nrm2_f32_avx2(x) }
}
#[cfg(not(target_arch = "x86_64"))]
{
super::scalar::nrm2_f32(x)
}
}
pub fn nrm2_f64(x: &[f64]) -> f64 {
super::scalar::nrm2_f64(x)
#[cfg(target_arch = "x86_64")]
{
// SAFETY: tier() verified AVX2+FMA.
unsafe { nrm2_f64_avx2(x) }
}
#[cfg(not(target_arch = "x86_64"))]
{
super::scalar::nrm2_f64(x)
}
}
pub fn asum_f32(x: &[f32]) -> f32 {
super::scalar::asum_f32(x)
#[cfg(target_arch = "x86_64")]
{
// SAFETY: tier() verified AVX2.
unsafe { asum_f32_avx2(x) }
}
#[cfg(not(target_arch = "x86_64"))]
{
super::scalar::asum_f32(x)
}
}
pub fn asum_f64(x: &[f64]) -> f64 {
super::scalar::asum_f64(x)
#[cfg(target_arch = "x86_64")]
{
// SAFETY: tier() verified AVX2.
unsafe { asum_f64_avx2(x) }
}
#[cfg(not(target_arch = "x86_64"))]
{
super::scalar::asum_f64(x)
}
}

// ── AVX2 intrinsic implementations ─────────────────────────────
Expand Down Expand Up @@ -677,6 +724,201 @@ mod avx2 {
i += 1;
}
}

// ── scal: x[i] *= alpha ────────────────────────────────────────

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn scal_f32_avx2(alpha: f32, x: &mut [f32]) {
use core::arch::x86_64::*;
let n = x.len();
let valpha = _mm256_set1_ps(alpha);
let mut i = 0;
while i + 8 <= n {
let v = _mm256_loadu_ps(x.as_ptr().add(i));
_mm256_storeu_ps(x.as_mut_ptr().add(i), _mm256_mul_ps(v, valpha));
i += 8;
}
while i < n {
x[i] *= alpha;
i += 1;
}
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn scal_f64_avx2(alpha: f64, x: &mut [f64]) {
use core::arch::x86_64::*;
let n = x.len();
let valpha = _mm256_set1_pd(alpha);
let mut i = 0;
while i + 4 <= n {
let v = _mm256_loadu_pd(x.as_ptr().add(i));
_mm256_storeu_pd(x.as_mut_ptr().add(i), _mm256_mul_pd(v, valpha));
i += 4;
}
while i < n {
x[i] *= alpha;
i += 1;
}
}

// ── nrm2: sqrt(Σ x[i]²) ────────────────────────────────────────
//
// Two-accumulator unroll + FMA for the squared sum, scalar sqrt at
// the end. SIMD horizontal reduce ordering differs from the strict
// left-fold the scalar reference uses, so the ULP error can drift
// by 1-2 ULP on long vectors — same tolerance the existing
// `dot_f32_avx2` carries, accepted in BLAS-1.

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn nrm2_f32_avx2(x: &[f32]) -> f32 {
use core::arch::x86_64::*;
let n = x.len();
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut i = 0;
while i + 16 <= n {
let v0 = _mm256_loadu_ps(x.as_ptr().add(i));
let v1 = _mm256_loadu_ps(x.as_ptr().add(i + 8));
acc0 = _mm256_fmadd_ps(v0, v0, acc0);
acc1 = _mm256_fmadd_ps(v1, v1, acc1);
i += 16;
}
while i + 8 <= n {
let v = _mm256_loadu_ps(x.as_ptr().add(i));
acc0 = _mm256_fmadd_ps(v, v, acc0);
i += 8;
}
acc0 = _mm256_add_ps(acc0, acc1);
let hi = _mm256_extractf128_ps(acc0, 1);
let lo = _mm256_castps256_ps128(acc0);
let sum128 = _mm_add_ps(lo, hi);
let shuf = _mm_movehdup_ps(sum128);
let sums = _mm_add_ps(sum128, shuf);
let shuf2 = _mm_movehl_ps(sums, sums);
let result = _mm_add_ss(sums, shuf2);
let mut total = _mm_cvtss_f32(result);
while i < n {
total += x[i] * x[i];
i += 1;
}
total.sqrt()
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn nrm2_f64_avx2(x: &[f64]) -> f64 {
use core::arch::x86_64::*;
let n = x.len();
let mut acc0 = _mm256_setzero_pd();
let mut acc1 = _mm256_setzero_pd();
let mut i = 0;
while i + 8 <= n {
let v0 = _mm256_loadu_pd(x.as_ptr().add(i));
let v1 = _mm256_loadu_pd(x.as_ptr().add(i + 4));
acc0 = _mm256_fmadd_pd(v0, v0, acc0);
acc1 = _mm256_fmadd_pd(v1, v1, acc1);
i += 8;
}
while i + 4 <= n {
let v = _mm256_loadu_pd(x.as_ptr().add(i));
acc0 = _mm256_fmadd_pd(v, v, acc0);
i += 4;
}
acc0 = _mm256_add_pd(acc0, acc1);
let hi = _mm256_extractf128_pd(acc0, 1);
let lo = _mm256_castpd256_pd128(acc0);
let sum128 = _mm_add_pd(lo, hi);
let shuf = _mm_unpackhi_pd(sum128, sum128);
let result = _mm_add_sd(sum128, shuf);
let mut total = _mm_cvtsd_f64(result);
while i < n {
total += x[i] * x[i];
i += 1;
}
total.sqrt()
}

// ── asum: Σ |x[i]| ─────────────────────────────────────────────
//
// Abs via AND with sign-bit-cleared mask (one AVX instruction —
// VANDPS), horizontal sum at the end. Same ordering caveat as
// nrm2.

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn asum_f32_avx2(x: &[f32]) -> f32 {
use core::arch::x86_64::*;
let n = x.len();
// Sign-bit-cleared mask: 0x7FFFFFFF in every lane.
let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFF_FFFFi32));
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut i = 0;
while i + 16 <= n {
let v0 = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i)), abs_mask);
let v1 = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i + 8)), abs_mask);
acc0 = _mm256_add_ps(acc0, v0);
acc1 = _mm256_add_ps(acc1, v1);
i += 16;
}
while i + 8 <= n {
let v = _mm256_and_ps(_mm256_loadu_ps(x.as_ptr().add(i)), abs_mask);
acc0 = _mm256_add_ps(acc0, v);
i += 8;
}
acc0 = _mm256_add_ps(acc0, acc1);
let hi = _mm256_extractf128_ps(acc0, 1);
let lo = _mm256_castps256_ps128(acc0);
let sum128 = _mm_add_ps(lo, hi);
let shuf = _mm_movehdup_ps(sum128);
let sums = _mm_add_ps(sum128, shuf);
let shuf2 = _mm_movehl_ps(sums, sums);
let result = _mm_add_ss(sums, shuf2);
let mut total = _mm_cvtss_f32(result);
while i < n {
total += x[i].abs();
i += 1;
}
total
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn asum_f64_avx2(x: &[f64]) -> f64 {
use core::arch::x86_64::*;
let n = x.len();
let abs_mask = _mm256_castsi256_pd(_mm256_set1_epi64x(0x7FFF_FFFF_FFFF_FFFFi64));
let mut acc0 = _mm256_setzero_pd();
let mut acc1 = _mm256_setzero_pd();
let mut i = 0;
while i + 8 <= n {
let v0 = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i)), abs_mask);
let v1 = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i + 4)), abs_mask);
acc0 = _mm256_add_pd(acc0, v0);
acc1 = _mm256_add_pd(acc1, v1);
i += 8;
}
while i + 4 <= n {
let v = _mm256_and_pd(_mm256_loadu_pd(x.as_ptr().add(i)), abs_mask);
acc0 = _mm256_add_pd(acc0, v);
i += 4;
}
acc0 = _mm256_add_pd(acc0, acc1);
let hi = _mm256_extractf128_pd(acc0, 1);
let lo = _mm256_castpd256_pd128(acc0);
let sum128 = _mm_add_pd(lo, hi);
let shuf = _mm_unpackhi_pd(sum128, sum128);
let result = _mm_add_sd(sum128, shuf);
let mut total = _mm_cvtsd_f64(result);
while i < n {
total += x[i].abs();
i += 1;
}
total
}
}

// ═══════════════════════════════════════════════════════════════════
Expand Down Expand Up @@ -760,4 +1002,63 @@ mod tests {
// Should be one of the valid tier values
assert!(nr == 4 || nr == 8 || nr == 16);
}

// ── TD-T6: parity sweep for the new AVX2 BLAS-1 kernels ────────
//
// The shim → real-intrinsic switch flipped scal/nrm2/asum from
// scalar-fallthrough to AVX2 chunked + scalar-tail kernels. Each
// new kernel: verify byte-equal (or ULP-tight for nrm2 which
// includes a sqrt and a different sum order) against the scalar
// reference across shapes that exercise the chunk-of-16, chunk-
// of-8, and scalar-tail code paths.

fn ref_scal(alpha: f32, x: &[f32]) -> Vec<f32> {
x.iter().map(|&v| v * alpha).collect()
}
fn ref_nrm2(x: &[f32]) -> f32 {
x.iter().map(|&v| v * v).sum::<f32>().sqrt()
}
fn ref_asum(x: &[f32]) -> f32 {
x.iter().map(|&v| v.abs()).sum()
}

#[test]
fn td_t6_scal_f32_parity() {
for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] {
let alpha = 1.5f32;
let init: Vec<f32> = (0..n).map(|i| (i as f32 * 0.5) - 1.0).collect();
let expected = ref_scal(alpha, &init);
let mut got = init.clone();
scal_f32(alpha, &mut got);
for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() {
assert_eq!(g.to_bits(), e.to_bits(), "scal_f32 n={n} i={i}: got {g} want {e}");
}
}
}

#[test]
fn td_t6_nrm2_f32_parity() {
for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] {
let x: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3) - 0.5).collect();
let expected = ref_nrm2(&x);
let got = nrm2_f32(&x);
// ULP tolerance because SIMD reduce order differs from
// strict left-fold; nrm2 also includes the final sqrt.
let abs_err = (got - expected).abs();
let rel_tol = expected.abs() * 1e-5 + 1e-6;
assert!(abs_err <= rel_tol, "nrm2_f32 n={n}: got {got} want {expected} (err {abs_err})");
}
}

#[test]
fn td_t6_asum_f32_parity() {
for n in [0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 64, 100] {
let x: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3) - 0.5).collect();
let expected = ref_asum(&x);
let got = asum_f32(&x);
let abs_err = (got - expected).abs();
let rel_tol = expected.abs() * 1e-5 + 1e-6;
assert!(abs_err <= rel_tol, "asum_f32 n={n}: got {got} want {expected} (err {abs_err})");
}
}
}
Loading