Constructing an interpolating lookup table, so we can simply look up the best length scale, given two points $D_1 = (x_1, y_1)$ and $D_2 = (x_2, y_2)$. The lookup table $f$ is constructed such that what we look up is $\log(ell) = f(n, s),$ where $n$ indicates normalized measurement noise level and $s$ indicates normalized signal level. The normalization is such that $D_1 = (0,0)$ and $D_2 = (1,1)$. 

We compute $n$ and $s$ from $\sigma_n$ and $\sigma_s$. We fix $\sigma_n$ by what we know about the lab test measurement noise (standard deviation of the measurement error), and we estimate $\sigma_s$ as the standard deviation of the signal, by taking the standard devation of a quick smooth fit to the measurements. That won't be a perfect estimate, but it's the best we have. 

If  $\Delta y = y_2 - y_1$, $\Delta x = x_2 - x_1$, then we 
have $n = \log(\sigma_n / \Delta y)$ and $s = \log(\sigma_s / \Delta y).$

Then we compute 
\begin{align}
    ell &= \exp(f(n, s)) \Delta x \\
        &= \exp(f(\log(\sigma_n / \Delta y), \log(\sigma_s / \Delta y)) \Delta x.
\end{align}  

One remaining question: Why is the average loss (`losses`) of the 100 runs at a given set of parameters different from the recomputed loss (`gp_losses`) at those same parameters?

In [None]:
import math
import numpy as np
import pandas as pd
import torch
import gpytorch
import pickle
from copy import deepcopy
from matplotlib import pyplot as plt
from numpy.random import default_rng
from timeit import default_timer as timer
from matplotlib.backends.backend_pdf import PdfPages
from scipy.interpolate import PchipInterpolator, RectBivariateSpline
from datetime import datetime
from warnings import simplefilter, warn, catch_warnings, filterwarnings
import logging
import pickle
from tqdm.notebook import trange, tqdm
from gpytorch.utils.warnings import GPInputWarning
# %matplotlib inline
%load_ext autoreload
%autoreload 2


In [None]:
class GPModel(gpytorch.models.ExactGP):
    def __init__(
        self,
        train_inputs,
        train_targets,
        ell_prior=gpytorch.priors.GammaPrior(concentration=2.0, rate=1.0),
        likelihood=gpytorch.likelihoods.GaussianLikelihood(),
        seed=None
    ):
        """Create a stationary GP model

        Args:
            train_inputs (float): observation locations
            train_targets (float): observation values
            ell_prior(gpytorch prior): Defaults to GammaPrior(concentration=2.0, rate=0.1)
            likelihood(gpytorch likelihood): Defaults to GaussianLikelihood().
        """
        train_inputs = torch.as_tensor(train_inputs)
        train_targets = torch.as_tensor(train_targets)
        self.rng = default_rng(seed)
        super().__init__(train_inputs, train_targets, likelihood)

        self.mean_module = gpytorch.means.ConstantMean()

        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(lengthscale_prior=ell_prior))
        self.likelihood = likelihood
        

    @property
    def lengthscale(self):
        return self.covar_module.base_kernel.lengthscale.item()

    @lengthscale.setter
    def lengthscale(self, value):
        self.covar_module.base_kernel.lengthscale = torch.tensor([value])

    @property
    def noise_var(self):
        return self.likelihood.noise.item()

    @property
    def loss(self):
        return self.best_loss.item()

    @property
    def signal_var(self):
        return self.covar_module.outputscale.item()

    @property
    def offset(self):
        return self.mean_module.constant.item()
        
    def loss_on(self, x, y):
        self.eval()
        self.likelihood.eval()
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self)
            output = self.__call__(torch.as_tensor(x))
            ll = -mll(output, torch.as_tensor(y))
            return ll

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

    def fit(self, *, n_iter=100,  progress_threshold=100, tol=0.0, lr=0.1, signal_var=None, noise_var=None, offset=None, lengthscale=None, verbose=False):
        # The initializations are specified with a two-point training set 
        # (0,0), (1,1) in mind.

        if offset is not None:
            self.mean_module.constant = offset 
            self.mean_module.raw_constant.requires_grad = False
        else:
            self.mean_module.constant = self.rng.uniform(-10.0, 10.0)
            self.mean_module.raw_constant.requires_grad = True
        
        if lengthscale is not None:
            self.covar_module.base_kernel.lengthscale = lengthscale
            self.covar_module.base_kernel.raw_lengthscale.requires_grad = False
        else:
            self.covar_module.base_kernel.lengthscale = self.rng.uniform(1e-2,10.0)
            self.covar_module.base_kernel.raw_lengthscale.requires_grad = True

        if signal_var is not None:
            self.covar_module.outputscale = signal_var 
            self.covar_module.raw_outputscale.requires_grad = False
        else:
            self.covar_module.outputscale = self.rng.uniform(1e-2, 10)
            self.covar_module.raw_outputscale.requires_grad = True

        if noise_var is not None:
            self.likelihood.noise = noise_var
            self.likelihood.raw_noise.requires_grad = False
        else:
            # mean: ln(0.1) = -2.3
            self.likelihood.noise = self.rng.lognormal(mean=-2.3, sigma=1.0)
            self.likelihood.raw_noise.requires_grad = True

        losses = torch.FloatTensor(n_iter)
        lengthscales = torch.FloatTensor(n_iter)
        noise_vars = torch.FloatTensor(n_iter)
        signal_vars = torch.FloatTensor(n_iter)
        self.best_loss = torch.inf
        best_index = None
        last_loss = torch.inf
        self.best_lengthscale = None
        self.train()
        self.likelihood.train()

        params = self.parameters()
        
        optimizer = torch.optim.Adam(params, lr=lr)

        # "Loss" for GPs - the marginal log likelihood
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self)
        n_no_improvement = 0
        n_small_improvement = 0

        for i in range(n_iter):
            

            # Zero gradients from previous iteration
            optimizer.zero_grad()
            # Output from model
            # TODO: figure out how to work with a tuple of inputs, if needed.
            output = self.__call__(self.train_inputs[0])
            # Calc loss and backprop gradients
            loss = -mll(output, self.train_targets)
            loss.backward()

            losses[i] = loss
            lengthscales[i] = self.lengthscale
            noise_vars[i] = self.noise_var
            signal_vars[i] = self.signal_var

            best_flag = ""
            n_no_improvement += 1
            n_small_improvement += 1
            if loss < self.best_loss:
                n_no_improvement = 0
                best_flag = "*"
                self.best_loss = loss
                best_model = deepcopy(self.state_dict())
                best_index = i
                self.best_lengthscale = self.lengthscale
            
            if verbose:
                ln_string = f"{self.lengthscale:.4f}"
                print(
                    f"Iter {i:04d}/{n_iter} - Loss: {loss.item():.4f}  Lengthscale: {ln_string} noise: {self.noise_var:.3f} ({best_index:04d} {ln_string}){best_flag}"
                )

            if n_no_improvement > progress_threshold:
                if verbose:
                    print( f"Stopped after {n_no_improvement} iterations without improvement.")
                break

            if last_loss - loss > tol:
                n_small_improvement = 0
                
            if n_small_improvement > progress_threshold:
                if verbose:
                    print( f"Stopped after {n_small_improvement} iterations with only small improvement.")
                break
            last_loss = loss
            
            optimizer.step()

        self.load_state_dict(best_model)
        j = i+1
        return losses.detach().numpy()[:j], noise_vars.detach().numpy()[:j], lengthscales.detach().numpy()[:j], signal_vars.detach().numpy()[:j], best_index

    def predict(self, test_inputs):
        self.eval()
        self.likelihood.eval()
        test_inputs = torch.as_tensor(test_inputs)
        # Make predictions by feeding model through likelihood
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            preds = self.likelihood(self.__call__(test_inputs))

        p_mean = preds.mean.numpy()
        p_lower, p_upper = [x.numpy() for x in preds.confidence_region()]
        return p_mean, p_lower, p_upper

    def __str__(self):
        return f"GPModel Loss: {self.loss:0.10f}, ell: {self.lengthscale:0.4f}, svar: {self.signal_var}, nvar: {self.noise_var:0.4f}"

