In [1]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import torch
from src.diffusion import SpacedDiffusion
from src.modules import UNet_conditional
from src.utils import plot_image_pairs, save_samples
from metrics import create_sections_list, cosine_step_schedule

# Define the widgets
device_widget = widgets.Text(description="Device:", value="cuda:1")
model_widget = widgets.Text(description="Model:", value="1")
slider_n = widgets.IntSlider(description="n", min=1, max=64, value=2)
slider_e = widgets.IntSlider(description="Energy", min=0, value=15)
slider_p = widgets.IntSlider(description="Pressure", min=0, value=15)
slider_ms = widgets.IntSlider(description="Acquisition time", min=1, value=20)
slider_minx = widgets.IntSlider(description="Energy xmin", min=1, max=100, value=2)
slider_maxx = widgets.IntSlider(description="Energy xmax", min=1, max=100, value=30)
button = widgets.Button(description="Generate")
output = widgets.Output()
clear_button = widgets.Button(description="Clear Output")

def clear_output_on_click(b):
    with output:
        clear_output()

def generate(b):
    model_name = model_widget.value
    device = device_widget.value
    n = slider_n.value
    E = slider_e.value
    P = slider_p.value
    ms = slider_ms.value

    xmin = slider_minx.value
    xmax = slider_maxx.value
    
    with output:
        if xmax <= xmin:
            print("xmax must be above the value of xmin")
            return
        if model_name == "1":
            path = "models/nophys_full/ema_ckpt.pt"
            print("Loading ", path)
            model = UNet_conditional(img_width=128, img_height=64, feat_num=3, device=device).to(device)
            ckpt = torch.load(path, map_location=device)
            model.load_state_dict(ckpt)
            sampler = SpacedDiffusion(beta_start=1e-4, beta_end=0.02, noise_steps=1000, section_counts=create_sections_list(10, 25, cosine_step_schedule), img_height=64, img_width=128, device=device, rescale_timesteps=False)
            y = torch.Tensor([E,P,ms]).to(device).float().unsqueeze(0) # parameter vector
            x = sampler.ddim_sample_loop(model=model, y=y, cfg_scale=1, device=device, eta=1, n=n)
            plot_image_pairs(x, xlim=[xmin,xmax], acquisition_time_ms=ms, beam_point_y=128, beam_point_x=62, energy=E, pressure=P, model=1)
        elif model_name == "2":
            path = "models/cossched_full/ema_ckpt.pt"
            print("Loading ", path)
            model = UNet_conditional(img_width=128, img_height=64, feat_num=3, device=device).to(device)
            ckpt = torch.load(path, map_location=device)
            model.load_state_dict(ckpt)
            sampler = SpacedDiffusion(beta_start=1e-4, beta_end=0.02, noise_steps=1000, section_counts=[15], img_height=64, img_width=128, device=device, rescale_timesteps=False)
            y = torch.Tensor([E,P,ms]).to(device).float().unsqueeze(0) # parameter vector
            x = sampler.ddim_sample_loop(model=model, y=y, cfg_scale=1, device=device, eta=1, n=n)
            plot_image_pairs(x, xlim=[xmin,xmax], acquisition_time_ms=ms, beam_point_y=128, beam_point_x=62, energy=E, pressure=P, model=2)
        elif model_name == "3":
            path = "models/x_start_phys/ema_ckpt.pt"
            print("Loading ", path)
            model = UNet_conditional(img_width=128, img_height=64, feat_num=3, device=device).to(device)
            ckpt = torch.load(path, map_location=device)
            model.load_state_dict(ckpt)
            sampler = SpacedDiffusion(beta_start=1e-4, beta_end=0.02, noise_steps=1000, section_counts=[15], img_height=64, img_width=128, device=device, rescale_timesteps=False)
            y = torch.Tensor([E,P,ms]).to(device).float().unsqueeze(0) # parameter vector
            x = sampler.ddim_sample_loop(model=model, y=y, cfg_scale=1, device=device, eta=1, n=n)
            plot_image_pairs(x, xlim=[xmin,xmax], acquisition_time_ms=ms, beam_point_y=128, beam_point_x=62, energy=E, pressure=P, model=2)


# Link the button click to the function
button.on_click(generate)
clear_button.on_click(clear_output_on_click)

# Display the widgets
display(device_widget, model_widget, slider_n, slider_e, slider_p, slider_ms, slider_minx, slider_maxx, button, clear_button, output)


Text(value='cuda:1', description='Device:')

Text(value='1', description='Model:')

IntSlider(value=2, description='n', max=64, min=1)

IntSlider(value=15, description='Energy')

IntSlider(value=15, description='Pressure')

IntSlider(value=20, description='Acquisition time', min=1)

IntSlider(value=2, description='Energy xmin', min=1)

IntSlider(value=30, description='Energy xmax', min=1)

Button(description='Generate', style=ButtonStyle())

Button(description='Clear Output', style=ButtonStyle())

Output()