diff --git a/AGENTS.md b/AGENTS.md index a42d0af..bd3bf1f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -111,6 +111,9 @@ invariant over the convenient edit. `git --no-pager log`, `git --no-pager show`, `git --no-pager blame`) to inspect changes/history - **ALWAYS** use `git --no-pager` when reading git output - Suggest git commands that modify version control state for the user to run manually +- When suggesting branch names, prefer `{type}/{issue}-descriptor-or-two`, e.g. `fix/307-topology-validation`, + `perf/315-bench-profile`, or `doc/329-branch-guidance`. If an environment requires an owner/tool prefix, + keep this structure after the prefix, e.g. `codex/fix/307-topology-validation`. ### Commit Messages @@ -240,6 +243,9 @@ just examples # Run all examples - Python setup: `uv sync --group dev` (or `just python-sync`) - Python tests: `just test-python` - Run a single test (by name filter): `cargo test solve_2x2_basic` (or the full path: `cargo test lu::tests::solve_2x2_basic`) + Cargo accepts only one positional test filter. To run multiple focused + filters, run separate `cargo test ` commands rather than passing + multiple filter arguments. - Run exact-feature tests: `cargo test --features exact --verbose` (or `just test-exact`) - Run examples: `just examples` (or `cargo run --example det_5x5` / `cargo run --example solve_5x5` / `cargo run --example ldlt_solve_3x3` / `cargo run --example const_det_4x4` / diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e63c36..31b845a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Feat!(matrix): enforce fallible matrix invariants [`e26c283`](https://github.com/acgetchell/la-stack/commit/e26c28358b2358100353b2895441b68892e92cd7) + ### Changed - Remove redundant cache restore keys for cargo-llvm-cov [`f75a01c`](https://github.com/acgetchell/la-stack/commit/f75a01c99c8dbcc8b6ffc36ae9f94ba968a2f111) @@ -37,6 +41,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Documentation - Document feature requirement for exact APIs [`19b10d5`](https://github.com/acgetchell/la-stack/commit/19b10d552e83b6a7f9e91695b4850b8fab3f4550) +- Document scalar scope and release roadmap [`bfb0393`](https://github.com/acgetchell/la-stack/commit/bfb039386588f94b95561c610181ca6d486acd6e) + + - Clarify that la-stack intentionally supports f64 floating-point APIs plus optional exact rationals, not alternate scalar families. + - Add a roadmap covering the v0.4.x stable-Rust issue sequence and the v0.5.0 generic_const_exprs anchor. + - Refresh generated changelog entries and archived changelog grouping. ### Maintenance diff --git a/Cargo.toml b/Cargo.toml index 3ff1452..dd725b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,8 +28,8 @@ proptest = "1.11.0" [features] default = [ ] -bench = [ "criterion", "faer", "nalgebra" ] -exact = [ "num-bigint", "num-rational", "num-traits" ] +bench = [ "dep:criterion", "dep:faer", "dep:nalgebra" ] +exact = [ "dep:num-bigint", "dep:num-rational", "dep:num-traits" ] [[example]] name = "exact_det_3x3" diff --git a/README.md b/README.md index b89c732..e266d8d 100644 --- a/README.md +++ b/README.md @@ -76,25 +76,29 @@ Solve a 5×5 system via LU: ```rust use la_stack::prelude::*; -// This system requires pivoting (a[0][0] = 0), so it's a good LU demo. -// A = J - I: zeros on diagonal, ones elsewhere. -let a = Matrix::<5>::from_rows([ - [0.0, 1.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 1.0, 1.0, 0.0, 1.0], - [1.0, 1.0, 1.0, 1.0, 0.0], -]); - -let b = Vector::<5>::new([14.0, 13.0, 12.0, 11.0, 10.0]); - -let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); -let x = lu.solve_vec(b).unwrap().into_array(); - -// Floating-point rounding is expected; compare with a tolerance. -let expected = [1.0, 2.0, 3.0, 4.0, 5.0]; -for (x_i, e_i) in x.iter().zip(expected.iter()) { - assert!((*x_i - *e_i).abs() <= 1e-12); +fn main() -> Result<(), LaError> { + // This system requires pivoting (a[0][0] = 0), so it's a good LU demo. + // A = J - I: zeros on diagonal, ones elsewhere. + let a = Matrix::<5>::from_rows([ + [0.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 0.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ]); + + let b = Vector::<5>::new([14.0, 13.0, 12.0, 11.0, 10.0]); + + let lu = a.lu(DEFAULT_PIVOT_TOL)?; + let x = lu.solve_vec(b)?.into_array(); + + // Floating-point rounding is expected; compare with a tolerance. + let expected = [1.0, 2.0, 3.0, 4.0, 5.0]; + for (x_i, e_i) in x.iter().zip(expected.iter()) { + assert!((*x_i - *e_i).abs() <= 1e-12); + } + + Ok(()) } ``` @@ -106,17 +110,21 @@ For symmetric positive-definite matrices, `LDL^T` is essentially a square-root-f ```rust use la_stack::prelude::*; -// This matrix is symmetric positive-definite (A = L*L^T) so LDLT works without pivoting. -let a = Matrix::<5>::from_rows([ - [1.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 2.0, 1.0, 0.0, 0.0], - [0.0, 1.0, 2.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 2.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 2.0], -]); - -let det = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap().det(); -assert!((det - 1.0).abs() <= 1e-12); +fn main() -> Result<(), LaError> { + // This matrix is symmetric positive-definite (A = L*L^T) so LDLT works without pivoting. + let a = Matrix::<5>::from_rows([ + [1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 2.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 2.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 2.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 2.0], + ]); + + let det = a.ldlt(DEFAULT_SINGULAR_TOL)?.det()?; + assert!((det - 1.0).abs() <= 1e-12); + + Ok(()) +} ``` > ⚠️ **LDLT invariant:** The input matrix must be **symmetric**. Asymmetric @@ -133,7 +141,7 @@ assert!((det - 1.0).abs() <= 1e-12); `det_direct()` is a `const fn` providing closed-form determinants for D=0–4, using fused multiply-add where applicable. `Matrix::<0>::zero().det_direct()` -returns `Some(1.0)` (the empty-product convention). For D=1–4, cofactor +returns `Ok(Some(1.0))` (the empty-product convention). For D=1–4, cofactor expansion bypasses LU factorization entirely. This enables compile-time evaluation when inputs are known: @@ -141,7 +149,7 @@ evaluation when inputs are known: use la_stack::prelude::*; // Evaluated entirely at compile time — no runtime cost. -const DET: Option = { +const DET: Result, LaError> = { let m = Matrix::<3>::from_rows([ [2.0, 0.0, 0.0], [0.0, 3.0, 0.0], @@ -149,7 +157,7 @@ const DET: Option = { ]); m.det_direct() }; -assert_eq!(DET, Some(30.0)); +assert_eq!(DET, Ok(Some(30.0))); ``` The public `det()` method automatically dispatches through the closed-form path @@ -181,23 +189,27 @@ la-stack = { version = "0.4.1", features = ["exact"] } ```rust,ignore use la_stack::prelude::*; -// Exact determinant -let m = Matrix::<3>::from_rows([ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - [7.0, 8.0, 9.0], -]); -assert_eq!(m.det_sign_exact().unwrap(), 0); // exactly singular - -let det = m.det_exact().unwrap(); -assert_eq!(det, BigRational::from_integer(0.into())); // exact zero - -// Exact linear system solve -let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); -let b = Vector::<2>::new([5.0, 11.0]); -let x = a.solve_exact_f64(b).unwrap().into_array(); -assert!((x[0] - 1.0).abs() <= f64::EPSILON); -assert!((x[1] - 2.0).abs() <= f64::EPSILON); +fn main() -> Result<(), LaError> { + // Exact determinant + let m = Matrix::<3>::from_rows([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ]); + assert_eq!(m.det_sign_exact()?, 0); // exactly singular + + let det = m.det_exact()?; + assert_eq!(det, BigRational::from_integer(0.into())); // exact zero + + // Exact linear system solve + let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); + let b = Vector::<2>::new([5.0, 11.0]); + let x = a.solve_exact_f64(b)?.into_array(); + assert!((x[0] - 1.0).abs() <= f64::EPSILON); + assert!((x[1] - 2.0).abs() <= f64::EPSILON); + + Ok(()) +} ``` With the `exact` feature enabled, `BigInt` and `BigRational` are re-exported @@ -222,19 +234,24 @@ adaptive-precision logic for geometric predicates: ```rust,ignore use la_stack::prelude::*; -let m = Matrix::<3>::identity(); -if let Some(bound) = m.det_errbound() { - let det = m.det_direct().unwrap(); - if det.abs() > bound { - // f64 sign is guaranteed correct - let sign = det.signum() as i8; +fn main() -> Result<(), LaError> { + let m = Matrix::<3>::identity(); + if let Some(bound) = m.det_errbound()? { + if let Some(det) = m.det_direct()? { + if det.abs() > bound { + // f64 sign is guaranteed correct + let sign = det.signum() as i8; + } else { + // Fall back to exact arithmetic (requires `exact` feature) + let sign = m.det_sign_exact()?; + } + } } else { - // Fall back to exact arithmetic (requires `exact` feature) - let sign = m.det_sign_exact().unwrap(); + // D ≥ 5: no fast filter, use exact directly (requires `exact` feature) + let sign = m.det_sign_exact()?; } -} else { - // D ≥ 5: no fast filter, use exact directly (requires `exact` feature) - let sign = m.det_sign_exact().unwrap(); + + Ok(()) } ``` diff --git a/benches/exact.rs b/benches/exact.rs index 228cb45..b7d60f1 100644 --- a/benches/exact.rs +++ b/benches/exact.rs @@ -17,7 +17,7 @@ //! empirical evidence for `docs/PERFORMANCE.md`. use criterion::{BenchmarkGroup, Criterion, measurement::WallTime}; -use la_stack::{Matrix, Vector}; +use la_stack::{DEFAULT_PIVOT_TOL, Matrix, Vector}; use pastey::paste; use std::hint::black_box; @@ -179,7 +179,7 @@ macro_rules! gen_exact_benches_for_dim { [].bench_function("det", |bencher| { bencher.iter(|| { let det = black_box(a) - .det(la_stack::DEFAULT_PIVOT_TOL) + .det(DEFAULT_PIVOT_TOL) .expect("diagonally dominant matrix is non-singular"); black_box(det); }); diff --git a/benches/vs_linalg.rs b/benches/vs_linalg.rs index ffbd806..9365b79 100644 --- a/benches/vs_linalg.rs +++ b/benches/vs_linalg.rs @@ -8,8 +8,9 @@ //! - Matrix infinity norm is the maximum absolute row sum on all sides. use criterion::Criterion; -use faer::linalg::solvers::Solve; +use faer::linalg::solvers::{PartialPivLu, Solve}; use faer::perm::PermRef; +use la_stack::{DEFAULT_PIVOT_TOL, Matrix, Vector}; use pastey::paste; use std::hint::black_box; @@ -43,7 +44,7 @@ fn faer_perm_sign(p: PermRef<'_, usize>) -> f64 { } } -fn faer_det_from_partial_piv_lu(lu: &faer::linalg::solvers::PartialPivLu) -> f64 { +fn faer_det_from_partial_piv_lu(lu: &PartialPivLu) -> f64 { // For PA = LU with unit-lower L, det(A) = det(P) * det(U). let u = lu.U(); let mut det = 1.0; @@ -128,10 +129,10 @@ macro_rules! gen_vs_linalg_benches_for_dim { paste! {{ // Isolate each dimension's inputs to keep types and captures clean. { - let a = la_stack::Matrix::<$d>::from_rows(make_matrix_rows::<$d>()); - let rhs = la_stack::Vector::<$d>::new(make_vector_array::<$d>(0.0)); - let v1 = la_stack::Vector::<$d>::new(make_vector_array::<$d>(0.0)); - let v2 = la_stack::Vector::<$d>::new(make_vector_array::<$d>(1.0)); + let a = Matrix::<$d>::from_rows(make_matrix_rows::<$d>()); + let rhs = Vector::<$d>::new(make_vector_array::<$d>(0.0)); + let v1 = Vector::<$d>::new(make_vector_array::<$d>(0.0)); + let v2 = Vector::<$d>::new(make_vector_array::<$d>(1.0)); let na = nalgebra::SMatrix::::from_fn(|r, c| matrix_entry::<$d>(r, c)); let nrhs = nalgebra::SVector::::from_fn(|i, _| vector_entry(i, 0.0)); @@ -145,7 +146,7 @@ macro_rules! gen_vs_linalg_benches_for_dim { // Precompute LU once for solve-only / det-only benchmarks. let a_lu = a - .lu(la_stack::DEFAULT_PIVOT_TOL) + .lu(DEFAULT_PIVOT_TOL) .expect("matrix should be non-singular"); let na_lu = na.clone().lu(); let fa_lu = fa.partial_piv_lu(); @@ -156,9 +157,12 @@ macro_rules! gen_vs_linalg_benches_for_dim { [].bench_function("la_stack_det_via_lu", |bencher| { bencher.iter(|| { let lu = black_box(a) - .lu(la_stack::DEFAULT_PIVOT_TOL) + .lu(DEFAULT_PIVOT_TOL) .expect("matrix should be non-singular"); - let det = lu.det(); + let det = match lu.det() { + Ok(det) => det, + Err(err) => panic!("finite benchmark matrix determinant failed: {err}"), + }; black_box(det); }); }); @@ -183,7 +187,7 @@ macro_rules! gen_vs_linalg_benches_for_dim { [].bench_function("la_stack_det", |bencher| { bencher.iter(|| { let det = black_box(a) - .det(la_stack::DEFAULT_PIVOT_TOL) + .det(DEFAULT_PIVOT_TOL) .expect("matrix should be non-singular"); black_box(det); }); @@ -193,7 +197,7 @@ macro_rules! gen_vs_linalg_benches_for_dim { [].bench_function("la_stack_lu", |bencher| { bencher.iter(|| { let lu = black_box(a) - .lu(la_stack::DEFAULT_PIVOT_TOL) + .lu(DEFAULT_PIVOT_TOL) .expect("matrix should be non-singular"); let _ = black_box(lu); }); @@ -217,7 +221,7 @@ macro_rules! gen_vs_linalg_benches_for_dim { [].bench_function("la_stack_lu_solve", |bencher| { bencher.iter(|| { let lu = black_box(a) - .lu(la_stack::DEFAULT_PIVOT_TOL) + .lu(DEFAULT_PIVOT_TOL) .expect("matrix should be non-singular"); let x = lu .solve_vec(black_box(rhs)) @@ -273,7 +277,10 @@ macro_rules! gen_vs_linalg_benches_for_dim { // === Determinant from a precomputed LU === [].bench_function("la_stack_det_from_lu", |bencher| { bencher.iter(|| { - let det = a_lu.det(); + let det = match a_lu.det() { + Ok(det) => det, + Err(err) => panic!("finite benchmark matrix determinant failed: {err}"), + }; black_box(det); }); }); @@ -295,7 +302,7 @@ macro_rules! gen_vs_linalg_benches_for_dim { // === Vector dot product === [].bench_function("la_stack_dot", |bencher| { bencher.iter(|| { - let result = black_box(v1).dot(black_box(v2)); + let result = black_box(v1).dot(black_box(v2)).unwrap(); black_box(result); }); }); @@ -322,7 +329,7 @@ macro_rules! gen_vs_linalg_benches_for_dim { // === Vector norm squared === [].bench_function("la_stack_norm2_sq", |bencher| { bencher.iter(|| { - let result = black_box(v1).norm2_sq(); + let result = black_box(v1).norm2_sq().unwrap(); black_box(result); }); }); @@ -349,7 +356,7 @@ macro_rules! gen_vs_linalg_benches_for_dim { // === Matrix infinity norm (max absolute row sum) === [].bench_function("la_stack_inf_norm", |bencher| { bencher.iter(|| { - let result = black_box(a).inf_norm(); + let result = black_box(a).inf_norm().unwrap(); black_box(result); }); }); diff --git a/docs/roadmap.md b/docs/roadmap.md index c19ac5d..f98ec70 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -47,12 +47,17 @@ native Blocking / Is blocked by graph mirrors this order: error behavior across all dimensions as far as stable Rust allows. - [#120](https://github.com/acgetchell/la-stack/issues/120) - Run the parse-don't-validate `NonZero*` audit. +- [#126](https://github.com/acgetchell/la-stack/issues/126) - Add finite + `Matrix` and `Vector` proof types. +- [#125](https://github.com/acgetchell/la-stack/issues/125) - Add a Semgrep + guardrail against `unwrap` / `expect` in examples, benches, and doctests. - [#98](https://github.com/acgetchell/la-stack/issues/98) - Add random-input percentile benchmarks to the exact arithmetic suite. The broad shape is: document scope first, add downstream dispatch ergonomics, -clean up small API contracts, tighten validation, then finish with broader -invariant and benchmark work. +clean up small API contracts, tighten validation, make reusable invariants +explicit in proof-carrying types, lock the public examples and benchmarks into +proper error handling, then finish with broader benchmark work. ### v0.5.0 Generic Const Expressions diff --git a/examples/const_det_4x4.rs b/examples/const_det_4x4.rs index 4b3bb75..9348ab4 100644 --- a/examples/const_det_4x4.rs +++ b/examples/const_det_4x4.rs @@ -15,8 +15,9 @@ const MAT: Matrix<4> = Matrix::<4>::from_rows([ /// Determinant computed at compile time. const DET: f64 = match MAT.det_direct() { - Some(d) => d, - None => panic!("det_direct only supports D <= 4"), + Ok(Some(d)) => d, + Ok(None) => panic!("det_direct only supports D <= 4"), + Err(_) => panic!("matrix entries must be finite"), }; fn main() { diff --git a/examples/det_5x5.rs b/examples/det_5x5.rs index 899151d..5b82b78 100644 --- a/examples/det_5x5.rs +++ b/examples/det_5x5.rs @@ -15,7 +15,7 @@ fn main() -> Result<(), LaError> { // Compute via explicit LU factorization. let lu = a.lu(DEFAULT_PIVOT_TOL)?; - let det = lu.det(); + let det = lu.det()?; println!("det = {det}"); Ok(()) diff --git a/examples/exact_det_3x3.rs b/examples/exact_det_3x3.rs index 8ea1ab1..1bfc832 100644 --- a/examples/exact_det_3x3.rs +++ b/examples/exact_det_3x3.rs @@ -24,7 +24,7 @@ fn main() { [7.0, 8.0, 9.0], ]); - let det_f64_approx = m.det_direct().unwrap(); + let det_f64_approx = m.det_direct().unwrap().unwrap(); let det_exact = m.det_exact().unwrap(); let det_exact_as_f64 = m.det_exact_f64().unwrap(); diff --git a/examples/exact_solve_3x3.rs b/examples/exact_solve_3x3.rs index 045b059..88cc1ee 100644 --- a/examples/exact_solve_3x3.rs +++ b/examples/exact_solve_3x3.rs @@ -26,7 +26,12 @@ fn main() { // f64 LU solve (using zero pivot tolerance since the matrix is nearly singular // and would be rejected by DEFAULT_PIVOT_TOL). - let lu_x = a.lu(0.0).unwrap().solve_vec(b).unwrap().into_array(); + let lu_x = a + .lu(Tolerance::new(0.0).unwrap()) + .unwrap() + .solve_vec(b) + .unwrap() + .into_array(); // Exact solve. let exact_x = a.solve_exact(b).unwrap(); diff --git a/examples/ldlt_solve_3x3.rs b/examples/ldlt_solve_3x3.rs index 6fc5073..1976529 100644 --- a/examples/ldlt_solve_3x3.rs +++ b/examples/ldlt_solve_3x3.rs @@ -17,7 +17,7 @@ fn main() -> Result<(), LaError> { let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?; let x = ldlt.solve_vec(b)?.into_array(); - let det = ldlt.det(); + let det = ldlt.det()?; println!("A (3×3 SPD tridiagonal):"); for r in 0..3 { diff --git a/justfile b/justfile index e1459e6..a284312 100644 --- a/justfile +++ b/justfile @@ -320,7 +320,6 @@ doc-check: examples: #!/usr/bin/env bash set -euo pipefail - cargo build --examples cargo build --features exact --examples exe_suffix="" @@ -328,7 +327,11 @@ examples: exe_suffix=".exe" fi - for example in det_5x5 solve_5x5 ldlt_solve_3x3 const_det_4x4 exact_det_3x3 exact_sign_3x3 exact_solve_3x3; do + shopt -s nullglob + for example_path in examples/*.rs; do + [[ -f "${example_path}" ]] || continue + example="${example_path##*/}" + example="${example%.rs}" "target/debug/examples/${example}${exe_suffix}" done diff --git a/src/exact.rs b/src/exact.rs index 1146467..ef7fdd2 100644 --- a/src/exact.rs +++ b/src/exact.rs @@ -489,10 +489,13 @@ impl Matrix { /// ``` /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); - /// let det = m.det_exact().unwrap(); + /// let det = m.det_exact()?; /// // det = 1*4 - 2*3 = -2 (exact) /// assert_eq!(det, BigRational::from_integer((-2).into())); + /// # Ok(()) + /// # } /// ``` /// /// # Errors @@ -515,9 +518,12 @@ impl Matrix { /// ``` /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); - /// let det = m.det_exact_f64().unwrap(); + /// let det = m.det_exact_f64()?; /// assert!((det - (-2.0)).abs() <= f64::EPSILON); + /// # Ok(()) + /// # } /// ``` /// /// # Errors @@ -570,12 +576,15 @@ impl Matrix { /// ``` /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// // A x = b where A = [[1,2],[3,4]], b = [5, 11] → x = [1, 2] /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); /// let b = Vector::<2>::new([5.0, 11.0]); - /// let x = a.solve_exact(b).unwrap(); + /// let x = a.solve_exact(b)?; /// assert_eq!(x[0], BigRational::from_integer(1.into())); /// assert_eq!(x[1], BigRational::from_integer(2.into())); + /// # Ok(()) + /// # } /// ``` /// /// # Errors @@ -600,11 +609,14 @@ impl Matrix { /// ``` /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); /// let b = Vector::<2>::new([5.0, 11.0]); - /// let x = a.solve_exact_f64(b).unwrap().into_array(); + /// let x = a.solve_exact_f64(b)?.into_array(); /// assert!((x[0] - 1.0).abs() <= f64::EPSILON); /// assert!((x[1] - 2.0).abs() <= f64::EPSILON); + /// # Ok(()) + /// # } /// ``` /// /// # Errors @@ -661,9 +673,10 @@ impl Matrix { /// [7.0, 8.0, 9.0], /// ]); /// // This matrix is singular (row 3 = row 1 + row 2 in exact arithmetic). - /// assert_eq!(m.det_sign_exact().unwrap(), 0); + /// assert_eq!(m.det_sign_exact()?, 0); /// - /// assert_eq!(Matrix::<3>::identity().det_sign_exact().unwrap(), 1); + /// assert_eq!(Matrix::<3>::identity().det_sign_exact()?, 1); + /// # Ok::<(), LaError>(()) /// ``` /// /// # Errors @@ -672,18 +685,13 @@ impl Matrix { pub fn det_sign_exact(&self) -> Result { // Stage 1: f64 fast filter for D ≤ 4. // - // When entries are large (e.g. near f64::MAX) the determinant can - // overflow to infinity even though every individual entry is finite. - // In that case the fast filter is inconclusive; fall through to the - // exact Bareiss path. For NaN/±∞ entries IEEE 754 propagates - // non-finite through `det_direct()`, the `det_f64.is_finite()` - // guard fails, and we also fall through — validation then happens - // inside `bareiss_det_int` via `decompose_matrix`. - match self.det_direct() { - Some(det_f64) - if let Some(err) = self.det_errbound() - && det_f64.is_finite() => - { + // When entries are large (e.g. near f64::MAX), the f64 determinant or + // its error bound can overflow even though every individual entry is + // finite. In that case the fast filter is inconclusive; fall through to + // the exact Bareiss path. Stored NaN/±∞ entries are still rejected + // immediately with their source coordinates. + match (self.det_direct(), self.det_errbound()) { + (Ok(Some(det_f64)), Ok(Some(err))) => { if det_f64 > err { return Ok(1); } @@ -691,6 +699,11 @@ impl Matrix { return Ok(-1); } } + (Err(err @ LaError::NonFinite { row: Some(_), .. }), _) + | (_, Err(err @ LaError::NonFinite { row: Some(_), .. })) => return Err(err), + (Err(LaError::NonFinite { row: None, .. }), _) + | (_, Err(LaError::NonFinite { row: None, .. })) => {} + (Err(err), _) | (_, Err(err)) => return Err(err), _ => {} } @@ -715,6 +728,7 @@ mod tests { use super::*; use crate::DEFAULT_PIVOT_TOL; + use core::assert_matches; use num_traits::Signed; use pastey::paste; use std::array::from_fn; @@ -772,14 +786,14 @@ mod tests { #[test] fn []() { let mut m = Matrix::<$d>::identity(); - m.set(0, 0, f64::NAN); + assert_eq!(m.set(0, 0, f64::NAN), Some(())); assert_eq!(m.det_exact(), Err(LaError::NonFinite { row: Some(0), col: 0 })); } #[test] fn []() { let mut m = Matrix::<$d>::identity(); - m.set(0, 0, f64::INFINITY); + assert_eq!(m.set(0, 0, f64::INFINITY), Some(())); assert_eq!(m.det_exact(), Err(LaError::NonFinite { row: Some(0), col: 0 })); } } @@ -803,7 +817,7 @@ mod tests { #[test] fn []() { let mut m = Matrix::<$d>::identity(); - m.set(0, 0, f64::NAN); + assert_eq!(m.set(0, 0, f64::NAN), Some(())); assert_eq!(m.det_exact_f64(), Err(LaError::NonFinite { row: Some(0), col: 0 })); } } @@ -836,7 +850,7 @@ mod tests { } let m = Matrix::<$d>::from_rows(rows); let exact = m.det_exact_f64().unwrap(); - let direct = m.det_direct().unwrap(); + let direct = m.det_direct().unwrap().unwrap(); let eps = direct.abs().mul_add(1e-12, 1e-12); assert!((exact - direct).abs() <= eps); } @@ -1014,7 +1028,7 @@ mod tests { fn det_sign_exact_returns_err_on_nan_5x5() { // D ≥ 5 bypasses the fast filter, exercising the bareiss_det path. let mut m = Matrix::<5>::identity(); - m.set(2, 3, f64::NAN); + assert_eq!(m.set(2, 3, f64::NAN), Some(())); assert_eq!( m.det_sign_exact(), Err(LaError::NonFinite { @@ -1027,7 +1041,7 @@ mod tests { #[test] fn det_sign_exact_returns_err_on_infinity_5x5() { let mut m = Matrix::<5>::identity(); - m.set(0, 0, f64::INFINITY); + assert_eq!(m.set(0, 0, f64::INFINITY), Some(())); assert_eq!( m.det_sign_exact(), Err(LaError::NonFinite { @@ -1071,18 +1085,21 @@ mod tests { #[test] fn det_errbound_d0_is_zero() { - assert_eq!(Matrix::<0>::zero().det_errbound(), Some(0.0)); + assert_eq!(Matrix::<0>::zero().det_errbound(), Ok(Some(0.0))); } #[test] fn det_errbound_d1_is_zero() { - assert_eq!(Matrix::<1>::from_rows([[42.0]]).det_errbound(), Some(0.0)); + assert_eq!( + Matrix::<1>::from_rows([[42.0]]).det_errbound(), + Ok(Some(0.0)) + ); } #[test] fn det_errbound_d2_positive() { let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); - let bound = m.det_errbound().unwrap(); + let bound = m.det_errbound().unwrap().unwrap(); assert!(bound > 0.0); // bound = ERR_COEFF_2 * (|1*4| + |2*3|) = ERR_COEFF_2 * 10 assert!(crate::ERR_COEFF_2.mul_add(-10.0, bound).abs() < 1e-30); @@ -1091,7 +1108,7 @@ mod tests { #[test] fn det_errbound_d3_positive() { let m = Matrix::<3>::identity(); - let bound = m.det_errbound().unwrap(); + let bound = m.det_errbound().unwrap().unwrap(); assert!(bound > 0.0); } @@ -1099,14 +1116,14 @@ mod tests { fn det_errbound_d3_non_identity() { // Non-identity matrix to exercise all code paths in D=3 case let m = Matrix::<3>::from_rows([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 10.0]]); - let bound = m.det_errbound().unwrap(); + let bound = m.det_errbound().unwrap().unwrap(); assert!(bound > 0.0); } #[test] fn det_errbound_d4_positive() { let m = Matrix::<4>::identity(); - let bound = m.det_errbound().unwrap(); + let bound = m.det_errbound().unwrap().unwrap(); assert!(bound > 0.0); } @@ -1119,13 +1136,13 @@ mod tests { [0.0, 0.0, 3.0, 0.0], [0.0, 0.0, 0.0, 4.0], ]); - let bound = m.det_errbound().unwrap(); + let bound = m.det_errbound().unwrap().unwrap(); assert!(bound > 0.0); } #[test] fn det_errbound_d5_is_none() { - assert_eq!(Matrix::<5>::identity().det_errbound(), None); + assert_eq!(Matrix::<5>::identity().det_errbound(), Ok(None)); } // ----------------------------------------------------------------------- @@ -1263,7 +1280,7 @@ mod tests { #[test] fn bareiss_det_int_errs_on_nan() { let mut m = Matrix::<3>::identity(); - m.set(1, 2, f64::NAN); + assert_eq!(m.set(1, 2, f64::NAN), Some(())); assert_eq!( bareiss_det_int(&m), Err(LaError::NonFinite { @@ -1276,7 +1293,7 @@ mod tests { #[test] fn bareiss_det_int_errs_on_inf() { let mut m = Matrix::<2>::identity(); - m.set(0, 0, f64::INFINITY); + assert_eq!(m.set(0, 0, f64::INFINITY), Some(())); assert_eq!( bareiss_det_int(&m), Err(LaError::NonFinite { @@ -1504,7 +1521,7 @@ mod tests { #[test] fn []() { let mut a = Matrix::<$d>::identity(); - a.set(0, 0, f64::NAN); + assert_eq!(a.set(0, 0, f64::NAN), Some(())); let b = arbitrary_rhs::<$d>(); assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { row: Some(0), col: 0 })); } @@ -1512,7 +1529,7 @@ mod tests { #[test] fn []() { let mut a = Matrix::<$d>::identity(); - a.set(0, 0, f64::INFINITY); + assert_eq!(a.set(0, 0, f64::INFINITY), Some(())); let b = arbitrary_rhs::<$d>(); assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { row: Some(0), col: 0 })); } @@ -1567,7 +1584,7 @@ mod tests { #[test] fn []() { let mut a = Matrix::<$d>::identity(); - a.set(0, 0, f64::NAN); + assert_eq!(a.set(0, 0, f64::NAN), Some(())); let b = arbitrary_rhs::<$d>(); assert_eq!(a.solve_exact_f64(b), Err(LaError::NonFinite { row: Some(0), col: 0 })); } @@ -1723,7 +1740,7 @@ mod tests { fn solve_exact_singular_duplicate_rows() { let a = Matrix::<3>::from_rows([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]); let b = Vector::<3>::new([1.0, 2.0, 3.0]); - assert!(matches!(a.solve_exact(b), Err(LaError::Singular { .. }))); + assert_matches!(a.solve_exact(b), Err(LaError::Singular { .. })); } #[test] @@ -2090,17 +2107,17 @@ mod tests { } /// Determinant of the large-entry 3×3 is roughly `big^3`, which - /// overflows `f64`. `det_direct()` therefore returns `±∞`, the fast - /// filter inside `det_sign_exact` falls through on the `is_finite()` - /// guard, and the Bareiss fallback resolves the positive sign - /// correctly. `det_exact_f64` must report `Overflow`. + /// overflows `f64`. `det_direct()` therefore reports a computed + /// [`LaError::NonFinite`], the fast filter inside `det_sign_exact` + /// treats that as inconclusive, and the Bareiss fallback resolves the + /// positive sign correctly. `det_exact_f64` must report `Overflow`. #[test] fn det_sign_exact_large_entries_3x3_positive() { let big = f64::MAX / 2.0; let a = Matrix::<3>::from_rows([[big, 1.0, 1.0], [1.0, big, 1.0], [1.0, 1.0, big]]); // Fast filter is inconclusive (big^3 overflows f64 to +∞), so // this exercises the Bareiss cold path. - assert!(!a.det_direct().is_some_and(f64::is_finite)); + assert_matches!(a.det_direct(), Err(LaError::NonFinite { row: None, .. })); assert_eq!(a.det_sign_exact().unwrap(), 1); // Cross-validate: the exact `BigRational` determinant must agree // on sign with `det_sign_exact`, and `det_exact_f64` must overflow diff --git a/src/ldlt.rs b/src/ldlt.rs index ff110d0..7f59f60 100644 --- a/src/ldlt.rs +++ b/src/ldlt.rs @@ -13,12 +13,12 @@ use core::hint::cold_path; -use crate::LaError; use crate::matrix::Matrix; use crate::vector::Vector; +use crate::{LaError, Tolerance}; /// Relative tolerance used by LDLT's mandatory symmetry validation. -const LDLT_SYMMETRY_REL_TOL: f64 = 1e-12; +const LDLT_SYMMETRY_REL_TOL: Tolerance = Tolerance::new_unchecked(1e-12); /// LDLT factorization (`A = L D Lᵀ`) for symmetric positive (semi)definite matrices. /// @@ -40,7 +40,7 @@ const LDLT_SYMMETRY_REL_TOL: f64 = 1e-12; #[derive(Clone, Copy, Debug, PartialEq)] pub struct Ldlt { factors: Matrix, - tol: f64, + tol: Tolerance, } impl Ldlt { @@ -50,8 +50,7 @@ impl Ldlt { /// invalid tolerances, asymmetric inputs, non-finite values, and degenerate /// diagonals before callers can observe an [`Ldlt`] value. #[inline] - pub(crate) fn factor(a: Matrix, tol: f64) -> Result { - let tol = LaError::validate_tolerance(tol)?; + pub(crate) fn factor(a: Matrix, tol: Tolerance) -> Result { reject_asymmetric(&a)?; let mut f = a; @@ -63,7 +62,7 @@ impl Ldlt { cold_path(); return Err(LaError::non_finite_cell(j, j)); } - if d <= tol { + if d <= tol.get() { cold_path(); return Err(LaError::Singular { pivot_col: j }); } @@ -106,22 +105,32 @@ impl Ldlt { /// ``` /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// // Symmetric SPD matrix. /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]); - /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); + /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?; /// - /// assert!((ldlt.det() - 8.0).abs() <= 1e-12); + /// assert!((ldlt.det()? - 8.0).abs() <= 1e-12); + /// # Ok(()) + /// # } /// ``` + /// + /// # Errors + /// Returns [`LaError::NonFinite`] if the determinant product overflows to + /// NaN or infinity. #[inline] - #[must_use] - pub const fn det(&self) -> f64 { + pub const fn det(&self) -> Result { let mut det = 1.0; let mut i = 0; while i < D { det *= self.factors.rows[i][i]; + if !det.is_finite() { + cold_path(); + return Err(LaError::non_finite_at(i)); + } i += 1; } - det + Ok(det) } /// Solve `A x = b` using this LDLT factorization. @@ -189,7 +198,7 @@ impl Ldlt { cold_path(); return Err(LaError::non_finite_cell(i, i)); } - if diag <= self.tol { + if diag <= self.tol.get() { cold_path(); return Err(LaError::Singular { pivot_col: i }); } @@ -245,6 +254,7 @@ mod tests { use crate::DEFAULT_SINGULAR_TOL; + use core::assert_matches; use core::hint::black_box; use approx::assert_abs_diff_eq; @@ -258,7 +268,7 @@ mod tests { let a = Matrix::<$d>::identity(); let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); - assert_abs_diff_eq!(ldlt.det(), 1.0, epsilon = 1e-12); + assert_abs_diff_eq!(ldlt.det().unwrap(), 1.0, epsilon = 1e-12); let b_arr = { let mut arr = [0.0f64; $d]; @@ -313,7 +323,7 @@ mod tests { } acc }; - assert_abs_diff_eq!(ldlt.det(), expected_det, epsilon = 1e-12); + assert_abs_diff_eq!(ldlt.det().unwrap(), expected_det, epsilon = 1e-12); let b_arr = { let mut arr = [0.0f64; $d]; @@ -350,7 +360,7 @@ mod tests { assert_abs_diff_eq!(x[0], -0.125, epsilon = 1e-12); assert_abs_diff_eq!(x[1], 0.75, epsilon = 1e-12); - assert_abs_diff_eq!(ldlt.det(), 8.0, epsilon = 1e-12); + assert_abs_diff_eq!(ldlt.det().unwrap(), 8.0, epsilon = 1e-12); } #[test] @@ -479,6 +489,19 @@ mod tests { assert_eq!(err, LaError::NonFinite { row: None, col: 1 }); } + #[test] + fn det_rejects_product_overflow() { + let a = Matrix::<5>::from_rows([ + [1.0e100, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0e100, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0e100, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0e100, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0e100], + ]); + let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); + assert_eq!(ldlt.det(), Err(LaError::NonFinite { row: None, col: 3 })); + } + #[test] fn asymmetric_input_returns_typed_error() { // a[0][1] = 2.0 but a[1][0] = -2.0 → clearly asymmetric. @@ -495,15 +518,17 @@ mod tests { #[test] fn invalid_tolerance_rejected() { - let a = Matrix::<2>::identity(); - assert_eq!(a.ldlt(-1.0), Err(LaError::InvalidTolerance { value: -1.0 })); + assert_eq!( + Tolerance::new(-1.0), + Err(LaError::InvalidTolerance { value: -1.0 }) + ); - assert!(matches!( - a.ldlt(f64::NAN), + assert_matches!( + Tolerance::new(f64::NAN), Err(LaError::InvalidTolerance { value }) if value.is_nan() - )); + ); assert_eq!( - a.ldlt(f64::INFINITY), + Tolerance::new(f64::INFINITY), Err(LaError::InvalidTolerance { value: f64::INFINITY, }) @@ -593,7 +618,7 @@ mod tests { /// exercising the multiply-accumulate loop at each dimension. #[test] fn []() { - const DET: f64 = { + const DET: Result = { let mut factors = Matrix::<$d>::identity(); factors.rows[0][0] = 2.0; let ldlt = Ldlt::<$d> { @@ -602,7 +627,7 @@ mod tests { }; ldlt.det() }; - assert!((DET - 2.0).abs() <= 1e-12); + assert_eq!(DET, Ok(2.0)); } /// `Ldlt::solve_vec` must be fully const-evaluable. Identity diff --git a/src/lib.rs b/src/lib.rs index a2dac93..cca1fbc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ mod readme_doctests { /// ```rust /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// // This system requires pivoting (a[0][0] = 0), so it's a good LU demo. /// let a = Matrix::<5>::from_rows([ /// [0.0, 1.0, 1.0, 1.0, 1.0], @@ -19,20 +20,23 @@ mod readme_doctests { /// /// let b = Vector::<5>::new([14.0, 13.0, 12.0, 11.0, 10.0]); /// - /// let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); - /// let x = lu.solve_vec(b).unwrap().into_array(); + /// let lu = a.lu(DEFAULT_PIVOT_TOL)?; + /// let x = lu.solve_vec(b)?.into_array(); /// /// // Floating-point rounding is expected; compare with a tolerance. /// let expected = [1.0, 2.0, 3.0, 4.0, 5.0]; /// for (x_i, e_i) in x.iter().zip(expected.iter()) { /// assert!((*x_i - *e_i).abs() <= 1e-12); /// } + /// # Ok(()) + /// # } /// ``` fn solve_5x5_example() {} /// ```rust /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// // This matrix is symmetric positive-definite (A = L*L^T) so LDLT works without pivoting. /// let a = Matrix::<5>::from_rows([ /// [1.0, 1.0, 0.0, 0.0, 0.0], @@ -42,8 +46,10 @@ mod readme_doctests { /// [0.0, 0.0, 0.0, 1.0, 2.0], /// ]); /// - /// let det = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap().det(); + /// let det = a.ldlt(DEFAULT_SINGULAR_TOL)?.det()?; /// assert!((det - 1.0).abs() <= 1e-12); + /// # Ok(()) + /// # } /// ``` fn det_5x5_ldlt_example() {} } @@ -124,15 +130,21 @@ const EPS: f64 = f64::EPSILON; // 2^-52 /// ``` /// use la_stack::prelude::*; /// +/// # fn main() -> Result<(), LaError> { /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); -/// let det = m.det_direct().unwrap(); +/// let Some(det) = m.det_direct()? else { +/// return Ok(()); +/// }; +/// assert_eq!(det, -2.0); /// // Compute the bound from the raw constant for illustration; most -/// // callers would just use `m.det_errbound().unwrap()` instead. +/// // callers would match on `m.det_errbound()?` instead. /// let p = (1.0_f64 * 4.0).abs() + (2.0_f64 * 3.0).abs(); /// let bound = ERR_COEFF_2 * p; /// if det.abs() > bound { /// // The f64 sign is provably correct without exact arithmetic. /// } +/// # Ok(()) +/// # } /// ``` pub const ERR_COEFF_2: f64 = 3.0 * EPS + 16.0 * EPS * EPS; @@ -172,27 +184,93 @@ pub const ERR_COEFF_3: f64 = 8.0 * EPS + 64.0 * EPS * EPS; /// constant for typical use; see [`ERR_COEFF_2`] for a worked example. pub const ERR_COEFF_4: f64 = 12.0 * EPS + 128.0 * EPS * EPS; +/// Largest dimension supported by [`try_with_stack_matrix!`]. +/// +/// The crate can represent `Matrix` for any compile-time `D`, but runtime +/// dispatch must enumerate a finite set of concrete stack types. Dimensions +/// `0..=7` cover downstream geometric predicate matrices while keeping the +/// dispatch surface explicit. +pub const MAX_STACK_MATRIX_DISPATCH_DIM: usize = 7; + +/// Finite, non-negative tolerance used by numerical predicates and factorizations. +/// +/// Construct with [`Tolerance::new`] when accepting raw caller input. Once +/// constructed, the stored value is guaranteed to be finite and `>= 0`, so +/// downstream algorithms do not need to revalidate the tolerance. +#[must_use] +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Tolerance { + value: f64, +} + +impl Tolerance { + /// Construct a tolerance without checking the raw value. + /// + /// This crate-internal escape hatch is only for constants whose finite, + /// non-negative value is visible at the call site. Public callers should + /// use [`Tolerance::new`] so the returned value carries the validation + /// proof. + pub(crate) const fn new_unchecked(value: f64) -> Self { + Self { value } + } + + /// Construct a finite, non-negative tolerance. + /// + /// # Examples + /// ``` + /// use la_stack::prelude::*; + /// + /// # fn main() -> Result<(), LaError> { + /// let tol = Tolerance::new(1e-12)?; + /// assert_eq!(tol.get(), 1e-12); + /// # Ok(()) + /// # } + /// ``` + /// + /// # Errors + /// Returns [`LaError::InvalidTolerance`] when `value` is NaN, infinite, or + /// negative. + #[inline] + pub const fn new(value: f64) -> Result { + if value >= 0.0 && value.is_finite() { + Ok(Self::new_unchecked(value)) + } else { + Err(LaError::invalid_tolerance(value)) + } + } + + /// Return the raw finite, non-negative tolerance value. + /// + /// # Examples + /// ``` + /// use la_stack::prelude::*; + /// + /// # fn main() -> Result<(), LaError> { + /// let tol = Tolerance::new(0.0)?; + /// assert_eq!(tol.get(), 0.0); + /// # Ok(()) + /// # } + /// ``` + #[inline] + #[must_use] + pub const fn get(self) -> f64 { + self.value + } +} + /// Default absolute threshold used for singularity/degeneracy detection. /// /// This is intentionally conservative for geometric predicates and small systems. /// /// Conceptually, this is an absolute bound for deciding when a scalar should be treated /// as "numerically zero" (e.g. LU pivots, LDLT diagonal entries). -pub const DEFAULT_SINGULAR_TOL: f64 = 1e-12; +pub const DEFAULT_SINGULAR_TOL: Tolerance = Tolerance::new_unchecked(1e-12); /// Default absolute pivot magnitude threshold used for LU pivot selection / singularity detection. /// /// This name is kept for backwards compatibility; prefer [`DEFAULT_SINGULAR_TOL`] when the /// tolerance is not specifically about pivot selection. -pub const DEFAULT_PIVOT_TOL: f64 = DEFAULT_SINGULAR_TOL; - -/// Largest dimension supported by [`try_with_stack_matrix!`]. -/// -/// The crate can represent `Matrix` for any compile-time `D`, but runtime -/// dispatch must enumerate a finite set of concrete stack types. Dimensions -/// `0..=7` cover downstream geometric predicate matrices while keeping the -/// dispatch surface explicit. -pub const MAX_STACK_MATRIX_DISPATCH_DIM: usize = 7; +pub const DEFAULT_PIVOT_TOL: Tolerance = DEFAULT_SINGULAR_TOL; /// Linear algebra errors. /// @@ -216,8 +294,9 @@ pub enum LaError { /// paths when they detect a corrupt stored factor (only reachable via /// direct struct construction; `factor` itself rejects such inputs). /// - `row: None, col: c` — the non-finite value is either a *vector input* - /// entry at index `c`, or a *computed intermediate* at step `c` - /// (e.g. an accumulator that overflowed during forward/back substitution). + /// entry at index `c`, or a *computed scalar/intermediate* at slot or + /// step `c` (e.g. an accumulator that overflowed during determinant + /// evaluation or forward/back substitution). NonFinite { /// Row of the non-finite entry for a stored matrix cell, or `None` for /// a vector-input entry or a computed intermediate. See the variant @@ -277,6 +356,19 @@ impl LaError { /// matching the stored-cell convention documented on /// [`NonFinite`](Self::NonFinite). For vector-input entries or computed /// intermediates, use [`non_finite_at`](Self::non_finite_at). + /// + /// # Examples + /// ``` + /// use la_stack::prelude::*; + /// + /// assert_eq!( + /// LaError::non_finite_cell(1, 2), + /// LaError::NonFinite { + /// row: Some(1), + /// col: 2, + /// } + /// ); + /// ``` #[inline] #[must_use] pub const fn non_finite_cell(row: usize, col: usize) -> Self { @@ -287,13 +379,23 @@ impl LaError { } /// Construct a [`LaError::NonFinite`] pinpointing a vector-input entry or - /// computed-intermediate step at index `col`. + /// computed scalar/intermediate at index `col`. + /// + /// Use this for non-finite values in a `Vector` input, determinant scalar, + /// or accumulator that overflowed during forward/back substitution. The + /// resulting error has `row: None, col`, matching the vector/intermediate + /// convention documented on [`NonFinite`](Self::NonFinite). For stored + /// matrix cells, use [`non_finite_cell`](Self::non_finite_cell). + /// + /// # Examples + /// ``` + /// use la_stack::prelude::*; /// - /// Use this for non-finite values in a `Vector` input or an accumulator - /// that overflowed during forward/back substitution. The resulting error - /// has `row: None, col`, matching the vector/intermediate convention - /// documented on [`NonFinite`](Self::NonFinite). For stored matrix cells, - /// use [`non_finite_cell`](Self::non_finite_cell). + /// assert_eq!( + /// LaError::non_finite_at(2), + /// LaError::NonFinite { row: None, col: 2 } + /// ); + /// ``` #[inline] #[must_use] pub const fn non_finite_at(col: usize) -> Self { @@ -379,13 +481,13 @@ impl LaError { Self::Asymmetric { row, col, dim } } - /// Validate that a tolerance is finite and non-negative. + /// Parse a raw tolerance into a finite, non-negative [`Tolerance`]. /// /// # Examples /// ``` /// use la_stack::prelude::*; /// - /// assert_eq!(LaError::validate_tolerance(1e-12)?, 1e-12); + /// assert_eq!(LaError::validate_tolerance(1e-12)?.get(), 1e-12); /// assert_eq!( /// LaError::validate_tolerance(-1.0), /// Err(LaError::InvalidTolerance { value: -1.0 }) @@ -397,12 +499,8 @@ impl LaError { /// Returns [`LaError::InvalidTolerance`] when `value` is NaN, infinite, or /// negative. #[inline] - pub const fn validate_tolerance(value: f64) -> Result { - if value >= 0.0 && value.is_finite() { - Ok(value) - } else { - Err(Self::invalid_tolerance(value)) - } + pub const fn validate_tolerance(value: f64) -> Result { + Tolerance::new(value) } } @@ -543,9 +641,10 @@ macro_rules! try_with_stack_matrix { /// Common imports for ergonomic usage. /// /// This prelude re-exports the primary types and constants: [`Matrix`], [`Vector`], [`Lu`], -/// [`Ldlt`], [`LaError`], [`DEFAULT_PIVOT_TOL`], [`DEFAULT_SINGULAR_TOL`], and the determinant -/// error bound coefficients [`ERR_COEFF_2`], [`ERR_COEFF_3`], and [`ERR_COEFF_4`]. -/// It also re-exports [`MAX_STACK_MATRIX_DISPATCH_DIM`] and +/// [`Ldlt`], [`Tolerance`], [`LaError`], [`DEFAULT_PIVOT_TOL`], +/// [`DEFAULT_SINGULAR_TOL`], and the determinant error bound coefficients +/// [`ERR_COEFF_2`], [`ERR_COEFF_3`], and [`ERR_COEFF_4`]. It also re-exports +/// [`MAX_STACK_MATRIX_DISPATCH_DIM`] and /// [`try_with_stack_matrix!`] for runtime-to-const matrix dispatch. /// /// When the `exact` feature is enabled, [`BigInt`] and [`BigRational`] @@ -559,7 +658,7 @@ macro_rules! try_with_stack_matrix { pub mod prelude { pub use crate::{ DEFAULT_PIVOT_TOL, DEFAULT_SINGULAR_TOL, ERR_COEFF_2, ERR_COEFF_3, ERR_COEFF_4, LaError, - Ldlt, Lu, MAX_STACK_MATRIX_DISPATCH_DIM, Matrix, Vector, try_with_stack_matrix, + Ldlt, Lu, MAX_STACK_MATRIX_DISPATCH_DIM, Matrix, Tolerance, Vector, try_with_stack_matrix, }; #[cfg(feature = "exact")] @@ -574,8 +673,12 @@ mod tests { #[test] fn default_singular_tol_is_expected() { - assert_abs_diff_eq!(DEFAULT_SINGULAR_TOL, 1e-12, epsilon = 0.0); - assert_abs_diff_eq!(DEFAULT_PIVOT_TOL, DEFAULT_SINGULAR_TOL, epsilon = 0.0); + assert_abs_diff_eq!(DEFAULT_SINGULAR_TOL.get(), 1e-12, epsilon = 0.0); + assert_abs_diff_eq!( + DEFAULT_PIVOT_TOL.get(), + DEFAULT_SINGULAR_TOL.get(), + epsilon = 0.0 + ); } #[test] @@ -711,12 +814,13 @@ mod tests { gen_stack_matrix_dispatch_tests!(3); gen_stack_matrix_dispatch_tests!(4); gen_stack_matrix_dispatch_tests!(5); + gen_stack_matrix_dispatch_tests!(6); gen_stack_matrix_dispatch_tests!(7); #[test] fn try_with_stack_matrix_supports_zero_dimension() { let got = try_with_stack_matrix!(0usize, |m| -> Result, LaError> { - Ok(m.det_direct()) + m.det_direct() }); assert_eq!(got, Ok(Some(1.0))); @@ -749,7 +853,7 @@ mod tests { #[test] fn try_with_stack_matrix_converts_unsupported_dimension_error() { let got = try_with_stack_matrix!(9usize, |m| -> Result { - assert_abs_diff_eq!(m.inf_norm(), 0.0, epsilon = 0.0); + assert_abs_diff_eq!(m.inf_norm().unwrap(), 0.0, epsilon = 0.0); Ok(0) }); diff --git a/src/lu.rs b/src/lu.rs index 4e3f887..1920a10 100644 --- a/src/lu.rs +++ b/src/lu.rs @@ -2,9 +2,9 @@ use core::hint::cold_path; -use crate::LaError; use crate::matrix::Matrix; use crate::vector::Vector; +use crate::{LaError, Tolerance}; /// LU decomposition (PA = LU) with partial pivoting. #[must_use] @@ -13,7 +13,7 @@ pub struct Lu { factors: Matrix, piv: [usize; D], piv_sign: f64, - tol: f64, + tol: Tolerance, } impl Lu { @@ -23,8 +23,7 @@ impl Lu { /// invalid tolerances, non-finite pivot candidates, and numerically singular /// pivots before callers can observe a [`Lu`] value. #[inline] - pub(crate) fn factor(a: Matrix, tol: f64) -> Result { - let tol = LaError::validate_tolerance(tol)?; + pub(crate) fn factor(a: Matrix, tol: Tolerance) -> Result { let mut lu = a; let mut piv = [0usize; D]; @@ -55,7 +54,7 @@ impl Lu { } } - if pivot_abs <= tol { + if pivot_abs <= tol.get() { cold_path(); return Err(LaError::Singular { pivot_col: k }); } @@ -177,7 +176,7 @@ impl Lu { cold_path(); return Err(LaError::non_finite_at(i)); } - if diag.abs() <= self.tol { + if diag.abs() <= self.tol.get() { cold_path(); return Err(LaError::Singular { pivot_col: i }); } @@ -204,21 +203,28 @@ impl Lu { /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); /// let lu = a.lu(DEFAULT_PIVOT_TOL)?; /// - /// let det = lu.det(); + /// let det = lu.det()?; /// assert!((det - (-2.0)).abs() <= 1e-12); /// # Ok(()) /// # } /// ``` + /// + /// # Errors + /// Returns [`LaError::NonFinite`] if the determinant product overflows to + /// NaN or infinity. #[inline] - #[must_use] - pub const fn det(&self) -> f64 { + pub const fn det(&self) -> Result { let mut det = self.piv_sign; let mut i = 0; while i < D { det *= self.factors.rows[i][i]; + if !det.is_finite() { + cold_path(); + return Err(LaError::non_finite_at(i)); + } i += 1; } - det + Ok(det) } } @@ -227,6 +233,7 @@ mod tests { use super::*; use crate::DEFAULT_PIVOT_TOL; + use core::assert_matches; use core::hint::black_box; use approx::assert_abs_diff_eq; @@ -249,7 +256,7 @@ mod tests { rows.swap(0, 1); let a = Matrix::<$d>::from_rows(black_box(rows)); - let lu_fn: fn(Matrix<$d>, f64) -> Result, LaError> = + let lu_fn: fn(Matrix<$d>, Tolerance) -> Result, LaError> = black_box(Matrix::<$d>::lu); let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap(); @@ -289,13 +296,14 @@ mod tests { rows.swap(0, 1); let a = Matrix::<$d>::from_rows(black_box(rows)); - let lu_fn: fn(Matrix<$d>, f64) -> Result, LaError> = + let lu_fn: fn(Matrix<$d>, Tolerance) -> Result, LaError> = black_box(Matrix::<$d>::lu); let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap(); // Row swap ⇒ determinant sign flip. - let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det); - assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 1e-12); + let det_fn: fn(&Lu<$d>) -> Result = + black_box(Lu::<$d>::det); + assert_abs_diff_eq!(det_fn(&lu).unwrap(), -1.0, epsilon = 1e-12); } } }; @@ -328,7 +336,7 @@ mod tests { } let a = Matrix::<$d>::from_rows(black_box(rows)); - let lu_fn: fn(Matrix<$d>, f64) -> Result, LaError> = + let lu_fn: fn(Matrix<$d>, Tolerance) -> Result, LaError> = black_box(Matrix::<$d>::lu); let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap(); @@ -367,12 +375,13 @@ mod tests { } let a = Matrix::<$d>::from_rows(black_box(rows)); - let lu_fn: fn(Matrix<$d>, f64) -> Result, LaError> = + let lu_fn: fn(Matrix<$d>, Tolerance) -> Result, LaError> = black_box(Matrix::<$d>::lu); let lu = lu_fn(a, DEFAULT_PIVOT_TOL).unwrap(); - let det_fn: fn(&Lu<$d>) -> f64 = black_box(Lu::<$d>::det); - assert_abs_diff_eq!(det_fn(&lu), f64::from($d) + 1.0, epsilon = 1e-8); + let det_fn: fn(&Lu<$d>) -> Result = + black_box(Lu::<$d>::det); + assert_abs_diff_eq!(det_fn(&lu).unwrap(), f64::from($d) + 1.0, epsilon = 1e-8); } } }; @@ -393,8 +402,8 @@ mod tests { let x = solve_fn(&lu, b).unwrap().into_array(); assert_abs_diff_eq!(x[0], 3.0, epsilon = 1e-12); - let det_fn: fn(&Lu<1>) -> f64 = black_box(Lu::<1>::det); - assert_abs_diff_eq!(det_fn(&lu), 2.0, epsilon = 0.0); + let det_fn: fn(&Lu<1>) -> Result = black_box(Lu::<1>::det); + assert_abs_diff_eq!(det_fn(&lu).unwrap(), 2.0, epsilon = 0.0); } #[test] @@ -416,8 +425,8 @@ mod tests { let a = Matrix::<2>::from_rows(black_box([[1.0, 2.0], [3.0, 4.0]])); let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); - let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det); - assert_abs_diff_eq!(det_fn(&lu), -2.0, epsilon = 1e-12); + let det_fn: fn(&Lu<2>) -> Result = black_box(Lu::<2>::det); + assert_abs_diff_eq!(det_fn(&lu).unwrap(), -2.0, epsilon = 1e-12); } #[test] @@ -426,8 +435,8 @@ mod tests { let a = Matrix::<2>::from_rows(black_box([[0.0, 1.0], [1.0, 0.0]])); let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); - let det_fn: fn(&Lu<2>) -> f64 = black_box(Lu::<2>::det); - assert_abs_diff_eq!(det_fn(&lu), -1.0, epsilon = 0.0); + let det_fn: fn(&Lu<2>) -> Result = black_box(Lu::<2>::det); + assert_abs_diff_eq!(det_fn(&lu).unwrap(), -1.0, epsilon = 0.0); } #[test] @@ -461,15 +470,17 @@ mod tests { #[test] fn invalid_tolerance_rejected() { - let a = Matrix::<2>::identity(); - assert_eq!(a.lu(-1.0), Err(LaError::InvalidTolerance { value: -1.0 })); + assert_eq!( + Tolerance::new(-1.0), + Err(LaError::InvalidTolerance { value: -1.0 }) + ); - assert!(matches!( - a.lu(f64::NAN), + assert_matches!( + Tolerance::new(f64::NAN), Err(LaError::InvalidTolerance { value }) if value.is_nan() - )); + ); assert_eq!( - a.lu(f64::INFINITY), + Tolerance::new(f64::INFINITY), Err(LaError::InvalidTolerance { value: f64::INFINITY, }) @@ -540,6 +551,19 @@ mod tests { assert_eq!(err, LaError::NonFinite { row: None, col: 1 }); } + #[test] + fn det_rejects_product_overflow() { + let a = Matrix::<5>::from_rows([ + [1.0e100, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0e100, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0e100, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0e100, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0e100], + ]); + let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); + assert_eq!(lu.det(), Err(LaError::NonFinite { row: None, col: 3 })); + } + // ----------------------------------------------------------------------- // Defensive-path coverage for `solve_vec`. // @@ -631,7 +655,7 @@ mod tests { #[test] fn lu_det_const_eval_d2() { - const DET: f64 = { + const DET: Result = { // Triangular factors with diag [2.0, 3.0] and no row swaps. let mut factors = Matrix::<2>::identity(); factors.rows[0][0] = 2.0; @@ -644,12 +668,12 @@ mod tests { }; lu.det() }; - assert!((DET - 6.0).abs() <= 1e-12); + assert_eq!(DET, Ok(6.0)); } #[test] fn lu_det_const_eval_d3_row_swap() { - const DET: f64 = { + const DET: Result = { // Identity factors but `piv_sign = -1.0` encoding a single row swap; // the determinant magnitude is 1 but the sign flips. let lu = Lu::<3> { @@ -660,7 +684,7 @@ mod tests { }; lu.det() }; - assert!((DET - (-1.0)).abs() <= 1e-12); + assert_eq!(DET, Ok(-1.0)); } #[test] diff --git a/src/matrix.rs b/src/matrix.rs index b1021d4..2dae220 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -2,10 +2,10 @@ use core::hint::cold_path; -use crate::LaError; use crate::ldlt::Ldlt; use crate::lu::Lu; use crate::{ERR_COEFF_2, ERR_COEFF_3, ERR_COEFF_4}; +use crate::{LaError, Tolerance}; /// Fixed-size square matrix `D×D`, stored inline. #[must_use] @@ -127,31 +127,32 @@ impl Matrix { /// Set an element with bounds checking. /// - /// Returns `true` if the index was in-bounds. + /// Returns `Some(())` if the index was in bounds, or `None` otherwise. /// /// # Examples /// ``` /// use la_stack::prelude::*; /// /// let mut m = Matrix::<2>::zero(); - /// assert!(m.set(0, 1, 2.5)); + /// assert_eq!(m.set(0, 1, 2.5), Some(())); /// assert_eq!(m.get(0, 1), Some(2.5)); - /// assert!(!m.set(10, 0, 1.0)); + /// assert_eq!(m.set(10, 0, 1.0), None); /// ``` #[inline] - pub const fn set(&mut self, r: usize, c: usize, value: f64) -> bool { + #[must_use] + pub const fn set(&mut self, r: usize, c: usize, value: f64) -> Option<()> { if r < D && c < D { self.rows[r][c] = value; - true + Some(()) } else { - false + None } } /// Set an element, preserving index context on failure. /// /// The matrix is mutated only when `(row, col)` is in bounds. Prefer - /// [`set`](Self::set) for const or hot paths that only need a boolean; + /// [`set`](Self::set) for const or hot paths that only need `Option`-style absence; /// use this method at public runtime boundaries where failed mutation /// should return a typed, contextual error. /// @@ -191,25 +192,35 @@ impl Matrix { /// Infinity norm (maximum absolute row sum). /// /// # Non-finite handling - /// If any entry is NaN, the result is NaN. NaN is detected explicitly - /// because a naive `row_sum > max_row_sum` comparison silently skips NaN - /// rows (every ordered comparison against NaN is `false`). If any entry - /// is infinite (and no entry is NaN), the result is `+∞`. + /// Non-finite entries are rejected with source coordinates instead of + /// silently propagating NaN or infinity through the norm. /// /// # Examples /// ``` /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// let m = Matrix::<2>::from_rows([[1.0, -2.0], [3.0, 4.0]]); - /// assert!((m.inf_norm() - 7.0).abs() <= 1e-12); + /// assert!((m.inf_norm()? - 7.0).abs() <= 1e-12); /// - /// // NaN entries propagate to the norm. + /// // NaN entries are rejected with coordinates. /// let nan = Matrix::<2>::from_rows([[f64::NAN, 1.0], [2.0, 3.0]]); - /// assert!(nan.inf_norm().is_nan()); + /// assert_eq!( + /// nan.inf_norm(), + /// Err(LaError::NonFinite { + /// row: Some(0), + /// col: 0, + /// }) + /// ); + /// # Ok(()) + /// # } /// ``` + /// + /// # Errors + /// Returns [`LaError::NonFinite`] when any entry is NaN or infinity, or when + /// a row sum overflows to NaN or infinity. #[inline] - #[must_use] - pub const fn inf_norm(&self) -> f64 { + pub const fn inf_norm(&self) -> Result { let mut max_row_sum: f64 = 0.0; let mut r = 0; @@ -221,15 +232,16 @@ impl Matrix { let mut row_sum: f64 = 0.0; let mut c = 0; while c < D { + if !row[c].is_finite() { + cold_path(); + return Err(LaError::non_finite_cell(r, c)); + } row_sum += row[c].abs(); c += 1; } - // Propagate NaN explicitly: `f64::max` drops NaN (IEEE 754 `maxNum`) - // and `f64::maximum` (IEEE 754-2019 `maximum`) is still unstable, - // so we short-circuit on NaN instead. - if row_sum.is_nan() { + if !row_sum.is_finite() { cold_path(); - return f64::NAN; + return Err(LaError::non_finite_at(r)); } if row_sum > max_row_sum { max_row_sum = row_sum; @@ -237,14 +249,14 @@ impl Matrix { r += 1; } - max_row_sum + Ok(max_row_sum) } /// Returns `true` if the matrix is symmetric within a relative tolerance. /// /// Two entries `self[r][c]` and `self[c][r]` are considered equal (for the /// purposes of symmetry) when - /// `|self[r][c] - self[c][r]| <= rel_tol * max(1.0, self.inf_norm())`. + /// `|self[r][c] - self[c][r]| <= rel_tol * max(1.0, inf_norm(self))`. /// This mirrors the predicate used internally by [`ldlt`](Self::ldlt), so /// callers can pre-validate matrices that may come from untrusted sources. /// @@ -262,20 +274,19 @@ impl Matrix { /// /// # fn main() -> Result<(), LaError> { /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]); - /// assert!(a.is_symmetric(1e-12)?); + /// let tol = Tolerance::new(1e-12)?; + /// assert!(a.is_symmetric(tol)?); /// /// let b = Matrix::<2>::from_rows([[4.0, 2.0], [3.0, 3.0]]); - /// assert!(!b.is_symmetric(1e-12)?); + /// assert!(!b.is_symmetric(tol)?); /// # Ok(()) /// # } /// ``` /// /// # Errors - /// Returns [`LaError::InvalidTolerance`] when `rel_tol` is NaN, infinite, - /// or negative. /// Returns [`LaError::NonFinite`] when any matrix entry is NaN or infinite. #[inline] - pub fn is_symmetric(&self, rel_tol: f64) -> Result { + pub fn is_symmetric(&self, rel_tol: Tolerance) -> Result { Ok(self.first_asymmetry(rel_tol)?.is_none()) } @@ -286,7 +297,7 @@ impl Matrix { /// Iteration order is row-major over the strict upper triangle, so the /// returned indices are the lexicographically smallest such pair. The /// predicate is the same as [`is_symmetric`](Self::is_symmetric): - /// `|self[r][c] - self[c][r]| <= rel_tol * max(1.0, self.inf_norm())`. + /// `|self[r][c] - self[c][r]| <= rel_tol * max(1.0, inf_norm(self))`. /// /// # Examples /// ``` @@ -298,19 +309,17 @@ impl Matrix { /// [2.0, 4.0, 5.0], /// [0.0, 6.0, 9.0], // 6.0 breaks symmetry with a[1][2] = 5.0 /// ]); - /// assert_eq!(a.first_asymmetry(1e-12)?, Some((1, 2))); - /// assert_eq!(Matrix::<3>::identity().first_asymmetry(1e-12)?, None); + /// let tol = Tolerance::new(1e-12)?; + /// assert_eq!(a.first_asymmetry(tol)?, Some((1, 2))); + /// assert_eq!(Matrix::<3>::identity().first_asymmetry(tol)?, None); /// # Ok(()) /// # } /// ``` /// /// # Errors - /// Returns [`LaError::InvalidTolerance`] when `rel_tol` is NaN, infinite, - /// or negative. /// Returns [`LaError::NonFinite`] when any matrix entry is NaN or infinite. #[inline] - pub fn first_asymmetry(&self, rel_tol: f64) -> Result, LaError> { - let rel_tol = LaError::validate_tolerance(rel_tol)?; + pub fn first_asymmetry(&self, rel_tol: Tolerance) -> Result, LaError> { let eps = self.symmetry_epsilon(rel_tol)?; for r in 0..D { for c in (r + 1)..D { @@ -343,7 +352,8 @@ impl Matrix { /// /// # Errors /// Returns [`LaError::NonFinite`] when any matrix entry is NaN or infinite. - fn symmetry_epsilon(&self, rel_tol: f64) -> Result { + fn symmetry_epsilon(&self, rel_tol: Tolerance) -> Result { + let rel_tol = rel_tol.get(); let mut eps = rel_tol; for r in 0..D { @@ -387,9 +397,8 @@ impl Matrix { /// Returns [`LaError::Singular`] if, for some column `k`, the largest-magnitude candidate pivot /// in that column satisfies `|pivot| <= tol` (so no numerically usable pivot exists). /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization. - /// Returns [`LaError::InvalidTolerance`] if `tol` is NaN, infinite, or negative. #[inline] - pub fn lu(self, tol: f64) -> Result, LaError> { + pub fn lu(self, tol: Tolerance) -> Result, LaError> { Lu::factor(self, tol) } @@ -417,7 +426,7 @@ impl Matrix { /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?; /// /// // det(A) = 8 - /// assert!((ldlt.det() - 8.0).abs() <= 1e-12); + /// assert!((ldlt.det()? - 8.0).abs() <= 1e-12); /// /// // Solve A x = b /// let b = Vector::<2>::new([1.0, 2.0]); @@ -433,17 +442,45 @@ impl Matrix { /// is `<= tol` (non-positive or too small). This treats PSD degeneracy (and indefinite inputs) /// as singular/degenerate. /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization. - /// Returns [`LaError::InvalidTolerance`] if `tol` is NaN, infinite, or negative. /// Returns [`LaError::Asymmetric`] if the input matrix is not symmetric. #[inline] - pub fn ldlt(self, tol: f64) -> Result, LaError> { + pub fn ldlt(self, tol: Tolerance) -> Result, LaError> { Ldlt::factor(self, tol) } + /// Return the first non-finite stored cell in row-major order. + const fn first_non_finite_cell(&self) -> Option<(usize, usize)> { + let mut r = 0; + while r < D { + let mut c = 0; + while c < D { + if !self.rows[r][c].is_finite() { + return Some((r, c)); + } + c += 1; + } + r += 1; + } + None + } + + /// Return a computed scalar result, preserving non-finite diagnostics. + const fn computed_scalar_result(&self, value: Option) -> Result, LaError> { + if let Some((row, col)) = self.first_non_finite_cell() { + Err(LaError::non_finite_cell(row, col)) + } else { + match value { + Some(value) if value.is_finite() => Ok(Some(value)), + Some(_) => Err(LaError::non_finite_at(0)), + None => Ok(None), + } + } + } + /// Closed-form determinant for dimensions 0–4, bypassing LU factorization. /// - /// Returns `Some(det)` for `D` ∈ {0, 1, 2, 3, 4}, `None` for D ≥ 5. - /// `D = 0` returns `Some(1.0)` (empty product). + /// Returns `Ok(Some(det))` for `D` ∈ {0, 1, 2, 3, 4}, `Ok(None)` for D ≥ 5. + /// `D = 0` returns `Ok(Some(1.0))` (empty product). /// This is a `const fn` (Rust 1.94+) and uses fused multiply-add (`mul_add`) /// for improved accuracy and performance. /// @@ -454,24 +491,32 @@ impl Matrix { /// ``` /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]); - /// assert!((m.det_direct().unwrap() - (-2.0)).abs() <= 1e-12); + /// assert_eq!(m.det_direct()?, Some(-2.0)); /// /// // D = 0 is the empty product. - /// assert_eq!(Matrix::<0>::zero().det_direct(), Some(1.0)); + /// assert_eq!(Matrix::<0>::zero().det_direct()?, Some(1.0)); /// /// // D ≥ 5 returns None. - /// assert!(Matrix::<5>::identity().det_direct().is_none()); + /// assert!(Matrix::<5>::identity().det_direct()?.is_none()); + /// # Ok(()) + /// # } /// ``` + /// + /// # Errors + /// Returns [`LaError::NonFinite`] when any entry is NaN or infinity, or when + /// the closed-form determinant overflows to NaN or infinity. #[inline] - #[must_use] - pub const fn det_direct(&self) -> Option { + pub const fn det_direct(&self) -> Result, LaError> { match D { - 0 => Some(1.0), - 1 => Some(self.rows[0][0]), + 0 => Ok(Some(1.0)), + 1 => self.computed_scalar_result(Some(self.rows[0][0])), 2 => { // ad - bc - Some(self.rows[0][0].mul_add(self.rows[1][1], -(self.rows[0][1] * self.rows[1][0]))) + self.computed_scalar_result(Some( + self.rows[0][0].mul_add(self.rows[1][1], -(self.rows[0][1] * self.rows[1][0])), + )) } 3 => { // Cofactor expansion on first row. @@ -481,10 +526,10 @@ impl Matrix { self.rows[1][0].mul_add(self.rows[2][2], -(self.rows[1][2] * self.rows[2][0])); let m02 = self.rows[1][0].mul_add(self.rows[2][1], -(self.rows[1][1] * self.rows[2][0])); - Some( + self.computed_scalar_result(Some( self.rows[0][0] .mul_add(m00, (-self.rows[0][1]).mul_add(m01, self.rows[0][2] * m02)), - ) + )) } 4 => { // Cofactor expansion on first row → four 3×3 sub-determinants. @@ -505,15 +550,15 @@ impl Matrix { let c02 = r[1][0].mul_add(s13, (-r[1][1]).mul_add(s03, r[1][3] * s01)); let c03 = r[1][0].mul_add(s12, (-r[1][1]).mul_add(s02, r[1][2] * s01)); - Some(r[0][0].mul_add( + self.computed_scalar_result(Some(r[0][0].mul_add( c00, (-r[0][1]).mul_add(c01, r[0][2].mul_add(c02, -(r[0][3] * c03))), - )) + ))) } _ => { // Cold in the common D ≤ 4 case; callers fall back to LU for D ≥ 5. cold_path(); - None + self.computed_scalar_result(None) } } } @@ -539,31 +584,17 @@ impl Matrix { /// Returns [`LaError::NonFinite`] if the result contains NaN or infinity. /// For D ≥ 5, propagates LU factorization errors (e.g. [`LaError::Singular`]). #[inline] - pub fn det(self, tol: f64) -> Result { - if let Some(d) = self.det_direct() { - return if d.is_finite() { - Ok(d) - } else { - cold_path(); - // Scan for the first non-finite entry to preserve coordinates. - for r in 0..D { - for c in 0..D { - if !self.rows[r][c].is_finite() { - return Err(LaError::non_finite_cell(r, c)); - } - } - } - // All entries are finite but the determinant overflowed. - Err(LaError::non_finite_at(0)) - }; + pub fn det(self, tol: Tolerance) -> Result { + if let Some(d) = self.det_direct()? { + return Ok(d); } - self.lu(tol).map(|lu| lu.det()) + self.lu(tol)?.det() } /// Conservative absolute error bound for `det_direct()`. /// - /// Returns `Some(bound)` such that `|det_direct() - det_exact| ≤ bound`, - /// or `None` for D ≥ 5 where no fast bound is available. + /// Returns `Ok(Some(bound))` such that `|det_direct() - det_exact| ≤ bound`, + /// or `Ok(None)` for D ≥ 5 where no fast bound is available. /// /// For D ≤ 4, the bound is derived from the absolute Leibniz sum using /// Shewchuk-style error analysis (see `REFERENCES.md` \[8\] and the @@ -585,14 +616,19 @@ impl Matrix { /// ``` /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// let m = Matrix::<3>::from_rows([ /// [1.0, 2.0, 3.0], /// [4.0, 5.0, 6.0], /// [7.0, 8.0, 9.0], /// ]); - /// let bound = m.det_errbound().unwrap(); - /// let det_approx = m.det_direct().unwrap(); - /// // If |det_approx| > bound, the sign is guaranteed correct. + /// if let (Some(bound), Some(det_approx)) = (m.det_errbound()?, m.det_direct()?) { + /// // If |det_approx| > bound, the sign is guaranteed correct. + /// let sign_is_certified = det_approx.abs() > bound; + /// assert!(!sign_is_certified); + /// } + /// # Ok(()) + /// # } /// ``` /// /// # Adaptive precision pattern (requires `exact` feature) @@ -600,29 +636,33 @@ impl Matrix { /// use la_stack::prelude::*; /// /// let m = Matrix::<3>::identity(); - /// if let Some(bound) = m.det_errbound() { - /// let det = m.det_direct().unwrap(); - /// if det.abs() > bound { - /// // f64 sign is guaranteed correct - /// let sign = det.signum() as i8; - /// } else { - /// // Fall back to exact arithmetic (requires `exact` feature) - /// let sign = m.det_sign_exact().unwrap(); + /// if let Some(bound) = m.det_errbound()? { + /// if let Some(det) = m.det_direct()? { + /// if det.abs() > bound { + /// // f64 sign is guaranteed correct + /// let sign = det.signum() as i8; + /// } else { + /// // Fall back to exact arithmetic (requires `exact` feature) + /// let sign = m.det_sign_exact()?; + /// } /// } /// } else { /// // D ≥ 5: no fast filter, use exact directly - /// let sign = m.det_sign_exact().unwrap(); + /// let sign = m.det_sign_exact()?; /// } /// ``` - #[must_use] + /// + /// # Errors + /// Returns [`LaError::NonFinite`] when any entry is NaN or infinity, or when + /// the bound computation overflows to NaN or infinity. #[inline] - pub const fn det_errbound(&self) -> Option { + pub const fn det_errbound(&self) -> Result, LaError> { match D { - 0 | 1 => Some(0.0), // No arithmetic — result is exact. + 0 | 1 => self.computed_scalar_result(Some(0.0)), // No arithmetic — result is exact. 2 => { let r = &self.rows; let permanent = (r[0][0] * r[1][1]).abs() + (r[0][1] * r[1][0]).abs(); - Some(ERR_COEFF_2 * permanent) + self.computed_scalar_result(Some(ERR_COEFF_2 * permanent)) } 3 => { let r = &self.rows; @@ -632,7 +672,7 @@ impl Matrix { let permanent = r[0][2] .abs() .mul_add(pm02, r[0][1].abs().mul_add(pm01, r[0][0].abs() * pm00)); - Some(ERR_COEFF_3 * permanent) + self.computed_scalar_result(Some(ERR_COEFF_3 * permanent)) } 4 => { let r = &self.rows; @@ -662,9 +702,12 @@ impl Matrix { .abs() .mul_add(pc2, r[0][1].abs().mul_add(pc1, r[0][0].abs() * pc0)), ); - Some(ERR_COEFF_4 * permanent) + self.computed_scalar_result(Some(ERR_COEFF_4 * permanent)) + } + _ => { + cold_path(); + self.computed_scalar_result(None) } - _ => None, } } } @@ -682,6 +725,7 @@ mod tests { use crate::{DEFAULT_PIVOT_TOL, Vector}; use approx::assert_abs_diff_eq; + use core::assert_matches; use pastey::paste; use std::hint::black_box; @@ -713,7 +757,9 @@ mod tests { ); // Out-of-bounds set fails. - assert!(!m.set($d, 0, 3.0)); + let before_failed_set = m; + assert_eq!(m.set($d, 0, 3.0), None); + assert_eq!(m, before_failed_set); assert_eq!( m.set_checked($d, 0, 3.0), Err(LaError::IndexOutOfBounds { @@ -722,6 +768,7 @@ mod tests { dim: $d, }) ); + assert_eq!(m, before_failed_set); assert_eq!( m.set_checked(0, $d, 3.0), Err(LaError::IndexOutOfBounds { @@ -730,10 +777,11 @@ mod tests { dim: $d, }) ); + assert_eq!(m, before_failed_set); assert_eq!(m.get(0, 0), Some(1.0)); // In-bounds set works. - assert!(m.set(0, $d - 1, 3.0)); + assert_eq!(m.set(0, $d - 1, 3.0), Some(())); assert_eq!(m.get(0, $d - 1), Some(3.0)); assert_eq!(m.set_checked($d - 1, 0, 4.0), Ok(())); assert_eq!(m.get_checked($d - 1, 0), Ok(4.0)); @@ -742,10 +790,10 @@ mod tests { #[test] fn []() { let z = Matrix::<$d>::zero(); - assert_abs_diff_eq!(z.inf_norm(), 0.0, epsilon = 0.0); + assert_abs_diff_eq!(z.inf_norm().unwrap(), 0.0, epsilon = 0.0); let d = Matrix::<$d>::default(); - assert_abs_diff_eq!(d.inf_norm(), 0.0, epsilon = 0.0); + assert_abs_diff_eq!(d.inf_norm().unwrap(), 0.0, epsilon = 0.0); } #[test] @@ -763,7 +811,7 @@ mod tests { } let m = Matrix::<$d>::from_rows(rows); - assert_abs_diff_eq!(m.inf_norm(), f64::from($d), epsilon = 0.0); + assert_abs_diff_eq!(m.inf_norm().unwrap(), f64::from($d), epsilon = 0.0); } #[test] @@ -815,13 +863,13 @@ mod tests { #[test] fn det_direct_d0_is_one() { - assert_eq!(Matrix::<0>::zero().det_direct(), Some(1.0)); + assert_eq!(Matrix::<0>::zero().det_direct(), Ok(Some(1.0))); } #[test] fn det_direct_d1_returns_element() { let m = Matrix::<1>::from_rows([[42.0]]); - assert_eq!(m.det_direct(), Some(42.0)); + assert_eq!(m.det_direct(), Ok(Some(42.0))); } #[test] @@ -829,7 +877,7 @@ mod tests { // [[1,2],[3,4]] → det = 1*4 - 2*3 = -2 // black_box prevents compile-time constant folding of the const fn. let m = black_box(Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]])); - assert_abs_diff_eq!(m.det_direct().unwrap(), -2.0, epsilon = 1e-15); + assert_abs_diff_eq!(m.det_direct().unwrap().unwrap(), -2.0, epsilon = 1e-15); } #[test] @@ -840,7 +888,7 @@ mod tests { [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], ])); - assert_abs_diff_eq!(m.det_direct().unwrap(), 0.0, epsilon = 1e-12); + assert_abs_diff_eq!(m.det_direct().unwrap().unwrap(), 0.0, epsilon = 1e-12); } #[test] @@ -851,13 +899,13 @@ mod tests { [0.0, 3.0, 1.0], [1.0, 0.0, 2.0], ])); - assert_abs_diff_eq!(m.det_direct().unwrap(), 13.0, epsilon = 1e-12); + assert_abs_diff_eq!(m.det_direct().unwrap().unwrap(), 13.0, epsilon = 1e-12); } #[test] fn det_direct_d4_identity() { let m = black_box(Matrix::<4>::identity()); - assert_abs_diff_eq!(m.det_direct().unwrap(), 1.0, epsilon = 1e-15); + assert_abs_diff_eq!(m.det_direct().unwrap().unwrap(), 1.0, epsilon = 1e-15); } #[test] @@ -869,17 +917,66 @@ mod tests { rows[2][2] = 5.0; rows[3][3] = 7.0; let m = black_box(Matrix::<4>::from_rows(rows)); - assert_abs_diff_eq!(m.det_direct().unwrap(), 210.0, epsilon = 1e-12); + assert_abs_diff_eq!(m.det_direct().unwrap().unwrap(), 210.0, epsilon = 1e-12); } #[test] fn det_direct_d5_returns_none() { - assert_eq!(Matrix::<5>::identity().det_direct(), None); + assert_eq!(Matrix::<5>::identity().det_direct(), Ok(None)); + } + + #[test] + fn det_direct_d5_rejects_nonfinite_before_returning_none() { + let mut m = Matrix::<5>::identity(); + assert_eq!(m.set(3, 4, f64::NAN), Some(())); + assert_eq!( + m.det_direct(), + Err(LaError::NonFinite { + row: Some(3), + col: 4, + }) + ); } #[test] fn det_direct_d8_returns_none() { - assert_eq!(Matrix::<8>::zero().det_direct(), None); + assert_eq!(Matrix::<8>::zero().det_direct(), Ok(None)); + } + + #[test] + fn det_direct_rejects_nonfinite_entry_with_coordinates() { + let m = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, f64::NAN, 0.0], [0.0, 0.0, 1.0]]); + assert_eq!( + m.det_direct(), + Err(LaError::NonFinite { + row: Some(1), + col: 1, + }) + ); + } + + #[test] + fn det_direct_rejects_computed_overflow() { + let m = Matrix::<2>::from_rows([[1e300, 0.0], [0.0, 1e300]]); + assert_eq!( + m.det_direct(), + Err(LaError::NonFinite { row: None, col: 0 }) + ); + } + + #[test] + fn det_d5_rejects_lu_product_overflow() { + let m = Matrix::<5>::from_rows([ + [1.0e100, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0e100, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0e100, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0e100, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0e100], + ]); + assert_eq!( + m.det(DEFAULT_PIVOT_TOL), + Err(LaError::NonFinite { row: None, col: 3 }) + ); } macro_rules! gen_det_direct_agrees_with_lu { @@ -900,8 +997,8 @@ mod tests { } } let m = Matrix::<$d>::from_rows(rows); - let direct = m.det_direct().unwrap(); - let lu_det = m.lu(DEFAULT_PIVOT_TOL).unwrap().det(); + let direct = m.det_direct().unwrap().unwrap(); + let lu_det = m.lu(DEFAULT_PIVOT_TOL).unwrap().det().unwrap(); let eps = lu_det.abs().mul_add(1e-12, 1e-12); assert_abs_diff_eq!(direct, lu_det, epsilon = eps); } @@ -917,22 +1014,22 @@ mod tests { #[test] fn det_direct_identity_all_dims() { assert_abs_diff_eq!( - Matrix::<1>::identity().det_direct().unwrap(), + Matrix::<1>::identity().det_direct().unwrap().unwrap(), 1.0, epsilon = 0.0 ); assert_abs_diff_eq!( - Matrix::<2>::identity().det_direct().unwrap(), + Matrix::<2>::identity().det_direct().unwrap().unwrap(), 1.0, epsilon = 0.0 ); assert_abs_diff_eq!( - Matrix::<3>::identity().det_direct().unwrap(), + Matrix::<3>::identity().det_direct().unwrap().unwrap(), 1.0, epsilon = 0.0 ); assert_abs_diff_eq!( - Matrix::<4>::identity().det_direct().unwrap(), + Matrix::<4>::identity().det_direct().unwrap().unwrap(), 1.0, epsilon = 0.0 ); @@ -941,17 +1038,17 @@ mod tests { #[test] fn det_direct_zero_matrix() { assert_abs_diff_eq!( - Matrix::<2>::zero().det_direct().unwrap(), + Matrix::<2>::zero().det_direct().unwrap().unwrap(), 0.0, epsilon = 0.0 ); assert_abs_diff_eq!( - Matrix::<3>::zero().det_direct().unwrap(), + Matrix::<3>::zero().det_direct().unwrap().unwrap(), 0.0, epsilon = 0.0 ); assert_abs_diff_eq!( - Matrix::<4>::zero().det_direct().unwrap(), + Matrix::<4>::zero().det_direct().unwrap().unwrap(), 0.0, epsilon = 0.0 ); @@ -1005,11 +1102,11 @@ mod tests { ($d:literal) => { paste! { /// `Matrix::::det_direct()` on the identity must const-evaluate - /// to `Some(1.0)` for every closed-form dimension `D ∈ {1, 2, 3, 4}`. + /// to `Ok(Some(1.0))` for every closed-form dimension `D ∈ {1, 2, 3, 4}`. #[test] fn []() { - const DET: Option = Matrix::<$d>::identity().det_direct(); - assert_eq!(DET, Some(1.0)); + const DET: Result, LaError> = Matrix::<$d>::identity().det_direct(); + assert_eq!(DET, Ok(Some(1.0))); } } }; @@ -1021,10 +1118,10 @@ mod tests { #[test] fn det_direct_const_eval_d5_is_none() { - // D ≥ 5 has no closed-form arm; `det_direct` returns `None`. Verify + // D ≥ 5 has no closed-form arm; `det_direct` returns `Ok(None)`. Verify // that the wildcard arm is reachable in a `const { … }` context. - const DET: Option = Matrix::<5>::identity().det_direct(); - assert_eq!(DET, None); + const DET: Result, LaError> = Matrix::<5>::identity().det_direct(); + assert_eq!(DET, Ok(None)); } // === det_errbound tests (no `exact` feature required) === @@ -1033,7 +1130,7 @@ mod tests { fn det_errbound_available_without_exact_feature() { // Verify det_errbound is accessible without exact feature let m = Matrix::<3>::identity(); - let bound = m.det_errbound(); + let bound = m.det_errbound().unwrap(); assert!(bound.is_some()); assert!(bound.unwrap() > 0.0); } @@ -1041,7 +1138,53 @@ mod tests { #[test] fn det_errbound_d5_returns_none() { // D=5 has no fast filter - assert_eq!(Matrix::<5>::identity().det_errbound(), None); + assert_eq!(Matrix::<5>::identity().det_errbound(), Ok(None)); + } + + #[test] + fn det_errbound_d1_rejects_nonfinite_even_with_zero_bound() { + let m = Matrix::<1>::from_rows([[f64::INFINITY]]); + assert_eq!( + m.det_errbound(), + Err(LaError::NonFinite { + row: Some(0), + col: 0, + }) + ); + } + + #[test] + fn det_errbound_d5_rejects_nonfinite_before_returning_none() { + let mut m = Matrix::<5>::identity(); + assert_eq!(m.set(4, 1, f64::NAN), Some(())); + assert_eq!( + m.det_errbound(), + Err(LaError::NonFinite { + row: Some(4), + col: 1, + }) + ); + } + + #[test] + fn det_errbound_rejects_nonfinite_entry_with_coordinates() { + let m = Matrix::<2>::from_rows([[1.0, f64::INFINITY], [0.0, 1.0]]); + assert_eq!( + m.det_errbound(), + Err(LaError::NonFinite { + row: Some(0), + col: 1, + }) + ); + } + + #[test] + fn det_errbound_rejects_computed_overflow() { + let m = Matrix::<2>::from_rows([[1e300, 0.0], [0.0, 1e300]]); + assert_eq!( + m.det_errbound(), + Err(LaError::NonFinite { row: None, col: 0 }) + ); } // === det_errbound const-evaluability tests (D = 2..=5) === @@ -1050,14 +1193,13 @@ mod tests { ($d:literal) => { paste! { /// `Matrix::::det_errbound()` on the identity must const-evaluate - /// to `Some(bound)` with `bound > 0` for every closed-form dimension + /// to `Ok(Some(bound))` with `bound > 0` for every closed-form dimension /// `D ∈ {2, 3, 4}`. Each dimension hits a distinct arm of /// `det_errbound` with a dimension-specific permanent computation. #[test] fn []() { - const BOUND: Option = Matrix::<$d>::identity().det_errbound(); - assert!(BOUND.is_some()); - assert!(BOUND.unwrap() > 0.0); + const BOUND: Result, LaError> = Matrix::<$d>::identity().det_errbound(); + assert!(BOUND.unwrap().unwrap() > 0.0); } } }; @@ -1069,9 +1211,9 @@ mod tests { #[test] fn det_errbound_const_eval_d5_is_none() { - // D ≥ 5 has no fast-filter bound; `det_errbound` returns `None`. - const BOUND: Option = Matrix::<5>::identity().det_errbound(); - assert_eq!(BOUND, None); + // D ≥ 5 has no fast-filter bound; `det_errbound` returns `Ok(None)`. + const BOUND: Result, LaError> = Matrix::<5>::identity().det_errbound(); + assert_eq!(BOUND, Ok(None)); } // === inf_norm const-evaluability tests (D = 2..=5) === @@ -1084,8 +1226,8 @@ mod tests { /// entry, so the max absolute row sum is exactly `1.0`. #[test] fn []() { - const NORM: f64 = Matrix::<$d>::identity().inf_norm(); - assert!((NORM - 1.0).abs() <= 1e-12); + const NORM: Result = Matrix::<$d>::identity().inf_norm(); + assert!((NORM.unwrap() - 1.0).abs() <= 1e-12); } } }; @@ -1096,40 +1238,54 @@ mod tests { gen_inf_norm_const_eval_tests!(4); gen_inf_norm_const_eval_tests!(5); - // === inf_norm NaN / Inf propagation (regression tests for #85) === + // === inf_norm NaN / Inf rejection (regression tests for #85) === macro_rules! gen_inf_norm_nonfinite_tests { ($d:literal) => { paste! { #[test] - fn []() { + fn []() { // Before the fix, `NaN > max_row_sum` was always false, so a // matrix full of NaN silently produced inf_norm == 0.0. let m = Matrix::<$d>::from_rows([[f64::NAN; $d]; $d]); - assert!(m.inf_norm().is_nan()); + assert_eq!( + m.inf_norm(), + Err(LaError::NonFinite { + row: Some(0), + col: 0, + }) + ); } #[test] - fn []() { - // A single NaN entry must contaminate its row sum and - // propagate through `f64::maximum` to the final result. + fn []() { + // A single NaN entry must surface with its source coordinates. let mut rows = [[0.0f64; $d]; $d]; rows[0][0] = f64::NAN; rows[$d - 1][$d - 1] = 1.0; let m = Matrix::<$d>::from_rows(rows); - assert!(m.inf_norm().is_nan()); + assert_eq!( + m.inf_norm(), + Err(LaError::NonFinite { + row: Some(0), + col: 0, + }) + ); } #[test] - fn []() { - // Infinity entries should propagate to +∞ via the row sum, - // not be silently dropped. The norm is a sum of absolute - // values, so any infinite result is necessarily +∞. + fn []() { + // Infinity entries should be rejected with their source coordinates. let mut rows = [[0.0f64; $d]; $d]; rows[0][0] = f64::INFINITY; let m = Matrix::<$d>::from_rows(rows); - let norm = m.inf_norm(); - assert!(norm.is_infinite() && norm.is_sign_positive()); + assert_eq!( + m.inf_norm(), + Err(LaError::NonFinite { + row: Some(0), + col: 0, + }) + ); } } }; @@ -1148,15 +1304,15 @@ mod tests { #[test] fn []() { let m = Matrix::<$d>::identity(); - assert!(m.is_symmetric(1e-12).unwrap()); - assert_eq!(m.first_asymmetry(1e-12).unwrap(), None); + assert!(m.is_symmetric(Tolerance::new(1e-12).unwrap()).unwrap()); + assert_eq!(m.first_asymmetry(Tolerance::new(1e-12).unwrap()).unwrap(), None); } #[test] fn []() { let m = Matrix::<$d>::zero(); - assert!(m.is_symmetric(1e-12).unwrap()); - assert_eq!(m.first_asymmetry(1e-12).unwrap(), None); + assert!(m.is_symmetric(Tolerance::new(1e-12).unwrap()).unwrap()); + assert_eq!(m.first_asymmetry(Tolerance::new(1e-12).unwrap()).unwrap(), None); } #[test] @@ -1178,8 +1334,8 @@ mod tests { } } let a = Matrix::<$d>::from_rows(sym); - assert!(a.is_symmetric(1e-12).unwrap()); - assert_eq!(a.first_asymmetry(1e-12).unwrap(), None); + assert!(a.is_symmetric(Tolerance::new(1e-12).unwrap()).unwrap()); + assert_eq!(a.first_asymmetry(Tolerance::new(1e-12).unwrap()).unwrap(), None); } #[test] @@ -1192,8 +1348,11 @@ mod tests { rows[0][$d - 1] = 1.0; rows[$d - 1][0] = -1.0; // breaks symmetry let a = Matrix::<$d>::from_rows(rows); - assert!(!a.is_symmetric(1e-12).unwrap()); - assert_eq!(a.first_asymmetry(1e-12).unwrap(), Some((0, $d - 1))); + assert!(!a.is_symmetric(Tolerance::new(1e-12).unwrap()).unwrap()); + assert_eq!( + a.first_asymmetry(Tolerance::new(1e-12).unwrap()).unwrap(), + Some((0, $d - 1)) + ); } #[test] @@ -1208,14 +1367,14 @@ mod tests { rows[1][0] = f64::NAN; let a = Matrix::<$d>::from_rows(rows); assert_eq!( - a.is_symmetric(1e-12), + a.is_symmetric(Tolerance::new(1e-12).unwrap()), Err(LaError::NonFinite { row: Some(0), col: 1, }) ); assert_eq!( - a.first_asymmetry(1e-12), + a.first_asymmetry(Tolerance::new(1e-12).unwrap()), Err(LaError::NonFinite { row: Some(0), col: 1, @@ -1237,29 +1396,32 @@ mod tests { // relative tolerance 1e-12 yields eps ≈ 2e-6, which accepts the gap; // a stricter tol of 1e-15 rejects it. let a = Matrix::<2>::from_rows([[1.0e6, 1.0e6 + 1.0e-6], [1.0e6, 1.0e6]]); - assert!(a.is_symmetric(1e-12).unwrap()); - assert!(!a.is_symmetric(1e-15).unwrap()); + assert!(a.is_symmetric(Tolerance::new(1e-12).unwrap()).unwrap()); + assert!(!a.is_symmetric(Tolerance::new(1e-15).unwrap()).unwrap()); } #[test] fn first_asymmetry_returns_lexicographically_first_pair() { // Two asymmetric pairs: (0, 2) and (1, 2). We must get (0, 2) first. let a = Matrix::<3>::from_rows([[1.0, 0.0, 2.0], [0.0, 1.0, 3.0], [-2.0, -3.0, 1.0]]); - assert_eq!(a.first_asymmetry(1e-12).unwrap(), Some((0, 2))); + assert_eq!( + a.first_asymmetry(Tolerance::new(1e-12).unwrap()).unwrap(), + Some((0, 2)) + ); } #[test] fn first_asymmetry_rejects_infinite_offdiagonal() { let a = Matrix::<2>::from_rows([[1.0, f64::INFINITY], [0.0, 1.0]]); assert_eq!( - a.first_asymmetry(1e-12), + a.first_asymmetry(Tolerance::new(1e-12).unwrap()), Err(LaError::NonFinite { row: Some(0), col: 1, }) ); assert_eq!( - a.is_symmetric(1e-12), + a.is_symmetric(Tolerance::new(1e-12).unwrap()), Err(LaError::NonFinite { row: Some(0), col: 1, @@ -1271,14 +1433,14 @@ mod tests { fn first_asymmetry_rejects_nan_diagonal() { let a = Matrix::<2>::from_rows([[f64::NAN, 1.0], [1.0, 1.0]]); assert_eq!( - a.first_asymmetry(1e-12), + a.first_asymmetry(Tolerance::new(1e-12).unwrap()), Err(LaError::NonFinite { row: Some(0), col: 0, }) ); assert_eq!( - a.is_symmetric(1e-12), + a.is_symmetric(Tolerance::new(1e-12).unwrap()), Err(LaError::NonFinite { row: Some(0), col: 0, @@ -1294,30 +1456,36 @@ mod tests { [f64::MAX, 0.0, f64::MAX], ]); - assert!(a.inf_norm().is_infinite()); - assert_eq!(a.first_asymmetry(0.0).unwrap(), Some((0, 1))); - assert!(!a.is_symmetric(0.0).unwrap()); + assert_eq!(a.inf_norm(), Err(LaError::NonFinite { row: None, col: 0 })); + assert_eq!( + a.first_asymmetry(Tolerance::new(0.0).unwrap()).unwrap(), + Some((0, 1)) + ); + assert!(!a.is_symmetric(Tolerance::new(0.0).unwrap()).unwrap()); } #[test] fn first_asymmetry_flags_overflowed_finite_difference() { let a = Matrix::<2>::from_rows([[1.0, f64::MAX], [-f64::MAX, 1.0]]); - assert_eq!(a.first_asymmetry(1e-12).unwrap(), Some((0, 1))); - assert!(!a.is_symmetric(1e-12).unwrap()); + assert_eq!( + a.first_asymmetry(Tolerance::new(1e-12).unwrap()).unwrap(), + Some((0, 1)) + ); + assert!(!a.is_symmetric(Tolerance::new(1e-12).unwrap()).unwrap()); } #[test] fn is_symmetric_rejects_invalid_tol() { assert_eq!( - Matrix::<2>::identity().is_symmetric(-1.0), + Tolerance::new(-1.0), Err(LaError::InvalidTolerance { value: -1.0 }) ); - assert!(matches!( - Matrix::<2>::identity().is_symmetric(f64::NAN), + assert_matches!( + Tolerance::new(f64::NAN), Err(LaError::InvalidTolerance { value }) if value.is_nan() - )); + ); assert_eq!( - Matrix::<2>::identity().is_symmetric(f64::INFINITY), + Tolerance::new(f64::INFINITY), Err(LaError::InvalidTolerance { value: f64::INFINITY, }) @@ -1327,19 +1495,19 @@ mod tests { #[test] fn first_asymmetry_rejects_negative_tol() { assert_eq!( - Matrix::<2>::identity().first_asymmetry(-1.0), + Tolerance::new(-1.0), Err(LaError::InvalidTolerance { value: -1.0 }) ); } #[test] fn first_asymmetry_rejects_nonfinite_tol() { - assert!(matches!( - Matrix::<2>::identity().first_asymmetry(f64::NAN), + assert_matches!( + Tolerance::new(f64::NAN), Err(LaError::InvalidTolerance { value }) if value.is_nan() - )); + ); assert_eq!( - Matrix::<2>::identity().first_asymmetry(f64::INFINITY), + Tolerance::new(f64::INFINITY), Err(LaError::InvalidTolerance { value: f64::INFINITY, }) diff --git a/src/vector.rs b/src/vector.rs index e9abce6..3f14c34 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -1,5 +1,7 @@ //! Fixed-size, stack-allocated vectors. +use crate::LaError; + /// Fixed-size vector of length `D`, stored inline. #[must_use] #[derive(Clone, Copy, Debug, PartialEq)] @@ -73,20 +75,35 @@ impl Vector { /// ``` /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// let a = Vector::<3>::new([1.0, 2.0, 3.0]); /// let b = Vector::<3>::new([-2.0, 0.5, 4.0]); - /// assert!((a.dot(b) - 11.0).abs() <= 1e-12); + /// assert!((a.dot(b)? - 11.0).abs() <= 1e-12); + /// # Ok(()) + /// # } /// ``` + /// + /// # Errors + /// Returns [`LaError::NonFinite`] when either input contains NaN or infinity, + /// or when the accumulated dot product overflows to NaN or infinity. #[inline] - #[must_use] - pub const fn dot(self, other: Self) -> f64 { + pub const fn dot(self, other: Self) -> Result { let mut acc = 0.0; let mut i = 0; while i < D { + if !self.data[i].is_finite() { + return Err(LaError::non_finite_at(i)); + } + if !other.data[i].is_finite() { + return Err(LaError::non_finite_at(i)); + } acc = self.data[i].mul_add(other.data[i], acc); + if !acc.is_finite() { + return Err(LaError::non_finite_at(i)); + } i += 1; } - acc + Ok(acc) } /// Squared Euclidean norm. @@ -95,12 +112,18 @@ impl Vector { /// ``` /// use la_stack::prelude::*; /// + /// # fn main() -> Result<(), LaError> { /// let v = Vector::<3>::new([1.0, 2.0, 3.0]); - /// assert!((v.norm2_sq() - 14.0).abs() <= 1e-12); + /// assert!((v.norm2_sq()? - 14.0).abs() <= 1e-12); + /// # Ok(()) + /// # } /// ``` + /// + /// # Errors + /// Returns [`LaError::NonFinite`] when the input contains NaN or infinity, + /// or when the accumulated norm overflows to NaN or infinity. #[inline] - #[must_use] - pub const fn norm2_sq(self) -> f64 { + pub const fn norm2_sq(self) -> Result { self.dot(self) } } @@ -209,11 +232,44 @@ mod tests { // Call via (black_boxed) fn pointers to discourage inlining, improving line-level coverage // attribution for the loop body. - let dot_fn: fn(Vector<$d>, Vector<$d>) -> f64 = black_box(Vector::<$d>::dot); - let norm2_sq_fn: fn(Vector<$d>) -> f64 = black_box(Vector::<$d>::norm2_sq); + let dot_fn: fn(Vector<$d>, Vector<$d>) -> Result = + black_box(Vector::<$d>::dot); + let norm2_sq_fn: fn(Vector<$d>) -> Result = + black_box(Vector::<$d>::norm2_sq); + + assert_abs_diff_eq!( + dot_fn(black_box(a), black_box(b)).unwrap(), + expected_dot, + epsilon = 1e-14 + ); + assert_abs_diff_eq!( + norm2_sq_fn(black_box(a)).unwrap(), + expected_norm2_sq, + epsilon = 1e-14 + ); + } + + #[test] + fn []() { + let mut a_arr = [1.0f64; $d]; + a_arr[$d - 1] = f64::NAN; + let a = Vector::<$d>::new(a_arr); + let b = Vector::<$d>::new([1.0; $d]); - assert_abs_diff_eq!(dot_fn(black_box(a), black_box(b)), expected_dot, epsilon = 1e-14); - assert_abs_diff_eq!(norm2_sq_fn(black_box(a)), expected_norm2_sq, epsilon = 1e-14); + assert_eq!( + a.dot(b), + Err(LaError::NonFinite { + row: None, + col: $d - 1, + }) + ); + assert_eq!( + a.norm2_sq(), + Err(LaError::NonFinite { + row: None, + col: $d - 1, + }) + ); } } }; diff --git a/tests/proptest_exact.rs b/tests/proptest_exact.rs index 84f59b9..7166f6a 100644 --- a/tests/proptest_exact.rs +++ b/tests/proptest_exact.rs @@ -300,8 +300,14 @@ macro_rules! gen_det_sign_fast_filter_boundary_proptests { ), ) { let m = Matrix::<$d>::from_rows(entries); - let det = m.det_direct().expect("D<=4 has closed-form det_direct"); - let bound = m.det_errbound().expect("D<=4 has a det_errbound"); + let det = m + .det_direct() + .unwrap() + .expect("D<=4 has closed-form det_direct"); + let bound = m + .det_errbound() + .unwrap() + .expect("D<=4 has a det_errbound"); let sign = m.det_sign_exact().unwrap(); // Only assert when the filter is conclusive. When diff --git a/tests/proptest_factorizations.rs b/tests/proptest_factorizations.rs index e26d292..c4ef6a0 100644 --- a/tests/proptest_factorizations.rs +++ b/tests/proptest_factorizations.rs @@ -89,7 +89,7 @@ macro_rules! gen_factorization_proptests { let a = Matrix::<$d>::from_rows(a_rows); let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); - assert_abs_diff_eq!(ldlt.det(), expected_det, epsilon = 1e-8); + assert_abs_diff_eq!(ldlt.det().unwrap(), expected_det, epsilon = 1e-8); let b = Vector::<$d>::new(b_arr); let x = ldlt.solve_vec(b).unwrap().into_array(); @@ -169,7 +169,7 @@ macro_rules! gen_factorization_proptests { let a = Matrix::<$d>::from_rows(a_rows); let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); - assert_abs_diff_eq!(lu.det(), expected_det, epsilon = 1e-8); + assert_abs_diff_eq!(lu.det().unwrap(), expected_det, epsilon = 1e-8); let b = Vector::<$d>::new(b_arr); let x = lu.solve_vec(b).unwrap().into_array(); @@ -255,7 +255,7 @@ macro_rules! gen_factorization_proptests { let a = Matrix::<$d>::from_rows(a_rows); let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap(); - assert_abs_diff_eq!(lu.det(), expected_det, epsilon = 1e-8); + assert_abs_diff_eq!(lu.det().unwrap(), expected_det, epsilon = 1e-8); let b = Vector::<$d>::new(b_arr); let x = lu.solve_vec(b).unwrap().into_array(); diff --git a/tests/proptest_matrix.rs b/tests/proptest_matrix.rs index b77ccc9..c79adb3 100644 --- a/tests/proptest_matrix.rs +++ b/tests/proptest_matrix.rs @@ -73,12 +73,44 @@ macro_rules! gen_public_api_matrix_proptests { v in small_f64(), ) { let mut m = Matrix::<$d>::zero(); - prop_assert!(m.set(r, c, v)); + prop_assert_eq!(m.set(r, c, v), Some(())); assert_abs_diff_eq!(m.get(r, c).unwrap(), v, epsilon = 0.0); prop_assert_eq!(m.set_checked(r, c, -v), Ok(())); assert_abs_diff_eq!(m.get_checked(r, c).unwrap(), -v, epsilon = 0.0); } + #[test] + fn []( + rows in proptest::array::[]( + proptest::array::[](small_f64()), + ), + v in small_f64(), + ) { + let mut m = Matrix::<$d>::from_rows(rows); + let original = m; + + prop_assert_eq!(m.set($d, 0, v), None); + prop_assert_eq!(m, original); + prop_assert_eq!( + m.set_checked($d, 0, v), + Err(LaError::IndexOutOfBounds { + row: $d, + col: 0, + dim: $d, + }) + ); + prop_assert_eq!(m, original); + prop_assert_eq!( + m.set_checked(0, $d, v), + Err(LaError::IndexOutOfBounds { + row: 0, + col: $d, + dim: $d, + }) + ); + prop_assert_eq!(m, original); + } + #[test] fn []( rows in proptest::array::[]( @@ -92,8 +124,8 @@ macro_rules! gen_public_api_matrix_proptests { .map(|row| row.iter().map(|&x| x.abs()).sum::()) .fold(0.0f64, f64::max); - assert_abs_diff_eq!(m.inf_norm(), expected, epsilon = 0.0); - prop_assert!(m.inf_norm() >= 0.0); + assert_abs_diff_eq!(m.inf_norm().unwrap(), expected, epsilon = 0.0); + prop_assert!(m.inf_norm().unwrap() >= 0.0); } #[test] @@ -179,7 +211,7 @@ macro_rules! gen_public_api_matrix_proptests { let a = Matrix::<$d>::from_rows(a_rows); let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap(); - let det_ldlt = ldlt.det(); + let det_ldlt = ldlt.det().unwrap(); let det_lu = a.det(DEFAULT_PIVOT_TOL).unwrap(); assert_abs_diff_eq!(det_ldlt, det_lu, epsilon = 1e-8); diff --git a/tests/proptest_vector.rs b/tests/proptest_vector.rs index 11d754b..563a12b 100644 --- a/tests/proptest_vector.rs +++ b/tests/proptest_vector.rs @@ -40,19 +40,19 @@ macro_rules! gen_public_api_vector_proptests { let a = Vector::<$d>::new(a_arr); let b = Vector::<$d>::new(b_arr); - let dot_ab = a.dot(b); - let dot_ba = b.dot(a); - assert_abs_diff_eq!(dot_ab, dot_ba, epsilon = 1e-14); + let dot_ab = a.dot(b).unwrap(); + let dot_reversed = b.dot(a).unwrap(); + assert_abs_diff_eq!(dot_ab, dot_reversed, epsilon = 1e-14); - let dot_aa = a.dot(a); - assert_abs_diff_eq!(a.norm2_sq(), dot_aa, epsilon = 0.0); + let dot_aa = a.dot(a).unwrap(); + assert_abs_diff_eq!(a.norm2_sq().unwrap(), dot_aa, epsilon = 0.0); // Squared norm is always non-negative for finite inputs. - prop_assert!(a.norm2_sq() >= 0.0); + prop_assert!(a.norm2_sq().unwrap() >= 0.0); // Dot with zero vector is zero. let z = Vector::<$d>::zero(); - assert_abs_diff_eq!(a.dot(z), 0.0, epsilon = 1e-14); + assert_abs_diff_eq!(a.dot(z).unwrap(), 0.0, epsilon = 1e-14); } } }