$$
\frac{d^2 y}{dx^2} = 1
$$

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.config import config
config.update("jax_enable_x64", False)

In [2]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m), dtype = jnp.float32), scale * random.normal(b_key, (n,), dtype = jnp.float32)

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [1, 128, 128, 128, 128, 1]
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [3]:
@jit
def predict(params, x):
  activations = x
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = jnp.tanh(outputs)
  
  final_w, final_b = params[-1]
  y = jnp.dot(final_w, activations) + final_b
  return jnp.sum(y)

from jax import value_and_grad

@jit
def predict_and_grad(params, x):
	y, grads = value_and_grad(predict, (0, 1))(params, x)
	return y, grads[0], grads[1]

@jit
def predict_and_second_order_grad(params, x):
	def dx(x):
		return jnp.sum(grad(predict, 1)(params, x))
	
	def ddx(x):
		return grad(dx)(x)
	return ddx(x)
	
batched_predict_and_grad = vmap(predict_and_grad, in_axes = (None, 0))
batched_predict_and_second_order_grad = vmap(predict_and_second_order_grad, in_axes = (None, 0))

In [4]:
# import jax
# def diag_grad(f, x):
#     def partial_grad_f_index(i):
#         def partial_grad_f_x(xi):
#             return f(jax.ops.index_update(x, i, xi))[i]
#         return jax.grad(partial_grad_f_x)(x[i])
#     return jax.vmap(partial_grad_f_index)(jax.numpy.arange(x.shape[0]))

# def jacdiag(f):
#     def _jacdiag(x):
#         def partial_grad_f_index(i):
#             def partial_grad_f_x(xi):
#                 return f(jax.ops.index_update(x, i, xi))[i]
#             return jax.grad(partial_grad_f_x)(x[i])
#         return jax.vmap(partial_grad_f_index)(jax.numpy.arange(x.shape[0]))
#     return _jacdiag

# # @jit
# # def vmap_jacdiag(f):
# # 	return jax.vmap(f)

In [7]:
def mse(x, y):
	return jnp.mean(jnp.sum(jnp.square(x - y)))

from jax import value_and_grad, vmap

@jit
def loss(params, x):
  dy_dxx = batched_predict_and_second_order_grad(params, x)
  return mse(dy_dxx, 1.0)

@jit
def update(params, x):
  grads = grad(loss)(params, x)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [8]:
x = random.uniform(random.PRNGKey(0), shape = (1000, 1), minval = 0.0, maxval = 1.0, dtype = jnp.float32)

In [9]:
epochs = 5000
step_size = 0.01

import time
start_time = time.time()
for epoch in range(1, epochs + 1):
	params = update(params, x)
print("AVG epoch time: ", (time.time() - start_time)/epochs)

AVG epoch time:  0.0007730528354644775
