In [1]:
# %% Imports
import jax
from jax import random, numpy as jnp
from functools import partial
from jax import lax
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from flax import linen as nn
from flax.core import freeze
from typing import Tuple, Sequence

from modax.feature_generators import library_backward
from modax.networks import MLP, MultiTaskMLP
from modax.data.burgers import burgers
from modax.logging import Logger
from flax import optim

from code import fwd_solver, bayes_ridge_update, fixed_point_solver, evidence
from sklearn.linear_model import BayesianRidge
from jax.test_util import check_grads


from modax.losses import neg_LL

%load_ext autoreload
%autoreload 2

In [13]:
data = jnp.load('test_data.npy', allow_pickle=True).item()
y, X = data['y'], data['X']

In [8]:
key = random.PRNGKey(42)

In [5]:
%env XLA_FLAGS="--xla_force_host_platform_device_count=4"

env: XLA_FLAGS="--xla_force_host_platform_device_count=4"


In [21]:
class BayesianRegression(nn.Module):
    hyper_prior: Tuple
    tol: float = 1e-3
        
    @nn.compact
    def __call__(self, inputs):
        z_init=jnp.stack([1, 1 / jnp.var(inputs[0])], axis=0)
        f = lambda prior, y, X: bayes_ridge_update(prior, y, X, self.hyper_prior)
        
        z_star = fixed_point_solver(f, inputs, z_init, tol=self.tol)
        return z_star

In [14]:
n_samples, n_terms = X.shape
hyper_prior =  jnp.stack([n_samples / 2, 1 / (n_samples / 2 * 1e-4)],  axis=0)
#inputs = (y, X)

In [23]:
model = BayesianRegression(hyper_prior)
params = model.init(key, inputs)

In [24]:
model.apply(params, inputs)

DeviceArray([ 0.36742398, 49.686714  ], dtype=float32)

So it works! Now let's move the construction of the update function to a separate initi

In [146]:
class BayesianRegression(nn.Module):
    hyper_prior: Tuple
    tol: float = 1e-3
        
    def setup(self):
        self.update = lambda prior, y, X: bayes_ridge_update(prior, y, X, self.hyper_prior)
        
    @nn.compact
    def __call__(self, inputs):
        z_init=jnp.stack([1, 1 / jnp.var(inputs[0])], axis=0)
        
        z_star = fixed_point_solver(self.update, inputs, z_init, tol=self.tol)
        return z_star

In [29]:
model = BayesianRegression(hyper_prior)
params = model.init(key, inputs)

In [31]:
%%time
model.apply(params, inputs)

CPU times: user 354 ms, sys: 6.35 ms, total: 360 ms
Wall time: 357 ms


DeviceArray([ 0.36742398, 49.686714  ], dtype=float32)

In [32]:
%%time
model.apply(params, inputs)

CPU times: user 370 ms, sys: 7.26 ms, total: 377 ms
Wall time: 375 ms


DeviceArray([ 0.36742398, 49.686714  ], dtype=float32)

now let's try to work with flax variables:

In [7]:
class BayesianRegression(nn.Module):
    hyper_prior: Tuple
    tol: float = 1e-3
        
    def setup(self):
        self.update = lambda prior, y, X: bayes_ridge_update(prior, y, X, self.hyper_prior)
        
    @nn.compact
    def __call__(self, inputs):
        is_initialized = self.has_variable('bayes', 'z')
        z_init = self.variable('bayes', 'z', 
                               lambda y: jnp.stack([1, 1 / jnp.var(y)], axis=0), 
                               inputs[0])
      
        z_star = fixed_point_solver(self.update, inputs, z_init.value, tol=self.tol)
        if is_initialized:
            z_init.value = z_star
        return z_star

In [111]:
model = BayesianRegression(hyper_prior, tol=1e-5)
variables = model.init(key, inputs)

In [112]:
print(variables)

FrozenDict({
    bayes: {
        z: DeviceArray([ 1.    , 30.3171], dtype=float32),
    },
})


In [113]:
%%time
y, updated_state = model.apply(variables, inputs, mutable=['bayes'])

CPU times: user 373 ms, sys: 6.68 ms, total: 379 ms
Wall time: 388 ms


In [114]:
updated_state

FrozenDict({
    bayes: {
        z: DeviceArray([ 0.3674249, 49.686718 ], dtype=float32),
    },
})

In [115]:
try:
    old_state, params = variables.pop('params') # if we don't have params
    variables = freeze({'params': params, **updated_state})
except:
    variables = freeze(updated_state)

In [116]:
print(variables)

FrozenDict({
    bayes: {
        z: DeviceArray([ 0.3674249, 49.686718 ], dtype=float32),
    },
})


In [118]:
%%time
y, updated_state = model.apply(variables, inputs, mutable=['bayes'])

CPU times: user 343 ms, sys: 5.91 ms, total: 348 ms
Wall time: 345 ms