In [None]:
def get_preds(gp, x, y, n):
    pad = (x[-1] - x[0]) * 0.1
    pred_x = np.linspace(x[0] - pad, x[-1] + pad, num=n)
    pred_y, lower, upper = gp.predict(pred_x)
    return pred_x, pred_y, lower, upper

def plot_preds(gp, x, y, n=100, noise_var=None, signal_var=None, lengthscale=None):
    pred_x, pred_y, lower, upper = get_preds(gp, x, y, n=n)
    plt.plot(pred_x, pred_y, color='b')
    plt.plot(x, y, 'ob')
    plt.fill_between(pred_x, lower, upper, color='b', alpha=0.1)
    plt.axhline(gp.offset, color='green', alpha=0.1)
    barwidth = 4
    if noise_var is not None:
        plt.errorbar(x, y, yerr=np.sqrt(noise_var), linewidth=0, elinewidth=barwidth, ecolor='blue', alpha=0.5)
    if signal_var is not None:
        plt.plot((x[0], x[0]), (gp.offset, gp.offset+np.sqrt(signal_var)), '-b', linewidth=barwidth, alpha=0.5)
    if lengthscale is not None:
        plt.plot((x[0], x[0]+lengthscale), (gp.offset, gp.offset), '-b',linewidth=barwidth, alpha=0.5)


