## **Scalable Constrained Bayesian Optimization (SCBO) of Plasma boundaries for Stellarator Design**

### Import all necessary libraries and dependencies

In [1]:
import math
import os
import warnings
from dataclasses import dataclass

import gpytorch
import torch
from torch import Tensor
from torch.quasirandom import SobolEngine

from gpytorch.constraints import Interval
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood

from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.fit import fit_gpytorch_mll

# Constrained Max Posterior Sampling s a new sampling class, similar to MaxPosteriorSampling,
# which implements the constrained version of Thompson Sampling described in [1].
from botorch.generation.sampling import ConstrainedMaxPosteriorSampling
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.utils.transforms import normalize, unnormalize

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.double
tkwargs = {"device": device, "dtype": dtype}

SMOKE_TEST = os.environ.get("SMOKE_TEST")

  from .autonotebook import tqdm as notebook_tqdm


### Add normalization transform on the fourier coefficients

In [2]:
def fit_minmax(r_cos_data, z_sin_data, margin=0.05):
    """Min-Max scaling with physical margins."""
    # Compute bounds with margins for exploration
    r_cos_min = r_cos_data.min(dim=0)[0]
    r_cos_max = r_cos_data.max(dim=0)[0]
    r_cos_range = r_cos_max - r_cos_min

    z_sin_min = z_sin_data.min(dim=0)[0]
    z_sin_max = z_sin_data.max(dim=0)[0]
    z_sin_range = z_sin_max - z_sin_min

    # Store bounds for normalization
    bounds = torch.vstack(
        [
            torch.cat(
                [
                    r_cos_min - margin * r_cos_range,
                    z_sin_min - margin * z_sin_range,
                ]
            ),
            torch.cat(
                [
                    r_cos_max + margin * r_cos_range,
                    z_sin_max + margin * z_sin_range,
                ]
            ),
        ]
    )

    return bounds


def transform(x, bounds):
    """Normalize data to [0,1]^d."""
    return normalize(x, bounds=bounds)


def untransform(x, bounds):
    """Unnormalize data from [0,1]^d to original space."""
    return unnormalize(x, bounds=bounds)

### Pre-process data and fit the normalization transform

In [3]:
n_coefficients = 45

# Load data and filter out NaN values
data_tensor = torch.load("../data/batch_1.pt").to(dtype=dtype)
mask = torch.isnan(data_tensor).any(dim=1)
filtered_tensor = data_tensor[~mask]

r_cos_data = filtered_tensor[:, :n_coefficients]
z_sin_data = filtered_tensor[:, n_coefficients : n_coefficients * 2]

# Fit bounds for normalization
bounds = fit_minmax(r_cos_data, z_sin_data)

### Define the evaluator class to evaluate objectives and constraints using the forward model

In [4]:
from constellaration.forward_model import forward_model
from constellaration.geometry.surface_rz_fourier import SurfaceRZFourier


