From 18815ab6859498469a24755254f0d1fe59fcd0ad Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Fri, 2 Dec 2022 16:29:19 +0800 Subject: [PATCH] big cleanup and revise for stable-diffusion-webui --- modules/api/api.py | 35 +- modules/api/models.py | 1 + modules/sd_models.py | 48 +- modules/shared.py | 41 +- modules/ui.py | 842 ++++++++++++++++++++++++++++------- requirements.txt.cn | 31 ++ requirements_versions.txt.cn | 28 ++ webui.py | 93 ++-- 8 files changed, 890 insertions(+), 229 deletions(-) create mode 100644 requirements.txt.cn create mode 100644 requirements_versions.txt.cn diff --git a/modules/api/api.py b/modules/api/api.py index 53f03a93f22..a70c46a3824 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,9 +18,10 @@ import json import os import boto3 -from modules import sd_hijack +from modules import sd_hijack, hypernetworks from typing import Union import traceback +import requests def upscaler_to_index(name: str): try: @@ -347,18 +348,42 @@ def invocations(self, req: InvocationsRequest): hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri try: + username = req.username + default_options = shared.opts.data + if username != '': + inputs = { + 'action': 'get', + 'username': username + } + api_endpoint = os.environ['api_endpoint'] + response = requests.post(url=f'{api_endpoint}/sd/user', json=inputs) + if response.status_code == 200 and response.text != '': + shared.opts.data = json.loads(response.text) + + self.download_s3files(hypernetwork_s3uri, os.path.join(script_path, shared.cmd_opts.hypernetwork_dir)) + hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) + hypernetworks.hypernetwork.apply_strength() + if req.task == 'text-to-image': self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() - return self.text2imgapi(req.txt2img_payload) + response = self.text2imgapi(req.txt2img_payload) + shared.opts.data = default_options + return response elif req.task == 'image-to-image': self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() - return self.img2imgapi(req.img2img_payload) + response = self.img2imgapi(req.img2img_payload) + shared.opts.data = default_options + return response elif req.task == 'extras-single-image': - return self.extras_single_image_api(req.extras_single_payload) + response = self.extras_single_image_api(req.extras_single_payload) + shared.opts.data = default_options + return response elif req.task == 'extras-batch-images': - return self.extras_batch_images_api(req.extras_batch_payload) + response = self.extras_batch_images_api(req.extras_batch_payload) + shared.opts.data = default_options + return response elif req.task == 'sd-models': return self.get_sd_models() else: diff --git a/modules/api/models.py b/modules/api/models.py index 89beb8ca526..9f1dd9800cf 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -242,6 +242,7 @@ class ArtistItem(BaseModel): class InvocationsRequest(BaseModel): task: str + username: Optional[str] txt2img_payload: Optional[StableDiffusionTxt2ImgProcessingAPI] img2img_payload: Optional[StableDiffusionImg2ImgProcessingAPI] extras_single_payload: Optional[ExtrasSingleImageRequest] diff --git a/modules/sd_models.py b/modules/sd_models.py index b6885078f4e..f8173c0e6b2 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,9 +15,6 @@ import requests import json -api_endpoint = os.environ['api_endpoint'] if 'api_endpoint' in os.environ else '' -endpoint_name = os.environ['endpoint_name'] if 'endpoint_name' in os.environ else '' - model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -53,18 +50,41 @@ def list_models(): checkpoints_list.clear() if shared.cmd_opts.pureui: - response = requests.get(url=f'{api_endpoint}/sd/models') - model_list = json.loads(response.text) - - for model in model_list: - h = model['hash'] - filename = model['filename'] - title = model['title'] - short_model_name = model['model_name'] - config = model['config'] - - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + api_endpoint = os.environ['api_endpoint'] if 'api_endpoint' in os.environ else '' + endpoint_name = os.environ['endpoint_name'] if 'endpoint_name' in os.environ else '' + class SDModel: + def __init__(self, sd_model_name, sd_model_hash, sd_model_checkpoint, sd_checkpoint_info): + self.sd_model_name = sd_model_name + self.sd_model_hash = sd_model_hash + self.sd_model_checkpoint = sd_model_checkpoint + self.sd_checkpoint_info = sd_checkpoint_info + response = requests.get(url=f'{api_endpoint}/sd/models') + if response.status_code == 200: + model_list = json.loads(response.text) + + for model in model_list: + h = model['hash'] + filename = model['filename'] + title = model['title'] + short_model_name = model['model_name'] + config = model['config'] + + if 'sd_model_checkpoint' not in shared.opts.data: + shared.opts.data['sd_model_checkpoint'] = title + + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + + sd_model_checkpoint = shared.opts.data['sd_model_checkpoint'] + sd_checkpoint_info = checkpoints_list[sd_model_checkpoint] + sd_model_name = checkpoints_list[sd_model_checkpoint].model_name + sd_model_hash = checkpoints_list[sd_model_checkpoint].hash + shared.sd_model = SDModel( + sd_model_name, + sd_model_hash, + sd_model_checkpoint, + sd_checkpoint_info + ) else: model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) diff --git a/modules/shared.py b/modules/shared.py index 5b3be5f776a..657b4d246dd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -98,8 +98,9 @@ parser.add_argument("--train-args", type=str, help='Train args', default='') parser.add_argument('--embeddings-s3uri', default='', type=str, help='Embedding S3Uri') parser.add_argument('--hypernetwork-s3uri', default='', type=str, help='Hypernetwork S3Uri') -parser.add_argument('--industrial-model', default='', type=str, help='Industrial Model') parser.add_argument('--region-name', type=str, help='Region Name') +parser.add_argument('--username', default='', type=str, help='Username') +parser.add_argument('--api-endpoint', default='', type=str, help='API Endpoint') cmd_opts = parser.parse_args() restricted_opts = { @@ -131,6 +132,14 @@ hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None +if cmd_opts.pureui: + username = '' + api_endpoint = os.environ['api_endpoint'] + industrial_model = '' + endpoint_name = '' + endpoint_names = [] + default_options = {} + def reload_hypernetworks(): global hypernetworks @@ -472,10 +481,13 @@ def load(self, filename): print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr) if cmd_opts.pureui: - opts.show_progressbar = False - api_endpoint = os.environ['api_endpoint'] + global api_endpoint, industrial_model, default_options - if 'industrial_model' not in opts.data: + #opts.show_progressbar = False + response = requests.get(url=f'{api_endpoint}/sd/industrialmodel') + if response.status_code == 200: + industrial_model = response.text + else: model_name = 'stable-diffusion-webui' model_description = model_name inputs = { @@ -493,8 +505,8 @@ def load(self, filename): if response.status_code == 200: body = json.loads(response.text) industrial_model = body['id'] - opts.data['industrial_model'] = industrial_model - opts.save(config_filename) + + default_options = self.data def onchange(self, key, func, call=True): item = self.data_labels.get(key) @@ -534,8 +546,6 @@ def reorder(self): progress_print_out = sys.stdout -userid = '' - class TotalTQDM: def __init__(self): self._tqdm = None @@ -577,3 +587,18 @@ def clear(self): def listfiles(dirname): filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] return [file for file in filenames if os.path.isfile(file)] + +if cmd_opts.pureui: + def init_endpoints(): + global endpoint_name, endpoint_names, industrial_model, api_endpoint + + endpoints = [] + params = { + 'industrial_model': industrial_model + } + response = requests.get(url=f'{api_endpoint}/endpoint', params=params) + if response.status_code == 200: + for endpoint_item in json.loads(response.text): + endpoints.append(endpoint_item['EndpointName']) + endpoint_name = endpoints[0] if len(endpoints) > 0 else '' + endpoint_names = endpoints \ No newline at end of file diff --git a/modules/ui.py b/modules/ui.py index 6b474fe3c4c..ae6ed75bdca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -43,6 +43,23 @@ import modules.textual_inversion.ui import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text +if cmd_opts.pureui: + import requests + training_instance_types = [ + 'ml.p2.xlarge', + 'ml.p2.8xlarge', + 'ml.p2.16xlarge', + 'ml.p3.2xlarge', + 'ml.p3.8xlarge', + 'ml.p3.16xlarge', + 'ml.g4dn.xlarge', + 'ml.g4dn.2xlarge', + 'ml.g4dn.4xlarge', + 'ml.g4dn.8xlarge', + 'ml.g4dn.12xlarge', + 'ml.g4dn.16xlarge' + ] + # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -565,7 +582,17 @@ def apply_setting(key, value): if oldval != value and opts.data_labels[key].onchange is not None: opts.data_labels[key].onchange() - opts.save(shared.config_filename) + if cmd_opts.pureui: + if shared.username != '': + inputs = { + 'action': 'edit', + 'username': shared.username, + 'options': json.dumps(opts.data) + } + + response = requests.post(url=f'{shared.api_endpoint}/sd/user', json = inputs) + else: + opts.save(shared.config_filename) return value @@ -870,12 +897,13 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) - - with gr.TabItem('Batch img2img', id='batch'): - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) + + if not cmd_opts.pureui: + with gr.TabItem('Batch img2img', id='batch'): + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) with gr.Row(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize") @@ -967,8 +995,8 @@ def create_ui(wrap_gradio_gpu_call): inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, + img2img_batch_input_dir if not cmd_opts.pureui else dummy_component, + img2img_batch_output_dir if not cmd_opts.pureui else dummy_component, ] + custom_inputs, outputs=[ img2img_gallery, @@ -1417,61 +1445,64 @@ def create_ui(wrap_gradio_gpu_call): with gr.Box(): gr.HTML(value="

