In this notebook we verify the custom backprop we've implemented. 

In [49]:
# %% Imports
import jax
from jax import random, numpy as jnp
from functools import partial
from jax import lax
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from code import fwd_solver, bayes_ridge_update#, #fixed_point_layer
from sklearn.linear_model import BayesianRidge
from jax.test_util import check_grads

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [50]:
data = jnp.load('test_data.npy', allow_pickle=True).item()
y, X = data['y'], data['X']

In [17]:
@partial(jax.custom_vjp, nondiff_argnums=(0, ))
def fixed_point_layer(f, params, x):
    z_star = fwd_solver(
        lambda z: f(params, x, z), z_init=jnp.stack([1, 1 / jnp.var(params)], axis=0), tol=1e-5
    )
    return z_star

def fixed_point_layer_fwd(f, params, x):
    z_star = fixed_point_layer(f, params, x)
    return z_star, (params, x, z_star)

def fixed_point_layer_bwd(f, res, z_star_bar):
    params, x, z_star = res
    _, vjp_a = jax.vjp(lambda params, x: f(params, x, z_star), params, x)
    _, vjp_z = jax.vjp(lambda z: f(params, x, z), z_star)
    return vjp_a(fwd_solver(lambda u: vjp_z(u)[0] + z_star_bar, z_init=jnp.zeros_like(z_star), tol=1e-5))

fixed_point_layer.defvjp(fixed_point_layer_fwd, fixed_point_layer_bwd)

In [4]:
layer = jax.jit(fixed_point_layer, static_argnums=0)

In [5]:
n_samples, n_terms = X.shape
hyper_prior =  jnp.stack([n_samples / 2, 1 / (n_samples / 2 * 1e-4)],  axis=0)
f = lambda y, X, prior: bayes_ridge_update(y, X, prior, hyper_prior)

z_star = layer(f, y, X)



In [6]:
z_star

DeviceArray([ 0.3674249, 49.686718 ], dtype=float32)

In [7]:
da = jax.vjp(lambda y, X: layer(f, y, X), y, X)[1](jnp.array([1., 0.]))
db = jax.vjp(lambda y, X: layer(f, y, X), y, X)[1](jnp.array([0., 1.]))

In [8]:
jax.grad(lambda X: layer(f, y, X).sum())(X)[0]

DeviceArray([-6.0546863e-06,  4.5321195e-04, -3.9698742e-04,
             -6.0164703e-06, -4.3055079e-05,  2.2004379e-03,
             -2.6065094e-05, -2.2817476e-06,  2.1683940e-04,
              1.8109565e-03,  1.9534102e-05,  1.2577546e-05],            dtype=float32)

In [9]:
def fwd_solver(f, z_init, tol):
    z_prev, z = z_init, f(z_init)
    while jnp.linalg.norm(z_prev - z) > tol:
        z_prev, z = z, f(z)
    return z

In [10]:
n_samples, n_terms = X.shape
hyper_prior =  jnp.stack([n_samples / 2, 1 / (n_samples / 2 * 1e-4)],  axis=0)
f = lambda y, X, prior: bayes_ridge_update(y, X, prior, hyper_prior)

z_star = fixed_point_layer(f, y, X)

In [11]:
z_star

DeviceArray([ 0.3674248, 49.686718 ], dtype=float32)

In [12]:
jax.grad(lambda X: fixed_point_layer(f, y, X).sum())(X)[0]

DeviceArray([-6.0547063e-06,  4.5321378e-04, -3.9698795e-04,
             -6.0165371e-06, -4.3054675e-05,  2.2004391e-03,
             -2.6065547e-05, -2.2816866e-06,  2.1683845e-04,
              1.8109567e-03,  1.9534278e-05,  1.2577598e-05],            dtype=float32)

So our custom backprop is correct :-)

Now let's refactor it a little bit to make it more clear:

In [15]:
@partial(jax.custom_vjp, nondiff_argnums=(0, ))
def fixed_point_solver(f, params, x, z_init, tol=1e-5):
    z_star = fwd_solver(lambda z: f(params, x, z), z_init=z_init, tol=tol)
    return z_star

def fixed_point_solver_fwd(f, params, x, z_init, tol):
    z_star = fixed_point_solver(f, params, x, z_init, tol)
    return z_star, (params, x, z_star, tol)

def fixed_point_solver_bwd(f, res, z_star_bar):
    params, x, z_star, tol = res
    _, vjp_a = jax.vjp(lambda params, x: f(params, x, z_star), params, x)
    _, vjp_z = jax.vjp(lambda z: f(params, x, z), z_star)
    res = vjp_a(fwd_solver(lambda u: vjp_z(u)[0] + z_star_bar, z_init=jnp.zeros_like(z_star), tol=tol))
    return (*res, None, None)

