So we know for sure the problem is with alpha. Let's figure it out by creating some fake input and not doing the optimized way, but the standard back-propagating thorugh the while loop.

In [26]:
# %% Imports
from jax import numpy as jnp, random
import jax
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim

from modax.training import train_max_iter
from modax.training.losses.utils import precision, normal_LL
from modax.utils.forward_solver import fixed_point_solver, fwd_solver, fwd_solver_simple


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
def update_sigma(gram, alpha, beta):
    sigma_inv = jnp.diag(alpha) + beta * gram
    L_inv = jnp.linalg.pinv(jnp.linalg.cholesky(sigma_inv))
    sigma_ = jnp.dot(L_inv.T, L_inv)
    return sigma_


def update_coeff(XT_y, beta, sigma_):
    coef_ = beta * jnp.linalg.multi_dot([sigma_, XT_y])
    return coef_


def update(prior, X, y, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]
    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)

    # Update alpha and lambda
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)
    gamma_ = 1.0 - alpha * jnp.diag(sigma)

    # TODO: Cap alpha with some threshold.
    alpha = (gamma_ + 2.0 * alpha_prior[0]) / (
        (coeffs.squeeze() ** 2 + 2.0 * alpha_prior[1])
    )
    beta = (n_samples - gamma_.sum() + 2.0 * beta_prior[0]) / (
        rmse_ + 2.0 * beta_prior[1]
    )

    return jnp.concatenate([alpha, beta[jnp.newaxis]], axis=0)

Lets first create some fake input the neural network:

In [28]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [29]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

prediction, dt, theta, coeffs = model.apply(variables, X)
y = dt
X = theta / jnp.linalg.norm(theta, axis=0)

In [30]:
n_samples, n_features = theta.shape
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)

In [31]:
alpha_prior = (1e-6, 1e-6)
beta_prior = (n_samples / 2, n_samples / (2 * jax.lax.stop_gradient(tau)))

In [32]:
n_samples, n_features = X.shape
norm_weight = jnp.concatenate((jnp.ones((n_features,)), jnp.zeros((1,))), axis=0)
prior_init = jnp.concatenate([jnp.ones((n_features,)), (1.0 / (jnp.var(y) + 1e-7))[jnp.newaxis]], axis=0)
gram = jnp.dot(X.T, X)
XT_y = jnp.dot(X.T, y)

tol = 1e-3
max_iter = 1000 # low to keep it manageable

Let's do the non-custom diff version first:

In [33]:
update_fn = lambda z: update(z, X, y, gram, XT_y, alpha_prior, beta_prior)
prior_params, metrics = fwd_solver(update_fn, prior_init, norm_weight, tol=tol,max_iter=max_iter)

In [34]:
print(prior_params, metrics)

[1.4395750e+01 7.4044579e-01 2.8489959e-01 1.6106035e+01 1.0625963e-01
 1.0787464e-01 1.6446370e-01 9.6243591e+01 3.2280737e-01 1.6879633e-01
 7.6911402e-01 6.7515747e+03 4.7790924e+02] (DeviceArray(341, dtype=int32), DeviceArray(0.0002842, dtype=float32))


Quickly comparing with the custom version (should be the same since its a thin wrapper)

