In [None]:
# Uncomment the line to download necessary packages
# !pip install torch ipywidgets ipyfilechooser ipython Pillow

In [1]:
import os
import torch
from vendor.lpw_stable_diffusion_xl import StableDiffusionXLLongPromptWeightingPipeline
import ipywidgets as widgets
from ipyfilechooser import FileChooser
from IPython.display import display, clear_output
from datetime import datetime
from PIL import Image

In [2]:
def create_file_choosers():
    home_dir = os.path.expanduser("~")  # Get the user's home directory

    model_fc = FileChooser(home_dir, show_hidden=True)  # Set show_hidden to True
    model_fc.title = 'Select the model file (.safetensors):'
    model_fc.filter_pattern = '*.safetensors'

    weights_fc = FileChooser(home_dir, show_hidden=True)  # Set show_hidden to True
    weights_fc.title = 'Select the weights file (.safetensors):'
    weights_fc.filter_pattern = '*.safetensors'

    load_model_button = widgets.Button(description='Load Model', layout=widgets.Layout(width='100%'))
    progress_bar = widgets.FloatProgress(value=0.0, min=0.0, max=1.0, bar_style='info', orientation='horizontal')
    status_text = widgets.HTML(value='', placeholder='', description='')

    # Hide the progress bar and status text initially
    progress_bar.layout.visibility = 'hidden'
    status_text.layout.visibility = 'hidden'
    
    def on_load_model_button_clicked(b):
        global pipe
        try:
            # Show the progress bar and status text when the button is clicked
            progress_bar.layout.visibility = 'visible'
            status_text.layout.visibility = 'visible'

            # Reset the progress bar style and value
            progress_bar.bar_style = 'info'
            progress_bar.value = 0.1
            status_text.value = 'Loading model...'
            
            pipe = StableDiffusionXLLongPromptWeightingPipeline.from_single_file(
                model_fc.selected, torch_dtype=torch.float16
            ).to("cuda")
            progress_bar.value = 0.6
            status_text.value = 'Loading lora...'
            pipe.load_lora_weights(weights_fc.selected)
            progress_bar.value = 1.0
            status_text.value = 'Model and lora loaded successfully.'
        except Exception as e:
            progress_bar.bar_style = 'danger'
            status_text.value = f'An error occurred: {e}'

    load_model_button.on_click(on_load_model_button_clicked)
    file_chooser_box = widgets.VBox([model_fc, weights_fc, load_model_button, status_text, progress_bar])
    return file_chooser_box

In [3]:
def create_widgets():
    global prompt_label, negative_prompt_label, iterations_label, height_label, width_label, guidance_scale_label, seed_label
    global prompt_textbox, negative_prompt_textbox, iterations_slider, height_slider, width_slider, guidance_scale_slider, seed_slider
    global generate_button, save_button, output

    prompt_label = widgets.Label('Prompt:')
    negative_prompt_label = widgets.Label('Neg Prompt:')
    iterations_label = widgets.Label('Iterations:')
    height_label = widgets.Label('Height:')
    width_label = widgets.Label('Width:')
    guidance_scale_label = widgets.Label('Guidance Scale:')
    seed_label = widgets.Label('Seed:')

    prompt_textbox = widgets.Textarea(
        value="pokemon pikachu, active pose, anime, sketch, cyberpunk, abstract brush strokes, heavy lines, HDR Neo-Noir style, bright yellow neon, minimal <lora:xl_more_art-full:1>",
        placeholder='Enter the prompt',
        disabled=False,
        layout=widgets.Layout(width='99%', height='100px')
    )

    negative_prompt_textbox = widgets.Textarea(
        value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation,",
        placeholder='Enter the negative prompt',
        disabled=False,
        layout=widgets.Layout(width='99%', height='100px')
    )

    iterations_slider = widgets.IntSlider(min=1, max=50, value=30, layout=widgets.Layout(width='100%'))
    height_slider = widgets.IntSlider(min=256, max=1024, value=1024, step=8, layout=widgets.Layout(width='100%'))
    width_slider = widgets.IntSlider(min=256, max=1024, value=768, step=8, layout=widgets.Layout(width='100%'))
    guidance_scale_slider = widgets.FloatSlider(min=1.0, max=10.0, value=7.5, step=0.1, layout=widgets.Layout(width='100%'))
    seed_slider = widgets.IntSlider(min=1, max=1000, value=80, layout=widgets.Layout(width='100%'))
    generate_button = widgets.Button(description="Generate Image", layout=widgets.Layout(width='100%'))
    output = widgets.Output(layout=widgets.Layout(width='100%', margin='0 auto'))
    save_button = widgets.Button(description="Save Image", layout=widgets.Layout(width='100%'))

