In [1]:
import numpy as np
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from scipy.spatial.distance import euclidean
import matplotlib.pyplot as plt
import torch
import pandas as pd
import numpy as np
from numpy import random
import torch.nn as nn

import ipywidgets as widgets
from IPython.display import display, clear_output


from modules import EDMPrecond
from diffusion import EdmSampler

import warnings
warnings.filterwarnings("ignore")

In [2]:
parameter_bounds = {
    "L1-PL3-PZM52_Setpoint": (-3.00, 3.00),
    "L1-PL3-PZM53_Setpoint": (-3.00, 3.00),
    "L1-PL3-PZM56_Setpoint": (-1.5910, 0.6349),
    "L1-PL3-PZM57_Setpoint": (-3.5833, -0.0177),
    "L1-PL3-PZM58_Setpoint": (-3.0000, 1.0000),
    "L1-PL3-PZM59_Setpoint": (-3.0000, 1.0000),
    "L1-PL4-PZM66_Setpoint": (-3.7500, 3.1082),
    "L1-PL4-PZM67_Setpoint": (-3.7500, 0.2802),
    "L1-PL4-PZM68_Setpoint": (-3.7500, 3.7500),
    "L1-PL4-PZM69_Setpoint": (-1.9848, 3.7500),
    "L1-INJ-PM70:VAL_CAL": (-0.3500, 41.8900),
    "L1-OPA3-5_2-PM98:VAL_CAL": (5.6200, 66.2500),
    "L1-INJ-PM70:VAL_CAL_diff": (-33.7200, 1.0800),
}

all_param_names = [
    "L1-PL3-PZM52_Setpoint", "L1-PL3-PZM53_Setpoint", "L1-PL3-PZM56_Setpoint", 
    "L1-PL3-PZM57_Setpoint", "L1-PL3-PZM58_Setpoint", "L1-PL3-PZM59_Setpoint",
    "L1-PL4-PZM66_Setpoint", "L1-PL4-PZM67_Setpoint", "L1-PL4-PZM68_Setpoint",
    "L1-PL4-PZM69_Setpoint", "L1-INJ-PM70:VAL_CAL", "L1-OPA3-5_2-PM98:VAL_CAL",
    "L1-INJ-PM70:VAL_CAL_diff"
]

In [3]:
epoch = 40
device_name = 'cuda:1'
name = "edm_bs16_do0_cgt20_cg3_ns30_lr1e-3_0_e40_ed999"
model_path = "models/" + name + "/ema_ckpt" + str(epoch) + ".pt"

In [4]:
def load_models(model_path="models/edm_bs16_do0_cgt20_cg3_ns30_lr1e-3_0_e40_ed999/ema_ckpt.pt",
                device="cuda:1",
                noise_steps=30,
                settings_dim=13):
    """
    Load and return the EDM model, sampler.
    """
    device = torch.device(device if torch.cuda.is_available() else 'cpu')

    model = EDMPrecond(device=device).to(device)
    ckpt = torch.load(model_path, map_location=device, weights_only=True)
    model.load_state_dict(ckpt)

    sampler = EdmSampler(net=model, num_steps=noise_steps)

    model.eval()

    return model, sampler


def edm_sampler(model, 
                sampler, 
                optimized_indices, 
                fixed_indices, 
                target_cond_vector, 
                optimized_params,
                settings_dim=13, 
                cfg_scale=3):
    """
    Uses a preloaded model, and sampler to generate a spectrum.
    This version handles non-contiguous optimized and fixed parameters using indices.
    """
    device = next(model.parameters()).device

    full_conditional_vector = torch.zeros(len(target_cond_vector), dtype=torch.float32).to(device)
    full_conditional_vector[fixed_indices] = torch.tensor(target_cond_vector[fixed_indices], dtype=torch.float32).to(device)

    #full_conditional_vector[optimized_indices] = torch.tensor(target_cond_vector[optimized_indices], dtype=torch.float32).to(device)
    full_conditional_vector[optimized_indices] = torch.tensor(optimized_params, dtype=torch.float32).to(device)
    with torch.no_grad():
        pred = sampler.sample(resolution=1024,
                              device=device,
                              settings=full_conditional_vector.unsqueeze(0),
                              n_samples=1,
                              cfg_scale=cfg_scale,
                              settings_dim=settings_dim)

    return pred.cpu().numpy().flatten()


