From c25a16cde586c77a2a3a2b3ca67dc98515c7020b Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Sat, 4 Mar 2023 19:33:20 +0800 Subject: [PATCH] cleanup --- modules/api/api.py | 68 +++-- modules/call_queue.py | 6 - modules/processing.py | 15 +- modules/shared.py | 15 +- modules/ui.py | 545 ++++++++++---------------------------- requirements_versions.txt | 1 + webui.py | 129 ++++----- 7 files changed, 277 insertions(+), 502 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9b5aca227fa..70887d7c0a5 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -26,6 +26,9 @@ from typing import Union import traceback import requests +import piexif +import piexif.helper +import numpy as np def upscaler_to_index(name: str): try: @@ -54,23 +57,48 @@ def decode_base64_to_image(encoding): encoding = encoding.split(";")[1].split(",")[1] return Image.open(BytesIO(base64.b64decode(encoding))) +def encode_to_base64(image): + if type(image) is str: + return image + elif type(image) is Image.Image: + return encode_pil_to_base64(image) + elif type(image) is np.ndarray: + return encode_np_to_base64(image) + else: + return "" + def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - # Copy any text-only metadata - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True + if opts.samples_format.lower() == 'png': + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) + + elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + parameters = image.info.get('parameters', None) + exif_bytes = piexif.dump({ + "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } + }) + if opts.samples_format.lower() in ("jpg", "jpeg"): + image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality) + else: + image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality) + + else: + raise HTTPException(status_code=500, detail="Invalid image format") - image.save( - output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) - ) bytes_data = output_bytes.getvalue() + return base64.b64encode(bytes_data) +def encode_np_to_base64(image): + pil = Image.fromarray(image) + return encode_pil_to_base64(pil) class Api: def __init__(self, app: FastAPI, queue_lock: Lock): @@ -138,16 +166,13 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): shared.state.begin() with self.queue_lock: - processed = process_images(p) - - if p.script_args is not None: - processed = p.scripts.run(p, *p.script_args) + processed = p.scripts.run(p, *p.script_args) if processed is None: processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) + b64images = list(map(encode_to_base64, processed.images)) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) @@ -185,16 +210,13 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): shared.state.begin() with self.queue_lock: - processed = process_images(p) - - if p.script_args is not None: - processed = p.scripts.run(p, *p.script_args) + processed = p.scripts.run(p, *p.script_args) if processed is None: processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) + b64images = list(map(encode_to_base64, processed.images)) if not img2imgreq.include_init_images: img2imgreq.init_images = None @@ -210,7 +232,7 @@ def extras_single_image_api(self, req: ExtrasSingleImageRequest): with self.queue_lock: result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict) - return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) + return ExtrasSingleImageResponse(image=encode_to_base64(result[0][0]), html_info=result[1]) def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): reqDict = setUpscalers(req) @@ -226,7 +248,7 @@ def prepareFiles(file): with self.queue_lock: result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict) - return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) + return ExtrasBatchImagesResponse(images=list(map(encode_to_base64, result[0])), html_info=result[1]) def pnginfoapi(self, req: PNGInfoRequest): if(not req.image.strip()): @@ -260,7 +282,7 @@ def progressapi(self, req: ProgressRequest = Depends()): current_image = None if shared.state.current_image and not req.skip_current_image: - current_image = encode_pil_to_base64(shared.state.current_image) + current_image = encode_to_base64(shared.state.current_image) return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) diff --git a/modules/call_queue.py b/modules/call_queue.py index a4eefcda4bd..d8442f7a32a 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -413,22 +413,16 @@ def f(*args, **kwargs): username = args[len(args) - 2] sagemaker_endpoint = args[len(args) -1] args = args[:-2] - print('username:', username) - print('sagemaker_endpoint:', sagemaker_endpoint) res = sagemaker_inference('text-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs) elif cmd_opts.pureui and func == modules.img2img.img2img: username = args[len(args) - 2] sagemaker_endpoint = args[len(args) -1] args = args[:-2] - print('username:', username) - print('sagemaker_endpoint:', sagemaker_endpoint) res = sagemaker_inference('image-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs) elif cmd_opts.pureui and func == modules.extras.run_extras: username = args[len(args) - 2] sagemaker_endpoint = args[len(args) -1] args = args[:-2] - print('username:', username) - print('sagemaker_endpoint:', sagemaker_endpoint) res = sagemaker_inference('extras', 'sync', username, sagemaker_endpoint, *args, **kwargs) else: shared.state.begin() diff --git a/modules/processing.py b/modules/processing.py index 69c2a5e2d43..b8d5bc3b75e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -22,7 +22,8 @@ import modules.styles import logging import base64 -import io +from io import BytesIO +from numpy import asarray # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -68,6 +69,10 @@ def apply_overlay(image, paste_loc, index, overlays): return image +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + return Image.open(BytesIO(base64.b64decode(encoding))) class StableDiffusionProcessing(): """ @@ -119,9 +124,11 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom self.script_args = json.loads(script_args) if script_args != None else None if self.script_args: - for key in self.script_args: - if key == 'image' or key == 'mask': - self.script_arg[key] = Image.open(io.BytesIO(base64.b64decode(self.script_args[key]))) + for idx in range(len(self.script_args)): + if(isinstance(self.script_args[idx], dict)): + for key in self.script_args[idx]: + if key == 'image' or key == 'mask': + self.script_args[idx][key] = asarray(decode_base64_to_image(self.script_args[idx][key])) if not seed_enable_extras: self.subseed = -1 diff --git a/modules/shared.py b/modules/shared.py index 4824aeab4b9..5c1bedc4e86 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -137,14 +137,13 @@ hypernetworks = {} loaded_hypernetwork = None -if cmd_opts.pureui: - api_endpoint = os.environ['api_endpoint'] - industrial_model = '' - default_options = {} - username_state = None - sagemaker_endpoint_component = None - sd_model_checkpoint_component = None - create_train_dreambooth_component = None +api_endpoint = os.environ['api_endpoint'] +industrial_model = '' +default_options = {} +username_state = None +sagemaker_endpoint_component = None +sd_model_checkpoint_component = None +create_train_dreambooth_component = None def reload_hypernetworks(): from modules.hypernetworks import hypernetwork diff --git a/modules/ui.py b/modules/ui.py index 08a5531d8f8..685b5d071eb 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -658,10 +658,9 @@ def open_folder(f): parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info -if cmd_opts.pureui: - txt2img_submit = None - img2img_submit = None - extras_submit = None +txt2img_submit = None +img2img_submit = None +extras_submit = None def create_ui(): import modules.img2img @@ -697,6 +696,7 @@ def create_ui(): ) ui_tabs = script_callbacks.ui_tabs_callback() + dreambooth_tab = None for ui_tab in ui_tabs: if ui_tab[2] != 'dreambooth_interface': @@ -752,13 +752,10 @@ def fun(): opts.reorder() def run_settings(*args): - assert cmd_opts.pureui - changed = [] username = args[len(args) - 1] args = args[:-1] - print('username:', username) if not username or username == '': return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' @@ -1335,8 +1332,7 @@ def update_orig(image, state): show_extras_results = gr.Checkbox(label='Show result images', value=True) submit = gr.Button('Generate', elem_id="extras_generate", variant='primary', visible=(not cmd_opts.pureui)) - if cmd_opts.pureui: - extras_submit = submit + extras_submit = submit with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by'): @@ -1434,431 +1430,176 @@ def update_orig(image, state): with gr.Row().style(equal_height=False): with gr.Tabs(elem_id="train_tabs"): + with gr.Tab(label="Train Embedding"): + gr.HTML(value="