def generate_image(prompt, negative_prompt, num_iterations, height, width, guidance_scale, seed):
    kwargs = {
        'height': min(height - height % 8, 1024),
        'width': min(width - width % 8, 1024),
        'num_inference_steps': min(num_iterations, 50),
        'guidance_scale': guidance_scale,
        'negative_prompt': negative_prompt,
        'generator': torch.Generator().manual_seed(seed)
    }
    
    image = pipe(prompt, **kwargs).images[0]
    return image

def on_generate_button_click(b):
    global image  # Declare image as a global variable
    with output:
        clear_output(wait=True)
        image = generate_image(
            prompt_textbox.value,
            negative_prompt_textbox.value,
            iterations_slider.value,
            height_slider.value,
            width_slider.value,
            guidance_scale_slider.value,
            seed_slider.value
        )
        display(image)

def on_save_button_click(b):
    if 'image' in globals():
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"generated_image_{timestamp}.png"
        image.save(filename)
        print(f"Image saved as {filename}")
    else:
        print("No image to save. Please generate an image first.")

In [4]:
def create_gui():
    create_widgets()
    file_chooser_box = create_file_choosers()
    generate_button.on_click(on_generate_button_click)
    save_button.on_click(on_save_button_click)

    prompt_box = widgets.VBox([prompt_label, prompt_textbox], layout=widgets.Layout(width='100%'))
    negative_prompt_box = widgets.VBox([negative_prompt_label, negative_prompt_textbox], layout=widgets.Layout(width='100%'))
    iterations_box = widgets.VBox([iterations_label, iterations_slider], layout=widgets.Layout(width='100%'))
    height_box = widgets.VBox([height_label, height_slider], layout=widgets.Layout(width='100%'))
    width_box = widgets.VBox([width_label, width_slider], layout=widgets.Layout(width='100%'))
    guidance_scale_box = widgets.VBox([guidance_scale_label, guidance_scale_slider], layout=widgets.Layout(width='100%'))
    seed_box = widgets.VBox([seed_label, seed_slider], layout=widgets.Layout(width='100%'))
    generate_button_box = widgets.VBox([generate_button], layout=widgets.Layout(width='100%', justify_content='center'))
    save_button_box = widgets.VBox([save_button], layout=widgets.Layout(width='100%', justify_content='center'))


    rows = widgets.VBox(
        [
            file_chooser_box,
            prompt_box, negative_prompt_box, iterations_box, height_box, width_box, guidance_scale_box, seed_box,
            generate_button_box, save_button_box
        ],
        layout=widgets.Layout(width='80%')
    )

    header = widgets.HTML("<h2 style='text-align: center;'>Stable Diffusion Image Generator</h2>")

    main_container = widgets.VBox(
        [header, rows, output],
        layout=widgets.Layout(align_items='center', justify_content='space-between', width='50%', margin='auto')
    )

    display(main_container)

if __name__ == '__main__':
    create_gui()

VBox(children=(HTML(value="<h2 style='text-align: center;'>Stable Diffusion Image Generator</h2>"), VBox(child…

Image saved as generated_image_20240410_022523.png