def compute_mean_wavelength(intensities, wavelengths):
    """
    Compute center of gravity (mean wavelength) of a spectrum.
    """
    intensities_shifted = intensities.copy()
    intensities_shifted += abs(min(intensities_shifted))  # Shift values to be non-negative
    
    total_intensity = np.sum(intensities_shifted)
    
    if total_intensity == 0:
        # Safe fallback if spectrum is empty or dead
        mean_wavelength = np.mean(wavelengths)
    else:
        mean_wavelength = np.sum(wavelengths * intensities_shifted) / total_intensity

    return mean_wavelength

def compute_wavelength_at_max_intensity(intensities, wavelengths):
    """
    Find the wavelength corresponding to the maximum intensity.
    """
    max_idx = np.argmax(intensities)
    return wavelengths[max_idx]


def loss_function(generated_intensities, 
                  target_mean_wavelength, 
                  target_max_wavelength,
                  target_max_intensity,
                  wavelengths):
    """
    Loss function based on the sum of differences in:
    - Mean Wavelength
    - Wavelength at Maximum Intensity
    - Maximum Intensity Value
    """
    # Generated
    gen_mean_wavelength = compute_mean_wavelength(generated_intensities, wavelengths)
    gen_max_wavelength = compute_wavelength_at_max_intensity(generated_intensities, wavelengths)
    gen_max_intensity = np.max(generated_intensities)

    # Compute losses
    mean_wavelength_loss = abs(gen_mean_wavelength - target_mean_wavelength)
    max_wavelength_loss = abs(gen_max_wavelength - target_max_wavelength)
    max_intensity_loss = abs(gen_max_intensity - target_max_intensity)

    total_loss = mean_wavelength_loss + max_wavelength_loss + max_intensity_loss

    return total_loss


def objective(optimized_params, model, sampler, 
              optimized_indices, fixed_indices, 
              target_cond_vector, 
              target_mean_wavelength, 
              target_max_wavelength,
              target_max_intensity,
              wavelengths):
    """
    Optimization objective using target mean wavelength.
    """
    generated_intensities = edm_sampler(model, 
                                        sampler, 
                                        optimized_indices, 
                                        fixed_indices, 
                                        target_cond_vector, 
                                       optimized_params,)
    
    return loss_function(generated_intensities, 
                          target_mean_wavelength, 
                          target_max_wavelength,
                          target_max_intensity,
                          wavelengths)


def get_param_space(optimized_param_names, parameter_bounds):
    """
    Dynamically generates the parameter space based on the selected optimized parameters and given bounds.

    Parameters:
        optimized_param_names (list of str): Names of parameters to optimize.
        parameter_bounds (dict): Dictionary with parameter names as keys and (min, max) tuples as values.

    Returns:
        list: A list of `Real` objects defining the search space.
    """
    param_space = []
    for param_name in optimized_param_names:
        if param_name in parameter_bounds:
            min_val, max_val = parameter_bounds[param_name]
            param_space.append(Real(min_val, max_val, name=param_name))
        else:
            raise ValueError(f"Parameter {param_name} not found in provided parameter bounds.")

    return param_space


def optimize_parameters(model, 
                        sampler, 
                        optimized_indices, 
                        fixed_indices, 
                        target_cond_vector, 
                        target_mean_wavelength, 
                        target_max_wavelength,
                        target_max_intensity,
                        wavelengths,
                        param_space,
                        n_calls=10):
    """
    Runs Bayesian optimization to match the target mean wavelength.
    """

    def objective_wrapped(optimized_params):
        return objective(optimized_params, model, sampler, 
                         optimized_indices, fixed_indices, 
                         target_cond_vector, 
                         target_mean_wavelength, 
                         target_max_wavelength,
                         target_max_intensity,
                         wavelengths)

    res = gp_minimize(objective_wrapped, param_space, n_calls=n_calls, random_state=42, acq_func="gp_hedge")

    best_params = res.x
    best_loss = res.fun
    return best_params, best_loss

