In [160]:
import jax
from sklearn.model_selection import train_test_split

f = jax.numpy.cos
n = 1000
X_train = jax.numpy.linspace(-n * jax.numpy.pi, 0, 100 * n).reshape(-1, 1)
y_train = f(X_train)


X_test = jax.numpy.linspace(0, n * jax.numpy.pi, 100 * n).reshape(-1, 1)
y_test = f(X_test)

In [161]:
import jax

from mvtk import interprenet

init_params, model = interprenet.constrained_model(
    (frozenset([interprenet.monotonic_constraint]),),
    get_layers=lambda n: [n + 1],
    preprocess=interprenet.identity,
    postprocess=interprenet.identity)

init_params = ((jax.numpy.asarray([0.]), jax.numpy.asarray([0.]),),
               init_params)
def scaled_model(params, x):
    (m, b), model_params = params
    u = jax.numpy.sin(x * jax.numpy.exp(m) + jax.numpy.arctan(b))
    return model(model_params, u)

In [None]:
def loss(y, y_pred):
    return ((y - y_pred) ** 2).mean()

trained_params = interprenet.train((X_train, y_train),
                                   (X_test, y_test),
                                   (init_params, scaled_model),
                                   metric=lambda y, y_pred: loss(y, y_pred),
                                   step_size=0.01,
                                   mini_batch_size=32,
                                   loss_fn=loss,
                                   num_epochs=128)

In [None]:
loss(y_test, y_test)

In [None]:
trained_model = lambda X: scaled_model(trained_params, X)
y_pred = trained_model(X_test)
loss(y_test, y_pred)

In [None]:
import matplotlib
import matplotlib.pyplot as pylab

q = 1000
pylab.plot(X_test[:q], y_test[:q])
pylab.plot(X_test[:q], y_pred[:q])

pylab.show()