Copyright 2023-2023 Lawrence Livermore National Security, LLC and other MuyGPyS
Project Developers. See the top-level COPYRIGHT file for details.

SPDX-License-Identifier: MIT

# Shear Kernel Tutorial

This notebook demonstrates how to use the specialized lensing shear kernel (hard-coded to RBF at the moment).

⚠️ _Note that this is still an experimental feature._ ⚠️

In [None]:
import copy
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
from matplotlib.colors import LogNorm, SymLogNorm

from MuyGPyS._test.shear import (
    conventional_Kout,
    conventional_mean,
    conventional_variance,
    conventional_shear,
    targets_from_GP,
)
from MuyGPyS.gp import MuyGPS
from MuyGPyS.gp.deformation import DifferenceIsotropy, F2
from MuyGPyS.gp.hyperparameter import Parameter
from MuyGPyS.gp.kernels.experimental import ShearKernel
from MuyGPyS.neighbors import NN_Wrapper
from MuyGPyS.gp.noise import HomoscedasticNoise, ShearNoise33

We will set a random seed here for consistency when building docs.
In practice we would not fix a seed.

In [None]:
np.random.seed(0)

## Kernel Implementation Comparisons

Here we will compare the analytic implementation of the kernel function to the `MuyGPyS` implementation, using some simple data.

Here we build some simple data, which is mean to represent a grid of sky coordinates.

In [None]:
n = 25  # number of galaxies on a side
xmin = 0
xmax = 1
ymin = 0
ymax = 1

xx = np.linspace(xmin, xmax, n)
yy = np.linspace(ymin, ymax, n)

x, y = np.meshgrid(xx, yy)
features = np.vstack((x.flatten(), y.flatten())).T
data_count = features.shape[0]
length_scale = 0.05
shear_model = MuyGPS(
        kernel=ShearKernel(
            deformation=DifferenceIsotropy(
                F2,
                length_scale=Parameter(length_scale),
            ),
        ),
        noise = ShearNoise33(1e-4),
)
diffs = shear_model.kernel.deformation.pairwise_tensor(features, np.arange(data_count))

plotting convenience

In [None]:
my_cmap = copy.copy(cm.get_cmap('viridis'))
my_cmap.set_bad("white")
my_sym_cmap = copy.copy(cm.get_cmap('coolwarm'))
my_sym_cmap.set_bad((0, 0, 0))

Use an Isotropic distance functor.

### Pairwise kernels (`Kin`)
This code computes the `Kin` kernels.

In [None]:
Kin_analytic = conventional_shear(features, length_scale=length_scale)

Here we do the same using the MuyGPyS implementation. Note the increased efficiency.

In [None]:
Kin_muygps = shear_model.kernel(diffs)

`Kin_muygps` is a more generalized tensor, so we need to flatten it to a conforming shape.

In [None]:
Kin_flat = Kin_muygps.reshape(data_count * 3, data_count * 3)

In [None]:
print(f"Kin_analytic.shape = {Kin_analytic.shape}")
print(f"Kin_muygps.shape = {Kin_muygps.shape}")
print(f"Kin_flat.shape = {Kin_flat.shape}")

Do the two implementations agree?

In [None]:
np.allclose(Kin_analytic, Kin_flat)

In [None]:
Kin_residual = np.abs(Kin_analytic - Kin_flat)
print(f"Kin residual max: {np.max(Kin_residual)}, min: {np.min(Kin_residual)}, mean : {np.mean(Kin_residual)}")

Plot results of the baseline and MuyGPyS implementations. 

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].set_title("original shear kernel")
axes[0].imshow(Kin_analytic)
axes[1].set_title("MuyGPyS shear kernel")
axes[1].imshow(Kin_flat)
axes[2].set_title("Residual")
im = axes[2].imshow(Kin_residual, norm=LogNorm(), cmap=my_cmap)
fig.colorbar(im, ax=axes[2])
plt.show()

### Cross-Covariance (`Kcross`)
Now we perform a similar analysis for the cross-covariance.

In [None]:
split = 200
X1 = features[:split]
X2 = features[split:]
n1, _ = X1.shape
n2, _ = X2.shape
crosswise_diffs = shear_model.kernel.deformation.crosswise_tensor(
    X1, X2, np.arange(n1), np.arange(n2)
)
print(X1.shape, X2.shape, crosswise_diffs.shape)

In [None]:
Kcross_analytic = conventional_shear(X1, X2, length_scale=length_scale)

In [None]:
Kcross_muygps = shear_model.kernel(crosswise_diffs, adjust=False)

In [None]:
Kcross_flat = Kcross_muygps.reshape(n1 * 3, n2 * 3)

In [None]:
print(f"Kcross_analytic.shape = {Kcross_analytic.shape}")
print(f"Kcross_muygps.shape = {Kcross_muygps.shape}")
print(f"Kcross_flat.shape = {Kcross_flat.shape}")

In [None]:
np.allclose(Kcross_analytic, Kcross_flat)

In [None]:
Kcross_residual = np.abs(Kcross_analytic - Kcross_flat)
print(f"Kcross residual max: {np.max(Kcross_residual)}, min: {np.min(Kcross_residual)}, mean : {np.mean(Kcross_residual)}")

