In [1]:
# %% Imports
import jax
from jax import random, numpy as jnp
from functools import partial
from jax import lax
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from flax import linen as nn
%load_ext autoreload
%autoreload 2

In [2]:
from modax.data.burgers import burgers
from modax.feature_generators import library_backward
from modax.networks import MLP
from flax import optim
from modax.losses import neg_LL, mse
from modax.logging import Logger
from typing import Sequence

from jax import jit, value_and_grad

In [3]:
def create_update(loss_fn, *args, **kwargs):
    def step(opt, state, loss_fn, *args, **kwargs):
        grad_fn = value_and_grad(loss_fn, argnums=0, has_aux=True)
        (loss, (updated_state, metrics)), grad = grad_fn(
            opt.target, state, *args, **kwargs
        )
        opt = opt.apply_gradient(grad)  # Return the updated optimizer with parameters.

        return (opt, updated_state), metrics

    return jit(lambda opt, state: step(opt, state, loss_fn, *args, **kwargs))

In [4]:
def loss_fn_pinn(params, state, model, x, y):
    variables = {'params': params, **state}
    (prediction, dt, theta, coeffs), updated_state = model.apply(variables, x, mutable=list(state.keys()))

    MSE = mse(prediction, y)
    Reg = mse(dt.squeeze(), (theta @ coeffs).squeeze())
    loss = MSE + Reg
    metrics = {"loss": loss, "mse": MSE, "reg": Reg, "coeff": coeffs}

    return loss, (updated_state, metrics)

# Implementing masked least squares:

In [5]:
data = jnp.load('../test_data.npy', allow_pickle=True).item()
y, X = data['y'], data['X']

X_normed = X / jnp.linalg.norm(X, axis=0, keepdims=True)

In [6]:
mask = jnp.zeros((X.shape[1]), dtype=bool)
mask = jax.ops.index_update(mask, jnp.array([2, 5]), True)

In [7]:
X_masked = X * (~mask * 1e-6 + mask)

In [8]:
jnp.linalg.lstsq(X, y)[0]

DeviceArray([[ 8.7450782e-04],
             [-4.1661084e-02],
             [ 9.8933630e-02],
             [ 1.6169542e-03],
             [-1.4724124e-02],
             [-7.6219821e-01],
             [ 9.2846295e-03],
             [-2.8839568e-03],
             [ 2.4262253e-02],
             [-2.6267979e-01],
             [-8.8574179e-03],
             [-2.2183033e-04]], dtype=float32)

In [9]:
jnp.linalg.lstsq(X_masked, y)[0]

DeviceArray([[-2.6406812e-09],
             [-1.5912620e-06],
             [ 9.5455840e-02],
             [ 3.5970850e-05],
             [-1.3432731e-07],
             [-9.9294835e-01],
             [-6.8076105e-07],
             [ 2.7888411e-05],
             [-1.2501999e-07],
             [-7.0616380e-07],
             [-7.7968679e-07],
             [ 2.3472699e-05]], dtype=float32)

Okay so the idea works, great.

In [5]:
class LeastSquares(nn.Module):
    @nn.compact
    def __call__(self, inputs):
        y, X = inputs
        mask = self.variable(
            "vars",
            "mask",
            lambda n_terms: jnp.ones((n_terms, ), dtype=bool), X.shape[1])
        
        X_masked = X * (~mask.value * 1e-6 + mask.value)
        coeffs =  jnp.linalg.lstsq(X_masked, y)[0]

        return coeffs * mask.value[:, None] # extra multiplication to compensate numerical errors

# Running

In [14]:
key = random.PRNGKey(42)
key_data, key_network = random.split(key)

# Making dataset
x = jnp.linspace(-3, 4, 100)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)
y_train += 0.01 * jnp.std(y_train) * jax.random.normal(key_data, y_train.shape)

In [6]:
class Deepmod(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, inputs):
        prediction, dt, theta = library_backward(MLP(self.features), inputs)
        coeffs = LeastSquares()((dt, theta))
        return prediction, dt, theta, coeffs

In [13]:
model = Deepmod([50, 50, 1])
variables = model.init(key_network, X_train)

In [14]:
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop('params')
optimizer = optimizer.create(params)

In [15]:
print(state)

FrozenDict({
    vars: {
        LeastSquares_0: {
            mask: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                          True,  True,  True,  True], dtype=bool),
        },
    },
})


In [16]:
# Compiling train step
update = create_update(loss_fn_pinn, model=model, x=X_train, y=y_train)
_ = update(optimizer, state)  # triggering compilation

