In [1]:
# MI-regularized Additive GP vs OAK

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import gpflow
from gpflow.utilities import print_summary, positive
from gpflow.models import GPR
from gpflow.optimizers import Scipy
from itertools import combinations
import functools

# ---- 1) Synthetic Datasets ----
def make_additive_data(N=50, D=3, noise=0.1, seed=0):
    tf.random.set_seed(seed)
    np.random.seed(seed)
    X = np.random.uniform(-1, 1, (N, D))
    f_terms = []
    if D >= 1: f_terms.append(np.sin(2 * np.pi * X[:, 0]))
    if D >= 2: f_terms.append(X[:, 1]**2 - 1/3)
    if D >= 3: f_terms.append(0.5 * X[:, 2])
    y = sum(f_terms) + noise * np.random.randn(N)
    return X.astype(np.float64), y[:, None].astype(np.float64)

def make_interaction_data(N=50, D=3, noise=0.1, seed=1):
    tf.random.set_seed(seed)
    np.random.seed(seed)
    X = np.random.uniform(-1, 1, (N, D))
    y_main = 0
    if D >= 1: y_main += np.sin(np.pi * X[:, 0])
    if D >= 2: y_main += X[:, 1]**2 - 1/3
    y_inter = X[:, 0] * X[:, 1] if D >= 2 else 0
    y = y_main + y_inter + noise * np.random.randn(N)
    return X.astype(np.float64), y[:, None].astype(np.float64)

INPUT_DIM = 3
X_add, y_add = make_additive_data(D=INPUT_DIM)
X_int, y_int = make_interaction_data(D=INPUT_DIM)

# ---- 2) Kernel factories ----
def create_base_kernel(active_dims, variance=1.0, lengthscales=1.0):
    return gpflow.kernels.SquaredExponential(
        active_dims=active_dims,
        variance=gpflow.Parameter(variance, transform=positive()),
        lengthscales=gpflow.Parameter(lengthscales, transform=positive())
    )

def create_centered_se_base_kernel(X_sample):
    def _creator(active_dims):
        return CenteredSE(
            X_sample_for_centering=X_sample,
            active_dims=active_dims,
            lengthscales=1.0,
            variance=1.0
        )
    return _creator

def general_additive_kernel(input_dim, max_order, base_kernel_creator_fn, name_prefix=""):
    # Build one base kernel per dimension
    dim_kernels = [base_kernel_creator_fn([d]) for d in range(input_dim)]
    terms = []
    # Bias
    terms.append(gpflow.kernels.Constant(
        variance=gpflow.Parameter(1.0, transform=positive(), name=f"{name_prefix}bias_var")))
    # 1..max_order
    for order in range(1, max_order + 1):
        for comb in combinations(range(input_dim), order):
            # product of base kernels
            prod = functools.reduce(lambda a,b: a*b, [dim_kernels[i] for i in comb])
            var_name = f"{name_prefix}ord{order}_{'_'.join(map(str,comb))}_var"
            sigma = gpflow.Parameter(1.0, transform=positive(), name=var_name)
            terms.append(prod * gpflow.kernels.Constant(variance=sigma))
    return gpflow.kernels.Sum(terms)

class CenteredSE(gpflow.kernels.Kernel):
    def __init__(self, X_sample_for_centering, active_dims, lengthscales=1.0, variance=1.0):
        super().__init__(active_dims=active_dims)
        self.variance = gpflow.Parameter(variance, transform=positive())
        self.lengthscales = gpflow.Parameter(lengthscales, transform=positive())
        Xs = tf.cast(X_sample_for_centering, gpflow.default_float())
        if len(active_dims) != 1:
            raise ValueError("CenteredSE needs a single active_dim")
        idx = active_dims[0]
        self.X_sample = Xs[:, idx:idx+1]

    def _se(self, X1, X2):
        sq = (X1 - tf.transpose(X2))**2
        return self.variance * tf.exp(-0.5 * sq / self.lengthscales**2)

    def K(self, X, X2=None):
        X1 = tf.gather(X, self.active_dims, axis=1)
        X2 = X1 if X2 is None else tf.gather(X2, self.active_dims, axis=1)
        Kxz = self._se(X1, X2)
        Exu = tf.reduce_mean(self._se(X1, self.X_sample), axis=1, keepdims=True)
        Euz = tf.reduce_mean(self._se(self.X_sample, X2), axis=0, keepdims=True)
        Euv = tf.reduce_mean(self._se(self.X_sample, self.X_sample))
        return Kxz - Exu - Euz + Euv

    def K_diag(self, X):
        X1 = tf.gather(X, self.active_dims, axis=1)
        diag = tf.linalg.diag_part(self._se(X1, X1))
        Exu = tf.reduce_mean(self._se(X1, self.X_sample), axis=1)
        Euv = tf.reduce_mean(self._se(self.X_sample, self.X_sample))
        return diag - 2*Exu + Euv