Now we visualize the comparison.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
axes[0].set_title("original shear kernel")
axes[0].imshow(Kcross_analytic)
axes[1].set_title("MuyGPyS shear kernel")
axes[1].imshow(Kcross_flat)
axes[2].set_title("Residual")
im = axes[2].imshow(Kcross_residual, norm=LogNorm(), cmap=my_cmap)
fig.colorbar(im, ax=axes[2])
plt.show()

Runtime comparison of the two implementations (Change `False` to `True` to run):

In [None]:
if False:
    %timeit conventional_shear(features)
    %timeit ShearKernel(deformation=dist_fn)(diffs)

## Posterior Comparisons

Now we will test the `posterior_mean` and `posterior_variance` of the analytic and muygps implementations.
We first simulate a dataset.
Targets should be square matrices like a grid of a swath of sky.
Ulitimately, the target array will have shape `(625,3)`, given `n=25` above.

We will then evaluate the posterior means and variances of the analytic and `MuyGPyS` implementations.
We will plot both, along with their residuals (compared against both each other and the ground truth for the mean). 

We compute the posterior mean as
$$ \hat{Y}(X^*|X) = K_{\theta}(X^*,X)(K_{\theta}(X,X)+\epsilon)^{-1}Y(X) $$
where $\epsilon = \sigma^2 I$ with $\sigma^2$ being the noise variance.

The posterior variance is defined as
$$ \mathrm{Var}(\hat{Y}(X^*|X)) = K_{\theta}(X^*,X^*) - K_{\theta}(X^*,X)[K_{\theta}(X,X) - \epsilon]^{-1}K_{\theta}(X,X^*) $$

Set noise level

In [None]:
noise_prior = 1e-4

Define the target matrices.
Initially was run with arbitrary targets, now as of 1/25/24 can sample targets from GP.

In [None]:
targets = targets_from_GP(features, n, length_scale, noise_prior)

Here we create a train/test split in the dataset.
Modify the `train_ratio` to specify the proportion of data to hold out for training.

In [None]:
train_ratio = 0.2

In [None]:
rng = np.random.default_rng(seed=1)
interval_count = int(data_count * train_ratio)
interval = int(data_count / interval_count)
sfl = rng.permutation(np.arange(data_count))
train_mask = np.zeros(data_count, dtype=bool)
for i in range(interval_count):
    idx = np.random.choice(sfl[i * interval : (i + 1) * interval])
    train_mask[idx] = True
test_mask = np.invert(train_mask)
train_count = np.count_nonzero(train_mask)
test_count = np.count_nonzero(test_mask)

In [None]:
train_targets = targets[train_mask, :]
test_targets = targets[test_mask, :]
train_features = features[train_mask, :]
test_features = features[test_mask, :]

Let's visualize the train/test datasets.

In [None]:
def make_im(vec, mask):
    ret = np.zeros(len(mask))
    ret[mask] = vec
    ret[np.invert(mask)] = -np.inf
    return ret.reshape(n, n)

In [None]:
fig, ax = plt.subplots(2, 3,figsize = (10,7))
ax[0, 0].imshow(make_im(train_targets[:,0], train_mask))
ax[0, 0].set_ylabel("train", fontsize = 15)
ax[0, 0].set_title("$\kappa$", fontsize = 15)
ax[1, 0].imshow(make_im(test_targets[:,0], test_mask))
ax[1, 0].set_ylabel("test", fontsize = 15)
ax[0, 1].imshow(make_im(train_targets[:,1], train_mask))
ax[0, 1].set_title("g1", fontsize = 15)
ax[1, 1].imshow(make_im(test_targets[:,1], test_mask))
ax[0, 2].imshow(make_im(train_targets[:,2], train_mask))
ax[0, 2].set_title("g2", fontsize = 15)
ax[1, 2].imshow(make_im(test_targets[:,2], test_mask))
plt.show()

Explicitly define the target matrices.
Also add leading unitary dimension to `targets_muygpys` for things to work.

In [None]:
train_targets_flat = train_targets.swapaxes(0, 1).reshape(3 * train_count)
test_targets_flat = test_targets.swapaxes(0, 1).reshape(3 * test_count)
print(train_targets_flat.shape, test_targets_flat.shape)

Analytic: for the analytic implementation, I'll do things with the full "flattened" difference tensors.

In [None]:
Kin_analytic = conventional_shear(train_features, train_features, length_scale=length_scale)
Kcross_analytic = conventional_shear(test_features, train_features, length_scale=length_scale)
Kout_analytic = conventional_Kout(shear_model.kernel, test_count)

In [None]:
print(Kcross_analytic.shape, Kin_analytic.shape, Kout_analytic.shape, train_targets_flat.shape)

In [None]:
posterior_mean_analytic = conventional_mean(
    Kin_analytic,
    Kcross_analytic,
    train_targets_flat,
    noise_prior,
)
posterior_variance_analytic = conventional_variance(
    Kin_analytic, 
    Kcross_analytic,
    Kout_analytic,
    noise_prior,
)

In [None]:
print(posterior_mean_analytic.shape, posterior_variance_analytic.shape)

Create flat solve using MuyGPyS functions.
This should be very close to the analytic solution.

