In [None]:
from pathlib import Path
import tomli
import jax
import jax.random as jr
from jaxlinop import identity
from jaxutils import Dataset
from geometric_asymptotics.io import load_space
from geometric_asymptotics.experiment import generate_kernel_and_training_data, compute_predictions, compute_extrinsic_length_scale, compute_non_asymptotic_expected_error
from geometric_asymptotics.plot_mpl import plot_kernel, plot_data, plot_errors
jax.config.update("jax_enable_x64", True)

# Load mesh 

Assemble adjacency matrix and graph

In [None]:
with open(Path.cwd() / ".." / "config" / "dumbbell.toml", "rb") as f:
    config = tomli.load(f)

seed = 0

In [None]:
(graph,coordinates) = load_space(Path.cwd() / ".." / config["source"])

# Generate full training data

Generate x-values by sampling random points on the mesh, and y-values by sampling from intrinsic prior

In [None]:
key = jr.PRNGKey(seed)
(k1,key) = jr.split(key)
(kernels, training_data, test_data) = generate_kernel_and_training_data(graph, coordinates, config, k1)

plot_kernel(kernels[0].gram(kernels[2], test_data[0][:config["num_plot"],:]).to_dense())
plot_data(training_data[1], training_data[2], lims=(-2,2))

# Predictions

For both intrinsic and extrinsic model with a subset of the data points, as a sanity check

In [None]:
(k2,key) = jr.split(key)
(pred_mean, test_pred_mean, test_pred_var, _) = compute_predictions(kernels, training_data, test_data, config, k2)

plot_data(*pred_mean, lims=(-2,2))

In [None]:
plot_data(*test_pred_mean, lims=(-2,2))
plot_data(*test_pred_var, lims=(0,1))

# Extrinsic length scale

Since the extrinsic and intrinsic processes have length scales which are defined only up to some unknown multiplicative constant, we need to find an extrinsic length scale which matches the intrinsic one

In [None]:
(k3,key) = jr.split(key)
kappa_extr, posterior_ls, learned_params_ls = compute_extrinsic_length_scale(kernels, training_data, config, k3)
print("extrinsic optimized length scale:", kappa_extr)

In [None]:
data_ls_matching = Dataset(X=training_data[1][:config["subset_size_int"],:], y=training_data[2][:config["subset_size_int"],:].reshape(-1,1))
plot_kernel(kernels[1].gram(learned_params_ls['kernel'], test_data[1][:config["num_plot"],:]).to_dense())

In [None]:
plot_data(test_data[1], posterior_ls.predict(learned_params_ls, data_ls_matching)(test_data[1]).mean(), lims=(-2,2))
plot_data(test_data[1], posterior_ls.predict(learned_params_ls, data_ls_matching)(test_data[1]).variance(), lims=(0,1))

# Non-asymptotic expected error

Finally, we condition on subsets of data, and compute expected error for the two models considered

In [None]:
(k4,key) = jr.split(key)
output = compute_non_asymptotic_expected_error(kernels, training_data, test_data, coordinates.shape[0], kappa_extr, config, k4)
plot_errors(output["expected_error_idxs"], output["expected_errors"], output["expected_errors_extr"], output["expected_errors_approx_intr"])
plot_errors(output["expected_error_idxs"], output["expected_errors_std"], output["expected_errors_extr_std"], output["expected_errors_approx_intr_std"])

# Plotting output

This code produces output for the plotting figures in the paper

In [None]:
(k5,k6,key) = jr.split(key,3)
y_sample = (kernels[0].gram(kernels[2], test_data[0]) + identity(coordinates.shape[0]) * config["noise_variance"]).to_root() @ jr.normal(k5, (coordinates.shape[0], 1))
y_sample_extr = (kernels[1].gram(learned_params_ls['kernel'], test_data[1]) + identity(coordinates.shape[0]) * config["noise_variance"]).to_root() @ jr.normal(k6, (coordinates.shape[0], 1))

In [None]:
plot_data(test_data[1], y_sample, lims=(-2,2), colorbar=False, file_name=Path.cwd() / ".." / "plot" / "db_sample_intr.png")
plot_data(test_data[1], y_sample_extr, lims=(-2,2), colorbar=False, file_name=Path.cwd() / ".." / "plot" / "db_sample_extr.png")

In [None]:
kernel_point_idx = 864
kernel_values = kernels[0].cross_covariance(kernels[2], test_data[0], test_data[0][None,kernel_point_idx,:])
kernel_values_extr = kernels[1].cross_covariance(learned_params_ls['kernel'], test_data[1], test_data[1][None,kernel_point_idx,:])
plot_data(test_data[1], kernel_values, lims=(0,1), colorbar=False, highlight = test_data[1][kernel_point_idx,:], file_name=Path.cwd() / ".." / "plot" / "db_kernel_intr.png")
plot_data(test_data[1], kernel_values_extr, lims=(0,1), colorbar=False, highlight = test_data[1][kernel_point_idx,:], file_name=Path.cwd() / ".." / "plot" / "db_kernel_extr.png")