In [1]:
# %% Imports
from jax.api import value_and_grad
from jax.config import config

config.update("jax_debug_nans", True)

from jax import numpy as jnp, random
import jax
from modax.data.burgers import burgers
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim

from modax.training.losses.SBL import loss_fn_SBL
from modax.training import train_max_iter
from sklearn.linear_model import ARDRegression
from modax.linear_model.SBL import SBL

from modax.training.losses.utils import precision, normal_LL
from flax.core import unfreeze
from jax.test_util import check_grads
import flax
%load_ext autoreload
%autoreload 2

In [2]:
# %% Making data
key = random.PRNGKey(42)

x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)


In [16]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))


In [4]:
# Neural network
model_state, loss_state = state
variables = {"params": params, **model_state}
(prediction, dt, theta, coeffs), updated_model_state = model.apply(
    variables, X, mutable=list(model_state.keys())
)

In [5]:
# Calculating MSE
n_samples, n_features = theta.shape
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)
p_mse, MSE = normal_LL(prediction, y, tau)

print(tau, MSE, p_mse)

22.46149 0.044520646 136.96254


In [6]:
# Forward pass 
hyper_prior_params = (
    n_samples / 2,
    n_samples / (2 * jax.lax.stop_gradient(tau)),
)
theta_normed = theta / jnp.linalg.norm(theta, axis=0)
prior_init = jnp.concatenate(
    [jnp.ones((n_features,)), 1.0 / jnp.var(dt)[jnp.newaxis]]
)

In [7]:
p_reg, mn, prior, fwd_metric = SBL(
    theta_normed,
    dt,
    prior_init=prior_init,
    beta_prior=hyper_prior_params,
    tol=1e-3,
    max_iter=1000,
    )

Reg = jnp.mean((dt - theta_normed @ mn) ** 2)
print(fwd_metric)
print(p_reg)

(DeviceArray(1000, dtype=int32), DeviceArray(1.4309582, dtype=float32))
2775.8103


In [8]:
reg = ARDRegression(
    fit_intercept=False,
    compute_score=True,
    tol=1e-3,
    alpha_1=hyper_prior_params[0],
    alpha_2=hyper_prior_params[1],
    threshold_lambda=1e6,
)
reg.fit(theta_normed, dt.squeeze())
print(reg.lambda_)
print(reg.scores_[-1])

[7.95559698e-02 2.05021217e+03 1.91450442e+03 2.14286879e+03
 1.40382828e-01 2.94175074e+00 1.96672140e+00 2.91738600e+03
 6.45218792e+02 7.32835003e+02 1.14048039e+03 2.36206652e+03]
2775.5527


So forwards pass is the same!

In [None]:
jax.grad(lambda theta: SBL(theta, dt, prior_init, hyper_prior_params, tol=1e-3, max_iter=1000)[0])(theta_normed)
jax.grad(lambda dt: SBL(theta_normed, dt, prior_init, hyper_prior_params, tol=1e-3, max_iter=1000)[0])(dt)

We can calculate the gradient, so that's not an issue... should we check if its correct?

In [10]:
grad_fn = jax.value_and_grad(loss_fn_SBL, has_aux=True)
(loss, (updated_state, metrics, output)), grad = grad_fn(
    optimizer.target, state, model, X, y
)

In [11]:
def check_grads_for_nan(grads):
    flattened_grads = flax.traverse_util.flatten_dict(unfreeze(grads))
    flattened_grads = jnp.concatenate([data.flatten() for key, data in flattened_grads.items()])
    
    return jnp.any(jnp.isnan(flattened_grads)).item()

In [14]:
check_grads_for_nan(grad)

False

In [21]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=1e-3)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))


optimizer, state = train_max_iter(update_fn, optimizer, state, 5000)

Loss step 0: -2912.77294921875
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
Invalid value encountered in the output of a jit function. Calling the de-optimized version.


FloatingPointError: invalid value (nan) encountered in while

Now let's run a loop and stop when the gradients go to nan

In [65]:
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state

In [None]:
for epoch in jnp.arange(1e4):
    grad_fn = jax.value_and_grad(loss_fn_SBL, has_aux=True)
    (loss, (updated_state, metrics, output)), grad = grad_fn(optimizer.target, state, model, X, y)
    
    if check_grads_for_nan(grad) is False:
        optimizer = optimizer.apply_gradient(grad)
        state = updated_state
    else:
        break
    
    if epoch % 10 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")

In [67]:
state

(FrozenDict({
     vars: {
         LeastSquares_0: {
             mask: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                           True,  True,  True,  True], dtype=bool),
         },
     },
 }),
 {'prior_init': DeviceArray([1.5142787e-03, 5.0820215e-03, 1.0640540e-02, 3.9893231e-01,
               5.9549212e-02, 1.6179248e+02, 2.8611942e+02, 4.9036827e+01,
               3.8905190e-03, 1.2776897e-02, 2.8615201e+02, 3.0173154e+02,
               1.0062928e+00], dtype=float32)})

In [68]:
metrics

{'alpha': DeviceArray([1.5142787e-03, 5.0820215e-03, 1.0640540e-02, 3.9893231e-01,
              5.9549212e-02, 1.6179248e+02, 2.8611942e+02, 4.9036827e+01,
              3.8905190e-03, 1.2776897e-02, 2.8615201e+02, 3.0173154e+02],            dtype=float32),
 'bayes_coeffs': DeviceArray([[-2.5412722e+01],
              [ 1.3900923e+01],
              [-9.5993853e+00],
              [ 1.2207934e+00],
              [ 3.9642036e+00],
              [ 2.0861239e-03],
              [ 1.8932729e-05],
              [-9.0139527e-03],
              [ 1.5618009e+01],
              [-8.7008715e+00],
              [-3.8609485e-04],
              [ 5.8121525e-04]], dtype=float32),
 'beta': DeviceArray(1.0062928, dtype=float32),
 'coeffs': DeviceArray([[-0.69913083],
              [ 2.682685  ],
              [-3.0587525 ],
              [ 0.14649728],
              [ 0.13522597],
              [ 2.527585  ],
              [-3.324237  ],
              [-0.00548166],
              [ 0.3055557 ],
     