# Algorithm Illustration

We will visualize the different components of IterGP to better understand how it performs GP inference.

In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline

In [17]:
from probnum import backend, randvars
from probnum.randprocs import kernels, mean_fns

from itergp import GaussianProcess, datasets, methods

## Dataset

We consider a synthetically generated dataset without noise such that we can evaluate the label function $y(x) = sin(\pi x^\top \mathbf{1})$.

In [63]:
# Generate data
rng_state = backend.random.rng_state(42)

num_data = 6
input_shape = ()
output_shape = ()

rng_state, rng_state_data = backend.random.split(rng_state, num=2)
data = datasets.SyntheticDataset(
    rng_state=rng_state,
    size=(num_data, num_data),
    input_shape=input_shape,
    output_shape=output_shape,
    noise_var=0.0,
)
X = data.train.X
y = data.train.y

## Gaussian Process Model

In [78]:
# Model
mean_fn = mean_fns.Zero(input_shape=input_shape, output_shape=output_shape)
kernel = kernels.ExpQuad(input_shape=input_shape, lengthscale=0.2)
sigma_sq = 0.0
noise = randvars.Normal(
    mean=backend.zeros(y.shape), cov=sigma_sq * backend.eye(y.shape[0])
)

gp = GaussianProcess(mean_fn, kernel)

## IterGP Inference

In [None]:
# Approximation methods
approx_methods = ["IterGP-Chol", "IterGP-CG", "IterGP-PI"]
pseudo_inputs = backend.linspace(-1, 1, len(X))

In [79]:
fig, axs = plt.subplots(
    nrows=3, ncols=len(approx_methods), sharex="col", figsize=(12, 6), sharey="row"
)

fig.patch.set_alpha(0.0)  # set figure background opacity to 0
plt.close()


def animate(idxiter):

    for ax in axs.flatten():
        ax.cla()

    for idxmethod, approx_method in enumerate(approx_methods):

        # Latent function
        Xnew = backend.linspace(-1, 1, 1000)
        axs[0, idxmethod].plot(
            Xnew, data.fun(Xnew), linestyle="--", color="black", lw=0.75
        )

        # Training Data
        data_range = X.max() - X.min()
        data_width = 0.05 * data_range
        for i in range(3):
            for x in X:
                axs[i, idxmethod].axvspan(
                    xmin=x - 0.25 * data_width,
                    xmax=x + 0.25 * data_width,
                    color="gray",
                    alpha=0.25,
                    zorder=-10,
                    lw=0.0,
                )

        # Gaussian process approximation
        if approx_method == "IterGP-Chol":
            ameth = methods.Cholesky(maxrank=idxiter)
        elif approx_method == "IterGP-CG":
            ameth = methods.CG(maxiter=idxiter)
        elif approx_method == "IterGP-PI":
            ameth = methods.PseudoInput(pseudo_inputs=pseudo_inputs[0:idxiter])
        else:
            raise NotImplementedError

        gp_post = gp.condition_on_data(X, y, b=noise, approx_method=ameth)
        gp_post.plot(X=Xnew, data=(X, y), ax=axs[0, idxmethod])

        axs[0, idxmethod].set(ylabel="Prediction", ylim=(-2.01, 2.01))

        # Residual
        residual_fn = lambda x: data.fun(x) - gp_post.mean(x)
        residual_global = residual_fn(Xnew)
        residual = residual_fn(X)
        residual_color = "C3"
        axs[1, idxmethod].fill_between(
            x=Xnew,
            y1=backend.zeros_like(Xnew),
            y2=residual_global,
            alpha=0.2,
            lw=0.0,
            color=residual_color,
        )
        axs[1, idxmethod].plot(Xnew, residual_global, color=residual_color)
        axs[1, idxmethod].scatter(X, residual, color=residual_color, marker=".")
        axs[1, idxmethod].axhline(y=0.0, color="black", linestyle="--", lw=0.5)
        ymin, ymax = axs[1, idxmethod].get_ylim()
        for i, x in enumerate(X):
            axs[1, idxmethod].axvspan(
                xmin=x - 0.25 * data_width,
                xmax=x + 0.25 * data_width,
                ymin=(0.0 - ymin) / (ymax - ymin),
                ymax=(backend.to_numpy(residual[i]) - ymin) / (ymax - ymin),
                color=residual_color,
                # alpha=0.5,
                zorder=-5,
                lw=0.0,
            )

        axs[1, idxmethod].set(ylabel="Residual", ylim=(-1.01, 1.01))

        # Action
        if idxiter < len(X):
            if approx_method == "IterGP-Chol":
                action_fn = lambda x: x == X[idxiter]
            elif approx_method == "IterGP-CG":
                action_fn = residual_fn
            elif approx_method == "IterGP-PI":
                action_fn = lambda x: kernel(pseudo_inputs[idxiter], x)
            else:
                raise NotImplementedError

            action_global = action_fn(Xnew)
            action = action_fn(X)

            action_color = "C4"
            axs[2, idxmethod].scatter(X, action, color=action_color, marker=".")
            axs[2, idxmethod].plot(Xnew, action_global, color=action_color)
            axs[2, idxmethod].fill_between(
                x=Xnew,
                y1=backend.zeros_like(Xnew),
                y2=action_global,
                color=action_color,
                lw=0.0,
                alpha=0.2,
            )
            axs[2, idxmethod].axhline(y=0.0, color="black", linestyle="--", lw=0.5)
            ymin, ymax = axs[2, idxmethod].get_ylim()
            for i, x in enumerate(X):
                axs[2, idxmethod].axvspan(
                    xmin=x - 0.25 * data_width,
                    xmax=x + 0.25 * data_width,
                    ymin=(0.0 - ymin) / (ymax - ymin),
                    ymax=(backend.to_numpy(action[i]) - ymin) / (ymax - ymin),
                    color=action_color,
                    # alpha=0.5,
                    zorder=-5,
                    lw=0.0,
                )

        axs[2, idxmethod].set(xlabel="Input Space", ylabel="Action")

        fig.align_ylabels()


from IPython.display import HTML
from matplotlib import animation

# Create animation
anim = animation.FuncAnimation(
    fig, func=animate, frames=num_data + 1, interval=1250, repeat_delay=4000, blit=False
)

# Create interactive plot
HTML(anim.to_jshtml())

### Observations

Notice how the action can be interpreted as _targeting computation_ towards certain datapoints. Different instances of IterGP differ by how the computation is targeted during a run of the algorithm. After $n$ iterations the residual is zero at the datapoints, which is a consequence of IterGP being a conjugate direction method.