In [21]:
import jax.numpy as jnp
from jax import custom_jvp
from jax import custom_vjp
from jax import grad
import numpy as np

In [16]:
@custom_jvp
def f(x, y):
    return np.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
    x, y = primals
    x_dot, y_dot = tangents
    primal_out = f(x, y)
    tangent_out = np.cos(x) * x_dot * y + np.sin(x) * y_dot
    return primal_out, tangent_out

In [17]:
g=grad(f)

In [20]:
%timeit g(1.0,2.0)

2.97 ms ± 89.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
def ff(x, y):
    return jnp.sin(x) * y
gg=grad(ff)

In [26]:
%timeit g(1.0,2.0)

3.14 ms ± 107 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
@custom_vjp
def h(x, y):
    return np.sin(x) * y

def h_fwd(x, y):
    # Returns primal output and residuals to be used in backward pass by f_bwd.
    #   .. dhdy
    return h(x, y), (jnp.cos(x), jnp.sin(x), y)

def h_bwd(res, g):
    cos_x, sin_x, y = res # Gets residuals computed in f_fwd
    return (cos_x * g * y, sin_x * g)

h.defvjp(h_fwd, h_bwd)


In [24]:
gh=grad(h)

In [27]:
%timeit gh(1.0,2.0)

3.6 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
