diff --git a/test/mooncake.jl b/test/mooncake.jl index 7aed68cf..3e19e44d 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -22,7 +22,7 @@ make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), T...) make_mooncake_fdata(x) = make_mooncake_tangent(x) make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) -ETs = (Float64, Float32, ComplexF64, ComplexF32) +ETs = (Float32, ComplexF64) # no `alg` argument function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) @@ -251,7 +251,10 @@ end dV = make_mooncake_tangent(ΔV) dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) # compute the dA corresponding to the above dD, dV - @testset for alg in (LAPACK_Simple(), LAPACK_Expert()) + @testset for alg in ( + LAPACK_Simple(), + #LAPACK_Expert(), # expensive on CI + ) @testset "eig_full" begin Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) @@ -340,9 +343,9 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop Ddiag = diagview(D) @testset for alg in ( LAPACK_QRIteration(), - LAPACK_DivideAndConquer(), - LAPACK_Bisection(), - LAPACK_MultipleRelativelyRobustRepresentations(), + #LAPACK_DivideAndConquer(), + #LAPACK_Bisection(), + #LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI ) @testset "eigh_full" begin Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) @@ -390,7 +393,7 @@ end minmn = min(m, n) @testset for alg in ( LAPACK_QRIteration(), - LAPACK_DivideAndConquer(), + #LAPACK_DivideAndConquer(), # expensive on CI ) @testset "svd_compact" begin ΔU = randn(rng, T, m, minmn) @@ -490,7 +493,12 @@ end @testset "size ($m, $n)" for n in (17, m, 23) atol = rtol = m * n * precision(T) A = randn(rng, T, m, n) - @testset for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) + @testset for alg in PolarViaSVD.( + ( + LAPACK_QRIteration(), + #LAPACK_DivideAndConquer(), # expensive on CI + ) + ) if m >= n WP = left_polar(A, alg) Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol)