Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions test/mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), T...)
make_mooncake_fdata(x) = make_mooncake_tangent(x)
make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),))

ETs = (Float64, Float32, ComplexF64, ComplexF32)
ETs = (Float32, ComplexF64)

# no `alg` argument
function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata)
Expand Down Expand Up @@ -251,7 +251,10 @@ end
dV = make_mooncake_tangent(ΔV)
dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV)
# compute the dA corresponding to the above dD, dV
@testset for alg in (LAPACK_Simple(), LAPACK_Expert())
@testset for alg in (
LAPACK_Simple(),
#LAPACK_Expert(), # expensive on CI
)
@testset "eig_full" begin
Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol)
test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg)
Expand Down Expand Up @@ -340,9 +343,9 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop
Ddiag = diagview(D)
@testset for alg in (
LAPACK_QRIteration(),
LAPACK_DivideAndConquer(),
LAPACK_Bisection(),
LAPACK_MultipleRelativelyRobustRepresentations(),
#LAPACK_DivideAndConquer(),
#LAPACK_Bisection(),
#LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI
)
@testset "eigh_full" begin
Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol)
Expand Down Expand Up @@ -390,7 +393,7 @@ end
minmn = min(m, n)
@testset for alg in (
LAPACK_QRIteration(),
LAPACK_DivideAndConquer(),
#LAPACK_DivideAndConquer(), # expensive on CI
)
@testset "svd_compact" begin
ΔU = randn(rng, T, m, minmn)
Expand Down Expand Up @@ -490,7 +493,12 @@ end
@testset "size ($m, $n)" for n in (17, m, 23)
atol = rtol = m * n * precision(T)
A = randn(rng, T, m, n)
@testset for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer()))
@testset for alg in PolarViaSVD.(
(
LAPACK_QRIteration(),
#LAPACK_DivideAndConquer(), # expensive on CI
)
)
if m >= n
WP = left_polar(A, alg)
Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol)
Expand Down