In [None]:
pairwise_diffs = shear_model.kernel.deformation.pairwise_tensor(
    train_features, np.arange(train_count)
)
crosswise_diffs = shear_model.kernel.deformation.crosswise_tensor(
    test_features, train_features, np.arange(test_count), np.arange(train_count)
)
Kin_muygps = shear_model.kernel(pairwise_diffs, adjust=False)
Kcross_muygps = shear_model.kernel(crosswise_diffs, adjust=False)
Kin_flat = Kin_muygps.reshape(3 * train_count, 3 * train_count)
Kcross_flat = Kcross_muygps.reshape(3 * test_count, 3 * train_count)

In [None]:
print(Kin_muygps.shape, Kcross_muygps.shape, Kin_flat.shape, Kcross_flat.shape)

Check that the flattened kernel tensors agree with the analytic tensors (should pass if the above passed).

In [None]:
print(
    np.allclose(Kin_analytic, Kin_flat),
    np.allclose(Kcross_analytic, Kcross_flat),
)

Plotting code.

In [None]:
def show_im(vec, mask, ax):
    mat = make_im(vec, mask)
    im = ax.imshow(mat.reshape(n, n), norm=LogNorm(), cmap=my_cmap)
    fig.colorbar(im, ax=ax)

def compare_means(truth, first, second, fname, sname, fontsize=12, all_colorbar=False):
    f_residual = np.abs(truth - first) + 1e-15
    s_residual = np.abs(truth - second) + 1e-15
    fs_residual = np.abs(first - second) + 1e-15

    fig, ax = plt.subplots(6, 3, figsize = (10, 18))
    
    for axis_set in ax:
        for axis in axis_set:
            axis.set_xticks([])
            axis.set_yticks([])

    ax[0, 0].set_title("$\kappa$")
    ax[0, 1].set_title("g1")
    ax[0, 2].set_title("g2")
    ax[0, 0].set_ylabel("Truth", fontsize=fontsize)
    ax[1, 0].set_ylabel(f"{fname} Mean", fontsize=fontsize)
    ax[2, 0].set_ylabel(f"|truth - {fname}|", fontsize=fontsize)
    ax[3, 0].set_ylabel(f"{sname} Mean", fontsize=fontsize)
    ax[4, 0].set_ylabel(f"|truth - {sname}|", fontsize=fontsize)
    ax[5, 0].set_ylabel(f"|{fname} - {sname}|", fontsize=fontsize)

    # truth
    im00 = ax[0, 0].imshow(make_im(truth[:,0], test_mask))
    im01 = ax[0, 1].imshow(make_im(truth[:,1], test_mask))
    im02 = ax[0, 2].imshow(make_im(truth[:,2], test_mask))
    if all_colorbar is True:
        fig.colorbar(im00, ax=ax[0, 0])
        fig.colorbar(im01, ax=ax[0, 1])
        fig.colorbar(im02, ax=ax[0, 2])

    # first model
    im10 = ax[1, 0].imshow(make_im(first[:,0], test_mask))
    im11 = ax[1, 1].imshow(make_im(first[:,1], test_mask))
    im12 = ax[1, 2].imshow(make_im(first[:,2], test_mask))
    if all_colorbar is True:
        fig.colorbar(im10, ax=ax[1, 0])
        fig.colorbar(im11, ax=ax[1, 1])
        fig.colorbar(im12, ax=ax[1, 2])

    # first model residual
    show_im(f_residual[:,0], test_mask, ax=ax[2, 0])
    show_im(f_residual[:,1], test_mask, ax=ax[2, 1])
    show_im(f_residual[:,2], test_mask, ax=ax[2, 2])

    # second model
    im30 = ax[3, 0].imshow(make_im(second[:,0], test_mask))
    im31 = ax[3, 1].imshow(make_im(second[:,1], test_mask))
    im32 = ax[3, 2].imshow(make_im(second[:,2], test_mask))
    if all_colorbar is True:
        fig.colorbar(im30, ax=ax[3, 0])
        fig.colorbar(im31, ax=ax[3, 1])
        fig.colorbar(im32, ax=ax[3, 2])

    # second model residual
    show_im(s_residual[:, 0], test_mask, ax=ax[4, 0])
    show_im(s_residual[:, 1], test_mask, ax=ax[4, 1])
    show_im(s_residual[:, 2], test_mask, ax=ax[4, 2])

    # residual between the two models
    show_im(fs_residual[:, 0], test_mask, ax=ax[5, 0])
    show_im(fs_residual[:, 1], test_mask, ax=ax[5, 1])
    show_im(fs_residual[:, 2], test_mask, ax=ax[5, 2])

    plt.show()

Now we compute the flattened `MuyGPyS` conventional solution.

In [None]:
posterior_mean_flat = conventional_mean(
    Kin_flat,
    Kcross_flat,
    train_targets_flat,
    noise_prior,
)
posterior_variance_flat = conventional_variance(
    Kin_flat, 
    Kcross_flat, 
    Kout_analytic,
    noise_prior,
)

In [None]:
print(posterior_mean_flat.shape, posterior_variance_flat.shape)

And finally, visually compare the posterior mean and posterior variance.
The flat and analytic solutions should be very close, up to ~1e-10 or so.

