In [1]:
import numpy as np
import torch
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

from sample import sample_from_model
from modules import UNet_conditional, EDMPrecond

from diffusion import SpacedDiffusion
from diffusion import EdmSampler

from evaluate import evaluate
from utils import plot_predictions_with_cond_vectors

Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.



In [2]:
wavelengths = np.load('../data/wavelengths.npy')
epoch = 40
name = "edm_bs16_do0_cgt20_cg3_ns30_lr1e-3_0_e40_ed999"
model_path = "models/" + name + "/ema_ckpt" + str(epoch) + ".pt"   
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Load model
model = EDMPrecond(device=device).to(device)
ckpt = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt)
model.eval()

# Create sampler
noise_steps = 30 
sampler = EdmSampler(net=model, num_steps=noise_steps)

In [3]:
# Parameter bounds and defaults
parameter_bounds = {
    "L1-PL3-PZM52_Setpoint": (-3.00, 3.00),
    "L1-PL3-PZM53_Setpoint": (-3.00, 3.00),
    "L1-PL3-PZM56_Setpoint": (-1.5910, 0.6349),
    "L1-PL3-PZM57_Setpoint": (-3.5833, -0.0177),
    "L1-PL3-PZM58_Setpoint": (-3.0000, 1.0000),
    "L1-PL3-PZM59_Setpoint": (-3.0000, 1.0000),
    "L1-PL4-PZM66_Setpoint": (-3.7500, 3.1082),
    "L1-PL4-PZM67_Setpoint": (-3.7500, 0.2802),
    "L1-PL4-PZM68_Setpoint": (-3.7500, 3.7500),
    "L1-PL4-PZM69_Setpoint": (-1.9848, 3.7500),
    "L1-INJ-PM70:VAL_CAL": (-0.3500, 41.8900),
    "L1-OPA3-5_2-PM98:VAL_CAL": (5.6200, 66.2500),
    "L1-INJ-PM70:VAL_CAL_diff": (-23.7200, 1.0800),
}
default_values = [-0.612, 0.1487, -0.4105, -1.2972, -1. , -1., -2.4977, -0.4915, -2.5347, 0.7122, 31.08, 43.41, 0.]
#default_values = [1,0,0,0,0,0,0,0,0,0,0,0,0]

In [4]:
def create_sampling_ui(parameter_bounds, default_values, sampler, wavelengths, device):
    output = widgets.Output()

    # Create text input widgets + bounds info
    text_inputs = {}
    input_rows = []
    for (param, (low, high)), default in zip(parameter_bounds.items(), default_values):
        input_box = widgets.FloatText(
            value=default,
            description=param,
            style={'description_width': '250px'},
            layout=widgets.Layout(width='400px')
        )
        bounds_label = widgets.Label(
            value=f"Usual range: ({low:.2f}, {high:.2f})",
            layout=widgets.Layout(width='200px')
        )
        text_inputs[param] = input_box
        row = widgets.HBox([input_box, bounds_label])
        input_rows.append(row)

    inputs_box = widgets.VBox(input_rows)

    n_samples_input = widgets.IntText(
        value=1,
        description='n_samples',
        style={'description_width': '150px'},
        layout=widgets.Layout(width='300px')
    )

    run_button = widgets.Button(description="Sample", button_style='success')

    # Plotting function
    def plot_spectra(predicted):
        plt.figure(figsize=(12, 6), dpi=300)
        if predicted.shape[0] == 1:
            plt.plot(wavelengths, predicted[0], color='tab:blue', alpha=0.7, label="Predicted sample")
        elif predicted.shape[0] == 2:
            plt.plot(wavelengths, predicted[0], color='tab:blue', alpha=0.7, label="Predicted sample 1")
            plt.plot(wavelengths, predicted[1], color='tab:green', alpha=0.7, label="Predicted sample 2")
        elif predicted.shape[0] == 3:
            plt.plot(wavelengths, predicted[0], color='tab:blue', alpha=0.7, label="Predicted sample 1")
            plt.plot(wavelengths, predicted[1], color='tab:green', alpha=0.7, label="Predicted sample 2")
            plt.plot(wavelengths, predicted[2], color='tab:orange', alpha=0.7, label="Predicted sample 3")
        else:
            min_intensities = np.min(predicted, axis=0)
            max_intensities = np.max(predicted, axis=0)
            plt.fill_between(wavelengths, min_intensities, max_intensities, color='tab:blue', alpha=0.3, label="Predicted range")
        plt.xlabel("Wavelengths")
        plt.ylabel("Intensity")
        plt.legend()
        plt.show()

    # Sampling function
    def sample_and_plot(b):
        with output:
            clear_output()
            cond_vec = torch.tensor(
                [input_box.value for input_box in text_inputs.values()],
                dtype=torch.float32
            ).to(device)
            n_samples = n_samples_input.value

            with torch.no_grad():
                pred = sampler.sample(
                    resolution=1024,
                    device=device,
                    settings=cond_vec,
                    n_samples=n_samples,
                    cfg_scale=2,
                    settings_dim=13
                )
            pred_np = pred.cpu().numpy()[:, 0, :]
            plot_spectra(pred_np)

    run_button.on_click(sample_and_plot)

    controls_box = widgets.VBox([n_samples_input, run_button])
    ui = widgets.HBox([inputs_box, controls_box])

    display(ui, output)

In [13]:
create_sampling_ui(parameter_bounds, default_values, sampler, wavelengths, device)

HBox(children=(VBox(children=(HBox(children=(FloatText(value=-0.612, description='L1-PL3-PZM52_Setpoint', layo…

Output()

In [12]:
create_sampling_ui(parameter_bounds, default_values, sampler, wavelengths, device)

HBox(children=(VBox(children=(HBox(children=(FloatText(value=-0.612, description='L1-PL3-PZM52_Setpoint', layo…

Output()

In [11]:
create_sampling_ui(parameter_bounds, default_values, sampler, wavelengths, device)

HBox(children=(VBox(children=(HBox(children=(FloatText(value=-0.612, description='L1-PL3-PZM52_Setpoint', layo…

Output()

In [14]:
create_sampling_ui(parameter_bounds, default_values, sampler, wavelengths, device)

HBox(children=(VBox(children=(HBox(children=(FloatText(value=-0.612, description='L1-PL3-PZM52_Setpoint', layo…

Output()