in this notebook we implement ridge regression, first as a function, then as a flax layer.

In [1]:
# %% Imports
import jax
from jax import random, numpy as jnp
from functools import partial
from jax import lax
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from flax import linen as nn
%load_ext autoreload
%autoreload 2

In [15]:
data = jnp.load('test_data.npy', allow_pickle=True).item()
y, X = data['y'], data['X']

X_normed = X / jnp.linalg.norm(X, axis=0, keepdims=True)

In [2]:
def ridge(X, y, l):
    """Ridge regression using augmente data. X can have dimensions."""
    l_normed = jnp.diag(jnp.sqrt(l) * jnp.linalg.norm(X, axis=0))
    l_normed = jax.ops.index_update(l_normed, jax.ops.index[0, 0], 0.0) # shouldnt apply l2 to offset
    X_augmented = jnp.concatenate([X, l_normed], axis=0)
    y_augmented = jnp.concatenate([y, jnp.zeros((X.shape[1], 1))], axis=0)
    
    coeffs = jnp.linalg.lstsq(X_augmented, y_augmented)[0]
    return coeffs

First we check if we did the normalization correctly:

In [72]:
ridge(X, y, l=1e-3)

DeviceArray([[ 0.00110248],
             [-0.05612651],
             [ 0.09737656],
             [ 0.0015001 ],
             [-0.0152597 ],
             [-0.6944833 ],
             [ 0.01022425],
             [-0.00192492],
             [ 0.02313143],
             [-0.3198123 ],
             [-0.00795702],
             [-0.00107516]], dtype=float32)

In [73]:
ridge(X_normed, y, l=1e-3) / jnp.linalg.norm(X, axis=0)[:, None]

DeviceArray([[ 0.00110242],
             [-0.05612676],
             [ 0.09737664],
             [ 0.0015001 ],
             [-0.01525926],
             [-0.69448274],
             [ 0.0102241 ],
             [-0.00192489],
             [ 0.02313079],
             [-0.31981215],
             [-0.00795697],
             [-0.00107523]], dtype=float32)

Great. Now we compare it to sklearn:

In [74]:
from sklearn.linear_model import Ridge

In [88]:
reg = Ridge(fit_intercept=False, alpha=1e-7)

In [89]:
reg.fit(X_normed, y.squeeze()).coef_[:, None] / jnp.linalg.norm(X, axis=0)[:, None]

DeviceArray([[ 8.7421207e-04],
             [-4.1655444e-02],
             [ 9.8937944e-02],
             [ 1.6173805e-03],
             [-1.4723522e-02],
             [-7.6221049e-01],
             [ 9.2654191e-03],
             [-2.8847728e-03],
             [ 2.4254709e-02],
             [-2.6267287e-01],
             [-8.8438150e-03],
             [-2.2157413e-04]], dtype=float32)

Which is within numerical precision :-). Now let's check if we add in zero columns:

In [112]:
mask = jnp.zeros((X.shape[1]), dtype=bool)
mask = jax.ops.index_update(mask, jnp.array([2, 5]), True)

In [113]:
ridge(X * mask, y, l=1e-7)

DeviceArray([[ 0.0000000e+00],
             [-2.9165541e-08],
             [ 9.5455818e-02],
             [-5.6868320e-21],
             [ 0.0000000e+00],
             [-9.9294829e-01],
             [ 0.0000000e+00],
             [ 0.0000000e+00],
             [ 0.0000000e+00],
             [ 0.0000000e+00],
             [ 0.0000000e+00],
             [ 0.0000000e+00]], dtype=float32)

In [114]:
reg.fit(X_normed * mask, y.squeeze()).coef_[:, None] / jnp.linalg.norm(X, axis=0)[:, None]

DeviceArray([[ 0.        ],
             [ 0.        ],
             [ 0.09545576],
             [ 0.        ],
             [ 0.        ],
             [-0.9929491 ],
             [ 0.        ],
             [ 0.        ],
             [ 0.        ],
             [ 0.        ],
             [ 0.        ],
             [ 0.        ]], dtype=float32)