Train an embedding; you must specify a directory with a set of 1:1 ratio images [wiki]

") + + with gr.Box(): + gr.HTML(value="

Embedding settings

") - if not cmd_opts.pureui: - with gr.Tab(label="Create embedding"): new_embedding_name = gr.Textbox(label="Name") initialization_text = gr.Textbox(label="Initialization text", value="*") nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") + with gr.Box(): + gr.HTML(value="

Image preprocess settings

") - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary') - - with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"]) - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') - - with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory') - process_dst = gr.Textbox(label='Destination directory') - process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) - - with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies') - process_split = gr.Checkbox(label='Split oversized images') - process_focal_crop = gr.Checkbox(label='Auto focal point crop') - process_caption = gr.Checkbox(label='Use BLIP for caption') - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True) - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_debug = gr.Checkbox(label='Create debug image') + embedding_images_s3uri = gr.Textbox(label='Images S3 URI') + embedding_models_s3uri = gr.Textbox(label='Models S3 URI') + embedding_process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + embedding_process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + embedding_preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt") - run_preprocess = gr.Button(value="Preprocess", variant='primary') - - process_split.change( + embedding_process_flip = gr.Checkbox(label='Create flipped copies') + embedding_process_split = gr.Checkbox(label='Split oversized images') + embedding_process_focal_crop = gr.Checkbox(label='Auto focal point crop') + embedding_process_caption = gr.Checkbox(label='Use BLIP for caption') + embedding_process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True) + + with gr.Row(visible=False) as embedding_process_split_extra_row: + embedding_process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + embedding_process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + + with gr.Row(visible=False) as embedding_process_focal_crop_row: + embedding_process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) + embedding_process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) + embedding_process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + embedding_process_focal_crop_debug = gr.Checkbox(label='Create debug image') + + embedding_process_split.change( fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], + inputs=[embedding_process_split], + outputs=[embedding_process_split_extra_row], ) - process_focal_crop.change( + embedding_process_focal_crop.change( fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], + inputs=[embedding_process_focal_crop], + outputs=[embedding_process_focal_crop_row], ) - with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with gr.Row(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + with gr.Box(): + gr.HTML(value="

