Skip to content

Commit

Permalink
add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sangyx committed Aug 24, 2020
1 parent 84d35ff commit 9f7e3b5
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
40 changes: 30 additions & 10 deletions src/tests/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def backend(request):
return backend


@pytest.fixture(scope="session", params=["vjp", "jvp"])
def mode(request):
mode = request.param
return mode


def generate_test_data(n_elements=12, a=None, b=None):
if a is None:
a = -10
Expand Down Expand Up @@ -143,14 +149,14 @@ def grad_check_sparse(f, x, analytic_grad, num_checks=10, h=1e-5):
(np.deg2rad, lambda x: pi / 180.0, None),
],
)
def test_unary_function(backend, func, y_d, domain):
def test_unary_function(backend, mode, func, y_d, domain):
if domain is None:
x_arr = generate_test_data()
else:
x_arr = generate_test_data(a=domain[0], b=domain[1])
expect_diff = [y_d(xa) for xa in x_arr]
try:
with ua.set_backend(udiff.DiffArrayBackend(backend), coerce=True):
with ua.set_backend(udiff.DiffArrayBackend(backend, mode=mode), coerce=True):
x = np.asarray(x_arr)
y = func(x)
x_diff = y.to(x)
Expand All @@ -162,7 +168,10 @@ def test_unary_function(backend, func, y_d, domain):
if isinstance(y, da.Array):
y.compute()

assert_allclose(x_diff.value, expect_diff)
if mode == "vjp":
assert_allclose(x_diff.value, expect_diff)
elif mode == "jvp":
assert_allclose(x_diff, expect_diff)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -231,7 +240,7 @@ def test_unary_function(backend, func, y_d, domain):
),
],
)
def test_binary_function(backend, func, u_d, v_d, u_domain, v_domain):
def test_binary_function(backend, mode, func, u_d, v_d, u_domain, v_domain):
if u_domain is None:
u_arr = generate_test_data()
else:
Expand All @@ -244,7 +253,7 @@ def test_binary_function(backend, func, u_d, v_d, u_domain, v_domain):
expect_u_diff = [u_d(ua, va) for ua, va in zip(u_arr, v_arr)]
expect_v_diff = [v_d(ua, va) for ua, va in zip(u_arr, v_arr)]
try:
with ua.set_backend(udiff.DiffArrayBackend(backend), coerce=True):
with ua.set_backend(udiff.DiffArrayBackend(backend, mode=mode), coerce=True):
u = np.asarray(u_arr)
v = np.asarray(v_arr)
y = func(u, v)
Expand All @@ -254,12 +263,20 @@ def test_binary_function(backend, func, u_d, v_d, u_domain, v_domain):
if backend in FULLY_TESTED_BACKENDS:
raise
pytest.xfail(reason="The backend has no implementation for this ufunc.")
except NotImplementedError:
pytest.xfail(
reason="The func has no implementation in the {} mode.".format(mode)
)

if isinstance(y, da.Array):
y.compute()

assert_allclose(u_diff.value, expect_u_diff)
assert_allclose(v_diff.value, expect_v_diff)
if mode == "vjp":
assert_allclose(u_diff.value, expect_u_diff)
assert_allclose(v_diff.value, expect_v_diff)
elif mode == "jvp":
assert_allclose(u_diff, expect_u_diff)
assert_allclose(v_diff, expect_v_diff)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -293,14 +310,14 @@ def test_binary_function(backend, func, u_d, v_d, u_domain, v_domain):
),
],
)
def test_arbitrary_function(backend, func, y_d, domain):
def test_arbitrary_function(backend, mode, func, y_d, domain):
if domain is None:
x_arr = generate_test_data()
else:
x_arr = generate_test_data(a=domain[0], b=domain[1])
expect_diff = [y_d(xa) for xa in x_arr]
try:
with ua.set_backend(udiff.DiffArrayBackend(backend), coerce=True):
with ua.set_backend(udiff.DiffArrayBackend(backend, mode=mode), coerce=True):
x = np.asarray(x_arr)
y = func(x)
x_diff = y.to(x)
Expand All @@ -312,7 +329,10 @@ def test_arbitrary_function(backend, func, y_d, domain):
if isinstance(y, da.Array):
y.compute()

assert_allclose(x_diff.value, expect_diff)
if mode == "vjp":
assert_allclose(x_diff.value, expect_diff)
elif mode == "jvp":
assert_allclose(x_diff, expect_diff)


# @pytest.mark.skip
Expand Down
2 changes: 2 additions & 0 deletions src/udiff/_jvp_diffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@
def_linear(np.cross)

# ----- Simple grads -----
defjvp(np.positive, lambda g, ans, x: np.ones_like(x) * g)
defjvp(np.negative, lambda g, ans, x: -np.ones_like(x) * g)
defjvp(np.fabs, lambda g, ans, x: np.sign(x) * g) # fabs doesn't take complex numbers.
defjvp(np.absolute, lambda g, ans, x: np.real(g * np.conj(x)) / ans)
defjvp(np.reciprocal, lambda g, ans, x: -g / x ** 2)
Expand Down

0 comments on commit 9f7e3b5

Please sign in to comment.