# Importing the utils

In [None]:
!pip install --upgrade diffusers accelerate transformers ipywidgets peft


In [None]:
from diffusers.schedulers import AysSchedules

from IPython.display import display
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
import numpy as np
import ipywidgets as widgets
import os
import json
import datetime

In [None]:
pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.torch.bfloat16).to("cuda")
# More efficient Scheduler (need just 20 inference steps)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, algorithm_type="sde-dpmsolver++", timestep_spacing="linspace")

In [None]:
loras = [{'src': 'Fictiverse/Voxel_XL_Lora', 'weight_name': 'VoxelXL_v1.safetensors','name': 'voxel', 'trigger': 'voxel style'},
         {'src': 'CiroN2022/toy-face', 'weight_name': 'toy_face_sdxl.safetensors','name': 'toy', 'trigger': 'toy_face'}
         ]
# Load LoRA weights
for lora in loras:
    pipe.load_lora_weights(lora['src'], weight_name=lora['weight_name'], adapter_name=lora['name'])


# Empty the Torch Cache

In [None]:
torch.cuda.empty_cache()

# Creating the inteface

## Creating the widgets

In [None]:
saved_directory = 'images_saved/txt2img'
saved_configs_path = 'configs_saved.json'
# Create the directory if it doesn't exist
if not os.path.exists(saved_directory):
    os.makedirs(saved_directory)
# Create the saved_configs_path if it doesn't exist
if not os.path.exists(saved_configs_path):
    with open(saved_configs_path, 'w') as f:
        json.dump({}, f)

# Align the end of the description of the widgets
align_kw = dict(
    _css = (('.widget-label', 'min-width', '20ex'),),
    margin = '0px 0px 5px 12px',
    style = {'description_width': 'auto'},
    layout = widgets.Layout(width='auto', height='auto', flex_flow='row', align_items='center', display='flex')
)
# Create widgets

# TAB 1

prompt_widget = widgets.Text(
    value='A astronaut on the moon',
    placeholder='Enter your prompt here',
    description='Prompt:',
    disabled=False,
    **align_kw
)

negative_prompt_widget = widgets.Text(
    value='',
    placeholder='Negative prompt',
    description='Negative Prompt:',
    disabled=False,
    **align_kw
)

seed_widget = widgets.IntText(
    value=0,
    min=-1,
    step=1,
    description='Seed (-1 is random):',
    disabled=False,
    **align_kw
)

inference_steps_widget = widgets.IntText(
    value=20,
    
    min=1,
    step=1,
    description='Inference Steps (for this Scheduler 20 is recommanded):',
    disabled=False,
    **align_kw
)

guidance_widget = widgets.FloatSlider(
            value=10.0,
            min=0.0,
            max=50.0,
            step=0.01,
            description=f"Guidance Scale:",
            disabled=False,
            **align_kw
        )
guidance_widget.layout.width = '400px'

lora_strength_widgets = []
lora_activate_widgets = []
for lora in loras:
    lora_strength_widgets.append(
        widgets.FloatSlider(
            value=1.0,
            min=0.0,
            max=1.0,
            step=0.01,
            description=f"{lora['name']} Strength:",
            disabled=False,
            **align_kw
        )
    )
    lora_strength_widgets[-1].layout.width = '400px'
    lora_activate_widgets.append(
        widgets.Checkbox(
            value=True,
            description=f"On/Off",
            disabled=False,
            indent=True
        )
    )
    
# Add checkbox for saving parameters
save_image_widget = widgets.Checkbox(
    value=True,
    description='Save image',
    disabled=False,
    indent=True
    
)

generate_button = widgets.Button(
    description='Generate Image',
    disabled=False,
    button_style='',  # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click to generate image'
)

output_image = widgets.Output()

# TAB 2
    
# Create a refresh button
refresh_button = widgets.Button(
    description='Refresh',
    disabled=False,
    button_style='',  # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click to refresh the file list'
)

delete_button = widgets.Button(
    description='Delete',
    disabled=False,
    button_style='',  # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click to delete the selected file'
)

file_select = widgets.Select(
    options=[],
    value=None,  # Default selected file
    description='Select File:',
    disabled=False
)

# Create a png viewer
viewer = widgets.Image(
    value=b'',
    format='png',
    width=512,
    height=512
)

# Create a text viewer to display the parameters
text_viewer = widgets.Textarea(
    value='',
    placeholder='Parameters',
    description='Parameters:',
    disabled=False,
    layout=widgets.Layout(width='auto', height='auto',flex_flow='column', align_items='center', display='flex')
)

# Define the callback function
def on_file_selected(change):
    if change['type'] == 'change' and change['name'] == 'value':
        selected_file = change['new']
        file_path = os.path.join(saved_directory, selected_file)
        new_image = open(file_path, "rb").read()
        
        with open('configs_saved.json', 'r') as f:
            configs = json.load(f)
            if file_path in configs:
                new_config = str(configs[file_path])
        f.close()
        # Update the image and parameters
        viewer.value = new_image
        text_viewer.value = new_config
        

# Attach the observer to the Select widget
file_select.observe(on_file_selected, names='value')