In [None]:
print(np.allclose(posterior_mean_flat, posterior_mean_analytic))
print(np.allclose(posterior_variance_flat, posterior_variance_analytic))

### Mean comparison

Here we plot the posterior mean of the analytic and flattened `MuyGPyS` implementations, conditioned on the full training data.
We also plot residuals with respect to the ground truth and each other.

In [None]:
compare_means(test_targets, posterior_mean_flat, posterior_mean_analytic, "flat", "analytic")

### Full covariance examination

Here we plot the full posterior covariance, and find that they agree.

In [None]:
residual = np.abs(posterior_variance_analytic - posterior_variance_flat)
fig, ax = plt.subplots(1,3,figsize = (14,4))

ax[0].imshow(posterior_variance_analytic)
ax[0].set_title("Analytic Variance")

ax[1].imshow(posterior_variance_flat)
ax[1].set_title("Flat MuyGPyS Variance")

ax[2].set_title("|Analytic - Flat MuyGPyS|")
im = ax[2].imshow(residual)
fig.colorbar(im, ax = ax[2])

print("Min Resid = ", np.min(residual), ", Max Resid = ", np.max(residual), ", Avg Residual = ", np.mean(residual))

In [None]:
def show_var_im(vec, mask, ax):
    mat = make_im(vec, mask)
    im = ax.imshow(mat.reshape(n, n), norm=SymLogNorm(linthresh=1e-7), cmap=my_sym_cmap)
    fig.colorbar(im, ax=ax)

def compare_variances(first, second, fname, sname, fontsize=12):
    residual = np.abs(first - second) + 1e-15

    fig, ax = plt.subplots(9, 3, figsize = (8, 24))
    
    for axis_set in ax:
        for axis in axis_set:
            axis.set_xticks([])
            axis.set_yticks([])

    ax[0, 0].set_title("$\kappa$")
    ax[0, 1].set_title("g1")
    ax[0, 2].set_title("g2")
    ax[0, 0].set_ylabel(f"{fname} $\kappa$", fontsize=fontsize)
    ax[1, 0].set_ylabel(f"{fname} g1", fontsize=fontsize)
    ax[2, 0].set_ylabel(f"{fname} g2", fontsize=fontsize)
    ax[3, 0].set_ylabel(f"{sname} $\kappa$", fontsize=fontsize)
    ax[4, 0].set_ylabel(f"{sname} g1", fontsize=fontsize)
    ax[5, 0].set_ylabel(f"{sname} g2", fontsize=fontsize)
    ax[6, 0].set_ylabel("residual $\kappa$", fontsize=fontsize)
    ax[7, 0].set_ylabel("residual g1", fontsize=fontsize)
    ax[8, 0].set_ylabel("residual g2", fontsize=fontsize)

    # first variances
    show_im(first[:, 0, 0], test_mask, ax=ax[0, 0])
    show_var_im(first[:, 0, 1], test_mask, ax=ax[0, 1])
    show_var_im(first[:, 0, 2], test_mask, ax=ax[0, 2])
    show_var_im(first[:, 1, 0], test_mask, ax=ax[1, 0])
    show_im(first[:, 1, 1], test_mask, ax=ax[1, 1])
    show_var_im(first[:, 1, 2], test_mask, ax=ax[1, 2])
    show_var_im(first[:, 2, 0], test_mask, ax=ax[2, 0])
    show_var_im(first[:, 2, 1], test_mask, ax=ax[2, 1])
    show_im(first[:, 2, 2], test_mask, ax=ax[2, 2])

    # second variances
    show_im(second[:, 0, 0], test_mask, ax=ax[3, 0])
    show_var_im(second[:, 0, 1], test_mask, ax=ax[3, 1])
    show_var_im(second[:, 0, 2], test_mask, ax=ax[3, 2])
    show_var_im(second[:, 1, 0], test_mask, ax=ax[4, 0])
    show_im(second[:, 1, 1], test_mask, ax=ax[4, 1])
    show_var_im(second[:, 1, 2], test_mask, ax=ax[4, 2])
    show_var_im(second[:, 2, 0], test_mask, ax=ax[5, 0])
    show_var_im(second[:, 2, 1], test_mask, ax=ax[5, 1])
    show_im(second[:, 2, 2], test_mask, ax=ax[5, 2])

    # variance residuals
    show_im(residual[:, 0, 0], test_mask, ax=ax[6, 0])
    show_im(residual[:, 0, 1], test_mask, ax=ax[6, 1])
    show_im(residual[:, 0, 2], test_mask, ax=ax[6, 2])
    show_im(residual[:, 1, 0], test_mask, ax=ax[7, 0])
    show_im(residual[:, 1, 1], test_mask, ax=ax[7, 1])
    show_im(residual[:, 1, 2], test_mask, ax=ax[7, 2])
    show_im(residual[:, 2, 0], test_mask, ax=ax[8, 0])
    show_im(residual[:, 2, 1], test_mask, ax=ax[8, 1])
    show_im(residual[:, 2, 2], test_mask, ax=ax[8, 2])

    plt.tight_layout()
    plt.show()