# ---- 3) MI penalty ----
def tf_pairwise_gaussian_mi_penalty(comps, eps=1e-6):
    means = comps - tf.reduce_mean(comps,0,keepdims=True)
    std = tf.sqrt(tf.reduce_mean(means**2,0,keepdims=True) + eps)
    normed = means / std
    N = tf.cast(tf.shape(normed)[0], normed.dtype)
    Corr = tf.matmul(normed, normed, transpose_a=True) / N
    Corr = Corr - tf.linalg.diag(tf.linalg.diag_part(Corr))
    MI = -0.5 * tf.math.log(1 - Corr**2 + eps)
    mask = tf.linalg.band_part(tf.ones_like(MI), 0, -1) - tf.eye(tf.shape(MI)[0], dtype=MI.dtype)
    return tf.reduce_sum(MI * mask)

# ---- 4) GP training ----
def get_component_posterior_means(model, X_tf):
    if not isinstance(model.kernel, gpflow.kernels.Sum):
        return [model.predict_f(X_tf)[0]]
    K = model.kernel.K(X_tf) + tf.eye(tf.shape(X_tf)[0], dtype=gpflow.default_float())*model.likelihood.variance
    alpha = tf.linalg.inv(K) @ model.data[1]
    mats = []
    for k in model.kernel.kernels:
        Kc = k.K(X_tf, X_tf)
        mats.append(Kc @ alpha)
    return tf.concat(mats, axis=1)

def train_gp(X, y, kernel, penalty_type=None, beta=0.0, maxiter=100):
    xb = tf.constant(X, gpflow.default_float())
    yb = tf.constant(y, gpflow.default_float())
    model = GPR((xb, yb), kernel=kernel, mean_function=None)

    @tf.function
    def loss():
        L = -model.log_marginal_likelihood()
        if penalty_type=='mi_gaussian':
            comps = get_component_posterior_means(model, xb)
            if tf.shape(comps)[1] > 1:
                L += beta * tf_pairwise_gaussian_mi_penalty(comps)
        return L

    opt = Scipy()
    opt.minimize(loss, model.trainable_variables,
                 options=dict(maxiter=maxiter, disp=False))
    return model

def calculate_mi_heatmap_numpy(cm):
    C = cm.shape[1]
    miM = np.zeros((C,C))
    for i in range(C):
        for j in range(i+1,C):
            ui, uj = cm[:,i], cm[:,j]
            rho = np.corrcoef(ui, uj)[0,1]
            miM[i,j] = miM[j,i] = -0.5*np.log(1-rho*rho+1e-9)
    return miM

# ---- 5) Run experiments & diagnostics ----
datasets = {"Additive": (X_add, y_add), "Interaction": (X_int, y_int)}
results = {}

for name, (X_np, y_np) in datasets.items():
    print(f"\n=== {name} ===")

    # baseline SE-ARD
    k_se = gpflow.kernels.SquaredExponential(lengthscales=[1.]*INPUT_DIM,
                                             variance=gpflow.Parameter(1.0, transform=positive()))
    m_se = train_gp(X_np, y_np, k_se)
    print_summary(m_se)

    # AGP no penalty
    se_creator = lambda d: create_base_kernel(active_dims=d)
    k_agp = general_additive_kernel(INPUT_DIM, 2, se_creator, "agp_")
    m_agp = train_gp(X_np, y_np, k_agp, maxiter=200)
    print_summary(m_agp)

    # AGP + MI
    mi_models = {}
    for β in [0.01, 0.1, 1.0]:
        k_mi = general_additive_kernel(INPUT_DIM, 2, se_creator, f"agpmi_{β}_")
        mi_models[f"AGP_MI_{β}"] = train_gp(X_np, y_np, k_mi,
                                           penalty_type='mi_gaussian', beta=β, maxiter=200)
        print_summary(mi_models[f"AGP_MI_{β}"])

    # OAK no penalty
    oak_creator = create_centered_se_base_kernel(X_np)
    k_oak = general_additive_kernel(INPUT_DIM, 2, oak_creator, "oak_")
    m_oak = train_gp(X_np, y_np, k_oak, maxiter=200)
    print_summary(m_oak)

    # MI heatmaps
    plt.figure(figsize=(12,4))
    for i,(label,mdl) in enumerate(
            [("AGP",m_agp)] + list(mi_models.items()) + [("OAK",m_oak)]):
        comps = get_component_posterior_means(mdl, tf.constant(X_np, gpflow.default_float())).numpy()
        hm = calculate_mi_heatmap_numpy(comps)
        ax=plt.subplot(1, len(mi_models)+3, i+1)
        ax.imshow(hm, cmap='viridis', vmin=0)
        ax.set_title(label)
        ax.set_xticks([]); ax.set_yticks([])
    plt.tight_layout(); plt.show()






