In [1]:
import jax
from jax import random
import jax.numpy as np
from functools import partial

# JVPs & VJPs

### VJPs
jax.vjp(fun, primals) --> (fun(*primals), vjpfun)

vjpfun(v) --> vjp

In [2]:
def f(x):
    return np.sin(x) * x ** 2

x = 2.
y = f(x)
print(y)

3.6371899


In [3]:
w = 1.
y, f_vjp = jax.vjp(f, x)
lmbda, = f_vjp(w)
print(y)
print(lmbda)

3.6371899
1.9726026


### JVP
jax.jvp(fun, primals, tangents) --> (fun(*primals), jvp)

In [4]:
delta_x = 1.
y, delta_y = jax.jvp(f, (x,), (delta_x,))
print(y)
print(delta_y)

3.6371899
1.9726026


### Function Composition

In [5]:
h = np.sin
g = lambda x: x ** 3
f = lambda x: g(h(x))

In [6]:
x, delta_x = jax.jvp(f, (1.,), (1.,))
print(x, delta_x)

0.59582335 1.147721


In [7]:
def f_jvp(x, delta_x):
    y, delta_y = jax.jvp(h, (x,), (delta_x,))
    z, delta_z = jax.jvp(g, (y,), (delta_y,))
    return z, delta_z
z, delta_z = f_jvp(1., 1.)
print(z, delta_z)

0.59582335 1.147721


In [8]:
def f_vjp(x, w):
    y, h_vjp = jax.vjp(h, x)
    z, g_vjp = jax.vjp(g, y)
    
    lmda_y, = g_vjp(w)
    lmda_x, = h_vjp(lmda_y)
    return z, lmda_x
z, lmda = f_vjp(1., 1.)
print(z, lmda)

0.59582335 1.147721


# Optimization problem

In [9]:
def fwd_solver(f, z_init):
    z_prev, z = z_init, f(z_init)
    while np.linalg.norm(z_prev - z) > 1e-5:
        z_prev, z = z, f(z)
    return z

def newton_solver(f, z_init):
    f_root = lambda z: f(z) - z
    g = lambda z: z - np.linalg.solve(jax.jacobian(f_root)(z), f_root(z))
    return fwd_solver(g, z_init)

def fixed_point_layer(solver, f, params, x):
    z_star = solver(lambda z: f(params, x, z), z_init=np.zeros_like(x))
    return z_star

In [10]:
# f = tanh(Wz * x)
f = lambda W, x, z: np.tanh(np.dot(W, z) + x)

ndim = 10
W = random.normal(random.PRNGKey(0), (ndim, ndim)) / np.sqrt(ndim)
x = random.normal(random.PRNGKey(1), (ndim,))

In [None]:
z_star = fixed_point_layer(newton_solver, f, W, x)
print(z_star)
#print(f(z_star))

### vjp expression for fixed point solution

In [65]:
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def fixed_point_layer(solver, f, params, x):
    z_star = solver(lambda z: f(params, x, z), z_init=np.zeros_like(x))
    return z_star

def fixed_point_layer_fwd(solver, f, params, x):
    z_star = fixed_point_layer(solver, f, params, x)
    return z_star, (params, x, z_star)

def fixed_point_layer_bwd(solver, f, res, z_star_bar):
    params, x, z_star = res
    _, vjp_a = jax.vjp(lambda params, x: f(params, x, z_star), params, x)
    _, vjp_z = jax.vjp(lambda z: f(params, x, z), z_star)
    return vjp_a(solver(lambda u: vjp_z(u)[0] + z_star_bar,
                      z_init=np.zeros_like(z_star)))
fixed_point_layer.defvjp(fixed_point_layer_fwd, fixed_point_layer_bwd)

In [66]:
g = jax.grad(lambda W: fixed_point_layer(fwd_solver, f, W, x).sum())(W)
print(g[0])

[ 0.00752357 -0.8125729  -1.1404755  -0.04860361 -0.7125366  -0.55805445
  0.6697886   1.1068368  -0.09697603  0.97840637]


In [67]:
g = jax.grad(lambda W: fixed_point_layer(newton_solver, f, W, x).sum())(W)
print(g[0])

[ 0.00752136 -0.8125742  -1.1404786  -0.04860303 -0.7125376  -0.55805624
  0.6697907   1.1068398  -0.09697363  0.9784083 ]
