In [1]:
# %% Imports
from jax import numpy as jnp, random
from modax.data.kdv import doublesoliton
from modax.models import Deepmod
from modax.training.utils import create_update
from flax import optim
import jax
from modax.training import train_max_iter
from modax.training.losses.SBL import loss_fn_SBL

%load_ext autoreload
%autoreload 2

In [168]:
key = random.PRNGKey(42)
max_iterations = 672

# Running noise levels
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, idexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y_noisy = y + 0.01 * jnp.std(y) * random.normal(key, y.shape)

In [20]:
model = Deepmod([30, 30, 30, 1], (5, 4))
optimizer_def = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
update_fn = create_update(loss_fn_SBL, (model, X, y_noisy, True))

variables = model.init(key, X)
state, params = variables.pop("params")
state = (state, {"prior_init": None})  # adding prior to state
optimizer = optimizer_def.create(params)

In [3]:
key = random.PRNGKey(42)
max_iterations = 20000

# Running noise levels
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, idexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y_noisy = y + 0.30 * jnp.std(y) * random.normal(key, y.shape)

In [6]:
model = Deepmod([30, 30, 30, 1], (3, 2))
optimizer_def = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
update_fn = create_update(loss_fn_SBL, (model, X, y_noisy, True))

variables = model.init(key, X)
state, params = variables.pop("params")
state = (state, {"prior_init": None})  # adding prior to state
optimizer = optimizer_def.create(params)

In [7]:
new_optimizer, new_state = train_max_iter(
            update_fn,
            optimizer,
            state,
            max_iterations)


Loss step 0: 1421.06201171875
Loss step 500: -1819.4847412109375
Loss step 1000: -2456.889404296875
Loss step 1500: -2754.833740234375
Loss step 2000: -2779.510498046875
Loss step 2500: -2782.551513671875
Loss step 3000: -2783.498779296875
Loss step 3500: -2784.201904296875
Loss step 4000: -2784.945068359375
Loss step 4500: -2784.926513671875
Loss step 5000: -2786.097412109375
Loss step 5500: -2786.4609375
Loss step 6000: -2787.45654296875
Loss step 6500: -2788.13818359375
Loss step 7000: -2789.01904296875
Loss step 7500: -2789.646240234375
Loss step 8000: -2790.39599609375
Loss step 8500: -2790.65771484375
Loss step 9000: -2791.02685546875
Loss step 9500: -2790.949951171875
Loss step 10000: -2791.286865234375
Loss step 10500: -2792.032470703125
Loss step 11000: -2791.7353515625
Loss step 11500: -2791.941162109375
Loss step 12000: -2792.588134765625
Loss step 12500: -2792.359375
Loss step 13000: -2792.79833984375
Loss step 13500: -2792.842041015625
Loss step 14000: -2793.255615234375
L

In [160]:
max_iterations = 6920

In [161]:
new_optimizer, new_state = train_max_iter(
            update_fn,
            new_optimizer,
            new_state,
            max_iterations)


Loss step 0: -5733.99951171875
Loss step 500: -5805.4658203125
Loss step 1000: -5858.4013671875
Loss step 1500: -5963.66552734375
Loss step 2000: -5777.2314453125
Loss step 2500: -5435.1142578125
Loss step 3000: -5722.0380859375
Loss step 3500: -5765.63671875
Loss step 4000: -5847.66796875
Loss step 4500: -5744.1875
Loss step 5000: -5847.427734375
Loss step 5500: -5831.9853515625
Loss step 6000: -5525.16845703125
Loss step 6500: -5595.25390625
6920


In [162]:
model_state, loss_state = new_state
variables = {"params": new_optimizer.target, **model_state}

In [163]:
pred, dt, theta, coeffs = model.apply(variables, X)

RuntimeError: cuSolver execution failed

In [92]:
tau = 1 / jnp.mean((pred - y)**2)
beta_prior = (X.shape[0] / 2, X.shape[0] / tau) 

In [93]:
from modax.linear_model.SBL import SBL

In [118]:
evidence, coeffs, prior, metrics = SBL(theta, dt, tol=1e-3, hyper_prior=((1e-6, 1e-6), beta_prior), max_iter=2000)
prior, noise = prior[:-1], prior[-1]

In [119]:
prior.shape

(30,)

In [120]:
noise

DeviceArray(35.04488, dtype=float32)

In [121]:
metrics

(DeviceArray(2001, dtype=int32),
 DeviceArray(0.33096713, dtype=float32),
 DeviceArray([ 4.9552847e-02, -3.2397120e+00,  2.9729975e-03,
              -5.3110254e-01, -1.2173046e-02, -8.9378515e-03,
              -1.3019951e-01,  2.5155051e+00,  7.9584103e-03,
              -8.3584118e-01, -8.2758162e-04,  3.6111546e-03,
               1.3356283e-03,  3.0777549e-02,  2.4774979e-04,
               9.5186532e-01, -7.4547366e-04, -1.1923624e-04,
               2.6828152e-04, -7.8906360e+00, -5.3846417e-04,
              -8.2665271e-01, -3.8805007e-04, -2.2338545e-03,
               1.9299846e-04,  3.0750492e+00, -6.7338059e-03,
               2.5646013e-01,  1.6463421e-04,  1.0001244e-03],            dtype=float32))

In [122]:
evidence

DeviceArray(774.44385, dtype=float32)

In [123]:
from sklearn.linear_model import ARDRegression

In [124]:
beta_prior

(500.0, DeviceArray(15.850436, dtype=float32))

In [132]:
#reg = ARDRegression(fit_intercept=False, n_iter=5000, tol=1e-3, verbose=True, compute_score=True, alpha_1=beta_prior[0], alpha_2=beta_prior[1], threshold_lambda=1e4)
#reg.fit(theta, dt.squeeze())

reg = ARDRegression(fit_intercept=False, n_iter=10000, tol=1e-3, verbose=True, compute_score=True)
reg.fit(theta, dt.squeeze())

Converged after 14 iterations


ARDRegression(compute_score=True, fit_intercept=False, n_iter=10000,
              verbose=True)

In [133]:
reg.scores_[-1]

2364.2839902757696

In [134]:
reg.n_iter

10000

In [135]:
reg.alpha_

310.9164734804584

In [136]:
reg.coef_[:, None]

array([[ 0.05877547],
       [-3.31372212],
       [ 0.        ],
       [-0.51134804],
       [ 0.        ],
       [ 0.        ],
       [-0.35908504],
       [ 4.09710355],
       [-0.19189519],
       [-1.1667189 ],
       [ 0.        ],
       [ 0.        ],
       [ 0.41027288],
       [-6.79879527],
       [ 0.32251991],
       [ 1.90366085],
       [ 0.        ],
       [ 0.        ],
       [-0.06397146],
       [ 0.4204566 ],
       [-0.09195078],
       [-1.09869162],
       [ 0.        ],
       [ 0.        ],
       [ 0.        ],
       [ 0.66802318],
       [ 0.        ],
       [ 0.22817827],
       [ 0.        ],
       [ 0.        ]])

In [138]:
jnp.stack([prior, reg.lambda_], axis=1)