=== Additive ===
╒═════════════════════════╤═══════════╤══════════════════╤═════════╤═════════════╤═════════╤═════════╤════════════════════════════════════════════════╕
│ name                    │ class     │ transform        │ prior   │ trainable   │ shape   │ dtype   │ value                                          │
╞═════════════════════════╪═══════════╪══════════════════╪═════════╪═════════════╪═════════╪═════════╪════════════════════════════════════════════════╡
│ GPR.kernel.variance     │ Parameter │ Softplus         │         │ True        │ ()      │ float64 │ 0.27497019836157727                            │
├─────────────────────────┼───────────┼──────────────────┼─────────┼─────────────┼─────────┼─────────┼────────────────────────────────────────────────┤
│ GPR.kernel.lengthscales │ Parameter │ Softplus         │         │ True        │ (3,)    │ float64 │ [1.80428973e+03 2.10174039e-01 7.14546443e-01] │
├─────────────────────────┼───────────┼──────────────────┼─────────

InvalidArgumentError: Graph execution error:

Detected at node GatherV2_3 defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\ipykernel\kernelapp.py", line 739, in start

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\tornado\platform\asyncio.py", line 205, in start

  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.12_3.12.2800.0_x64__qbz5n2kfra8p0\Lib\asyncio\base_events.py", line 645, in run_forever

  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.12_3.12.2800.0_x64__qbz5n2kfra8p0\Lib\asyncio\base_events.py", line 1999, in _run_once

  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.12_3.12.2800.0_x64__qbz5n2kfra8p0\Lib\asyncio\events.py", line 88, in _run

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\ipykernel\kernelbase.py", line 534, in process_one

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\ipykernel\ipkernel.py", line 362, in execute_request

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\ipykernel\kernelbase.py", line 778, in execute_request

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\ipykernel\ipkernel.py", line 449, in do_execute

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\IPython\core\interactiveshell.py", line 3075, in run_cell

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\IPython\core\interactiveshell.py", line 3130, in _run_cell

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\IPython\core\async_helpers.py", line 128, in _pseudo_sync_runner

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\IPython\core\interactiveshell.py", line 3334, in run_cell_async

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\IPython\core\interactiveshell.py", line 3517, in run_ast_nodes

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\IPython\core\interactiveshell.py", line 3577, in run_code

  File "C:\Users\User\AppData\Local\Temp\ipykernel_29532\674261087.py", line 189, in <module>

  File "C:\Users\User\AppData\Local\Temp\ipykernel_29532\674261087.py", line 145, in train_gp

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\gpflow\optimizers\scipy.py", line 159, in minimize

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\scipy\optimize\_minimize.py", line 713, in minimize

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\scipy\optimize\_lbfgsb_py.py", line 309, in _minimize_lbfgsb

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\scipy\optimize\_optimize.py", line 402, in _prepare_scalar_function

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\scipy\optimize\_differentiable_functions.py", line 166, in __init__

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\scipy\optimize\_differentiable_functions.py", line 262, in _update_fun

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\scipy\optimize\_differentiable_functions.py", line 163, in update_fun

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\scipy\optimize\_differentiable_functions.py", line 145, in fun_wrapped

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\scipy\optimize\_optimize.py", line 78, in __call__

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\scipy\optimize\_optimize.py", line 72, in _compute_if_needed

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\gpflow\optimizers\scipy.py", line 224, in _eval

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\gpflow\optimizers\scipy.py", line 190, in _tf_eval

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\gpflow\optimizers\scipy.py", line 192, in _tf_eval

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\gpflow\optimizers\scipy.py", line 329, in _compute_loss_and_gradients

  File "C:\Users\User\AppData\Local\Temp\ipykernel_29532\674261087.py", line 137, in loss

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\check_shapes\integration\tf.py", line 89, in wrapped_method

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\check_shapes\decorator.py", line 121, in wrapped_function

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\check_shapes\decorator.py", line 123, in wrapped_function

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\gpflow\models\gpr.py", line 100, in log_marginal_likelihood

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\gpflow\kernels\base.py", line 290, in __call__

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\gpflow\kernels\base.py", line 290, in __call__

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\check_shapes\integration\tf.py", line 89, in wrapped_method

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\check_shapes\decorator.py", line 121, in wrapped_function

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\check_shapes\decorator.py", line 123, in wrapped_function

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\gpflow\kernels\base.py", line 209, in __call__

  File "c:\Users\User\Documents\MEng-Oak\MEng-Project\.venv\Lib\site-packages\gpflow\kernels\base.py", line 214, in __call__

  File "C:\Users\User\AppData\Local\Temp\ipykernel_29532\674261087.py", line 91, in K

indices[0] = 1 is not in [0, 1)
	 [[{{node GatherV2_3}}]] [Op:__inference__tf_eval_39757]