def refresh_list_file(b):
    file_list = os.listdir(saved_directory)
    # Sort the file names
    file_list.sort()
    config_default = ""
    file_default = None
    image_default = b""
    if file_list != []:
        file_default = file_list[0]
        file_path = os.path.join(saved_directory, file_default)
        image_default = open(file_path, "rb").read()
        
        with open('configs_saved.json', 'r') as f:
            configs = json.load(f)
            if file_path in configs:
                config_default = str(configs[file_path])
        f.close()
        file_select.options = file_list
    else:
        file_select.options = None
    
    # Update the Select widget with the new options
    file_select.value = file_default
    viewer.value = image_default
    text_viewer.value = config_default

refresh_list_file(None)
refresh_button.on_click(refresh_list_file)

def delete_selected_file(b):
    selected_file = file_select.value
    if selected_file:
        file_path = os.path.join(saved_directory, selected_file)
        os.remove(file_path)
        # Remove the config from the saved_configs_path
        with open(saved_configs_path, 'r') as f:
            configs = json.load(f)
            if file_path in configs:
                del configs[file_path]
        f.close()
        with open(saved_configs_path, 'w') as f:
            json.dump(configs, f, indent=4)
        f.close()
        # Refresh the file list
        refresh_list_file(None)
        
delete_button.on_click(delete_selected_file)

tabs = widgets.Tab()

tabs.children = [widgets.VBox([widgets.Label(value="Prompt & Model Parameters"),
                               prompt_widget,
                               negative_prompt_widget, 
                               seed_widget, 
                               inference_steps_widget,
                               guidance_widget,
                               widgets.Label(value="LoRA Strengths"),]
                               + [widgets.HBox([lora_activate_widgets[i], lora_strength_widgets[i]],flex_flow='row',justify='flex-start', display='flex') for i in range(len(loras))]
                                + [save_image_widget, 
                                 generate_button, 
                                 output_image]), 
                 widgets.VBox([widgets.HBox([file_select, widgets.VBox([refresh_button, 
                                                                        delete_button])]),
                               widgets.HBox([viewer, text_viewer])])]

tabs.set_title(0, "Generation")
tabs.set_title(1, "Saved Images")

## Utils Functions

In [None]:
sampling_schedule = []
# sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]

def generate_image(b):
  output_image.clear_output()
  with output_image:
    
    # Choising the seed
    seed = seed_widget.value
    if seed == -1:
      seed = np.random.randint(0, 2**32 - 1)
    generator = torch.Generator(device="cuda").manual_seed(seed)
    
    # Prompt and LoRA
    prompt = prompt_widget.value
    active_adapters = ([],[])
    for (lora, widget, activate) in zip(loras,lora_strength_widgets, lora_activate_widgets):
      if not activate.value or widget.value == 0:
        continue
      trigger = lora['trigger']
      adapter_weights = widget.value
      if adapter_weights > 0:
        active_adapters[0].append(lora['name'])
        active_adapters[1].append(adapter_weights)
        if trigger not in prompt:
            prompt = f"{prompt}, {trigger}"
            
    if len(active_adapters[0]) > 0:
      pipe.set_adapters(active_adapters[0], active_adapters[1])
            
    print(f"Prompt used: {prompt}")
    
    # Sampling schedule and steps
    if sampling_schedule != []:
      num_inference_steps = len(sampling_schedule)
    else:
      num_inference_steps = inference_steps_widget.value
    
    
    # Inference
    image = pipe(prompt,
                  negative_prompt=negative_prompt_widget.value,
                  timesteps=sampling_schedule if sampling_schedule != [] else None,
                  num_inference_steps=num_inference_steps,
                  guidance_scale=guidance_widget.value,
                  cross_attention_kwargs={"scale": 0.85},
                  generator=generator
                  ).images[0]
    
    # Save the image
    if save_image_widget.value:
      image_path = save_image(image)
      metadata = {'time_steps_spacing': pipe.scheduler.config.get('timestep_spacing'),
                  'inference_steps': num_inference_steps,
                  'guidance_scale': guidance_widget.value,
                  'prompt': prompt,
                  'nagative_prompt': negative_prompt_widget.value,
                  'seed': seed,
                  'lora_useds': active_adapters,
                  'cross_attention_kwargs': {"scale": 0.85}
      }
      save_config(image_path, metadata)
  
    display(image)

def save_image(image):
  
  current_time = datetime.datetime.now()

  # Format the current time as a string
  formatted_time = current_time.strftime("%Y_%m_%d_%H_%M_%S")
  image_path = f"{saved_directory}/image_{formatted_time}.png"
  image.save(image_path)
  return image_path

def save_config(filename, metadata):
  if not os.path.exists('configs_saved.json'):
      with open('configs_saved.json', 'w') as f:
          f.write('{}')
  with open('configs_saved.json', 'r') as f:
      configs = json.load(f)
  configs[filename] = metadata
  with open('configs_saved.json', 'w') as f:
      json.dump(configs, f, indent=4, allow_nan=False)
  

# Displaying the interface

In [None]:
# Display the widgets
output_image.clear_output()
# Link the button to the function
generate_button.on_click(generate_image)

default_prompt = 'A deep forest with monkeys, detailed, realistic, 8k'
prompt_widget.value = default_prompt

display(tabs)