Train settings

") + with gr.Row(): - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") + with gr.Column(): + embedding_training_instance_type = gr.Dropdown(label='Instance type', value="ml.g4dn.xlarge", choices=training_instance_types) + with gr.Column(): + embedding_training_instance_count = gr.Number(label='Instance count', value=1, precision=0) + with gr.Row(): embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") - batch_size = gr.Number(label='Batch size', value=1, precision=0) - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) - training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - steps = gr.Number(label='Max steps', value=100000, precision=0) - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) - with gr.Row(): - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0) + embedding_batch_size = gr.Number(label='Batch size', value=1, precision=0) + embedding_gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) + embedding_training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + embedding_training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + embedding_steps = gr.Number(label='Max steps', value=100000, precision=0) + embedding_create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) + embedding_save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) + embedding_save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) + embedding_preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) with gr.Row(): - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) - + embedding_shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) + embedding_tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0) with gr.Row(): - interrupt_training = gr.Button(value="Interrupt") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') - train_embedding = gr.Button(value="Train Embedding", variant='primary') + embedding_latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + with gr.Row(): + with gr.Column(scale=3): + embedding_output = gr.Label(label='Output') - script_callbacks.ui_train_tabs_callback(params) + with gr.Column(): + create_train_embedding = gr.Button(value="Train Embedding", variant='primary', visible=False) - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + with gr.Tab(label="Train Hypernetwork"): + gr.HTML(value="