In [None]:
def tensor_diagonalize(var):
    count = int(var.shape[0] / 3)
    ret = np.zeros((count, 3, 3))
    for i in range(3):
        for j in range(3):
            ret[:, i, j] = np.diagonal(
                var[
                    (count * i):(count * (i + 1)),
                    (count * j):(count * (j + 1)),
                ]
            )
    return ret

### Posterior variance comparison

Here we compare the posterior variances of the flattened `MuyGPyS` and analytic solutions, and compare their mutual residual.
For each setting we plot a 3x3 grid of plots showing the variance or residual of the corresponding pair of shear parameters at each test point.

In [None]:
compare_variances(
    tensor_diagonalize(posterior_variance_flat),
    tensor_diagonalize(posterior_variance_analytic),
    "flat",
    "analytic",
)

## MuyGPyS workflow

Here we'll use an nn-sparsified MuyGPyS workflow to the conventional GP using the analytic kernel.
The two approaches should converge (up to ~1e-9 precision) when `nn_count == test_count`.
As `nn_count` decreases, the `MuyGPyS` workflow will get faster but will correspondingly drift from the conventional predictions.
The two solutions should still remain visually similar, however.

In [None]:
def get_nn_tensors(nn_count=50):
    indices = np.arange(test_count)
    if nn_count == train_count:
        nn_indices = np.array([
            np.arange(train_count) for _ in range(test_count)
        ])
    else:
        nbrs_lookup = NN_Wrapper(train_features, nn_count, nn_method='exact', algorithm='ball_tree')
        nn_indices, _ = nbrs_lookup.get_nns(test_features)
    
    (
        crosswise_diffs,
        pairwise_diffs,
        nn_targets,
    ) = shear_model.make_predict_tensors(
        indices,
        nn_indices,
        test_features,
        train_features,
        train_targets,
    )

    nn_targets= nn_targets.swapaxes(-2, -1)
    
    x0_features = test_features[indices[0]][None, ...]
    x0_nn_features = train_features[nn_indices[0]]
    
    return crosswise_diffs, pairwise_diffs, x0_features, x0_nn_features, nn_targets

### This section consistency checks the various MuyGPyS tensors when `nn_count == train_count`

In [None]:
crosswise_diffs, pairwise_diffs, x0_features, x0_nn_features, x0_nn_targets = get_nn_tensors(nn_count=train_count)

Check that the `nn_targets` agree.

In [None]:
x0_targets_flat = x0_nn_targets.reshape(test_count, 3 * train_count)

In [None]:
print(x0_nn_targets.shape, x0_targets_flat.shape)

In [None]:
np.all([
    np.allclose(train_targets_flat, x0_targets_flat[0])
    for _ in range(test_count)
])

Check that the `Kin`s agree.

In [None]:
Kin_test = shear_model.kernel(pairwise_diffs)

In [None]:
print(Kin_test.shape, Kin_flat.shape)

In [None]:
np.all([
    np.allclose(
        Kin_analytic,
        Kin_test[i].reshape(3 * train_count, 3 * train_count),
    ) for i in range(test_count)
])

Check that the `Kcross`es agree.

In [None]:
Kcross_test = shear_model.kernel(crosswise_diffs)

In [None]:
print(Kcross_test.shape, Kcross_analytic.shape)

In [None]:
np.all([
    np.allclose(
        np.squeeze(Kcross_analytic.reshape(3, test_count, 3 * train_count)[:, i, :]),
        Kcross_test[i].reshape(3 * train_count, 3).swapaxes(-2, -1),
    ) for i in range(test_count)
])

Here we check to see if the resulting means agree.

In [None]:
x0_analytic = np.array([
    conventional_mean(
        Kin_analytic,
        np.squeeze(Kcross_analytic.reshape(3, test_count, 3 * train_count)[:, i, :]),
        train_targets_flat,
        noise_prior,
    ) for i in range(test_count)
])

In [None]:
x0_test = np.array([
    conventional_mean(
        Kin_test[i].reshape(3 * train_count, 3 * train_count),
        Kcross_test[i].reshape(3 * train_count, 3).swapaxes(-2, -1),
        x0_targets_flat[i],
        noise_prior,
    ) for i in range(test_count)
])

In [None]:
np.allclose(x0_analytic, x0_test)

Here we define the full `MuyGPyS` workflow.
There are a lot of unnecessary internal checks and prints that were used for debugging.

