In [1]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, SelectionSlider

# Load the data from file
with open("waveform_results_sweeps.pkl", "rb") as f:
    waveform_dict = pickle.load(f)


In [2]:
print(waveform_dict.keys())

dict_keys([(100, 0, 0.5, 'quadratic'), (100, 0, 0.5, 'linear'), (100, 0, 1.0, 'quadratic'), (100, 0, 1.0, 'linear'), (100, 0, 1.5, 'quadratic'), (100, 0, 1.5, 'linear'), (100, 0, 2.5, 'quadratic'), (100, 0, 2.5, 'linear'), (100, 0.2, 0.5, 'quadratic'), (100, 0.2, 0.5, 'linear'), (100, 0.2, 1.0, 'quadratic'), (100, 0.2, 1.0, 'linear'), (100, 0.2, 1.5, 'quadratic'), (100, 0.2, 1.5, 'linear'), (100, 0.2, 2.5, 'quadratic'), (100, 0.2, 2.5, 'linear'), (100, 0.5, 0.5, 'quadratic'), (100, 0.5, 0.5, 'linear'), (100, 0.5, 1.0, 'quadratic'), (100, 0.5, 1.0, 'linear'), (100, 0.5, 1.5, 'quadratic'), (100, 0.5, 1.5, 'linear'), (100, 0.5, 2.5, 'quadratic'), (100, 0.5, 2.5, 'linear'), (100, 0.8, 0.5, 'quadratic'), (100, 0.8, 0.5, 'linear'), (100, 0.8, 1.0, 'quadratic'), (100, 0.8, 1.0, 'linear'), (100, 0.8, 1.5, 'quadratic'), (100, 0.8, 1.5, 'linear'), (100, 0.8, 2.5, 'quadratic'), (100, 0.8, 2.5, 'linear'), (100, 1.0, 0.5, 'quadratic'), (100, 1.0, 0.5, 'linear'), (100, 1.0, 1.0, 'quadratic'), (100, 

In [4]:

# Extract unique parameter choices (assuming keys are tuples: (guidance_scale, steps, eta, condition))
steps_list = sorted(set(key[0] for key in waveform_dict.keys()))
etas = sorted(set(key[1] for key in waveform_dict.keys()))
guidance_scales = sorted(set(key[2] for key in waveform_dict.keys()))
methods = sorted(set(key[3] for key in waveform_dict.keys()))
conds = sorted(set(key for key in list(waveform_dict.values())[0].keys()))

def plot_for_params(steps, eta, guidance_scale, method, cond):
    config_key = (steps, eta, guidance_scale, method)
    if config_key not in waveform_dict:
        print(f"Parameters {config_key} not found.")
        return
    wfs = waveform_dict[config_key][cond]  # shape (batch_size, 512)
    x = np.arange(wfs.shape[1])
    x = np.flip(x, -1)
    plt.figure(figsize=(9, 6))
    for wf in wfs:
        plt.plot(x, wf, alpha=0.1)
    plt.title(f"Guidance: {guidance_scale}, Steps: {steps}, Eta: {eta}, Tree-cover: {cond}")
    plt.xlim(0, wfs.shape[1] - 1)
    plt.ylim(0, 1.2)
    plt.show()


slider_steps = SelectionSlider(
    options=steps_list,
    description="Steps",
    continuous_update=True
)

slider_eta = SelectionSlider(
    options=etas,
    description="Eta",
    continuous_update=True
)

slider_gs = SelectionSlider(
    options=guidance_scales,
    description="Guidance Scale",
    continuous_update=True
)

slider_method = SelectionSlider(
    options=methods,
    description="Method",
    continuous_update=True
)

slider_cond = SelectionSlider(
    options=conds,
    description="Tree-cover",
    continuous_update=True
)

interact(plot_for_params, steps=slider_steps, eta=slider_eta, guidance_scale=slider_gs, method=slider_method, cond=slider_cond)


interactive(children=(SelectionSlider(description='Steps', options=(100, 250, 400), value=100), SelectionSlide…

<function __main__.plot_for_params(steps, eta, guidance_scale, method, cond)>