diff --git a/launch.py b/launch.py index 5fa115606f9..484f4a79481 100644 --- a/launch.py +++ b/launch.py @@ -241,7 +241,9 @@ def tests(argv): def start(): print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") import webui - if '--nowebui' in sys.argv: + if '--train' in sys.argv: + webui.train() + elif '--nowebui' in sys.argv: webui.api_only() else: webui.webui() diff --git a/modules/api/api.py b/modules/api/api.py index 00e2b30825b..53f03a93f22 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -14,6 +14,13 @@ from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models from typing import List +from modules.paths import script_path +import json +import os +import boto3 +from modules import sd_hijack +from typing import Union +import traceback def upscaler_to_index(name: str): try: @@ -77,8 +84,11 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) - self.app.add_api_route("/invocations", self.invocations, methods=["POST"], response_model=InvocationsResponse) + self.app.add_api_route("/invocations", self.invocations, methods=["POST"], response_model=Union[TextToImageResponse, ImageToImageResponse, ExtrasSingleImageResponse, ExtrasBatchImagesResponse, List[SDModelItem]]) self.app.add_api_route("/ping", self.ping, methods=["GET"], response_model=PingResponse) + self.cache = dict() + self.s3_client = boto3.client('s3') + self.s3_resource= boto3.resource('s3') def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -166,10 +176,8 @@ def extras_single_image_api(self, req: ExtrasSingleImageRequest): reqDict = setUpscalers(req) reqDict['image'] = decode_base64_to_image(reqDict['image']) - with self.queue_lock: result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict) - return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): @@ -183,9 +191,9 @@ def prepareFiles(file): reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList'])) reqDict.pop('imageList') + with self.queue_lock: result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict) - return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) def pnginfoapi(self, req: PNGInfoRequest): @@ -306,15 +314,60 @@ def get_artists_categories(self): def get_artists(self): return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] + def download_s3files(self, s3uri, path): + pos = s3uri.find('/', 5) + bucket = s3uri[5 : pos] + key = s3uri[pos + 1 : ] + + s3_bucket = self.s3_resource.Bucket(bucket) + objs = list(s3_bucket.objects.filter(Prefix=key)) + + if os.path.isfile('cache'): + self.cache = json.load(open('cache', 'r')) + + for obj in objs: + response = self.s3_client.head_object( + Bucket = bucket, + Key = obj.key + ) + obj_key = 's3://{0}/{1}'.format(bucket, obj.key) + if obj_key not in self.cache or self.cache[obj_key] != response['ETag']: + filename = obj.key[obj.key.rfind('/') + 1 : ] + + self.s3_client.download_file(bucket, obj.key, os.path.join(path, filename)) + self.cache[obj_key] = response['ETag'] + + json.dump(self.cache, open('cache', 'w')) + def invocations(self, req: InvocationsRequest): - if req.task == 'text-to-image': - return self.text2imgapi(req.payload) - elif req.task == 'image-to-image': - return self.img2imgapi(req.payload) - else: - raise NotImplementedError + print('-------invocation------') + print(req) + + embeddings_s3uri = shared.cmd_opts.embeddings_s3uri + hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri + + try: + if req.task == 'text-to-image': + self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + return self.text2imgapi(req.txt2img_payload) + elif req.task == 'image-to-image': + self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + return self.img2imgapi(req.img2img_payload) + elif req.task == 'extras-single-image': + return self.extras_single_image_api(req.extras_single_payload) + elif req.task == 'extras-batch-images': + return self.extras_batch_images_api(req.extras_batch_payload) + elif req.task == 'sd-models': + return self.get_sd_models() + else: + raise NotImplementedError + except Exception as e: + traceback.print_exc() def ping(self): + print('-------ping------') return {'status': 'Healthy'} def launch(self, server_name, port): diff --git a/modules/api/models.py b/modules/api/models.py index 6f1433de6bd..89beb8ca526 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -6,7 +6,7 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.shared import sd_upscalers, opts, parser from typing import Dict, List -from typing import Union +from typing import Optional API_NOT_ALLOWED = [ "self", @@ -142,7 +142,7 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") class ExtrasSingleImageResponse(ExtraBaseResponse): - image: str = Field(default=None, title="Image", description="The generated image in base64 format.") + image: str = Field(title="Image", description="The generated image in base64 format.") class FileData(BaseModel): data: str = Field(title="File data", description="Base64 representation of the file") @@ -242,12 +242,10 @@ class ArtistItem(BaseModel): class InvocationsRequest(BaseModel): task: str - payload: Union[StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI] - -class InvocationsResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: dict - info: str + txt2img_payload: Optional[StableDiffusionTxt2ImgProcessingAPI] + img2img_payload: Optional[StableDiffusionImg2ImgProcessingAPI] + extras_single_payload: Optional[ExtrasSingleImageRequest] + extras_batch_payload: Optional[ExtrasBatchImagesRequest] class PingResponse(BaseModel): status: str \ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index b5de9f161d9..4a1848d7563 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -243,6 +243,9 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", self.all_subseeds = all_subseeds or [self.subseed] self.infotexts = infotexts or [info] + self.scripts = p.scripts + self.script_args = p.script_args + def js(self): obj = { "prompt": self.prompt, @@ -472,10 +475,8 @@ def infotext(iteration=0, position_in_batch=0): if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() - if p.scripts is not None: p.scripts.process(p) - infotexts = [] output_images = [] @@ -609,6 +610,8 @@ def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstp self.firstphase_height = firstphase_height self.truncate_x = 0 self.truncate_y = 0 + self.scripts = modules.scripts.scripts_txt2img + self.scripts.setup_scripts(False) def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: @@ -740,6 +743,8 @@ def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strengt self.mask = None self.nmask = None self.image_conditioning = None + self.scripts = modules.scripts.scripts_img2img + self.scripts.setup_scripts(True) def init(self, all_prompts, all_seeds, all_subseeds): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) diff --git a/modules/sd_models.py b/modules/sd_models.py index 34c57bfa73a..b6885078f4e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -12,6 +12,11 @@ from modules import shared, modelloader, devices, script_callbacks, sd_vae from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting +import requests +import json + +api_endpoint = os.environ['api_endpoint'] if 'api_endpoint' in os.environ else '' +endpoint_name = os.environ['endpoint_name'] if 'endpoint_name' in os.environ else '' model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -44,44 +49,61 @@ def checkpoint_tiles(): def list_models(): + global checkpoints_list checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) - - def modeltitle(path, shorthash): - abspath = os.path.abspath(path) - - if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): - name = abspath.replace(shared.cmd_opts.ckpt_dir, '') - elif abspath.startswith(model_path): - name = abspath.replace(model_path, '') - else: - name = os.path.basename(path) - - if name.startswith("\\") or name.startswith("/"): - name = name[1:] - - shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] - - return f'{name} [{shorthash}]', shortname - - cmd_ckpt = shared.cmd_opts.ckpt - if os.path.exists(cmd_ckpt): - h = model_hash(cmd_ckpt) - title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) - shared.opts.data['sd_model_checkpoint'] = title - elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: - print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) - for filename in model_list: - h = model_hash(filename) - title, short_model_name = modeltitle(filename, h) - - basename, _ = os.path.splitext(filename) - config = basename + ".yaml" - if not os.path.exists(config): - config = shared.cmd_opts.config - - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + + if shared.cmd_opts.pureui: + response = requests.get(url=f'{api_endpoint}/sd/models') + model_list = json.loads(response.text) + + for model in model_list: + h = model['hash'] + filename = model['filename'] + title = model['title'] + short_model_name = model['model_name'] + config = model['config'] + + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + + else: + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) + + def modeltitle(path, shorthash): + abspath = os.path.abspath(path) + + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): + name = abspath.replace(shared.cmd_opts.ckpt_dir, '') + elif abspath.startswith(model_path): + name = abspath.replace(model_path, '') + else: + name = os.path.basename(path) + + if name.startswith("\\") or name.startswith("/"): + name = name[1:] + + shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + + return f'{name} [{shorthash}]', shortname + + cmd_ckpt = shared.cmd_opts.ckpt + if os.path.exists(cmd_ckpt): + h = model_hash(cmd_ckpt) + title, short_model_name = modeltitle(cmd_ckpt, h) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) + shared.opts.data['sd_model_checkpoint'] = title + elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: + print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) + + for filename in model_list: + h = model_hash(filename) + title, short_model_name = modeltitle(filename, h) + + basename, _ = os.path.splitext(filename) + config = basename + ".yaml" + if not os.path.exists(config): + config = shared.cmd_opts.config + + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) def get_closet_checkpoint_match(searchString): diff --git a/modules/shared.py b/modules/shared.py index 91869947953..5b3be5f776a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -19,6 +19,8 @@ from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path +import requests + sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() @@ -94,6 +96,10 @@ parser.add_argument("--train", action='store_true', help="Train only on SageMaker", default=False) parser.add_argument("--train-task", type=str, help='Train task - embedding or hypernetwork', default='embedding') parser.add_argument("--train-args", type=str, help='Train args', default='') +parser.add_argument('--embeddings-s3uri', default='', type=str, help='Embedding S3Uri') +parser.add_argument('--hypernetwork-s3uri', default='', type=str, help='Hypernetwork S3Uri') +parser.add_argument('--industrial-model', default='', type=str, help='Industrial Model') +parser.add_argument('--region-name', type=str, help='Region Name') cmd_opts = parser.parse_args() restricted_opts = { @@ -253,6 +259,25 @@ def options_section(section_identifier, options_dict): options_templates = {} +options_templates.update(options_section(('sd', "Stable Diffusion"), { + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), + "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), + "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), + "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), + "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), + "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), + "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), + "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), + "filter_nsfw": OptionInfo(False, "Filter NSFW content"), + 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), +})) + options_templates.update(options_section(('saving-images', "Saving images/grids"), { "samples_save": OptionInfo(True, "Always save all generated images"), "samples_format": OptionInfo('png', 'File format for images'), @@ -331,25 +356,6 @@ def options_section(section_identifier, options_dict): "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"), })) -options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), - "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), - "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), - "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), - "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), - "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), - "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), - "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), - "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), - "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), - "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), - "filter_nsfw": OptionInfo(False, "Filter NSFW content"), - 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), - "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), -})) - options_templates.update(options_section(('interrogate', "Interrogate Options"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), @@ -465,6 +471,31 @@ def load(self, filename): if bad_settings > 0: print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr) + if cmd_opts.pureui: + opts.show_progressbar = False + api_endpoint = os.environ['api_endpoint'] + + if 'industrial_model' not in opts.data: + model_name = 'stable-diffusion-webui' + model_description = model_name + inputs = { + 'model_algorithm': 'stable-diffusion-webui', + 'model_name': model_name, + 'model_description': model_description, + 'model_extra': '{"visible": "false"}', + 'model_samples': '', + 'file_content': { + 'data': [(lambda x: int(x))(x) for x in open(os.path.join(script_path, 'logo.ico'), 'rb').read()] + } + } + + response = requests.post(url=f'{api_endpoint}/industrialmodel', json = inputs) + if response.status_code == 200: + body = json.loads(response.text) + industrial_model = body['id'] + opts.data['industrial_model'] = industrial_model + opts.save(config_filename) + def onchange(self, key, func, call=True): item = self.data_labels.get(key) item.onchange = func @@ -494,8 +525,6 @@ def reorder(self): opts = Options() if os.path.exists(config_filename): opts.load(config_filename) -if cmd_opts.pureui: - opts.show_progressbar = False sd_upscalers = [] @@ -505,6 +534,7 @@ def reorder(self): progress_print_out = sys.stdout +userid = '' class TotalTQDM: def __init__(self): diff --git a/modules/ui.py b/modules/ui.py index d998e3b0b51..6b474fe3c4c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,20 +429,23 @@ def copy_seed(gen_info_string: str, index): def update_token_counter(text, steps): - try: - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] + if not cmd_opts.pureui: + try: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step, prompt_text in flat_prompts] - tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" + else: + return f"N/A" def create_toprow(is_img2img): @@ -486,7 +489,7 @@ def create_toprow(is_img2img): with gr.Row(): skip = gr.Button('Skip', elem_id=f"{id_part}_skip") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary', visible=(not cmd_opts.pureui)) skip.click( fn=lambda: shared.state.skip(), @@ -661,6 +664,10 @@ def open_folder(f): parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info +if cmd_opts.pureui: + txt2img_submit = None + img2img_submit = None + extras_submit = None def create_ui(wrap_gradio_gpu_call): import modules.img2img @@ -675,6 +682,8 @@ def create_ui(wrap_gradio_gpu_call): dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + txt2img_submit=submit + with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): pass @@ -824,6 +833,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) + img2img_submit=submit + with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -1049,13 +1060,15 @@ def create_ui(wrap_gradio_gpu_call): with gr.TabItem('Batch Process'): image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") + if not cmd_opts.pureui: + with gr.TabItem('Batch from Directory', visible=False): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") + show_extras_results = gr.Checkbox(label='Show result images', value=True) - with gr.TabItem('Batch from Directory'): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") - show_extras_results = gr.Checkbox(label='Show result images', value=True) - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary', visible=(not cmd_opts.pureui)) + if cmd_opts.pureui: + extras_submit = submit with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by'): @@ -1094,9 +1107,9 @@ def create_ui(wrap_gradio_gpu_call): dummy_component, extras_image, image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, + extras_batch_input_dir if not cmd_opts.pureui else dummy_component, + extras_batch_output_dir if not cmd_opts.pureui else dummy_component, + show_extras_results if not cmd_opts.pureui else dummy_component, gfpgan_visibility, codeformer_visibility, codeformer_weight, @@ -1115,6 +1128,7 @@ def create_ui(wrap_gradio_gpu_call): html_info, ] ) + parameters_copypaste.add_paste_fields("extras", extras_image, None) extras_image.change( @@ -1141,25 +1155,27 @@ def create_ui(wrap_gradio_gpu_call): outputs=[html, generation_info, html2], ) - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + if not cmd_opts.pureui: + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") - save_as_half = gr.Checkbox(value=False, label="Save as float16") - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + with gr.Row(): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + custom_name = gr.Textbox(label="Custom Name (Optional)") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") + save_as_half = gr.Checkbox(value=False, label="Save as float16") + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + with gr.Column(variant='panel'): + submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + if not cmd_opts.pureui: + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Row().style(equal_height=False): @@ -1455,7 +1471,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_train_embedding = gr.Button(value="Create & train embedding", variant='primary') + create_train_embedding = gr.Button(value="Create & train embedding", variant='primary', visible=False) with gr.Tab(label="Create & Train Hypernetwork"): gr.HTML(value="

