In [1]:
from IPython.display import clear_output
from ipywidgets import widgets, Layout, HBox
import os
import sys
from datetime import datetime, timezone
import json
import shutil

def project_setup(
        dataset = "person_ddim",
        uploaded_images = (),
        project_name = "project_name",
        max_training_steps = 2000,
        class_word = "class_word",
        flip_percent = 0.5,
        token = "token",
        learning_rate = 1.0e-06,
        save_every_x_steps = 0,
):
    if len(uploaded_images) == 0:
        print("No training images provided, please click the 'Training Images' upload button.", file=sys.stderr)
        return

    # Download Regularization Images
    regularization_images_git_folder = f"./Stable-Diffusion-Regularization-Images-{dataset}"
    if not os.path.exists(regularization_images_git_folder):
        !git clone https://github.com/djbielejeski/Stable-Diffusion-Regularization-Images-{dataset}.git
        clear_output()

    regularization_images_root_folder = "regularization_images"
    if not os.path.exists(regularization_images_root_folder):
        os.mkdir(regularization_images_root_folder)

    regularization_images_dataset_folder = f"{regularization_images_root_folder}/{dataset}"
    if not os.path.exists(regularization_images_dataset_folder):
        os.mkdir(regularization_images_dataset_folder)

    regularization_images = os.listdir(f"{regularization_images_git_folder}/{dataset}")
    for file_name in regularization_images:
        shutil.move(os.path.join(f"{regularization_images_git_folder}/{dataset}", file_name), regularization_images_dataset_folder)


    # Training images
    training_images_save_path = "./training_images"
    if os.path.exists(training_images_save_path):
        # remove existing images
        shutil.rmtree(training_images_save_path)

    # Create the training images directory
    os.mkdir(training_images_save_path)

    images = []
    image_widgets = []
    for i, img in enumerate(uploaded_images):
        images.append(img.name)
        image_widgets.append(widgets.Image(
            value = img.content
        ))
        with open(f"{training_images_save_path}/{img.name}", "w+b") as image_file:
            image_file.write(img.content)


    display(HBox(image_widgets))

    # setup our values for training
    config_date_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
    project_config_filename = f"{config_date_time}-{project_name}-joepenna-dreambooth-config.json"

    project_config = {
        "config_file_name": project_config_filename,
        "date_utc": config_date_time,
        "dataset": dataset,
        "project_name": project_name,
        "max_training_steps": max_training_steps,
        "training_images_count": len(images),
        "training_images": images,
        "class_word": class_word,
        "flip_percent": flip_percent,
        "token": token,
        "learning_rate": learning_rate,
        "save_every_x_steps": save_every_x_steps,
    }

    project_config_json = json.dumps(project_config, indent=4)
    with open(project_config_filename, "w") as config_file:
        config_file.write(project_config_json)

    config_save_path = "./joepenna-dreambooth-configs"
    if not os.path.exists(config_save_path):
        os.mkdir(config_save_path)

    shutil.copy(project_config_filename, f"{config_save_path}/{project_config_filename}")
    shutil.move(project_config_filename, f"{config_save_path}/active-config.json")

    print(f"✅ {project_config_filename} successfully generated.  Proceed to training.")
    print(project_config_json)


def submit_form_click(b):
    with output:
        clear_output()
        project_setup(
            dataset = reg_images_select.value,
            uploaded_images = training_images_uploader.value,
            project_name = project_name_input.value,
            max_training_steps = max_training_steps_input.value,
            class_word = class_word_input.value,
            flip_percent = flip_slider.value,
            token = token_input.value,
            learning_rate = learning_rate_select.value,
            save_every_x_steps = save_every_x_steps_input.value,
        )
        return



style = {'description_width': '150px'}
label_style = {'font_size': '10px', 'text_color': '#777'}
input_and_description_layout = Layout(width="812px")
layout = Layout(width="400px")

form_widgets = []

# Training Images
training_images_uploader = widgets.FileUpload(
    accept='image/*',
    multiple=True,
    description='Training Images',
    tooltip='Training Images',
    button_style='warning',
    style=style,
    layout=layout,
)
form_widgets.append(training_images_uploader);

# Regularization Images
reg_images_select = widgets.Dropdown(
    options=["man_euler", "man_unsplash", "person_ddim", "woman_ddim", "artstyle"],
    value="person_ddim",
    description="Regularization Images: ",
    style=style,
    layout=layout,
)

reg_images_label = widgets.Label(
    value = "person_ddim recommended",
    style=label_style,
    layout=layout,
)

reg_images_label_and_input = widgets.HBox(
    [reg_images_select, reg_images_label],
    layout=input_and_description_layout
)
form_widgets.append(reg_images_label_and_input);

# Project Name
project_name_input = widgets.Text(
    placeholder='Project Name',
    description='Project Name: ',
    value='ProjectName',
    style=style,
    layout=layout,
)
project_name_label = widgets.Label(
    value = "This isn't used for training, just to help you remember what your trained into the model.",
    style=label_style,
    layout=layout,
)

project_name_label_and_input = widgets.HBox(
    [project_name_input, project_name_label],
    layout=input_and_description_layout
)
form_widgets.append(project_name_label_and_input);