Train an hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) - ti_progress = gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + with gr.Box(): + gr.HTML(value="

Hypernetwork settings

") - create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_embedding_name, - embedding_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - steps, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - steps, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - else: - with gr.Tab(label="Train Embedding"): - gr.HTML(value="

Train an embedding; you must specify a directory with a set of 1:1 ratio images [wiki]

") - - with gr.Box(): - gr.HTML(value="

Embedding settings

") - - new_embedding_name = gr.Textbox(label="Name") - initialization_text = gr.Textbox(label="Initialization text", value="*") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") - - with gr.Box(): - gr.HTML(value="

Image preprocess settings

") - - embedding_images_s3uri = gr.Textbox(label='Images S3 URI') - embedding_models_s3uri = gr.Textbox(label='Models S3 URI') - embedding_process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - embedding_process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - embedding_preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) + new_hypernetwork_name = gr.Textbox(label="Name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") - with gr.Row(): - embedding_process_flip = gr.Checkbox(label='Create flipped copies') - embedding_process_split = gr.Checkbox(label='Split oversized images') - embedding_process_focal_crop = gr.Checkbox(label='Auto focal point crop') - embedding_process_caption = gr.Checkbox(label='Use BLIP for caption') - embedding_process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True) - - with gr.Row(visible=False) as embedding_process_split_extra_row: - embedding_process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - embedding_process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) - - with gr.Row(visible=False) as embedding_process_focal_crop_row: - embedding_process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) - embedding_process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) - embedding_process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - embedding_process_focal_crop_debug = gr.Checkbox(label='Create debug image') - - embedding_process_split.change( - fn=lambda show: gr_show(show), - inputs=[embedding_process_split], - outputs=[embedding_process_split_extra_row], - ) + with gr.Box(): + gr.HTML(value="

Image preprocess settings

") - embedding_process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[embedding_process_focal_crop], - outputs=[embedding_process_focal_crop_row], - ) + hypernetwork_images_s3uir = gr.Textbox(label='Images S3 URI') + hypernetwork_models_s3uri = gr.Textbox(label='Models S3 URI') + hypernetwork_process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + hypernetwork_process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + hypernetwork_preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) - with gr.Box(): - gr.HTML(value="

Train settings

") + with gr.Row(): + hypernetwork_process_flip = gr.Checkbox(label='Create flipped copies') + hypernetwork_process_split = gr.Checkbox(label='Split oversized images') + hypernetwork_process_focal_crop = gr.Checkbox(label='Auto focal point crop') + hypernetwork_process_caption = gr.Checkbox(label='Use BLIP for caption') + hypernetwork_process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True) + + with gr.Row(visible=False) as hypernetwork_process_split_extra_row: + hypernetwork_process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + hypernetwork_process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + + with gr.Row(visible=False) as hypernetwork_process_focal_crop_row: + hypernetwork_process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) + hypernetwork_process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) + hypernetwork_process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + hypernetwork_process_focal_crop_debug = gr.Checkbox(label='Create debug image') + + hypernetwork_process_split.change( + fn=lambda show: gr_show(show), + inputs=[hypernetwork_process_split], + outputs=[hypernetwork_process_split_extra_row], + ) - with gr.Row(): - with gr.Column(): - embedding_training_instance_type = gr.Dropdown(label='Instance type', value="ml.g4dn.xlarge", choices=training_instance_types) - with gr.Column(): - embedding_training_instance_count = gr.Number(label='Instance count', value=1, precision=0) - - with gr.Row(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") - - embedding_batch_size = gr.Number(label='Batch size', value=1, precision=0) - embedding_gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) - embedding_training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - embedding_training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - embedding_steps = gr.Number(label='Max steps', value=100000, precision=0) - embedding_create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) - embedding_save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) - embedding_save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) - embedding_preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) - with gr.Row(): - embedding_shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) - embedding_tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0) - with gr.Row(): - embedding_latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) + hypernetwork_process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[hypernetwork_process_focal_crop], + outputs=[hypernetwork_process_focal_crop_row], + ) + with gr.Box(): + gr.HTML(value="

Train settings

") with gr.Row(): - with gr.Column(scale=3): - embedding_output = gr.Label(label='Output') - with gr.Column(): - create_train_embedding = gr.Button(value="Train Embedding", variant='primary', visible=False) - - with gr.Tab(label="Train Hypernetwork"): - gr.HTML(value="

Train an hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - - with gr.Box(): - gr.HTML(value="

Hypernetwork settings

") - - new_hypernetwork_name = gr.Textbox(label="Name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") - - with gr.Box(): - gr.HTML(value="

Image preprocess settings

") - - hypernetwork_images_s3uir = gr.Textbox(label='Images S3 URI') - hypernetwork_models_s3uri = gr.Textbox(label='Models S3 URI') - hypernetwork_process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - hypernetwork_process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - hypernetwork_preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) - - with gr.Row(): - hypernetwork_process_flip = gr.Checkbox(label='Create flipped copies') - hypernetwork_process_split = gr.Checkbox(label='Split oversized images') - hypernetwork_process_focal_crop = gr.Checkbox(label='Auto focal point crop') - hypernetwork_process_caption = gr.Checkbox(label='Use BLIP for caption') - hypernetwork_process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True) - - with gr.Row(visible=False) as hypernetwork_process_split_extra_row: - hypernetwork_process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - hypernetwork_process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) - - with gr.Row(visible=False) as hypernetwork_process_focal_crop_row: - hypernetwork_process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) - hypernetwork_process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) - hypernetwork_process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - hypernetwork_process_focal_crop_debug = gr.Checkbox(label='Create debug image') - - hypernetwork_process_split.change( - fn=lambda show: gr_show(show), - inputs=[hypernetwork_process_split], - outputs=[hypernetwork_process_split_extra_row], - ) - - hypernetwork_process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[hypernetwork_process_focal_crop], - outputs=[hypernetwork_process_focal_crop_row], - ) - with gr.Box(): - gr.HTML(value="

Train settings

") + hypernetwork_training_instance_type = gr.Dropdown(label='Instance type', value="ml.g4dn.xlarge", choices=training_instance_types) + with gr.Column(): + hypernetwork_training_instance_count = gr.Number(label='Instance count', value=1, precision=0) - with gr.Row(): - with gr.Column(): - hypernetwork_training_instance_type = gr.Dropdown(label='Instance type', value="ml.g4dn.xlarge", choices=training_instance_types) - with gr.Column(): - hypernetwork_training_instance_count = gr.Number(label='Instance count', value=1, precision=0) - - with gr.Row(): - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") - - hypernetwork_batch_size = gr.Number(label='Batch size', value=1, precision=0) - hypernetwork_gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) - hypernetwork_training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - hypernetwork_training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - hypernetwork_steps = gr.Number(label='Max steps', value=100000, precision=0) - hypernetwork_create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) - hypernetwork_save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) - hypernetwork_save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) - hypernetwork_preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) - with gr.Row(): - hypernetwork_shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) - hypernetwork_tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0) - with gr.Row(): - hypernetwork_latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) + with gr.Row(): + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") + hypernetwork_batch_size = gr.Number(label='Batch size', value=1, precision=0) + hypernetwork_gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) + hypernetwork_training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + hypernetwork_training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + hypernetwork_steps = gr.Number(label='Max steps', value=100000, precision=0) + hypernetwork_create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) + hypernetwork_save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) + hypernetwork_save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) + hypernetwork_preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + with gr.Row(): + hypernetwork_shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) + hypernetwork_tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0) with gr.Row(): - with gr.Column(scale=3): - hypernetwork_output = gr.Label(label='Output') + hypernetwork_latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) - with gr.Column(): - create_train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', visible=False) + with gr.Row(): + with gr.Column(scale=3): + hypernetwork_output = gr.Label(label='Output') + + with gr.Column(): + create_train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', visible=False) + if dreambooth_tab: with gr.Tab(label="Train Dreambooth"): dreambooth_tab.render() @@ -1967,7 +1708,7 @@ def sagemaker_train_embedding( 'instance_count': embedding_training_instance_count, 'inputs': inputs } - + response = requests.post(url=f'{shared.api_endpoint}/train', json=data) if response.status_code == 200: return { @@ -2091,7 +1832,7 @@ def sagemaker_train_hypernetwork( 'instance_count': hypernetwork_training_instance_count, 'inputs': inputs } - + response = requests.post(url=f'{shared.api_endpoint}/train', json=data) if response.status_code == 200: return { @@ -2399,19 +2140,19 @@ def user_delete(login_username, login_password, login_email): signin.click( fn=user_signin, inputs=[signin_username, signin_password], - outputs=[shared.username_state, user_login_row, user_sign_row, login_username, login_password, login_email,txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, shared.create_train_dreambooth_component, signin_output] + components + outputs=[shared.username_state, user_login_row, user_sign_row, login_username, login_password, login_email,txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, shared.create_train_dreambooth_component if shared.create_train_dreambooth_component else dummy_component, signin_output] + components ) signup.click( fn=user_signup, inputs=[signup_username, signup_password, signup_email], - outputs=[shared.username_state, user_login_row, user_sign_row, login_username, login_password, login_email, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, shared.create_train_dreambooth_component, signup_output] + outputs=[shared.username_state, user_login_row, user_sign_row, login_username, login_password, login_email, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, shared.create_train_dreambooth_component if shared.create_train_dreambooth_component else dummy_component, signup_output] ) signout.click( fn=user_signout, inputs=[], - outputs=[shared.username_state, user_login_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, shared.create_train_dreambooth_component] + components + outputs=[shared.username_state, user_login_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, shared.create_train_dreambooth_component if shared.create_train_dreambooth_component else dummy_component] + components ) userupdate.click( @@ -2423,7 +2164,7 @@ def user_delete(login_username, login_password, login_email): userdelete.click( fn=user_delete, inputs=[login_username, login_password, login_email], - outputs=[shared.username_state, user_login_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, shared.create_train_dreambooth_component, login_output] + components + outputs=[shared.username_state, user_login_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, shared.create_train_dreambooth_component if shared.create_train_dreambooth_component else dummy_component, login_output] + components ) if cmd_opts.pureui: diff --git a/requirements_versions.txt b/requirements_versions.txt index bfe9deafd23..ad410616b20 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -27,4 +27,5 @@ inflection==0.5.1 GitPython==3.1.27 torchsde==0.2.5 safetensors==0.2.5 +fastapi==0.90.1 boto3 diff --git a/webui.py b/webui.py index 7d1f0f56566..f6fc554363f 100644 --- a/webui.py +++ b/webui.py @@ -142,63 +142,6 @@ def api_only(): app.add_middleware(GZipMiddleware, minimum_size=1000) api = create_api(app) - ckpt_dir = cmd_opts.ckpt_dir - sd_models_path = os.path.join(shared.models_path, "Stable-diffusion") - if ckpt_dir is not None: - sd_models_path = ckpt_dir - - controlnet_dir = cmd_opts.controlnet_dir - cn_models_path = os.path.join(shared.models_path, "ControlNet") - os.makedirs(controlnet_dir, exist_ok=True) - if controlnet_dir is not None: - cn_models_path = controlnet_dir - - if 'endpoint_name' in os.environ: - items = [] - api_endpoint = os.environ['api_endpoint'] - endpoint_name = os.environ['endpoint_name'] - for file in os.listdir(sd_models_path): - if os.path.isfile(os.path.join(sd_models_path, file)) and file.endswith('.ckpt'): - hash = modules.sd_models.model_hash(os.path.join(sd_models_path, file)) - item = {} - item['model_name'] = file - item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' - item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file) - item['hash'] = hash - item['title'] = '{0} [{1}]'.format(file, hash) - item['endpoint_name'] = endpoint_name - items.append(item) - inputs = { - 'items': items - } - params = { - 'module': 'Stable-diffusion' - } - if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): - response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) - print(response) - - items = [] - inputs = { - 'items': items - } - params = { - 'module': 'ControlNet' - } - for file in os.listdir(cn_models_path): - if os.path.isfile(os.path.join(cn_models_path, file)) and \ - (file.endswith('pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')): - hash = modules.sd_models.model_hash(os.path.join(cn_models_path, file)) - item = {} - item['model_name'] = file - item['title'] = '{0} [{1}]'.format(file, hash) - item['endpoint_name'] = endpoint_name - items.append(item) - - if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): - response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) - print(response) - modules.script_callbacks.app_started_callback(None, app) @app.exception_handler(RequestValidationError) @@ -210,6 +153,17 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) +def user_auth(username, password): + inputs = { + 'username': username, + 'password': password + } + api_endpoint = os.environ['api_endpoint'] + + response = requests.post(url=f'{api_endpoint}/sd/login', json=inputs) + print(response) + + return response.status_code == 200 def webui(): launch_api = cmd_opts.api @@ -228,7 +182,8 @@ def webui(): ssl_keyfile=cmd_opts.tls_keyfile, ssl_certfile=cmd_opts.tls_certfile, debug=cmd_opts.gradio_debug, - auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, + auth=user_auth, + auth_message="This login process is being used to verify your eligibility to use this stable-diffusion-webui. It's up to your organization's implementation", inbrowser=cmd_opts.autolaunch, prevent_thread_lock=True ) @@ -248,7 +203,63 @@ def webui(): if launch_api: create_api(app) - modules.script_callbacks.app_started_callback(shared.demo, app) + ckpt_dir = cmd_opts.ckpt_dir + sd_models_path = os.path.join(shared.models_path, "Stable-diffusion") + if ckpt_dir is not None: + sd_models_path = ckpt_dir + + controlnet_dir = cmd_opts.controlnet_dir + cn_models_path = os.path.join(shared.models_path, "ControlNet") + os.makedirs(controlnet_dir, exist_ok=True) + if controlnet_dir is not None: + cn_models_path = controlnet_dir + + if 'endpoint_name' in os.environ: + items = [] + api_endpoint = os.environ['api_endpoint'] + endpoint_name = os.environ['endpoint_name'] + for file in os.listdir(sd_models_path): + if os.path.isfile(os.path.join(sd_models_path, file)) and file.endswith('.ckpt'): + hash = modules.sd_models.model_hash(os.path.join(sd_models_path, file)) + item = {} + item['model_name'] = file + item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' + item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file) + item['hash'] = hash + item['title'] = '{0} [{1}]'.format(file, hash) + item['endpoint_name'] = endpoint_name + items.append(item) + inputs = { + 'items': items + } + params = { + 'module': 'Stable-diffusion' + } + if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): + response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) + print(response) + + items = [] + inputs = { + 'items': items + } + params = { + 'module': 'ControlNet' + } + for file in os.listdir(cn_models_path): + if os.path.isfile(os.path.join(cn_models_path, file)) and \ + (file.endswith('pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')): + hash = modules.sd_models.model_hash(os.path.join(cn_models_path, file)) + item = {} + item['model_name'] = file + item['title'] = '{0} [{1}]'.format(os.path.basename(file), hash) + item['endpoint_name'] = endpoint_name + items.append(item) + + if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): + response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) + print(response) + modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo)