In [1]:
import jax, jax.nn
from jax import random
import jax.numpy as jnp
from jax.experimental import optimizers

from collections import namedtuple

In [2]:
def siren_layer_params(key, scale, m, n):
	w_key, b_key = random.split(key)
	return random.uniform(w_key, (m, n), jnp.float32, minval = -scale, maxval = scale), jnp.zeros((n, ), jnp.float32)

def init_siren_params(key, layers, c0, w0):
	keys = random.split(key, len(layers))
	return [siren_layer_params(keys[0], w0*jnp.sqrt(c0/layers[0]), layers[0], layers[1])] + \
			[siren_layer_params(k, jnp.sqrt(c0/m), m, n) for m, n, k in zip(layers[1:-1], layers[2:], keys[1:])]

layers = [2, 128, 2] # (x, t) -> (u, v)
params = init_siren_params(random.PRNGKey(0), layers, 1.0, 1.0)

In [14]:
@jax.jit
def scalar_u_model(params, x, t):
	x_ = jnp.hstack([x, t])
	for w, b in params[:-1]:
		x_ = jnp.sin(jnp.dot(x_, w) + b)
	return jnp.sum(jnp.dot(x_, params[-1][0][:, 0:1]) + params[-1][1][0])

@jax.jit
def scalar_v_model(params, x, t):
	x_ = jnp.hstack([x, t])
	for w, b in params[:-1]:
		x_ = jnp.sin(jnp.dot(x_, w) + b)
	return jnp.sum(jnp.dot(x_, params[-1][0][:, 1:2]) + params[-1][1][1])

u_model = jax.jit(jax.vmap(scalar_u_model, in_axes = (None, 0, 0)))
v_model = jax.jit(jax.vmap(scalar_v_model, in_axes = (None, 0, 0)))

@jax.jit
def uv_model(params, x):
	for w, b in params[:-1]:
		x = jnp.sin(jnp.dot(x, w) + b)
	return jnp.dot(x, params[-1][0]) + params[-1][1]

uv_m = jax.vmap(uv_model, (None, 0))

In [15]:
x = jnp.ones((10, 1))
t = jnp.ones((10, 1))*0.1
xt = jnp.hstack([x, t])

In [16]:
u_ = u_model(params, x, t).reshape((-1, 1))
v_ = v_model(params, x, t).reshape((-1, 1))
uv_ = uv_model(params, xt)
uv_ = uv_m(params, xt)
jnp.sum(jnp.square(jnp.hstack([u_, v_]) - uv_))

DeviceArray(5.551115e-16, dtype=float32)

In [71]:
@jax.jit
def duv_dxt1(params, xt):
	return jax.vmap(jax.jacfwd(uv_model, 1), (None, 0))(params, xt)

@jax.jit
def duv_dxt2(params, xt):
	return  jax.vmap(jax.jacrev(uv_model, 1), (None, 0))(params, xt)

hessian_uv = jax.jit(jax.vmap(jax.hessian(uv_model, 1), (None, 0)))
jacobian_uv = jax.jit(jax.vmap(jax.jacobian(uv_model, 1), (None, 0)))

In [67]:
hessian_uv(params, xt).shape

(10, 2, 2, 2)

In [66]:
%timeit hessian_uv(params, xt).block_until_ready()

196 µs ± 260 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [72]:
%timeit duv_dxt1(params, xt).block_until_ready()
%timeit duv_dxt2(params, xt).block_until_ready()
%timeit jacobian_uv(params, xt).block_until_ready()

193 µs ± 309 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
188 µs ± 214 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
189 µs ± 193 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [26]:
%timeit jax.vmap(jax.jacfwd(uv_model, 1), (None, 0))(params, xt).block_until_ready()

1.53 ms ± 31.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [27]:
%timeit jax.vmap(jax.jacrev(uv_model, 1), (None, 0))(params, xt).block_until_ready()

1.9 ms ± 22.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [68]:
@jax.jit
def scalar_du_dx(params, x, t):
    return jnp.sum(jax.grad(scalar_u_model, 1)(params, x, t))

@jax.jit
def scalar_du_dy(params, x, t):
    return jnp.sum(jax.grad(scalar_u_model, 2)(params, x, t))

@jax.jit
def scalar_dv_dx(params, x, t):
    return jnp.sum(jax.grad(scalar_v_model, 1)(params, x, t))

@jax.jit
def scalar_dv_dy(params, x, t):
    return jnp.sum(jax.grad(scalar_v_model, 2)(params, x, t))

du_dx = jax.jit(jax.vmap(scalar_du_dx, in_axes = (None, 0, 0)))
du_dy = jax.jit(jax.vmap(scalar_du_dy, in_axes = (None, 0, 0)))
dv_dx = jax.jit(jax.vmap(scalar_dv_dx, in_axes = (None, 0, 0)))
dv_dy = jax.jit(jax.vmap(scalar_dv_dy, in_axes = (None, 0, 0)))

@jax.jit
def du_dxx(params, x, t):
    return jax.grad(scalar_du_dx, 1)(params, x, t)

@jax.jit
def du_dyy(params, x, t):
    return jax.grad(scalar_du_dy, 2)(params, x, t)

In [40]:
%timeit du_dx(params, x, t).block_until_ready()
%timeit du_dy(params, x, t).block_until_ready()
%timeit dv_dx(params, x, t).block_until_ready()
%timeit dv_dy(params, x, t).block_until_ready()

230 µs ± 225 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
229 µs ± 284 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
228 µs ± 211 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
230 µs ± 470 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [70]:
%timeit du_dxx(params, x, t).block_until_ready()

238 µs ± 348 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [141]:
params = jnp.array([[1., 2.], [3., 4.]])
x = jnp.ones((10, 1))
t = x*2
xt = jnp.hstack([x, t])

def model(params, xt):
	return jnp.square(jnp.dot(xt, params))

jacobian = lambda params, xt: jax.jit(jax.vmap(jax.jacobian(model, 1), (None, 0)))(params, xt)
hessian = lambda params, xt: jax.jit(jax.vmap(jax.hessian(model, 1), (None, 0)))(params, xt)

In [138]:
jax.vmap(model, (None, 0))(params, xt)

DeviceArray([[ 49., 100.],
             [ 49., 100.],
             [ 49., 100.],
             [ 49., 100.],
             [ 49., 100.],
             [ 49., 100.],
             [ 49., 100.],
             [ 49., 100.],
             [ 49., 100.],
             [ 49., 100.]], dtype=float32)

In [139]:
jacobian(params, xt)

DeviceArray([[[14., 42.],
              [40., 80.]],

             [[14., 42.],
              [40., 80.]],

             [[14., 42.],
              [40., 80.]],

             [[14., 42.],
              [40., 80.]],

             [[14., 42.],
              [40., 80.]],

             [[14., 42.],
              [40., 80.]],

             [[14., 42.],
              [40., 80.]],

             [[14., 42.],
              [40., 80.]],

             [[14., 42.],
              [40., 80.]],

             [[14., 42.],
              [40., 80.]]], dtype=float32)

In [144]:
hessian(params, xt)

DeviceArray([[[[ 2.,  6.],
               [ 6., 18.]],

              [[ 8., 16.],
               [16., 32.]]],


             [[[ 2.,  6.],
               [ 6., 18.]],

              [[ 8., 16.],
               [16., 32.]]],


             [[[ 2.,  6.],
               [ 6., 18.]],

              [[ 8., 16.],
               [16., 32.]]],


             [[[ 2.,  6.],
               [ 6., 18.]],

              [[ 8., 16.],
               [16., 32.]]],


             [[[ 2.,  6.],
               [ 6., 18.]],

              [[ 8., 16.],
               [16., 32.]]],


             [[[ 2.,  6.],
               [ 6., 18.]],

              [[ 8., 16.],
               [16., 32.]]],


             [[[ 2.,  6.],
               [ 6., 18.]],

              [[ 8., 16.],
               [16., 32.]]],


             [[[ 2.,  6.],
               [ 6., 18.]],

              [[ 8., 16.],
               [16., 32.]]],


             [[[ 2.,  6.],
               [ 6., 18.]],

              [[ 8., 16.],
     