In [None]:
def plot_training(losses, noise_vars, lengthscales, signal_vars, best_iter):
    fig, ax1 = plt.subplots()
    ax1.semilogy(losses, 'b-', label=f'loss ({losses[best_iter]:0.4g})')
    # plt.semilogy(noise_vars, label=f'noise var ({noise_vars[best_iter]:0.2g})')
    ax2 = ax1.twinx()
    ax2.semilogy(lengthscales, 'g-', label=f'ell ({lengthscales[best_iter]:0.2g})')
    # plt.semilogy(signal_vars, label=f'signal var ({signal_vars[best_iter]:0.2g})')
    ax2.axvline(best_iter)
    ax1.set_xlabel('iteration')
    ax1.set_ylabel('Loss', color='b')
    ax2.set_ylabel('lengthscale', color='g')
    fig.legend()

In [None]:
def fit_gp_once(x, y, *, verbose, n_iter, **kwargs):
    gp = GPModel(x, y, ell_prior=None)
    losses, nvs, ells, svs, best_iter = gp.fit(verbose=False, n_iter=n_iter, **kwargs) 
    if verbose:
        print(f"Iter {best_iter}/ {len(losses)} / {n_iter}")
    if best_iter == n_iter - 1:
        logging.warning(f"Fit reached iteration limit: {gp}")
        # plot_training(losses=losses, noise_vars=nvs, lengthscales=ells, signal_vars=svs, best_iter=best_iter)
    return gp


In [None]:
def fit_gp_sampled(x, y,  *, verbose, n_samples, **kwargs ):
    losses = np.empty(n_samples)
    ells = np.empty(n_samples)
    for i in range(n_samples):
        gp = fit_gp_once(x, y, verbose=verbose, **kwargs)
        losses[i] = gp.loss
        ells[i] = gp.lengthscale
        if verbose:
            print(f"[{i:03d}]: {gp}")
    ell_mean = np.average(ells, weights=np.exp(-losses))
    gp.lengthscale = ell_mean
    
    return gp, ell_mean, ells, losses


In [None]:
res=50
n_samples = 100
logging.basicConfig(filename=f'sampled_table_creation_rep_{n_samples}_res_{res}_v2.log', encoding='utf-8', filemode='w', level=logging.INFO)
filterwarnings("ignore", category=GPInputWarning)

x = np.array([0.0, 1.0])
y = np.array([0.0, 1.0])
offset = y.mean()

n_iter = 1000
tol = 1e-8
progress_threshold = 100
lr = 0.2

signal_sds = np.logspace(-1, 1, num=res)
slen = len(signal_sds)
noise_sds = np.logspace(-1, 1, num=res)
nlen = len(noise_sds)
ells = np.full((slen, nlen), None)
losses = np.full((slen, nlen), None)
gp_losses = np.full((slen, nlen), None)
for si in tqdm(range(slen), desc="signal"):
    for ni in tqdm(range(nlen), desc="noise", leave=False):
        sv = signal_sds[si]**2
        nv = noise_sds[ni]**2
        gp, ell_mean, iter_ells, iter_losses = fit_gp_sampled(x=x, y=y, offset=offset, signal_var=sv, noise_var=nv, n_samples=n_samples, n_iter=n_iter, progress_threshold=progress_threshold, tol=tol, lr=lr, verbose=False)    
        ells[si, ni] = ell_mean
        lm = iter_losses.mean()
        losses[si, ni] = lm
        lo = gp.loss_on(x, y)
        gp_losses[si, ni] = lo
        logging.info(f"{datetime.now()} sv: {sv:0.2g}, nv:{nv:0.2g}, ell: {gp.lengthscale:0.3g} mean loss: {lm:0.5g} rerun loss: {lo:0.5g}")

