In [178]:
# %% Imports
import jax
from jax import random, numpy as jnp
from flax import optim
from modax.models import Deepmod
from modax.training import create_update
from modax.losses import loss_fn_pinn
from modax.logging import Logger

from sklearn.linear_model import BayesianRidge
from jax.scipy.stats import gamma
from modax.data.burgers import burgers
from time import time

from functools import partial
from jax import lax

In [180]:
# Making dataset
x = jnp.linspace(-3, 4, 5|0)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)
y_train += 0.1 * jnp.std(y_train) * jax.random.normal(key, y_train.shape)

In [181]:
# Instantiating model and optimizers
model = Deepmod(features=[50, 50, 1])
key = random.PRNGKey(42)
params = model.init(key, X_train)
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
optimizer = optimizer.create(params)

# Compiling train step
update = create_update(loss_fn_pinn, model=model, x=X_train, y=y_train)
_ = update(optimizer)  # triggering compilation

In [182]:
# Running to convergence
max_epochs = 20001
logger = Logger()
for epoch in jnp.arange(max_epochs):
    optimizer, metrics = update(optimizer)
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if epoch % 100 == 0:
        logger.write(metrics, epoch)
logger.close()

Loss step 0: 0.07853322476148605
Loss step 1000: 0.00017794837185647339
Loss step 2000: 0.0001510621514171362
Loss step 3000: 0.00012943477486260235
Loss step 4000: 0.00012168406101409346
Loss step 5000: 0.00011918714153580368
Loss step 6000: 0.00011606322368606925
Loss step 7000: 0.00011324064689688385
Loss step 8000: 0.00011104914301540703
Loss step 9000: 0.00010823373304447159
Loss step 10000: 0.0001053838204825297
Loss step 11000: 0.00010257066605845466
Loss step 12000: 9.865898755379021e-05
Loss step 13000: 9.564652282278985e-05
Loss step 14000: 9.192790457746014e-05
Loss step 15000: 8.85783665580675e-05
Loss step 16000: 8.627793431514874e-05
Loss step 17000: 8.401690138271078e-05
Loss step 18000: 7.963005919009447e-05
Loss step 19000: 7.57902380428277e-05
Loss step 20000: 7.397824811050668e-05


In [6]:
prediction, dt, theta, coeffs = model.apply(optimizer.target, X_train)

In [167]:
def fwd_solver(f, z_init):
    def cond_fun(carry):
        z_prev, z = carry
        return jnp.linalg.norm(z_prev - z) > 1e-4

    def body_fun(carry):
        _, z = carry
        return z, f(z)

    init_carry = (z_init, f(z_init))
    _, z_star = lax.while_loop(cond_fun, body_fun, init_carry)
    return z_star

In [15]:
@partial(jax.custom_vjp, nondiff_argnums=(0, ))
def fixed_point_layer(f, params, x):
    z_star = fwd_solver(lambda z: f(params, x, z), z_init=jnp.zeros_like(x))
    return z_star

def fixed_point_layer_fwd(f, params, x):
    z_star = fixed_point_layer(f, params, x)
    return z_star, (params, x, z_star)

def fixed_point_layer_bwd(f, res, z_star_bar):
    params, x, z_star = res
    _, vjp_a = jax.vjp(lambda params, x: f(params, x, z_star), params, x)
    _, vjp_z = jax.vjp(lambda z: f(params, x, z), z_star)
    return vjp_a(fwd_solver(lambda u: vjp_z(u)[0] + z_star_bar,
                      z_init=jnp.zeros_like(z_star)))

fixed_point_layer.defvjp(fixed_point_layer_fwd, fixed_point_layer_bwd)

In [21]:
f = lambda W, x, z: jnp.tanh(jnp.dot(W, z) + x)

ndim = 10
W = random.normal(random.PRNGKey(0), (ndim, ndim)) / jnp.sqrt(ndim)
x = random.normal(random.PRNGKey(1), (ndim,))

layer = jax.jit(fixed_point_layer, static_argnums=0)

In [22]:
z_star = layer(f, W, x)
print(z_star)

[ 0.00632886 -0.70152855 -0.9847213  -0.0419194  -0.6151645  -0.48185453
  0.5783277   0.9556748  -0.08354193  0.8447265 ]


In [27]:
g = jax.grad(lambda W: layer(f, W, x).sum())(W)
print(g[0])

[ 0.00733157 -0.81267565 -1.1407362  -0.04856092 -0.7126285  -0.55819744
  0.66995543  1.1070877  -0.09677795  0.9785612 ]


