Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,16 +287,16 @@ def from_numpy(self, a, type_as=None):
return jnp.array(a).astype(type_as.dtype)

def set_gradients(self, val, inputs, grads):
# no gradients for jax because it is functional
from jax.flatten_util import ravel_pytree
val, = jax.lax.stop_gradient((val,))

# does not work
# from jax import custom_jvp
# @custom_jvp
# def f(*inputs):
# return val
# f.defjvps(*grads)
# return f(*inputs)
ravelled_inputs, _ = ravel_pytree(inputs)
ravelled_grads, _ = ravel_pytree(grads)

aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
aux = aux - jax.lax.stop_gradient(aux)

val, = jax.tree_map(lambda z: z + aux, (val,))
return val

def zeros(self, shape, type_as=None):
Expand Down
15 changes: 14 additions & 1 deletion test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ def test_gradients_backends():

rnd = np.random.RandomState(0)
v = rnd.randn(10)
c = rnd.randn(1)
c = rnd.randn()
e = rnd.randn()

if torch:

Expand All @@ -362,3 +363,15 @@ def test_gradients_backends():

assert torch.equal(v2.grad, v2)
assert torch.equal(c2.grad, c2)

if jax:
nx = ot.backend.JaxBackend()
with jax.checking_leaks():
def fun(a, b, d):
val = b * nx.sum(a ** 4) + d
return nx.set_gradients(val, (a, b, d), (a, b, 2 * d))
grad_val = jax.grad(fun, argnums=(0, 1, 2))(v, c, e)

np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4)
np.testing.assert_allclose(grad_val[0], v, atol=1e-4)
np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4)