In [None]:
import pandas as pd
import gpjax as gpx
import jax.numpy as jnp
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
from jax import jit
import optax as ox

In [None]:
yacht = pd.read_fwf('https://archive.ics.uci.edu/ml/machine-learning-databases/00243/yacht_hydrodynamics.data', header=None).values[:-1, :]
X = yacht[:, :-1]
y = yacht[:, -1].reshape(-1, 1)

Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.3, random_state=123)

## Preprocessing

### Response Variable

In [None]:
log_ytr = np.log(ytr)
log_yte = np.log(yte)
y_scaler = StandardScaler().fit(log_ytr)
scaled_y = y_scaler.transform(log_ytr)

fig, ax = plt.subplots(ncols=3, figsize=(16, 4))
ax[0].hist(ytr, bins=30)
ax[0].set_title('y')
ax[1].hist(log_ytr, bins=30)
ax[1].set_title('log(y)')
ax[2].hist(scaled_y, bins=30)
ax[2].set_title('scaled log(y)')

### Input Variable

In [None]:
x_scaler = StandardScaler().fit(Xtr)
scaled_Xtr = x_scaler.transform(Xtr)
scaled_Xte = x_scaler.transform(Xte)

### Train/Test Split

In [None]:
Xtr, Xte, ytr, yte = train_test_split(scaled_X, scaled_y, test_size=0.3, random_state=123)

## Model fitting

### Model specification

In [None]:
n_train, n_covariates = Xtr.shape
kernel = gpx.kernels.RBF(active_dims = list(range(n_covariates)))
prior = gpx.Prior(kernel = kernel)

likelihood = gpx.Gaussian(num_datapoints=n_train)

posterior = prior * likelihood

params, trainables, constrainers, unconstrainers = gpx.initialise(posterior)
params = gpx.transform(params, unconstrainers)

### Model Optimisation

In [None]:
training_data = gpx.Dataset(X = Xtr, y=ytr)

mll = jit(posterior.marginal_log_likelihood(train_data = training_data, transformations=constrainers, negative=True))
learned_params = gpx.fit(objective=mll, params=params, trainables=trainables, optax_optim=ox.adam(0.05), n_iters=1000, log_rate=50)
learned_params = gpx.transform(learned_params, constrainers)

## Prediction

In [None]:
def lognormal_transform(mean, variance):
    mu = jnp.exp(mean + variance/2)
    sigma2 = (jnp.exp(variance) - 1) * jnp.exp(2*mean + variance)
    return mu, sigma2 

latent_dist = posterior(training_data, learned_params)(Xte)
predictive_dist = likelihood(latent_dist, learned_params)

predictive_mean = predictive_dist.mean()
predictive_variance = predictive_dist.variance()

predictive_mean, predictive_variance = lognormal_transform(predictive_mean, predictive_variance)
predictive_mean = y_scaler.inverse_transform(predictive_mean.reshape(-1, 1))
predictive_variance = y_scaler.inverse_transform(predictive_variance.reshape(-1, 1))

In [None]:
mean_squared_error(y_true = yte, y_pred = predictive_mean.squeeze())

In [None]:
fig, ax = plt.subplots()
ax.scatter(predictive_mean.squeeze(), yte)
ax.plot([0, 1], [0, 1], color='tab:orange', transform=ax.transAxes)
ax.set(xlabel='Predicted', ylabel='Actual', title='Predicted vs Actual')