def plot_results_bayesian_optimization(best_params,
                                       target_cond_vector,
                                       target_mean_wavelength, 
                                       target_max_wavelength,
                                       target_max_intensity,
                                       wavelengths, 
                                       optimized_indices, 
                                       fixed_indices, 
                                       model, 
                                       sampler):
    """
    Generates predicted intensities using the best parameters, calculates mean wavelength error, and plots the results.
    """
    # Create full conditional vector
    full_conditional_vector = np.zeros(len(target_cond_vector))
    full_conditional_vector[fixed_indices] = target_cond_vector[fixed_indices]
    full_conditional_vector[optimized_indices] = best_params


    predicted_intensities = edm_sampler(model, 
                                        sampler, 
                                        optimized_indices, 
                                        fixed_indices, 
                                        target_cond_vector,
                                        best_params)

    mean_wavelength_predicted = compute_mean_wavelength(predicted_intensities, wavelengths)
    max_wavelength_predicted = compute_wavelength_at_max_intensity(predicted_intensities, wavelengths)
    predicted_max_intensity = np.max(predicted_intensities)
    
    fig, ax = plt.subplots(figsize=(8, 5))

    # Plot spectrum
    ax.plot(wavelengths, predicted_intensities, label='Predicted Spectrum', color='tab:orange')

    # Plot mean wavelength lines
    ax.axvline(target_mean_wavelength, color='tab:blue', linestyle='--', label=f'Target mean Œª = {target_mean_wavelength:.2f}')
    ax.axvline(mean_wavelength_predicted, color='tab:red', linestyle='--', label=f'Predicted mean Œª = {mean_wavelength_predicted:.2f}')

    # Plot max wavelength lines (vertical)
    ax.axvline(target_max_wavelength, color='tab:blue', linestyle=':', label=f'Target Max Œª = {target_max_wavelength:.2f}')
    ax.axvline(max_wavelength_predicted, color='tab:red', linestyle=':', label=f'Predicted Max Œª = {max_wavelength_predicted:.2f}')

    # Plot max intensity lines (horizontal)
    ax.axhline(target_max_intensity, color='tab:blue', linestyle='-.', label=f'Target Max Intensity = {target_max_intensity:.2f}')
    ax.axhline(predicted_max_intensity, color='tab:red', linestyle='-.', label=f'Predicted Max Intensity = {predicted_max_intensity:.2f}')

    ax.set_title("Predicted Spectrum")
    ax.set_xlabel("Wavelengths")
    ax.set_ylabel("Intensity")
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

    ax.set_ylim(top=1200)
    
    plt.show()

def run_bayesian_optimization(
                            target_cond_vector,
                            target_mean_wavelength, 
                            target_max_wavelength,
                            target_max_intensity,
                            optimized_param_names,
                            parameter_bounds,
                            all_param_names,
                            model,
                            sampler,
                            wavelengths,
                            n_calls=50):
    """
    Runs Bayesian optimization based on user-input parameters and target mean wavelength.
    """

    # Map parameter names to indices
    optimized_indices = [all_param_names.index(name) for name in optimized_param_names]
    fixed_indices = [i for i in range(len(all_param_names)) if i not in optimized_indices]

    # Generate search space
    param_space = get_param_space(optimized_param_names, parameter_bounds)

    # Run optimization
    print(f"üöÄ Running Bayesian optimization with {n_calls} calls...")
    best_params, best_loss = optimize_parameters(
                                        model,
                                        sampler,
                                        optimized_indices,
                                        fixed_indices,
                                        np.array(target_cond_vector),
                                        target_mean_wavelength,
                                        target_max_wavelength,
                                        target_max_intensity,
                                        wavelengths,
                                        param_space,
                                        n_calls=n_calls
                                    )

    print(f"\n‚úÖ Optimization completed")
    # Plot results
    plot_results_bayesian_optimization(
                                        best_params,
                                        target_cond_vector,
                                        target_mean_wavelength,
                                        target_max_wavelength,
                                        target_max_intensity,
                                        wavelengths,
                                        optimized_indices,
                                        fixed_indices,
                                        model,
                                        sampler
                                    )

    return best_params, best_loss

