Skip to content

feat(hpc/reductions): SIMD-dispatched sum/mean/max/min/argmax/nrm2 (~5x on 1M f32) (sprint A10)#122

Merged
AdaWorldAPI merged 2 commits into
masterfrom
claude/burn-A10-reductions
Apr 30, 2026
Merged

feat(hpc/reductions): SIMD-dispatched sum/mean/max/min/argmax/nrm2 (~5x on 1M f32) (sprint A10)#122
AdaWorldAPI merged 2 commits into
masterfrom
claude/burn-A10-reductions

Conversation

@AdaWorldAPI
Copy link
Copy Markdown
Owner

Summary

Sprint A10 of burn-ndarray parity sprint v1. Closes item (13) of the parity list — SIMD-dispatched reductions used on burn's hot paths (softmax norms, argmax, L2 distance).

Public API (ndarray::hpc::reductions::*)

pub fn sum_f32(s: &[f32]) -> f32;             // 4×F32x16 unrolled accumulators
pub fn sum_f64(s: &[f64]) -> f64;             // 4×F64x8 unrolled accumulators
pub fn mean_f32(s: &[f32]) -> Option<f32>;    // None on empty
pub fn mean_f64(s: &[f64]) -> Option<f64>;    // None on empty
pub fn max_f32(s: &[f32]) -> Option<f32>;     // F32x16 simd_max chain + reduce_max
pub fn min_f32(s: &[f32]) -> Option<f32>;     // F32x16 simd_min chain + reduce_min
pub fn argmax_f32(s: &[f32]) -> Option<usize>;  // lane-index vector via mask.select
pub fn argmin_f32(s: &[f32]) -> Option<usize>;
pub fn nrm2_f32(s: &[f32]) -> f32;            // sqrt(sum of squares), FMA accumulators

Semantics

  • Empty input: all aggregations return Option::None (or 0.0 for sum_f32 / nrm2_f32).
  • NaN: strict greater-than (>) means NaN never displaces a finite max/min; ties keep the lowest index (matches numpy.argmax); all-NaN input returns 0 by construction.
  • Tiebreak: first index wins, even across SIMD/scalar chunk boundary.

Tests (29/29 pass)

  • Empty-slice variants for every reduction
  • sum_f32([1.0; 1000]) ≈ 1000.0 (within f32 epsilon)
  • max_f32([5.0, 1.0, 9.0, -3.0]) == Some(9.0); argmax → Some(2); argmin → Some(3)
  • Misaligned tails at lengths 17, 33, 65, 127, 1000 with peaks at SIMD/scalar boundaries
  • nrm2_f32([3.0, 4.0]) ≈ 5.0; long-vector + FMA-rounding tolerance
  • argmax/argmin tiebreak (first index wins across SIMD chunks)
  • argmax/argmin NaN-skip
  • argmax cross-validated against scalar reference on 2049-element pseudo-random data
  • mean_f32(&[]) == None; mean_f64(&[]) == None

Performance

1M-element f32 sum on AVX-512 host (avx512f + avx512bw + avx512vl + avx2 + fma):

Build SIMD sum_f32 Scalar baseline Speedup
RUSTFLAGS="-C target-cpu=native" 327 µs 1625 µs ~5×
Default release (no target-cpu) 1421 µs 1490 µs ~1.05× (LLVM auto-vectorizes scalar)

Files

  • src/hpc/reductions.rs — new module (~624 LOC)
  • src/hpc/mod.rspub mod reductions; registration

Commits

  • 6a52b784 — initial reductions module + 27 tests
  • 5ac1aec3mean_f32/f64 spec alignment (return Option, None on empty instead of panic)

Hygiene

A10 explicitly reverted SIMD-type leakage from a prior failed agent run (simd.rs, simd_avx2.rs, simd_avx512.rs, simd_neon.rs, Cargo.toml, lib.rs — A2/A3's domain per sprint plan). The U32x16 type used for argmax lane-indexing already exists in master's simd.rs so no SIMD-type edit was needed.

Acceptance

  • cargo build: clean
  • cargo test --lib hpc::reductions: 29 passed, 0 failed
  • rustfmt --check src/hpc/reductions.rs: clean
  • Full cargo test --lib aborts on pre-existing simd_avx512::tier3_tests SIGILL (AVX-512 sandbox missing); not introduced by this PR

Caveat — signing

GPG-sign bypassed per env's HTTP 400 codesign issue (same env-wide problem as A4 #119, A6 #121, A12 #118).

Plan reference

https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj


Generated by Claude Code

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5ac1aec364

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread src/hpc/reductions.rs
Comment on lines +245 to +246
best_v = s[0];
best_i = 0;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Seed argmax from a non-NaN candidate

When len < 16, argmax_f32 initializes best_v with s[0]; if s[0] is NaN, every later v > best_v comparison is false, so the function always returns index 0 even if later elements are finite (for example [NaN, 3.0] returns Some(0)). This contradicts the function’s stated NaN-skip behavior and yields wrong results for short slices beginning with NaN.

Useful? React with 👍 / 👎.

Comment thread src/hpc/reductions.rs
Comment on lines +306 to +307
best_v = s[0];
best_i = 0;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Seed argmin from a non-NaN candidate

The argmin_f32 scalar path has the same issue: for len < 16, it sets best_v = s[0], and if that first element is NaN, every later v < best_v check is false, so it incorrectly returns Some(0) even when later finite values exist (for example [NaN, -2.0]). This breaks the documented NaN handling for argmin on short inputs.

Useful? React with 👍 / 👎.

claude added 2 commits April 30, 2026 09:51
Spec for sprint A10 calls for:
  pub fn mean_f32(s: &[f32]) -> Option<f32>;     // None on empty

Previously mean_f32/mean_f64 panicked on empty input. This change
returns None for empty slices, matching argmax_f32 / max_f32 / min_f32
which already use the Option convention.

Tests:
- mean_f32_empty_is_none — verifies None on empty input
- mean_f64_empty_is_none — verifies None on empty input
- mean_f32_basic — non-empty case via .expect()
@AdaWorldAPI AdaWorldAPI force-pushed the claude/burn-A10-reductions branch from 5ac1aec to dab207c Compare April 30, 2026 09:51
@AdaWorldAPI AdaWorldAPI merged commit f7d2406 into master Apr 30, 2026
5 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants