diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..59b42c9 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,109 @@ +name: test + +on: + push: + +env: + RUST_BACKTRACE: full + CARGO_TERM_COLOR: always + +jobs: + test: + name: test platform builds and run tests + strategy: + fail-fast: false + matrix: + os: + - macos-latest + - ubuntu-latest + - windows-latest + rust: + - beta + - stable + - nightly + + runs-on: ${{ matrix.os }} + + steps: + - name: checkout + uses: actions/checkout@v5 + + - name: install rust + uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ matrix.rust }} + default: true + profile: minimal + + - name: restore cache + uses: Swatinem/rust-cache@v2 + + - name: test + env: + OS: ${{ matrix.os }} + RUST_VERSION: ${{ matrix.rust }} + RUSTFLAGS: -D warnings + run: cargo test + + rustfmt: + name: check formatting + runs-on: ubuntu-latest + steps: + - name: checkout + uses: actions/checkout@v5 + + - name: install rust + uses: actions-rs/toolchain@v1 + with: + default: true + profile: minimal + toolchain: stable + components: rustfmt + + - name: formatter test + run: cargo fmt --all -- --check + + clippy: + name: run linter + runs-on: ubuntu-latest + steps: + - name: checkout repository + uses: actions/checkout@v5 + + - name: install rust + uses: actions-rs/toolchain@v1 + with: + default: true + profile: minimal + toolchain: stable + components: clippy + + - name: restore cache + uses: Swatinem/rust-cache@v2 + + - name: clippy test + run: cargo clippy --all --tests -- -D clippy::all -D warnings + + miri: + name: sanitize unsafe + runs-on: ubuntu-latest + steps: + - name: checkout repository + uses: actions/checkout@v5 + + - name: install rust + uses: actions-rs/toolchain@v1 + with: + default: true + profile: minimal + toolchain: nightly + components: "miri" + + - name: restore cache + uses: Swatinem/rust-cache@v2 + + - name: miri test + env: + PROPTEST_CASES: "10" + MIRIFLAGS: "-Zmiri-disable-isolation -Zmiri-permissive-provenance" + run: cargo miri test diff --git a/auto-core/src/lib.rs b/auto-core/src/lib.rs index e9079e9..e65b8f4 100644 --- a/auto-core/src/lib.rs +++ b/auto-core/src/lib.rs @@ -448,10 +448,7 @@ where Var { value, - inner: VarInner { - index, - tape, - }, + inner: VarInner { index, tape }, phantom: PhantomData, } } diff --git a/auto-scalar/src/lib.rs b/auto-scalar/src/lib.rs index 0f2ef61..c68fac7 100644 --- a/auto-scalar/src/lib.rs +++ b/auto-scalar/src/lib.rs @@ -606,6 +606,28 @@ mod tests { use super::*; use lib_auto_core::{Gradient, Tape}; + // tolerant to rounding differences + fn float_cmp(a: f64, b: f64, epsilon: f64) -> bool { + (a - b).abs() < epsilon + } + + macro_rules! assert_float_eq { + ($left:expr, $right:expr) => { + assert_float_eq!($left, $right, 1e-14) + }; + ($left:expr, $right:expr, $epsilon:expr) => { + let left = $left; + let right = $right; + assert!( + float_cmp(left, right, $epsilon), + "assertion `left ≃ right` failed\n left: {}\n right: {}\n diff: {}", + left, + right, + (left - right).abs() + ); + }; + } + mod var { use super::*; @@ -614,7 +636,7 @@ mod tests { let mut tape: Tape = Tape::new(); tape.scope(|guard| { let a = guard.var(1.3); - assert_eq!(*a.value(), 1.3); + assert_float_eq!(*a.value(), 1.3); }); } @@ -625,12 +647,12 @@ mod tests { let a = guard.var(3.0); let b = guard.var(4.0); let c = a.add(&b); - assert_eq!(*c.value(), 7.0); + assert_float_eq!(*c.value(), 7.0); let (_, grads) = guard.lock().collapse(); let dc = c.deltas(&grads); // df/da = 1, df/db = 1 - assert_eq!(dc[&a], 1.0); - assert_eq!(dc[&b], 1.0); + assert_float_eq!(dc[&a], 1.0); + assert_float_eq!(dc[&b], 1.0); }); } @@ -640,11 +662,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(3.0); let c = a.add_f64(5.0); - assert_eq!(*c.value(), 8.0); + assert_float_eq!(*c.value(), 8.0); let (_, grads) = guard.lock().collapse(); let dc = c.deltas(&grads); // df/da = 1 - assert_eq!(dc[&a], 1.0); + assert_float_eq!(dc[&a], 1.0); }); } @@ -655,12 +677,12 @@ mod tests { let a = guard.var(7.0); let b = guard.var(4.0); let c = a.sub(&b); - assert_eq!(*c.value(), 3.0); + assert_float_eq!(*c.value(), 3.0); let (_, grads) = guard.lock().collapse(); let dc = c.deltas(&grads); // df/da = 1, df/db = -1 - assert_eq!(dc[&a], 1.0); - assert_eq!(dc[&b], -1.0); + assert_float_eq!(dc[&a], 1.0); + assert_float_eq!(dc[&b], -1.0); }); } @@ -670,11 +692,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(7.0); let c = a.sub_f64(3.0); - assert_eq!(*c.value(), 4.0); + assert_float_eq!(*c.value(), 4.0); let (_, grads) = guard.lock().collapse(); let dc = c.deltas(&grads); // df/da = 1 - assert_eq!(dc[&a], 1.0); + assert_float_eq!(dc[&a], 1.0); }); } @@ -685,12 +707,12 @@ mod tests { let a = guard.var(3.0); let b = guard.var(4.0); let c = a.mul(&b); - assert_eq!(*c.value(), 12.0); + assert_float_eq!(*c.value(), 12.0); let (_, grads) = guard.lock().collapse(); let dc = c.deltas(&grads); // df/da = b, df/db = a - assert_eq!(dc[&a], 4.0); - assert_eq!(dc[&b], 3.0); + assert_float_eq!(dc[&a], 4.0); + assert_float_eq!(dc[&b], 3.0); }); } @@ -700,11 +722,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(3.0); let c = a.mul_f64(5.0); - assert_eq!(*c.value(), 15.0); + assert_float_eq!(*c.value(), 15.0); let (_, grads) = guard.lock().collapse(); let dc = c.deltas(&grads); // df/da = 5 - assert_eq!(dc[&a], 5.0); + assert_float_eq!(dc[&a], 5.0); }); } @@ -715,12 +737,12 @@ mod tests { let a = guard.var(6.0); let b = guard.var(3.0); let c = a.div(&b); - assert_eq!(*c.value(), 2.0); + assert_float_eq!(*c.value(), 2.0); let (_, grads) = guard.lock().collapse(); let dc = c.deltas(&grads); // df/da = 1/b, df/db = -a/b^2 - assert_eq!(dc[&a], 1.0 / 3.0); - assert_eq!(dc[&b], -6.0 / 9.0); + assert_float_eq!(dc[&a], 1.0 / 3.0); + assert_float_eq!(dc[&b], -6.0 / 9.0); }); } @@ -730,11 +752,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(6.0); let c = a.div_f64(2.0); - assert_eq!(*c.value(), 3.0); + assert_float_eq!(*c.value(), 3.0); let (_, grads) = guard.lock().collapse(); let dc = c.deltas(&grads); // df/da = 1/2 - assert_eq!(dc[&a], 0.5); + assert_float_eq!(dc[&a], 0.5); }); } @@ -745,13 +767,13 @@ mod tests { let a = guard.var(2.0); let b = guard.var(3.0); let c = a.pow(&b); - assert_eq!(*c.value(), 8.0); + assert_float_eq!(*c.value(), 8.0); let (_, grads) = guard.lock().collapse(); let dc = c.deltas(&grads); // df/da = b * a^(b-1) // df/db = a^b * ln(a) - assert_eq!(dc[&a], 3.0 * 4.0); - assert_eq!(dc[&b], 8.0 * 2.0f64.ln()); + assert_float_eq!(dc[&a], 3.0 * 4.0); + assert_float_eq!(dc[&b], 8.0 * 2.0f64.ln()); }); } @@ -761,11 +783,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(2.0); let c = a.powf(3.0); - assert_eq!(*c.value(), 8.0); + assert_float_eq!(*c.value(), 8.0); let (_, grads) = guard.lock().collapse(); let dc = c.deltas(&grads); // df/da = 3 * a^(3-1) - assert_eq!(dc[&a], 12.0); + assert_float_eq!(dc[&a], 12.0); }); } @@ -775,11 +797,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(2.0); let neg_a = a.neg(); - assert_eq!(*neg_a.value(), -2.0); + assert_float_eq!(*neg_a.value(), -2.0); let (_, grads) = guard.lock().collapse(); let dneg_a = neg_a.deltas(&grads); // df/da = -1 - assert_eq!(dneg_a[&a], -1.0); + assert_float_eq!(dneg_a[&a], -1.0); }); } @@ -789,11 +811,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(1.3); let b = a.reciprocal(); - assert_eq!(*b.value(), 1.0 / 1.3); + assert_float_eq!(*b.value(), 1.0 / 1.3); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = -1/a^2 - assert_eq!(db[&a], -1.0 / (1.3 * 1.3)); + assert_float_eq!(db[&a], -1.0 / (1.3 * 1.3)); }); } @@ -803,11 +825,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(1.3); let b = a.sin(); - assert_eq!(*b.value(), 1.3f64.sin()); + assert_float_eq!(*b.value(), 1.3f64.sin()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = cos(a) - assert_eq!(db[&a], 1.3f64.cos()); + assert_float_eq!(db[&a], 1.3f64.cos()); }); } @@ -817,11 +839,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(3.1); let b = a.cos(); - assert_eq!(*b.value(), 3.1f64.cos()); + assert_float_eq!(*b.value(), 3.1f64.cos()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = -sin(a) - assert_eq!(db[&a], -3.1f64.sin()); + assert_float_eq!(db[&a], -3.1f64.sin()); }); } @@ -831,11 +853,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(5.6); let b = a.tan(); - assert_eq!(*b.value(), 5.6f64.tan()); + assert_float_eq!(*b.value(), 5.6f64.tan()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = sec^2(a) - assert_eq!(db[&a], 1.0 / (5.6f64.cos().powi(2))); + assert_float_eq!(db[&a], 1.0 / (5.6f64.cos().powi(2))); }); } @@ -845,11 +867,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(5.6); let b = a.ln(); - assert_eq!(*b.value(), 5.6f64.ln()); + assert_float_eq!(*b.value(), 5.6f64.ln()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/a - assert_eq!(db[&a], 1.0 / 5.6); + assert_float_eq!(db[&a], 1.0 / 5.6); }); } @@ -860,11 +882,11 @@ mod tests { let a = guard.var(5.6); let base = 3.0; let b = a.log(base); - assert_eq!(*b.value(), 5.6f64.log(base)); + assert_float_eq!(*b.value(), 5.6f64.log(base)); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/(a ln(b)) - assert_eq!(db[&a], 1.0 / (5.6 * base.ln())); + assert_float_eq!(db[&a], 1.0 / (5.6 * base.ln())); }); } @@ -874,11 +896,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(1.0); let b = a.log10(); - assert_eq!(*b.value(), 0.0); + assert_float_eq!(*b.value(), 0.0); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/(a ln(10)) - assert_eq!(db[&a], 1.0 / 10.0f64.ln()); + assert_float_eq!(db[&a], 1.0 / 10.0f64.ln()); }); } @@ -888,11 +910,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(2.0); let b = a.log2(); - assert_eq!(*b.value(), 1.0); + assert_float_eq!(*b.value(), 1.0); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/(a ln(2)) - assert_eq!(db[&a], 1.0 / (2.0 * 2.0f64.ln())); + assert_float_eq!(db[&a], 1.0 / (2.0 * 2.0f64.ln())); }); } @@ -902,11 +924,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(0.5); let b = a.asin(); - assert_eq!(*b.value(), 0.5f64.asin()); + assert_float_eq!(*b.value(), 0.5f64.asin()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/sqrt(1-a^2) - assert_eq!(db[&a], 1.0 / f64::sqrt(1.0 - 0.25)); + assert_float_eq!(db[&a], 1.0 / f64::sqrt(1.0 - 0.25)); }); } @@ -916,11 +938,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(0.5); let b = a.acos(); - assert_eq!(*b.value(), 0.5f64.acos()); + assert_float_eq!(*b.value(), 0.5f64.acos()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = -1/sqrt(1-a^2) - assert_eq!(db[&a], -1.0 / f64::sqrt(1.0 - 0.25)); + assert_float_eq!(db[&a], -1.0 / f64::sqrt(1.0 - 0.25)); }); } @@ -930,11 +952,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(1.2); let b = a.atan(); - assert_eq!(*b.value(), 1.2f64.atan()); + assert_float_eq!(*b.value(), 1.2f64.atan()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/(1+a^2) - assert_eq!(db[&a], 1.0 / (1.0 + 1.44)); + assert_float_eq!(db[&a], 1.0 / (1.0 + 1.44)); }); } @@ -944,11 +966,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(1.3); let b = a.sinh(); - assert_eq!(*b.value(), 1.3f64.sinh()); + assert_float_eq!(*b.value(), 1.3f64.sinh()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = cosh(a) - assert_eq!(db[&a], 1.3f64.cosh()); + assert_float_eq!(db[&a], 1.3f64.cosh()); }); } @@ -958,11 +980,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(1.3); let b = a.cosh(); - assert_eq!(*b.value(), 1.3f64.cosh()); + assert_float_eq!(*b.value(), 1.3f64.cosh()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = sinh(a) - assert_eq!(db[&a], 1.3f64.sinh()); + assert_float_eq!(db[&a], 1.3f64.sinh()); }); } @@ -972,11 +994,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(0.8); let b = a.tanh(); - assert_eq!(*b.value(), 0.8f64.tanh()); + assert_float_eq!(*b.value(), 0.8f64.tanh()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/cosh^2(a) - assert_eq!(db[&a], 1.0 / (0.8f64.cosh() * 0.8f64.cosh())); + assert_float_eq!(db[&a], 1.0 / (0.8f64.cosh() * 0.8f64.cosh())); }); } @@ -986,11 +1008,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(1.5); let b = a.asinh(); - assert_eq!(*b.value(), 1.5f64.asinh()); + assert_float_eq!(*b.value(), 1.5f64.asinh()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/sqrt(1+a^2) - assert_eq!(db[&a], 1.0 / f64::sqrt(1.0 + 2.25)); + assert_float_eq!(db[&a], 1.0 / f64::sqrt(1.0 + 2.25)); }); } @@ -1000,11 +1022,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(2.0); let b = a.acosh(); - assert_eq!(*b.value(), 2.0f64.acosh()); + assert_float_eq!(*b.value(), 2.0f64.acosh()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/sqrt(a^2-1) - assert_eq!(db[&a], 1.0 / f64::sqrt(4.0 - 1.0)); + assert_float_eq!(db[&a], 1.0 / f64::sqrt(4.0 - 1.0)); }); } @@ -1014,11 +1036,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(0.0); let b = a.atanh(); - assert_eq!(*b.value(), 0.0); + assert_float_eq!(*b.value(), 0.0); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/(1-a^2) - assert_eq!(db[&a], 1.0); + assert_float_eq!(db[&a], 1.0); }); } @@ -1028,11 +1050,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(1.3); let b = a.exp(); - assert_eq!(*b.value(), 1.3f64.exp()); + assert_float_eq!(*b.value(), 1.3f64.exp()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = exp(a) - assert_eq!(db[&a], 1.3f64.exp()); + assert_float_eq!(db[&a], 1.3f64.exp()); }); } @@ -1042,11 +1064,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(1.3); let b = a.exp2(); - assert_eq!(*b.value(), 1.3f64.exp2()); + assert_float_eq!(*b.value(), 1.3f64.exp2()); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = exp2(a) * ln(2) - assert_eq!(db[&a], 1.3f64.exp2() * 2.0f64.ln()); + assert_float_eq!(db[&a], 1.3f64.exp2() * 2.0f64.ln()); }); } @@ -1056,11 +1078,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(4.0); let b = a.sqrt(); - assert_eq!(*b.value(), 2.0); + assert_float_eq!(*b.value(), 2.0); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/(2*sqrt(a)) - assert_eq!(db[&a], 0.25); + assert_float_eq!(db[&a], 0.25); }); } @@ -1070,11 +1092,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(1.0); let b = a.cbrt(); - assert_eq!(*b.value(), 1.0); + assert_float_eq!(*b.value(), 1.0); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = 1/(3*cbrt(a^2)) - assert_eq!(db[&a], 1.0 / 3.0); + assert_float_eq!(db[&a], 1.0 / 3.0); }); } @@ -1084,11 +1106,11 @@ mod tests { tape.scope(|guard| { let a = guard.var(-3.0); let b = a.abs(); - assert_eq!(*b.value(), 3.0); + assert_float_eq!(*b.value(), 3.0); let (_, grads) = guard.lock().collapse(); let db = b.deltas(&grads); // df/da = a/|a| = -1 for a < 0 - assert_eq!(db[&a], -1.0); + assert_float_eq!(db[&a], -1.0); }); } } @@ -1105,15 +1127,16 @@ mod tests { let c = guard.var(1.0); let res = a.pow(&b).sub(&c.asinh().div_f64(2.0)).add_f64(1.0f64.sin()); let expected = 5.0f64.powf(2.0) - 1.0f64.asinh() / 2.0 + 1.0f64.sin(); - assert_eq!(*res.value(), expected); + // relax equality due to accumulated floating point errors... + assert_float_eq!(*res.value(), expected, 1e-9); let (_, grads) = guard.lock().collapse(); let dres = res.deltas(&grads); let ga = dres[&a]; // df/da let gb = dres[&b]; // df/db let gc = dres[&c]; // df/dc - assert_eq!(ga, 2.0 * 5.0); - assert_eq!(gb, 25.0 * 5.0f64.ln()); - assert_eq!(gc, -1.0 / (2.0 * 2.0f64.sqrt())); + assert_float_eq!(ga, 2.0 * 5.0, 1e-9); + assert_float_eq!(gb, 25.0 * 5.0f64.ln(), 1e-9); + assert_float_eq!(gc, -1.0 / (2.0 * 2.0f64.sqrt()), 1e-9); }); } } diff --git a/examples/simple.rs b/examples/simple.rs index 8b9fdf3..8097e14 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -14,4 +14,4 @@ fn main() { let wrt_y = y.deltas(&grads); println!("Value: {}, dy/dx: {}", y.value(), wrt_y[&x]); }); -} \ No newline at end of file +}