Train an hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") @@ -1521,7 +1537,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_train_hypernetwork = gr.Button(value="Create & train hypernetwork", variant='primary') + create_train_hypernetwork = gr.Button(value="Create & train hypernetwork", variant='primary', visible=False) def create_setting_component(key, is_quicksettings=False): def fun(): @@ -1609,6 +1625,98 @@ def run_settings_single(value, key): return gr.update(value=value), opts.dumpjson() with gr.Blocks(analytics_enabled=False) as settings_interface: + if cmd_opts.pureui: + def change_sign_options(choice): + return { + signin_column: gr.update(visible=(choice=="Sign In")), + signup_column: gr.update(visible=(choice=="Sign Up")) + } + + with gr.Row(visible=(shared.userid!='')) as user_id_row: + with gr.Column(): + login_userid = gr.Label(label="User ID") + login_username = gr.Textbox(label="User Name") + + with gr.Column(): + signout = gr.Button("Sign Out") + edit_user = gr.Button("Edit User") + delete_user = gr.Button("Delete User") + + with gr.Row(visible=(shared.userid=='')) as user_sign_row: + with gr.Column(): + sign_options = gr.Radio(["Sign In", "Sign Up"], label="Sign Options", value="Sign In", interactive=True) + + with gr.Column(visible=(sign_options.value=="Sign In")) as signin_column: + signin_username = gr.Textbox(label="User Name") + signin_password = gr.Textbox(label="Password") + signin = gr.Button("Sign In") + + with gr.Column(visible=(sign_options.value=="Sign Up")) as signup_column: + signup_username = gr.Textbox(label="User Name") + signup_password = gr.Textbox(label="Password") + signup_email = gr.Textbox(label="Email") + signup = gr.Button("Sign Up") + + sign_options.change(change_sign_options, sign_options, [signin_column, signup_column]) + + def user_signin(signin_username, signin_password): + shared.userid='1234' + return { + user_id_row : gr.update(visible=True), + user_sign_row: gr.update(visible=False), + login_userid: gr.update(value=shared.userid), + login_username: gr.update(value=signin_username), + txt2img_submit: gr.update(visible=True), + img2img_submit: gr.update(visible=True), + extras_submit: gr.update(visible=True), + create_train_embedding: gr.update(visible=True), + create_train_hypernetwork: gr.update(visible=True) + } + + def user_signup(signup_username, signup_password, signup_email): + shared.userid='5678' + return { + user_id_row: gr.update(visible=True), + user_sign_row: gr.update(visible=False), + login_userid: gr.update(value=shared.userid), + login_username: gr.update(value=signup_username), + txt2img_submit: gr.update(visible=True), + img2img_submit: gr.update(visible=True), + extras_submit: gr.update(visible=True), + create_train_embedding: gr.update(visible=True), + create_train_hypernetwork: gr.update(visible=True) + } + + def user_signout(): + shared.userid='' + return { + user_id_row : gr.update(visible=False), + user_sign_row: gr.update(visible=True), + txt2img_submit: gr.update(visible=False), + img2img_submit: gr.update(visible=False), + extras_submit: gr.update(visible=True), + create_train_embedding: gr.update(visible=False), + create_train_hypernetwork: gr.update(visible=False) + } + + signin.click( + fn=user_signin, + inputs=[signin_username, signin_password], + outputs=[user_id_row, user_sign_row, login_userid, login_username, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork] + ) + + signup.click( + fn=user_signup, + inputs=[signup_username, signup_password, signup_email], + outputs=[user_id_row, user_sign_row, login_userid, login_username, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork] + ) + + signout.click( + fn=user_signout, + inputs=[], + outputs=[user_id_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork] + ) + settings_submit = gr.Button(value="Apply settings", variant='primary') result = gr.HTML() @@ -1701,14 +1809,23 @@ def request_restart(): if column is not None: column.__exit__() - interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "ti"), - ] + if cmd_opts.pureui: + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (train_interface, "Train", "ti"), + ] + else: + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (train_interface, "Train", "ti"), + ] css = "" @@ -1776,35 +1893,36 @@ def get_settings_values(): outputs=[component_dict[k] for k in component_keys], ) - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - custom_name, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) + if not cmd_opts.pureui: + def modelmerger(*args): + try: + results = modules.extras.run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() # to remove the potentially missing models from the list + return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] + return results + + modelmerger_merge.click( + fn=modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + custom_name, + ], + outputs=[ + submit_result, + primary_model_name, + secondary_model_name, + tertiary_model_name, + component_dict['sd_model_checkpoint'], + ] + ) ui_config_file = cmd_opts.ui_config_file ui_settings = {} @@ -1870,7 +1988,8 @@ def apply_field(obj, field, condition=None, init_field=None): visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") visit(extras_interface, loadsave, "extras") - visit(modelmerger_interface, loadsave, "modelmerger") + if not cmd_opts.pureui: + visit(modelmerger_interface, loadsave, "modelmerger") if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): with open(ui_config_file, "w", encoding="utf8") as file: diff --git a/requirements.txt b/requirements.txt index 0fba0b233ac..43258b1aa26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,4 @@ kornia lark inflection GitPython +boto3 diff --git a/requirements_versions.txt b/requirements_versions.txt index f7059f205c4..17837633cca 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -24,3 +24,4 @@ kornia==0.6.7 lark==1.1.2 inflection==0.5.1 GitPython==3.1.27 +boto3 diff --git a/webui.py b/webui.py index 14738cd797e..d7b319f037c 100644 --- a/webui.py +++ b/webui.py @@ -42,12 +42,18 @@ import uuid +from PIL import Image, ImageOps, ImageChops + queue_lock = threading.Lock() server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name api_endpoint = os.environ['api_endpoint'] if 'api_endpoint' in os.environ else '' endpoint_name = os.environ['endpoint_name'] if 'endpoint_name' in os.environ else '' +import boto3 +import traceback +from botocore.exceptions import ClientError + def wrap_queued_call(func): def f(*args, **kwargs): @@ -58,51 +64,26 @@ def f(*args, **kwargs): return f def wrap_gradio_gpu_call(func, extra_outputs=None): - def sagemaker_inference(task, *args, **kwargs): - script_args = [] - for i in range(23, len(args)): - script_args.append(args[i]) - - payload = { - "prompt": args[0], - "negative_prompt": args[1], - "styles": [args[2], args[3]], - "steps": args[4], - "sampler_index": sd_samplers.samplers[args[5]].name, - "restore_faces": args[6], - "tiling": args[7], - "batch_count": args[8], - "batch_size": args[9], - "cfg_scale": args[10], - "seed": args[11], - "subseed": args[12], - "subseed_strength": args[13], - "seed_resize_from_h": args[14], - "seed_resize_from_w": args[15], - "seed_checkbox": args[16], - "width": args[17], - "height": args[18], - "enable_hr": args[19], - "denoising_strength": args[20], - "firstphase_width": args[21], - "firstphase_height": args[22], - "script_args": json.dumps(script_args), - "eta": opts.eta_ddim if sd_samplers.samplers[args[5]].name == 'DDIM' or sd_samplers.samplers[args[5]].name == 'PLMS' else opts.eta_ancestral, - "s_churn": opts.s_churn, - "s_tmax": None, - "s_tmin": opts.s_tmin, - "s_noise": opts.s_noise, - } - inputs = { - 'task': task, - 'payload': payload - } - params = { - 'endpoint_name': endpoint_name, - 'infer_type': 'async' - } - - response = requests.post(url=f'{api_endpoint}/inference', params = params, json = inputs) + def encode_image_to_base64(image): + if isinstance(image, bytes): + encoded_string = base64.b64encode(image) + else: + image.tobytes("hex", "rgb") + image_bytes = io.BytesIO() + image.save(image_bytes, format='PNG') + encoded_string = base64.b64encode(image_bytes.getvalue()) + + base64_str = str(encoded_string, "utf-8") + mimetype = 'image/png' + image_encoded_in_base64 = ( + "data:" + + (mimetype if mimetype is not None else "") + + ";base64," + + base64_str + ) + return image_encoded_in_base64 + + def handle_sagemaker_inference_async(response): s3uri = response.text params = {'s3uri': s3uri} start = time.time() @@ -114,24 +95,311 @@ def sagemaker_inference(task, *args, **kwargs): break else: time.sleep(1) - + httpuri = text['payload'][0]['httpuri'] response = requests.get(url=httpuri) processed = json.loads(response.text) - images = [] - for image in processed['images']: - images.append(Image.open(io.BytesIO(base64.b64decode(image.split(',')[1])))) - parameters = processed['parameters'] - info = processed['info'] print(f"Time taken: {time.time() - start}s") - - return images, json.dumps(payload), modules.ui.plaintext_to_html(info) + + return processed + + def sagemaker_inference(task, infer, *args, **kwargs): + infer = 'async' + if task == 'text-to-image' or task == 'image-to-image': + if task == 'text-to-image': + script_args = [] + for i in range(23, len(args)): + script_args.append(args[i]) + + prompt = args[0] + negative_prompt = args[1] + styles = [args[2], args[3]] + steps = args[4] + sampler_index = sd_samplers.samplers[args[5]].name + restore_faces = args[6] + tiling = args[7] + n_iter = args[8] + batch_size = args[9] + cfg_scale = args[10] + seed = args[11] + subseed = args[12] + subseed_strength = args[13] + seed_resize_from_h = args[14] + seed_resize_from_w = args[15] + seed_enable_extras = args[16] + height = args[17] + width = args[18] + enable_hr = args[19] + denoising_strength = args[20] + firstphase_width = args[21] + firstphase_height = args[22] + + payload = { + "enable_hr": enable_hr, + "denoising_strength": denoising_strength, + "firstphase_width": firstphase_width, + "firstphase_height": firstphase_height, + "prompt": prompt, + "styles": styles, + "seed": seed, + "subseed": subseed, + "subseed_strength": subseed_strength, + "seed_resize_from_h": seed_resize_from_h, + "seed_resize_from_w": seed_resize_from_w, + "sampler_index": sampler_index, + "batch_size": batch_size, + "n_iter": n_iter, + "steps": steps, + "cfg_scale": cfg_scale, + "width": width, + "height": height, + "restore_faces": restore_faces, + "tiling": tiling, + "negative_prompt": negative_prompt, + "eta": opts.eta_ddim if sd_samplers.samplers[args[5]].name == 'DDIM' or sd_samplers.samplers[args[5]].name == 'PLMS' else opts.eta_ancestral, + "s_churn": opts.s_churn, + "s_tmax": None, + "s_tmin": opts.s_tmin, + "s_noise": opts.s_noise, + "override_settings": {}, + "script_args": json.dumps(script_args), + } + inputs = { + 'task': task, + 'txt2img_payload': payload + } + else: + mode = args[0] + prompt = args[1] + negative_prompt = args[2] + styles = [args[3], args[4]] + init_img = args[5] + init_img_with_mask = args[6] + init_img_inpaint = args[7] + init_mask_inpaint = args[8] + mask_mode = args[9] + steps = args[10] + sampler_index = sd_samplers.samplers[args[11]].name + mask_blur = args[12] + inpainting_fill = args[13] + restore_faces = args[14] + tiling = args[15] + n_iter = args[16] + batch_size = args[17] + cfg_scale = args[18] + denoising_strength = args[19] + seed = args[20] + subseed = args[21] + subseed_strength = args[22] + seed_resize_from_h = args[23] + seed_resize_from_w = args[24] + seed_enable_extras = args[25] + height = args[26] + width = args[27] + resize_mode = args[28] + inpaint_full_res = args[29] + inpaint_full_res_padding = args[30] + inpainting_mask_invert = args[31] + img2img_batch_input_dir = args[32] + img2img_batch_output_dir = args[33] + + script_args = [] + for i in range(34, len(args)): + script_args.append(args[i]) + + if mode == 1: + # Drawn mask + if mask_mode == 0: + image = init_img_with_mask['image'] + mask = init_img_with_mask['mask'] + alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') + mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') + image = image.convert('RGB') + # Uploaded mask + else: + image = init_img_inpaint + mask = init_mask_inpaint + # No mask + else: + image = init_img + mask = None + + # Use the EXIF orientation of photos taken by smartphones. + if image is not None: + image = ImageOps.exif_transpose(image) + + assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + + image_encoded_in_base64 = encode_image_to_base64(image) + mask_encoded_in_base64 = encode_image_to_base64(mask) if mask else None + + if init_img_with_mask: + image = init_img_with_mask['image'] + image_encoded_in_base64 = encode_image_to_base64(image) + mask_encoded_in_base64 = encode_image_to_base64(mask) + init_img_with_mask['image'] = image_encoded_in_base64 + init_img_with_mask['mask'] = mask_encoded_in_base64 + + payload = { + "init_images": [image_encoded_in_base64], + "resize_mode": resize_mode, + "denoising_strength": denoising_strength, + "mask": mask_encoded_in_base64, + "mask_blur": mask_blur, + "inpainting_fill": inpainting_fill, + "inpaint_full_res": inpaint_full_res, + "inpaint_full_res_padding": inpaint_full_res_padding, + "inpainting_mask_invert": inpainting_mask_invert, + "prompt": prompt, + "styles": styles, + "seed": seed, + "subseed": subseed, + "subseed_strength": subseed_strength, + "seed_resize_from_h": seed_resize_from_h, + "seed_resize_from_w": seed_resize_from_w, + "sampler_index": sampler_index, + "batch_size": batch_size, + "n_iter": n_iter, + "steps": steps, + "cfg_scale": args[18], + "width": width, + "height": height, + "restore_faces": restore_faces, + "tiling": tiling, + "negative_prompt": negative_prompt, + "sampler_index": sampler_index, + "eta": opts.eta_ddim if sd_samplers.samplers[args[11]].name == 'DDIM' or sd_samplers.samplers[args[11]].name == 'PLMS' else opts.eta_ancestral, + "s_churn": opts.s_churn, + "s_tmax": None, + "s_tmin": opts.s_tmin, + "s_noise": opts.s_noise, + "override_settings": {}, + "include_init_images": False, + "script_args": json.dumps(script_args) + } + inputs = { + 'task': task, + 'img2img_payload': payload + } + + params = { + 'endpoint_name': endpoint_name + } + + response = requests.post(url=f'{api_endpoint}/inference', params=params, json=inputs) + if infer == 'async': + processed = handle_sagemaker_inference_async(response) + else: + processed = json.loads(response.text) + + images = [] + for image in processed['images']: + images.append(Image.open(io.BytesIO(base64.b64decode(image)))) + parameters = processed['parameters'] + info = processed['info'] + + return images, json.dumps(payload), modules.ui.plaintext_to_html(info) + else: + extras_mode = args[0] + resize_mode = args[1] + image = args[2] + image_folder = args[3] + input_dir = args[4] + output_dir = args[5] + show_extras_results = args[6] + gfpgan_visibility = args[7] + codeformer_visibility = args[8] + codeformer_weight = args[9] + upscaling_resize = args[10] + upscaling_resize_w = args[11] + upscaling_resize_h = args[12] + upscaling_crop = args[13] + extras_upscaler_1 = shared.sd_upscalers[args[14]].name + extras_upscaler_2 = shared.sd_upscalers[args[15]].name + extras_upscaler_2_visibility = args[16] + upscale_first = args[17] + + if extras_mode == 0: + image_encoded_in_base64 = encode_image_to_base64(image) + + payload = { + "resize_mode": resize_mode, + "show_extras_results": show_extras_results if show_extras_results else True, + "gfpgan_visibility": gfpgan_visibility, + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + "upscaling_resize": upscaling_resize, + "upscaling_resize_w": upscaling_resize_w, + "upscaling_resize_h": upscaling_resize_h, + "upscaling_crop": upscaling_crop, + "upscaler_1": extras_upscaler_1, + "upscaler_2": extras_upscaler_2, + "extras_upscaler_2_visibility": extras_upscaler_2_visibility, + "upscale_first": upscale_first, + "image": image_encoded_in_base64 + } + task = 'extras-single-image' + inputs = { + 'task': task, + 'extras_single_payload': payload + } + else: + imageList = [] + for img in image_folder: + image_encoded_in_base64 = encode_image_to_base64(Image.open(img)) + imageList.append( + { + 'data': image_encoded_in_base64, + 'name': img.name + } + ) + payload = { + "resize_mode": resize_mode, + "show_extras_results": show_extras_results if show_extras_results else True, + "gfpgan_visibility": gfpgan_visibility, + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + "upscaling_resize": upscaling_resize, + "upscaling_resize_w": upscaling_resize_w, + "upscaling_resize_h": upscaling_resize_h, + "upscaling_crop": upscaling_crop, + "upscaler_1": extras_upscaler_1, + "upscaler_2": extras_upscaler_2, + "extras_upscaler_2_visibility": extras_upscaler_2_visibility, + "upscale_first": upscale_first, + "imageList": imageList + } + task = 'extras-batch-images' + inputs = { + 'task': task, + 'extras_batch_payload': payload + } + + params = { + 'endpoint_name': endpoint_name + } + response = requests.post(url=f'{api_endpoint}/inference', params=params, json=inputs) + if infer == 'async': + processed = handle_sagemaker_inference_async(response) + else: + processed = json.loads(response.text) + + if task == 'extras-single-image': + images = [Image.open(io.BytesIO(base64.b64decode(processed['image'])))] + else: + images = [] + for image in processed['images']: + images.append(Image.open(io.BytesIO(base64.b64decode(image)))) + info = processed['html_info'] + return images, modules.ui.plaintext_to_html(info), '' def f(*args, **kwargs): if cmd_opts.pureui and func == modules.txt2img.txt2img: - res = sagemaker_inference('text-to-image', *args, **kwargs) - elif(cmd_opts.pureui and func == modules.img2img.img2img): - res = sagemaker_inference('image-to-image', *args, **kwargs) + res = sagemaker_inference('text-to-image', 'sync', *args, **kwargs) + elif cmd_opts.pureui and func == modules.img2img.img2img: + res = sagemaker_inference('image-to-image', 'sync', *args, **kwargs) + elif cmd_opts.pureui and func == modules.extras.run_extras: + res = sagemaker_inference('extras', 'sync', *args, **kwargs) else: shared.state.begin() @@ -164,11 +432,13 @@ def initialize(): modules.scripts.load_scripts() modules.sd_vae.refresh_vae_list() - modules.sd_models.load_model() - shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) - shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) + + if not cmd_opts.pureui: + modules.sd_models.load_model() + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) + shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) + shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: @@ -279,16 +549,31 @@ def webui(): modules.sd_models.list_models() print('Restarting Gradio') -def train(): - os.system('mount -t efs -o tls fs-06810c18b16c76fed:/ /mnt/efs') - os.system('mkdir -p /mnt/efs/embedding') - os.system('mkdir -p /mnt/efs/hypernetwork') +def upload_s3file(s3uri, file_path, file_name): + s3_client = boto3.client('s3', region_name = cmd_opts.region_name) + + pos = s3uri.find('/', 5) + bucket = s3uri[5 : pos] + key = s3uri[pos + 1 : ] + binary = io.BytesIO(open(file_path, 'rb').read()) + key = key + file_name + try: + s3_client.upload_fileobj(binary, bucket, key) + except ClientError as e: + print(e) + return False + return True + +def train(): initialize() train_task = cmd_opts.train_task train_args = json.loads(cmd_opts.train_args) + embeddings_s3uri = cmd_opts.embeddings_s3uri + hypernetworks_s3uri = cmd_opts.hypernetworks_s3uri + if train_task == 'embedding': name = train_args['embedding_settings']['name'] nvpt = train_args['embedding_settings']['nvpt'] @@ -300,6 +585,8 @@ def train(): overwrite_old, init_text=initialization_text ) + if not cmd_opts.pureui: + modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() process_src = '/opt/ml/input/data/images' process_dst = str(uuid.uuid4()) process_width = train_args['images_preprocessing_settings']['process_width'] @@ -344,7 +631,7 @@ def train(): steps = train_args['train_embedding_settings']['steps'] create_image_every = train_args['train_embedding_settings']['create_image_every'] save_embedding_every = train_args['train_embedding_settings']['save_embedding_every'] - template_file = 'style_filewords.txt' + template_file = os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt") save_image_with_stored_embedding = train_args['train_embedding_settings']['save_image_with_stored_embedding'] preview_from_txt2img = train_args['train_embedding_settings']['preview_from_txt2img'] txt2img_preview_params = train_args['train_embedding_settings']['txt2img_preview_params'] @@ -364,6 +651,11 @@ def train(): preview_from_txt2img, *txt2img_preview_params ) + try: + upload_s3file(embeddings_s3uri, os.path.join(cmd_opts.embeddings_dir, '{0}.pt'.format(train_embedding_name)), '{0}.pt'.format(train_embedding_name)) + except Exception as e: + traceback.print_exc() + print(e) elif train_task == 'hypernetwork': name = train_args['hypernetwork_settings']['name'] enable_sizes = train_args['hypernetwork_settings']['enable_sizes'] @@ -428,7 +720,7 @@ def train(): process_focal_crop_edges_weight, process_focal_crop_debug, ) - train_embedding_name = name + train_hypernetwork_name = name embedding_learn_rate = train_args['train_embedding_settings']['embedding_learn_rate'] batch_size = train_args['train_embedding_settings']['batch_size'] dataset_directory = process_dst @@ -443,7 +735,7 @@ def train(): preview_from_txt2img = train_args['train_embedding_settings']['preview_from_txt2img'] txt2img_preview_params = train_args['train_embedding_settings']['txt2img_preview_params'] _, filename = modules.textual_inversion.textual_inversion.train_embedding( - train_embedding_name, + train_hypernetwork_name, embedding_learn_rate, batch_size, dataset_directory, @@ -458,8 +750,13 @@ def train(): preview_from_txt2img, *txt2img_preview_params ) + try: + upload_s3file(hypernetworks_s3uri, os.path.join(cmd_opts.hypernetwork_dir, '{0}.pt'.format(train_hypernetwork_name)), '{0}.pt'.format(train_hypernetwork_name)) + except Exception as e: + traceback.print_exc() + print(e) else: - print('Incorrect trainingg task') + print('Incorrect training task') exit(-1) if __name__ == "__main__":