Bayesian regression works with the old code. Now let's make a nice new function a la sklearn.

In [1]:
# %% Imports
import jax
from jax import jit, numpy as jnp, lax, random
from functools import partial
from modax.utils.forward_solver import fixed_point_solver
from modax.linear_model.bayesian_regression import bayesianregression, evidence

from sklearn.linear_model import BayesianRidge


%load_ext autoreload 
%autoreload 2

In [2]:
data = jnp.load('test_data.npy', allow_pickle=True).item()
y, X = data['y'], data['X']

X_normed = X / jnp.linalg.norm(X, axis=0)

In [3]:
prior, metric = bayesianregression(X_normed, y)
print(prior)
print(metric)

[3.5665381e-01 2.7395762e+05]
(DeviceArray(4, dtype=int32), DeviceArray(6.0796738e-06, dtype=float32))


In [4]:
bayesianregression(X, y)

(DeviceArray([1.8114342e+01, 2.7395122e+05], dtype=float32),
 (DeviceArray(300, dtype=int32), DeviceArray(0.00158882, dtype=float32)))

In [5]:
reg = BayesianRidge(fit_intercept=False, compute_score=True)
reg.fit(X_normed, y.squeeze())

BayesianRidge(compute_score=True, fit_intercept=False)

In [6]:
reg.alpha_

273806.14266750304

In [18]:
reg.lambda_

0.3566588567031124

In [19]:
reg.coef_

array([ 0.02766591, -0.42259253,  3.77226037,  0.43094458, -0.1104252 ,
       -4.15133131,  0.2873178 , -0.61194118,  0.10678613, -1.11750869,
       -0.28417433, -0.04745628])

In [20]:
reg.scores_

array([ 761.94079555, 3578.52057138, 4772.05472899, 4772.05489891,
       4772.05489891])

In [21]:
prior

DeviceArray([3.566560e-01, 2.739576e+05], dtype=float32)

In [22]:
evidence(X_normed, y, prior, hyper_prior_params=(0.0, 0.0))

(DeviceArray(4772.328, dtype=float32),
 DeviceArray([[ 0.02767511],
              [-0.4225819 ],
              [ 3.7722263 ],
              [ 0.43099666],
              [-0.11044621],
              [-4.151463  ],
              [ 0.28747463],
              [-0.6120944 ],
              [ 0.10681057],
              [-1.117393  ],
              [-0.28420353],
              [-0.04737961]], dtype=float32))

# Implementing as a loss function

In [10]:
def loss_fn_bayesian_ridge(params, state, model, X, y):
    variables = {"params": params, **state}
    (prediction, dt, theta, coeffs), updated_state = model.apply(
        variables, X, mutable=list(state.keys())
    )
    
    n_samples = theta.shape[0]
    prior_params_mse = (0.0, 0.0)
    
    # MSE stuff
    tau = precision(y, prediction, *prior_params_mse)
    p_mse, MSE = normal_LL(prediction, y, tau)
    
    # Regression stuff
    hyper_prior_params = (n_samples/2, n_samples / (2 * jax.lax.stop_gradient(tau))) # we dont want the gradient
    theta_normed = theta / jnp.linalg.norm(theta, axis=0)
    prior, fwd_metric = bayesianregression(theta_normed, dt, hyper_prior_params=hyper_prior_params)
    p_reg, mn = evidence(theta_normed, dt, prior, hyper_prior_params=hyper_prior_params)
    Reg = jnp.mean((dt - theta_normed @ mn)**2)
    
    
    loss = -(p_mse + p_reg)
    metrics = {"loss": loss, 
               "p_mse": p_mse,
               "mse": MSE, 
               "p_reg": p_reg,
               "reg": Reg, 
               "bayes_coeffs": mn, 
               "coeffs": coeffs, 
               "alpha": prior[:-1], 
               "beta": prior[-1],
               "tau": tau,
              "its": fwd_metric[0],
              "gap": fwd_metric[1]}

    return loss, (updated_state, metrics, (prediction, dt, theta, mn))

In [11]:
from modax.data.burgers import burgers
from modax.data.kdv import doublesoliton
from modax.training import train_max_iter
from modax.models import Deepmod
from modax.training.utils import create_update
from modax.training.losses.utils import precision, normal_LL
from modax.training.losses import loss_fn_pinn

from flax import optim

In [14]:
# %% Making data
key = random.PRNGKey(42)

x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

In [13]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)
update_fn = create_update(loss_fn_bayesian_ridge, (model, X, y))

In [14]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Loss step 0: -1985.470947265625
Loss step 1000: -9142.236328125
Loss step 2000: -9149.2734375
Loss step 3000: -9151.642578125
Loss step 4000: -9152.2529296875
Loss step 5000: -9151.6845703125
Loss step 6000: -9152.771484375
Loss step 7000: -9152.865234375
Loss step 8000: -9153.7734375
Loss step 9000: -9154.736328125


In [70]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)
update_fn = create_update(loss_fn_pinn, (model, X, y, 1.0))

In [71]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Loss step 0: 0.04526979848742485
Loss step 1000: 0.00039324039244093
Loss step 2000: 0.0003718193038366735
Loss step 3000: 0.0003707118739839643
Loss step 4000: 0.0003705784911289811
Loss step 5000: 0.00037028806400485337
Loss step 6000: 0.00037023628829047084
Loss step 7000: 0.0003697912907227874
Loss step 8000: 0.00036968590575270355
Loss step 9000: 0.00036963008460588753