# Max Training steps
max_training_steps_input = widgets.BoundedIntText(
    value=2000, # default value
    min=0, # min value
    max=100000, # max value
    step=100, # incriment size
    description='Max Training Steps: ',# slider label
    tooltip='Max Training Steps',
    style=style,
    layout=layout,
)
max_training_steps_label = widgets.Label(
    value = "How many steps do you want to train for?",
    style=label_style,
    layout=layout,
)

max_training_steps_label_and_input = widgets.HBox(
    [max_training_steps_input, max_training_steps_label],
    layout=input_and_description_layout
)
form_widgets.append(max_training_steps_label_and_input);

# Learning Rate
learning_rate_select = widgets.Dropdown(
    options=[2.0e-06, 1.5e-06, 1.0e-06, 8.0e-07, 6.0e-07, 5.0e-07, 4.0e-07],
    value=1.0e-06,
    description="Learning Rate: ",
    style=style,
    layout=layout,
)
learning_rate_label = widgets.Label(
    value = "How fast do you want to train? 1.0e-06 is highly recommended.",
    style=label_style,
    layout=layout,
)

learning_rate_label_and_input = widgets.HBox(
    [learning_rate_select, learning_rate_label],
    layout=input_and_description_layout
)
form_widgets.append(learning_rate_label_and_input);


# Class
class_word_input = widgets.Text(
    value='person',
    placeholder='man / person / woman / artstyle / etc',
    description='Class Word: ',
    style=style,
    layout=layout,
)
class_word_label = widgets.Label(
    value = "Typical uses are 'man', 'person', 'woman', or 'artstyle'",
    style=label_style,
    layout=layout,
)

class_word_label_and_input = widgets.HBox(
    [class_word_input, class_word_label],
    layout=input_and_description_layout
)
form_widgets.append(class_word_label_and_input);


# Flip slider
flip_slider = widgets.FloatSlider(
    value=0.5,
    min=0,
    max=1,
    step=0.05,
    description="Flip Images %: ",
    style=style,
    layout=layout,
)

flip_label = widgets.Label(
    value = "Set to 0.0 or 0.1 if you are training a person's face.  0.75 is the same as 0.25",
    style=label_style,
    layout=layout,
)

flip_label_and_input = widgets.HBox(
    [flip_slider, flip_label],
    layout=input_and_description_layout
)
form_widgets.append(flip_label_and_input);

# Token
token_input = widgets.Text(
    value='firstNameLastName',
    placeholder='firstNameLastName',
    description='Token: ',
    style=style,
    layout=layout,
)
token_label = widgets.Label(
    value = "Chose your unique token you want to train into stable diffusion (don't use 'sks')",
    style=label_style,
    layout=layout,
)

token_label_and_input = widgets.HBox(
    [token_input, token_label],
    layout=input_and_description_layout
)
form_widgets.append(token_label_and_input);


# Save every x steps
save_every_x_steps_input = widgets.BoundedIntText(
    value=0, # default value
    min=0, # min value
    max=100000, # max value
    step=50, # increment size
    description='Save every (x) steps: ',# slider label
    tooltip='Save every (x) steps.  Leave at 0 to only save the final checkpoint',
    style=style,
    layout=layout,
)

save_every_x_steps_label = widgets.Label(
    value = "Change to save intermediate checkpoints",
    style=label_style,
    layout=layout,
)

save_every_x_steps_label_and_input = widgets.HBox(
    [save_every_x_steps_input, save_every_x_steps_label],
    layout=input_and_description_layout
)
form_widgets.append(save_every_x_steps_label_and_input);

save_form_button = widgets.Button(
    description="Save",
    disabled=False,
    button_style='success',
    tooltip='Save',
    icon='save',
    style=style,
    layout=layout,
)
form_widgets.append(save_form_button);

# bind the save_form_button to the submit_form_click event
save_form_button.on_click(submit_form_click)

output = widgets.Output()

# display the form
for i, widget in enumerate(form_widgets):
    if widget == save_form_button:
        display(widget, output)
    else:
        display(widget)



HBox(children=(Dropdown(description='Regularization Images: ', index=3, layout=Layout(width='400px'), options=…

HBox(children=(Text(value='ProjectName', description='Project Name: ', layout=Layout(width='400px'), placehold…

HBox(children=(BoundedIntText(value=2000, description='Max Training Steps: ', layout=Layout(width='400px'), ma…

HBox(children=(Dropdown(description='Learning Rate: ', index=2, layout=Layout(width='400px'), options=('2.0e-0…

HBox(children=(Text(value='person', description='Class Word: ', layout=Layout(width='400px'), placeholder='man…

HBox(children=(FloatSlider(value=0.5, description='Flip Images %: ', layout=Layout(width='400px'), max=0.5, st…

Text(value='firstNameLastName', description='Token: ', layout=Layout(width='400px'), placeholder='firstNameLas…

HBox(children=(BoundedIntText(value=0, description='Save every (x) steps: ', layout=Layout(width='400px'), max…

Button(button_style='success', description='Save', icon='save', layout=Layout(width='400px'), style=ButtonStyl…

Output()