In [25]:
g = jax.grad(lambda x: layer(f, W, x).sum())(x)
print(g)

[1.1584356  0.38825384 0.06453395 1.792349   0.23890097 1.4024842
 0.7339767  0.09179301 0.9173474  0.18508697]


Now lets try this with bayesian ridge. To do so, we need to 
1. params -> dt
2. x -> theta
3. write a function taking in those and giving back alpha and beta

In [43]:
def bayes_ridge_update(y, X, z):
    # Unpacking parameters
    alpha_prev, beta_prev = z
    n_samples, n_terms  = X.shape
    a, b = n_samples/2, 1 /(n_samples/2 * 1e-4) # prior params
    
    # Preparing some matrices
    X_normed = X / jnp.linalg.norm(X, axis=0)
    gram = X_normed.T @ X_normed
    eigvals = jnp.linalg.eigvalsh(gram) 
    
    # Calculating intermediat matrices
    gamma_ = jnp.sum((beta_prev * eigvals) / (alpha_prev + beta_prev * eigvals))
    S = jnp.linalg.inv(beta_prev * gram + alpha_prev * jnp.eye(n_terms))
    mn = beta_prev * S @ X_normed.T @ y
    
    # Update estimate
    alpha = gamma_ / jnp.sum(mn**2)
    beta = (n_samples - gamma_ + 2 * (a - 1)) / (jnp.sum((y - X_normed @ mn)**2) + 2 * b)
    
    return jnp.stack([alpha, beta], axis=0)

In [44]:
bayes_ridge_update(dt, theta, jnp.ones((2, )))

DeviceArray([  0.32389224, 144.18567   ], dtype=float32)

So a single update works. Now to try with the forward pass.

In [45]:
z_init = jnp.stack([1, 1 / jnp.var(dt)], axis=0)

In [48]:
%%timeit
fwd_solver(lambda z: bayes_ridge_update(dt, theta, z), z_init=z_init)

260 ms ± 1.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


That works as well... Now to try with the layer.

In [66]:
def bayes_ridge_update(y, X, z):
    # Unpacking parameters
    alpha_prev, beta_prev = z
    n_samples, n_terms  = X.shape
    a, b = n_samples/2, 1 /(n_samples/2 * 1e-4) # prior params
    
    # Preparing some matrices
    X_normed = X / jnp.linalg.norm(X, axis=0)
    gram = X_normed.T @ X_normed
    eigvals = jnp.linalg.eigvalsh(gram) 
    
    # Calculating intermediate matrices
    gamma_ = jnp.sum((beta_prev * eigvals) / (alpha_prev + beta_prev * eigvals))
    S = jnp.linalg.inv(beta_prev * gram + alpha_prev * jnp.eye(n_terms))
    mn = beta_prev * S @ X_normed.T @ y
    
    # Update estimate
    alpha = gamma_ / jnp.sum(mn**2)
    beta = (n_samples - gamma_ + 2 * (a - 1)) / (jnp.sum((y - X_normed @ mn)**2) + 2 * b)
    
    return jnp.stack([alpha, beta], axis=0)

@partial(jax.custom_vjp, nondiff_argnums=(0, ))
def fixed_point_layer(f, params, x):
    z_star = fwd_solver(lambda z: f(params, x, z), z_init=jnp.stack([1, 1 / jnp.var(params)], axis=0))
    return z_star

def fixed_point_layer_fwd(f, params, x):
    z_star = fixed_point_layer(f, params, x)
    return z_star, (params, x, z_star)

def fixed_point_layer_bwd(f, res, z_star_bar):
    params, x, z_star = res
    _, vjp_a = jax.vjp(lambda params, x: f(params, x, z_star), params, x)
    _, vjp_z = jax.vjp(lambda z: f(params, x, z), z_star)
    return vjp_a(fwd_solver(lambda u: vjp_z(u)[0] + z_star_bar,
                      z_init=jnp.zeros_like(z_star)))

fixed_point_layer.defvjp(fixed_point_layer_fwd, fixed_point_layer_bwd)
layer = jax.jit(fixed_point_layer, static_argnums=0)

In [67]:
z_star = layer(bayes_ridge_update, dt, theta)

In [68]:
z_star

DeviceArray([1.6304933e-01, 1.9929942e+02], dtype=float32)

Lets see if we can calculate the derivatives:

In [127]:
jax.vjp(lambda y, X: layer(bayes_ridge_update, y, X), dt, theta)[1](jnp.array([0., 1.]))

(DeviceArray([[ 0.00169743],
              [ 0.00176666],
              [ 0.00163769],
              ...,
              [-0.00159182],
              [-0.00220567],
              [-0.00181732]], dtype=float32),
 DeviceArray([[-1.7207964e-06, -9.6736410e-05, -1.5373771e-04, ...,
               -4.2493947e-04,  8.8823639e-05, -6.7857636e-06],
              [-1.7349913e-06, -9.6280361e-05, -1.6067649e-04, ...,
               -4.2243715e-04,  8.8712732e-05, -6.7922888e-06],
              [-1.7202834e-06, -9.6421179e-05, -1.4779282e-04, ...,
               -4.2329155e-04,  8.8510293e-05, -6.7520523e-06],
              ...,
              [ 2.0118662e-06, -1.4835734e-04,  1.5040957e-04, ...,
               -6.2490714e-04,  1.7513241e-05, -5.6543859e-06],
              [ 1.3946650e-06, -1.4123453e-04,  2.1717543e-04, ...,
               -6.0367957e-04,  3.2208281e-05, -5.5440623e-06],
              [ 7.6851694e-07, -1.3259590e-04,  1.8263116e-04, ...,
               -5.7059061e-04,  4.4317359e-

In [126]:
jax.vjp(lambda y, X: layer(bayes_ridge_update, y, X), dt, theta)[1](jnp.array([1., 0.]))

(DeviceArray([[0.00191845],
              [0.00191423],
              [0.00190841],
              ...,
              [0.00326243],
              [0.00296539],
              [0.00272752]], dtype=float32),
 DeviceArray([[-6.3239258e-07,  7.4459640e-05, -1.9322045e-04, ...,
                3.4397098e-04, -2.8167053e-05,  3.3908968e-06],
              [-6.3568882e-07,  7.4632335e-05, -1.9278172e-04, ...,
                3.4469625e-04, -2.8151417e-05,  3.4021436e-06],
              [-6.1973941e-07,  7.3841387e-05, -1.9226574e-04, ...,
                3.4098790e-04, -2.8048093e-05,  3.3649455e-06],
              ...,
              [-1.8955561e-06,  9.7045580e-05, -3.2402133e-04, ...,
                4.3808299e-04, -7.9861502e-06,  3.2982689e-06],
              [-1.4947502e-06,  8.5814012e-05, -2.9566450e-04, ...,
                3.8835127e-04, -1.1735467e-05,  2.9846397e-06],
              [-1.2548866e-06,  8.0879137e-05, -2.7259887e-04, ...,
                3.6721979e-04, -1.5101070e-05,  2

YES

Now let's try to put it in a nice flax layer.

In [130]:
from flax import linen as nn

In [132]:
class BayesianRidge(nn.Module):
    @nn.compact
    def __call__(self, inputs):
        dt, theta = inputs
        z_star = layer(bayes_ridge_update, dt, theta)
        return z_star

In [133]:
model = BayesianRidge()

In [135]:
params = model.init(key, (dt, theta))

In [136]:
model.apply(params, (dt, theta))

DeviceArray([1.6304933e-01, 1.9929942e+02], dtype=float32)

In [157]:
from typing import Sequence, Callable
from modax.feature_generators import library_backward, library_forward
from modax.layers import LeastSquares, LeastSquaresMT
from modax.networks import MLP, MultiTaskMLP
from flax import linen as nn
from modax.losses import neg_LL


class Deepmod(nn.Module):
    """Simple feed-forward NN.
    """

    features: Sequence[int]  # this is dataclass, so we dont use __init__

    @nn.compact  # this function decorator lazily intializes the model, so it makes the layers the first time we call it
    def __call__(self, inputs):
        prediction, dt, theta = library_backward(MLP(self.features), inputs)
        z = BayesianRidge()((dt, theta))
        return prediction, dt, theta, z

In [138]:
model = Deepmod([30, 30, 1])

In [140]:
params = model.init(key, X_train)

In [143]:
model.apply(params, X_train)[-1]

DeviceArray([1.2429158e-02, 8.7499802e+01], dtype=float32)

In [158]:
def evidence(y, X, z):
    n_samples, n_terms  = X.shape
    alpha, beta = z
    a, b = n_samples/2, 1 /(n_samples/2 * 1e-4) # prior params
   
    A = alpha * jnp.eye(n_terms) + beta * X.T @ X 
    mn = beta * jnp.linalg.inv(A) @ X.T @ y
    
    E = beta / 2 * jnp.sum((y - X @ mn)**2) + alpha / 2 * jnp.sum(mn**2)
    loss = n_terms / 2 * jnp.log(alpha) + n_samples / 2 * jnp.log(beta) - E - 1/2 * jnp.linalg.slogdet(A)[1] - n_samples / 2 * jnp.log(2 * jnp.pi)
    loss += jnp.sum(gamma.logpdf(beta, a=a, scale=b))
    return -loss, mn

In [159]:
z = model.apply(params, X_train)[-1]
evidence(dt, theta, z)

(DeviceArray(1198.8901, dtype=float32),
 DeviceArray([[ 1.4184997e-04],
              [-1.5999153e-03],
              [ 9.9819690e-02],
              [ 1.7427839e-05],
              [ 7.0923567e-04],
              [-9.8933864e-01],
              [ 1.9797683e-04],
              [-1.6318355e-04],
              [-3.6606789e-03],
              [-7.3438287e-03],
              [-1.8105358e-03],
              [ 4.2999390e-04]], dtype=float32))

In [162]:
def loss_fn_pinn_bayes_regression(params, model, x, y):
    """ first argument should always be params!
    """
    prediction, dt, theta, z = model.apply(params, x)

    # MSE
    sigma_ml = jnp.mean((prediction - y) ** 2)
    tau = 1 / sigma_ml
    MSE = neg_LL(prediction, y, tau)
    
    # Reg
    Reg, mn = evidence(dt, theta, z)
    loss = MSE + Reg
    metrics = {
        "loss": loss,
        "mse": MSE,
        "reg": Reg,
        "coeff": mn,
        "tau": tau, 
        "beta": z[1],
        "alpha": z[0]
    }
    return loss, metrics

In [164]:
loss_fn_pinn_bayes_regression(params, model, X_train, y_train)

(DeviceArray(3920.4146, dtype=float32),
 {'loss': DeviceArray(3920.4146, dtype=float32),
  'mse': DeviceArray(1625.2769, dtype=float32),
  'reg': DeviceArray(2295.1377, dtype=float32),
  'coeff': DeviceArray([[-0.08180069],
               [-2.2042918 ],
               [-0.39007962],
               [-0.19036376],
               [-0.4221705 ],
               [-6.298418  ],
               [ 2.7376833 ],
               [-0.07010293],
               [-0.8428339 ],
               [-0.74586403],
               [ 3.294509  ],
               [ 1.1794865 ]], dtype=float32),
  'tau': DeviceArray(3.362216, dtype=float32),
  'beta': DeviceArray(87.4998, dtype=float32),
  'alpha': DeviceArray(0.01242916, dtype=float32)})

Okay so that works... maybe now let's try to do a full run

In [172]:
# Making dataset
x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)
y_train += 0.1 * jnp.std(y_train) * jax.random.normal(key, y_train.shape)

In [173]:
# Instantiating model and optimizers
model = Deepmod(features=[50, 50, 1])
key = random.PRNGKey(42)
params = model.init(key, X_train)
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
optimizer = optimizer.create(params)

# Compiling train step
update = create_update(loss_fn_pinn_bayes_regression, model=model, x=X_train, y=y_train)
_ = update(optimizer)  # triggering compilation

In [174]:
# Running to convergence
max_epochs = 20001
logger = Logger()
for epoch in jnp.arange(max_epochs):
    optimizer, metrics = update(optimizer)
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if epoch % 100 == 0:
        logger.write(metrics, epoch)
logger.close()

Loss step 0: 2028.1748046875
Loss step 1000: -1361.991943359375
Loss step 2000: -1371.446533203125
Loss step 3000: -1373.916748046875
Loss step 4000: -1377.374755859375
Loss step 5000: -1380.127685546875
Loss step 6000: -1383.46728515625
Loss step 7000: -1385.64208984375
Loss step 8000: -1388.329833984375
Loss step 9000: -1391.118896484375
Loss step 10000: -1393.1318359375
Loss step 11000: -1395.61376953125
Loss step 12000: -1397.76318359375
Loss step 13000: -1400.376220703125
Loss step 14000: -1403.441162109375
Loss step 15000: -1405.583984375
Loss step 16000: -1407.463134765625
Loss step 17000: -1409.633544921875
Loss step 18000: -1411.95361328125
Loss step 19000: -1413.703857421875
Loss step 20000: -1414.93994140625
