In [None]:
# UTILITY FUNCTIONS FOR CHAPTER 1

import ipywidgets as widgets
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from matplotlib import animation

# ---------------------------------
# % Animation of damped oscillator
# ---------------------------------


def animate_damped_oscillator(
    damped_oscillator_fn,
    t_anim,
    theta,
    fps=24,
    n_coils=12,
    amp=0.11,
    trace_len=200,
):
    y_anim = damped_oscillator_fn(t_anim, theta)
    t_anim_np = np.asarray(t_anim)
    y_anim_np = np.asarray(y_anim)

    def _spring_points(
        x_left: float, x_right: float, n_coils: int = n_coils, amp: float = amp
    ):
        xs = np.linspace(x_left, x_right, n_coils * 2 + 1)
        ys = np.zeros_like(xs)
        ys[1:-1:2] = amp
        ys[2:-1:2] = -amp
        ys[0] = ys[-1] = 0.0
        return xs, ys

    fig, (ax_left, ax_right) = plt.subplots(
        1, 2, figsize=(12, 3), gridspec_kw={"width_ratios": [1, 1.2]}
    )
    fig.suptitle("Damped Harmonic Oscillator — Mass Motion & Displacement", fontsize=14)

    A = theta.A
    ax_left.set_xlim(-A * 1.8 - 0.4, A * 1.8 + 0.4)
    ax_left.set_ylim(-0.8, 0.8)
    ax_left.set_xlabel("Position")
    ax_left.set_yticks([])
    ax_left.grid(True, axis="x")
    wall_x = -A * 1.8 - 0.2
    ax_left.plot([wall_x, wall_x], [-0.8, 0.8], linewidth=4)

    mass_w, mass_h = 0.28, 0.32
    mass = plt.Rectangle(
        (y_anim_np[0] - mass_w / 2, -mass_h / 2), mass_w, mass_h, ec="black", fc="C0"
    )
    ax_left.add_patch(mass)
    (spring_line,) = ax_left.plot([], [], lw=2)
    (trace_line,) = ax_left.plot([], [], lw=1, alpha=0.6)
    trace_buf = []

    ax_right.set_xlim(t_anim_np[0], t_anim_np[-1])
    ax_right.set_ylim(-A * 1.2, A * 1.2)
    ax_right.set_xlabel("Time t")
    ax_right.set_ylabel("Displacement y(t)")
    ax_right.grid(True)
    ax_right.plot(t_anim_np, y_anim_np, alpha=0.4, label="True displacement")
    (marker_dot,) = ax_right.plot([], [], "o", label="Current position")
    y0, y1 = ax_right.get_ylim()
    (time_line,) = ax_right.plot(
        [t_anim_np[0], t_anim_np[0]], [y0, y1], linestyle="--", alpha=0.6
    )
    ax_right.legend(loc="upper right")

    def init():
        spring_line.set_data([], [])
        trace_line.set_data([], [])
        marker_dot.set_data([], [])
        y0, y1 = ax_right.get_ylim()
        time_line.set_data([t_anim_np[0], t_anim_np[0]], [y0, y1])
        return mass, spring_line, trace_line, marker_dot, time_line

    def animate(i):
        x_pos = y_anim_np[i]
        mass.set_xy((x_pos - mass_w / 2, -mass_h / 2))
        xs, ys = _spring_points(wall_x, x_pos - mass_w / 2)
        spring_line.set_data(xs, ys)
        trace_buf.append(x_pos)
        if len(trace_buf) > trace_len:
            del trace_buf[0]
        trace_line.set_data(np.array(trace_buf), np.zeros(len(trace_buf)))
        marker_dot.set_data([t_anim_np[i]], [y_anim_np[i]])
        y0, y1 = ax_right.get_ylim()
        time_line.set_data([t_anim_np[i], t_anim_np[i]], [y0, y1])
        return mass, spring_line, trace_line, marker_dot, time_line

    anim = animation.FuncAnimation(
        fig,
        animate,
        init_func=init,
        frames=len(t_anim_np),
        interval=1000 / fps,
        blit=True,
    )
    plt.close(fig)
    return HTML(anim.to_jshtml())


# ---------------------------------------
# % Plotting GPs and interactive widgets
# ---------------------------------------

plt.ioff()


def plot_gp(
    ax,
    gp_or_dist,
    X_grid,
    *,
    n_samples: int = 0,
    seed: int = 0,
    mean_label: str = "Mean",
    shade_k: float = 2.0,
    shade_label: str = "\pm2\sigma",
    true_curve=None,  # (X_true, y_true)
    obs=None,  # (X_obs, y_obs)
    obs_style=None,
    sample_style=None,
    mean_style=None,
    band_alpha: float = 0.2,
):
    """
    Draws GP mean, ±ksigma band, optional samples, optional true curve, and observations.
    Works with either a GP (callable returning a dist) or a distribution directly.
    """
    # Get a distribution at X_grid
    if hasattr(gp_or_dist, "mu") and hasattr(gp_or_dist, "Sigma"):
        dist = gp_or_dist
    else:
        dist = gp_or_dist(X_grid)  # assume GP-like callable

    Xr = np.asarray(X_grid).ravel()
    mu = np.asarray(dist.mu).reshape(-1)
    std = np.sqrt(np.asarray(dist.Sigma).diagonal())

    # Styles
    if mean_style is None:
        mean_style = dict(lw=2, label=mean_label)
    if sample_style is None:
        sample_style = dict(lw=1.0, alpha=0.8)
    if obs_style is None:
        obs_style = dict(s=60, marker="x", color="red", zorder=5, label="Observed")

    # Plot uncertainty band + mean
    ax.fill_between(
        Xr, mu - shade_k * std, mu + shade_k * std, alpha=band_alpha, label=shade_label
    )
    ax.plot(Xr, mu, **mean_style)

    # Optional samples
    if n_samples > 0:
        S = dist.sample(jax.random.PRNGKey(seed), num_samples=n_samples)
        ax.plot(Xr, np.asarray(S).T, **sample_style)

    # Optional ground truth
    if true_curve is not None:
        X_true, y_true = true_curve
        X_true, y_true = np.asarray(X_true).ravel(), np.asarray(y_true).ravel()
        ax.plot(
            X_true,
            y_true,
            lw=2,
            alpha=0.7,
            color="black",
            linestyle="--",
            label="True Signal",
        )

    # Optional observations
    if obs is not None:
        X_obs, y_obs = obs
        X_obs, y_obs = np.asarray(X_obs).ravel(), np.asarray(y_obs).ravel()
        ax.scatter(X_obs, y_obs, **obs_style)

    return mu, std  # handy for metrics etc.


class InteractiveGPPlotter:
    def __init__(
        self,
        plot_type: str,
        *,
        kernel_registry,
        gp_cls=None,
        mean_fn=None,
        training_data=None,
        test_data=None,
    ):
        assert plot_type in {"kernel", "prior", "posterior"}
        self.plot_type = plot_type
        self.kernel_registry = kernel_registry
        self.gp_cls = gp_cls
        self.mean_fn = mean_fn

        self.X_train, self.y_train = training_data if training_data else (None, None)
        self.X_test, self.y_test = test_data if test_data else (None, None)

        if self.plot_type == "posterior":
            if any(
                v is None
                for v in (self.X_train, self.y_train, self.X_test, self.y_test)
            ):
                raise ValueError("posterior plot requires training_data and test_data.")

        self.widgets = self._create_widgets()
        self.controls = self._layout_widgets()

        self.fig, self.ax = plt.subplots(figsize=(9, 4.5))
        self.ui = widgets.VBox([self.controls, self.fig.canvas])

        self._link_widgets_to_redraw()
        self.redraw()

    def _create_widgets(self) -> dict[str, widgets.Widget]:
        style = {"description_width": "110px"}
        w = {
            "kernel": widgets.Dropdown(
                options=list(self.kernel_registry.keys()),
                value=next(iter(self.kernel_registry.keys())),
                description="Kernel",
                style=style,
            ),
            "variance": widgets.FloatLogSlider(
                value=1.0, base=10, min=-2, max=2, description="Variance", style=style
            ),
            "lengthscale": widgets.FloatLogSlider(
                value=1.0,
                base=10,
                min=-2,
                max=2,
                description="Lengthscale",
                style=style,
            ),
            "period": widgets.FloatSlider(
                value=np.pi,
                min=0.2,
                max=8.0,
                step=0.05,
                description="Period",
                style=style,
            ),
        }
        if self.plot_type == "kernel":
            w["xmax"] = widgets.FloatSlider(
                value=5.0, min=1.0, max=15, step=0.1, description="x_max", style=style
            )
        elif self.plot_type == "prior":
            w.update(
                {
                    "samples": widgets.IntSlider(
                        value=5, min=0, max=30, description="# Samples", style=style
                    ),
                    "seed": widgets.IntSlider(
                        value=0, min=0, max=100, description="Seed", style=style
                    ),
                    "xmin": widgets.FloatSlider(
                        value=-5.0,
                        min=-15,
                        max=0,
                        step=0.1,
                        description="x_min",
                        style=style,
                    ),
                    "xmax": widgets.FloatSlider(
                        value=5.0,
                        min=0,
                        max=15,
                        step=0.1,
                        description="x_max",
                        style=style,
                    ),
                }
            )
        else:
            w.update(
                {
                    "noise": widgets.FloatLogSlider(
                        value=0.1,
                        base=10,
                        min=-3,
                        max=0,
                        description="σ_noise",
                        style=style,
                    ),
                    "samples": widgets.IntSlider(
                        value=3, min=0, max=30, description="# Samples", style=style
                    ),
                    "seed": widgets.IntSlider(
                        value=0, min=0, max=100, description="Seed", style=style
                    ),
                    "n_used": widgets.IntSlider(
                        value=min(10, len(self.X_train)),
                        min=0,
                        max=len(self.X_train),
                        description="# Points",
                        style=style,
                    ),
                    "metrics": widgets.HTML(
                        value="<pre style='margin:0'>RMSE: –\nMLPD: –</pre>"
                    ),
                }
            )
        return w

    def _layout_widgets(self) -> widgets.Widget:
        w = self.widgets
        kernel_box = widgets.VBox(
            [w["kernel"], w["variance"], w["lengthscale"], w["period"]]
        )
        if self.plot_type == "kernel":
            return widgets.HBox([kernel_box, widgets.VBox([w["xmax"]])])
        if self.plot_type == "prior":
            sampling_box = widgets.VBox([w["samples"], w["xmin"], w["xmax"], w["seed"]])
            return widgets.HBox([kernel_box, sampling_box])
        data_box = widgets.VBox([w["noise"], w["n_used"], w["samples"], w["seed"]])
        metrics_box = widgets.VBox([widgets.HTML("<b>Metrics</b>"), w["metrics"]])
        return widgets.HBox([kernel_box, data_box, metrics_box])

    def _link_widgets_to_redraw(self):
        for name, widget in self.widgets.items():
            if name == "metrics":
                continue
            widget.observe(self.redraw, names="value")

    def _get_kernel(self):
        KernelClass = self.kernel_registry[self.widgets["kernel"].value]
        params = {
            "variance": float(self.widgets["variance"].value),
            "lengthscale": float(self.widgets["lengthscale"].value),
            "period": float(self.widgets["period"].value),
        }
        # show/hide period control if the kernel suggests it
        show_period = (
            "period" in KernelClass.__name__.lower()
            or "periodic" in self.widgets["kernel"].value.lower()
        )
        self.widgets["period"].layout.display = "" if show_period else "none"
        return KernelClass.from_params(params)

    @staticmethod
    def _compute_metrics(mu, std, y_true):
        eps = 1e-12
        rmse = float(np.sqrt(np.mean((mu - y_true) ** 2)))
        var = std**2 + eps
        mlpd = float(
            np.mean(-0.5 * np.log(2 * np.pi * var) - 0.5 * ((y_true - mu) ** 2) / var)
        )
        return rmse, mlpd

    # --- redraw
    def redraw(self, *_):
        ax = self.ax
        ax.cla()

        if self.plot_type == "kernel":
            self._plot_kernel(ax)
        elif self.plot_type == "prior":
            self._plot_prior(ax)
        else:
            self._plot_posterior(ax)

        ax.grid(True, alpha=0.5)
        ax.legend(loc="upper right")
        self.fig.canvas.draw_idle()

    # --- individual plot modes
    def _plot_kernel(self, ax: plt.Axes):
        w = self.widgets
        kernel = self._get_kernel()
        xmax = float(w["xmax"].value)
        X = jnp.linspace(-xmax, xmax, 500)[:, None]
        ax.plot(X, kernel(X, jnp.array([[0.0]])), lw=2, label="k(x, 0)")
        ax.set(
            ylim=(-0.1, float(w["variance"].value) * 1.1),
            title=f"'{w['kernel'].value}' Kernel Function",
            xlabel="x",
            ylabel="k(x,0)",
        )

    def _plot_prior(self, ax: plt.Axes):
        w = self.widgets
        kernel = self._get_kernel()
        gp = self.gp_cls(m=self.mean_fn, k=kernel)
        X_grid = jnp.linspace(float(w["xmin"].value), float(w["xmax"].value), 400)[
            :, None
        ]
        mu, std = plot_gp(
            ax,
            gp,
            X_grid,
            n_samples=int(w["samples"].value),
            seed=int(w["seed"].value),
            mean_label="Prior Mean",
            shade_k=2.0,
            shade_label="±2σ",
        )
        ax.set(title=f"GP Prior — {w['kernel'].value}", xlabel="x", ylabel="f(x)")

    def _plot_posterior(self, ax: plt.Axes):
        w = self.widgets
        kernel = self._get_kernel()
        gp = self.gp_cls(m=self.mean_fn, k=kernel)

        X_grid = jnp.linspace(float(self.X_test.min()), float(self.X_test.max()), 400)[
            :, None
        ]
        dist = gp(X_grid)

        n = int(w["n_used"].value)
        obs = None
        if n > 0:
            X_obs = self.X_train[:n]
            y_obs = self.y_train[:n]

            gp = gp.condition(
                y=jnp.asarray(y_obs),
                X=jnp.asarray(X_obs),
                sigma2=float(w["noise"].value) ** 2,
            )
            dist = gp(X_grid)
            obs = (X_obs, y_obs)

        mu, std = plot_gp(
            ax,
            dist,
            X_grid,
            n_samples=int(w["samples"].value),
            seed=int(w["seed"].value),
            mean_label="Posterior Mean",
            true_curve=(self.X_test, self.y_test),
            obs=obs,
        )
        # metrics
        y_interp = np.interp(
            np.asarray(X_grid).ravel(), self.X_test.ravel(), self.y_test.ravel()
        )
        rmse, mlpd = self._compute_metrics(mu, std, y_interp)
        self.widgets[
            "metrics"
        ].value = f"<pre style='margin:0'>RMSE: {rmse:.4f}\nMLPD: {mlpd:.4f}</pre>"
        ax.set(title=f"GP Posterior — {w['kernel'].value}", xlabel="x", ylabel="f(x)")