Embedding settings

") - new_embedding_name = gr.Textbox(label="Embedding name") + new_embedding_name = gr.Textbox(label="Name") initialization_text = gr.Textbox(label="Initialization text", value="*") nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") with gr.Box(): - gr.HTML(value="

Image preprocessing settings

") + gr.HTML(value="

Image preprocess settings

") - images_s3uri = gr.Textbox(label='Images S3 URI') - process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) + embedding_images_s3uri = gr.Textbox(label='Images S3 URI') + embedding_models_s3uri = gr.Textbox(label='Models S3 URI') + embedding_process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + embedding_process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + embedding_preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies') - process_split = gr.Checkbox(label='Split oversized images') - process_focal_crop = gr.Checkbox(label='Auto focal point crop') - process_caption = gr.Checkbox(label='Use BLIP for caption') - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) + embedding_process_flip = gr.Checkbox(label='Create flipped copies') + embedding_process_split = gr.Checkbox(label='Split oversized images') + embedding_process_focal_crop = gr.Checkbox(label='Auto focal point crop') + embedding_process_caption = gr.Checkbox(label='Use BLIP for caption') + embedding_process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + embedding_process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + embedding_process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_debug = gr.Checkbox(label='Create debug image') + embedding_process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) + embedding_process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) + embedding_process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + embedding_process_focal_crop_debug = gr.Checkbox(label='Create debug image') with gr.Box(): - gr.HTML(value="

Train embedding settings

") + gr.HTML(value="

Train settings

") + with gr.Row(): + with gr.Column(): + embedding_training_instance_type = gr.Dropdown(label='Instance type', value="ml.g4dn.xlarge", choices=training_instance_types) + with gr.Column(): + embedding_training_instance_count = gr.Number(label='Instance count', value=1, precision=0) + with gr.Row(): embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") - batch_size = gr.Number(label='Batch size', value=1, precision=0) - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) - training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - steps = gr.Number(label='Max steps', value=100000, precision=0) - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + embedding_batch_size = gr.Number(label='Batch size', value=1, precision=0) + embedding_training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + embedding_training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + embedding_steps = gr.Number(label='Max steps', value=100000, precision=0) + embedding_create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) + embedding_save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) + embedding_save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) + embedding_preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) with gr.Row(): with gr.Column(scale=3): - gr.HTML(value="") + embedding_output = gr.Label(label='Output') with gr.Column(): - create_train_embedding = gr.Button(value="Create & train embedding", variant='primary', visible=False) + create_train_embedding = gr.Button(value="Train Embedding", variant='primary', visible=False) with gr.Tab(label="Create & Train Hypernetwork"): gr.HTML(value="

Train an hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") @@ -1479,7 +1510,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Box(): gr.HTML(value="

Hypernetwork settings

") - new_hypernetwork_name = gr.Textbox(label="Hypernetwork name") + new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) @@ -1489,55 +1520,356 @@ def create_ui(wrap_gradio_gpu_call): overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") with gr.Box(): - gr.HTML(value="

Image preprocessing settings

") + gr.HTML(value="

Image preprocess settings

") - images_s3uir = gr.Textbox(label='Images S3 URI') - process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) + hypernetwork_images_s3uir = gr.Textbox(label='Images S3 URI') + hypernetwork_models_s3uri = gr.Textbox(label='Models S3 URI') + hypernetwork_process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + hypernetwork_process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + hypernetwork_preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies') - process_split = gr.Checkbox(label='Split oversized images') - process_focal_crop = gr.Checkbox(label='Auto focal point crop') - process_caption = gr.Checkbox(label='Use BLIP for caption') - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) + hypernetwork_process_flip = gr.Checkbox(label='Create flipped copies') + hypernetwork_process_split = gr.Checkbox(label='Split oversized images') + hypernetwork_process_focal_crop = gr.Checkbox(label='Auto focal point crop') + hypernetwork_process_caption = gr.Checkbox(label='Use BLIP for caption') + hypernetwork_process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + hypernetwork_process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + hypernetwork_process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_debug = gr.Checkbox(label='Create debug image') + hypernetwork_process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) + hypernetwork_process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) + hypernetwork_process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + hypernetwork_process_focal_crop_debug = gr.Checkbox(label='Create debug image') with gr.Box(): - gr.HTML(value="

Train hypernetwork settings

") + gr.HTML(value="

Train settings

") with gr.Row(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") + with gr.Column(): + hypernetwork_training_instance_type = gr.Dropdown(label='Instance type', value="ml.g4dn.xlarge", choices=training_instance_types) + with gr.Column(): + hypernetwork_training_instance_count = gr.Number(label='Instance count', value=1, precision=0) + + with gr.Row(): hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") - batch_size = gr.Number(label='Batch size', value=1, precision=0) - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) - training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - steps = gr.Number(label='Max steps', value=100000, precision=0) - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + hypernetwork_batch_size = gr.Number(label='Batch size', value=1, precision=0) + hypernetwork_training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + hypernetwork_training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + hypernetwork_steps = gr.Number(label='Max steps', value=100000, precision=0) + hypernetwork_create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) + hypernetwork_save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) + hypernetwork_save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) + hypernetwork_preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) with gr.Row(): with gr.Column(scale=3): - gr.HTML(value="") + hypernetwork_output = gr.Label(label='Output') with gr.Column(): - create_train_hypernetwork = gr.Button(value="Create & train hypernetwork", variant='primary', visible=False) + create_train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', visible=False) + + def sagemaker_train_embedding( + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + embedding_images_s3uri, + embedding_models_s3uri, + embedding_process_width, + embedding_process_height, + embedding_preprocess_txt_action, + embedding_process_flip, + embedding_process_split, + embedding_process_focal_crop, + embedding_process_caption, + embedding_process_caption_deepbooru, + embedding_process_split_threshold, + embedding_process_overlap_ratio, + embedding_process_focal_crop_face_weight, + embedding_process_focal_crop_entropy_weight, + embedding_process_focal_crop_edges_weight, + embedding_process_focal_crop_debug, + embedding_learn_rate, + embedding_batch_size, + embedding_training_width, + embedding_training_height, + embedding_steps, + embedding_create_image_every, + embedding_save_embedding_every, + embedding_save_image_with_stored_embedding, + embedding_preview_from_txt2img, + embedding_training_instance_type, + embedding_training_instance_count, + *txt2img_preview_params + ): + + train_args = { + 'embedding_settings': { + 'name': new_embedding_name, + 'nvpt': nvpt, + 'overwrite_old': overwrite_old_embedding, + 'initialization_text': initialization_text + }, + 'images_preprocessing_settings': { + 'process_width': embedding_process_width, + 'process_height': embedding_process_height, + 'preprocess_txt_action': embedding_preprocess_txt_action, + 'process_flip': embedding_process_flip, + 'process_split': embedding_process_split, + 'process_caption': embedding_process_caption, + 'process_caption_deepbooru': embedding_process_caption_deepbooru, + 'process_split_threshold': embedding_process_split_threshold, + 'process_overlap_ratio': embedding_process_overlap_ratio, + 'process_focal_crop': embedding_process_focal_crop, + 'process_focal_crop_face_weight': embedding_process_focal_crop_face_weight, + 'process_focal_crop_entropy_weight': embedding_process_focal_crop_entropy_weight, + 'process_focal_crop_debug': embedding_process_focal_crop_debug + }, + 'train_embedding_settings':{ + 'learn_rate': embedding_learn_rate, + 'batch_size': embedding_batch_size, + 'training_width': embedding_training_width, + 'training_height': embedding_training_height, + 'steps': embedding_steps, + 'create_image_every': embedding_create_image_every, + 'save_embedding_every': embedding_save_embedding_every, + 'save_image_with_stored_embedding': embedding_save_image_with_stored_embedding, + 'preview_from_txt2img': embedding_preview_from_txt2img, + 'txt2img_preview_params': txt2img_preview_params + } + } + + hyperparameters = { + 'train-args': json.dumps(json.dumps(train_args)), + 'train-task': 'embedding', + 'ckpt': '/opt/ml/input/data/models/{0}'.format(shared.sd_model.sd_model_name), + 'username': shared.username, + 'api-endpoint': shared.api_endpoint + } + + inputs = { + 'images': embedding_images_s3uri, + 'models': embedding_models_s3uri + } + + data = { + 'training_job_name': '', + 'model_algorithm': 'stable-diffusion-webui', + 'model_hyperparameters': hyperparameters, + 'industrial_model': shared.industrial_model, + 'instance_type': embedding_training_instance_type, + 'instance_count': embedding_training_instance_count, + 'inputs': inputs + } + + response = requests.post(url=f'{shared.api_endpoint}/train', json=data) + if response.status_code == 200: + return { + embedding_output: gr.update(value='Submit training job sucessful') + } + else: + return { + embedding_output: gr.update(value=response.text) + } + + def sagemaker_train_hypernetwork( + new_hypernetwork_name, + new_hypernetwork_sizes, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout, + overwrite_old_hypernetwork, + hypernetwork_images_s3uri, + hypernetwork_models_s3uri, + hypernetwork_process_width, + hypernetwork_process_height, + hypernetwork_preprocess_txt_action, + hypernetwork_process_flip, + hypernetwork_process_split, + hypernetwork_process_focal_crop, + hypernetwork_process_caption, + hypernetwork_process_caption_deepbooru, + hypernetwork_process_split_threshold, + hypernetwork_process_overlap_ratio, + hypernetwork_process_focal_crop_face_weight, + hypernetwork_process_focal_crop_entropy_weight, + hypernetwork_process_focal_crop_edges_weight, + hypernetwork_process_focal_crop_debug, + hypernetwork_learn_rate, + hypernetwork_batch_size, + hypernetwork_training_width, + hypernetwork_training_height, + hypernetwork_steps, + hypernetwork_create_image_every, + hypernetwork_save_embedding_every, + hypernetwork_save_image_with_stored_embedding, + hypernetwork_preview_from_txt2img, + hypernetwork_training_instance_type, + hypernetwork_training_instance_count, + *txt2img_preview_params + ): + + train_args = { + 'hypernetwork_settings': { + 'name': new_hypernetwork_name, + 'enable_sizes': new_hypernetwork_sizes, + 'overwrite_old': overwrite_old_hypernetwork, + 'layer_structure': new_hypernetwork_layer_structure, + 'activation_func': new_hypernetwork_activation_func, + 'weight_init': new_hypernetwork_initialization_option, + 'new_hypernetwork_add_layer_norm': new_hypernetwork_add_layer_norm, + 'new_hypernetwork_use_dropout': new_hypernetwork_use_dropout, + }, + 'images_preprocessing_settings': { + 'process_width': hypernetwork_process_width, + 'process_height': hypernetwork_process_height, + 'preprocess_txt_action': hypernetwork_preprocess_txt_action, + 'process_flip': hypernetwork_process_flip, + 'process_split': hypernetwork_process_split, + 'process_caption': hypernetwork_process_caption, + 'process_caption_deepbooru': hypernetwork_process_caption_deepbooru, + 'process_split_threshold': hypernetwork_process_split_threshold, + 'process_overlap_ratio': hypernetwork_process_overlap_ratio, + 'process_focal_crop': hypernetwork_process_focal_crop, + 'process_focal_crop_face_weight': hypernetwork_process_focal_crop_face_weight, + 'process_focal_crop_entropy_weight': hypernetwork_process_focal_crop_entropy_weight, + 'process_focal_crop_debug': hypernetwork_process_focal_crop_debug + }, + 'train_hypernetwork_settings':{ + 'learn_rate': hypernetwork_learn_rate, + 'batch_size': hypernetwork_batch_size, + 'training_width': hypernetwork_training_width, + 'training_height': hypernetwork_training_height, + 'steps': hypernetwork_steps, + 'create_image_every': hypernetwork_create_image_every, + 'save_embedding_every': hypernetwork_save_embedding_every, + 'save_image_with_stored_embedding': hypernetwork_save_image_with_stored_embedding, + 'preview_from_txt2img': hypernetwork_preview_from_txt2img, + 'txt2img_preview_params': txt2img_preview_params + } + } + + hyperparameters = { + 'train-args': json.dumps(json.dumps(train_args)), + 'train-task': 'hypernetwork', + 'ckpt': '/opt/ml/input/data/models/{0}'.format(shared.sd_model.sd_model_name), + 'username': shared.username, + 'api-endpoint': shared.api_endpoint + } + + inputs = { + 'images': hypernetwork_images_s3uri, + 'models': hypernetwork_models_s3uri + } + + data = { + 'training_job_name': '', + 'model_algorithm': 'stable-diffusion-webui', + 'model_hyperparameters': hyperparameters, + 'industrial_model': shared.industrial_model, + 'instance_type': hypernetwork_training_instance_type, + 'instance_count': hypernetwork_training_instance_count, + 'inputs': inputs + } + + response = requests.post(url=f'{shared.api_endpoint}/train', json=data) + if response.status_code == 200: + return { + hypernetwork_output: gr.update(value='Submit training job sucessful') + } + else: + return { + hypernetwork_output: gr.update(value=response.text) + } + + create_train_embedding.click( + fn=sagemaker_train_embedding, + inputs=[ + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + embedding_images_s3uri, + embedding_models_s3uri, + embedding_process_width, + embedding_process_height, + embedding_preprocess_txt_action, + embedding_process_flip, + embedding_process_split, + embedding_process_focal_crop, + embedding_process_caption, + embedding_process_caption_deepbooru, + embedding_process_split_threshold, + embedding_process_overlap_ratio, + embedding_process_focal_crop_face_weight, + embedding_process_focal_crop_entropy_weight, + embedding_process_focal_crop_edges_weight, + embedding_process_focal_crop_debug, + embedding_learn_rate, + embedding_batch_size, + embedding_training_width, + embedding_training_height, + embedding_steps, + embedding_create_image_every, + embedding_save_embedding_every, + embedding_save_image_with_stored_embedding, + embedding_preview_from_txt2img, + embedding_training_instance_type, + embedding_training_instance_count, + *txt2img_preview_params + ], + outputs=[embedding_output] + ) + + create_train_hypernetwork.click( + fn=sagemaker_train_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout, + overwrite_old_hypernetwork, + hypernetwork_images_s3uir, + hypernetwork_models_s3uri, + hypernetwork_process_width, + hypernetwork_process_height, + hypernetwork_preprocess_txt_action, + hypernetwork_process_flip, + hypernetwork_process_split, + hypernetwork_process_focal_crop, + hypernetwork_process_caption, + hypernetwork_process_caption_deepbooru, + hypernetwork_process_split_threshold, + hypernetwork_process_overlap_ratio, + hypernetwork_process_focal_crop_face_weight, + hypernetwork_process_focal_crop_entropy_weight, + hypernetwork_process_focal_crop_edges_weight, + hypernetwork_process_focal_crop_debug, + hypernetwork_learn_rate, + hypernetwork_batch_size, + hypernetwork_training_width, + hypernetwork_training_height, + hypernetwork_steps, + hypernetwork_create_image_every, + hypernetwork_save_embedding_every, + hypernetwork_save_image_with_stored_embedding, + hypernetwork_preview_from_txt2img, + hypernetwork_training_instance_type, + hypernetwork_training_instance_count, + *txt2img_preview_params + ], + outputs=[hypernetwork_output] + ) def create_setting_component(key, is_quicksettings=False): def fun(): @@ -1601,7 +1933,20 @@ def run_settings(*args): changed.append(key) try: - opts.save(shared.config_filename) + if cmd_opts.pureui: + if shared.username != '': + inputs = { + 'action': 'edit', + 'username': shared.username, + 'options': json.dumps(opts.data) + } + + response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) + if response.status_code != 200: + raise RuntimeError("Settings saved failed") + else: + opts.save(shared.config_filename) + except RuntimeError: return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' @@ -1625,98 +1970,6 @@ def run_settings_single(value, key): return gr.update(value=value), opts.dumpjson() with gr.Blocks(analytics_enabled=False) as settings_interface: - if cmd_opts.pureui: - def change_sign_options(choice): - return { - signin_column: gr.update(visible=(choice=="Sign In")), - signup_column: gr.update(visible=(choice=="Sign Up")) - } - - with gr.Row(visible=(shared.userid!='')) as user_id_row: - with gr.Column(): - login_userid = gr.Label(label="User ID") - login_username = gr.Textbox(label="User Name") - - with gr.Column(): - signout = gr.Button("Sign Out") - edit_user = gr.Button("Edit User") - delete_user = gr.Button("Delete User") - - with gr.Row(visible=(shared.userid=='')) as user_sign_row: - with gr.Column(): - sign_options = gr.Radio(["Sign In", "Sign Up"], label="Sign Options", value="Sign In", interactive=True) - - with gr.Column(visible=(sign_options.value=="Sign In")) as signin_column: - signin_username = gr.Textbox(label="User Name") - signin_password = gr.Textbox(label="Password") - signin = gr.Button("Sign In") - - with gr.Column(visible=(sign_options.value=="Sign Up")) as signup_column: - signup_username = gr.Textbox(label="User Name") - signup_password = gr.Textbox(label="Password") - signup_email = gr.Textbox(label="Email") - signup = gr.Button("Sign Up") - - sign_options.change(change_sign_options, sign_options, [signin_column, signup_column]) - - def user_signin(signin_username, signin_password): - shared.userid='1234' - return { - user_id_row : gr.update(visible=True), - user_sign_row: gr.update(visible=False), - login_userid: gr.update(value=shared.userid), - login_username: gr.update(value=signin_username), - txt2img_submit: gr.update(visible=True), - img2img_submit: gr.update(visible=True), - extras_submit: gr.update(visible=True), - create_train_embedding: gr.update(visible=True), - create_train_hypernetwork: gr.update(visible=True) - } - - def user_signup(signup_username, signup_password, signup_email): - shared.userid='5678' - return { - user_id_row: gr.update(visible=True), - user_sign_row: gr.update(visible=False), - login_userid: gr.update(value=shared.userid), - login_username: gr.update(value=signup_username), - txt2img_submit: gr.update(visible=True), - img2img_submit: gr.update(visible=True), - extras_submit: gr.update(visible=True), - create_train_embedding: gr.update(visible=True), - create_train_hypernetwork: gr.update(visible=True) - } - - def user_signout(): - shared.userid='' - return { - user_id_row : gr.update(visible=False), - user_sign_row: gr.update(visible=True), - txt2img_submit: gr.update(visible=False), - img2img_submit: gr.update(visible=False), - extras_submit: gr.update(visible=True), - create_train_embedding: gr.update(visible=False), - create_train_hypernetwork: gr.update(visible=False) - } - - signin.click( - fn=user_signin, - inputs=[signin_username, signin_password], - outputs=[user_id_row, user_sign_row, login_userid, login_username, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork] - ) - - signup.click( - fn=user_signup, - inputs=[signup_username, signup_password, signup_email], - outputs=[user_id_row, user_sign_row, login_userid, login_username, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork] - ) - - signout.click( - fn=user_signout, - inputs=[], - outputs=[user_id_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork] - ) - settings_submit = gr.Button(value="Apply settings", variant='primary') result = gr.HTML() @@ -1809,6 +2062,228 @@ def request_restart(): if column is not None: column.__exit__() + with gr.Blocks(analytics_enabled=False) as user_interface: + def change_sign_options(choice): + return { + signin_column: gr.update(visible=(choice=="Sign In")), + signup_column: gr.update(visible=(choice=="Sign Up")) + } + + with gr.Row(visible=(shared.username!='')) as user_login_row: + with gr.Column(): + login_username = gr.Text(label="Username") + login_password = gr.Text(label="Password", type="password") + login_email = gr.Text(label="Email", type="email") + + with gr.Column(): + signout = gr.Button("Sign Out") + userupdate = gr.Button("Update") + userdelete = gr.Button("Delete") + login_output = gr.Label(label="Output") + + with gr.Row(visible=(shared.username=='')) as user_sign_row: + with gr.Column(): + sign_options = gr.Radio(["Sign In", "Sign Up"], label="Sign Options", value="Sign In", interactive=True) + + with gr.Column(visible=(sign_options.value=="Sign In")) as signin_column: + signin_username = gr.Textbox(label="Username") + signin_password = gr.Textbox(label="Password", type="password") + signin_output = gr.Label(label="Output") + signin = gr.Button("Sign In") + + with gr.Column(visible=(sign_options.value=="Sign Up")) as signup_column: + signup_username = gr.Textbox(label="Username") + signup_password = gr.Textbox(label="Password", type="password") + signup_email = gr.Textbox(label="Email", type="email") + signup_output = gr.Label(label="Output") + signup = gr.Button("Sign Up") + + sign_options.change(change_sign_options, sign_options, [signin_column, signup_column]) + + def user_signin(signin_username, signin_password): + inputs = { + 'action': 'signin', + 'username': signin_username, + 'password': signin_password + } + + response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) + if response.status_code == 200: + shared.username = json.loads(response.text)['username'] + password = json.loads(response.text)['password'] + email = json.loads(response.text)['email'] + options = json.loads(response.text)['options'] if 'options' in json.loads(response.text) else None + + response = { + user_login_row : gr.update(visible=True), + user_sign_row: gr.update(visible=False), + login_username: gr.update(value=signin_username), + login_password: gr.update(value=password), + login_email: gr.update(value=email), + txt2img_submit: gr.update(visible=True), + img2img_submit: gr.update(visible=True), + extras_submit: gr.update(visible=True), + create_train_embedding: gr.update(visible=True), + create_train_hypernetwork: gr.update(visible=True), + signin_output: gr.update(value='') + } + + if options != None: + opts.data = json.loads(options) + for key in opts.data: + if key in component_dict: + response[component_dict[key]] = gr.update(value=opts.data[key]) + for key in sd_models.checkpoints_list: + if sd_models.checkpoints_list[key].title == opts.data['sd_model_checkpoint']: + shared.sd_model.sd_model_name = sd_models.checkpoints_list[key].model_name + break + return response + else: + return { + signin_output: gr.update(value='Mismatched username/password or not existed username') + } + + def user_signup(signup_username, signup_password, signup_email): + inputs = { + 'action': 'signup', + 'username': signup_username, + 'password': signup_password, + 'email': signup_email + } + + response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) + if response.status_code == 200: + shared.username = json.loads(response.text)['username'] + + return { + user_login_row: gr.update(visible=True), + user_sign_row: gr.update(visible=False), + login_username: gr.update(value=signup_username), + login_password: gr.update(value=signup_password), + login_email: gr.update(value=signup_email), + txt2img_submit: gr.update(visible=True), + img2img_submit: gr.update(visible=True), + extras_submit: gr.update(visible=True), + create_train_embedding: gr.update(visible=True), + create_train_hypernetwork: gr.update(visible=True), + signup_output: gr.update(value='') + } + else: + return { + signup_output: gr.update(value='Signup failed, please check and retry again') + } + + def user_signout(): + shared.username='' + opts.data = shared.default_options + for key in sd_models.checkpoints_list: + if sd_models.checkpoints_list[key].title == opts.data['sd_model_checkpoint']: + shared.sd_model.sd_model_name = sd_models.checkpoints_list[key].model_name + break + + response = { + user_login_row : gr.update(visible=False), + user_sign_row: gr.update(visible=True), + txt2img_submit: gr.update(visible=False), + img2img_submit: gr.update(visible=False), + extras_submit: gr.update(visible=True), + create_train_embedding: gr.update(visible=False), + create_train_hypernetwork: gr.update(visible=False) + } + + for key in opts.data: + if key in component_dict: + response[component_dict[key]] = gr.update(value=opts.data[key]) + + return response + + def user_update(login_username, login_password, login_email): + inputs = { + 'action': 'edit', + 'username': login_username, + 'password': login_password, + 'email': login_email + } + + response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) + if response.status_code == 200: + shared.username = json.loads(response.text)['username'] + + return { + login_output: gr.update(value='Update succeed') + } + else: + return { + login_output: gr.update(value='Update failed, please check and retry again') + } + + def user_delete(login_username, login_password, login_email): + inputs = { + 'action': 'delete', + 'username': login_username, + 'password': login_password + } + + response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) + if response.status_code == 200: + shared.username = json.loads(response.text)['username'] + opts.data = shared.default_options + for key in sd_models.checkpoints_list: + if sd_models.checkpoints_list[key].title == opts.data['sd_model_checkpoint']: + shared.sd_model.sd_model_name = sd_models.checkpoints_list[key].model_name + break + + response = { + user_login_row : gr.update(visible=False), + user_sign_row: gr.update(visible=True), + txt2img_submit: gr.update(visible=False), + img2img_submit: gr.update(visible=False), + extras_submit: gr.update(visible=True), + create_train_embedding: gr.update(visible=False), + create_train_hypernetwork: gr.update(visible=False), + login_output: gr.update(value='') + } + + for key in opts.data: + if key in component_dict: + response[component_dict[key]] = gr.update(value=opts.data[key]) + + return response + else: + return { + login_output: gr.update(value='Delete failed, please check and retry again') + } + + signin.click( + fn=user_signin, + inputs=[signin_username, signin_password], + outputs=[user_login_row, user_sign_row, login_username, login_password, login_email,txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, signin_output] + components + ) + + signup.click( + fn=user_signup, + inputs=[signup_username, signup_password, signup_email], + outputs=[user_login_row, user_sign_row, login_username, login_password, login_email, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, signup_output] + ) + + signout.click( + fn=user_signout, + inputs=[], + outputs=[user_login_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork] + components + ) + + userupdate.click( + fn=user_update, + inputs=[login_username, login_password, login_email], + outputs=[user_login_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, login_output] + ) + + userdelete.click( + fn=user_delete, + inputs=[login_username, login_password, login_email], + outputs=[user_login_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, login_output] + components + ) + if cmd_opts.pureui: interfaces = [ (txt2img_interface, "txt2img", "txt2img"), @@ -1816,6 +2291,7 @@ def request_restart(): (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), (train_interface, "Train", "ti"), + (user_interface, "User", "user") ] else: interfaces = [ @@ -1847,14 +2323,44 @@ def request_restart(): interfaces += [(settings_interface, "Settings", "settings")] extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] - + interfaces += [(extensions_interface, "Extensions", "extensions")] + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings"): for i, k, item in quicksettings_list: component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component + if cmd_opts.pureui: + shared.init_endpoints() + + with gr.Row(): + with gr.Column(scale=9): + endpoint_names = gr.Dropdown(label='SageMaker endpoint', value=shared.endpoint_name, choices=shared.endpoint_names) + with gr.Column(scale=1): + endpoint_refresh = gr.Button(refresh_symbol) + + def refresh_endpoint(): + shared.init_endpoints() + return { + endpoint_names: gr.update(value=shared.endpoint_name, choices=shared.endpoint_names) + } + + def change_endpoint(endpoint_names): + shared.endpoint_name = endpoint_names + + endpoint_names.change( + fn=change_endpoint, + inputs=[endpoint_names], + outputs=[] + ) + + endpoint_refresh.click( + fn=refresh_endpoint, + inputs=[], + outputs=[endpoint_names] + ) + parameters_copypaste.integrate_settings_paste_fields(component_dict) parameters_copypaste.run_bind() diff --git a/requirements.txt.cn b/requirements.txt.cn new file mode 100644 index 00000000000..ad7fc684661 --- /dev/null +++ b/requirements.txt.cn @@ -0,0 +1,31 @@ +-i https://pypi.tuna.tsinghua.edu.cn/simple +basicsr +diffusers +fairscale==0.4.4 +fonts +font-roboto +gfpgan +gradio==3.9 +invisible-watermark +numpy +omegaconf +opencv-python +requests +piexif +Pillow +pytorch_lightning==1.7.7 +realesrgan +scikit-image>=0.19 +timm==0.4.12 +transformers==4.19.2 +torch +einops +jsonmerge +clean-fid +resize-right +torchdiffeq +kornia +lark +inflection +GitPython +boto3 diff --git a/requirements_versions.txt.cn b/requirements_versions.txt.cn new file mode 100644 index 00000000000..a002fc7ebb9 --- /dev/null +++ b/requirements_versions.txt.cn @@ -0,0 +1,28 @@ +-i https://pypi.tuna.tsinghua.edu.cn/simple +transformers==4.19.2 +diffusers==0.3.0 +basicsr==1.4.2 +gfpgan==1.3.8 +gradio==3.9 +numpy==1.23.3 +Pillow==9.2.0 +realesrgan==0.3.0 +torch +omegaconf==2.2.3 +pytorch_lightning==1.7.6 +scikit-image==0.19.2 +fonts +font-roboto +timm==0.6.7 +fairscale==0.4.9 +piexif==1.1.3 +einops==0.4.1 +jsonmerge==1.8.0 +clean-fid==0.1.29 +resize-right==0.0.2 +torchdiffeq==0.2.3 +kornia==0.6.7 +lark==1.1.2 +inflection==0.5.1 +GitPython==3.1.27 +boto3 diff --git a/webui.py b/webui.py index d7b319f037c..d9bf0b962b5 100644 --- a/webui.py +++ b/webui.py @@ -47,9 +47,6 @@ queue_lock = threading.Lock() server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name -api_endpoint = os.environ['api_endpoint'] if 'api_endpoint' in os.environ else '' -endpoint_name = os.environ['endpoint_name'] if 'endpoint_name' in os.environ else '' - import boto3 import traceback from botocore.exceptions import ClientError @@ -88,7 +85,11 @@ def handle_sagemaker_inference_async(response): params = {'s3uri': s3uri} start = time.time() while True: - response = requests.get(url=f'{api_endpoint}/s3', params = params) + if shared.state.interrupted or shared.state.skipped: + shared.job_count = 0 + return None + + response = requests.get(url=f'{shared.api_endpoint}/s3', params = params) text = json.loads(response.text) if text['count'] > 0: @@ -166,7 +167,8 @@ def sagemaker_inference(task, infer, *args, **kwargs): } inputs = { 'task': task, - 'txt2img_payload': payload + 'txt2img_payload': payload, + 'username': shared.username } else: mode = args[0] @@ -279,19 +281,23 @@ def sagemaker_inference(task, infer, *args, **kwargs): } inputs = { 'task': task, - 'img2img_payload': payload + 'img2img_payload': payload, + 'username': shared.username } params = { - 'endpoint_name': endpoint_name + 'endpoint_name': shared.endpoint_name } - response = requests.post(url=f'{api_endpoint}/inference', params=params, json=inputs) + response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs) if infer == 'async': processed = handle_sagemaker_inference_async(response) else: processed = json.loads(response.text) + if processed == None: + return [], "", "" + images = [] for image in processed['images']: images.append(Image.open(io.BytesIO(base64.b64decode(image)))) @@ -341,7 +347,8 @@ def sagemaker_inference(task, infer, *args, **kwargs): task = 'extras-single-image' inputs = { 'task': task, - 'extras_single_payload': payload + 'extras_single_payload': payload, + 'username': shared.username } else: imageList = [] @@ -372,13 +379,14 @@ def sagemaker_inference(task, infer, *args, **kwargs): task = 'extras-batch-images' inputs = { 'task': task, - 'extras_batch_payload': payload + 'extras_batch_payload': payload, + 'username': shared.username } params = { - 'endpoint_name': endpoint_name + 'endpoint_name': shared.endpoint_name } - response = requests.post(url=f'{api_endpoint}/inference', params=params, json=inputs) + response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs) if infer == 'async': processed = handle_sagemaker_inference_async(response) else: @@ -572,7 +580,23 @@ def train(): train_args = json.loads(cmd_opts.train_args) embeddings_s3uri = cmd_opts.embeddings_s3uri - hypernetworks_s3uri = cmd_opts.hypernetworks_s3uri + hypernetwork_s3uri = cmd_opts.hypernetwork_s3uri + api_endpoint = cmd_opts.api_endpoint + username = cmd_opts.username + + default_options = opts.data + if username != '': + inputs = { + 'action': 'get', + 'username': username + } + response = requests.post(url=f'{api_endpoint}/sd/user', json=inputs) + if response.status_code == 200 and response.text != '': + opts.data = json.loads(response.text) + for key in modules.sd_models.checkpoints_list: + if modules.sd_models.checkpoints_list[key].title == opts.data['sd_model_checkpoint']: + shared.sd_model.sd_model_name = modules.sd_models.checkpoints_list[key].model_name + break if train_task == 'embedding': name = train_args['embedding_settings']['name'] @@ -622,9 +646,9 @@ def train(): process_focal_crop_debug, ) train_embedding_name = name - embedding_learn_rate = train_args['train_embedding_settings']['embedding_learn_rate'] + learn_rate = train_args['train_embedding_settings']['learn_rate'] batch_size = train_args['train_embedding_settings']['batch_size'] - dataset_directory = process_dst + data_root = process_dst log_directory = 'textual_inversion' training_width = train_args['train_embedding_settings']['training_width'] training_height = train_args['train_embedding_settings']['training_height'] @@ -637,9 +661,9 @@ def train(): txt2img_preview_params = train_args['train_embedding_settings']['txt2img_preview_params'] _, filename = modules.textual_inversion.textual_inversion.train_embedding( train_embedding_name, - embedding_learn_rate, + learn_rate, batch_size, - dataset_directory, + data_root, log_directory, training_width, training_height, @@ -656,6 +680,7 @@ def train(): except Exception as e: traceback.print_exc() print(e) + opts.data = default_options elif train_task == 'hypernetwork': name = train_args['hypernetwork_settings']['name'] enable_sizes = train_args['hypernetwork_settings']['enable_sizes'] @@ -668,7 +693,7 @@ def train(): name = "".join( x for x in name if (x.isalnum() or x in "._- ")) - fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") + fn = os.path.join(cmd_opts.hypernetwork_dir, f"{name}.pt") if not overwrite_old: assert not os.path.exists(fn), f"file {fn} already exists" @@ -721,22 +746,22 @@ def train(): process_focal_crop_debug, ) train_hypernetwork_name = name - embedding_learn_rate = train_args['train_embedding_settings']['embedding_learn_rate'] - batch_size = train_args['train_embedding_settings']['batch_size'] + learn_rate = train_args['train_hypernetwork_settings']['learn_rate'] + batch_size = train_args['train_hypernetwork_settings']['batch_size'] dataset_directory = process_dst log_directory = 'textual_inversion' - training_width = train_args['train_embedding_settings']['training_width'] - training_height = train_args['train_embedding_settings']['training_height'] - steps = train_args['train_embedding_settings']['steps'] - create_image_every = train_args['train_embedding_settings']['create_image_every'] - save_embedding_every = train_args['train_embedding_settings']['save_embedding_every'] - template_file = 'style_filewords.txt' - save_image_with_stored_embedding = train_args['train_embedding_settings']['save_image_with_stored_embedding'] - preview_from_txt2img = train_args['train_embedding_settings']['preview_from_txt2img'] - txt2img_preview_params = train_args['train_embedding_settings']['txt2img_preview_params'] - _, filename = modules.textual_inversion.textual_inversion.train_embedding( + training_width = train_args['train_hypernetwork_settings']['training_width'] + training_height = train_args['train_hypernetwork_settings']['training_height'] + steps = train_args['train_hypernetwork_settings']['steps'] + create_image_every = train_args['train_hypernetwork_settings']['create_image_every'] + save_hypernetwork_every = train_args['train_hypernetwork_settings']['save_embedding_every'] + template_file = os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt") + save_image_with_stored_embedding = train_args['train_hypernetwork_settings']['save_image_with_stored_embedding'] + preview_from_txt2img = train_args['train_hypernetwork_settings']['preview_from_txt2img'] + txt2img_preview_params = train_args['train_hypernetwork_settings']['txt2img_preview_params'] + _, filename = modules.hypernetworks.hypernetwork.train_hypernetwork( train_hypernetwork_name, - embedding_learn_rate, + learn_rate, batch_size, dataset_directory, log_directory, @@ -744,17 +769,17 @@ def train(): training_height, steps, create_image_every, - save_embedding_every, + save_hypernetwork_every, template_file, - save_image_with_stored_embedding, preview_from_txt2img, *txt2img_preview_params ) try: - upload_s3file(hypernetworks_s3uri, os.path.join(cmd_opts.hypernetwork_dir, '{0}.pt'.format(train_hypernetwork_name)), '{0}.pt'.format(train_hypernetwork_name)) + upload_s3file(hypernetwork_s3uri, os.path.join(cmd_opts.hypernetwork_dir, '{0}.pt'.format(train_hypernetwork_name)), '{0}.pt'.format(train_hypernetwork_name)) except Exception as e: traceback.print_exc() print(e) + opts.data = default_options else: print('Incorrect training task') exit(-1)