Skip to content

Commit

Permalink
Add some simple test cases (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
sangyx committed May 20, 2020
1 parent 57594f1 commit 015b4de
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/tests/test_builtin_diffs.py
Expand Up @@ -171,3 +171,55 @@ def test_arbitrary_function(backend, func, y_d, domain):
ret.compute()

assert_allclose(ret.diffs[x].arr, y_d_arr.arr)

@pytest.mark.xfail
@pytest.mark.parametrize(
"u, diff_ndim, func, diff_u",
[
(onp.arange(24).reshape(2, 3, 4, 1), 2, lambda x: x, onp.tile(onp.eye(4), (2, 3, 1)).reshape(2, 3, 4, 1, 4, 1)),
(onp.arange(12).reshape(2, 3, 2), 2, lambda x: np.sum(x, axis=1), onp.tile([[1, 0], [0, 1]], (2, 3)).reshape(2, 1, 2, 3, 2)),
(onp.arange(12).reshape((2, 3, 2)), 2, lambda x: x, onp.stack([onp.eye(6).reshape(3, 2, 3, 2), onp.eye(6).reshape(3, 2, 3, 2)]))
],
)
def test_separation_unary(backend, u, diff_ndim, func, diff_u):
try:
with ua.set_backend(backend), ua.set_backend(udiff, coerce=True):
u = np.asarray(u)
u.var = udiff.Variable('u', diff_ndim=diff_ndim)
ret = func(u)
except ua.BackendNotImplementedError:
if backend in FULLY_TESTED_BACKENDS:
raise
pytest.xfail(reason="The backend has no implementation for this ufunc.")

if isinstance(ret, da.Array):
ret.compute()

assert_allclose(ret.diffs[u].arr, diff_u.tolist())

@pytest.mark.xfail
@pytest.mark.parametrize(
"u, v, u_diff_ndim, v_diff_ndim, func, diff_u, diff_v",
[
(onp.arange(2).reshape(1, 2, 1), onp.arange(2).reshape(1, 1, 2), 2, 2, lambda x, y: np.matmul(x, y), onp.array([[0, 0, 1, 0], [0, 0, 0, 1]]).reshape(1, 2, 2, 2, 1), onp.array([[0, 0, 0, 0], [1, 0, 0, 1]]).reshape(1, 2, 2, 1, 2))
],
)
def test_separation_binary(backend, u, v, u_diff_ndim, v_diff_ndim, func, diff_u, diff_v):
try:
with ua.set_backend(backend), ua.set_backend(udiff, coerce=True):
u = np.asarray(u)
u.var = udiff.Variable('u', diff_ndim=u_diff_ndim)
v = np.asarray(v)
v.var = udiff.Variable('v', diff_ndim=v_diff_ndim)

ret = func(u, v)
except ua.BackendNotImplementedError:
if backend in FULLY_TESTED_BACKENDS:
raise
pytest.xfail(reason="The backend has no implementation for this ufunc.")

if isinstance(ret, da.Array):
ret.compute()

assert_allclose(ret.diffs[u].arr, diff_u.tolist())
assert_allclose(ret.diffs[v].arr, diff_v.tolist())

0 comments on commit 015b4de

Please sign in to comment.