# ------------------------------------------------
# Optimization widget
# ------------------------------------------------


class OptimizationWidget:
    """Manages the UI for fitting and plotting."""

    def __init__(
        self,
        tuner,
        kernel_registry: dict,
        param_init_registry: dict,
        test_data: tuple,
        gp_cls,
        extrapolate: float | bool = False,
    ):
        self.tuner = tuner
        self.gp_cls = gp_cls
        self.kernel_registry = kernel_registry
        self.param_init_registry = param_init_registry
        self.X_test, self.y_test = test_data

        self.kernel_sel = widgets.Dropdown(
            options=list(self.kernel_registry.keys()), description="Kernel"
        )
        self.fit_button = widgets.Button(
            description="Fit Hyperparameters", button_style="success"
        )

        self.fig, self.ax = plt.subplots(figsize=(9, 4.5))
        self.ui = widgets.VBox(
            [widgets.HBox([self.kernel_sel, self.fit_button]), self.fig.canvas]
        )
        self.fit_button.on_click(self._on_fit_clicked)

        # Initial empty plot
        self.ax.text(
            0.5,
            0.5,
            "Select a kernel and click 'Fit' to begin.",
            ha="center",
            va="center",
        )
        self.ax.grid(True)

        self.extrapolate = extrapolate

    def _format_results_text(self, kernel_name, fit_results):
        """Formats the optimization results into a string for plotting."""
        lines = [f"--- Optimized: {kernel_name} ---"]
        final_params = fit_results["final_params"]
        log_params = fit_results["optimized_params_log"]

        for key, val in final_params.items():
            if key in ["noise", "sigma2"]:
                continue
            log_key = "log_" + key if key in ["variance", "lengthscale"] else key
            log_val_str = (
                f"(log={log_params.get(log_key, 'N/A'):.2f})"
                if log_key in log_params
                else ""
            )
            lines.append(f"{key:<12s} = {val:.3f} {log_val_str}")

        lines.append(
            f"{'noise':<12s} = {final_params['noise']:.3f}  (log={log_params['log_noise']:.2f})"
        )
        lines.append(f"Final Neg MLL = {fit_results['result'].fun:.2f}")
        return "\n".join(lines)

    def _on_fit_clicked(self, _):
        # 1. Update the plot to show an "Optimizing..." message
        self.ax.cla()
        self.ax.text(0.5, 0.5, "Optimizing...", ha="center", va="center", fontsize=14)
        self.fig.canvas.draw_idle()

        # 2. Run the optimization
        kernel_name = self.kernel_sel.value
        fit_results = self.tuner.fit(
            kernel_name, self.kernel_registry, self.param_init_registry
        )

        # 3. Build the optimized GP posterior
        final_params = fit_results["final_params"]
        sigma2_opt = final_params["sigma2"]
        k_opt = self.kernel_registry[kernel_name].from_params(final_params)

        gp_prior_opt = self.gp_cls(m=lambda x: jnp.zeros(x.shape[0]), k=k_opt)
        post_gp = gp_prior_opt.condition(
            y=self.tuner.y, X=self.tuner.X, sigma2=sigma2_opt
        )

        # 4. Clear the axes and draw the final plot
        self.ax.cla()
        X_grid = (
            jnp.linspace(
                self.X_test.min() - self.extrapolate,
                self.X_test.max() + self.extrapolate,
                500,
            )[:, None]
            if self.extrapolate
            else self.X_test[:, None]
        )
        plot_gp(
            self.ax,
            post_gp,
            X_grid,
            n_samples=3,
            obs=(self.tuner.X, self.tuner.y),
            true_curve=(self.X_test, self.y_test),
        )
        self.ax.set_title(f"Posterior with Optimized '{kernel_name}' Kernel")

        # 5. Add the formatted results text to the plot
        results_text = self._format_results_text(kernel_name, fit_results)
        self.ax.text(
            0.02,
            0.98,
            results_text,
            transform=self.ax.transAxes,
            verticalalignment="top",
            fontfamily="monospace",
            fontsize=9,
            bbox=dict(boxstyle="round,pad=0.5", fc="wheat", alpha=0.7),
        )

        self.ax.legend(loc="upper right")
        self.ax.grid(True)


In [None]:
from collections.abc import Callable
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import scipy.integrate as spi
from scipy.interpolate import interp1d
from tueplots.constants.color import palettes

tue_col_1 = palettes.tue_plot[0]
tue_col_2 = palettes.tue_plot[1]

# ----------------------------------
# CPU Problem
# ----------------------------------


@dataclass
class CPUHeatProblem1D:
    """
    Encapsulates the 1D CPU heat problem, now independent of linpde_gp.
    The ground truth solution is computed numerically.
    """

    # Geometry
    width: float
    domain: tuple[float, float]

    # Material & Heat Properties
    kappa: float
    TDP: float

    # Ground Truth Solution (numerically computed)
    solution: Callable[[np.ndarray], np.ndarray]

    # PDE right-hand-side function
    q_total: Callable[[np.ndarray], np.ndarray]

    # Synthetic DTS sensor data
    X_dts: np.ndarray
    y_dts: np.ndarray
    dts_noise_std: float


def create_cpu_problem() -> CPUHeatProblem1D:
    """
    Factory function to set up the CPU heat problem using NumPy and SciPy.
    """
    # -- Geometry
    width, height, depth = 16.28, 9.19, 0.37
    domain = (0.0, width)
    V = width * height * depth

    # -- Material property
    kappa = 1.56 * 10.0

    # -- Heat sources
    TDP = 95.0
    N_cores_x, core_width, core_offset_x, core_distance_x = 3, 2.5, 1.95, 0.35
    core_centers_xs = (
        core_offset_x
        + (core_width + core_distance_x) * np.arange(N_cores_x, dtype=np.double)
        + core_width / 2.0
    )

    # -- Heat source function (re-implemented with NumPy)
    def create_core_heat_source_fn(rel_heights=[0.9, 0.75, 1.0]):
        """Creates a callable heat source function using interpolation."""
        xs, ys = [domain[0]], [0.0]
        eps = core_distance_x / 3
        for cx, h in zip(core_centers_xs, rel_heights):
            xs.extend(
                [
                    cx - core_width / 2 - eps,
                    cx - core_width / 2,
                    cx + core_width / 2,
                    cx + core_width / 2 + eps,
                ]
            )
            ys.extend([0.0, h, h, 0.0])
        xs.append(domain[1])
        ys.append(0.0)

        # Normalize the distribution using the trapezoidal rule
        norm_const = np.trapz(ys, xs)
        ys_normalized = np.array(ys) / norm_const

        # Return a callable interpolation function
        return interp1d(xs, ys_normalized, bounds_error=False, fill_value=0.0)

    q_src_dist = create_core_heat_source_fn()

    def q_dot_V_src(x):
        return (TDP / (depth * height)) * q_src_dist(x)

    q_dot_V_sink = -TDP / V

    def q_total(x):
        return q_dot_V_src(x) + q_dot_V_sink

    # -- Numerical Ground Truth Solution --
    def create_numerical_solution(u0=60.0, du0=0.0):
        """Solves -k u'' = q_total for u using numerical integration."""
        x_dense = np.linspace(domain[0], domain[1], 2001)

        # We solve u'' = -q_total(x) / kappa
        rhs = -q_total(x_dense) / kappa
        du_dx = spi.cumulative_trapezoid(rhs, x_dense, initial=0.0) + du0
        u_values = spi.cumulative_trapezoid(du_dx, x_dense, initial=0.0) + u0

        # Return a callable interpolation function
        return interp1d(
            x_dense,
            u_values,
            bounds_error=False,
            fill_value=(u_values[0], u_values[-1]),
        )

    solution_fn = create_numerical_solution(u0=60.0, du0=0.0)

    # -- Synthetic DTS Data --
    dts_noise_std = 0.5
    noise_rng = np.random.default_rng(33215)
    noise_dts = noise_rng.normal(scale=dts_noise_std, size=len(core_centers_xs))
    y_dts = solution_fn(core_centers_xs) + noise_dts

    return CPUHeatProblem1D(
        width=width,
        domain=domain,
        kappa=kappa,
        TDP=TDP,
        solution=solution_fn,
        q_total=q_total,
        X_dts=core_centers_xs,
        y_dts=y_dts,
        dts_noise_std=dts_noise_std,
    )


# -----------------------------
# Plotting utilities
# -----------------------------