import ipywidgets as widgets
from IPython.display import display, clear_output
import numpy as np

def launch_bayesian_optimization_tool(all_param_names, 
                                      parameter_bounds, 
                                      model, 
                                      sampler, 
                                      wavelengths):
    default_values = [-0.612 ,  0.1487, -0.4105, 
                      -1.2972, -1.    , -1.    , 
                      -2.4977, -0.4915, -2.5347,  
                      0.7122, 31.08  , 43.41  ,  0.    ]

    # Layout for consistent width
    text_input_layout = widgets.Layout(width='150px')

    # Create a checkbox, text input, and range label for each parameter
    param_widgets = {}
    for name, default_val in zip(all_param_names, default_values):
        checkbox = widgets.Checkbox(value=False, description=name, layout=widgets.Layout(width='350px'))
        text_input = widgets.FloatText(value=default_val, layout=text_input_layout)
        if name in parameter_bounds:
            bounds = parameter_bounds[name]
            range_label = widgets.Label(f"usual range: [{bounds[0]}, {bounds[1]}]", layout=widgets.Layout(width='200px'))
        else:
            range_label = widgets.Label("Range: N/A", layout=widgets.Layout(width='200px'))
        param_widgets[name] = (checkbox, text_input, range_label)

    select_all_button = widgets.Button(
        description="Select All",
        button_style='info',
        icon='check'
    )

    deselect_all_button = widgets.Button(
        description="Deselect All",
        button_style='warning',
        icon='times'
    )

    n_calls_widget = widgets.IntSlider(
        value=50,
        min=10,
        max=200,
        step=10,
        description='n_calls:'
    )

    # Target specification widgets
    target_mean_wavelength_widget = widgets.FloatText(value=840.0, layout=text_input_layout, style={'description_width': 'initial'})
    target_max_wavelength_widget = widgets.FloatText(value=911.0, layout=text_input_layout, style={'description_width': 'initial'})
    target_max_intensity_widget = widgets.FloatText(value=800.0, layout=text_input_layout, style={'description_width': 'initial'})

    run_button = widgets.Button(
        description="Run Optimization",
        button_style='success',
        icon='play'
    )

    output = widgets.Output()

    def on_select_all_clicked(b):
        for cb, _, _ in param_widgets.values():
            cb.value = True

    def on_deselect_all_clicked(b):
        for cb, _, _ in param_widgets.values():
            cb.value = False

    select_all_button.on_click(on_select_all_clicked)
    deselect_all_button.on_click(on_deselect_all_clicked)

    def on_run_button_clicked(b):
        with output:
            clear_output()
            selected_params = [name for name, (cb, _, _) in param_widgets.items() if cb.value]
            all_text_values = {name: text.value for name, (_, text, _) in param_widgets.items()}

            n_calls = n_calls_widget.value

            target_mean_wavelength = target_mean_wavelength_widget.value
            target_max_wavelength = target_max_wavelength_widget.value
            target_max_intensity = target_max_intensity_widget.value

            if not selected_params:
                print("‚ö†Ô∏è Please select at least one parameter to optimize.")
                return

            # Build the target_cond_vector from all current values
            target_cond_vector = np.array([all_text_values[name] for name in all_param_names])

            # Build selected bounds
            selected_bounds = {
                param: parameter_bounds[param]
                for param in selected_params
            }

            print("üìà Launching optimization...")
            best_params, best_loss = run_bayesian_optimization(
                                                                target_cond_vector,
                                                                target_mean_wavelength, 
                                                                target_max_wavelength,
                                                                target_max_intensity,
                                                                selected_params,
                                                                selected_bounds,
                                                                all_param_names,
                                                                model,
                                                                sampler,
                                                                wavelengths,
                                                                n_calls=n_calls
                                                                )
            print("\n‚úÖ Best Parameters:")
            for name, value in zip(selected_params, best_params):
                print(f"  {name}: {value:.5f}")

    run_button.on_click(on_run_button_clicked)

    def display_widgets():
        clear_output(wait=True)

        param_rows = []
        for name, (checkbox, text_input, range_label) in param_widgets.items():
            row = widgets.HBox([checkbox, text_input, range_label])
            param_rows.append(row)

        param_box = widgets.VBox(param_rows)

        # Spectrum target properties arranged properly left-aligned
        target_rows = widgets.VBox([
        widgets.HBox([
            widgets.Label("Target Mean Œª", layout=widgets.Layout(width='200px')), 
            target_mean_wavelength_widget,
            widgets.Label("usual range: [830, 850]", layout=widgets.Layout(width='200px'))
        ]),
        widgets.HBox([
            widgets.Label("Target Max Œª", layout=widgets.Layout(width='200px')), 
            target_max_wavelength_widget,
            widgets.Label("usual range: [794, 913]", layout=widgets.Layout(width='200px'))
        ]),
        widgets.HBox([
            widgets.Label("Target Max Intensity", layout=widgets.Layout(width='200px')), 
            target_max_intensity_widget,
            widgets.Label("usual range: [600, 820]", layout=widgets.Layout(width='200px'))
        ])
        ])

        display(widgets.VBox([
            widgets.HBox([select_all_button, deselect_all_button]),
            widgets.HTML("<h3><b>Select parameters to optimize and set values for fixed parameters:</b></h3>"),
            param_box,
            widgets.HTML("<h3><b>Target spectrum properties:</b></h3>"),
            target_rows,
            n_calls_widget,
            run_button,
            output
        ]))

    display_widgets()