Now that it works, let's put it in a model:

In [6]:
class Deepmod(nn.Module):
    features: Sequence[int]  
    hyper_prior: Tuple
    tol: float = 1e-3
        
    @nn.compact 
    def __call__(self, inputs):
        prediction, dt, theta = library_backward(MLP(self.features), inputs)
        z = BayesianRegression(self.hyper_prior, self.tol)((dt, theta))
        return prediction, dt, theta, z

In [16]:
# Making dataset
x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)
y_train += 0.1 * jnp.std(y_train) * jax.random.normal(key, y_train.shape)

In [18]:
model = Deepmod([50, 50, 1], hyper_prior)
variables = model.init(key, X_train)

In [11]:
model.apply(variables, X_train, mutable='bayes');

NameError: name 'variables' is not defined

In [12]:
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
optimizer = optimizer.create(variables)

In [13]:
optimizer.target['bayes']

FrozenDict({
    BayesianRegression_0: {
        z: DeviceArray([ 1.    , 44.1788], dtype=float32),
    },
})

In [2]:
def loss_fn_pinn_bayes_regression(params, state, model, x, y):
    """ first argument should always be params!
    """
    variables = {'params': params, **state}
    (prediction, dt, theta, z), updated_state = model.apply(variables, x, mutable=list(state.keys()))
    
    
    # MSE
    sigma_ml = jnp.mean((prediction - y) ** 2)
    tau = 1 / sigma_ml
    MSE = neg_LL(prediction, y, tau)
    
    # Reg
    theta_normed = theta / jnp.linalg.norm(theta, axis=0, keepdims=True)
    Reg, mn = evidence(z, dt, theta_normed, model.hyper_prior)
    loss = MSE - Reg
    metrics = {
        "loss": loss,
        "mse": MSE,
        "reg": Reg,
        "coeff": mn,
        "tau": tau, 
        "beta": z[1],
        "alpha": z[0]
    }
    return loss, (updated_state, metrics)

In [3]:
from jax import value_and_grad, jit


def create_update(loss_fn, *args, **kwargs):
    def step(opt, state, loss_fn, *args, **kwargs):
        grad_fn = value_and_grad(loss_fn, argnums=0, has_aux=True)
        (loss, (updated_state, metrics)), grad = grad_fn(opt.target, state, *args, **kwargs)
        opt = opt.apply_gradient(grad)  # Return the updated optimizer with parameters.
        
        return (opt, updated_state), metrics

    return jit(lambda opt, state: step(opt, state, loss_fn, *args, **kwargs))

In [19]:
# Making dataset
x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)
y_train += 0.01 * jnp.std(y_train) * jax.random.normal(key, y_train.shape)

In [20]:
model = Deepmod([50, 50, 1], hyper_prior, tol=1e-4)
variables = model.init(key, X_train)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop('params')
optimizer = optimizer.create(params)

In [21]:
# Compiling train step
update = create_update(loss_fn_pinn_bayes_regression, model=model, x=X_train, y=y_train)
_ = update(optimizer, state)  # triggering compilation

In [22]:
# Running to convergence
max_epochs = 10000
logger = Logger()
for epoch in jnp.arange(max_epochs):
    (optimizer, state), metrics = update(optimizer, state)
    if epoch % 100 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if epoch % 100 == 0:
        logger.write(metrics, epoch)
logger.close()

Loss step 0: -1360.2958984375
Loss step 100: -3297.87841796875
Loss step 200: -3842.685791015625
Loss step 300: -4392.36376953125
Loss step 400: -5141.0751953125
Loss step 500: -5726.458984375
Loss step 600: -5991.244140625
Loss step 700: -6167.302734375
Loss step 800: -6316.0205078125
Loss step 900: -6411.0087890625
Loss step 1000: -6495.7685546875
Loss step 1100: -6547.9140625
Loss step 1200: -6583.92626953125
Loss step 1300: -6613.85498046875
Loss step 1400: -6658.244140625
Loss step 1500: -6674.17431640625
Loss step 1600: -6695.69921875
Loss step 1700: -6671.9560546875
Loss step 1800: -6719.74609375
Loss step 1900: -6761.48388671875
Loss step 2000: -6744.74560546875
Loss step 2100: -6731.6728515625
Loss step 2200: -6751.22119140625
Loss step 2300: -6743.8203125
Loss step 2400: -6761.58203125
Loss step 2500: -6753.8935546875
Loss step 2600: -6755.419921875
Loss step 2700: -6765.4755859375
Loss step 2800: -6759.00048828125
Loss step 2900: -6764.4111328125
Loss step 3000: -6771.03125


In [18]:
metrics