In [None]:
def muygps_mean_workflow(nn_count=50):
    crosswise_diffs, pairwise_diffs, x0_features, x0_nn_features, nn_targets = get_nn_tensors(nn_count=nn_count)

    Kcross = shear_model.kernel(crosswise_diffs)
    Kin = shear_model.kernel(pairwise_diffs)
    
    print(pairwise_diffs.shape, Kin.shape)
    
    Kin_flat = Kin.reshape(test_count, 3 * nn_count, 3 * nn_count)
    Kcross_flat = Kcross.reshape(test_count, 3 * nn_count, 3)
    nn_targets_flat = nn_targets.reshape(test_count, 3 * nn_count)
    
    Kin_an = conventional_shear(
        x0_nn_features,
        length_scale=length_scale,
    )
    Kcross_an = conventional_shear(
        x0_features,
        x0_nn_features,
        length_scale=length_scale,
    )
    Kout_an = conventional_Kout(shear_model.kernel, 1)

    # here we are consistency checking the tensors of each implementation
    print(f"Kin.shape = {Kin.shape}")
    print(f"Kcross.shape = {Kcross.shape}")
    print(f"nn_targets.shape = {nn_targets.shape}")
    print("----------")
    print(f"Kin_flat.shape = {Kin_flat.shape}")
    print(f"Kcross_flat.shape = {Kcross_flat.shape}")
    print(f"nn_targets_flat.shape = {nn_targets_flat.shape}")
    print("----------")
    print(f"Kin_an.shape = {Kin_an.shape}")
    print(f"Kcross_an.shape = {Kcross_an.shape}")
    print("----------")
    print("----------")
    print(f"Kin_flat[0] == Kin_an? {np.allclose(Kin_flat[0], Kin_an)}")
    print(f"Kcross_flat[0] == Kcross_an? {np.allclose(Kcross_flat[0], Kcross_an.swapaxes(-2, -1))}")
    
    mean = shear_model.posterior_mean(Kin, Kcross, nn_targets)
    variance = shear_model.posterior_variance(Kin, Kcross)

    # This is more spot checking to see whether and to what extent the different implementations
    # agree on a particular prediction.
    mean_flat = np.squeeze(conventional_mean(
        Kin_flat[0],
        Kcross_flat[0].swapaxes(-2, -1),
        nn_targets_flat[0],
        noise_prior,
    ))
    mean_an = np.squeeze(conventional_mean(
        Kin_an,
        Kcross_an,
        nn_targets_flat[0],
        noise_prior,
    ))
    variance_flat = np.squeeze(conventional_variance(
        Kin_flat[0],
        Kcross_flat[0].swapaxes(-2, -1),
        Kout_an,
        noise_prior,
    ))
    variance_an = np.squeeze(conventional_variance(
        Kin_an,
        Kcross_an,
        Kout_an,
        noise_prior,
    ))
    
    print("----------")
    print("----------")
    print(f"mean.shape = {mean.shape}")
    print(f"mean_flat.shape = {mean_flat.shape}")
    print(f"mean_an.shape = {mean_an.shape}")
    print("----------")
    print(f"mean_flat == mean_an? {np.allclose(mean_flat, mean_an)}")
    print(f"mean[0] == mean_flat? {np.allclose(mean[0], mean_flat)}")
    print(f"mean[0] == mean_an? {np.allclose(mean[0], mean_an)}")
    print("----------")
    print("----------")
    print(f"variance.shape = {variance.shape}")
    print(f"variance_flat.shape = {variance_flat.shape}")
    print(f"variance_an.shape = {variance_an.shape}")
    print("----------")
    print("----------")
    print(f"variance_flat == variance_an? {np.allclose(variance_flat, variance_an)}")
    print(f"variance[0] == variance_flat? {np.allclose(variance[0], variance_flat)}")
    print(f"variance[0] == variance_an? {np.allclose(variance[0], variance_an)}")
#     print(mean[0])
#     print(posterior_mean_analytic[:, 0])
#     print(posterior_mean_flat[:, 0])
    
    return mean, variance

Here we compute the MuyGPs posterior mean.
If `nn_count == train_count`, the results should agree with the analytic/flat solutions.
Smaller `nn_count`s will drift (as expected).

In [None]:
posterior_mean_muygps, posterior_variance_muygps = muygps_mean_workflow(nn_count=50)

Check numerically if things are close.

In [None]:
print(f"mean is all close? {np.allclose(posterior_mean_analytic, posterior_mean_muygps)}")
print(f"variance is all close? {np.allclose(tensor_diagonalize(posterior_variance_analytic), posterior_variance_muygps)}")

Check the mean error

In [None]:
print(f"mean ME: {np.mean(np.abs(posterior_mean_analytic - posterior_mean_muygps))}")
print(f"variance ME: {np.mean(np.abs(tensor_diagonalize(posterior_variance_analytic) - posterior_variance_muygps))}")

### Mean comparison

Finally, we compare the `MuyGPyS` predictions conditioned on the specified number of neighbors to the conventional GP predictions.

In [None]:
compare_means(test_targets, posterior_mean_muygps, posterior_mean_analytic, "MuyGPyS", "Analytic")

### Posterior variance comparison

We follow it up by comparing the posterior variances as before, still conditioned on the specified number of nearest neighbors.

In [None]:
compare_variances(
    posterior_variance_muygps,
    tensor_diagonalize(posterior_variance_analytic),
    "MuyGPyS",
    "analytic",
)

# Optimization Test

The next step is to test the optimizer on the mock data generated above.
Recall that the mock data are sampled with a GP given a `length_scale`, which means that if we run hyperparameter optimization, we *should* recover this length scale if the optimizer is working properly.

First, create the train tensors.

In [None]:
# from MuyGPyS.gp.tensors import crosswise_tensor, pairwise_tensor
from MuyGPyS.optimize.batch import sample_batch
from MuyGPyS.optimize import Bayes_optimize
from MuyGPyS.optimize.loss import lool_fn, looph_fn, mse_fn
from MuyGPyS.gp.hyperparameter import AnalyticScale

In [None]:
print(train_targets.shape, train_features.shape)