def plot_gp_belief_and_pde(
    gp,
    problem: CPUHeatProblem1D,
    X_grid: np.ndarray,
    conditions: list | None = None,
    n_samples: int = 0,
    seed: int = 0,
    title: str | None = None,
):
    """
    Generates a two-panel plot, now with support for visualizing integral conditions.
    """
    fig, (ax_u, ax_pde) = plt.subplots(
        ncols=2, sharex=True, figsize=(12, 3.5), constrained_layout=True
    )
    if conditions is None:
        conditions = []

    # -- Prepare GP data for plotting --
    X_plot = np.asarray(X_grid).reshape(-1, 1)
    gp_eval = gp(jnp.asarray(X_plot))
    mu = np.asarray(gp_eval.mu).flatten()
    Sigma = np.asarray(gp_eval.Sigma)
    std = np.sqrt(Sigma.diagonal())

    # --- 1. Left Panel: Belief over Temperature u(x) ---
    ax_u.set(
        ylabel="Temperature (°C)",
        title="Belief over Temperature $u(x)$",
        xlim=problem.domain,
    )
    ax_u.grid(True)
    ax_u.plot(
        X_grid, problem.solution(X_grid), color="k", lw=2, ls="-", label="True Solution"
    )
    ax_u.fill_between(
        X_grid,
        mu - 1.96 * std,
        mu + 1.96 * std,
        color="C0",
        alpha=0.2,
        label="95% Credible Interval",
    )
    ax_u.plot(X_grid, mu, color="C0", lw=2, label="GP Mean")
    if n_samples > 0:
        samples = gp_eval.sample(jax.random.key(seed), n_samples)
        ax_u.plot(X_grid, samples.T, color="C0", lw=1.0, alpha=0.5)

    # --- 2. Right Panel: PDE Balance Check ---
    ax_pde.set(
        xlabel="x-position (mm)",
        ylabel="Heat Flow",
        title=r"PDE Balance: $-\kappa u''(x)$ vs $\dot{q}_V(x)$",
    )
    ax_pde.grid(True)
    ax_pde.plot(
        X_grid,
        problem.q_total(X_grid),
        color="C1",
        lw=2,
        label=r"Heat Source $\dot{q}_V(x)$",
    )

    h = np.median(np.diff(X_grid))
    d2_mu = (np.roll(mu, -1) - 2 * mu + np.roll(mu, 1)) / h**2
    d2_mu[[0, -1]] = d2_mu[[1, -2]]
    mean_Lu = -problem.kappa * d2_mu
    N = len(X_grid)
    D = np.zeros((N, N))
    D[range(1, N - 1), range(0, N - 2)] = 1
    D[range(1, N - 1), range(1, N - 1)] = -2
    D[range(1, N - 1), range(2, N)] = 1
    D /= h**2
    D[0, 0:3] = [1, -2, 1]
    D[-1, -3:] = [1, -2, 1]
    Sigma_d2 = D @ Sigma @ D.T
    var_d2 = np.diag(Sigma_d2)
    std_Lu = problem.kappa * np.sqrt(np.maximum(var_d2, 0))
    ax_pde.plot(X_grid, mean_Lu, color="C0", lw=2, label=r"GP-implied $-\kappa u''(x)$")
    ax_pde.fill_between(
        X_grid, mean_Lu - 1.96 * std_Lu, mean_Lu + 1.96 * std_Lu, color="C0", alpha=0.2
    )

    if n_samples > 0:
        d2_samples = np.array(
            [
                (np.roll(sample, -1) - 2 * sample + np.roll(sample, 1)) / h**2
                for sample in samples
            ]
        )
        d2_samples[:, [0, -1]] = d2_samples[:, [1, -2]]
        Lu_samples = -problem.kappa * d2_samples
        ax_pde.plot(
            X_grid,
            Lu_samples.T,
            color="C0",
            lw=1.0,
            alpha=0.5,
        )

    # -- Plot all provided conditions --
    for cond in conditions:
        cond_name = type(cond.op).__name__
        if cond_name == "PDEObservation":
            ax_pde.scatter(
                cond.X,
                cond.y_vec,
                s=40,
                marker="x",
                color=tue_col_1,
                label="PDE Collocation Points",
                zorder=5,
            )
        elif cond_name == "SensorObservation":
            ax_u.errorbar(
                cond.X,
                cond.y_vec,
                yerr=np.sqrt(cond.op.sigma2),
                fmt="o",
                color=tue_col_2,
                label="Sensor Observations",
                zorder=5,
            )
            ax_pde.scatter(
                cond.X,
                np.zeros_like(cond.X),
                marker="o",
                color=tue_col_2,
                s=40,
                label="Sensor Locations",
            )
        elif cond_name == "BoundaryObservation":
            ax_u.errorbar(
                cond.X,
                cond.y_vec,
                yerr=np.sqrt(cond.op.sigma2),
                fmt="o",
                color=tue_col_2,
                label="Dirichlet Condition",
                zorder=5,
            )

        else:
            print(f"Warning: Unknown condition type for plotting: {cond_name}")

    ax_u.legend(loc="lower left")
    ax_pde.legend(loc="upper right")
    if title:
        fig.suptitle(title, fontsize=16)
    plt.show()

In [None]:
!pip install jax
!pip install matplotlib
!pip install numpy
!pip install tueplots
!pip install jaxtyping
!pip install beartype
!pip install IPython
!pip install ipykernel
!pip install ipywidgets
!pip install ipympl
!pip install plum-dispatch

# From Gaussian Processes to Physics-Informed Regression

Welcome! This tutorial will let you implement Gaussian Processes (GPs) regression from scratch and demonstrate how to embed physical laws into them in different ways. We will cover:

1.  **Part 1: A Primer on Gaussian Processes:** We will start with the fundamentals of GP regression, understanding them as distributions over functions and using them for interpolating data.

2.  **Part 2: Physics-Informed Regression:** We will then see how to condition a Gaussian Process on domain knowledge arising in physical applications, such as boundary conditions or the PDE itself.

By the end, you will see how this approach allows us to make accurate predictions even with very sparse data, which is a common problem in many scientific and engineering applications. Our goal is to build a solid intuition for these powerful techniques and gain understanding on how to implement such concepts in a general way.


---

### Tools We'll Use

We'll be using a modern, powerful stack of libraries for this tutorial.

* **JAX**: A high-performance numerical computing library from Google. We use **JAX** for its NumPy-like API, automatic differentiation capabilities, and just-in-time (JIT) compilations. We'll enable 64-bit precision (double) for better numerical stability in our linear algebra operations.

* **Jaxtyping & Beartype**: To write bug-free code, we'll heavily rely on type annotations. **Jaxtyping** allows us to annotate the shapes and data types of our JAX arrays directly in the function signatures. **Beartype** then acts as a runtime type-checker, catching shape-related errors as soon as they happen, which is invaluable for debugging.

* **Plum & Multiple Dispatch**: We will use `plum-dispatch` to implement multiple dispatch. This is a powerful programming paradigm that allows us to define multiple versions of the same function that operate on different data types. It helps in writing clean and extensible code by avoiding complex `if/else` chains.

* **Tueplots**: For all our visualizations, we will use **Tueplots**, a library designed to create aesthetically pleasing, publication-quality plots with ease.

In [None]:
# %% Imports
%matplotlib widget

import warnings

import jax
import jax.numpy as jnp # API compatible with numpy
import matplotlib.pyplot as plt
import numpy as np # for matplotlib compatibility
from tueplots import bundles

# JAX settings - enable 64-bit precision
jax.config.update("jax_enable_x64", True)

# Suppress warnings for a cleaner notebook
warnings.filterwarnings("ignore")

# %% Jaxtyping and Beartype
# Enable runtime type checking to catch shape errors
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

# %% Plot settings - enable latex packages
rcparams = bundles.beamer_moml()
rcparams["text.latex.preamble"] = r"\usepackage{amsfonts}\n\usepackage{siunitx}\n\usepackage{bm}"
plt.rcParams.update(rcparams)

# %% Utility functions
def atleast_2d(arr: jax.Array) -> jax.Array:
    arr = jnp.asarray(arr)
    if arr.ndim == 1:
        return arr[:, None]
    if arr.ndim == 2:
        return arr
    msg = f"Input array must be 1D or 2D, but got {arr.ndim}D."
    raise ValueError(msg)

### A Quick Example of `jaxtyping`

To understand how `jaxtyping` and `beartype` help us, let's look at a quick example. Imagine we have a function that performs a matrix-vector multiplication.

With `jaxtyping`, we can write the function signature like this:

In [None]:
from jaxtyping import Array, Float

def matrix_vector_product(
    matrix: Float[Array, "rows cols"], 
    vector: Float[Array, " cols"]
) -> Float[Array, " rows"]:
    return matrix @ vector

A = jnp.ones((3, 2))  # Shape is (rows=3, cols=2)
x = jnp.ones(2)       # Shape is (cols=2)
result = matrix_vector_product(A, x)
print(result.shape)

When things go wrong, `beartype` catches the error immediately because the shapes don't match the annotations.

In [None]:
# This will raise an error
A = jnp.ones((3, 2))  # Shape is (rows=3, cols=2)
y = jnp.ones(5)       # Shape is (5), but expected (cols=2)
try:
    matrix_vector_product(A, y)
except Exception as e:
    print(f"Error: {e}")

## Part 1: A primer on Gaussian Processes (from an implementation perspective)

In [5]:
from jaxtyping import Array, Float
from collections.abc import Callable
from dataclasses import dataclass 
# We can use this to save writing __init__ for parameter definitions.

FloatType = float  | Float[Array, ""]

In the first part we will see the main idea behind GP regression, see how to implement GPs (as it is done in probnum, GPytorch and GPJax), play around with kernels and learn how to optimize hyperparameters.

### Part 1 — Problem setup: Damped Oscillator (data for GP regression)

**Goal.** Define a physical signal (damped harmonic oscillator), generate sparse/noisy observations, and visualize the true trajectory.  
This will be the dataset for vanilla GP regression in Part 1.




#### Damped oscillator model

We model the displacement $y(t)$ as
$$ y(t) = A \, e^{-\gamma t} \cos(\omega t),$$
with parameters $\theta = (A, \gamma, \omega)$.

Interpretation:
- $A$: initial amplitude (units of displacement)
- $\gamma$: damping coefficient (1/time); (larger ⇒ faster decay) 
- $\omega$: angular frequency (rad/time); (larger ⇒ faster oscillation)

We keep this deterministic here. Stochasticity will enter via measurement noise when generating training data. Why this system? It is linear, underdamped, and has interpretable parameters. Also it has some clear structure which we will later exploit for learning about including prior knowledge.

We gonna start our tutorial by collecting some data about a physical system. Consider the following damped oscillator simulation. We collect its discplacement at 15 observation points.

In [6]:
@dataclass
class DampedOscillatorParams:
    """Physical parameters for the damped oscillator."""
    A: FloatType = 1.0      # Amplitude
    gamma: FloatType = 0.25 # Damping coefficient
    omega: FloatType = 2.0  # Angular frequency

def damped_oscillator(
    t: Float[Array, " T"],
    params: DampedOscillatorParams
) -> Float[Array, " T"]:
    """
    Displacement y(t) = A * exp(-gamma * t) * cos(omega *t)

    Args:
        t: scalar or array of times.
        params: (A, gamma, omega).

    Returns:
        y(t) with the same shape as t.
    """
    A, gamma, omega = params.A, params.gamma, params.omega
    return A * jnp.exp(-gamma * t) * jnp.cos(omega * t)

Let us animate this system.

In [None]:
from chapter1 import animate_damped_oscillator

# Time points for animation
t_anim = jnp.linspace(0, 20, 200)

# "true" physical parameters for data generation
theta_true = DampedOscillatorParams(A=1.0, gamma=0.25, omega=2.0)

animate_damped_oscillator(
    damped_oscillator,
    t_anim,
    theta_true
)

