# Setting up a 3D Bayesian optimization problem with gaussian process regression

Jackson S. Bentley, Sumner B. Harris
09/01/25

In [None]:
# Step 1: Clean uninstall of conflicting packages
!pip uninstall -y jax jaxlib ml-dtypes gpax

# Step 2: Install compatible versions
!pip install jax==0.6.2 jaxlib==0.6.2 ml_dtypes==0.5.1

# Step 3: Install gpax version that supports jax>=0.6.2
!pip install gpax==0.1.9

# Step 4: Force CPU use and test
import jax
import gpax

jax.config.update("jax_platform_name", "cpu")  # CPU only
gpax.utils.enable_x64()

print("✅ GPAX is ready with JAX version:", jax.__version__)


Found existing installation: jax 0.5.3
Uninstalling jax-0.5.3:
  Successfully uninstalled jax-0.5.3
Found existing installation: jaxlib 0.5.3
Uninstalling jaxlib-0.5.3:
  Successfully uninstalled jaxlib-0.5.3
Found existing installation: ml_dtypes 0.5.3
Uninstalling ml_dtypes-0.5.3:
  Successfully uninstalled ml_dtypes-0.5.3
[0mCollecting jax==0.6.2
  Downloading jax-0.6.2-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib==0.6.2
  Downloading jaxlib-0.6.2-cp312-cp312-manylinux2014_x86_64.whl.metadata (1.3 kB)
Collecting ml_dtypes==0.5.1
  Downloading ml_dtypes-0.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)
