diff --git a/test/mooncake/polar.jl b/test/mooncake/polar.jl index 6442b7b33..ea7509426 100644 --- a/test/mooncake/polar.jl +++ b/test/mooncake/polar.jl @@ -13,14 +13,23 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) + atol = rtol = m * n * TestSuite.precision(T) if !is_buildkite - atol = rtol = m * n * TestSuite.precision(T) m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol) n >= m && TestSuite.test_mooncake_right_polar(T, (m, n); atol, rtol) #=if m == n AT = Diagonal{T, Vector{T}} - TestSuite.test_mooncake_left_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - TestSuite.test_mooncake_right_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - end=# # broken due to pullback + TestSuite.test_mooncake_left_polar(AT, m; atol, rtol) + TestSuite.test_mooncake_right_polar(AT, m; atol, rtol) + end=# + end + if T ∈ BLASFloats && CUDA.functional() + m >= n && TestSuite.test_mooncake_left_polar(CuMatrix{T}, (m, n); atol, rtol) + n >= m && TestSuite.test_mooncake_right_polar(CuMatrix{T}, (m, n); atol, rtol) + #=if m == n + AT = Diagonal{T, CuVector{T}} + TestSuite.test_mooncake_left_polar(AT, m; atol, rtol) + TestSuite.test_mooncake_right_polar(AT, m; atol, rtol) + end=# end end