class Evaluator:
    """
    Class for evaluating a plasma boundary with stellarator symmetry constraints.
    """

    def __init__(self, bounds):
        self.n_poloidal_modes = 5
        self.n_toroidal_modes = 9
        self.n_coeffecients = self.n_poloidal_modes * self.n_toroidal_modes
        self.aspect_ratio_upper_bound = 4.0
        self.average_triangularity_upper_bound = -0.5
        self.edge_rotational_transform_over_n_field_periods_lower_bound = 0.3
        self.bounds = bounds
        self._cache = {}

    def evaluate(self, x):
        """Evaluate the forward model once and cache the result."""
        x_tuple = tuple(x)  # Ensure x is hashable for caching
        if x_tuple not in self._cache.keys():
            # Unnormalize before doing anything with the data
            # x = untransform(x, self.bounds)

            # Apply constraints to enforce stellarator symmetry
            # x = self.enforce_constraints(x)

            # Retrieve Fourier series coefficients from the unnormalized input and reshape according to the number of angular modes
            r_cos = x[: self.n_coeffecients].reshape(
                self.n_poloidal_modes, self.n_toroidal_modes
            )
            z_sin = x[self.n_coeffecients : 2 * self.n_coeffecients].reshape(
                self.n_poloidal_modes, self.n_toroidal_modes
            )

            # Instantiate the boundary surface with the Fourier coefficients
            boundary_surface = SurfaceRZFourier(
                r_cos=r_cos.numpy(), z_sin=z_sin.numpy()
            )

            # Evaluate the forward model with the boundary surface and retrieve results
            metrics, _ = forward_model(boundary=boundary_surface)
            self._cache[x_tuple] = {
                "max_elongation": metrics.max_elongation,
                "aspect_ratio": metrics.aspect_ratio,
                "average_triangularity": metrics.average_triangularity,
                "edge_rotational_transform_over_n_field_periods": metrics.edge_rotational_transform_over_n_field_periods,
            }

        return self._cache[x_tuple]

    def get_objective(self, x):
        results = self.evaluate(x)
        return results[tuple(x)]["max_elongation"]  # Objective based on max_elongation

    def aspect_ratio_constraint(self, x):
        # Requires get_objective for the same x to be run prior to computation of constraints
        constraint_value = (
            self._cache[tuple(x)]["aspect_ratio"] - self.aspect_ratio_upper_bound
        )
        return constraint_value

    def average_triangularity_constraint(self, x):
        # Requires get_objective for the same x to be run prior to computation of constraints
        constraint_value = (
            self._cache[tuple(x)]["average_triangularity"]
            - self.average_triangularity_upper_bound
        )
        return constraint_value

    def edge_rotational_transform_over_n_field_periods_constraint(self, x):
        # Requires get_objective for the same x to be run prior to computation of constraints
        constraint_value = (
            self.edge_rotational_transform_over_n_field_periods_lower_bound
            - self._cache[tuple(x)]["edge_rotational_transform_over_n_field_periods"]
        )
        return constraint_value

    def enforce_constraints(self, x):
        """Enforce stellarator symmetry constraints."""
        # Reshape to 5x9 arrays
        r_cos = x[:n_coefficients].reshape(self.n_poloidal_modes, self.n_toroidal_modes)
        z_sin = x[n_coefficients : n_coefficients * 2].reshape(
            self.n_poloidal_modes, self.n_toroidal_modes
        )

        # Enforce constraints for stellarator symmetry
        # r_cos for m=0 and n<0 must be 0.0 for stellarator symmetric surfaces
        r_cos[0, :4] = 0.0
        # Major radius has to equal 1.0
        r_cos[:, 4] = 1.0
        # z_sin for m=0 and n<=0 must be 0.0 for stellarator symmetric surfaces
        z_sin[0, :5] = 0.0

        return torch.cat([r_cos.flatten(), z_sin.flatten()])

    def clear_cache(self):
        """Clear the cache to free memory."""
        self._cache.clear()

### Define all the evaluation functions as a thin wrapper around methods of the `Evaluator` class to stick with semantics of the BoTorch example

In [5]:
evaluator = Evaluator(bounds=bounds)
dim = evaluator.n_coeffecients * 2

batch_size = 1
max_cholesky_size = float("inf")  # Always use Cholesky


def eval_objective(x, evaluator):
    """Evaluate objective with proper normalization."""
    return evaluator.get_objective(x)


def eval_c1(x, evaluator):
    """Evaluate constraint 1 with proper normalization."""
    return evaluator.aspect_ratio_constraint(x)


def eval_c2(x, evaluator):
    """Evaluate constraint 2 with proper normalization."""
    return evaluator.average_triangularity_constraint(x)


def eval_c3(x, evaluator):
    """Evaluate constraint 3 with proper normalization."""
    return evaluator.edge_rotational_transform_over_n_field_periods_constraint(x)

### Define the SCBO State

In [6]:
@dataclass
class ScboState:
    dim: int
    batch_size: int
    length: float = 0.8
    length_min: float = 0.5**7
    length_max: float = 1.6
    failure_counter: int = 0
    failure_tolerance: int = float("nan")  # Note: Post-initialized
    success_counter: int = 0
    success_tolerance: int = 10  # Note: The original paper uses 3
    best_value: float = -float("inf")
    best_constraint_values: Tensor = torch.ones(2, **tkwargs) * torch.inf
    restart_triggered: bool = False

    def __post_init__(self):
        self.failure_tolerance = math.ceil(
            max([4.0 / self.batch_size, float(self.dim) / self.batch_size])
        )

### Utility functions to update the state

In [7]:
def update_tr_length(state: ScboState):
    # Update the length of the trust region according to
    # success and failure counters
    # (Just as in original TuRBO paper)
    if state.success_counter == state.success_tolerance:  # Expand trust region
        state.length = min(2.0 * state.length, state.length_max)
        state.success_counter = 0
    elif state.failure_counter == state.failure_tolerance:  # Shrink trust region
        state.length /= 2.0
        state.failure_counter = 0

    if state.length < state.length_min:  # Restart when trust region becomes too small
        state.restart_triggered = True

    return state


def get_best_index_for_batch(Y: Tensor, C: Tensor):
    """Return the index for the best point."""
    is_feas = (C <= 0).all(dim=-1)
    if is_feas.any():  # Choose best feasible candidate
        score = Y.clone()
        score[~is_feas] = -float("inf")
        return score.argmax()
    return C.clamp(min=0).sum(dim=-1).argmin()


