In [11]:
from jax import random
import jax.numpy as jnp
from functools import partial
import jax

In [12]:
ndim = 10
W = random.normal(random.PRNGKey(0), (ndim, ndim)) / jnp.sqrt(ndim)
x = random.normal(random.PRNGKey(1), (ndim,))

In [13]:
def fwd_solver(f, z_init):
  z_prev, z = z_init, f(z_init)
  while jnp.linalg.norm(z_prev - z) > 1e-5:
    z_prev, z = z, f(z)
  return z

In [44]:
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def fixed_point_layer(solver, f, params, x):
  z_star = solver(lambda z: f(params, x, z), z_init=jnp.zeros_like(x))
  return z_star

def fixed_point_layer_fwd(solver, f, params, x):
  z_star = fixed_point_layer(solver, f, params, x)
  return z_star, (params, x, z_star)

def fixed_point_layer_bwd(solver, f, res, z_star_bar):
  params, x, z_star = res
  print(z_star_bar.shape)
  _, vjp_a = jax.vjp(lambda params, x: f(params, x, z_star), params, x)
  _, vjp_z = jax.vjp(lambda z: f(params, x, z), z_star)
  return vjp_a(solver(lambda u: vjp_z(u)[0] + z_star_bar,
                      z_init=jnp.zeros_like(z_star)))

fixed_point_layer.defvjp(fixed_point_layer_fwd, fixed_point_layer_bwd)

In [45]:
f = lambda W, x, z: jnp.tanh(jnp.dot(W, z) + x)

In [46]:
g = jax.grad(lambda W: fixed_point_layer(fwd_solver, f, W, x).sum())(W)
print(g[0])

(10,)
[ 0.00752357 -0.8125729  -1.1404755  -0.04860361 -0.7125366  -0.55805445
  0.6697886   1.1068368  -0.09697603  0.97840637]


In [47]:
g = jax.grad(lambda W: fixed_point_layer(fwd_solver, f, W, x)[0])(W)
print(g[0])

(10,)
[ 0.00373198 -0.40306717 -0.5657194  -0.02410925 -0.35344535 -0.2768163
  0.33224073  0.54903334 -0.04810381  0.48532692]


In [49]:
z_star = fixed_point_layer(fwd_solver, f, W, x)
z_star

DeviceArray([ 0.00649604, -0.7015957 , -0.98471504, -0.04196557,
             -0.61522186, -0.48183814,  0.5783122 ,  0.95567054,
             -0.08373152,  0.8447805 ], dtype=float32)

Let's see if we can deconstruct the backwards function:

In [55]:
z_star_bar = jnp.zeros((10, ))
z_star_bar = jax.ops.index_update(z_star_bar, 0, 1)

In [56]:
_, vjp_a = jax.vjp(lambda W, x: f(W, x, z_star), W, x)
_, vjp_z = jax.vjp(lambda z: f(W, x, z), z_star)

In [65]:
u = fwd_solver(lambda u: vjp_z(u)[0] + z_star_bar,
                      z_init=jnp.zeros_like(z_star))
g = vjp_a(u)

In [67]:
g[0][0]

DeviceArray([ 0.00373198, -0.40306717, -0.5657194 , -0.02410925,
             -0.35344535, -0.2768163 ,  0.33224073,  0.54903334,
             -0.04810381,  0.48532692], dtype=float32)

In [70]:
u

DeviceArray([ 0.5745249 ,  0.23003381,  0.25346488, -0.03544246,
             -0.3168217 ,  0.12168527, -0.00903781,  0.01450401,
             -0.022818  , -0.15082248], dtype=float32)

In [73]:
z_star.shape

(10,)

In [77]:
u_diff = jnp.linalg.inv(jnp.eye(z_star.shape[0]) - jax.jacobian(lambda z: f(W, x, z))(z_star))

In [80]:
jnp.dot(z_star_bar, u_diff)

DeviceArray([ 0.5745274 ,  0.23003279,  0.25346458, -0.03544248,
             -0.31682113,  0.12168445, -0.00903799,  0.01450502,
             -0.02281865, -0.15082209], dtype=float32)

In [79]:
u_diff

DeviceArray([[ 5.74527383e-01,  2.30032787e-01,  2.53464580e-01,
              -3.54424752e-02, -3.16821128e-01,  1.21684454e-01,
              -9.03799478e-03,  1.45050213e-02, -2.28186548e-02,
              -1.50822088e-01],
             [-9.71050113e-02,  7.36825943e-01,  1.18140623e-01,
               1.05652541e-01, -7.06935301e-02,  6.67412132e-02,
              -6.45479113e-02, -1.28294125e-01, -1.40098333e-01,
              -6.30238801e-02],
             [-1.20770314e-03,  1.17540210e-02,  1.00734746e+00,
               7.36990338e-03, -4.77466546e-03,  5.61290048e-03,
              -1.75211253e-03, -1.20259523e-02,  3.17129446e-03,
              -2.04873439e-02],
             [ 3.46762180e-01, -3.00977588e-01,  7.59065747e-02,
               1.07281601e+00, -3.88183624e-01,  1.43041372e-01,
               1.48075059e-01, -3.79792005e-01,  3.34723443e-01,
              -3.19274127e-01],
             [ 1.71978116e-01, -2.26782054e-01,  1.48526236e-01,
               3.96907628e-