DeviceArray([[3.98927795e+02, 2.88755066e+02],
             [9.52427238e-02, 9.10642445e-02],
             [2.33041094e+04, 2.80085312e+05],
             [3.53271675e+00, 3.82210374e+00],
             [6.28030029e+03, 1.88773066e+04],
             [1.15827148e+04, 1.11728891e+05],
             [5.75298615e+01, 7.71865511e+00],
             [1.56551600e-01, 5.94956987e-02],
             [5.23519336e+03, 2.69252777e+01],
             [1.28710139e+00, 7.32866466e-01],
             [7.30539375e+04, 4.06037500e+04],
             [3.75362812e+04, 8.98617812e+04],
             [9.10812793e+03, 5.82656193e+00],
             [1.15768995e+01, 2.15826016e-02],
             [3.05374414e+04, 9.56430817e+00],
             [7.79449999e-01, 2.75549799e-01],
             [1.24610547e+05, 4.92433516e+04],
             [1.29742672e+05, 4.72120281e+05],
             [1.92278672e+04, 2.07681198e+02],
             [1.60088576e-02, 3.77857089e+00],
             [3.06472617e+04, 1.17738655e+02],
             

In [139]:
jnp.concatenate([coeffs, reg.coef_[:, None]], axis=1)

DeviceArray([[ 4.8696157e-02,  5.8775466e-02],
             [-3.2245805e+00, -3.3137221e+00],
             [ 2.1546360e-03,  0.0000000e+00],
             [-5.4744256e-01, -5.1134807e-01],
             [-1.2097111e-02,  0.0000000e+00],
             [-9.4791455e-03,  0.0000000e+00],
             [-1.2721595e-01, -3.5908505e-01],
             [ 2.4134932e+00,  4.0971036e+00],
             [ 7.0957816e-03, -1.9189519e-01],
             [-7.1535170e-01, -1.1667188e+00],
             [-1.0213110e-03,  0.0000000e+00],
             [ 3.8789716e-03,  0.0000000e+00],
             [ 1.1479468e-03,  4.1027287e-01],
             [ 2.4756903e-02, -6.7987952e+00],
             [ 4.1476483e-04,  3.2251993e-01],
             [ 7.4088562e-01,  1.9036609e+00],
             [-8.0181990e-04,  0.0000000e+00],
             [-3.3212517e-04,  0.0000000e+00],
             [ 1.2976903e-04, -6.3971460e-02],
             [-7.7441735e+00,  4.2045659e-01],
             [-7.9772383e-04, -9.1950782e-02],
             

In [109]:
@jax.jit
def update_posterior(gram, XT_y, prior):
    alpha, beta = prior
    L = jnp.linalg.cholesky(jnp.diag(alpha) + beta * gram)
    R = jax.scipy.linalg.solve_triangular(L, jnp.eye(alpha.shape[0]), check_finite=False, lower=True)
    sigma = jnp.dot(R.T, R)
    mean = beta * jnp.dot(sigma, XT_y)

    return mean, sigma

In [108]:
coeffs, _ = update_posterior(gram, XT_y, (prior, noise))

In [114]:
coeffs

DeviceArray([[-5.7919014e-02],
             [-3.0032797e+00],
             [-2.2747514e-01],
             [-8.2916158e-01],
             [-4.8000738e-02],
             [ 1.7009124e-03],
             [-2.5027883e-01],
             [-4.2495742e+00],
             [-5.1446104e-01],
             [ 1.1387956e+00],
             [-9.3527526e-02],
             [ 7.0195116e-02],
             [ 3.1325640e-04],
             [-4.7878356e+00],
             [ 2.8211581e-02],
             [-2.1874134e-03],
             [ 1.5482834e-01],
             [-1.3086452e-01],
             [-5.6049746e-01],
             [ 1.3943349e+01],
             [ 1.6021583e-03],
             [ 1.3293099e-04],
             [-5.8732081e-02],
             [ 7.3281229e-02],
             [-1.0674131e-03],
             [-6.7824774e+00],
             [ 6.9751003e-04],
             [-2.5307190e-01],
             [-9.2636282e-04],
             [-1.2664698e-02]], dtype=float32)

In [125]:
def update_posterior(gram, XT_y, prior):
    alpha, beta = prior
    L = jnp.linalg.cholesky(jnp.diag(alpha) + beta * gram)
    print(L)
    R = jax.scipy.linalg.solve_triangular(L, jnp.eye(alpha.shape[0]), check_finite=False, lower=True)
    sigma = jnp.dot(R.T, R)
    mean = beta * jnp.dot(sigma, XT_y)

    return mean, sigma

In [126]:
coeffs, _ = update_posterior(gram, XT_y, (prior, noise))

[[ 2.25791779e+02  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 6.57620525e+00  8.84318237e+01  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [-6.7

In [127]:
coeffs

DeviceArray([[-5.7919014e-02],
             [-3.0032797e+00],
             [-2.2748420e-01],
             [-8.2918572e-01],
             [-4.8003003e-02],
             [ 1.6994019e-03],
             [-2.5027883e-01],
             [-4.2495742e+00],
             [-5.1450938e-01],
             [ 1.1388923e+00],
             [-9.3521483e-02],
             [ 7.0201159e-02],
             [ 3.1212348e-04],
             [-4.7882223e+00],
             [ 2.8223665e-02],
             [-2.1843922e-03],
             [ 1.5482230e-01],
             [-1.3080409e-01],
             [-5.6048536e-01],
             [ 1.3940255e+01],
             [ 1.6021583e-03],
             [ 1.2386752e-04],
             [-5.8736615e-02],
             [ 7.3305398e-02],
             [-1.0674131e-03],
             [-6.7824774e+00],
             [ 6.9751003e-04],
             [-2.5307190e-01],
             [-9.2636282e-04],
             [-1.2664698e-02]], dtype=float32)