def update_state(state, Y_next, C_next):
    """Method used to update the TuRBO state after each step of optimization.

    Success and failure counters are updated according to the objective values
    (Y_next) and constraint values (C_next) of the batch of candidate points
    evaluated on the optimization step.

    As in the original TuRBO paper, a success is counted whenver any one of the
    new candidate points improves upon the incumbent best point. The key difference
    for SCBO is that we only compare points by their objective values when both points
    are valid (meet all constraints). If exactly one of the two points being compared
    violates a constraint, the other valid point is automatically considered to be better.
    If both points violate some constraints, we compare them inated by their constraint values.
    The better point in this case is the one with minimum total constraint violation
    (the minimum sum of constraint values)"""

    # Pick the best point from the batch
    best_ind = get_best_index_for_batch(Y=Y_next, C=C_next)
    y_next, c_next = Y_next[best_ind], C_next[best_ind]

    if (c_next <= 0).all():
        # At least one new candidate is feasible
        improvement_threshold = state.best_value + 1e-3 * math.fabs(state.best_value)
        if y_next > improvement_threshold or (state.best_constraint_values > 0).any():
            state.success_counter += 1
            state.failure_counter = 0
            state.best_value = y_next.item()
            state.best_constraint_values = c_next
        else:
            state.success_counter = 0
            state.failure_counter += 1
    else:
        # No new candidate is feasible
        total_violation_next = c_next.clamp(min=0).sum(dim=-1)
        total_violation_center = state.best_constraint_values.clamp(min=0).sum(dim=-1)
        if total_violation_next < total_violation_center:
            state.success_counter += 1
            state.failure_counter = 0
            state.best_value = y_next.item()
            state.best_constraint_values = c_next
        else:
            state.success_counter = 0
            state.failure_counter += 1

    # Update the length of the trust region according to the success and failure counters
    state = update_tr_length(state)
    return state


# Define example state
state = ScboState(dim=dim, batch_size=batch_size)
print(state)

ScboState(dim=90, batch_size=1, length=0.8, length_min=0.0078125, length_max=1.6, failure_counter=0, failure_tolerance=90, success_counter=0, success_tolerance=10, best_value=-inf, best_constraint_values=tensor([inf, inf], dtype=torch.float64), restart_triggered=False)


### Utility function to generate a batch of candidates

In [8]:
def generate_batch(
    state,
    model,  # GP model
    X,  # Evaluated points on the domain [0, 1]^d
    Y,  # Function values
    C,  # Constraint values
    batch_size,
    n_candidates,  # Number of candidates for Thompson sampling
    constraint_model,
    sobol: SobolEngine,
):
    # assert X.min() >= 0.0 and X.max() <= 1.0 and torch.all(torch.isfinite(Y))

    # Create the TR bounds
    best_ind = get_best_index_for_batch(Y=Y, C=C)
    x_center = X[best_ind, :].clone()
    tr_lb = torch.clamp(x_center - state.length / 2.0, 0.0, 1.0)
    tr_ub = torch.clamp(x_center + state.length / 2.0, 0.0, 1.0)

    # Thompson Sampling w/ Constraints (SCBO)
    dim = X.shape[-1]
    pert = sobol.draw(n_candidates).to(dtype=torch.float32, device=device)
    pert = tr_lb + (tr_ub - tr_lb) * pert

    # Create a perturbation mask
    prob_perturb = min(20.0 / dim, 1.0)
    mask = torch.rand(n_candidates, dim, **tkwargs) <= prob_perturb
    ind = torch.where(mask.sum(dim=1) == 0)[0]
    mask[ind, torch.randint(0, dim - 1, size=(len(ind),), device=device)] = 1

    # Create candidate points from the perturbations and the mask
    X_cand = x_center.expand(n_candidates, dim).clone()
    X_cand[mask] = pert[mask]

    # Sample on the candidate points using Constrained Max Posterior Sampling
    constrained_thompson_sampling = ConstrainedMaxPosteriorSampling(
        model=model, constraint_model=constraint_model, replacement=False
    )
    with torch.no_grad():
        X_next = constrained_thompson_sampling(X_cand, num_samples=batch_size)

    return X_next

### Retrieve an initial dataset to train the GP

In [9]:
n_init = 100

