In [5]:
import jax.numpy as jnp
from jax import custom_jvp
from jax import custom_vjp
from jax import grad
from jax import jacrev
import numpy as np
from jax import jit
from jax import vmap

In [11]:
@custom_jvp
def f(x, y):
    return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
    x, y = primals
    x_dot, y_dot = tangents
    df1dx=y * jnp.cos(x)
    df1dy=jnp.sin(x)
    #df2dx=np.cos(y)
    #df2dy=- x * np.cos(x)
    primal_out = f(x, y)
    tangent_out = df1dx * x_dot  + df1dy * y_dot
    return primal_out, tangent_out

g=grad(f,argnums=(0,))
gv=vmap(grad(f,argnums=(0,1)),(0,None),0)

In [12]:
g(1.0,2.0)

(DeviceArray(1.0806046, dtype=float32),)

In [13]:
xv=jnp.linspace(-1,1,3)
gv(xv,2.0)

(DeviceArray([1.0806046, 2.       , 1.0806046], dtype=float32),
 DeviceArray([-0.84147096,  0.        ,  0.84147096], dtype=float32))

In [31]:
@custom_jvp
def fnp(x, y):
    return np.sin(x) * y #これはnpでもよい

@fnp.defjvp
def fnp_jvp(primals, tangents):
    x, y = primals
    x_dot, y_dot = tangents
    df1dx=y * np.cos(x)
    df1dy=jnp.sin(x)
    #df2dx=np.cos(y)
    #df2dy=- x * np.cos(x)
    primal_out = f(x, y)
    tangent_out = df1dx * x_dot  + df1dy * y_dot
    return primal_out, tangent_out

gnp=grad(fnp,argnums=(0,))
gvnp=vmap(grad(fnp,argnums=(0,1)),(0,None),0)

In [32]:
gnp(1.0,2.0)

(DeviceArray(1.0806046, dtype=float32),)

In [33]:
#np versionだとvmapが通らない
xv=jnp.linspace(-1,1,3)
gvnp(xv,2.0)

Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)>
  with val = DeviceArray([-1.,  0.,  1.], dtype=float32)
       batch_dim = 0.

This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.

In [65]:
from jax import core
from jax.interpreters import ad
import numpy as np  # I changed this name out of habit
from jax.core import Primitive
from jax.interpreters.ad import defvjp

# Define function to be differentiate
def udsin(x):
    return foo_p.bind(x)
foo_p = Primitive('udsin')
foo_p.def_impl(np.sin)

def dudsin(g, x):
    return g*udcos(x)
defvjp(foo_p, dudsin)

def udcos(x):
    return bar_p.bind(x)
bar_p = Primitive('udcos')
bar_p.def_impl(np.cos)

def dudcos(g, x):
    return -g*udsin(x)
defvjp(bar_p, dudcos)


udsin(0.0)

0.0

In [112]:
#see https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
from jax import core
from jax.interpreters import ad
import numpy as np  
from jax.core import Primitive
from jax.interpreters.ad import defvjp
from jax.interpreters import batching
from jax import abstract_arrays
from jax.lib import xla_client
from jax.interpreters import xla

def udsin(x):
    return foo_p.bind(x)
foo_p = Primitive('udsin')
foo_p.def_impl(np.sin)

def dudsin(g, x):
    return g*udcos(x)
defvjp(foo_p, dudsin)

def udcos(x):
    return bar_p.bind(x)
bar_p = Primitive('udcos')
bar_p.def_impl(np.cos)

def dudcos(g, x):
    return -g*udsin(x)
defvjp(bar_p, dudcos)

#vmapを使うためにはbatching ruleが必要
def foo_batch(vector_arg_values, batch_axes):
    res = udsin(*vector_arg_values)
    return res, batch_axes[0]

batching.primitive_batchers[foo_p] = foo_batch

def bar_batch(vector_arg_values, batch_axes):
    res = udcos(*vector_arg_values)
    return res, batch_axes[0]

batching.primitive_batchers[bar_p] = bar_batch

#jitを使うためにはabstract evaluation ruleが必要
def foobar_abstract_eval(xs):
    return abstract_arrays.ShapedArray(xs.shape, xs.dtype)

foo_p.def_abstract_eval(foobar_abstract_eval)
bar_p.def_abstract_eval(foobar_abstract_eval)

# jitを使うにはXLA compilation ruleも必要
# xla_client.pyはtensorflowのもの
def foo_xla_translation(c, xc):
    return xla_client.ops.Sin(xc)
xla.backend_specific_translations['cpu'][foo_p] = foo_xla_translation

def bar_xla_translation(c, xc):
    return xla_client.ops.Cos(xc)
xla.backend_specific_translations['cpu'][bar_p] = bar_xla_translation

In [113]:
grad_udsin=grad(udsin)

In [114]:
grad_udsinv=jit(vmap(grad(udsin)))

In [117]:

%timeit grad_udsinv(xv)

155 µs ± 3.78 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [62]:
@custom_jvp
def fnp(x, y):
    return np.sin(x) * y #これはnpでもよい

@fnp.defjvp
def fnp_jvp(primals, tangents):
    x, y = primals
    x_dot, y_dot = tangents
    df1dx=y * udcos(x)
    df1dy=udsin(x)
    #df2dx=np.cos(y)
    #df2dy=- x * np.cos(x)
    primal_out = f(x, y)
    tangent_out = df1dx * x_dot  + df1dy * y_dot
    return primal_out, tangent_out

gnp=grad(fnp,argnums=(0,))
gvnp=vmap(grad(fnp,argnums=(0,1)),(0,None),0)

In [64]:
#np versionだとvmapが通らない
xv=np.linspace(-1,1,3)
gvnp(xv,2.0)

NotImplementedError: Batching rule for 'udcos' not implemented

In [58]:
from jax.tree_util import tree_structure
print(tree_structure(fv(1.0,2.0)))

PyTreeDef(tuple, [*,*])


In [61]:
print(tree_structure(tuple((1.0,2.0))))

PyTreeDef(tuple, [*,*])


In [23]:
@custom_vjp
def h(x, y):
    return np.sin(x) * y

def h_fwd(x, y):
    # Returns primal output and residuals to be used in backward pass by f_bwd.
    #   .. dhdy
    return h(x, y), (jnp.cos(x), jnp.sin(x), y)

def h_bwd(res, g):
    cos_x, sin_x, y = res # Gets residuals computed in f_fwd
    return (cos_x * g * y, sin_x * g)

h.defvjp(h_fwd, h_bwd)


In [24]:
gh=grad(h)

In [27]:
%timeit gh(1.0,2.0)

3.6 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