In [8]:
from sklearn.linear_model import LassoCV
import numpy as np
from flax.core import freeze

In [7]:
def update_mask(X, y, reg, threshold=0.1):
    X_normed = X / jnp.linalg.norm(X, axis=0, keepdims=True)
    y_normed = y / jnp.linalg.norm(y, axis=0, keepdims=True)
    coeffs = reg.fit(np.array(X_normed), np.array(y_normed).squeeze()).coef_
    mask = np.abs(coeffs) > threshold 
    return mask

In [19]:
# Running to convergence
max_epochs = 10000
logger = Logger(comment='baseline')
for epoch in jnp.arange(max_epochs):
    (optimizer, state), metrics = update(optimizer, state)
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if epoch % 25 == 0:
        logger.write(metrics, epoch)
logger.close()

Loss step 0: 0.19655461609363556
Loss step 1000: 0.0001199067773995921
Loss step 2000: 3.684171133500058e-06
Loss step 3000: 2.1924374777881894e-06
Loss step 4000: 1.9715557755262125e-06
Loss step 5000: 1.9775927739829058e-06
Loss step 6000: 1.885058736661449e-06
Loss step 7000: 1.8930719534182572e-06
Loss step 8000: 1.9562123725336278e-06
Loss step 9000: 1.916803057611105e-06


In [20]:
dt, theta, coeffs = model.apply({"params": optimizer.target, **state}, X_train, mutable=list(state.keys()))[0][1:]

In [21]:
jnp.linalg.lstsq(theta, dt)[0]

DeviceArray([[-1.6329670e-04],
             [-2.3596138e-03],
             [ 1.0054986e-01],
             [-1.7797307e-04],
             [ 4.3829046e-03],
             [-9.8531437e-01],
             [-7.2126538e-03],
             [ 8.3148736e-04],
             [-1.5166596e-02],
             [-3.0938089e-03],
             [ 3.8494654e-03],
             [-2.9145367e-04]], dtype=float32)

In [22]:
reg = LassoCV(fit_intercept=False)
update_mask(theta, dt, reg, threshold=0.1)

array([False, False,  True, False, False,  True, False, False, False,
       False, False, False])

So perfect this works. Now to add in the update:

In [28]:
model = Deepmod([50, 50, 1])
variables = model.init(key_network, X_train)

In [29]:
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop('params')
optimizer = optimizer.create(params)

In [31]:
# Running to convergence
max_epochs = 10000
reg = LassoCV(fit_intercept=False)
logger = Logger(comment='updating_mask')
for epoch in jnp.arange(max_epochs):
    (optimizer, state), metrics = update(optimizer, state)
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if (epoch % 100 == 0) and (epoch > 2000):
        dt, theta = model.apply({"params": optimizer.target, **state}, X_train, mutable=list(state.keys()))[0][1:3]
        mask = update_mask(theta, dt, reg)
        print(mask)
        state = freeze({'vars': {'LeastSquares_0': {'mask': mask}}})
    if epoch % 25 == 0:
        logger.write(metrics, epoch)
logger.close()

Loss step 0: 0.19655461609363556
Loss step 1000: 0.0001199067773995921
Loss step 2000: 3.684171133500058e-06
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
Loss step 3000: 2.292486215083045e-06
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True Fa

In [9]:
from dataclasses import dataclass

@dataclass
class mask_scheduler:
    patience: int = 500
    delta: float = 1e-5
    periodicity: int = 200
        
    periodic: bool = False
    best_loss = None
    best_iteration = None
    
    def __call__(self, loss, iteration, optimizer):
        if self.periodic is True:
            if (iteration - self.best_iteration) % self.periodicity == 0:
                update_mask, optimizer = True, optimizer
            else:
                 update_mask, optimizer = False, optimizer

        elif self.best_loss is None:
            self.best_loss = loss
            self.best_iteration = iteration
            self.best_optim_state = optimizer
            update_mask, optimizer = False, optimizer

        # If it didnt improve, check if we're past patience
        elif (self.best_loss - loss) < self.delta:
            if (iteration - self.best_iteration) >= self.patience:
                self.periodic = True  # switch to periodic regime
                self.best_iteration = iteration  # because the iterator doesnt reset
                update_mask, optimizer = True, self.best_optim_state
            else:
                update_mask, optimizer = False, optimizer

        # If not, keep going
        else:
            self.best_loss = loss
            self.best_iteration = iteration
            self.best_optim_state = optimizer
            update_mask, optimizer = False, optimizer

        return update_mask, optimizer

In [20]:
key = random.PRNGKey(42)
key_data, key_network = random.split(key)

