# Importing the utils

In [None]:
!pip install --upgrade diffusers accelerate transformers ipywidgets peft


In [2]:
from IPython.display import display
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
import numpy as np
import ipywidgets as widgets

In [3]:
pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
# More efficient Scheduler (need just 20 inference steps)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
loras = [{'src': 'Fictiverse/Voxel_XL_Lora', 'weight_name': 'VoxelXL_v1.safetensors','name': 'voxel', 'trigger': 'voxel style'},
         {'src': 'CiroN2022/toy-face', 'weight_name': 'toy_face_sdxl.safetensors','name': 'toy', 'trigger': 'toy_face'}
         ]
# Load LoRA weights
for lora in loras:
    pipe.load_lora_weights(lora['src'], weight_name=lora['weight_name'], adapter_name=lora['name'])




# Empty the Torch Cache

In [5]:
torch.cuda.empty_cache()

# Creating the inteface

## Creating the widgets

In [None]:
# Align the end of the description of the widgets
align_kw = dict(
    _css = (('.widget-label', 'min-width', '20ex'),),
    margin = '0px 0px 5px 12px',
    style = {'description_width': 'auto'},
    layout = widgets.Layout(width='auto', height='auto', flex_flow='row', align_items='center', display='flex')
)
# Create widgets
prompt_widget = widgets.Text(
    value='A female pirate smoking a pipe',
    placeholder='Enter your prompt here',
    description='Prompt:',
    disabled=False,
    **align_kw
)

seed_widget = widgets.IntText(
    value=0,
    min=-1,
    step=1,
    description='Seed (-1 is random):',
    disabled=False,
    **align_kw
)

inference_steps_widget = widgets.IntText(
    value=20,
    
    min=1,
    step=1,
    description='Inference Steps (for this Scheduler 20 is recommanded):',
    disabled=False,
    **align_kw
)

lora_strength_widgets = []
for lora in loras:
    lora_strength_widgets.append(
        widgets.FloatSlider(
            value=1.0,
            min=0.0,
            max=1.0,
            step=0.01,
            description=f"{lora['name']} Strength:",
            disabled=False,
            **align_kw
        )
    )
    lora_strength_widgets[-1].layout.width = '400px'
    

generate_button = widgets.Button(
    description='Generate Image',
    disabled=False,
    button_style='',  # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click to generate image'
)


output_image = widgets.Output()

## Utils Functions

In [111]:
def generate_image(b):
  with output_image:
    output_image.clear_output()
    seed = seed_widget.value
    if seed == -1:
      seed = np.random.randint(0, 2**32 - 1)
    generator = torch.Generator(device="cuda").manual_seed(seed)
    
    prompt = prompt_widget.value
    adapter_weights=[lora_widget.value for lora_widget in lora_strength_widgets]
    adapter_name = [lora['name'] for lora in loras]
    pipe.set_adapters(adapter_name, adapter_weights=adapter_weights)
    triggers = [lora['trigger'] for lora in loras]
    
    for trigger in triggers:
        if trigger not in prompt:
            prompt = f"{prompt} {trigger}"
            
    print(f"Prompt used: {prompt}")
    image = pipe(prompt,
                  num_inference_steps=inference_steps_widget.value,
                  cross_attention_kwargs={"scale": 0.85},
                  generator=generator
                  ).images[0]
    display(image)

## Displaying the widgets

In [112]:
# Display the widgets
output_image.clear_output()
# Link the button to the function
generate_button.on_click(generate_image)
display(prompt_widget,
        seed_widget,
        inference_steps_widget,
        *lora_strength_widgets,
        generate_button,
        output_image
        )

Text(value='A female pirate on a boat', description='Prompt:', layout=Layout(align_items='center', display='fl…

IntText(value=0, description='Seed (-1 is random):', layout=Layout(align_items='center', display='flex', flex_…

IntText(value=20, description='Inference Steps (for this Scheduler 20 is recommanded):', layout=Layout(align_i…

FloatSlider(value=1.0, description='voxel Strength:', layout=Layout(align_items='center', display='flex', flex…

FloatSlider(value=1.0, description='toy Strength:', layout=Layout(align_items='center', display='flex', flex_f…

Button(description='Generate Image', style=ButtonStyle(), tooltip='Click to generate image')

Output()