In [None]:
from IPython.display import clear_output
from ipywidgets import widgets, Layout, HBox
import os
import sys
import glob
from datetime import datetime, timezone
import json
import shutil

def project_setup(
        dataset = "person_ddim",
        uploaded_images = (),
        project_name = "project_name",
        max_training_steps = 2000,
        class_word = "class_word",
        face_training = False,
        token = "token",
        config = "v1-finetune_unfrozen.yaml",
):
    if len(uploaded_images) == 0:
        print("No training images provided, please click the 'Training Images' upload button.", file=sys.stderr)
        return

    # Download Regularization Images
    regularization_images_git_folder = f"./Stable-Diffusion-Regularization-Images-{dataset}"
    if not os.path.exists(regularization_images_git_folder):
        !git clone https://github.com/djbielejeski/Stable-Diffusion-Regularization-Images-{dataset}.git
        clear_output()

    regularization_images_root_folder = "regularization_images"
    if not os.path.exists(regularization_images_root_folder):
        os.mkdir(regularization_images_root_folder)

    regularization_images_dataset_folder = f"{regularization_images_root_folder}/{dataset}"
    if not os.path.exists(regularization_images_dataset_folder):
        os.mkdir(regularization_images_dataset_folder)

    regularization_images = os.listdir(f"{regularization_images_git_folder}/{dataset}")
    for file_name in regularization_images:
        shutil.move(os.path.join(f"{regularization_images_git_folder}/{dataset}", file_name), regularization_images_dataset_folder)


    # Training images
    training_images_save_path = "./training_images"
    if os.path.exists(training_images_save_path):
        # remove existing images
        shutil.rmtree(training_images_save_path)

    # Create the training images directory
    os.mkdir(training_images_save_path)

    images = []
    image_widgets = []
    for i, img in enumerate(uploaded_images):
        images.append(img.name)
        image_widgets.append(widgets.Image(
            value = img.content
        ))
        with open(f"{training_images_save_path}/{img.name}", "w+b") as image_file:
            image_file.write(img.content)


    display(HBox(image_widgets))

    # setup our values for training
    config_date_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
    project_config = {
        "date_utc": config_date_time,
        "dataset": dataset,
        "project_name": project_name,
        "max_training_steps": max_training_steps,
        "training_images_count": len(images),
        "training_images": images,
        "class_word": class_word,
        "face_training": face_training,
        "token": token,
        "config": config,
    }

    project_config_json = json.dumps(project_config, indent=4)
    project_config_filename = f"{config_date_time}-{project_name}-joepenna-dreambooth-config.json"
    with open(project_config_filename, "w") as config_file:
        config_file.write(project_config_json)

    config_save_path = "./joepenna-dreambooth-configs"
    if not os.path.exists(config_save_path):
        os.mkdir(config_save_path)

    shutil.copy(project_config_filename, f"{config_save_path}/{project_config_filename}")
    shutil.copy(project_config_filename, f"{config_save_path}/active-config.json")

    print(f"✅ {project_config_filename} successfully generated.  Proceed to training.")
    print(project_config_json)


def submit_form_click(b):
    with output:
        clear_output()
        project_setup(
            dataset = reg_images_select.value,
            uploaded_images = training_images_uploader.value,
            project_name = project_name_input.value,
            max_training_steps = max_training_steps_input.value,
            class_word = class_word_input.value,
            face_training = i_am_training_a_persons_face_checkbox.value,
            token = token_input.value,
            config = config_select.value,
        )
        return


style = {'description_width': '150px'}
layout = Layout(width="400px")

form_widgets = []

training_images_uploader = widgets.FileUpload(
    accept='image/*',
    multiple=True,
    description='Training Images',
    tooltip='Training Images',
    button_style='warning',
    style=style,
    layout=layout,
)
form_widgets.append(training_images_uploader);


reg_images_select = widgets.Dropdown(
    options=["man_euler", "man_unsplash", "300_person_ddim", "person_ddim", "woman_ddim", "blonde_woman"],
    value="person_ddim",
    description="Regularization Images",
    style=style,
    layout=layout,
)
form_widgets.append(reg_images_select);


project_name_input = widgets.Text(
    placeholder='Project Name',
    description='Project Name:',
    value='ProjectName',
    style=style,
    layout=layout,
)
form_widgets.append(project_name_input);


max_training_steps_input = widgets.BoundedIntText(
    value=2000, # default value
    min=0, # min value
    max=100000, # max value
    step=100, # incriment size
    description='Max Training Steps: ',# slider label
    tooltip='Max Training Steps',
    style=style,
    layout=layout,
)
form_widgets.append(max_training_steps_input);


# typical uses are "man", "person", "woman"
class_word_input = widgets.Text(
    value='person',
    placeholder='man / person / woman / etc',
    description='Class Word:',
    style=style,
    layout=layout,
)
form_widgets.append(class_word_input);


i_am_training_a_persons_face_checkbox = widgets.Checkbox(
    value=False,
    description='Training a persons face?',
    tooltip='Training a persons face?',
    button_style='info',
    icon='check',
    style=style,
    layout=layout,
)
form_widgets.append(i_am_training_a_persons_face_checkbox);


token_input = widgets.Text(
    value='firstNameLastName',
    placeholder='firstNameLastName',
    description='Token:',
    style=style,
    layout=layout,
)
form_widgets.append(token_input);


config_select = widgets.Dropdown(
    options=["v1-finetune_unfrozen.yaml", "v1-finetune_unfrozen_save_checkpoints_every_250_steps.yaml", "v1-finetune_unfrozen_save_checkpoints_every_500_steps.yaml"],
    value="v1-finetune_unfrozen.yaml",
    description="Config",
    style=style,
    layout=layout,
)
form_widgets.append(config_select);


save_form_button = widgets.Button(
    description="Save",
    disabled=False,
    button_style='success',
    tooltip='Save',
    icon='save',
    style=style,
    layout=layout,
)
form_widgets.append(save_form_button);

# bind the save_form_button to the submit_form_click event
save_form_button.on_click(submit_form_click)

output = widgets.Output()

# display the form
for i, widget in enumerate(form_widgets):
    if widget == save_form_button:
        display(widget, output)
    else:
        display(widget)