In [None]:
shear_model = MuyGPS(
    kernel=ShearKernel(
        deformation=DifferenceIsotropy(
            F2,
            length_scale=Parameter(0.5, [0.01, 0.9]),
        ),
    ),
    noise=ShearNoise33(noise_prior),
    scale=AnalyticScale(),
)

train_features_count = train_features.shape[0]

nn_count = 50
nbrs_lookup = NN_Wrapper(train_features, nn_count, nn_method='exact', algorithm='ball_tree')
    

batch_count=500
batch_indices, batch_nn_indices = sample_batch(
    nbrs_lookup, batch_count, train_features_count
)

# need pairwise and crosswise diffs
batch_crosswise_diffs = shear_model.kernel.deformation.crosswise_tensor(
    train_features,
    train_features,
    batch_indices,
    batch_nn_indices,
)

batch_pairwise_diffs = shear_model.kernel.deformation.pairwise_tensor(
    train_features, batch_nn_indices
)

In [None]:
batch_targets = train_targets[batch_indices]
batch_nn_targets= train_targets[batch_nn_indices].swapaxes(-2, -1)

See if the optimization correctly predicts the length scale with the `mse` loss fn.

In [None]:
shear_mse_optimized = Bayes_optimize(
    shear_model,
    batch_targets,
    batch_nn_targets,
    batch_crosswise_diffs,
    batch_pairwise_diffs,
    train_targets,
    loss_fn=mse_fn,
    verbose=True,
    init_points=5,
    n_iter=20,
)

In [None]:
test_features_count = test_features.shape[0]

indices = np.arange(test_features_count)
test_nn_indices, _ = nbrs_lookup.get_nns(test_features)

(
    test_crosswise_diffs,
    test_pairwise_diffs,
    test_nn_targets,
) = shear_model.make_predict_tensors(
    indices,
    test_nn_indices,
    test_features,
    train_features,
    train_targets,
)

test_nn_targets= test_nn_targets.swapaxes(-2, -1)

In [None]:
Kcross = shear_mse_optimized.kernel(test_crosswise_diffs)
Kin = shear_mse_optimized.kernel(test_pairwise_diffs)
posterior_mean_muygps_optimized = shear_mse_optimized.posterior_mean(Kin, Kcross, test_nn_targets)
posterior_variance_muygps_optimized = shear_mse_optimized.posterior_variance(Kin, Kcross)

In [None]:
compare_means(test_targets, posterior_mean_muygps_optimized, posterior_mean_analytic, "Optimized MuyGPyS", "Analytic")

And finally, an presentation of the posterior variance.

In [None]:
compare_variances(
    posterior_variance_muygps_optimized,
    tensor_diagonalize(posterior_variance_analytic),
    "Optimized MuyGPyS",
    "analytic",
)

# 2in3out Exploration

Here we explore the 2in3out variant of the shear kernel, which trains on observations only of `g1` and `g2`, but predicts onto all three covariates.

In [None]:
from MuyGPyS.gp.kernels.experimental import ShearKernel2in3out

In [None]:
Kin_analytic_33 = Kin_analytic
Kcross_analytic_33 = Kcross_analytic
targets_flat_33 = train_targets_flat
mean_analytic_33 = posterior_mean_analytic
variance_analytic_33 = posterior_variance_analytic
variance_diag_33 = np.diag(variance_analytic_33)
ci_analytic_33 = np.sqrt(variance_diag_33) * 1.96
ci_analytic_33 = ci_analytic_33.reshape(test_count, 3)
coverage_analytic_33 = (
    np.count_nonzero(
        np.abs(test_targets - mean_analytic_33) < ci_analytic_33, axis=0
    ) / test_count
)

In [None]:
Kin_analytic_23 = Kin_analytic[train_count:, train_count:]
Kcross_analytic_23 = Kcross_analytic[:, train_count:]
targets_flat_23 = train_targets_flat[train_count:] 

In [None]:
mean_analytic_23 = conventional_mean(
    Kin_analytic_23, Kcross_analytic_23, targets_flat_23, noise_prior
)
variance_analytic_23 = conventional_variance(
    Kin_analytic_23, Kcross_analytic_23, Kout_analytic, noise_prior
)
variance_diag_23 = np.diag(variance_analytic_23)
ci_analytic_23 = np.sqrt(variance_diag_23) * 1.96
ci_analytic_23 = ci_analytic_23.reshape(test_count, 3)
coverage_analytic_23 = (
    np.count_nonzero(
        np.abs(test_targets - mean_analytic_23) < ci_analytic_23, axis=0
    ) / test_count
)

In [None]:
print(
    mean_analytic_33.shape, Kin_analytic_33.shape, Kcross_analytic_33.shape, targets_flat_33.shape
)
print(
    mean_analytic_23.shape, Kin_analytic_23.shape, Kcross_analytic_23.shape, targets_flat_23.shape
)

In [None]:
compare_means(test_targets, mean_analytic_23, mean_analytic_33, "2x3 Model", "3x3 Model", all_colorbar=True)

Here we do the same thing with the 2x3 MuyGPyS implementation.