In [30]:
# %% Making data
key = random.PRNGKey(42)

x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")

u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.1 * jnp.std(y) * random.normal(key, y.shape)


In [76]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)
update_fn = create_update(loss_fn_pinn, (model, X, y, 1.0))

In [77]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Loss step 0: 0.4641355574131012
Loss step 1000: 0.007112619932740927
Loss step 2000: 0.003383260453119874
Loss step 3000: 0.003378728637471795
Loss step 4000: 0.0033790224697440863
Loss step 5000: 0.0033775491174310446
Loss step 6000: 0.0033782690297812223
Loss step 7000: 0.0033780799712985754
Loss step 8000: 0.0033773169852793217
Loss step 9000: 0.0033768159337341785


In [78]:
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)
update_fn = create_update(loss_fn_bayesian_ridge, (model, X, y))

In [79]:
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Loss step 0: 1499.5491943359375
Loss step 1000: -5780.04150390625
Loss step 2000: -5813.61083984375
Loss step 3000: -5814.09814453125
Loss step 4000: -5814.857421875
Loss step 5000: -5815.029296875
Loss step 6000: -5814.96337890625
Loss step 7000: -5814.95068359375
Loss step 8000: -5814.1875
Loss step 9000: -5814.4736328125


# Warm restart

In [7]:
from dataclasses import dataclass

In [33]:
class loss_fn_bayesian_ridge(params, state, model, X, y, warm_restart=True):
    network_state, loss_state = state
    variables = {"params": params, **network_state}
    (prediction, dt, theta, coeffs), updated_state = model.apply(
        variables, X, mutable=list(state.keys())
    )

    n_samples = theta.shape[0]
    prior_params_mse = (0.0, 0.0)

    # MSE stuff
    tau = precision(y, prediction, *prior_params_mse)
    p_mse, MSE = normal_LL(prediction, y, tau)

    # Regression stuff
    hyper_prior_params = (n_samples/2, n_samples / (2 * jax.lax.stop_gradient(tau))) # we dont want the gradient
    theta_normed = theta / jnp.linalg.norm(theta, axis=0)
    
    if (network_state['prior_init'] is None) or (warm_restart is False):
        prior_init = jnp.stack([1.0, 1.0 / jnp.var(dt)])
    else:
        prior_init = network_state['prior']

    prior, fwd_metric = bayesianregression(theta_normed, dt, prior_params_init=prior_init, hyper_prior_params=hyper_prior_params)
    p_reg, mn = evidence(theta_normed, dt, prior, hyper_prior_params=hyper_prior_params)
    Reg = jnp.mean((dt - theta_normed @ mn)**2)
    
    network_state['prior_init'] = prior
    loss = -(p_mse + p_reg)
    metrics = {"loss": loss, 
               "p_mse": p_mse,
               "mse": MSE, 
               "p_reg": p_reg,
               "reg": Reg, 
               "bayes_coeffs": mn, 
               "coeffs": coeffs, 
               "alpha": prior[:-1], 
               "beta": prior[-1],
               "tau": tau,
              "its": fwd_metric[0],
              "gap": fwd_metric[1]}

    return loss, ((updated_state, network_state), metrics, (prediction, dt, theta, mn))

SyntaxError: 'return' outside function (<ipython-input-33-d5d56a8db427>, line 43)

In [15]:
# %% Building model and params
# with warm restart
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)
update_fn = create_update(loss_fn_bayesian_ridge(warm_restart=False), (model, X, y))

In [16]:
%%time
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Loss step 0: -1985.470947265625
Loss step 1000: -9141.5234375
Loss step 2000: -9149.9833984375
Loss step 3000: -9150.4365234375
Loss step 4000: -9151.61328125
Loss step 5000: -9152.45703125
Loss step 6000: -9151.58984375
Loss step 7000: -9153.0888671875
Loss step 8000: -9153.326171875
Loss step 9000: -9154.28515625
CPU times: user 59.9 s, sys: 970 ms, total: 1min
Wall time: 59.9 s


In [29]:
# %% Building model and params
# without warm restart
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)
update_fn = create_update(loss_fn_bayesian_ridge(warm_restart=True), (model, X, y))

In [20]:
%%time
optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

Loss step 0: -1985.470947265625
Loss step 1000: -9141.5400390625
Loss step 2000: -9149.931640625
Loss step 3000: -9150.6376953125
Loss step 4000: -9151.5419921875
Loss step 5000: -9152.056640625
Loss step 6000: -9151.7451171875
Loss step 7000: -9152.3037109375
Loss step 8000: -9152.853515625
Loss step 9000: -9154.81640625
CPU times: user 1min, sys: 907 ms, total: 1min 1s
Wall time: 1min


In [21]:
state

FrozenDict({
    vars: {
        LeastSquares_0: {
            mask: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                          True,  True,  True,  True], dtype=bool),
        },
    },
})

In [22]:
from flax.core import freeze, unfreeze

In [23]:
state = freeze({**unfreeze(state), "diff vars": 1})

In [30]:
variables = {"params": params, **state}
(prediction, dt, theta, coeffs), updated_state = model.apply(
    variables, X, mutable=list(state.keys())
)

In [31]:
updated_state

FrozenDict({
    vars: {
        LeastSquares_0: {
            mask: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                          True,  True,  True,  True], dtype=bool),
        },
    },
})