# Bayesian Optimization on the 3-Sphere

Showcase-ready notebook for Bayesian optimization on the 3-sphere (S^3) mesh embedded in R^4. Geometry only; stripped of the old 3D/chemistry notes.

In [1]:
%pip install -e "../Altered packages/GeometricKernels" pymanopt matplotlib ipympl kaleido plotly scipy

import random
from pathlib import Path

import geometric_kernels
import kaleido
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import plotly
import plotly.io as pio
from geometric_kernels.kernels import MaternGeometricKernel, MaternKarhunenLoeveKernel
from geometric_kernels.spaces import Mesh, Hypersphere
from geometric_kernels.spaces.eigenfunctions import EigenfunctionsFromEigenvectors
from plotly.subplots import make_subplots
from Plotting import *
from scipy.stats import norm

pio.renderers.default = "browser"  # "browser" opens in browser; "vscode" for in-editor


Obtaining file:///workspaces/Newest-try-/Altered%20packages/GeometricKernels
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: geometric_kernels
  Building editable for geometric_kernels (pyproject.toml) ... [?25ldone
[?25h  Created wheel for geometric_kernels: filename=geometric_kernels-0.4-py3-none-any.whl size=8835 sha256=84623443072de2c554826b98dc0c7530ef53b187429818cdb5c3ace0f4c02adc
  Stored in directory: /tmp/pip-ephem-wheel-cache-pro1lqb2/wheels/a3/88/0f/62cc36c2f8a5adcb8e307bac8051fd1078533735e772aadd33
Successfully built geometric_kernels
Installing collected packages: geometric_kernels
  Attempting uninstall: geometric_kernels
    Found existing installation: geometric_kernels 0.4
    Uninstalling geometric_kernels-0.4:
      Successfull

INFO (geometric_kernels): Numpy backend is enabled. To enable other backends, don't forget to `import geometric_kernels.*backend name*`.
INFO (geometric_kernels): We may be suppressing some logging of external libraries. To override the logging policy, call `logging.basicConfig`.


Note: you may need to restart the kernel to use updated packages.


  from pkg_resources import resource_filename


## Mesh and kernel setup

In [2]:
MeshFolder_dir = Path.cwd()
Filename = "Delaunay 100point sphere.obj"
mesh = Mesh.load_mesh(str(MeshFolder_dir / Filename))
print("DIR:", MeshFolder_dir)
print("num_vertices:", mesh.num_vertices)

ambient_dim = mesh.vertices.shape[1]
display_vertices = mesh.vertices[:, :3] if ambient_dim > 3 else mesh.vertices

eigenvals = np.load("4D_eigenvals.npy").reshape(-1, 1)
eigenvectors = np.load("4D_eigenvecs.npy")

eigenfunctions = EigenfunctionsFromEigenvectors(eigenvectors=eigenvectors)
kernel = MaternKarhunenLoeveKernel(
    space=mesh,
    eigenfunctions=eigenfunctions,
    eigenvalues_laplacian=eigenvals,
    num_levels=eigenvectors.shape[0],
    dimension=3,
    normalize=True,
)

LENGTH_SCALE, NU = 4, 0.5
VARIANCE = "NEEDS_CONFIGURING"  # variance not yet exposed in the package

params = kernel.init_params()
params["lengthscale"] = np.array([LENGTH_SCALE])
params["nu"] = np.array([NU])


DIR: /workspaces/Newest-try-/4D sphere demo
num_vertices: 100


## Objective on S^3

In [3]:
def f(x):
    """Simple objective: return the first coordinate of the node (1-indexed input)."""
    idx = np.asarray(x, dtype=np.int64).flatten()
    coords = mesh.vertices[idx - 1]
    return coords[..., 0]

j = np.vectorize(f)


## BO setup

In [4]:
def expected_improvement(mu, sigma, f_best, xi=0.0):
    """
    Calculates the Expected Improvement (EI) for a set of points.

    Args:
        mu (np.ndarray): The posterior mean vector.
        sigma (np.ndarray): The posterior standard deviation vector.
        f_best (float): The best observed function value.
        xi (float): The exploration-exploitation trade-off parameter.

    Returns:
        np.ndarray: The EI vector.
    """
    # Calculate the standardized improvement (Z)
    xi = 0.9
    with np.errstate(divide='ignore'):  # suppress divide by zero warning
        
        mu = mu.reshape(-1,1) # TODO: delete this if it causes issues. 
        sigma = sigma.reshape(-1,1)
        Z = (f_best - mu - xi) / sigma #TODO: this and the one below may be the wrong expression. 
        
    # Calculate the EI for non-zero sigma
    ei = (f_best - mu - xi) * norm.cdf(Z) + sigma * norm.pdf(Z)
    
    # Handle the case where sigma is zero
    ei = np.where(sigma > 1e-10, ei, 0.0)
    
    return ei


def BO_loop_fixed(num_iterations, x_obs=None, objective_func=None, 
                  exploration_weight=0.1):
    """
    Fixed BO loop with better exploration of inner torus surface.
    
    Args:
        exploration_weight: Add random exploration with this probability
    """
    if objective_func is None:
        objective_func = f
    if x_obs is None:
        x_obs = x_observed

    num_verts = mesh.num_vertices
    whole_domain = np.atleast_2d(np.arange(1, num_verts + 1)).T

    y_observed = np.atleast_2d(np.apply_along_axis(objective_func, 1, x_obs)).reshape(-1, 1)
    K_XX_prior = kernel.K(params, whole_domain - 1, whole_domain - 1)
    mu_prior_vector = np.zeros((num_verts, 1))

    exploration_points = []

    for i in range(num_iterations):
        # GP Model Fitting
        m_vector = mu_prior_vector[x_obs.flatten() - 1]
        K_xX = kernel.K(params, x_obs - 1, whole_domain - 1)
        K_xx = kernel.K(params, x_obs - 1, x_obs - 1)
        K_Xx = K_xX.T
        
        K_xx_stable = K_xx + np.eye(K_xx.shape[0]) * 1e-6
        C_inv = np.linalg.pinv(K_xx_stable)

        mew_vec = mu_prior_vector + K_Xx @ C_inv @ (y_observed - m_vector)
        Current_K_matrix = K_XX_prior - K_Xx @ C_inv @ K_xX
        Sigma_vec = np.diag(Current_K_matrix).copy().reshape(-1, 1)
        Sigma_vec[Sigma_vec < 0] = 0 

        # Acquisition Function
        EI_vec = expected_improvement(mew_vec, np.sqrt(Sigma_vec), np.min(y_observed))
        
        
        next_point = np.argmax(EI_vec) + 1
        
        next_point = np.atleast_2d(next_point)

        # Update
        y_next = np.atleast_2d(objective_func(next_point - 1))
        x_obs = np.vstack((x_obs, next_point))
        y_observed = np.vstack((y_observed, y_next))

    
    return mew_vec, Sigma_vec, EI_vec, x_obs, y_observed

## BO run

In [5]:
objective_vals = j(np.arange(1, mesh.num_vertices + 1))

objective_plot_kwargs = dict(
    name="objective value",
    marker=dict(size=10, colorscale="hot"),
)
objective_trace = vector_values_to_mesh_trace(mesh, objective_vals, **objective_plot_kwargs)
hover_kwargs = dict(
    customdata=np.hstack(
        [np.atleast_2d(objective_vals).T, np.atleast_2d(np.arange(1, mesh.num_vertices + 1)).T]
    ),
    hovertemplate=(
        "x : %{x:.2f}<br>"
        + "y : %{y:.2f}<br>"
        + "z : %{z:.2f}<br>"
        + "objective value: %{customdata[0]:.2f}<br>"
        + "node index: %{customdata[1]}"
    ),
)
objective_trace = add_custom_hover_data(objective_trace, **hover_kwargs)

initial_point = np.array([[np.random.randint(1, mesh.num_vertices)]])
mu_1, sigma_1, ei_1, X_1, Y_1 = BO_loop_fixed(1, x_obs=initial_point, objective_func=f)
mu_2, sigma_2, ei_2, X_2, Y_2 = BO_loop_fixed(5, x_obs=initial_point, objective_func=f)
mu_3, sigma_3, ei_3, X_3, Y_3 = BO_loop_fixed(10, x_obs=initial_point, objective_func=f)

advanced_fig = go.Figure()

mu_1_trace = vector_values_to_mesh_trace(mesh, mu_1)
mu_2_trace = vector_values_to_mesh_trace(mesh, mu_2)
mu_3_trace = vector_values_to_mesh_trace(mesh, mu_3)
mu_2_trace = add_custom_hover_data(mu_2_trace, customdata=mu_2)
mu_1_trace = add_custom_hover_data(mu_1_trace, customdata=mu_1)
mu_3_trace = add_custom_hover_data(mu_3_trace, customdata=mu_3)

advanced_fig.add_trace(mu_1_trace)
advanced_fig.add_trace(mu_2_trace)
advanced_fig.add_trace(mu_3_trace)
advanced_fig.add_trace(objective_trace)

# points explored by the algorithm (projected for 3D viewing)
proj_points_5 = display_vertices[np.int64(X_2.flatten() - 1)]
proj_points_10 = display_vertices[np.int64(X_3.flatten() - 1)]
point_plot_trace = go.Scatter3d(
    x=np.array(proj_points_5[:, 0]),
    y=np.array(proj_points_5[:, 1]),
    z=np.array(proj_points_5[:, 2]),
    marker=dict(color="cyan"),
    mode="markers",
)

second_point_plot_trace = go.Scatter3d(
    x=np.array(proj_points_10[:, 0]),
    y=np.array(proj_points_10[:, 1]),
    z=np.array(proj_points_10[:, 2]),
    marker=dict(color="purple"),
    mode="markers",
)

advanced_fig.add_trace(point_plot_trace)
advanced_fig.add_trace(second_point_plot_trace)

fig_show(advanced_fig)

print(
    "the sampled values are (sorted ascending)",
    sorted(list(Y_2)),
    ".While the true best is",
    np.min(objective_vals),
    "which is at index:",
    np.argmax(-1 * objective_vals),
)
print("the initial point's value is", objective_vals[initial_point.squeeze() - 1])


KeyboardInterrupt: 

## Kernel influence

In [None]:
def visualize_kernel_influence(mesh, kernel, params, source_point, objective_vals):
    """Visualize how much influence a single point has across the mesh."""
    source_idx = np.atleast_2d([source_point])
    all_points = np.atleast_2d(np.arange(mesh.num_vertices)).T

    K_influence = kernel.K(params, source_idx, all_points).flatten()

    fig = go.Figure()
    influence_trace = vector_values_to_mesh_trace(
        mesh,
        K_influence,
        marker=dict(colorscale="Viridis", colorbar=dict(title="Kernel Value")),
    )
    influence_trace = add_custom_hover_data(influence_trace, customdata=K_influence.reshape(-1, 1))
    fig.add_trace(influence_trace)

    source_coord = display_vertices[source_point]
    j_trace = go.Scatter3d(
        x=[source_coord[0]],
        y=[source_coord[1]],
        z=[source_coord[2]],
        mode="markers",
        marker=dict(size=15, color="red", symbol="diamond"),
        name="Source Point",
    )

    j_trace = add_custom_hover_data(j_trace, customdata=K_influence.reshape(-1, 1))
    fig.add_trace(j_trace)

    fig.update_layout(title=f"Kernel Influence from Point {source_point}")
    return fig

initial_point = np.array([[4]]) - 1
influence_fig = visualize_kernel_influence(
    mesh, kernel, params,
    initial_point.item(), objective_vals,
)
fig_show(influence_fig)


KeyboardInterrupt: 

## Geodesic vs. mesh kernel

In [None]:
poisson_mesh_points = mesh.vertices
original_mesh_points = np.empty((0, ambient_dim), dtype=np.float32)


In [None]:
# Load original point cloud from PLY
with open("100_points_before_poisson.ply", "r") as file:
    vectors_start = False
    for line in file:
        line_arr = line.split(" ")
        if line_arr[0] == 'end_header\n':
            vectors_start = True
        elif vectors_start == True:
            line_arr = list(map(float, line_arr))[:ambient_dim]
            new_vec = np.array(line_arr).reshape(1, -1)
            original_mesh_points = np.vstack((original_mesh_points, new_vec))


In [None]:
def find_nearest_poisson(original_point):
    min_dist = 10000
    best_one = 0
    for node_ind in range(mesh.num_vertices):
        poisson_point = mesh.vertices[node_ind]
        poisson_point = poisson_point / np.linalg.norm(poisson_point, keepdims=True)
        if np.linalg.norm(poisson_point - original_point) < min_dist:
            min_dist = np.linalg.norm(poisson_point - original_point)
            best_one = node_ind
    return best_one

# Map original points to nearest poisson vertices
INDS_VECTOR = np.zeros((original_mesh_points.shape[0], 1))
for i in range(original_mesh_points.shape[0]):
    point_1 = original_mesh_points[i, :]
    INDS_VECTOR[i] = find_nearest_poisson(point_1)

BASTARD_KERNEL_MATRIX = kernel.K(params, INDS_VECTOR, INDS_VECTOR)


In [None]:
# Compute continuous kernel on S^3
proper_cts_sphere = Hypersphere(dim=3)
cts_kernel = MaternGeometricKernel(proper_cts_sphere)

Dist_matrix = np.zeros((original_mesh_points.shape[0], original_mesh_points.shape[0]))
original_mesh_points_normalized = original_mesh_points / np.linalg.norm(
    original_mesh_points, axis=1, keepdims=True
)
CTS_kernel_matrix = cts_kernel.K(params, original_mesh_points_normalized, original_mesh_points_normalized)

# Compute geodesic distances
for i in range(original_mesh_points.shape[0]):
    for j in range(i, original_mesh_points.shape[0]):
        point_1 = original_mesh_points_normalized[i]
        point_2 = original_mesh_points_normalized[j]
        dot_product = np.dot(point_1, point_2)
        clipped_dot_product = np.clip(dot_product, -1.0, 1.0)
        geodesic_dist = np.arccos(clipped_dot_product)
        Dist_matrix[j][i] = geodesic_dist
        Dist_matrix[i][j] = geodesic_dist


## Visualization

In [None]:
# Flatten matrices for plotting
c = Dist_matrix.flatten()  # geodesic distances
b = CTS_kernel_matrix.flatten()  # continuous kernel
a = BASTARD_KERNEL_MATRIX.flatten()  # mesh kernel


In [None]:
# Scatter plot comparison
fig, ax = plt.subplots(figsize=(10, 6))
ax.set_xlabel('Geodesic distance')
ax.set_ylabel('Kernel values')
ax.scatter(c, a, c='tab:red', label='Mesh kernel', s=3)
ax.scatter(c, b, c='tab:blue', label='Continuous SÂ³ kernel', s=3)
ax.set_title('Mesh vs. Continuous Kernel: Geodesic Distance Comparison')
ax.legend(loc='upper left')
plt.show()


"\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n\n\n# Ensure data is 1D for plotting (flattening the column vectors)\nc = c.flatten()\na = a.flatten()\nb = b.flatten()\n\n# 1. Create figure and primary Axes (ax1)\nfig, ax1 = plt.subplots(figsize=(10, 6))\n\n# Set the primary plot (a vs c) on the left Y-axis\ncolor_a = 'tab:red'\nax1.set_xlabel('Vector c (X-axis)')\nax1.set_ylabel('Vector a (Y1)', color=color_a)\nax1.scatter(c, a, c=color_a, label='Vector a', s = 3)\nax1.tick_params(axis='y', labelcolor=color_a)\n\n# 2. Create secondary Axes (ax2) sharing the X-axis\nax2 = ax1.twinx()\n\n# Set the secondary plot (b vs c) on the right Y-axis\ncolor_b = 'tab:blue'\nax2.set_ylabel('Vector b (Y2)', color=color_b)\nax2.scatter(c, b, c=color_b, label='Vector b', s = 3)\nax2.tick_params(axis='y', labelcolor=color_b)\n\n\n#TRYING TO FIX THE SCALING ISSUES:\n\n\n\n# 3. Add title and legend\nplt.title('Scatter Plots with Dual Y-Axes')\n# Combine legends from both axes\nlines1, labels1 =