In [None]:
model33 = MuyGPS(
    kernel=ShearKernel(
        deformation=DifferenceIsotropy(
            F2,
            length_scale=Parameter(shear_mse_optimized.kernel.deformation.length_scale()),
        ),
    ),
    noise=ShearNoise33(noise_prior),
)

In [None]:
model23 = MuyGPS(
    kernel=ShearKernel2in3out(
        deformation=DifferenceIsotropy(
            F2,
            length_scale=Parameter(shear_mse_optimized.kernel.deformation.length_scale()),
        ),
    ),
    noise=HomoscedasticNoise(noise_prior),
)

In [None]:
Kcross_33 = model33.kernel(test_crosswise_diffs)
Kin_33 = model33.kernel(test_pairwise_diffs)
nn_targets_33 = test_nn_targets
mean_muygps_33 = model33.posterior_mean(Kin_33, Kcross_33, nn_targets_33)
covariance_muygps_33 = model33.posterior_variance(Kin_33, Kcross_33)
variance_muygps_33 = np.zeros((test_count, 3))
for i in range(test_count):
    variance_muygps_33[i, 0] = covariance_muygps_33[i, 0, 0] 
    variance_muygps_33[i, 1] = covariance_muygps_33[i, 1, 1] 
    variance_muygps_33[i, 2] = covariance_muygps_33[i, 2, 2] 
ci_muygps_33 = np.sqrt(variance_muygps_33) * 1.96
coverage_muygps_33 = (
    np.count_nonzero(
        np.abs(test_targets - mean_muygps_33) < ci_muygps_33, axis=0
    ) / test_count
)

In [None]:
Kcross_23 = model23.kernel(test_crosswise_diffs)
Kin_23 = model23.kernel(test_pairwise_diffs)
nn_targets_23 = test_nn_targets[:, 1:, :]
mean_muygps_23 = model23.posterior_mean(Kin_23, Kcross_23, nn_targets_23)
covariance_muygps_23 = model23.posterior_variance(Kin_23, Kcross_23)
variance_muygps_23 = np.zeros((test_count, 3))
for i in range(test_count):
    variance_muygps_23[i, 0] = covariance_muygps_23[i, 0, 0] 
    variance_muygps_23[i, 1] = covariance_muygps_23[i, 1, 1] 
    variance_muygps_23[i, 2] = covariance_muygps_23[i, 2, 2] 
ci_muygps_23 = np.sqrt(variance_muygps_23) * 1.96
coverage_muygps_23 = (
    np.count_nonzero(
        np.abs(test_targets - mean_muygps_23) < ci_muygps_23, axis=0
    ) / test_count
)

In [None]:
print(
    mean_muygps_23.shape, Kin_23.shape, Kcross_23.shape, nn_targets_23.shape
)

Here we compare the 2x3 MuyGPyS implementation to the 2x3 Analytic model

In [None]:
compare_means(test_targets, mean_analytic_23, mean_muygps_23, "2x3 Analytic", "2x3 MuyGPs", all_colorbar=True)

Here we compare the 2x3 MuyGPyS implementation to the 3x3 Analytic model.
Note that the 3x3 analytic model is a better approximation to the truth, but both look at least visually ok.

In [None]:
compare_means(test_targets, mean_analytic_33, mean_muygps_23, "3x3 Analytic", "2x3 MuyGPs", all_colorbar=True)

In [None]:
variance_muygps_23.shape

In [None]:
compare_variances(
    covariance_muygps_23,
    tensor_diagonalize(variance_analytic_33),
    "2x3 MuyGPyS",
    "3x3 analytic",
)

Now let's look at the covariate-wise coverage

In [None]:
mean_ci_a33 = np.mean(ci_analytic_33, axis=0)
mean_ci_a23 = np.mean(ci_analytic_23, axis=0)
mean_ci_m33 = np.mean(ci_muygps_33, axis=0)
mean_ci_m23 = np.mean(ci_muygps_23, axis=0)

In [None]:
print("\t\tconvergence\t\tshear 1\t\t\tshear 2")
print(f"dense 3x3\t{mean_ci_a33[0]}\t{mean_ci_a33[1]}\t{mean_ci_a33[2]}")
print(f"muygps 3x3\t{mean_ci_m33[0]}\t{mean_ci_m33[1]}\t{mean_ci_m33[2]}")
print(f"dense 2x3\t{mean_ci_a23[0]}\t{mean_ci_a23[1]}\t{mean_ci_a23[2]}")
print(f"muygps 2x3\t{mean_ci_m23[0]}\t{mean_ci_m23[1]}\t{mean_ci_m23[2]}")

In [None]:
print("\t\tconvergence\tshear 1\t\tshear 2")
print(f"dense 3x3\t{coverage_analytic_33[0]}\t\t{coverage_analytic_33[1]}\t\t{coverage_analytic_33[2]}")
print(f"muygps 3x3\t{coverage_muygps_33[0]}\t\t{coverage_muygps_33[1]}\t\t{coverage_muygps_33[2]}")
print(f"dense 2x3\t{coverage_analytic_23[0]}\t\t{coverage_analytic_23[1]}\t\t{coverage_analytic_23[2]}")
print(f"muygps 2x3\t{coverage_muygps_23[0]}\t\t{coverage_muygps_23[1]}\t\t{coverage_muygps_23[2]}")