In [2]:
%load_ext autoreload
%autoreload 2

It is possible to save/load the model within GPX.
Each model class exposes two methods, `save` and `load`, that allow one to save and load the **values** of the inner `ModelState`.

In [1]:
import jax.numpy as jnp
from jax import random

In [3]:
import gpx
from gpx.kernels import SquaredExponential
from gpx.utils import softplus, inverse_softplus
from gpx.models import GPR

In [4]:
model = GPR(
    kernel=SquaredExponential(),
)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
# Generate synthetic data
train_x = jnp.linspace(0, 1, 100)
key = random.PRNGKey(0)
train_y = jnp.sin(train_x * (2 * jnp.pi)) + random.normal(
    key, shape=train_x.shape
) * jnp.sqrt(0.04)
test_x = jnp.linspace(-0.5, 1.5, 51)
test_f = jnp.sin(test_x * (2 * jnp.pi))

train_x = train_x.reshape(-1, 1)
train_y = train_y.reshape(-1, 1)
test_x = test_x.reshape(-1, 1)
test_f = test_f.reshape(-1, 1)

In [6]:
# Fit using SciPy's L-BFGS-B
model.fit(train_x, train_y)

<gpx.models.gpr.GaussianProcessRegression at 0x14e9d5c4e8f0>

In [7]:
model.print()

┌────────────────────┬─────────────┬───────────┬──────────────────┬──────────────────┬───────────┬─────────┬─────────┬──────────┐
│ name               │ trainable   │ forward   │ backward         │ prior            │ type      │ dtype   │ shape   │    value │
├────────────────────┼─────────────┼───────────┼──────────────────┼──────────────────┼───────────┼─────────┼─────────┼──────────┤
│ kernel lengthscale │ True        │ softplus  │ inverse_softplus │ Normal(0.0, 1.0) │ ArrayImpl │ float64 │ ()      │ 0.411059 │
├────────────────────┼─────────────┼───────────┼──────────────────┼──────────────────┼───────────┼─────────┼─────────┼──────────┤
│ sigma              │ True        │ softplus  │ inverse_softplus │ Normal(0.0, 1.0) │ ArrayImpl │ float64 │ ()      │ 0.217113 │
└────────────────────┴─────────────┴───────────┴──────────────────┴──────────────────┴───────────┴─────────┴─────────┴──────────┘


The model state values are stored inside an uncompressed NumPy `npz` file.
The function `save` also returns the dictionary of values that are saved.

In [8]:
saved_dict = model.save("model_state.npz")

Loading the model is reminiscent of the [PyTorch API](https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-state-dict-recommended). It requires reinstantiating the model, and then updating the model parameters by reading them from file.

In [9]:
new_model = GPR(
    kernel=SquaredExponential(),
)

In [10]:
new_model.print()

┌────────────────────┬─────────────┬───────────┬──────────────────┬──────────────────┬───────────┬─────────┬─────────┬─────────┐
│ name               │ trainable   │ forward   │ backward         │ prior            │ type      │ dtype   │ shape   │   value │
├────────────────────┼─────────────┼───────────┼──────────────────┼──────────────────┼───────────┼─────────┼─────────┼─────────┤
│ kernel lengthscale │ True        │ softplus  │ inverse_softplus │ Normal(0.0, 1.0) │ ArrayImpl │ float64 │ ()      │       1 │
├────────────────────┼─────────────┼───────────┼──────────────────┼──────────────────┼───────────┼─────────┼─────────┼─────────┤
│ sigma              │ True        │ softplus  │ inverse_softplus │ Normal(0.0, 1.0) │ ArrayImpl │ float64 │ ()      │       1 │
└────────────────────┴─────────────┴───────────┴──────────────────┴──────────────────┴───────────┴─────────┴─────────┴─────────┘


In [11]:
new_model.load("model_state.npz")

<gpx.models.gpr.GaussianProcessRegression at 0x14e9a82e7c10>

In [12]:
new_model.print()

┌────────────────────┬─────────────┬───────────┬──────────────────┬──────────────────┬───────────┬─────────┬─────────┬──────────┐
│ name               │ trainable   │ forward   │ backward         │ prior            │ type      │ dtype   │ shape   │    value │
├────────────────────┼─────────────┼───────────┼──────────────────┼──────────────────┼───────────┼─────────┼─────────┼──────────┤
│ kernel lengthscale │ True        │ softplus  │ inverse_softplus │ Normal(0.0, 1.0) │ ArrayImpl │ float64 │ ()      │ 0.411059 │
├────────────────────┼─────────────┼───────────┼──────────────────┼──────────────────┼───────────┼─────────┼─────────┼──────────┤
│ sigma              │ True        │ softplus  │ inverse_softplus │ Normal(0.0, 1.0) │ ArrayImpl │ float64 │ ()      │ 0.217113 │
└────────────────────┴─────────────┴───────────┴──────────────────┴──────────────────┴───────────┴─────────┴─────────┴──────────┘


In [13]:
new_model.state.is_fitted

array(True)

In [14]:
%%bash

if [ -f model_state.npz ] ; then rm model_state.npz ; fi