$$
\frac{d^2 y}{dx^2} = 1
$$

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.config import config
config.update("jax_enable_x64", False)

In [32]:
@jit 
def f(x, y):
	return jnp.sum(2*x**2 + y**2)

@jit
def dx(x, y):
	return jnp.sum(grad(f, 0)(x, y))

@jit
def ddx(x, y):
	return jnp.sum(grad(dx, 0)(x, y))

batched_dx = vmap(dx, in_axes=(0, 0), out_axes=0)
batched_ddx = vmap(ddx, in_axes=(0, 0), out_axes=0)

In [33]:
x = jnp.array([1.0, 2.0, 3.0]).reshape((-1, 1))
y = x
print(f(x, y))
print(batched_dx(x, y))
print(batched_ddx(x, y))

42.0
[ 4.  8. 12.]
[4. 4. 4.]