# Making dataset
x = jnp.linspace(-3, 4, 100)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y_train) * jax.random.normal(key_data, y.shape)

In [21]:
rand_idx = random.permutation(key, X.shape[0])
X = X[rand_idx, :]
y = y[rand_idx, :]

split = int(0.8 * X.shape[0])
X_train, X_test = X[:split, :], X[split:, :]
y_train, y_test = y[:split, :], y[split:, :]

In [39]:
model = Deepmod([50, 50, 1])
variables = model.init(key_network, X_train)

In [40]:
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop('params')
optimizer = optimizer.create(params)

In [41]:
validation_metric = jit(lambda opt, state: loss_fn_pinn(opt.target, state, model, X_test, y_test)[1][1])

In [42]:
# Running to convergence
max_epochs = 10000
reg = LassoCV(fit_intercept=False, )
logger = Logger(comment='validation')
scheduler = mask_scheduler()

for epoch in jnp.arange(max_epochs):
    (optimizer, state), train_metrics = update(optimizer, state)
    
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {train_metrics['loss']}")
        
    if epoch % 25 == 0:
        logger.write(train_metrics, epoch)
        
        val_metrics = validation_metric(optimizer, state)
        apply_sparsity, optimizer = scheduler(val_metrics['mse'], epoch, optimizer)

        if apply_sparsity:
            dt, theta = model.apply({"params": optimizer.target, **state}, X_train, mutable=list(state.keys()))[0][1:3]
            mask = update_mask(theta, dt, reg)
            state = freeze({'vars': {'LeastSquares_0': {'mask': mask}}})
            print(mask)
        
logger.close()

Loss step 0: 0.19655461609363556
Loss step 1000: 0.0001199067773995921
[False False  True False False  True  True False  True False False False]
Loss step 2000: 2.886177389882505e-05
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
Loss step 3000: 2.844449682015693e-06
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
[False False  True False False  True False False False False False False]
Loss step 4000: 2.1348250811570324e-06


In [22]:
from sklearn.linear_model import LassoCV
import numpy as np
from flax.core import freeze

In [23]:
def update_mask(X, y, reg, threshold=0.1):
    X_normed = X / jnp.linalg.norm(X, axis=0, keepdims=True)
    y_normed = y / jnp.linalg.norm(y, axis=0, keepdims=True)
    coeffs = reg.fit(np.array(X_normed), np.array(y_normed).squeeze()).coef_
    mask = np.abs(coeffs) > threshold 
    return mask

In [24]:
model = Deepmod([50, 50, 1])
variables = model.init(key_network, X_train)

In [25]:
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop('params')
optimizer = optimizer.create(params)

In [26]:
validation_metric = jit(lambda opt, state: loss_fn_pinn(opt.target, state, model, X_test, y_test)[1][1])

In [27]:
# Compiling train step
update = create_update(loss_fn_pinn, model=model, x=X_train, y=y_train)
_ = update(optimizer, state)  # triggering compilation

In [28]:
# Running to convergence
max_epochs = 10000
reg = LassoCV(fit_intercept=False, )
logger = Logger(comment='validation')
scheduler = jax.jit(lambda metrics, epoch, opt: mask_scheduler()(metrics['mse'], epoch, optimizer))

for epoch in jnp.arange(max_epochs):
    (optimizer, state), train_metrics = update(optimizer, state)
    
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {train_metrics['loss']}")
        
    if epoch % 25 == 0:
        logger.write(train_metrics, epoch)
        
        val_metrics = validation_metric(optimizer, state)
        apply_sparsity, optimizer = scheduler(val_metrics, epoch, optimizer)

        if apply_sparsity:
            dt, theta = model.apply({"params": optimizer.target, **state}, X_train, mutable=list(state.keys()))[0][1:3]
            mask = update_mask(theta, dt, reg)
            state = freeze({'mask': {'LeastSquares_0': {'active terms': mask}}})
            print(mask)
        
logger.close()

Loss step 0: 0.1946893334388733
Loss step 1000: 0.04699103534221649
Loss step 2000: 0.04699103534221649
Loss step 3000: 0.04699103534221649
Loss step 4000: 0.04699103534221649
Loss step 5000: 0.04699103534221649
Loss step 6000: 0.04699103534221649
Loss step 7000: 0.04699103534221649
Loss step 8000: 0.04699103534221649
Loss step 9000: 0.04699103534221649


In [30]:
mask

NameError: name 'mask' is not defined

In [31]:
state

FrozenDict({
    vars: {
        LeastSquares_0: {
            mask: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                          True,  True,  True,  True], dtype=bool),
        },
    },
})