In [35]:
prior_params, metrics = fixed_point_solver(
    update,
    (X, y, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

In [36]:
print(prior_params, metrics)

[1.4395750e+01 7.4044579e-01 2.8489959e-01 1.6106035e+01 1.0625963e-01
 1.0787464e-01 1.6446370e-01 9.6243591e+01 3.2280737e-01 1.6879633e-01
 7.6911402e-01 6.7515747e+03 4.7790924e+02] (DeviceArray(341, dtype=int32), DeviceArray(0.0002842, dtype=float32))


Okay that's fine, as expected. Now to check the grad

In [37]:
direct_grad = jax.grad(lambda X: fwd_solver(lambda z: update(z, X, y, gram, XT_y, alpha_prior, beta_prior), prior_init, norm_weight, tol=tol,max_iter=max_iter)[0][0])

In [38]:
direct_grad(X)

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

So we cant backpropagate through the jax while loop and scan doesn't work for us; maybe a nice first thing to fix for me.

In [39]:
update_fn = lambda z: update(z, X, y, gram, XT_y, alpha_prior, beta_prior)
prior_params, metrics = fwd_solver_simple(update_fn, prior_init, norm_weight, tol=tol,max_iter=max_iter)

In [40]:
prior_params

Buffer([1.43955908e+01, 7.40445554e-01, 2.84899831e-01, 1.61059723e+01,
        1.06259637e-01, 1.07874796e-01, 1.64463893e-01, 9.62457809e+01,
        3.22807193e-01, 1.68796569e-01, 7.69117355e-01, 6.75160400e+03,
        4.77909241e+02], dtype=float32)

In [41]:
metrics

(Buffer(606, dtype=int32), Buffer(0.00091333, dtype=float32))

Result from forward pass is the same, now let's try to backpropagate:

In [42]:
direct_grad = jax.grad(lambda X: fwd_solver_simple(lambda z: update(z, X, y, gram, XT_y, alpha_prior, beta_prior), prior_init, norm_weight, tol=tol,max_iter=max_iter)[0][0])

In [43]:
direct_grad(X)

DeviceArray([[ 1.9318270e-04,  8.6529669e-04,  1.3994672e-03, ...,
               1.8171556e-03,  8.4207469e-04,  1.3879843e-07],
             [ 1.3219929e-04,  5.9214211e-04,  9.5768663e-04, ...,
               1.2435211e-03,  5.7625090e-04,  9.4982909e-08],
             [ 6.4624714e-05,  2.8946475e-04,  4.6815886e-04, ...,
               6.0788682e-04,  2.8169638e-04,  4.6431762e-08],
             ...,
             [-2.8131932e-03, -1.2600745e-02, -2.0379514e-02, ...,
              -2.6462026e-02, -1.2262574e-02, -2.0212287e-06],
             [-2.8625680e-03, -1.2821916e-02, -2.0737207e-02, ...,
              -2.6926521e-02, -1.2477812e-02, -2.0567063e-06],
             [-2.9104475e-03, -1.3036371e-02, -2.1084065e-02, ...,
              -2.7376872e-02, -1.2686521e-02, -2.0911075e-06]],            dtype=float32)

That didnt even take that long... Now to figure out why the custom doesn't work. First let's do the math again; maybe we can find a singularity?

That's not super easy due to derivatives of matrices. I'll make a script so I can debug the code. The error has to be in the three lines of the custom backprop, so we're pretty close to the origin.

After I made the script it turned out it actually worked? Let's check that here too:

In [44]:
prior_params, metrics = fixed_point_solver(
    update,
    (X, y, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)

In [45]:
jax.grad(lambda X: fixed_point_solver(
    update,
    (X, y, gram, XT_y, alpha_prior, beta_prior),
    prior_init,
    norm_weight,
    tol=tol,
    max_iter=max_iter,
)[0][0])(X)

DeviceArray([[ 1.93200714e-04,  8.65380571e-04,  1.39960204e-03, ...,
               1.81733165e-03,  8.42159556e-04,  1.38852613e-07],
             [ 1.32212779e-04,  5.92204742e-04,  9.57787794e-04, ...,
               1.24365208e-03,  5.76313876e-04,  9.50208232e-08],
             [ 6.46337357e-05,  2.89506075e-04,  4.68225538e-04, ...,
               6.07973605e-04,  2.81737652e-04,  4.64520191e-08],
             ...,
             [-2.81337928e-03, -1.26016289e-02, -2.03809347e-02, ...,
              -2.64638923e-02, -1.22634852e-02, -2.02196475e-06],
             [-2.86276219e-03, -1.28228245e-02, -2.07386799e-02, ...,
              -2.69284099e-02, -1.24787455e-02, -2.05745619e-06],
             [-2.91064475e-03, -1.30372988e-02, -2.10855547e-02, ...,
              -2.73788143e-02, -1.26874642e-02, -2.09186919e-06]],            dtype=float32)

And that's the same as the inefficient version - so the backprop works...? So that means the issue is with the SBL function, and not the backprop.

Let's continue checking with the sbl function now

In [46]:
def SBL(
    X,
    y,
    prior_init=None,
    alpha_prior=(1e-6, 1e-6),
    beta_prior=(1e-6, 1e-6),
    tol=1e-3,
    max_iter=300,
    non_diff=True
):
    n_samples, n_features = X.shape
    norm_weight = jnp.concatenate((jnp.ones((n_features,)), jnp.zeros((1,))), axis=0)
    if prior_init is None:
        prior_init = jnp.concatenate(
            [jnp.ones((n_features,)), (1.0 / (jnp.var(y) + 1e-7))[jnp.newaxis]], axis=0
        )
    # adding zeros to z for coeffs
    gram = jnp.dot(X.T, X)
    XT_y = jnp.dot(X.T, y)

    prior_params, metrics = fixed_point_solver(
        update,
        (X, y, gram, XT_y, alpha_prior, beta_prior),
        prior_init,
        norm_weight,
        tol=tol,
        max_iter=max_iter,
    )
    
    if non_diff:
        prior = jax.lax.stop_gradient(prior_params) # no it doesnt backprop through the prior
    else:
        prior = prior_params
    loss, mn = evidence(X, y, prior, gram, XT_y, alpha_prior, beta_prior)

    return loss, mn, prior, metrics

def evidence(X, y, prior, gram, XT_y, alpha_prior, beta_prior):
    n_samples, n_features = X.shape
    alpha, beta = prior[:-1], prior[-1]

    sigma = update_sigma(gram, alpha, beta)
    coeffs = update_coeff(XT_y, beta, sigma)
    rmse_ = jnp.sum((y - jnp.dot(X, coeffs)) ** 2)

    score = jnp.sum(alpha_prior[0] * jnp.log(alpha) - alpha_prior[1] * alpha)
    score += beta_prior[0] * jnp.log(beta) - beta_prior[1] * beta
    score += 0.5 * (
        jnp.linalg.slogdet(sigma)[1]
        + n_samples * jnp.log(beta)
        + jnp.sum(jnp.log(alpha))
    )
    score -= 0.5 * (beta * rmse_ + jnp.sum(alpha * coeffs.squeeze() ** 2))

    return score.squeeze(), coeffs

In [47]:
non_grad_fn = jax.grad(lambda X: SBL(X, y, None, alpha_prior, beta_prior, tol, max_iter, non_diff=True)[2][0])
print(non_grad_fn(X))

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


In [53]:
grad_fn = jax.grad(lambda X: SBL(X, y, None, alpha_prior, beta_prior, tol, max_iter=300, non_diff=False)[2][0])
print(grad_fn(X))

[[ 1.2180909e+00  5.9950333e+00  9.6226645e+00 ...  1.2477463e+01
   5.9092836e+00  1.7023655e-03]
 [ 1.2702627e+00  5.9797325e+00  9.6257296e+00 ...  1.2490816e+01
   5.8725410e+00  1.5124257e-03]
 [ 1.3276619e+00  5.9609957e+00  9.6262436e+00 ...  1.2501865e+01
   5.8300838e+00  1.3015452e-03]
 ...
 [ 3.2141671e+00  2.7458637e+00  5.7366452e+00 ...  7.8860006e+00
   1.6516140e+00 -8.2875732e-03]
 [ 3.2597971e+00  2.7483113e+00  5.7631011e+00 ...  7.9280229e+00
   1.6366050e+00 -8.4358566e-03]
 [ 3.3040833e+00  2.7508843e+00  5.7890759e+00 ...  7.9691882e+00
   1.6222491e+00 -8.5795633e-03]]


In [49]:
X

DeviceArray([[ 3.1622779e-02, -1.3295893e-03, -8.8769535e-05, ...,
              -1.7141082e-06, -8.8170779e-08,  2.2782169e-09],
             [ 3.1622779e-02, -1.3617877e-03, -8.7602079e-05, ...,
              -1.4097492e-06, -6.9869365e-08,  3.9560764e-09],
             [ 3.1622779e-02, -1.3933505e-03, -8.5470499e-05, ...,
              -1.1204592e-06, -5.2953165e-08,  5.0470441e-09],
             ...,
             [ 3.1622779e-02, -1.6162228e-03,  5.6422185e-05, ...,
              -2.0657141e-04,  5.5559572e-06,  1.4853565e-06],
             [ 3.1622779e-02, -1.5948149e-03,  6.1155952e-05, ...,
              -1.9879382e-04,  5.8731534e-06,  1.3662623e-06],
             [ 3.1622779e-02, -1.5717234e-03,  6.5599801e-05, ...,
              -1.9107455e-04,  6.1442561e-06,  1.2423093e-06]],            dtype=float32)