fixed_point_solver.defvjp(fixed_point_solver_fwd, fixed_point_solver_bwd)
layer = jax.jit(fixed_point_solver, static_argnums=(0, )) # static argnums should match non_diff argnums

In [None]:
n_samples, n_terms = X.shape
hyper_prior =  jnp.stack([n_samples / 2, 1 / (n_samples / 2 * 1e-4)],  axis=0)
z_init=jnp.stack([1, 1 / jnp.var(y)], axis=0)
f = lambda y, X, prior: bayes_ridge_update(y, X, prior, hyper_prior)

z_star = layer(f, y, X, z_init, tol=1e-4)
jax.grad(lambda X: layer(f, y, X, z_init).sum())(X)[0]

In [None]:
jax.grad(lambda z: layer(f, y, X, z).sum())(z_init)[0]

Great that works, now let's implement *args

In [3]:
def bayes_ridge_update(prior_params, y, X, hyper_prior_params):
    # Unpacking parameters
    alpha_prev, beta_prev = prior_params
    a, b = hyper_prior_params

    # Preparing some matrices
    X_normed = X / jnp.linalg.norm(X, axis=0)
    gram = X_normed.T @ X_normed
    eigvals = jnp.linalg.eigvalsh(gram)

    # Calculating intermediate matrices
    n_samples, n_terms = X.shape
    gamma_ = jnp.sum((beta_prev * eigvals) / (alpha_prev + beta_prev * eigvals))
    S = jnp.linalg.inv(beta_prev * gram + alpha_prev * jnp.eye(n_terms))
    mn = beta_prev * S @ X_normed.T @ y

    # Update estimate
    alpha = gamma_ / jnp.sum(mn ** 2)
    beta = (n_samples - gamma_ + 2 * (a - 1)) / (
        jnp.sum((y - X_normed @ mn) ** 2) + 2 * b
    )

    return jnp.stack([alpha, beta], axis=0)

In [51]:
@partial(jax.custom_vjp, nondiff_argnums=(0, ))
def fixed_point_solver(f, args, z_init, tol=1e-5):
    z_star = fwd_solver(lambda z: f(z, *args), z_init=z_init, tol=tol)
    return z_star

def fixed_point_solver_fwd(f, args, z_init, tol):
    z_star = fixed_point_solver(f, args, z_init, tol)
    return z_star, (z_star, tol, args)

def fixed_point_solver_bwd(f, res, z_star_bar):
    z_star, tol, args = res
    _, vjp_a = jax.vjp(lambda args: f(z_star, *args), args)
    _, vjp_z = jax.vjp(lambda z: f(z, *args), z_star)
    res = vjp_a(fwd_solver(lambda u: vjp_z(u)[0] + z_star_bar, z_init=jnp.zeros_like(z_star), tol=tol))
    return (*res, None, None) # None for init and tol

fixed_point_solver.defvjp(fixed_point_solver_fwd, fixed_point_solver_bwd)
layer = jax.jit(fixed_point_solver, static_argnums=(0, )) # static argnums should match non_diff argnums

In [52]:
n_samples, n_terms = X.shape
hyper_prior =  jnp.stack([n_samples / 2, 1 / (n_samples / 2 * 1e-4)],  axis=0)
z_init=jnp.stack([1, 1 / jnp.var(y)], axis=0)
f = lambda prior, y, X: bayes_ridge_update(prior, y, X, hyper_prior)

z_star = layer(f, (y, X), z_init, 1e-5)

In [53]:
z_star

DeviceArray([ 0.3674249, 49.686718 ], dtype=float32)

In [54]:
jax.grad(lambda X: layer(f, (y, X), z_init).sum())(X)[0]

DeviceArray([-6.0546863e-06,  4.5321195e-04, -3.9698742e-04,
             -6.0164703e-06, -4.3055079e-05,  2.2004379e-03,
             -2.6065094e-05, -2.2817476e-06,  2.1683940e-04,
              1.8109565e-03,  1.9534102e-05,  1.2577546e-05],            dtype=float32)

In [55]:
jax.grad(lambda y: layer(f, (y, X), z_init).sum())(y)[0]

DeviceArray([0.00433664], dtype=float32)

In [56]:
jax.grad(lambda tol: layer(f, (y, X), z_init, tol=tol).sum())(1e-4)

DeviceArray(0., dtype=float32)

In [57]:
jax.grad(lambda z_init: layer(f, (y, X), z_init, tol=1e-4).sum())(z_init)

DeviceArray([0., 0.], dtype=float32)