with open(f'table_creation_rep_{n_samples}_res_{res}_v2.pkl', 'wb') as f:
    pickle.dump(ells, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(losses, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(gp_losses, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(signal_sds, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(noise_sds, f, pickle.HIGHEST_PROTOCOL)


In [None]:
with open(f'table_creation_rep_{n_samples}_res_{res}_v2.pkl', 'wb') as f:
    pickle.dump(ells, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(losses, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(gp_losses, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(signal_sds, f, pickle.HIGHEST_PROTOCOL)
    pickle.dump(noise_sds, f, pickle.HIGHEST_PROTOCOL)

In [None]:
for k in range(len(noise_sds)):
    plt.loglog(noise_sds, ells[k,:], '-')
    plt.xlabel('Measurement Noise SD')
    plt.ylabel('Length Scale')
plt.figure()
for k in range(len(signal_sds)):
    plt.loglog(signal_sds, ells[:,k], '-')
    plt.xlabel('Signal SD')
    plt.ylabel('Length Scale')
plt.figure()
for k in range(len(noise_sds)):
    plt.loglog(noise_sds, losses[k,:], '-')
    plt.xlabel('Measurement Noise SD')
    plt.ylabel('Average Loss')
plt.figure()
for k in range(len(signal_sds)):
    plt.loglog(signal_sds, losses[:,k], '-')
    plt.xlabel('Signal SD')
    plt.ylabel('Average Loss')
plt.figure()
for k in range(len(noise_sds)):
    plt.loglog(noise_sds, gp_losses[k,:], '-')
    plt.xlabel('Measurement Noise SD')
    plt.ylabel('Recomputed Loss')
plt.figure()
for k in range(len(signal_sds)):
    plt.loglog(signal_sds, gp_losses[:,k], '-')
    plt.xlabel('Signal SD')
    plt.ylabel('Recomputed Loss')
for k in range(len(signal_sds)):
    plt.loglog(losses[:,k], gp_losses[:,k], '-')
    plt.xlabel('Average Loss')
    plt.ylabel('Recomputed Loss')

In [None]:
from mpl_toolkits.mplot3d import axes3d
%matplotlib widget

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
X, Y = np.meshgrid(np.log10(noise_sds), np.log10(signal_sds))
# X, Y = np.meshgrid(noise_sds, signal_sds)
Z = np.log10(ells.astype(float))
ax.plot_surface(X, Y, Z, cmap="autumn_r", lw=0.5, rstride=1, cstride=1, alpha=0.5)
# ax.contour(X, Y, Z, 10, lw=3, cmap="autumn_r", linestyles="solid", offset=-5)
ax.contour(X, Y, Z, 10, lw=3, colors="k", linestyles="solid")
plt.show()

In [None]:
from scipy import interpolate
x, y = np.log10(noise_sds), np.log10(signal_sds)
Z = np.log10(ells.astype(float))
interp_func = interpolate.interp2d(x, y, Z, kind='linear')


In [None]:
with open(f'../results/lookup_{n_samples}_res_{res}_v2.pkl', 'wb') as f:
    pickle.dump(interp_func, f, pickle.HIGHEST_PROTOCOL)



Creating a RectBivariateSpline version, because interp2d doesn't handle vectorized input for lookup.

In [None]:
res=50
n_samples = 100

x = np.array([0.0, 1.0])
y = np.array([0.0, 1.0])
offset = y.mean()

n_iter = 1000
tol = 1e-8
progress_threshold = 100
lr = 0.2

signal_sds = np.logspace(-1, 1, num=res)
slen = len(signal_sds)
noise_sds = np.logspace(-1, 1, num=res)
nlen = len(noise_sds)

In [None]:
with open(f'table_creation_rep_100_res_50_v2.pkl', 'rb') as f:
    ells = pickle.load(f)
    losses = pickle.load(f)
    gp_losses = pickle.load(f)
    signal_sds = pickle.load(f)
    noise_sds = pickle.load(f)

In [None]:
x, y = np.log10(noise_sds), np.log10(signal_sds)
Z = np.log10(ells.astype(float))
interp_func = RectBivariateSpline(x, y, Z, kx=1, ky=1)

In [None]:
plt.figure()
xnew = np.arange(-1.00, 1.00, 1e-2)

ynew = np.arange(-1.00, 1.00, 1e-2)

znew = interp_func(xnew, ynew)
i = 30
plt.plot(x, Z[i, :], 'ro')
plt.plot(xnew, znew[i * 4, :], 'bo-', alpha=0.2)

plt.show()

In [None]:
z_reconstructed = interp_func(x, y)
z_delta = Z - z_reconstructed

(np.all(z_delta == 0), np.all(z_reconstructed == Z))

In [None]:
znew = interp_func(xnew, ynew, grid=False)

In [None]:
plt.plot(znew)

In [None]:
with open(f'../results/rbs_lookup_{n_samples}_res_{res}.pkl', 'wb') as f:
    pickle.dump(interp_func, f, pickle.HIGHEST_PROTOCOL)