#### Test data

We define a **dense test grid** for plotting the true trajectory. We will later sample **sparse, noisy observations** for training. For reproducibility we will use the **JAX PRNG** with a fixed seed.

In [8]:
# reproducibility
key = jax.random.key(42)

# time ranges
t_min, t_max = 0.0, 10.0
X_test = jnp.linspace(t_min, t_max, 600) # for smooth plots

y_test = damped_oscillator(X_test, theta_true)

#### Training data

We simulate a typical experimental setting:

- Sample $n_{\text{train}}$ time points uniformly in $[t_{\min}, t_{\max}]$.
- Observe $y_i = y(t_i) + \varepsilon_i$ with i.i.d. Gaussian noise $\varepsilon_i \sim \mathcal{N}(0, \sigma^2)$.

**Parameters**
- `n_train`: number of observations
- `sigma_noise`: measurement noise std (controls signal-to-noise ratio).


In [9]:
from jaxtyping import PRNGKeyArray

def make_training_data(
    key: PRNGKeyArray,
    n_train: int,
    t_min: float,
    t_max: float,
    params: DampedOscillatorParams,
    sigma_noise: float = 0.1,
) -> tuple[Float[Array, "N 1"], Float[Array, " N"]]:
    """
    Create sparse, noisy observations (X_train, y_train).count
    """
    k1, k2 = jax.random.split(key)
    X = jax.random.uniform(k1, shape=(n_train, 1), minval=t_min, maxval=t_max)
    y_clean = damped_oscillator(X.squeeze(-1), params)
    y_noisy = y_clean + sigma_noise * jax.random.normal(k2, shape=y_clean.shape)
    return X, y_noisy

n_train = 30
sigma_noise = 0.05
X_train, y_train = make_training_data(
    key, n_train=n_train, t_min=t_min, t_max=t_max, params=theta_true, sigma_noise=sigma_noise
)

Let us now inspect the dataset.

- Red line: true (noise-free) trajectory on a dense grid.
- Blue crosses: sparse, noisy training points that the GP will fit in the the next step.

In [None]:
fig, ax = plt.subplots()
ax.plot(np.asarray(X_test), np.asarray(y_test), label="True signal", linewidth=2)
ax.scatter(np.asarray(X_train.squeeze(-1)), np.asarray(y_train), marker="x", s=60, label="Training data", color="blue")
ax.set_xlabel("time t")
ax.set_ylabel("displacement y(t)")
ax.set_title("Damped Oscillator — Data for GP Regression")
ax.legend(loc="upper right")
plt.show()

### Implementing Gaussian Processes

#### Gaussian Processes (GPs)

A **Gaussian process** is a powerful tool that lets us define a *distribution over functions*. When we say a function $f$ is drawn from a GP, we write:

$$f \sim \mathcal{GP}(m, k)$$

where $m$ is the **mean function** and $k$ is the **kernel function**.

This means that if you pick any finite set of points, say $X = [x_1, \dots, x_n]$, the function values at those points, $f(X)$, will follow a multivariate normal distribution:

$$ f(X) \sim \mathcal{N}\big(m(X),\, K(X, X) \big) $$

The mean vector is given by $m(X) = [m(x_i)]_i$, and the covariance matrix, often called the **Gram matrix**, is constructed from the kernel function: $K_{ij} = k(x_i, x_j)$.

#### The Kernel Function 

The **kernel function** is the most important ingredient. It encodes our assumptions about the function we are trying to model (e.g., is it smooth? periodic? rough?). A kernel takes two points, $x$ and $x'$, and returns a scalar value that represents their "similarity".

* **High kernel value**: The points are considered "similar," and the function values at these points are expected to be strongly correlated.

* **Low kernel value**: The points are "dissimilar," and their function values are less correlated.

Any function can be a kernel as long as the resulting Gram matrix is symmetric and positive semi-definite. Here are some popular choices:

**1. Radial Basis Function (RBF) Kernel**

This is the most common kernel, often called the "squared exponential" kernel. It's a great default choice and assumes the function is infinitely differentiable (i.e., very smooth).

$$ k_\mathrm{RBF}(x, x') = \sigma^2 \exp \left( - \frac{\| x - x' \|^2}{2\ell^2} \right) $$

* `variance` ($\sigma^2$): Controls the average vertical variation of the function.
* `lengthscale` ($\ell$): Controls the horizontal "wiggliness". A small $\ell$ leads to functions that vary quickly, while a large $\ell$ leads to smoother, slowly varying functions.


**2. Matérn 5/2 Kernel**

The Matérn family of kernels is a generalization of the RBF kernel. The Matérn 5/2 is a popular choice because it assumes the function is twice-differentiable, making it suitable for modeling physical processes that are not infinitely smooth.

$$ k_{\mathrm{Matérn}5/2}(r) = \sigma^2 \left( 1 + \frac{\sqrt{5}r}{\ell} + \frac{5r^2}{3\ell^2} \right) \exp \left( - \frac{\sqrt{5}r}{\ell} \right), \quad \text{where } r = \| x - x' \| $$


**3. Periodic Kernel**
This kernel is perfect for modeling functions that exhibit repetitive patterns, like seasonal data or wave-like signals. It constructs a periodic function by mapping the input space onto a circle using the sine function and then applying an RBF-like kernel in this new space.

$$k_\mathrm{Periodic}(x, x') = \sigma^2 \exp \left( - \frac{2 \sin^2(\pi \|x - x'\| / p)}{\ell^2} \right)$$

It has three key hyperparameters:

* **`variance` ($\sigma^2$)**: Similar to the RBF kernel, this controls the average vertical variation of the function.
* **`lengthscale` ($\ell$)**: This controls the smoothness of the function *within a single period*. A small lengthscale will lead to more complex, wiggly patterns inside each repetition.
* **`period` ($p$)**: This is the most important parameter here, as it defines the distance over which the function repeats itself.

This kernel is a great choice when you have a strong prior belief that your underlying function is periodic.

#### What we need to implement

We will implement a Gaussian Process from the ground up. This will help solidify our understanding of how they work and what is needed. Our implementation will be modular, consisting of four main components:

1. **Gaussian**, a multivariate Gaussian distributed random variable, which we will use for sampling.

2.  **Mean Functions (`m(x)`)**: These define the average value of the function. For simplicity, we'll start with a `ZeroFunction`, which assumes the function's mean is zero everywhere.

2.  **Kernel Functions (`k(x, x')`)**: This is the heart of the GP. The kernel, also known as a covariance function, defines the "similarity" between points. It encodes our assumptions about the function's smoothness and shape.

3.  **The `GaussianProcess` Class**: This class will bring the mean and kernel functions together to define the GP prior distribution.

For simplicity we will introduce a `MeanFunction` and `KernelFunction` Type.

In [11]:
from functools import cached_property

@dataclass
class Gaussian:
    mu: Float[Array, " N"]
    Sigma: Float[Array, "N N"]
    jitter: FloatType = 1e-9

    @cached_property
    def L(self) -> Float[Array, "N N"]:
        """Cholesky factor of the covariance matrix."""
        return jnp.linalg.cholesky(self.Sigma + self.jitter * jnp.eye(self.Sigma.shape[0]))

    def sample(self, key: PRNGKeyArray, num_samples: int = 1) -> Float[Array, "num_samples N"]:
        """
        Sample from the multivariate normal distribution N(mu, Sigma). 

        Hint: Use the reparameterization trick. Use the cholesky factor L of the covariance matrix Sigma to transform 
        standard normal samples z ~ N(0, I) into samples y ~ N(mu, Sigma).
        """ 

        # YOUR CODE HERE
        raise NotImplementedError("You need to implement the sample method!")
    
    def log_pdf(self, x: Float[Array, " N"]) -> Float[Array, ""]:
        """
        Computes the log porbability density of a vector x under the multivariate normal distribution N(mu, Sigma).

        The formula for the log PDF is:
            log p(x) = -0.5 * [ (x - mu)^T Sigma^{-1} (x - mu) + log|Sigma| + N log(2*pi) ]
        """
        x = jnp.atleast_1d(x)
        n = x.shape[0]

        # Step 1: Calculate the quadratic term: (x - mu)^T * Sigma^{-1} * (x - mu)
        # Compute difference from the mean
        resid = x - self.mu

        # Solve for v in L * v = resid, which is equivalent to v = L^{-1} @ resid
        v = jax.scipy.linalg.solve_triangular(self.L, resid, lower=True)

        # Compute quadratic term v.T @ v
        quad = jnp.dot(v, v)

        # Step 2: Calculate the log determinant term: log|Sigma|
        logdet = 2.0 * jnp.sum(jnp.log(jnp.diag(self.L))) 

        # Step 3: Combine terms
        return -0.5 * (quad + logdet + n * jnp.log(2.0 * jnp.pi))

In [None]:
# TEST YOUR CODE
# 1. Define a simple 2D Gaussian distribution
true_mean = jnp.array([1.0, -2.0])
true_cov = jnp.array([[2.0, 0.8], 
                     [0.8, 1.0]])
gaussian = Gaussian(mu=true_mean, Sigma=true_cov)

# 2. Draw a large number of samples
key = jax.random.key(0)
num_samples = 200_000
samples = gaussian.sample(key, num_samples=num_samples)

# 3. Compute the empirical mean and covariance
empirical_mean = jnp.mean(samples, axis=0)
empirical_cov = jnp.cov(samples, rowvar=False)

# 4. Compare and print the results
print("True Mean:\n", true_mean)
print("Empirical Mean:\n", empirical_mean)
print("\nTrue Covariance:\n", true_cov)
print("Empirical Covariance:\n", empirical_cov)

In [13]:
MeanFunction = Callable[[Float[Array, "N D"]], Float[Array, " N"]]