# Generate initial data
train_X, y_init, c1_init, c2_init, c3_init = (
    filtered_tensor[:n_init, :90],
    filtered_tensor[:n_init, 90],
    filtered_tensor[:n_init, 91],
    filtered_tensor[:n_init, 92],
    filtered_tensor[:n_init, 93],
)
train_Y = y_init.unsqueeze(-1)
C1 = c1_init.unsqueeze(-1)
C2 = c2_init.unsqueeze(-1)
C3 = c3_init.unsqueeze(-1)

### Run the SCBO algorithm (looks very much like TuRBO)

In [10]:
# Initialize TuRBO state
state = ScboState(dim, batch_size=batch_size)

# Note: We use 2000 candidates here to make the tutorial run faster.
# SCBO actually uses min(5000, max(2000, 200 * dim)) candidate points by default.
N_CANDIDATES = 2000 if not SMOKE_TEST else 4
sobol = SobolEngine(dim, scramble=True, seed=1)


def get_fitted_model(X, Y):
    likelihood = GaussianLikelihood(noise_constraint=Interval(1e-8, 1e-3))
    covar_module = ScaleKernel(  # Use the same lengthscale prior as in the TuRBO paper
        MaternKernel(
            nu=2.5, ard_num_dims=dim, lengthscale_constraint=Interval(0.005, 4.0)
        )
    )
    model = SingleTaskGP(
        X,
        Y,
        covar_module=covar_module,
        likelihood=likelihood,
        # input_transform=Normalize(d=X.shape[-1]),
        outcome_transform=Standardize(m=1),
    )
    mll = ExactMarginalLogLikelihood(model.likelihood, model)

    print("Training Model")

    with gpytorch.settings.max_cholesky_size(max_cholesky_size):
        fit_gpytorch_mll(mll)

    return model


while not state.restart_triggered:  # Run until TuRBO converges
    # Fit GP models for objective and constraints
    model = get_fitted_model(train_X, train_Y)
    c1_model = get_fitted_model(train_X, C1)
    c2_model = get_fitted_model(train_X, C2)
    c3_model = get_fitted_model(train_X, C3)

    print(f"Started TuRBO")

    # Generate a batch of candidates
    with gpytorch.settings.max_cholesky_size(max_cholesky_size):
        X_next = generate_batch(
            state=state,
            model=model,
            # X=train_X,
            # Y=train_Y,
            X=model.train_inputs[0],
            Y=model.train_targets,
            C=torch.cat((C1, C2, C3), dim=-1),
            batch_size=batch_size,
            n_candidates=N_CANDIDATES,
            constraint_model=ModelListGP(c1_model, c2_model, c3_model),
            sobol=sobol,
        )

    # Evaluate both the objective and constraints for the selected candidaates
    Y_next = torch.tensor(
        [eval_objective(x, evaluator) for x in X_next], dtype=dtype, device=device
    ).unsqueeze(-1)
    C1_next = torch.tensor(
        [eval_c1(x, evaluator) for x in X_next], dtype=dtype, device=device
    ).unsqueeze(-1)
    C2_next = torch.tensor(
        [eval_c2(x, evaluator) for x in X_next], dtype=dtype, device=device
    ).unsqueeze(-1)
    C3_next = torch.tensor(
        [eval_c3(x, evaluator) for x in X_next], dtype=dtype, device=device
    ).unsqueeze(-1)
    C_next = torch.cat([C1_next, C2_next, C3_next], dim=-1)

    # Update TuRBO state
    state = update_state(state=state, Y_next=Y_next, C_next=C_next)

    # Append data. Note that we append all data, even points that violate
    # the constraints. This is so our constraint models can learn more
    # about the constraint functions and gain confidence in where violations occur.
    train_X = torch.cat((train_X, X_next), dim=0)
    train_Y = torch.cat((train_Y, Y_next), dim=0)
    C1 = torch.cat((C1, C1_next), dim=0)
    C2 = torch.cat((C2, C2_next), dim=0)
    C3 = torch.cat((C3, C3_next), dim=0)

    # Print current status. Note that state.best_value is always the best
    # objective value found so far which meets the constraints, or in the case
    # that no points have been found yet which meet the constraints, it is the
    # objective value of the point with the minimum constraint violation.
    if (state.best_constraint_values <= 0).all():
        print(
            f"{len(train_X)}) Best value: {state.best_value:.2e}, TR length: {state.length:.2e}"
        )
    else:
        violation = state.best_constraint_values.clamp(min=0).sum()
        print(
            f"{len(train_X)}) No feasible point yet! Smallest total violation: "
            f"{violation:.2e}, TR length: {state.length:.2e}"
        )

Training Model
Training Model
Training Model
Training Model
Started TuRBO


RuntimeError: Thread 0:
	FATAL ERROR in thread=0Thread 0:
	FATAL ERROR in thread=0