From 015b4deaee6c57cad4a93adad6fd74d071711e7a Mon Sep 17 00:00:00 2001 From: ethan <32364921+sangyx@users.noreply.github.com> Date: Wed, 20 May 2020 18:20:46 +0800 Subject: [PATCH] Add some simple test cases (#15) --- src/tests/test_builtin_diffs.py | 52 +++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/tests/test_builtin_diffs.py b/src/tests/test_builtin_diffs.py index 41f80f4..f9efe26 100644 --- a/src/tests/test_builtin_diffs.py +++ b/src/tests/test_builtin_diffs.py @@ -171,3 +171,55 @@ def test_arbitrary_function(backend, func, y_d, domain): ret.compute() assert_allclose(ret.diffs[x].arr, y_d_arr.arr) + +@pytest.mark.xfail +@pytest.mark.parametrize( + "u, diff_ndim, func, diff_u", + [ + (onp.arange(24).reshape(2, 3, 4, 1), 2, lambda x: x, onp.tile(onp.eye(4), (2, 3, 1)).reshape(2, 3, 4, 1, 4, 1)), + (onp.arange(12).reshape(2, 3, 2), 2, lambda x: np.sum(x, axis=1), onp.tile([[1, 0], [0, 1]], (2, 3)).reshape(2, 1, 2, 3, 2)), + (onp.arange(12).reshape((2, 3, 2)), 2, lambda x: x, onp.stack([onp.eye(6).reshape(3, 2, 3, 2), onp.eye(6).reshape(3, 2, 3, 2)])) + ], +) +def test_separation_unary(backend, u, diff_ndim, func, diff_u): + try: + with ua.set_backend(backend), ua.set_backend(udiff, coerce=True): + u = np.asarray(u) + u.var = udiff.Variable('u', diff_ndim=diff_ndim) + ret = func(u) + except ua.BackendNotImplementedError: + if backend in FULLY_TESTED_BACKENDS: + raise + pytest.xfail(reason="The backend has no implementation for this ufunc.") + + if isinstance(ret, da.Array): + ret.compute() + + assert_allclose(ret.diffs[u].arr, diff_u.tolist()) + +@pytest.mark.xfail +@pytest.mark.parametrize( + "u, v, u_diff_ndim, v_diff_ndim, func, diff_u, diff_v", + [ + (onp.arange(2).reshape(1, 2, 1), onp.arange(2).reshape(1, 1, 2), 2, 2, lambda x, y: np.matmul(x, y), onp.array([[0, 0, 1, 0], [0, 0, 0, 1]]).reshape(1, 2, 2, 2, 1), onp.array([[0, 0, 0, 0], [1, 0, 0, 1]]).reshape(1, 2, 2, 1, 2)) + ], +) +def test_separation_binary(backend, u, v, u_diff_ndim, v_diff_ndim, func, diff_u, diff_v): + try: + with ua.set_backend(backend), ua.set_backend(udiff, coerce=True): + u = np.asarray(u) + u.var = udiff.Variable('u', diff_ndim=u_diff_ndim) + v = np.asarray(v) + v.var = udiff.Variable('v', diff_ndim=v_diff_ndim) + + ret = func(u, v) + except ua.BackendNotImplementedError: + if backend in FULLY_TESTED_BACKENDS: + raise + pytest.xfail(reason="The backend has no implementation for this ufunc.") + + if isinstance(ret, da.Array): + ret.compute() + + assert_allclose(ret.diffs[u].arr, diff_u.tolist()) + assert_allclose(ret.diffs[v].arr, diff_v.tolist()) \ No newline at end of file