@dataclass
class ZeroFunction:
    def __call__(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        return jnp.zeros(X.shape[0])
    
@dataclass
class ConstantFunction:
    value: float = 0.0

    def __call__(self, X: Float[Array, "N D"]) -> Float[Array, " N"]:
        return jnp.full(X.shape[0], self.value)

In [14]:
from abc import abstractmethod
from dataclasses import fields

KernelFunction = Callable[[Float[Array, "N D"], Float[Array, "M D"]], Float[Array, "N M"]]

def _sqeuclidean(X: Float[Array, "N D"], Y: Float[Array, "M D"]) -> Float[Array, "N M"]:
    """Pairwise squared Euclidean distances between two sets of vectors."""
    diff = X[:, None, :] - Y[None, :, :]  # (N, M, D)
    return jnp.sum(diff**2, axis=-1)         # (N, M)

@dataclass 
class Kernel:
    """A base class for all kernel functions."""

    @classmethod
    def from_params(cls, params: dict[str, FloatType]) -> KernelFunction:
        """Constructs a kernel instance from a dictionary of parameters.
        
        This method automatically filters the dictionary to only use the parameters that 
        the specific kernel dataclass expects.
        """
        valid_keys = {f.name for f in fields(cls)}
        filtered_params = {k: v for k, v in params.items() if k in valid_keys}
        return cls(**filtered_params)
    
    @abstractmethod
    def __call__(self, X: Float[Array, "N D"], Y: Float[Array, "M D"]) -> Float[Array, "N M"]:
        """Compute the kernel matrix between two sets of input points X and Y."""
        raise NotImplementedError


@dataclass
class RBF(Kernel):
    variance: FloatType = 1.0
    lengthscale: FloatType = 1.0

    def __call__(self, X: Float[Array, "N D"], Y: Float[Array, "M D"]) -> Float[Array, "N M"]:
        """
        Implement the RBF kernel.
        k(x, x') = variance * exp(-0.5 * ||x - x'||^2 / lengthscale^2)

        You can use the provided `_sqeuclidean` function to compute pairwise squared distances.
        """
        # YOUR CODE HERE
        raise NotImplementedError("You need to implement the RBF kernel!")


@dataclass
class Matern52(Kernel):
    variance: FloatType = 1.0
    lengthscale: FloatType = 1.0

    def __call__(self, X: Float[Array, "N D"], Y: Float[Array, "M D"]) -> Float[Array, "N M"]:
        d2 = _sqeuclidean(X, Y)
        r = jnp.sqrt(jnp.maximum(d2, 0.0))
        s = jnp.sqrt(5.0) * r / self.lengthscale
        return self.variance * (1.0 + s + (s**2) / 3.0) * jnp.exp(-s)


@dataclass
class Periodic(Kernel):
    variance: FloatType = 1.0
    lengthscale: FloatType = 1.0
    period: FloatType = 2.0 * jnp.pi

    def __call__(self, X: Float[Array, "N D"], Y: Float[Array, "M D"]) -> Float[Array, "N M"]:
        d = jnp.abs(X[:, None, :] - Y[None, :, :])       # (N, M, D)
        s = jnp.sin(jnp.pi * d / self.period)
        d2 = jnp.sum(s * s, axis=-1)                     # (N, M)
        return self.variance * jnp.exp(-2.0 * d2 / (self.lengthscale ** 2))
    
KERNEL_REGISTRY: dict[str, KernelFunction] = {
    "RBF": RBF,
    "Matérn 5/2": Matern52,
    "Periodic": Periodic,
}

In [None]:
# TEST YOUR CODE
x_grid = jnp.linspace(-5, 5, 500)[:, None]
kernel = RBF()
values = kernel(x_grid, jnp.zeros((1, 1)))
print(values.shape) # Should be of shape (500, 1)

In [None]:
from chapter1 import InteractiveGPPlotter

plotter = InteractiveGPPlotter(
    plot_type="kernel",
    kernel_registry=KERNEL_REGISTRY,
)
plotter.ui

In [17]:
@dataclass
class GaussianProcess:
    """
    A GP prior is fully specified by its mean and kernel function.
    m: (N, D) -> (N,)
    k: (N, D), (M, D) -> (N, M)
    """
    m: MeanFunction
    k: KernelFunction

    def __call__(self, X: Float[Array, "N D"]) -> "Gaussian":
        """
        Evaluate the GP prior at a given set of (grid) points.
        
        This is the core of the GP. According to the definition, evaluating a GP at a set of points X
        gives a multivariate normal distribution.

        Task: Compute the mean and covariance.
        """
        # YOUR CODE HERE
        # raise NotImplementedError("You need to implement the __call__method!")

        # SOLUTION
        mu = self.m(X)
        Sigma = self.k(X, X)
        return Gaussian(mu, Sigma)

    def condition(
        self,
        y: Float[Array, " N"],
        X: Float[Array, "N D"],
        sigma2: float,
    ) -> "GaussianProcess":
        """Conditions the GP on observed data. We will implement this class later."""
        return ConditionalGaussianProcess(self, y, X, sigma2)

In [None]:
from chapter1 import plot_gp

# TEST YOUR CODE
gp = GaussianProcess(m=ZeroFunction(), k=RBF(variance=1.0, lengthscale=1.0))
X_grid = jnp.linspace(-5, 5, 500)[:, None]
dist = gp(X_grid)
print(f"Mean shape: {dist.mu.shape}, Cov shape: {dist.Sigma.shape}") # Should be shape (500,) and (500, 500)

# ## TEST PLOT
# fig= plt.figure(figsize=(9, 4.5))
# ax = plt.gca()
# plot_gp(ax, gp, X_grid, n_samples=5, seed=0)
# plt.show()

Now that we've implemented the core components, let us bring them to life! The widget below allows you to draw function samples directly from the Gaussian Process prior you define.

**Things to Try**:
- **Kernel**: Switch between the `RBF`, `Matern 5/2` and `Periodic` kernels. Notice the difference in smoothness?
- `lengthscale`: How does a small vs. a large lengthscale affect the functions?
- `variance`: What happens to the vertical spread of the functions when you change the variance?
- `period`: What happens here?

In [None]:
plotter = InteractiveGPPlotter(
    plot_type='prior',
    kernel_registry=KERNEL_REGISTRY,
    gp_cls=GaussianProcess,
    mean_fn=ZeroFunction(),    
)
plotter.ui

#### Deriving the GP Posterior: Conditioning on Data

So far, we have a prior -- a distribution over functions before seeing any data. Now, we want to update this prior belief using our noisy observations
$(X, y)$ to get a posterior distribution by conditioning the GP on this data. For GPs, this is elegant because it is just conditioning a big joint Gaussian distribution.  

##### Setup
We start with our GP prior, $f \sim \mathcal{GP}(m,k)$. We consider now our training inputs $X$ (where we have data) and our test inputs $X_*$ (where we want to predict). The function 
values at all these points, $f(X)$ and $f(X_*)$, are jointly Gaussian:

$$
\begin{bmatrix} f(X) \\ f(X_*) \end{bmatrix} \sim \mathcal{N} \left( \begin{bmatrix} m(X) \\ m(X_*) \end{bmatrix}, \begin{bmatrix} K_{XX} & K_{X X_*} \\ K_{X_* X} & K_{X_* X_*} \end{bmatrix} \right)
$$

Here, $K_{AB}$ is the matrix of the kernel evaluated between all points in sets $A$ and $B$.

##### Observation Model
Usually, we do not observe the true function $f(X)$ directly. Instead, we see noisy version $y = f(X) + \varepsilon$, where the noise $\varepsilon$ is typically Gaussian, $\varepsilon \sim \mathcal{N}(0, \sigma^2 I)$. 
This means our observations $y$ are also Gaussian:

$$ y \sim \mathcal{N}\big( m(X), K_{XX} + \sigma^2 I \big).$$

##### Joint Distribution of Data and Predictions

Now we can write the joint distribution of what we've seen ($y$) and what we want to predict ($f_*$):

$$
\begin{bmatrix} y \\ f_* \end{bmatrix} \sim \mathcal{N}\!\left(
\begin{bmatrix}
m(X) \\
m(X_*)
\end{bmatrix},
\begin{bmatrix}
K_{XX} + \sigma^2 I & K_{X X_*} \\
K_{X_* X} & K_{X_* X_*}
\end{bmatrix}
\right).
$$


##### Closed-form conditioning

For any joint Gaussian, there's a standard formula to find the conditional distribution $p(b \mid a)$. Applying this rule to our joint distribution gives us the posterior predictive distribution $p(f_* \mid y)$.  

The posterior distribution for $f_*$ at test points $X_*$ is a new Gaussian distribution with the following mean and covariance:

$$
\boxed{
\begin{aligned}
\mu_{\text{post}}(X_*) &= m(X_*) + K_{X_*X}(K_{XX}+\sigma^2 I)^{-1}(y - m(X)), \\\\[6pt]
K_{\text{post}}(X_*,X_*) &= K_{X_*X_*} - K_{X_*X}(K_{XX}+\sigma^2 I)^{-1}K_{XX_*}.
\end{aligned}
}
$$

##### Implementation

As we learned before, we should avoid computing a matrix inverse directly. We can rewrite the update rules using a pre-computed vector $\alpha$ (representer weights) and the 
Cholesky factor $L$ of the training covariance matrix $(K_{XX} + \sigma^2 I)$. 

First, solve for a weight vector $\alpha$:
$$
\alpha = (K_{XX} + \sigma^2 I)^{-1} (y - m(X)).
$$
Then, the posterior mean is simply a weighted combination of kernel similarities:
$$
\mu_{\text{post}}(X_*) = m(X_*) + K_{X_*X}\,\alpha.
$$
The posterior covariance can be computed similarly using efficient triangular solves with the Cholesky factor $L$.

Let us now implement a conditional Gaussian Process.

TASK:
- Tip: You can use `jax.scipy.linalg.solve_triangular` for computing....

In [20]:
@dataclass
class ConditionalGaussianProcess(GaussianProcess):
    """
    Represents the GP posterior distribution after conditioning on data.
    """
    def __init__(
        self,
        prior: GaussianProcess,
        y: Float[Array, " N"],
        X: Float[Array, "N D"],
        sigma2: float,
        jitter: float = 1e-9,
    ):
        self.prior = prior
        self.X = X
        self.y = y
        self.sigma2 = sigma2
        self.jitter = jitter

        # Pre-compute shared qunatities for efficient prediction
        Kxx = prior.k(self.X, self.X) + (sigma2 + jitter) * jnp.eye(self.X.shape[0])
        self.L = jax.scipy.linalg.cholesky(Kxx, lower=True)
        resid = self.y - prior.m(self.X)
        self.alpha = jax.scipy.linalg.solve(Kxx, resid)

        # Alternatively, we can do the two-step triangular solve for better numerical stability:
        # v = jax.scipy.linalg.solve_triangular(self.L, resid, lower=True)
        # self.alpha = jax.scipy.linalg.solve_triangular(self.L.T, v, lower=False)

        def m_post(Xstar: Float[Array, "Nstar D"]) -> Float[Array, " Nstar"]:
            """
            Calculate the posterior mean.
            m_post(Xstar) = m(Xstar) + K(Xstar, X) @ alpha
            
            Your task:
            Compute the posterior mean at new input locations Xstar.
            """
            # YOUR CODE HERE
            raise NotImplementedError("You need to implement the posterior mean fucntion!")

        def k_post(Xstar1: Float[Array, "Nstar1 D"], Xstar2: Float[Array, "Nstar2 D"] = None) -> Float[Array, "Nstar1 Nstar2"]:
            """
            Calculate the posterior covariance.
            K_post(A, B) = K(A, B) - K(A, X) @ (Kxx)^-1 @ K(X, B)

            Your task:
            1. Use jax.scipy.linalg.solve_triangular with self.L to solve for v = L^-1 @ K(X, B). 
            Then solve for w = (L^T)^-1 @ v. 
            2. Compute the final result using the formula.            
            """
            # YOUR CODE HERE
            raise NotImplementedError("You need to implement the posterior covariance function!")
        
        super().__init__(m=m_post, k=k_post)

In [None]:
# TEST YOUR CODE

# 1. Define a GP Prior
gp_prior = GaussianProcess(m=ZeroFunction(), k=RBF())

# 2. Create some noisy training data
X_train_test = jnp.array([[-4.0], [-1.0], [2.5]])
y_true_test = gp_prior(X_train_test).sample(jax.random.key(1), num_samples=1).flatten()
noise = 0.1
y_train_test = y_true_test + noise * jax.random.normal(jax.random.key(2), y_true_test.shape)

# 3. Condition the prior on the data to get the posterior
gp_posterior = gp_prior.condition(y=y_train_test, X=X_train_test, sigma2=noise**2)

# 4. Let's check the output shapes on a grid
X_grid = jnp.linspace(-5, 5, 500)[:, None]
dist_post = gp_posterior(X_grid)
print(f"Posterior Mean shape: {dist_post.mu.shape}, Posterior Cov shape: {dist_post.Sigma.shape}") # Should be (500,) and (500, 500)

# # 5. Plot the posterior belief
# fig, ax = plt.subplots(figsize=(9, 4.5))
# plot_gp(
#     ax,
#     gp_posterior,
#     X_grid,
#     n_samples=5,
#     seed=42,
#     mean_label="Posterior Mean",
#     obs=(X_train_test, y_train_test)
# )
# ax.set_title("GP Posterior after Conditioning on Data")
# plt.show()

#### How Good is Our Model?

Before we start exploring the posterior, let us think about how to measure success. A plot gives us a good feeling, but a number is better for comparing different models. To do
so we want to introduce two useful metrics to tell us quantitatively how well our GP is doing. This will be crucial for tuning hyperparameters or kernel choices later on.

We will look at two key metrics: one for **accuracy** and one for **quality of uncertainty**.

1. **Root Mean Squared Error (RMSE)**: The most common way to measure the average prediction error is the **RMSE**. It tells you on average how far off the model's mean prediction is from the true value. Of course, lower is better. $$ \text{RMSE} \;=\; \sqrt{\frac{1}{n} \sum_{i=1}^n \big( \hat{y}_i - y_i \big)^2 }. $$
2. **Mean Log Predictive Density (MLPD)**: With GPs we have a whole predictive distribution for each point. The **MLPD** measures how well this distribution explains the data we actually saw. We then average the log of the pointwise probabilities. A higher MLPD is better. The metric rewards a model that correctly expresses high uncertainty when it is not sure. $$ \text{MLPD} \;=\; \frac{1}{n} \sum_{i=1}^n \log \mathcal{N}\!\big(y_i \,;\, \mu_i, \sigma_i^2\big). $$

In the next cell, you will see both of these metrics in action when conditioning on the training data specified at the beginning of the tutorial. Try to find the settings that give you the best RMSE and the best MLPD. Are they always the same?

In [None]:
plotter = InteractiveGPPlotter(
    plot_type='posterior',
    kernel_registry=KERNEL_REGISTRY,
    gp_cls=GaussianProcess,
    mean_fn=ZeroFunction(),    
    training_data=(X_train, y_train),
    test_data=(X_test, y_test)
)
plotter.ui

### Finding the Best Hyperparameters with the Marginal Log-Likelihood

So far, we have been moving sliders to tune our GP's hyperaparameters (`lengthscale`, `variance`, `sigma_noise`). It is a good way to build intuition, but of course it is not practical. How can we find the best values automatically? 
A powerful principle is to choose the hyperparameters that makes our observed data most plausible/likely.

##### The Marginal Log-Likelihood (MLL)

The quantity we want to maximize is the Marginal Log-Likelihood (MLL). It is the answer to the question: Under our GP model with a given set of hyperparameters $\theta$, what was the probability of seeing the exact training data $y$ that we observed? 
The term "marginal" ise used because we do not care about the specific underlying latent function $f$. We integrate it out of the picture to the marginal probability of the data $y$. $$ p(y \mid X, \theta) = \int p(y \mid f)\, p(f \mid X,\theta)\, df. $$

For a GP, this probability has a convenient formula. We work with its logarithm for numerical stability and mathematical convenience:
$$
\underbrace{\log p(y \mid X, \theta)}_{\text{Log-Likelihood}}
= \underbrace{-\tfrac{1}{2}(y-m)^T (K_\theta + \sigma^2 I)^{-1} (y-m)}_{\text{Data Fit Term}}
\;+\; \underbrace{-\tfrac{1}{2}\log \text{det} \big( K_\theta + \sigma^2 I \big)}_{\text{Complexity Penalty}}
\;+\; \underbrace{-\tfrac{n}{2}\log 2\pi}_{\text{Constant}}.
$$

Maximizing the MLL involves a trade-off between two competing terms, which perfectly embodies the principle of Occam's Razor: 

1. **Data Fit Term (Goodness-of-Fit)** This term is just what you would expect: it gets better (less negative) when the model's predictions are closer to the actual data points. If this were the only term, the model would overfit wildly to match the data perfectly.

2. **Complexity Penalty** This term prevents overfitting. It penalizes models that are too complex. A model with a small `lengthscale`, for example, is very "complex" because it can wiggle a lot to fit any data. The GP prior says such functions are inherently less likely, and the log term reflects this by becoming large, which penalizes the overall MLL.

The MLL is maximized for the simplest model that still explains the data well.

In [None]:
from scipy.optimize import minimize

def untyped(func):
    """Return undecorated original function if beartype wrapped it."""
    return getattr(func, '__beartype_func', func)

class HyperparameterTuner:
    """Handles the logic of optimizing GP hyperparameters."""

    def __init__(self, X_train: jax.Array, y_train: jax.Array):
        self.X = jnp.asarray(X_train, dtype=jnp.float64)
        self.y = jnp.asarray(y_train, dtype=jnp.float64)

    def _get_opt_keys(self, kernel_name: str, param_init: dict) -> list[str]:
        """Determines which parameters to optimize for a given kernel."""
        base_params = param_init.get(kernel_name, {})
        return list(base_params.keys()) + ["log_noise"]

    def _pack_params(self, params_dict: dict, opt_keys: list[str]) -> np.ndarray:
        """Packs a dictionary of parameters into a NumPy array for the optimizer."""
        return np.array([params_dict.get(k, 0.0) for k in opt_keys], dtype=np.float64)

    def _unpack_params(self, theta: np.ndarray | jax.Array, opt_keys: list[str]) -> dict:
        """Unpacks an array from the optimizer into a dictionary."""
        return {key: val for key, val in zip(opt_keys, theta)}
    
    def _make_mll_objective(self, kernel_cls: type[Kernel], opt_keys: list[str]):
        """Creates the negative MLL function for the optimizer."""
        @jax.jit
        def neg_mll(theta_values: jax.Array) -> jax.Array:
            # 1. Unpack and transform parameters
            params = self._unpack_params(theta_values, opt_keys)
            log_noise = params.pop("log_noise")
            sigma2 = jnp.exp(2.0 * log_noise)
            kernel_params = {
                k[4:] if k.startswith("log_") else k: jnp.exp(v)
                if k.startswith("log_") else v 
                for k, v in params.items()
            }
            kernel = kernel_cls.from_params(kernel_params)
        
            # TODO: Implement the MLL calculation using the steps outlined below.
            NotImplementedError()
            
            # 3. Define the GP prior (using GaussianProcess class)
            
            # 4. Evaluate the prior at the training points. 

            # 5. The distribution of y is N(mu, K_theta + sigma^2 * I) - use log_pdf
            mll = ...
            return -mll

        return jax.value_and_grad(neg_mll)

    def fit(self, kernel_name: str, kernel_registry: dict, param_init_registry: dict) -> dict:
        """Performs the optimization for a given kernel."""
        
        # 1. Look up the specific kernel class and params from the registries
        kernel_cls = kernel_registry[kernel_name]
        base_params = param_init_registry.get(kernel_name, {}).copy()
        
        # 2. Setup initial parameters and optimization keys
        base_params.setdefault("log_noise", jnp.log(0.1 * jnp.std(self.y)))
        opt_keys = self._get_opt_keys(kernel_name, param_init_registry)
        theta0 = self._pack_params(base_params, opt_keys)

        # 3. Create the objective function and its gradient
        mll_value_and_grad = self._make_mll_objective(kernel_cls, opt_keys)

        def objective_for_scipy(theta: np.ndarray):
            val, grad = mll_value_and_grad(jnp.asarray(theta))
            return np.asarray(val, dtype=np.float64), np.asarray(grad, dtype=np.float64)

        # 4. Run the optimizer
        res = minimize(fun=objective_for_scipy, x0=theta0, jac=True, method="L-BFGS-B")
        
        # 5. Return results (this part is the same)
        optimized_params_log = self._unpack_params(res.x, opt_keys)
        final_params = {}
        for key, value in optimized_params_log.items():
            if key.startswith("log_"):
                final_params[key[4:]] = np.exp(value)
            else:
                final_params[key] = value
        final_params['noise'] = np.exp(optimized_params_log['log_noise'])
        final_params['sigma2'] = final_params['noise']**2

        return {
            "result": res,
            "optimized_params_log": optimized_params_log,
            "final_params": final_params,
        }

In [None]:
from chapter1 import OptimizationWidget

# Parameter initialization
PARAMETER_INIT = {
    "RBF": {"log_variance": 0.0, "log_lengthscale": 0.0},
    "Matérn 5/2": {"log_variance": 0.0, "log_lengthscale": 0.0},
    "Periodic": {"log_variance": 0.0, "log_lengthscale": 0.0, "period": np.pi * 2},
}

# 2. Create an instance of the tuner with your training data
n_obs = 10
tuner = HyperparameterTuner(X_train=X_train[:n_obs, :], y_train=y_train[:n_obs])

# 3. Create an instance of the widget, injecting all dependencies
optimization_widget = OptimizationWidget(
    tuner=tuner,
    kernel_registry=KERNEL_REGISTRY,
    param_init_registry=PARAMETER_INIT,
    test_data=(X_test, y_test),
    extrapolate=False,
    gp_cls=GaussianProcess,
)

# 4. Display the widget. This MUST be the last line in the cell.
optimization_widget.ui

### Design Your Own Kernel!

Now it is your turn to design a kernel. You can design custom kernels that capture complex, real-world structures by combining simpler ones. This is your chance to get creative and see if you can build a kernel that better explains our damped oscillator data. Your goal is to achieve the lowest possible Negative Marginal Log-Likelihood (MLL). The person with the best MLL wins!

#### Some helpful rules of Kernel Algebra.

Just like numbers, kernels have an algebra. If you have two valid kernels, $k_1$ and $k_2$, you can combine them in the following ways to create a new, valid kernel:

- **Addition:** The sum of two kernels is also a valid kernel: $$k_{\text{sum}}(x, x') = k_1(x, x') + k_2(x, x')\,.$$
- **Multiplication:** The product of two kernels is also a valid kernel: $$k_{\text{prod}} = k_1(x, x') * k_2(x, x')\,$$

In [None]:
# 1. DEFINE YOUR CUSTOM KERNEL
@dataclass
class LinearKernel(Kernel):
    """
    A simple linear kernel: k(x, y) = variance * x^T y.
    This kernel assumes the function is a straight line through the origin.
    """
    variance: FloatType = 1.0
    
    def __call__(self, X: Float[Array, "N D"], Y: Float[Array, "M D"]) -> Float[Array, "N M"]:
        # For 1D inputs, this is equivalent to variance * x * y
        return self.variance * (X @ Y.T)

# 2. ADD YOUR KERNEL TO THE REGISTRIES
KERNEL_REGISTRY["Linear"] = LinearKernel
PARAMETER_INIT["Linear"] = {
    "log_variance": jnp.log(1.0), 
    # Define the initial values. Putting `log_` prefix means we optimize in log-space.
}

# 3. RUN THE WIDGET
tuner = HyperparameterTuner(X_train=X_train, y_train=y_train)
widget = OptimizationWidget(
    tuner=tuner,
    kernel_registry=KERNEL_REGISTRY,
    param_init_registry=PARAMETER_INIT,
    test_data=(X_test, y_test),
    gp_cls=GaussianProcess,
)
widget.ui

## Part 2: Physics-Informed Regression (1D Poisson with CPU toy model)

#### Conditioning on Physics

In Part 1, we crafted a special, physics-informed kernel from the analytical solution of an ODE. This is elegant, but what if an analytical solutions isn't available?

In this chapter, we explore a more general and powerful technique: we start with a generic kernel (like Matern) that only assumes smoothness, and then we force its samples to obey the physics by directly conditioning the GP on the differential equation itself.

The core idea: We will treat the PDE as a source of data. Instead of only having data points of the form `(location, temperature)`, we will create new, virtual data points of the form `(location, value of the PDE's right-hand-side)`. This allows us to inject our physical knowledge directly into the standard GP regression framework.

#### The Toy Problem: 1D Heat Flow in a CPU

We will model the staedy-state temperature distribution, $u(x)$ along a 1D slice of a CPU bar. The physics is governed by the Poisson equation, a cornerstone of heat transfer and electrostatics:

$$ -\kappa \frac{d^2 u}{dx^2}(x) = \dot q_V(x) $$

Where:

- $u(x)$ is the temperature. 
- $\kappa$ is the material's thermal conductivity.
- $-\kappa \tfrac{d^2u}{dx^2}$ represents the net heat leaving a point due to conduction. 
- $\dot q_V(x)$ is the volumetric heat source (heat generated by CPU cores) and sink (heat lost to the environemnt).

Why is this a good approach?

By conditioning on the PDE, we:

- in ject strong prior knowledge about the system's behavior.

- achieve high data efficiency, requiring fewer physical sensors.

- obtain physically consistent predictions, especially when extrapolating outside the data range.

Let us see this in action. We will start with a generic prior and progressively add information.

In [None]:
from chapter2 import create_cpu_problem, plot_gp_belief_and_pde

# Setup the toy problem
problem = create_cpu_problem()

# Define our prior belief
prior_gp = GaussianProcess(
    m=ConstantFunction(value=60.0), 
    k=Matern52(variance=3.0, lengthscale=0.75 * problem.width)
)

# Define a grid for plotting
X_grid = np.linspace(0.0, problem.width, 400)

# Call the plotting function
plot_gp_belief_and_pde(
    gp=prior_gp,
    problem=problem,
    X_grid=X_grid,
    conditions=None,  # No data conditions yet
    n_samples=3,
    title="Prior Belief vs. True CPU Physics"
)

#### Linear observations

In "vanilla" GP regression, we condition on observations of the function itself. The key insight for physics-informed GPs is that if $L$ is a linear differential operator (like our $L = -\kappa \frac{d^2}{dx^2}$), we can treat $L_x(u)(x)$ as a new set of observations.

Because a GP is closed under linear operations, if our prior belief about the temperature $u$ is a GP:

$$ u(x) \sim \mathcal{GP}(m(x), k(x,x')) $$

Then our belief about the operator image, $L_xu(x)$ is also a Gaussian Process!

$$ (L_x u)(x) \sim \mathcal{GP}((L_xm)(x), (L_xL_{x'}k)(x, x')) $$

This means we can use the standard rules of GP conditioning on this new, "virtual" data.


#### Building the Joint Distribution

The process is the same as for standard regression: we build a joint Gaussian distribution over what we want to predict (the temperature $u(X)$) and what we are observing (the PDE's value at a set of "collocation points," $X_{\text{PDE}}$). 

Our "observation" is that $(Lu)(X_{pde})$ should be equal to the known heat source $\dot q_V(X_{pde})$. We can write this as a set of linear observations with some small noise tolerance $\varepsilon$:
$$\dot{q}_V(X_{\text{pde}}) = (L u)(X_{\text{pde}}) + \epsilon, \qquad \epsilon \sim \mathcal{N}(0, \sigma^2_{\text{pde}} I)$$

The joint distribution is:
$$
\begin{bmatrix}
u(X) \\[2pt]
(L u)(X_{\text{pde}})
\end{bmatrix}
\sim
\mathcal{N}\!\left(
\begin{bmatrix}
\text{Prior Mean} \\
\text{Prior Mean of PDE}
\end{bmatrix},
\begin{bmatrix}
\text{Cov}(u, u) & \text{Cov}(u, Lu) \\[2pt]
\text{Cov}(Lu, u) & \text{Cov}(Lu, Lu)
\end{bmatrix}
\right)
$$

This looks complex, but it's just the standard GP setup. Each block in the covariance matrix is derived from our prior kernel, $k(x, x')$, by applying the operator $L$ to its arguments. For our 1D heat equation ($L = -\kappa \frac{d^2}{dx^2}$):

- $\text{Cov}(u, u) = k(X, X)$
- $\text{Cov}(u, Lu) = (L_{x'}k)(X, X_{\text{pde}}) = -\kappa \frac{\partial^2}{\partial {x'}^2} k(X, X_{\text{pde}})$
- $\text{Cov}(Lu, Lu) = (L_x L_{x'}k)(X_{\text{pde}}, X_{\text{pde}}) = \kappa^2 \frac{\partial^4}{\partial x^2 \partial {x'}^2} k(X_{\text{pde}}, X_{\text{pde}})$

Once this joint distribution is built, we can apply the standard Gaussian conditioning formulas to get the posterior mean and covariance.


#### Doing the Math: The Covariances for a Matérn-5/2 Kernel

The theory is general, but to implement it, we need to compute the specific covariance terms like $\text{Cov}(u, Lu)$ and $\text{Cov}(Lu, Lu)$. This involves taking derivatives of our chosen prior kernel. Let's do this for a popular choice: the **Matérn-5/2 kernel**. Its twice-differentiable nature makes it a great candidate for modeling systems governed by second-order PDEs.

We start with the 1D isotropic Matérn-5/2 kernel, which is a function of the distance $r = |x - y|$:
$$k(x,y) = \phi(r) = \sigma^2\left(1 + a r + \frac{a^2 r^2}{3}\right)e^{-a r}, \quad \text{where} \quad a = \frac{\sqrt{5}}{\ell}$$

Our operator is $L = -\kappa \frac{d^2}{dx^2}$. To find the necessary covariance functions, we need to apply this operator to the kernel. In 1D, thanks to the chain rule, taking derivatives with respect to $x$ or $y$ simplifies to taking derivatives of $\phi(r)$ with respect to $r$.

The required derivatives of $\phi(r)$ are:
$$
\begin{aligned}
\phi''(r) &= \frac{a^2\sigma^2(a^2 r^2 - a r - 1)}{3}\,e^{-a r} \\
\phi^{(4)}(r) &= \frac{a^4\sigma^2(a^2 r^2 - 5 a r + 3)}{3}\,e^{-a r}
\end{aligned}
$$

Now we can construct the two key covariance blocks we need for our joint distribution:

1.  **The Cross-Covariance:** This term, $\text{Cov}(u(x), (Lu)(y))$, measures the correlation between the function at one point and the PDE's value at another.
    $$\text{Cov}(u, Lu) = L_y k(x,y) = -\kappa \frac{\partial^2}{\partial y^2}k(x,y) = -\kappa\,\phi''(|x-y|)$$

2.  **The PDE Covariance:** This term, $\text{Cov}((Lu)(x), (Lu)(y))$, forms the Gram matrix for our "PDE observations."
    $$\text{Cov}(Lu, Lu) = L_x L_y k(x,y) = \kappa^2 \frac{\partial^4}{\partial x^2 \partial y^2}k(x,y) = \kappa^2\,\phi^{(4)}(|x-y|)$$

With these two closed-form expressions, we have everything we need to build the full covariance matrix for conditioning.

#### Stacking It All Up: Combining All Your Knowledge

In a real problem, we don't just have one source of information. We might have:
- **Noisy sensor data:** $y_{\text{data}} = u(X_{\text{data}}) + \epsilon$
- **PDE constraints:** $y_{\text{pde}} = (Lu)(X_{\text{pde}}) + \epsilon$
- **Boundary conditions:** $y_{\text{bc}} = u(X_{\text{boundary}}) + \epsilon$

Each of these is a linear observation of our underlying function $u$. The most robust and stable way to combine them is to **stack them into a single, large linear system**. The core of GP conditioning is building a joint Gaussian distribution between what you want to predict and what you have observed. When we have multiple types of observations, we build a joint distribution over all of them.

This results in a "master" covariance matrix that has a **block structure**:

$$y_{\text{all}} = \begin{bmatrix} y_{\text{data}} \\ y_{\text{pde}} \\ y_{\text{bc}} \end{bmatrix}, \qquad K_{YY, \text{all}} = \begin{bmatrix} K_{\text{data,data}} & K_{\text{data,pde}} & K_{\text{data,bc}} \\ K_{\text{pde,data}} & K_{\text{pde,pde}} & K_{\text{pde,bc}} \\ K_{\text{bc,data}} & K_{\text{bc,pde}} & K_{\text{bc,bc}} \end{bmatrix}$$

- **Diagonal Blocks** (e.g., $K_{\text{data}, \text{data}}$): These are the standard Gram matrices you already know. They measure the internal correlation within a single type of observation (e.g., how two sensor readings relate to each other).

- **Off-Diagonal Blocks** (e.g., $K_{\text{data}, \text{pde}}$): This is the new, crucial part! These are the cross-covariance matrices. They measure the relationship between different types of observations (e.g. how a sensor reading at $x_1$ relates to the PDE being satisfied at $x_2$).

Then, we solve this single, larger system to get the final posterior.

In [30]:
import jax.scipy.linalg as jsl
from plum import dispatch

@dataclass
class LinearObservation:
    """
    A linear observation operator for f:
    
    Given (prior GP, X, y, [op params]), produce a linear block:
        - mY: expected observation mean
        - KYY_tilde: observation Gram + noise
        - KfY_fn: function X* -> k(f(X^*), Y)
    
    The provided 'y' is returned back so the conditioner can stack it.
    """
    
    def __call__(
        self,
        prior,                      # GaussianProcess
        X: Float[Array, "N D"],
        y: Float[Array, "N"],
    ) -> tuple[
        Float[Array, "N"], 
        Float[Array, "N N"], 
        Callable[[Float[Array, "M D"]], Float[Array, "N M"]], 
        Float[Array, "N"]
    ]:
        raise NotImplementedError()

@dataclass
class Block:
    op: LinearObservation                
    X: Float[Array, "N D"] | None        
    y_vec: Float[Array, "N"] | None
    mY: Float[Array, "N"] | None         
    KYY_tilde: Float[Array, "N N"] | None
    KfY_fn: Callable[[Float[Array, "M D"]], Float[Array, "N M"]] | None 

@dispatch
def cross_covariance(op1: LinearObservation, op2: LinearObservation, prior_k: KernelFunction) -> KernelFunction:
    msg = f"cross-covariance not implemented for {type(op1)} and {type(op2)}"
    raise NotImplementedError(msg)

@dataclass
class LazyConditionalGaussianProcess:
    prior: GaussianProcess
    jitter: float = 1e-6
    _blocks: list[Block] = None

    def add_condition(self, X: Float[Array, "N D"], y: Float[Array, "N"], op: LinearObservation):
        """
        Register a linear observation block defined by `op` at locations X with targets y.
        Nothing is solved yet - we only collect blocks.
        """
        mY, KYY_tilde, KfY_fn, y_vec = op(self.prior, X, y)
        KYY_tilde = KYY_tilde + self.jitter * jnp.eye(KYY_tilde.shape[0], dtype=KYY_tilde.dtype)
        
        # Store condition
        if not self._blocks:
            self._blocks = []
            
        self._blocks.append(
            Block(op=op, X=X, mY=mY, KYY_tilde=KYY_tilde, KfY_fn=KfY_fn, y_vec=y_vec)
        )

    def _compute_cross_covariance(self, block_i, block_j):
        op_i, X_i = block_i.op, block_i.X 
        op_j, X_j = block_j.op, block_j.X

        return cross_covariance(op_i, op_j, self.prior.k)(X_i, X_j)

    def condition(self) -> GaussianProcess:
        """
        Perform one joint conditioning step over all collected blocks and return a posterior
        GaussianProcess with generic (non-stationary) kernel.
        """
        if not self._blocks:
            return self.prior
        
        # Stack means and observations
        mY_all = jnp.concatenate([b.mY.reshape(-1) for b in self._blocks])
        y_all = jnp.concatenate([b.y_vec.reshape(-1) for b in self._blocks])

        # Build the full KYY matrix using the factory
        num_blocks = len(self._blocks)
        KYY_grid = [[None] * num_blocks for _ in range(num_blocks)]
        for i in range(num_blocks):
            for j in range(num_blocks):
                if i == j:
                    KYY_grid[i][j] = self._blocks[i].KYY_tilde
                else:
                    KYY_grid[i][j] = self._compute_cross_covariance(self._blocks[i], self._blocks[j])

        KYY_all = jnp.block(KYY_grid)
        KYY_all += self.jitter * jnp.eye(KYY_all.shape[0], dtype=KYY_all.dtype)
        
        # One Cholesky solve for all constraints
        L = jsl.cholesky(KYY_all, lower=True)
        resid = y_all - mY_all
        v = jsl.solve_triangular(L, resid, lower=True)
        alpha = jsl.solve_triangular(L.T, v, lower=False)

        # Helper to horizontally stack cross-covariances from all blocks
        def KfY_stacked(X: Float[Array, "M D"]) -> Float[Array, "M N_total"]:
            parts = [fn(jnp.atleast_2d(X)) for fn in [b.KfY_fn for b in self._blocks]]
            return jnp.concatenate(parts, axis=1)
        
        def m_post(Xstar: Float[Array, "M D"]) -> Float[Array, " M"]:
            Xstar = jnp.atleast_2d(Xstar)
            return self.prior.m(Xstar) + KfY_stacked(Xstar) @ alpha
        
        def _apply_KYY_inv(B: Float[Array, "N_total M"]) -> Float[Array, "N_total M"]:
            v = jsl.solve_triangular(L, B, lower=True)
            return jsl.solve_triangular(L.T, v, lower=False)

        def k_post(Xa: Float[Array, "M D"], Xb: Float[Array, "P D"]) -> Float[Array, "M P"]:
            Kxz = self.prior.k(Xa, Xb)
            KxY = KfY_stacked(Xa)
            KYz = KfY_stacked(Xb).T
            return Kxz - KxY @ _apply_KYY_inv(KYz)

        return GaussianProcess(
            m=m_post,
            k=k_post
        )

### Sensor Observations

Let us start by creating our first and most fundamental brick: a block for standard sensor observations.

We will call this block `SensorObservation`. It represents noisy, direct measurements of the temperature. It handles observations of the form:

$$ y_{\text{data}} = u(X_{\text{data}}) + \varepsilon, \qquad \varepsilon \sim \mathcal{N}(0, \sigma^2 I) $$

It needs to be able to compute three things:

1. The prior mean of the observations: $$ \mathbb{E}[u(X)] $$
2. The prior covariance among the observations (its Gram matrix): $$\text{Cov}(u(X), u(X))$$
3. The cross-covariance between a new prediction point $X^*$ and the observations: $$ \text{Cov}(u(X^*), u(X)) $$

In [31]:
@dataclass
class SensorObservation(LinearObservation):
    r"""Noisy point observations: y = u(X) + eps, eps \sim N(0, sigma^2 I)."""

    sigma2: float

    # This is a helper for the cross-covariance registy later.
    def L_k(self, prior_k, X, X_other):
        """Computes cov(u(X), u(X_other)) = k(X, X_other)."""
        return prior_k(X, X_other)
    
    def __call__(        
        self, prior: KernelFunction, X: Float[Array, "N 1"], y: Float[Array, " N"],
    ):

        # YOUR CODE HERE
        # TODO: Implement the three key components for a standard GP observation.

        # 1. Calculate `mY`, the prior mean at the observation points `X`. 
        # 2. Calculate `KYY_tilde`, the prior covariance at `X`, and remember to add the sensor noise variance `self.sigma2`.
        # 3. Define `KfY_fn`, a function that takes new points `Xstar` and returns the cross-covariance between ``u(Xstar)`` and ``u(X)``.

        mY = ...
        KYY_tilde = ...

        def KfY_fn(Xstar: Float[Array, "Nstar 1"]) -> Float[Array, "Nstar N"]:
            return ...

        raise NotImplementedError("You need to implement the SensorObservation operator!")

        return mY, KYY_tilde, KfY_fn, y

# Here we provide the cross-covariance definition.
# This tells the stacking engine how to compute the covariance
# when two SensorObservation blocks interact (which is just the base kernel).
@cross_covariance.dispatch
def _(op1: SensorObservation, op2: SensorObservation, prior_k: KernelFunction) -> KernelFunction:
    def k(X1: Float[Array, "N D"], X2: Float[Array, "M D"]) -> Float[Array, "N M"]:
        return op1.L_k(prior_k, X1, X2)
    return k

In [None]:
lazy_gp = LazyConditionalGaussianProcess(prior=prior_gp, jitter=1e-9)

# Access the synthetic sensor data from the object's attributes
sensor_locations = problem.X_dts
noisy_measurements = problem.y_dts
noise_level_std = problem.dts_noise_std


lazy_gp.add_condition(
    X=jnp.asarray(sensor_locations).reshape(-1, 1),
    y=jnp.asarray(noisy_measurements).reshape(-1),
    op=SensorObservation(sigma2=noise_level_std**2)
)

post_gp = lazy_gp.condition()

# Plot the posterior belief and PDE balance
plot_gp_belief_and_pde(
    gp=post_gp,
    problem=problem,
    X_grid=X_grid,
    conditions=lazy_gp._blocks,
    n_samples=3,
    title="Posterior Belief after Conditioning on DTS Data"
)

#### Adding PDE observation

Now for the most important block: encoding the physics. The `PDEObservation` block represents our knowledge that the underlying temperature field $u(x)$ must satify the Poisson equation. It handles "virtual" observations of the form:

$$ y_{\text{PDE}} = (- \kappa u'')(X_{\text{pde}}) + \varepsilon, \qquad \text{where } y_{\text{pde}} \text{ is the known heat source } \dot q_V(X_{\text{pde}}). $$

To build this block, we need to compute the covariances that involve the linear operator $L= -\kappa d^2/dx^2$. This is where the kernel derivaties we derived earlier come into play. This block will be responsible for calculating:

1. The PDE-data cross-covariance: $$ \text{Cov}(u(X), (Lu)(X_{\text{pde}})) $$

2. The PDE-PDE covariance: $$ \text{Cov}((Lu)(X_\text{pde})) $$

Your task is to implement the methods that compute these two crucial components.

In [33]:
@dataclass
class PDEObservation(LinearObservation):
    """
    Strong-form PDE collocation in 1D for L = -kappa d^2/dx^2 with a Matérn-5/2 prior.
    """
    kappa: float
    sigma2: float

    def _check_if_matern52(self, prior_k: KernelFunction):
        if not isinstance(prior_k, Matern52):
            msg = f"PDEObservation only supports Matérn-5/2 kernels, got {type(prior_k)}"
            raise ValueError(msg)

    def L_k(self, prior_k: KernelFunction, X: Float[Array, "Ndata D"], X_pde: Float[Array, "Npde D"]) -> Float[Array, "Ndata Npde"]:
        """Computes cov(f(X), Lf(X_pde)) = ... """

        self._check_if_matern52(prior_k)        
        variance, lengthscale = prior_k.variance, prior_k.lengthscale 
        
        # --- YOUR CODE HERE ---
        # TODO: Implement the formula for the cross-covariance cov(f(X), Lf
        # (X_pde))
        raise NotImplementedError("Implement the L_k method for PDEObservation")

    def L_k_L(self, prior_k: KernelFunction, X_pde1: Float[Array, "NPDE1 D"], X_pde2: Float[Array, "NPDE2 D"]) -> Float[Array, "NPDE1 NPDE2"]:
        """Computes cov(Lf(X_pde1), Lf(X_pde2)) = ... """
        self._check_if_matern52(prior_k)
        variance, lengthscale = prior_k.variance, prior_k.lengthscale

        # --- YOUR CODE HERE
        # TODO: Implement the formula for the cross-covariance cov(Lf(X_pde1), Lf(X_pde2))
        raise NotImplementedError("Implement the L_k_L method for PDEObservation")

    def __call__(
            self, prior: KernelFunction, X: Float[Array, "N D"], y: Float[Array, " N"]
        ) -> tuple[Float[Array, " N"], Float[Array, "N N"], Callable, Float[Array, " N"]]:
        
        # Mean: (-kappa) * m''(X_pde). m''=0 for a constant mean.
        mY = jnp.zeros_like(y)
        KYY_tilde = self.L_k_L(prior.k, X, X) + self.sigma2 * jnp.eye(X_pde.shape[0])

        def KfY_fn(Xstar: Float[Array, "Nstar D"]) -> Float[Array, "Nstar N"]:
            return self.L_k(prior.k, Xstar, X)
        
        return mY, KYY_tilde, KfY_fn, y


@cross_covariance.dispatch
def _(op1: PDEObservation, op2: SensorObservation, prior_k: KernelFunction) -> KernelFunction:
    def k(X1: Float[Array, "N D"], X2: Float[Array, "M D"]) -> Float[Array, "N M"]:
        return op1.L_k(prior_k, X1, X2)
    return k


@cross_covariance.dispatch
def _(op1: SensorObservation, op2: PDEObservation, prior_k: KernelFunction) -> KernelFunction:
    def k(X1: Float[Array, "N D"], X2: Float[Array, "M D"]) -> Float[Array, "N M"]:
        return op2.L_k(prior_k, X2, X1).T
    return k


@cross_covariance.dispatch
def _(op1: PDEObservation, op2: PDEObservation, prior_k: KernelFunction) -> KernelFunction:
    def k(X1: Float[Array, "N D"], X2: Float[Array, "M D"]) -> Float[Array, "N M"]:
        return op1.L_k_L(prior_k, X1, X2)
    return k

In [None]:
# Define PDE collocation points using the problem's width
X_pde = np.linspace(0.0, problem.width, 18, endpoint=False)[1:] # Avoid the boundary

lazy_gp.add_condition(
    X=jnp.asarray(X_pde).reshape(-1, 1),
    y=jnp.asarray(problem.q_total(X_pde)),
    op=PDEObservation(kappa=problem.kappa, sigma2=0.1**2)
)

# Compute the posterior GP after adding the PDE conditions
post_gp = lazy_gp.condition()


# Call the plotting function with the posterior GP
# plot_conditions = create_plot_conditions_from_blocks(lazy_gp._blocks)
plot_gp_belief_and_pde(
    gp=post_gp,
    problem=problem,
    X_grid=X_grid,
    conditions=lazy_gp._blocks,
    n_samples=1,
    seed=42,
    title="GP Conditioned on DTS Data and PDE"
)

#### Adding boundary conditions

We have now conditioned our GP on sensor data and the underlying physics. If you look at the last plot, however, you will notice the uncertainty (the shaded blue region) is still quite large at the very edges of the domain. Our model is still unsure what the exact temperature is at $x = 0$ and $x = L$.

The final piece is to add boundary conditions. A common type is the Dirichlet boundary condition, which directly specifies the value of the function at the boundary:

$$ u(x) = c, \qquad \text{ for } x \in \delta \Omega. $$

In our 1D case, this means we specify the temperature at $x = 0$ and $x = L$. F

In [35]:
# YOUR CODE HERE
# TODO: Implement the entire class `DirichletObservation` and visualize using the logic from before.

In [None]:
# Get boundary condition from the problem definition
X_bc = jnp.asarray([problem.domain[0], problem.domain[1]]).reshape(-1, 1)
y_bc = jnp.asarray(problem.solution(X_bc.ravel()))

# Use a very small noise value to enforce the constraint strictly
bc_noise_std = 1e-5
bc_noise_var = bc_noise_std**2

# Add Dirichlet Condition
lazy_gp.add_condition(
    X=X_bc,
    y=y_bc,
    op=DirichletObservation(sigma2=bc_noise_var)
)

# Solve for the new posterior and visualize
post_gp_bc = lazy_gp.condition()

# Visualize
plot_gp_belief_and_pde(
    gp=post_gp_bc,
    problem=problem,
    X_grid=X_grid,
    n_samples=1,
    conditions=lazy_gp._blocks,
    title="Posterior Belief after Conditioning on Boundary Values"
)

The second part rebuilds the framework behind the paper ["Physics-Informed Gaussian Process Regression Generalizes Linear PDE Solvers"](arxiv.org/abs/2212.12474).

The accompanying code basis contains a much more general framework (more kernels, more PDEs, more conditions) and a multiple of additional tutorials:

https://github.com/marvinpfoertner/linpde-gp