{'alpha': DeviceArray(0.5818718, dtype=float32),
 'beta': DeviceArray(48.99622, dtype=float32),
 'coeff': DeviceArray([[ 0.15310402],
              [-1.4754432 ],
              [ 2.3529584 ],
              [ 0.71037406],
              [ 0.4666279 ],
              [-2.0328088 ],
              [ 0.28171843],
              [ 0.03225338],
              [-1.1648126 ],
              [-1.2099943 ],
              [-0.49702704],
              [-0.5595163 ]], dtype=float32),
 'loss': DeviceArray(-4549.8765, dtype=float32),
 'mse': DeviceArray(-2594.754, dtype=float32),
 'reg': DeviceArray(1955.1224, dtype=float32),
 'tau': DeviceArray(3063.719, dtype=float32)}

In [25]:
# Running to convergence
max_epochs = 1000
logger = Logger()
for epoch in jnp.arange(max_epochs):
    (optimizer, state), metrics = update(optimizer, state)
    if epoch % 100 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if epoch % 100 == 0:
        logger.write(metrics, epoch)
logger.close()

Loss step 0: -4480.9990234375
Loss step 100: -4482.5615234375
Loss step 200: -4484.4423828125
Loss step 300: -4485.9072265625
Loss step 400: -4486.9345703125
Loss step 500: -4488.0380859375
Loss step 600: -4488.5107421875
Loss step 700: -4489.59375
Loss step 800: -4489.91943359375
Loss step 900: -4490.802734375


In [26]:
metrics

{'alpha': DeviceArray(0.4882597, dtype=float32),
 'beta': DeviceArray(49.442085, dtype=float32),
 'coeff': DeviceArray([[ 1.4626276e-03],
              [-4.7906667e-02],
              [ 9.0431087e-02],
              [ 7.0760306e-04],
              [ 5.4294147e-02],
              [-8.2236814e-01],
              [-2.4918601e-02],
              [ 1.8872246e-03],
              [-1.9642621e-01],
              [-2.4825841e-02],
              [ 8.3470047e-03],
              [-3.2375995e-03]], dtype=float32),
 'loss': DeviceArray(-4491.4624, dtype=float32),
 'mse': DeviceArray(-2556.8352, dtype=float32),
 'reg': DeviceArray(1934.6272, dtype=float32),
 'tau': DeviceArray(2839.967, dtype=float32)}

In [27]:
# Running to convergence
max_epochs = 2000
logger = Logger()
for epoch in jnp.arange(max_epochs):
    (optimizer, state), metrics = update(optimizer, state)
    if epoch % 100 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if epoch % 100 == 0:
        logger.write(metrics, epoch)
logger.close()

Loss step 0: -4491.01123046875
Loss step 100: -4491.70458984375
Loss step 200: -4492.50927734375
Loss step 300: -4492.556640625
Loss step 400: -4492.7998046875
Loss step 500: -4493.5419921875
Loss step 600: -4493.87353515625
Loss step 700: -4494.416015625
Loss step 800: -4495.08251953125
Loss step 900: -4495.1376953125
Loss step 1000: -4495.68017578125
Loss step 1100: -4495.96630859375
Loss step 1200: -4496.14990234375
Loss step 1300: -4496.8369140625
Loss step 1400: -4496.9521484375
Loss step 1500: -4497.6611328125
Loss step 1600: -4498.01611328125
Loss step 1700: -4498.4619140625
Loss step 1800: -4499.03564453125
Loss step 1900: -4499.2626953125


In [28]:
metrics

{'alpha': DeviceArray(0.5149161, dtype=float32),
 'beta': DeviceArray(49.31387, dtype=float32),
 'coeff': DeviceArray([[ 0.00158105],
              [-0.05517259],
              [ 0.09083098],
              [ 0.00277134],
              [ 0.04931932],
              [-0.7894591 ],
              [-0.03046691],
              [-0.00440558],
              [-0.18170032],
              [-0.06253958],
              [ 0.01805699],
              [ 0.00101865]], dtype=float32),
 'loss': DeviceArray(-4499.5713, dtype=float32),
 'mse': DeviceArray(-2567.4487, dtype=float32),
 'reg': DeviceArray(1932.1224, dtype=float32),
 'tau': DeviceArray(2900.8936, dtype=float32)}

In [23]:
def SBL_update(prior_params, y, X, hyper_prior_params):
    # Unpacking parameters
    alpha_prev, beta_prev = prior_params[:-1], prior_params[-1]
    a, b = hyper_prior_params

    # Calculating intermediate matrices
    n_samples, n_terms = X.shape
    Sigma = jnp.linalg.inv(beta_prev * X.T @ X + jnp.diag(alpha_prev))
    mu = beta_prev * Sigma @ X.T @ y
    gamma = 1 - alpha_prev * jnp.diag(Sigma)

    # Updating
    cap = 1e6
    alpha = jnp.minimum(gamma / (mu**2).squeeze(), cap)
    beta = (n_samples - jnp.sum(gamma) + 2 * a) / (jnp.sum((y - X @ mu) ** 2) + 2 * b)
   
    return jnp.concatenate([alpha, beta[None]], axis=0), mu, gamma