In [8]:
wavelengths = np.load('../data/wavelengths.npy')
model, sampler = load_models(model_path=model_path,
                            device=device_name)

In [9]:
launch_bayesian_optimization_tool(
                                  all_param_names, 
                                  parameter_bounds, 
                                  model, 
                                  sampler, 
                                  wavelengths
                                    )

VBox(children=(HBox(children=(Button(button_style='info', description='Select All', icon='check', style=Button‚Ä¶

In [118]:
launch_bayesian_optimization_tool(
                                  all_param_names, 
                                  parameter_bounds, 
                                  model, 
                                  sampler, 
                                  wavelengths
                                    )

VBox(children=(HBox(children=(Button(button_style='info', description='Select All', icon='check', style=Button‚Ä¶

In [120]:
launch_bayesian_optimization_tool(
                                  all_param_names, 
                                  parameter_bounds, 
                                  model, 
                                  sampler, 
                                  wavelengths
                                    )

VBox(children=(HBox(children=(Button(button_style='info', description='Select All', icon='check', style=Button‚Ä¶

In [119]:
launch_bayesian_optimization_tool(
                                  all_param_names, 
                                  parameter_bounds, 
                                  model, 
                                  sampler, 
                                  wavelengths
                                    )

VBox(children=(HBox(children=(Button(button_style='info', description='Select All', icon='check', style=Button‚Ä¶

In [121]:
launch_bayesian_optimization_tool(
                                  all_param_names, 
                                  parameter_bounds, 
                                  model, 
                                  sampler, 
                                  wavelengths
                                    )

VBox(children=(HBox(children=(Button(button_style='info', description='Select All', icon='check', style=Button‚Ä¶

In [127]:
launch_bayesian_optimization_tool(
                                  all_param_names, 
                                  parameter_bounds, 
                                  model, 
                                  sampler, 
                                  wavelengths
                                    )

VBox(children=(HBox(children=(Button(button_style='info', description='Select All', icon='check', style=Button‚Ä¶

In [122]:
launch_bayesian_optimization_tool(
                                  all_param_names, 
                                  parameter_bounds, 
                                  model, 
                                  sampler, 
                                  wavelengths
                                    )

VBox(children=(HBox(children=(Button(button_style='info', description='Select All', icon='check', style=Button‚Ä¶

In [132]:
launch_bayesian_optimization_tool(
                                  all_param_names, 
                                  parameter_bounds, 
                                  model, 
                                  sampler, 
                                  wavelengths
                                    )

VBox(children=(HBox(children=(Button(button_style='info', description='Select All', icon='check', style=Button‚Ä¶