In [1]:
# %% Imports
from jax.api import value_and_grad
from jax.config import config

config.update("jax_debug_nans", True)

from jax import numpy as jnp, random
import jax
from modax.data.burgers import burgers
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim

from modax.training.losses.SBL import loss_fn_SBL
from modax.training import train_max_iter
from sklearn.linear_model import ARDRegression
from modax.linear_model.SBL import SBL

from flax.core import unfreeze
import flax
%load_ext autoreload
%autoreload 2

In [None]:
# %% Making data
key = random.PRNGKey(42)

x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)


In [None]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))


In [None]:
grad_fn = jax.value_and_grad(loss_fn_SBL, has_aux=True)
(loss, (updated_state, metrics, output)), grad = grad_fn(
    optimizer.target, state, model, X, y
)

In [None]:
print(loss)

In [None]:
def check_grads_for_nan(grads):
    flattened_grads = flax.traverse_util.flatten_dict(unfreeze(grads))
    flattened_grads = jnp.concatenate([data.flatten() for key, data in flattened_grads.items()])
    
    return jnp.any(jnp.isnan(flattened_grads))

In [None]:
check_grads_for_nan(grad)

In [None]:
metrics

Okay so it reached the maximum number of iterations, maybe we should fix that later; is that a problem though?

No it isn't cause the differentiation is independent of the value; could be 10 or 100 as well.
We can probably fix this later using the fast marginalization property. Let's see how many iteration sklearn needs:

In [None]:
model_state, loss_state = state
variables = {"params": params, **model_state}
(prediction, dt, theta, coeffs), updated_model_state = model.apply(
    variables, X, mutable=list(model_state.keys())
)

theta_normed = theta / jnp.linalg.norm(theta, axis=0)
n_samples, n_features = theta.shape
tau = 1 / jnp.mean((y - prediction) ** 2)
hyper_prior = (n_samples / 2, n_samples / 2 * 1 / tau)

reg = ARDRegression(
    fit_intercept=False,
    compute_score=True,
    tol=1e-3 * tau,
    alpha_1=hyper_prior[0],
    alpha_2=hyper_prior[1],
    threshold_lambda=1e6,
)
reg.fit(theta_normed, dt.squeeze())

reg.scores_

Much less but that's because of their convergence criterium; let's check the difference in alpha and beta

In [None]:
print(jnp.stack([reg.lambda_, metrics['alpha']], axis=1))

In [None]:
jnp.abs(reg.lambda_ - metrics['alpha']) / reg.lambda_

Which is fairly small save for one or two terms. So everythnig seems alright here; let's run it for a while:

In [None]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Okay so we're getting Nans, seems to be in the custom backprop. Let's up the number of iterations:

In [None]:
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state

In [None]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

So it's not an issue with that, and also not with the jit. What about setting the hyperprior standard, would that fix it?

In [None]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))


In [None]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

So that ran without error, good. That would imply it's an issue with the hyperprior. Now what if we run kdv?

In [3]:
key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [None]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))


In [None]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Okay so that gets stuck; weird. Maybe remove the max iterations?

In [None]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))


In [None]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Okay so it's not only that... Maybe let's use a really back ficed point solver for the backwards pass?

In [None]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))


In [None]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Same issue... What if we used the simple one in the forward pass as well?

In [None]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))


In [None]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Okay so it gets stuck; does it get stuck on the forward on backward pass? Let's limit the forward pass.

In [4]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))


In [None]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

In [10]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))


In [11]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Loss step 0: -1080.5452880859375
Loss step 1: -1500.6864013671875
Loss step 2: -1901.683837890625
Loss step 3: -1863.78271484375
Loss step 4: -1811.1790771484375
Loss step 5: -2057.734130859375
Loss step 6: -2330.02587890625
Loss step 7: -2495.3330078125
Loss step 8: -2549.15234375
Loss step 9: -2404.401611328125
Loss step 10: -2514.052490234375
Loss step 11: -2837.67724609375
Loss step 12: -2988.36962890625
Loss step 13: -3150.870849609375
Loss step 14: -3011.189697265625
Loss step 15: -3049.209228515625
Loss step 16: -3191.8525390625
Loss step 17: -3220.287109375
Loss step 18: -3403.50341796875
Loss step 19: -3307.265625
Loss step 20: -3272.236572265625
Loss step 21: -3322.76611328125
Loss step 22: -3396.041015625
Loss step 23: -3734.49072265625
Loss step 24: -3723.11083984375
Loss step 25: -3538.55810546875
Loss step 26: -3426.380615234375
Loss step 27: -3476.582763671875
Loss step 28: -3828.815185546875
Loss step 29: -4077.4521484375
Loss step 30: -3904.801513671875
Loss step 31: -

StoreException: Store empty