In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.config import config
config.update("jax_enable_x64", False)

In [4]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m), dtype = jnp.float32), scale * random.normal(b_key, (n,), dtype = jnp.float32)

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [1, 128, 128, 128, 128, 1]
params = init_network_params(layer_sizes, random.PRNGKey(0))

@jit
def predict(params, x):
  activations = x
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = jnp.tanh(outputs)
  
  final_w, final_b = params[-1]
  y = jnp.dot(final_w, activations) + final_b
  return jnp.sum(y)

x = random.uniform(random.PRNGKey(0), shape = (1000, 1), minval = 0.0, maxval = 1.0, dtype = jnp.float32)

## 1. first order grad wrt param and x

In [10]:
from jax import value_and_grad

@jit
def first_order_grad_1(params, x):
	y, grads = value_and_grad(predict, (0, 1))(params, x)
	return y, grads[0], grads[1]
batched_first_order_grad_1 = vmap(first_order_grad_1, in_axes = (None, 0))


@jit
def first_order_grad_2(params, x):
	y = predict(params, x)
	grads = grad(predict, (0, 1))(params, x)
	return y, grads[0], grads[1]
batched_first_order_grad_2 = vmap(first_order_grad_2, in_axes = (None, 0))


@jit
def first_order_grad_3(params, x):
	y = predict(params, x)
	grads_param = grad(predict, 0)(params, x)
	grads_x = grad(predict, 1)(params, x)
	return y, grads_param, grads_x
batched_first_order_grad_3 = vmap(first_order_grad_3, in_axes = (None, 0))


@jit
def first_order_grad_4(params, x):
	grads = grad(predict, (0, 1))(params, x)
	return grads[0], grads[1]
batched_first_order_grad_4 = vmap(first_order_grad_4, in_axes = (None, 0))


@jit
def first_order_grad_5(params, x):
	grads_param = grad(predict, 0)(params, x)
	grads_x = grad(predict, 1)(params, x)
	return grads_param, grads_x
batched_first_order_grad_5 = vmap(first_order_grad_5, in_axes = (None, 0))


%timeit batched_first_order_grad_1(params, x)
%timeit batched_first_order_grad_2(params, x)
%timeit batched_first_order_grad_3(params, x)
%timeit batched_first_order_grad_4(params, x)
%timeit batched_first_order_grad_5(params, x)

640 µs ± 640 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
641 µs ± 402 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
638 µs ± 660 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
618 µs ± 532 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
619 µs ± 529 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## 2. second derivative wrt x and first derivative wrt param

In [None]:
@jit
def second_order_grad(params, x):
	y, first_order_grads = value_and_grad(predict, (0, 1))(params, x)
	ddx = 

## 3. multiple outputs

## 4. multiple inputs

## 5. multiple outputs and inputs