Seems close enough for me... Now to put it in a flax layer:

In [3]:
class Ridge(nn.Module):
    l: float=1e-7
  
    @nn.compact
    def __call__(self, inputs):
        y, X = inputs
        mask = self.variable(
            "mask",
            "active terms",
            lambda n_terms: jnp.ones((n_terms, ), dtype=bool), X.shape[1])
        
        coeffs =  ridge(X * mask.value, y, l=self.l) 

        return coeffs * mask.value[:, None] # extra multiplication to compensate numerical errors

In [182]:
model = Ridge(l=1e-7)

In [183]:
key = random.PRNGKey(42)

In [184]:
params = model.init(key, (y, X))

In [185]:
model.apply(params, (y, X), mutable=["mask"])

(DeviceArray([[ 8.7459665e-04],
              [-4.1662801e-02],
              [ 9.8933339e-02],
              [ 1.6169511e-03],
              [-1.4724741e-02],
              [-7.6219010e-01],
              [ 9.2849545e-03],
              [-2.8839791e-03],
              [ 2.4262663e-02],
              [-2.6268721e-01],
              [-8.8574085e-03],
              [-2.2184884e-04]], dtype=float32),
 FrozenDict({
     mask: {
         active terms: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                       True,  True,  True,  True], dtype=bool),
     },
 }))

now lets try and make a new dict and do a smaller update:

In [5]:
from flax.core import freeze

In [4]:

mask = jnp.zeros((X.shape[1]), dtype=bool)
mask = jax.ops.index_update(mask, jnp.array([2, 5]), True)

params = freeze({'mask': {'active terms': mask}})

NameError: name 'X' is not defined

In [187]:
model.apply(params, (y, X), mutable=["mask"])

(DeviceArray([[ 0.        ],
              [-0.        ],
              [ 0.09545582],
              [-0.        ],
              [ 0.        ],
              [-0.9929483 ],
              [ 0.        ],
              [ 0.        ],
              [ 0.        ],
              [ 0.        ],
              [ 0.        ],
              [ 0.        ]], dtype=float32),
 FrozenDict({
     mask: {
         active terms: DeviceArray([False, False,  True, False, False,  True, False, False,
                      False, False, False, False], dtype=bool),
     },
 }))

Comparison with least squares:

In [190]:
jnp.linalg.lstsq(X[:, [2,5]], y)[0]

DeviceArray([[ 0.09545584],
             [-0.99294853]], dtype=float32)

In [6]:
from verifying_bayesian_regression.code import create_update
from modax.data.burgers import burgers
from modax.feature_generators import library_backward
from modax.networks import MLP
from flax import optim
from modax.losses import neg_LL, loss_fn_pinn, mse
from modax.logging import Logger
from typing import Sequence

In [7]:
def loss_fn_pinn(params, state, model, x, y):
    variables = {'params': params, **state}
    (prediction, dt, theta, coeffs), updated_state = model.apply(variables, x, mutable=list(state.keys()))

    MSE = mse(prediction, y)
    Reg = mse(dt.squeeze(), (theta @ coeffs).squeeze())
    loss = MSE + Reg
    metrics = {"loss": loss, "mse": MSE, "reg": Reg, "coeff": coeffs}

    return loss, (updated_state, metrics)

In [8]:
class Deepmod(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, inputs):
        prediction, dt, theta = library_backward(MLP(self.features), inputs)
        coeffs = Ridge(l=1e-7)((dt, theta))
        return prediction, dt, theta, coeffs

In [9]:
key = random.PRNGKey(42)

In [10]:
# Making dataset
x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)
y_train += 0.01 * jnp.std(y_train) * jax.random.normal(key, y_train.shape)