Downloading jax-0.6.2-py3-none-any.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m35.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.6.2-cp312-cp312-manylinux2014_x86_64.whl (89.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 MB[0m [31m6.8 MB/s[0m eta [36



✅ GPAX is ready with JAX version: 0.8.0


In [None]:
# Define a degree symbol for later plotting
deg_sgn = '\N{DEGREE SIGN}'


# Starting from scratch gpbo weighting loop: only looping through n & k (deviation from out-of-plane lattice parameter & integrated LaVO4 intensity)

In [None]:
# Script that:
# - Loads Excel data
# - Builds the 3D grid (Pressure (log10), Temperature, Fluence)
# - Loops over n,k in {1,2,3} (m=l=1), computes y_train for each weight set
# - Retrains the gpax viGP for each (n,k) from scratch
# - Predicts across the full 3D grid, finds the best predicted point and best measured point
# - Saves a results table, histogram, and an interactive Plotly 3D scatter of best points
#
# Notes:
# - This script intentionally does NOT plot the full GP maps for every (n,k) to save time.
#   You can enable per-run plotting by setting `plot_each_run = True`.
# - Speed tips: reduce grid resolution (larger step sizes) or run fewer HMC samples (if you
#   expose fit arguments). Avoid running multiple GP fits simultaneously (JAX + multiprocessing
#   can deadlock).

# %%
# Install (uncomment if needed)
# NOTE: gpax may be a private package in your environment. If gpax is installed locally, skip install.
# !pip install gpax plotly openpyxl numpyro

# %%
# Imports and global settings
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
from pathlib import Path

import jax
import jax.numpy as jnp
import gpax

# Enable 64-bit (same as your previous cell)
gpax.utils.enable_x64()

# Plotly for interactive 3D
import plotly.graph_objects as go

# Colab file upload helper
try:
    from google.colab import files
except Exception:
    files = None


LVO_c_lit = 3.945 # out-of-plane lattice parameter from literature

# %%
# --- Helper functions (normalize / inverse normalize / save/load) ---

def normalize(data, min_val, max_val):
    return (data - min_val) / (max_val - min_val)


def inverse_normalize(norm_data, min_val, max_val):
    return norm_data * (max_val - min_val) + min_val


def save_data(
    file_name,
    X_train, y_train,
    X_test,
    x1, x2, x3,
    y_pred, y_sampled,
    acq, next_point,
    running_best,
    iteration
):
    np.savez(
        file_name,
        X_train=X_train,
        y_train=y_train,
        X_test=X_test,
        x1=x1,
        x2=x2,
        x3=x3,
        y_pred=y_pred,
        y_sampled=y_sampled,
        acq=acq,
        next_point=next_point,
        running_best=running_best,
        iteration=iteration
    )


def load_data(file_name):
    ds = np.load(file_name, allow_pickle=True)
    return (
        ds['X_train'],
        ds['y_train'],
        ds['X_test'],
        ds['x1'],
        ds['x2'],
        ds['x3'],
        ds['y_pred'],
        ds['y_sampled'],
        ds['acq'],
        ds['next_point'],
        ds['running_best'],
        ds['iteration']
    )


def update_datapoints(X_new, y_new, X_train, y_train):
    X_train = jnp.append(X_train, X_new, axis=0)
    y_train = jnp.append(y_train, y_new, axis=0)
    return X_train, y_train

# Reuse plotting routine (matplotlib projections), pretty much unchanged.
# Note: this function references globals pressure_min, pressure_max, T_min, T_max, fluence_min, fluence_max
# which will be set after the grid build cell below.

def plot_3d_projections(
    field_flat,
    P_array,
    T_array,
    F_array,
    title: str,
    X_train: np.ndarray = None,
    y_train: np.ndarray = None,
    deg_sgn: str = '°'
):
    field = np.asarray(field_flat)
    nP, nT, nF = len(P_array), len(T_array), len(F_array)
    if field.size != nP * nT * nF:
        raise ValueError(f"field length {field.size} != {nP}*{nT}*{nF}")

    field3d = field.reshape(nP, nT, nF)
    proj_PT = field3d.mean(axis=2)
    proj_PF = field3d.mean(axis=1)
    proj_TF = field3d.mean(axis=0)

    fig, axs = plt.subplots(1, 3, figsize=(18, 6), dpi=200, constrained_layout=True)
    for ax in axs:
        ax.tick_params(axis='both', labelsize=12)
    font_title = 20
    font_label = 18
    font_tick = 14

    im0 = axs[0].pcolormesh(P_array, T_array, proj_PT.T, shading='auto', cmap='nipy_spectral')
    axs[0].set_xlabel('log$_{10}$ O$_2$ Partial Pressure (log$_{10}$ Torr)', fontsize=font_label)
    axs[0].set_ylabel(f'Deposition Temperature ({deg_sgn}C)', fontsize=font_label)
    axs[0].set_title(f'{title} → T vs P', fontsize=font_title)
    fig.colorbar(im0, ax=axs[0], orientation='vertical')

    im1 = axs[1].pcolormesh(P_array, F_array, proj_PF.T, shading='auto', cmap='nipy_spectral')
    axs[1].set_xlabel('log$_{10}$ O$_2$ Partial Pressure (log$_{10}$ Torr)', fontsize=font_label)
    axs[1].set_ylabel('Fluence (J/cm$^2$)', fontsize=font_label)
    axs[1].set_title(f'{title} → F vs P', fontsize=font_title)
    fig.colorbar(im1, ax=axs[1], orientation='vertical')

    im2 = axs[2].pcolormesh(T_array, F_array, proj_TF.T, shading='auto', cmap='nipy_spectral')
    axs[2].set_xlabel(f'Deposition Temperature ({deg_sgn}C)', fontsize=font_label)
    axs[2].set_ylabel('Fluence (J/cm$^2$)', fontsize=font_label)
    axs[2].set_title(f'{title} → F vs T', fontsize=font_title)
    fig.colorbar(im2, ax=axs[2], orientation='vertical')

    if X_train is not None and y_train is not None:
        P_train = inverse_normalize(X_train[:,0], pressure_min, pressure_max)
        T_train = inverse_normalize(X_train[:,1], T_min, T_max)
        F_train = inverse_normalize(X_train[:,2], fluence_min, fluence_max)
        y_vals  = y_train.flatten()

        scatter_kwargs = dict(c=y_vals, cmap='Reds', edgecolor='k', s=200, alpha=0.8)

        sc0 = axs[0].scatter(P_train, T_train, **scatter_kwargs)
        sc1 = axs[1].scatter(P_train, F_train, **scatter_kwargs)
        sc2 = axs[2].scatter(T_train, F_train, **scatter_kwargs)

        from mpl_toolkits.axes_grid1.inset_locator import inset_axes
        cbar_ax = inset_axes(
            axs[2],
            width="60%",
            height="5%",
            loc='upper left',
            bbox_to_anchor=(0.55, 0.3, 1, 1),
            bbox_transform=axs[2].transAxes,
            borderpad=0,
        )
        cbar = fig.colorbar(sc0, cax=cbar_ax, orientation='horizontal')
        cbar.set_label('Sample\nScores', fontsize=font_label)
        cbar.ax.tick_params(labelsize=font_tick)

    fig.suptitle(
    f"{title}: Min-Max Scaling\n"
    r"$y_{\mathrm{train}} = \alpha \cdot |c - 3.945| + \beta \cdot R_{RMS} + \gamma \cdot FWHM + \delta \cdot LaVO_{4}$" + "\n" +
    f"[α = {alpha}, β = {beta}, γ = {gamma}, δ = {delta}]",
    fontsize=16,
    fontweight='bold',
    y=1.17
    )
    plt.show()

# %%
# --- Build the 3D grid (pressures are represented as log10(pressure)) ---
# Pressure bounds
start, stop = 3e-8, 3e-4
num_points_per_decade = 10
total_decades = int(np.log10(stop / start))
total_points = total_decades * num_points_per_decade + 1

# Pressure array stored as log10(pressure)
Pressure = np.log10(np.geomspace(start, stop, num=total_points))
pressure_min, pressure_max = Pressure.min(), Pressure.max()

# Temperature
T_min, T_max, T_stepsize = 500, 835, 20
Temperature = np.arange(T_min, T_max, T_stepsize, dtype=np.float32)

# Fluence
fluence_min, fluence_max, fluence_stepsize = 0.8, 2.2, 0.1
Fluence = np.arange(fluence_min, fluence_max, fluence_stepsize, dtype=np.float32)

# Normalized grids
P_norm = normalize(Pressure, pressure_min, pressure_max)
T_norm = normalize(Temperature, T_min, T_max)
F_norm = normalize(Fluence, fluence_min, fluence_max)

# Meshgrid in normalized space
Pg, Tg, Fg = np.meshgrid(P_norm, T_norm, F_norm, indexing='ij')
P_flat = Pg.reshape(-1, 1)
T_flat = Tg.reshape(-1, 1)
F_flat = Fg.reshape(-1, 1)
points_3d = np.hstack((P_flat, T_flat, F_flat))

print('3D parameter space has size:', points_3d.shape)
print('Pressure points (log10):', Pressure.shape)
print('Temperature points:', Temperature.shape)
print('Fluence points:', Fluence.shape)
print('Example back-transformed Pressure (Torr):', 10**(Pressure[:3]))

# Convert X_test to jax array once for faster device transfer
X_test_jnp = jnp.asarray(points_3d)

# %%
# --- GP step function (wraps gpax.viGP) ---
import numpyro
from numpyro import distributions
from jax import random

# Optionally define priors here (you can tune these or set to None)
lengthscale_prior_dist = None
noise_prior_dist = None


def step_GP(X_measured, y_measured, X_unmeasured,
            noise_prior_dist=None,
            lengthscale_prior_dist=None):
    rng_key1, rng_key2 = gpax.utils.get_keys()

    # Convert to jax arrays
    X_measured_j = jnp.asarray(X_measured)
    y_measured_j = jnp.asarray(y_measured).ravel()

    gp_model = gpax.viGP(
        input_dim=3,
        kernel='Matern',
        noise_prior_dist=noise_prior_dist,
        lengthscale_prior_dist=lengthscale_prior_dist
    )

    print('Training Model (this may take some time).')
    # You can pass additional kwargs into .fit if gpax supports them (e.g., fewer HMC samples)
    gp_model.fit(rng_key1, X_measured_j, y_measured_j, jitter=1e-2)

    print('Getting Model Predictions.')
    y_pred, y_sampled = gp_model.predict_in_batches(
        rng_key2,
        X_unmeasured,
        noiseless=False,
        jitter=1e-2
    )

    print('Calculating acquisition.')
    acquisition = gpax.acquisition.EI(
        rng_key2,
        gp_model,
        X_unmeasured,
        maximize=False,
        recent_points=X_measured_j,
        noiseless=False,
        jitter=1e-2,
        penalty='delta'
    )

    print('Getting parameter samples.')
    paras = gp_model.get_samples()

    # Convert outputs back to numpy for convenience
    return np.asarray(acquisition), np.asarray(y_pred), np.asarray(y_sampled), paras

# %%
# --- Load Excel / Build seed points ---

# Option A: Set path directly (if you already uploaded the file to /content)
file_path = '/content/LaVO3_PLD_Parameters_working_pressure_fixed_cleaned_up_no_10nm_no_LaVO4_to_B38.xlsx'
df = pd.read_excel(file_path)

# Option B: Use files.upload() in Colab
if files is not None:
    print('If running in Colab: use files.upload() to upload your Excel file.\nOr set file_path directly in the script.')

# We'll try to pick up a file in /content if one dropped there, otherwise use upload
content_files = list(Path('/content').glob('*.xlsx')) + list(Path('.').glob('*.xlsx'))
if len(content_files) == 0:
    if files is not None:
        uploaded = files.upload()
        # take the first uploaded file
        first = next(iter(uploaded.keys()))
        file_path = f'/content/{first}'
    else:
        raise FileNotFoundError('No .xlsx files found in working directory and files.upload() is unavailable.')
else:
    file_path = str(content_files[0])

print('Using Excel file:', file_path)

# Read sheet (adjust sheet_name if necessary)
df = pd.read_excel(file_path, sheet_name=0)

# Extract columns (fall back to alternative names if necessary)
# Use the exact column names from excel spreadsheet; if any KeyError occurs, inspect df.columns

colmap = {
    'Sample ID': 'Sample ID',
    'Date (YYMMDD)': 'Date (YYMMDD)',
    'Fluence': 'Fluence (J/cm2)',
    'O2 Pressure': 'O2 Pressure (Torr)',
    'Working pressure': 'Working pressure (Torr)',
}

# Try to access the columns listed earlier; adapt if column names differ
try:
    fluence = df['Fluence (J/cm2)'].values.astype(float)
    o2_pressure = df['O2 Pressure (Torr)'].values.astype(float)
    substrate_temp = df['Substrate Temp. (C) (based on average of temperature measured near substrate)'].values.astype(float)
    roughness = df['Roughness (nm) (standard deviation of height: 3x3 um, 2 Hz, 512 sam/l, B25 sub only = 0.147 )'].values.astype(float)
    c_lattice_parameter = df['C Lattice Parameter from Fringes (Å) (or 2Tw  peak position*)'].values.astype(float)
    fwhm_rocking_curve = df['FWHM (deg) (Gaussian fitting of rocking curve: B25: sub only, 1 twin=0.0050 )'].values.astype(float)
    integrated_intensity = df['Integrated LaVO4 Intensity: 48-51, normalized, x10^6 [B25 sub only= 4.7 (250520)]'].values.astype(float)
except Exception as e:
    print('Column name mismatch: please inspect df.columns')
    print(df.columns)
    raise e

# Build raw seed points array (physical values)
seed_points_raw = np.column_stack((o2_pressure, substrate_temp, fluence))

# Normalize seed points into [0,1]
seed_points = np.zeros_like(seed_points_raw, dtype=np.float32)
seed_points[:, 0] = normalize(np.log10(seed_points_raw[:, 0]), pressure_min, pressure_max)
seed_points[:, 1] = normalize(seed_points_raw[:, 1], T_min, T_max)
seed_points[:, 2] = normalize(seed_points_raw[:, 2], fluence_min, fluence_max)

# initial X_train (normalized) and placeholders
X_train_initial = seed_points

# NOTE: y_train will be recalculated inside the loop for each (n,k)

# %%
# --- Main loop over weight combinations (n,k) ---
results = []

# Control flags
plot_each_run = False   # set True if you want to plot the GP results for each (n,k)
save_checkpoints = True # save a .npz checkpoint per run

for n, k in product([1,2,3], repeat=2):
    print('\n' + '='*60)
    print(f'Running weights: n={n}, k={k} (m=l=1)')
    m = 1
    l = 1
    alpha = 6.0606 * n
    beta = 0.0876 * m
    gamma = 4.3879 * l
    delta = 0.0338 * k

    # Compute y components and full target
    lattice_parameter_mismatch = np.abs(c_lattice_parameter - LVO_c_lit)
    y_1 = alpha * lattice_parameter_mismatch
    y_2 = beta * roughness
    y_3 = gamma * fwhm_rocking_curve
    y_4 = delta * integrated_intensity
    y_train_full = (y_1 + y_2 + y_3 + y_4).astype(float)

    # Handle NaNs and infs
    mask_good = np.isfinite(y_train_full)
    if not np.all(mask_good):
        print(f'Warning: {np.count_nonzero(~mask_good)} bad rows in y_train; dropping them from X_train.')

    X_train_masked = X_train_initial[mask_good]
    y_train_masked = y_train_full[mask_good]

    # convert to jax arrays when passing to step_GP
    acq, y_pred, y_sampled, paras = step_GP(X_train_masked, y_train_masked, X_test_jnp,
                                           noise_prior_dist=noise_prior_dist,
                                           lengthscale_prior_dist=lengthscale_prior_dist)

    # Find best predicted point (minimum since maximize=False)
    best_pred_idx = int(np.argmin(y_pred))
    best_pred_norm = points_3d[best_pred_idx]  # normalized coords
    best_pred_log10P = inverse_normalize(best_pred_norm[0], pressure_min, pressure_max)
    best_pred_T = inverse_normalize(best_pred_norm[1], T_min, T_max)
    best_pred_F = inverse_normalize(best_pred_norm[2], fluence_min, fluence_max)

    # Best measured (from the seed / measured dataset)
    meas_best_idx = int(np.argmin(y_train_masked))
    meas_best_norm = X_train_masked[meas_best_idx]
    meas_best_log10P = inverse_normalize(meas_best_norm[0], pressure_min, pressure_max)
    meas_best_T = inverse_normalize(meas_best_norm[1], T_min, T_max)
    meas_best_F = inverse_normalize(meas_best_norm[2], fluence_min, fluence_max)
    meas_best_y = float(y_train_masked[meas_best_idx])

    # Save checkpoint
    if save_checkpoints:
        ckpt_name = f'BO_ckpt_n{n}_k{k}.npz'
        save_data(
            ckpt_name,
            X_train_masked, y_train_masked, points_3d,
            x1=Pressure, x2=Temperature, x3=Fluence,
            y_pred=y_pred, y_sampled=y_sampled, acq=acq,
            next_point=None, running_best=meas_best_y, iteration=0
        )

    # Optional per-run plotting (commented out by default)
    if plot_each_run:
        plot_3d_projections(y_pred, Pressure, Temperature, Fluence, title=f'GP Mean n{n}_k{k}', X_train=X_train_masked, y_train=y_train_masked)

    # Append to results
    results.append({
        'n': n,
        'k': k,
        'alpha': alpha,
        'beta': beta,
        'gamma': gamma,
        'delta': delta,
        'best_pred_score': float(np.min(y_pred)),
        'best_pred_log10P': float(best_pred_log10P),
        'best_pred_pressure_Torr': 10**(float(best_pred_log10P)),
        'best_pred_temperature_C': float(best_pred_T),
        'best_pred_fluence_Jpcm2': float(best_pred_F),
        'meas_best_score': meas_best_y,
        'meas_best_log10P': float(meas_best_log10P),
        'meas_best_pressure_Torr': 10**(float(meas_best_log10P)),
        'meas_best_temperature_C': float(meas_best_T),
        'meas_best_fluence_Jpcm2': float(meas_best_F),
    })

# %%
# --- Save and summarize results ---
results_df = pd.DataFrame(results)
print('\nSummary:')
print(results_df)

# Save table to excel/csv
results_df.to_csv('best_points_by_weights.csv', index=False)
results_df.to_excel('best_points_by_weights.xlsx', index=False)
print('Saved best_points_by_weights.csv/.xlsx to current directory')

# Histogram of best predicted scores
plt.figure(figsize=(8,5))
plt.hist(results_df['best_pred_score'], bins=8, edgecolor='k')
plt.xlabel('Best Predicted Score')
plt.ylabel('Count')
plt.title('Histogram of Best Predicted Samples (over weight sets)')
plt.show()

# %%
# --- 3D Plotly of best predicted points (one marker per (n,k)) ---

hover_text = []
for r in results:
    txt = (f"n={r['n']}, k={r['k']}<br>score={r['best_pred_score']:.4f}" +
           f"<br>P={r['best_pred_pressure_Torr']:.2e} Torr<br>T={r['best_pred_temperature_C']:.0f} C<br>F={r['best_pred_fluence_Jpcm2']:.2f} J/cm^2")
    hover_text.append(txt)

fig = go.Figure(data=[
    go.Scatter3d(
        x=[r['best_pred_log10P'] for r in results],  # log10(P)
        y=[r['best_pred_temperature_C'] for r in results],
        z=[r['best_pred_fluence_Jpcm2'] for r in results],
        mode='markers+text',
        text=[f"n={r['n']},k={r['k']}" for r in results],
        hovertext=hover_text,
        marker=dict(size=6, color=[r['best_pred_score'] for r in results], colorscale='Viridis', colorbar=dict(title='Score'))
    )
])
fig.update_layout(
    scene=dict(
        xaxis_title='log10(Pressure) (Torr)',
        yaxis_title='Temperature (C)',
        zaxis_title='Fluence (J/cm^2)'
    ),
    title='Best Predicted Points for Different Weight Sets'
)
fig.show()

print('Done. Check the CSV/XLSX for the full table of results.')


3D parameter space has size: (9758, 3)
Pressure points (log10): (41,)
Temperature points: (17,)
Fluence points: (14,)
Example back-transformed Pressure (Torr): [3.00000000e-08 3.77677624e-08 4.75467958e-08]


FileNotFoundError: [Errno 2] No such file or directory: '/content/LaVO3_PLD_Parameters_working_pressure_fixed_cleaned_up_no_10nm_no_LaVO4_to_B38.xlsx'

# Starting from scratch gpbo weighting loop: looping through all 4 metrics (n, l, m, k)

In [None]:
# Script that:
# - Loads Excel data
# - Builds the 3D grid (Pressure (log10), Temperature, Fluence)
# - Loops over n,k in {1,2,3} (m=l=1), computes y_train for each weight set
# - Retrains the gpax viGP for each (n,k) from scratch
# - Predicts across the full 3D grid, finds the best predicted point and best measured point
# - Saves a results table, histogram, and an interactive Plotly 3D scatter of best points
#
# Notes:
# - Run the cells in order in Colab. If you want to upload a file instead of setting file_path,
#   use the files.upload() flow provided.
# - This script intentionally does NOT plot the full GP maps for every (n,k) to save time.
#   You can enable per-run plotting by setting `plot_each_run = True`.
# - Speed tips: reduce grid resolution (larger step sizes) or run fewer HMC samples (if you
#   expose fit arguments). Avoid running multiple GP fits simultaneously (JAX + multiprocessing
#   can deadlock).

# %%
# Install (uncomment in Colab if needed)
# NOTE: gpax may be a private package in your environment. If gpax is installed locally, skip install.
# !pip install gpax plotly openpyxl numpyro

# %%
# Imports and global settings
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
from pathlib import Path

import jax
import jax.numpy as jnp
import gpax

# Enable 64-bit (same as your previous cell)
gpax.utils.enable_x64()

# Plotly for interactive 3D
import plotly.graph_objects as go

# Colab file upload helper
try:
    from google.colab import files
except Exception:
    files = None

# %%
# --- Helper functions (normalize / inverse normalize / save/load) ---

def normalize(data, min_val, max_val):
    return (data - min_val) / (max_val - min_val)


def inverse_normalize(norm_data, min_val, max_val):
    return norm_data * (max_val - min_val) + min_val


def save_data(
    file_name,
    X_train, y_train,
    X_test,
    x1, x2, x3,
    y_pred, y_sampled,
    acq, next_point,
    running_best,
    iteration
):
    np.savez(
        file_name,
        X_train=X_train,
        y_train=y_train,
        X_test=X_test,
        x1=x1,
        x2=x2,
        x3=x3,
        y_pred=y_pred,
        y_sampled=y_sampled,
        acq=acq,
        next_point=next_point,
        running_best=running_best,
        iteration=iteration
    )


def load_data(file_name):
    ds = np.load(file_name, allow_pickle=True)
    return (
        ds['X_train'],
        ds['y_train'],
        ds['X_test'],
        ds['x1'],
        ds['x2'],
        ds['x3'],
        ds['y_pred'],
        ds['y_sampled'],
        ds['acq'],
        ds['next_point'],
        ds['running_best'],
        ds['iteration']
    )


def update_datapoints(X_new, y_new, X_train, y_train):
    X_train = jnp.append(X_train, X_new, axis=0)
    y_train = jnp.append(y_train, y_new, axis=0)
    return X_train, y_train

# Reuse plotting routine (matplotlib projections), pretty much unchanged.
# Note: this function references globals pressure_min, pressure_max, T_min, T_max, fluence_min, fluence_max
# which will be set after the grid build cell below.

def plot_3d_projections(
    field_flat,
    P_array,
    T_array,
    F_array,
    title: str,
    X_train: np.ndarray = None,
    y_train: np.ndarray = None,
    deg_sgn: str = '°'
):
    field = np.asarray(field_flat)
    nP, nT, nF = len(P_array), len(T_array), len(F_array)
    if field.size != nP * nT * nF:
        raise ValueError(f"field length {field.size} != {nP}*{nT}*{nF}")

    field3d = field.reshape(nP, nT, nF)
    proj_PT = field3d.mean(axis=2)
    proj_PF = field3d.mean(axis=1)
    proj_TF = field3d.mean(axis=0)

    fig, axs = plt.subplots(1, 3, figsize=(18, 6), dpi=200, constrained_layout=True)
    for ax in axs:
        ax.tick_params(axis='both', labelsize=12)
    font_title = 20
    font_label = 18
    font_tick = 14

    im0 = axs[0].pcolormesh(P_array, T_array, proj_PT.T, shading='auto', cmap='nipy_spectral')
    axs[0].set_xlabel('log$_{10}$ O$_2$ Partial Pressure (log$_{10}$ Torr)', fontsize=font_label)
    axs[0].set_ylabel(f'Deposition Temperature ({deg_sgn}C)', fontsize=font_label)
    axs[0].set_title(f'{title} → T vs P', fontsize=font_title)
    fig.colorbar(im0, ax=axs[0], orientation='vertical')

    im1 = axs[1].pcolormesh(P_array, F_array, proj_PF.T, shading='auto', cmap='nipy_spectral')
    axs[1].set_xlabel('log$_{10}$ O$_2$ Partial Pressure (log$_{10}$ Torr)', fontsize=font_label)
    axs[1].set_ylabel('Fluence (J/cm$^2$)', fontsize=font_label)
    axs[1].set_title(f'{title} → F vs P', fontsize=font_title)
    fig.colorbar(im1, ax=axs[1], orientation='vertical')

    im2 = axs[2].pcolormesh(T_array, F_array, proj_TF.T, shading='auto', cmap='nipy_spectral')
    axs[2].set_xlabel(f'Deposition Temperature ({deg_sgn}C)', fontsize=font_label)
    axs[2].set_ylabel('Fluence (J/cm$^2$)', fontsize=font_label)
    axs[2].set_title(f'{title} → F vs T', fontsize=font_title)
    fig.colorbar(im2, ax=axs[2], orientation='vertical')

    if X_train is not None and y_train is not None:
        P_train = inverse_normalize(X_train[:,0], pressure_min, pressure_max)
        T_train = inverse_normalize(X_train[:,1], T_min, T_max)
        F_train = inverse_normalize(X_train[:,2], fluence_min, fluence_max)
        y_vals  = y_train.flatten()

        scatter_kwargs = dict(c=y_vals, cmap='Reds', edgecolor='k', s=200, alpha=0.8)

        sc0 = axs[0].scatter(P_train, T_train, **scatter_kwargs)
        sc1 = axs[1].scatter(P_train, F_train, **scatter_kwargs)
        sc2 = axs[2].scatter(T_train, F_train, **scatter_kwargs)

        from mpl_toolkits.axes_grid1.inset_locator import inset_axes
        cbar_ax = inset_axes(
            axs[2],
            width="60%",
            height="5%",
            loc='upper left',
            bbox_to_anchor=(0.55, 0.3, 1, 1),
            bbox_transform=axs[2].transAxes,
            borderpad=0,
        )
        cbar = fig.colorbar(sc0, cax=cbar_ax, orientation='horizontal')
        cbar.set_label('Sample\nScores', fontsize=font_label)
        cbar.ax.tick_params(labelsize=font_tick)

    fig.suptitle(
    f"{title}: Min-Max Scaling\n"
    r"$y_{\mathrm{train}} = \alpha \cdot |c - 3.945| + \beta \cdot R_{RMS} + \gamma \cdot FWHM + \delta \cdot LaVO_{4}$" + "\n" +
    f"[α = {alpha}, β = {beta}, γ = {gamma}, δ = {delta}]",
    fontsize=16,
    fontweight='bold',
    y=1.17
    )
    plt.show()

# %%
# --- Build the 3D grid (pressures are represented as log10(pressure)) ---
# Pressure bounds (same as your provided code)
start, stop = 3e-8, 3e-4
num_points_per_decade = 10
total_decades = int(np.log10(stop / start))
total_points = total_decades * num_points_per_decade + 1

# Pressure array stored as log10(pressure)
Pressure = np.log10(np.geomspace(start, stop, num=total_points))
pressure_min, pressure_max = Pressure.min(), Pressure.max()

# Temperature
T_min, T_max, T_stepsize = 500, 835, 20
Temperature = np.arange(T_min, T_max, T_stepsize, dtype=np.float32)

# Fluence
fluence_min, fluence_max, fluence_stepsize = 0.8, 2.2, 0.1
Fluence = np.arange(fluence_min, fluence_max, fluence_stepsize, dtype=np.float32)

# Normalized grids
P_norm = normalize(Pressure, pressure_min, pressure_max)
T_norm = normalize(Temperature, T_min, T_max)
F_norm = normalize(Fluence, fluence_min, fluence_max)

# Meshgrid in normalized space
Pg, Tg, Fg = np.meshgrid(P_norm, T_norm, F_norm, indexing='ij')
P_flat = Pg.reshape(-1, 1)
T_flat = Tg.reshape(-1, 1)
F_flat = Fg.reshape(-1, 1)
points_3d = np.hstack((P_flat, T_flat, F_flat))

print('3D parameter space has size:', points_3d.shape)
print('Pressure points (log10):', Pressure.shape)
print('Temperature points:', Temperature.shape)
print('Fluence points:', Fluence.shape)
print('Example back-transformed Pressure (Torr):', 10**(Pressure[:3]))

# Convert X_test to jax array once for faster device transfer
X_test_jnp = jnp.asarray(points_3d)

# %%
# --- GP step function (wraps gpax.viGP) ---
import numpyro
from numpyro import distributions
from jax import random

# Optionally define priors here (you can tune these or set to None)
lengthscale_prior_dist = None
noise_prior_dist = None


def step_GP(X_measured, y_measured, X_unmeasured,
            noise_prior_dist=None,
            lengthscale_prior_dist=None):
    rng_key1, rng_key2 = gpax.utils.get_keys()

    # Convert to jax arrays
    X_measured_j = jnp.asarray(X_measured)
    y_measured_j = jnp.asarray(y_measured).ravel()

    gp_model = gpax.viGP(
        input_dim=3,
        kernel='Matern',
        noise_prior_dist=noise_prior_dist,
        lengthscale_prior_dist=lengthscale_prior_dist
    )

    print('Training Model (this may take some time).')
    # You can pass additional kwargs into .fit if gpax supports them (e.g., fewer HMC samples)
    gp_model.fit(rng_key1, X_measured_j, y_measured_j, jitter=1e-2)

    print('Getting Model Predictions.')
    y_pred, y_sampled = gp_model.predict_in_batches(
        rng_key2,
        X_unmeasured,
        noiseless=False,
        jitter=1e-2
    )

    print('Calculating acquisition.')
    acquisition = gpax.acquisition.EI(
        rng_key2,
        gp_model,
        X_unmeasured,
        maximize=False,
        recent_points=X_measured_j,
        noiseless=False,
        jitter=1e-2,
        penalty='delta'
    )

    print('Getting parameter samples.')
    paras = gp_model.get_samples()

    # Convert outputs back to numpy for convenience
    return np.asarray(acquisition), np.asarray(y_pred), np.asarray(y_sampled), paras

# %%
# --- Load Excel / Build seed points ---

# Option A: Set path directly (if file already uploaded to /content)
file_path = '/content/LaVO3_PLD_Parameters_working_pressure_fixed_cleaned_up_no_10nm_no_LaVO4_to_B38.xlsx'
df = pd.read_excel(file_path)

# Option B: Use files.upload() in Colab
if files is not None:
    print('If running in Colab: use files.upload() to upload your Excel file.\nOr set file_path directly in the script.')

# We'll try to pick up a file in /content if user dropped one there, otherwise use upload
content_files = list(Path('/content').glob('*.xlsx')) + list(Path('.').glob('*.xlsx'))
if len(content_files) == 0:
    if files is not None:
        uploaded = files.upload()
        # take the first uploaded file
        first = next(iter(uploaded.keys()))
        file_path = f'/content/{first}'
    else:
        raise FileNotFoundError('No .xlsx files found in working directory and files.upload() is unavailable.')
else:
    file_path = str(content_files[0])

print('Using Excel file:', file_path)

# Read sheet (adjust sheet_name if necessary)
df = pd.read_excel(file_path, sheet_name=0)

# Extract columns (fall back to alternative names if necessary)
# Use the exact column names; if any KeyError occurs, inspect df.columns

colmap = {
    'Sample ID': 'Sample ID',
    'Date (YYMMDD)': 'Date (YYMMDD)',
    'Fluence': 'Fluence (J/cm2)',
    'O2 Pressure': 'O2 Pressure (Torr)',
    'Working pressure': 'Working pressure (Torr)',
}

# Try to access the columns listed earlier; adapt if column names differ
try:
    fluence = df['Fluence (J/cm2)'].values.astype(float)
    o2_pressure = df['O2 Pressure (Torr)'].values.astype(float)
    substrate_temp = df['Substrate Temp. (C) (based on average of temperature measured near substrate)'].values.astype(float)
    roughness = df['Roughness (nm) (standard deviation of height: 3x3 um, 2 Hz, 512 sam/l, B25 sub only = 0.147 )'].values.astype(float)
    c_lattice_parameter = df['C Lattice Parameter from Fringes (Å) (or 2Tw  peak position*)'].values.astype(float)
    fwhm_rocking_curve = df['FWHM (deg) (Gaussian fitting of rocking curve: B25: sub only, 1 twin=0.0050 )'].values.astype(float)
    integrated_intensity = df['Integrated LaVO4 Intensity: 48-51, normalized, x10^6 [B25 sub only= 4.7 (250520)]'].values.astype(float)
except Exception as e:
    print('Column name mismatch: please inspect df.columns')
    print(df.columns)
    raise e

# Build raw seed points array (physical values)
seed_points_raw = np.column_stack((o2_pressure, substrate_temp, fluence))

# Normalize seed points into [0,1]
seed_points = np.zeros_like(seed_points_raw, dtype=np.float32)
seed_points[:, 0] = normalize(np.log10(seed_points_raw[:, 0]), pressure_min, pressure_max)
seed_points[:, 1] = normalize(seed_points_raw[:, 1], T_min, T_max)
seed_points[:, 2] = normalize(seed_points_raw[:, 2], fluence_min, fluence_max)

# initial X_train (normalized) and placeholders
X_train_initial = seed_points

# NOTE: y_train will be recalculated inside the loop for each (n,k)

# %%
# --- Main loop over weight combinations (n,m,l,k) ---
results = []

# Control flags
plot_each_run = False   # set True if you want to plot the GP results for each (n,k)
save_checkpoints = True # save a .npz checkpoint per run

# --- Loop over n, m, l, k ---
for n, m, l, k in product([0,1,2,3], repeat=4):
    print('\n' + '='*60)
    print(f'Running weights: n={n}, m={m}, l={l}, k={k}')
    alpha = 6.0606 * n
    beta = 0.0876 * m
    gamma = 4.3879 * l
    delta = 0.0338 * k

    lattice_parameter_mismatch = np.abs(c_lattice_parameter - LVO_c_lit)
    y_1 = alpha * lattice_parameter_mismatch
    y_2 = beta * roughness
    y_3 = gamma * fwhm_rocking_curve
    y_4 = delta * integrated_intensity
    y_train_full = (y_1 + y_2 + y_3 + y_4).astype(float)

    mask_good = np.isfinite(y_train_full)
    if not np.all(mask_good):
        print(f'Warning: {np.count_nonzero(~mask_good)} bad rows in y_train; dropping them.')

    X_train_masked = X_train_initial[mask_good]
    y_train_masked = y_train_full[mask_good]

    print(f"Number of samples used to train GP surrogate model for weights n={n}, m={m}, l={l}, k={k}: {X_train_masked.shape[0]}")


    acq, y_pred, y_sampled, paras = step_GP(
        X_train_masked, y_train_masked, X_test_jnp,
        noise_prior_dist=noise_prior_dist,
        lengthscale_prior_dist=lengthscale_prior_dist
    )

    best_pred_idx = int(np.argmin(y_pred))
    best_pred_norm = points_3d[best_pred_idx]
    best_pred_log10P = inverse_normalize(best_pred_norm[0], pressure_min, pressure_max)
    best_pred_T = inverse_normalize(best_pred_norm[1], T_min, T_max)
    best_pred_F = inverse_normalize(best_pred_norm[2], fluence_min, fluence_max)

    meas_best_idx = int(np.argmin(y_train_masked))
    meas_best_norm = X_train_masked[meas_best_idx]
    meas_best_log10P = inverse_normalize(meas_best_norm[0], pressure_min, pressure_max)
    meas_best_T = inverse_normalize(meas_best_norm[1], T_min, T_max)
    meas_best_F = inverse_normalize(meas_best_norm[2], fluence_min, fluence_max)
    meas_best_y = float(y_train_masked[meas_best_idx])

    if save_checkpoints:
        ckpt_name = f'BO_ckpt_n{n}_m{m}_l{l}_k{k}.npz'
        save_data(
            ckpt_name,
            X_train_masked, y_train_masked, points_3d,
            x1=Pressure, x2=Temperature, x3=Fluence,
            y_pred=y_pred, y_sampled=y_sampled, acq=acq,
            next_point=None, running_best=meas_best_y, iteration=0
        )

    if plot_each_run:
        plot_3d_projections(
            y_pred, Pressure, Temperature, Fluence,
            title=f'GP Mean n{n}_m{m}_l{l}_k{k}',
            X_train=X_train_masked, y_train=y_train_masked
        )

    results.append({
        'n': n,
        'm': m,
        'l': l,
        'k': k,
        'alpha': alpha,
        'beta': beta,
        'gamma': gamma,
        'delta': delta,
        'best_pred_score': float(np.min(y_pred)),
        'best_pred_log10P': float(best_pred_log10P),
        'best_pred_pressure_Torr': 10**(float(best_pred_log10P)),
        'best_pred_temperature_C': float(best_pred_T),
        'best_pred_fluence_Jpcm2': float(best_pred_F),
        'meas_best_score': meas_best_y,
        'meas_best_log10P': float(meas_best_log10P),
        'meas_best_pressure_Torr': 10**(float(meas_best_log10P)),
        'meas_best_temperature_C': float(meas_best_T),
        'meas_best_fluence_Jpcm2': float(meas_best_F),
    })

# After loop: Convert to DataFrame and save as before
results_df = pd.DataFrame(results)
print('\nSummary:')
print(results_df)
results_df.to_csv('best_points_by_weights.csv', index=False)
results_df.to_excel('best_points_by_weights.xlsx', index=False)
print('Saved best_points_by_weights.csv/.xlsx to current directory')

# --- Histogram ---
# Number of bins per dimension
n_bins = 7

# Define bin edges for each parameter
pressure_bins = np.linspace(pressure_min, pressure_max, n_bins + 1)
temperature_bins = np.linspace(T_min, T_max, n_bins + 1)
fluence_bins = np.linspace(fluence_min, fluence_max, n_bins + 1)

bin_labels = []
bin_counts = []

best_P = np.array([r['best_pred_log10P'] for r in results])
best_T = np.array([r['best_pred_temperature_C'] for r in results])
best_F = np.array([r['best_pred_fluence_Jpcm2'] for r in results])

# Loop through bins and gather only bins with points
for i in range(n_bins):
    for j in range(n_bins):
        for k in range(n_bins):
            p_min, p_max = pressure_bins[i], pressure_bins[i + 1]
            t_min, t_max = temperature_bins[j], temperature_bins[j + 1]
            f_min, f_max = fluence_bins[k], fluence_bins[k + 1]

            inside_bin = (
                (best_P >= p_min) & (best_P < p_max) &
                (best_T >= t_min) & (best_T < t_max) &
                (best_F >= f_min) & (best_F < f_max)
            )
            count = np.sum(inside_bin)
            if count > 0:
                bin_counts.append(count)
                label = (
                    f"P: {10**p_min:.1e}-{10**p_max:.1e} Torr\n"
                    f"T: {t_min:.0f}-{t_max:.0f} C\n"
                    f"F: {f_min:.2f}-{f_max:.2f} J/cm²"
                )
                bin_labels.append(label)

# Plot only bins with points
plt.figure(figsize=(14, 6), dpi=200)
bars = plt.bar(range(len(bin_counts)), bin_counts, edgecolor='k')
plt.xticks(range(len(bin_counts)), bin_labels, rotation=75, fontsize=9)
plt.ylabel('Count of Best Predicted Points')
plt.title(f'Histogram of Best Predicted Points Binned by 3D Parameter Space ({n_bins} bins per dimension)')

# Add count labels above each bar
for bar in bars:
    height = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2,  # x-position: center of the bar
        height,                             # y-position: top of the bar
        f'{int(height)}',                   # label text
        ha='center', va='bottom', fontsize=10
    )

plt.tight_layout()
plt.show()



# Updated Plotly 3D scatter for best predicted points, now showing (n,m,l,k)
hover_text = []
for r in results:
    txt = (f"n={r['n']}, m={r['m']}, l={r['l']}, k={r['k']}<br>"
           f"score={r['best_pred_score']:.4f}<br>"
           f"P={r['best_pred_pressure_Torr']:.2e} Torr<br>"
           f"T={r['best_pred_temperature_C']:.0f} C<br>"
           f"F={r['best_pred_fluence_Jpcm2']:.2f} J/cm^2")
    hover_text.append(txt)

fig = go.Figure(data=[
    go.Scatter3d(
        x=[r['best_pred_log10P'] for r in results],  # log10(P)
        y=[r['best_pred_temperature_C'] for r in results],
        z=[r['best_pred_fluence_Jpcm2'] for r in results],
        mode='markers',
        hovertext=hover_text,
        marker=dict(
            size=10,
            color=[r['best_pred_score'] for r in results],
            colorscale='Viridis',
            colorbar=dict(title='Score')
        )
    )
])
fig.update_layout(
    scene=dict(
        xaxis_title='log10(Pressure) (Torr)',
        yaxis_title='Temperature (C)',
        zaxis_title='Fluence (J/cm^2)'
    ),
    title='Best Predicted Points for Different Weight Sets (n,m,l,k)',
    width=1100,   # wider plot
    height=700    # taller plot
)
fig.show()

print('Done. Check the CSV/XLSX for the full table of results.')