In [11]:
model = Deepmod([50, 50, 1])
key_network, _ = random.split(key)
variables = model.init(key_network, X_train)

In [12]:
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop('params')
optimizer = optimizer.create(params)

In [13]:
# Compiling train step
update = create_update(loss_fn_pinn, model=model, x=X_train, y=y_train)
_ = update(optimizer, state)  # triggering compilation

In [14]:
# Running to convergence
max_epochs = 1000
logger = Logger()
for epoch in jnp.arange(max_epochs):
    (optimizer, state), metrics = update(optimizer, state)
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if epoch % 25 == 0:
        logger.write(metrics, epoch)
logger.close()

Loss step 0: 0.23318196833133698


So that works as well. now to update the mask every xx epochs:

In [15]:
from sklearn.linear_model import LassoCV

In [16]:
import numpy as np

In [17]:
def update_mask(X, y, reg, threshold=0.1):
    X_normed = X / jnp.linalg.norm(X, axis=0, keepdims=True)
    y_normed = y / jnp.linalg.norm(y, axis=0, keepdims=True)
    coeffs = reg.fit(np.array(X_normed), np.array(y_normed).squeeze()).coef_
    mask = coeffs > threshold 
    return jnp.array(mask)

In [18]:
model = Deepmod([50, 50, 1])
key_network, _ = random.split(key)
variables = model.init(key_network, X_train)

In [19]:
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop('params')
optimizer = optimizer.create(params)

In [20]:
print(state)

FrozenDict({
    mask: {
        Ridge_0: {
            active terms: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                          True,  True,  True,  True], dtype=bool),
        },
    },
})


In [21]:
# Compiling train step
update = create_update(loss_fn_pinn, model=model, x=X_train, y=y_train)
_ = update(optimizer, state)  # triggering compilation

In [22]:
state

FrozenDict({
    mask: {
        Ridge_0: {
            active terms: DeviceArray([ True,  True,  True,  True,  True,  True,  True,  True,
                          True,  True,  True,  True], dtype=bool),
        },
    },
})

In [23]:
# Running to convergence
max_epochs = 1000
reg = LassoCV(fit_intercept=False)
logger = Logger()
for epoch in jnp.arange(max_epochs):
    (optimizer, state), metrics = update(optimizer, state)
    if epoch % 1000 == 0:
        print(f"Loss step {epoch}: {metrics['loss']}")
    if ((epoch % 5000 == 0) and (epoch != 0)):
        dt, theta= model.apply({"params": params, **state}, X_train, mutable=list(state.keys()))[0][1:3]
        mask = update_mask(theta, dt, reg)
        state = freeze({'mask': {'Ridge_0': {'active terms': mask}}})
    if epoch % 25 == 0:
        logger.write(metrics, epoch)
logger.close()

Loss step 0: 0.23318196833133698


In [29]:
dt, theta, coeffs = model.apply({"params": params, **state}, X_train, mutable=list(state.keys()))[0][1:]

In [30]:
dt.shape

(1000, 1)

In [31]:
theta.shape

(1000, 12)

In [32]:
coeffs

DeviceArray([[ 0.05443856],
             [-0.482194  ],
             [ 0.79097605],
             [ 0.04588859],
             [ 0.7138658 ],
             [ 0.7887436 ],
             [-2.1838267 ],
             [-0.30629992],
             [-0.59208643],
             [-0.42597318],
             [ 1.0998764 ],
             [ 0.21274425]], dtype=float32)

In [28]:
jnp.linalg.lstsq(theta / jnp.linalg.norm(theta, axis=0), dt)[0]

DeviceArray([[  1.721303 ],
             [ -4.3714476],
             [  8.153681 ],
             [  1.1237171],
             [ 18.71585  ],
             [  5.806181 ],
             [-19.842834 ],
             [ -5.4475436],
             [-17.152758 ],
             [ -3.0097315],
             [ 10.575122 ],
             [  3.2467732]], dtype=float32)