diff --git a/config.json b/config.json deleted file mode 100644 index 1bfd0e0dae3..00000000000 --- a/config.json +++ /dev/null @@ -1,116 +0,0 @@ -{ - "samples_save": true, - "samples_format": "png", - "samples_filename_pattern": "", - "save_images_add_number": true, - "grid_save": true, - "grid_format": "png", - "grid_extended_filename": false, - "grid_only_if_multiple": true, - "grid_prevent_empty_spots": false, - "n_rows": -1, - "enable_pnginfo": true, - "save_txt": false, - "save_images_before_face_restoration": false, - "save_images_before_highres_fix": false, - "save_images_before_color_correction": false, - "jpeg_quality": 80, - "export_for_4chan": true, - "use_original_name_batch": false, - "save_selected_only": true, - "do_not_add_watermark": false, - "temp_dir": "", - "clean_temp_dir_at_start": false, - "outdir_samples": "", - "outdir_txt2img_samples": "outputs/txt2img-images", - "outdir_img2img_samples": "outputs/img2img-images", - "outdir_extras_samples": "outputs/extras-images", - "outdir_grids": "", - "outdir_txt2img_grids": "outputs/txt2img-grids", - "outdir_img2img_grids": "outputs/img2img-grids", - "outdir_save": "log/images", - "save_to_dirs": false, - "grid_save_to_dirs": false, - "use_save_to_dirs_for_ui": false, - "directories_filename_pattern": "", - "directories_max_prompt_words": 8, - "ESRGAN_tile": 192, - "ESRGAN_tile_overlap": 8, - "realesrgan_enabled_models": [ - "R-ESRGAN 4x+", - "R-ESRGAN 4x+ Anime6B" - ], - "SWIN_tile": 192, - "SWIN_tile_overlap": 8, - "ldsr_steps": 100, - "upscaler_for_img2img": null, - "use_scale_latent_for_hires_fix": false, - "face_restoration_model": null, - "code_former_weight": 0.5, - "face_restoration_unload": false, - "memmon_poll_rate": 8, - "samples_log_stdout": false, - "multiple_tqdm": true, - "unload_models_when_training": false, - "pin_memory": false, - "save_optimizer_state": false, - "dataset_filename_word_regex": "", - "dataset_filename_join_string": " ", - "training_image_repeats_per_epoch": 1, - "training_write_csv_every": 500, - "training_xattention_optimizations": false, - "sd_model_checkpoint": "", - "sd_checkpoint_cache": 0, - "sd_vae": "auto", - "sd_vae_as_default": false, - "sd_hypernetwork": "None", - "sd_hypernetwork_strength": 1.0, - "inpainting_mask_weight": 1.0, - "img2img_color_correction": false, - "img2img_fix_steps": false, - "enable_quantization": false, - "enable_emphasis": true, - "use_old_emphasis_implementation": false, - "enable_batch_seeds": true, - "comma_padding_backtrack": 20, - "filter_nsfw": false, - "CLIP_stop_at_last_layers": 1, - "random_artist_categories": [], - "interrogate_keep_models_in_memory": false, - "interrogate_use_builtin_artists": true, - "interrogate_return_ranks": false, - "interrogate_clip_num_beams": 1, - "interrogate_clip_min_length": 24, - "interrogate_clip_max_length": 48, - "interrogate_clip_dict_limit": 1500, - "interrogate_deepbooru_score_threshold": 0.5, - "deepbooru_sort_alpha": true, - "deepbooru_use_spaces": false, - "deepbooru_escape": true, - "show_progressbar": true, - "show_progress_every_n_steps": 0, - "show_progress_grid": true, - "return_grid": true, - "do_not_show_images": false, - "add_model_hash_to_info": true, - "add_model_name_to_info": false, - "disable_weights_auto_swap": false, - "send_seed": true, - "font": "", - "js_modal_lightbox": true, - "js_modal_lightbox_initially_zoomed": true, - "show_progress_in_title": true, - "quicksettings": "", - "localization": "zh_CN", - "hide_samplers": [], - "eta_ddim": 0.0, - "eta_ancestral": 1.0, - "ddim_discretize": "uniform", - "s_churn": 0.0, - "s_tmin": 0.0, - "s_noise": 1.0, - "eta_noise_seed_delta": 0, - "disabled_extensions": [], - "sagemaker_endpoint": "" -} - diff --git a/modules/call_queue.py b/modules/call_queue.py index 386d4a6f994..01ef298f0ae 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -75,7 +75,7 @@ def handle_sagemaker_inference_async(response): return processed - def sagemaker_inference(task, infer, *args, **kwargs): + def sagemaker_inference(task, infer, username, sagemaker_endpoint, *args, **kwargs): infer = 'async' if task == 'text-to-image' or task == 'image-to-image': if task == 'text-to-image': @@ -139,7 +139,7 @@ def sagemaker_inference(task, infer, *args, **kwargs): inputs = { 'task': task, 'txt2img_payload': payload, - 'username': shared.username + 'username': username } else: mode = args[0] @@ -273,14 +273,13 @@ def sagemaker_inference(task, infer, *args, **kwargs): inputs = { 'task': task, 'img2img_payload': payload, - 'username': shared.username + 'username': username } print(sd_samplers.samplers[sampler_index].name) params = { - 'endpoint_name': shared.opts.sagemaker_endpoint + 'endpoint_name': sagemaker_endpoint } - response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs) if infer == 'async': processed = handle_sagemaker_inference_async(response) @@ -340,7 +339,7 @@ def sagemaker_inference(task, infer, *args, **kwargs): inputs = { 'task': task, 'extras_single_payload': payload, - 'username': shared.username + 'username': username } else: imageList = [] @@ -372,11 +371,11 @@ def sagemaker_inference(task, infer, *args, **kwargs): inputs = { 'task': task, 'extras_batch_payload': payload, - 'username': shared.username + 'username': username } params = { - 'endpoint_name': shared.opts.sagemaker_endpoint + 'endpoint_name': sagemaker_endpoint } response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs) if infer == 'async': @@ -395,11 +394,26 @@ def sagemaker_inference(task, infer, *args, **kwargs): def f(*args, **kwargs): if cmd_opts.pureui and func == modules.txt2img.txt2img: - res = sagemaker_inference('text-to-image', 'sync', *args, **kwargs) + username = args[len(args) - 2] + sagemaker_endpoint = args[len(args) -1] + args = args[:-2] + print('username:', username) + print('sagemaker_endpoint:', sagemaker_endpoint) + res = sagemaker_inference('text-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs) elif cmd_opts.pureui and func == modules.img2img.img2img: - res = sagemaker_inference('image-to-image', 'sync', *args, **kwargs) + username = args[len(args) - 2] + sagemaker_endpoint = args[len(args) -1] + args = args[:-2] + print('username:', username) + print('sagemaker_endpoint:', sagemaker_endpoint) + res = sagemaker_inference('image-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs) elif cmd_opts.pureui and func == modules.extras.run_extras: - res = sagemaker_inference('extras', 'sync', *args, **kwargs) + username = args[len(args) - 2] + sagemaker_endpoint = args[len(args) -1] + args = args[:-2] + print('username:', username) + print('sagemaker_endpoint:', sagemaker_endpoint) + res = sagemaker_inference('extras', 'sync', username, sagemaker_endpoint, *args, **kwargs) else: shared.state.begin() with queue_lock: diff --git a/modules/sd_models.py b/modules/sd_models.py index c2f2bad6fd9..8c48904102f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -56,7 +56,7 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) -def list_models(): +def list_models(sagemaker_endpoint=None): global checkpoints_list checkpoints_list.clear() @@ -100,7 +100,7 @@ def modeltitle(path, shorthash): if shared.cmd_opts.pureui: params = { - 'endpoint_name': shared.opts.sagemaker_endpoint + 'endpoint_name': sagemaker_endpoint } response = requests.get(url=f'{api_endpoint}/sd/models', params=params) if response.status_code == 200: diff --git a/modules/shared.py b/modules/shared.py index 2209c14d5a3..f733f98de35 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,646 +1,661 @@ -import argparse -import datetime -import json -import os -import sys -import time - -import gradio as gr -import tqdm - -import modules.artists -import modules.interrogate -import modules.memmon -import modules.styles -import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading -from modules.paths import models_path, script_path, sd_path -import requests - -demo = None - -sd_model_file = os.path.join(script_path, 'model.ckpt') -default_sd_model_file = sd_model_file -parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",) -parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) -parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") -parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) -parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) -parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") -parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") -parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") -parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") -parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") -parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") -parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") -parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") -parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") -parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") -parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") -parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") -parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") -parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") -parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") -parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) -parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") -parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options") -parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) -parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) -parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) -parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN')) -parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN')) -parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) -parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") -parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") -parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") -parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") -parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") -parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") -parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) -parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") -parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) -parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) -parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) -parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) -parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) -parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) -parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") -parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) -parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") -parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it") -parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") -parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) -parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) -parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) -parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) -parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) -parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) -parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) -parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) -parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") -parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) -parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui") -parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") -parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) -parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) -parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None) -parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None) -parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) -parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) -parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) -parser.add_argument("--pureui", action='store_true', help="Pure UI without local inference and progress bar", default=False) -parser.add_argument("--train", action='store_true', help="Train only on SageMaker", default=False) -parser.add_argument("--train-task", type=str, help='Train task - embedding or hypernetwork', default='embedding') -parser.add_argument("--train-args", type=str, help='Train args', default='') -parser.add_argument('--embeddings-s3uri', default='', type=str, help='Embedding S3Uri') -parser.add_argument('--hypernetwork-s3uri', default='', type=str, help='Hypernetwork S3Uri') -parser.add_argument('--sd-models-s3uri', default='', type=str, help='SD Models S3Uri') -parser.add_argument('--db-models-s3uri', default='', type=str, help='DB Models S3Uri') -parser.add_argument('--region-name', type=str, help='Region Name') -parser.add_argument('--username', default='', type=str, help='Username') -parser.add_argument('--api-endpoint', default='', type=str, help='API Endpoint') -parser.add_argument('--dreambooth-config-id', default='', type=str, help='Dreambooth config ID') - -script_loading.preload_extensions(extensions.extensions_dir, parser) -script_loading.preload_extensions(extensions.extensions_builtin_dir, parser) - -cmd_opts = parser.parse_args() - -restricted_opts = { - "samples_filename_pattern", - "directories_filename_pattern", - "outdir_samples", - "outdir_txt2img_samples", - "outdir_img2img_samples", - "outdir_extras_samples", - "outdir_grids", - "outdir_txt2img_grids", - "outdir_save", -} - -cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access - -devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ - (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) - -device = devices.device -weight_load_location = None if cmd_opts.lowram else "cpu" - -batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) -parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram -xformers_available = False -config_filename = cmd_opts.ui_settings_file - -os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) -hypernetworks = {} -loaded_hypernetwork = None - -if cmd_opts.pureui: - username = '' - api_endpoint = os.environ['api_endpoint'] - industrial_model = '' - default_options = {} - -def reload_hypernetworks(): - from modules.hypernetworks import hypernetwork - global hypernetworks - - hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) - hypernetwork.load_hypernetwork(opts.sd_hypernetwork) - -class State: - skipped = False - interrupted = False - job = "" - job_no = 0 - job_count = 0 - job_timestamp = '0' - sampling_step = 0 - sampling_steps = 0 - current_latent = None - current_image = None - current_image_sampling_step = 0 - textinfo = None - time_start = None - need_restart = False - - def skip(self): - self.skipped = True - - def interrupt(self): - self.interrupted = True - - def nextjob(self): - if opts.show_progress_every_n_steps == -1: - self.do_set_current_image() - - self.job_no += 1 - self.sampling_step = 0 - self.current_image_sampling_step = 0 - - def dict(self): - obj = { - "skipped": self.skipped, - "interrupted": self.skipped, - "job": self.job, - "job_count": self.job_count, - "job_no": self.job_no, - "sampling_step": self.sampling_step, - "sampling_steps": self.sampling_steps, - } - - return obj - - def begin(self): - self.sampling_step = 0 - self.job_count = -1 - self.job_no = 0 - self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - self.current_latent = None - self.current_image = None - self.current_image_sampling_step = 0 - self.skipped = False - self.interrupted = False - self.textinfo = None - self.time_start = time.time() - - devices.torch_gc() - - def end(self): - self.job = "" - self.job_count = 0 - - devices.torch_gc() - - """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" - def set_current_image(self): - if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0: - self.do_set_current_image() - - def do_set_current_image(self): - if not parallel_processing_allowed: - return - if self.current_latent is None: - return - - import modules.sd_samplers - if opts.show_progress_grid: - self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent) - else: - self.current_image = modules.sd_samplers.sample_to_image(self.current_latent) - - self.current_image_sampling_step = self.sampling_step - -state = State() - -artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) - -styles_filename = cmd_opts.styles_file -prompt_styles = modules.styles.StyleDatabase(styles_filename) - -interrogator = modules.interrogate.InterrogateModels("interrogate") - -face_restorers = [] - - -def realesrgan_models_names(): - import modules.realesrgan_model - return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] - - -class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): - self.default = default - self.label = label - self.component = component - self.component_args = component_args - self.onchange = onchange - self.section = section - self.refresh = refresh - - -def options_section(section_identifier, options_dict): - for k, v in options_dict.items(): - v.section = section_identifier - - return options_dict - - -def list_checkpoint_tiles(): - import modules.sd_models - return modules.sd_models.checkpoint_tiles() - - -def refresh_checkpoints(): - import modules.sd_models - return modules.sd_models.list_models() - - -def list_samplers(): - import modules.sd_samplers - return modules.sd_samplers.all_samplers - - -hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} - -options_templates = {} - -def refresh_sagemaker_endpoints(): - global industrial_model, api_endpoint, default_options - - if industrial_model == '': - response = requests.get(url=f'{api_endpoint}/sd/industrialmodel') - if response.status_code == 200: - industrial_model = response.text - else: - model_name = 'stable-diffusion-webui' - model_description = model_name - inputs = { - 'model_algorithm': 'stable-diffusion-webui', - 'model_name': model_name, - 'model_description': model_description, - 'model_extra': '{"visible": "false"}', - 'model_samples': '', - 'file_content': { - 'data': [(lambda x: int(x))(x) for x in open(os.path.join(script_path, 'logo.ico'), 'rb').read()] - } - } - - response = requests.post(url=f'{api_endpoint}/industrialmodel', json = inputs) - if response.status_code == 200: - body = json.loads(response.text) - industrial_model = body['id'] - - default_options = opts.data - - sagemaker_endpoints = [] - - if industrial_model != '': - params = { - 'industrial_model': industrial_model - } - response = requests.get(url=f'{api_endpoint}/endpoint', params=params) - if response.status_code == 200: - for endpoint_item in json.loads(response.text): - sagemaker_endpoints.append(endpoint_item['EndpointName']) - - return sagemaker_endpoints - -options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sagemaker_endpoint": OptionInfo(None, "SaegMaker endpoint", gr.Dropdown, lambda: {"choices": refresh_sagemaker_endpoints()}, refresh=refresh_sagemaker_endpoints), - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), - "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), - "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), - "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), - "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), - "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), - "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), - "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), - "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), - "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), - "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), - "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), - "filter_nsfw": OptionInfo(False, "Filter NSFW content"), - 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), - "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), -})) - -options_templates.update(options_section(('saving-images', "Saving images/grids"), { - "samples_save": OptionInfo(True, "Always save all generated images"), - "samples_format": OptionInfo('png', 'File format for images'), - "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs), - "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs), - - "grid_save": OptionInfo(True, "Always save all generated image grids"), - "grid_format": OptionInfo('png', 'File format for grids'), - "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), - "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), - "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), - "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), - - "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), - "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), - "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), - "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), - "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), - "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), - "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), - - "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), - "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), - "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), - - "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"), - "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), - -})) - -options_templates.update(options_section(('saving-paths', "Paths for saving"), { - "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), - "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), - "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), - "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs), - "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs), - "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs), - "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs), - "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs), -})) - -options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { - "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"), - "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"), - "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), - "directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs), - "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), -})) - -options_templates.update(options_section(('upscaling', "Upscaling"), { - "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), - "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), - "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), - "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), -})) - -options_templates.update(options_section(('face-restoration', "Face restoration"), { - "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), - "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), - "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), -})) - -options_templates.update(options_section(('system', "System"), { - "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), - "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), - "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), -})) - -options_templates.update(options_section(('training', "Training"), { - "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), - "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), - "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), - "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), - "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), - "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), - "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), - "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"), -})) - -options_templates.update(options_section(('interrogate', "Interrogate Options"), { - "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), - "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), - "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."), - "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), - "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), - "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), - "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), - "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), - "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), - "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), - "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"), -})) - -options_templates.update(options_section(('ui', "User interface"), { - "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), - "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), - "return_grid": OptionInfo(True, "Show grid in results for web"), - "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), - "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), - "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), - "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), - "font": OptionInfo("", "Font for image grids that have text"), - "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), - "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), - 'quicksettings': OptionInfo("sagemaker_endpoint", "Quicksettings list"), - 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), -})) - -options_templates.update(options_section(('sampler-params', "Sampler parameters"), { - "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}), - "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), - 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), -})) - -options_templates.update(options_section((None, "Hidden options"), { - "disabled_extensions": OptionInfo([], "Disable those extensions"), -})) - -options_templates.update() - - -class Options: - data = None - data_labels = options_templates - typemap = {int: float} - - def __init__(self): - self.data = {k: v.default for k, v in self.data_labels.items()} - - def __setattr__(self, key, value): - if self.data is not None: - if key in self.data or key in self.data_labels: - assert not cmd_opts.freeze_settings, "changing settings is disabled" - - info = opts.data_labels.get(key, None) - comp_args = info.component_args if info else None - if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: - raise RuntimeError(f"not possible to set {key} because it is restricted") - - if cmd_opts.hide_ui_dir_config and key in restricted_opts: - raise RuntimeError(f"not possible to set {key} because it is restricted") - - self.data[key] = value - return - - return super(Options, self).__setattr__(key, value) - - def __getattr__(self, item): - if self.data is not None: - if item in self.data: - return self.data[item] - - if item in self.data_labels: - return self.data_labels[item].default - - return super(Options, self).__getattribute__(item) - - def set(self, key, value): - """sets an option and calls its onchange callback, returning True if the option changed and False otherwise""" - - oldval = self.data.get(key, None) - if oldval == value: - return False - - try: - setattr(self, key, value) - except RuntimeError: - return False - - if self.data_labels[key].onchange is not None: - self.data_labels[key].onchange() - - return True - - def save(self, filename): - assert not cmd_opts.freeze_settings, "saving settings is disabled" - - with open(filename, "w", encoding="utf8") as file: - json.dump(self.data, file, indent=4) - - def same_type(self, x, y): - if x is None or y is None: - return True - - type_x = self.typemap.get(type(x), type(x)) - type_y = self.typemap.get(type(y), type(y)) - - return type_x == type_y - - def load(self, filename): - with open(filename, "r", encoding="utf8") as file: - self.data = json.load(file) - - bad_settings = 0 - for k, v in self.data.items(): - info = self.data_labels.get(k, None) - if info is not None and not self.same_type(info.default, v): - print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr) - bad_settings += 1 - - if bad_settings > 0: - print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr) - - def onchange(self, key, func, call=True): - item = self.data_labels.get(key) - item.onchange = func - - if call: - func() - - def dumpjson(self): - d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} - return json.dumps(d) - - def add_option(self, key, info): - self.data_labels[key] = info - - def reorder(self): - """reorder settings so that all items related to section always go together""" - - section_ids = {} - settings_items = self.data_labels.items() - for k, item in settings_items: - if item.section not in section_ids: - section_ids[item.section] = len(section_ids) - - self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])} - - -opts = Options() -if os.path.exists(config_filename): - opts.load(config_filename) - -if cmd_opts.pureui and opts.localization == None: - opts.localization = "zh_CN" - -sd_upscalers = [] - -sd_model = None - -clip_model = None - -progress_print_out = sys.stdout - - -class TotalTQDM: - def __init__(self): - self._tqdm = None - - def reset(self): - self._tqdm = tqdm.tqdm( - desc="Total progress", - total=state.job_count * state.sampling_steps, - position=1, - file=progress_print_out - ) - - def update(self): - if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: - return - if self._tqdm is None: - self.reset() - self._tqdm.update() - - def updateTotal(self, new_total): - if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: - return - if self._tqdm is None: - self.reset() - self._tqdm.total=new_total - - def clear(self): - if self._tqdm is not None: - self._tqdm.close() - self._tqdm = None - - -total_tqdm = TotalTQDM() - -mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts) -mem_mon.start() - - -def listfiles(dirname): - filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] - return [file for file in filenames if os.path.isfile(file)] +import argparse +import datetime +import json +import os +import sys +import time + +import gradio as gr +import tqdm + +import modules.artists +import modules.interrogate +import modules.memmon +import modules.styles +import modules.devices as devices +from modules import localization, sd_vae, extensions, script_loading +from modules.paths import models_path, script_path, sd_path +import requests + +demo = None + +sd_model_file = os.path.join(script_path, 'model.ckpt') +default_sd_model_file = sd_model_file +parser = argparse.ArgumentParser() +parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",) +parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) +parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") +parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) +parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) +parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") +parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") +parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") +parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") +parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") +parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") +parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") +parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") +parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") +parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") +parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") +parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") +parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") +parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) +parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") +parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options") +parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) +parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) +parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) +parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN')) +parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN')) +parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) +parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") +parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") +parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") +parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") +parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") +parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") +parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") +parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") +parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) +parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) +parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) +parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) +parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) +parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) +parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") +parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") +parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it") +parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") +parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) +parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) +parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) +parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) +parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) +parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) +parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) +parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") +parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui") +parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") +parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) +parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) +parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None) +parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None) +parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) +parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) +parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) +parser.add_argument("--pureui", action='store_true', help="Pure UI without local inference and progress bar", default=False) +parser.add_argument("--train", action='store_true', help="Train only on SageMaker", default=False) +parser.add_argument("--train-task", type=str, help='Train task - embedding or hypernetwork', default='embedding') +parser.add_argument("--train-args", type=str, help='Train args', default='') +parser.add_argument('--embeddings-s3uri', default='', type=str, help='Embedding S3Uri') +parser.add_argument('--hypernetwork-s3uri', default='', type=str, help='Hypernetwork S3Uri') +parser.add_argument('--sd-models-s3uri', default='', type=str, help='SD Models S3Uri') +parser.add_argument('--db-models-s3uri', default='', type=str, help='DB Models S3Uri') +parser.add_argument('--region-name', type=str, help='Region Name') +parser.add_argument('--username', default='', type=str, help='Username') +parser.add_argument('--api-endpoint', default='', type=str, help='API Endpoint') +parser.add_argument('--dreambooth-config-id', default='', type=str, help='Dreambooth config ID') + +script_loading.preload_extensions(extensions.extensions_dir, parser) +script_loading.preload_extensions(extensions.extensions_builtin_dir, parser) + +cmd_opts = parser.parse_args() + +restricted_opts = { + "samples_filename_pattern", + "directories_filename_pattern", + "outdir_samples", + "outdir_txt2img_samples", + "outdir_img2img_samples", + "outdir_extras_samples", + "outdir_grids", + "outdir_txt2img_grids", + "outdir_save", +} + +cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access + +devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ + (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) + +device = devices.device +weight_load_location = None if cmd_opts.lowram else "cpu" + +batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) +parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram +xformers_available = False +config_filename = cmd_opts.ui_settings_file + +os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) +hypernetworks = {} +loaded_hypernetwork = None + +if cmd_opts.pureui: + api_endpoint = os.environ['api_endpoint'] + industrial_model = '' + default_options = {} + username_state = None + sagemaker_endpoint_component = None + +def reload_hypernetworks(): + from modules.hypernetworks import hypernetwork + global hypernetworks + + hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) + hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + +class State: + skipped = False + interrupted = False + job = "" + job_no = 0 + job_count = 0 + job_timestamp = '0' + sampling_step = 0 + sampling_steps = 0 + current_latent = None + current_image = None + current_image_sampling_step = 0 + textinfo = None + time_start = None + need_restart = False + + def skip(self): + self.skipped = True + + def interrupt(self): + self.interrupted = True + + def nextjob(self): + if opts.show_progress_every_n_steps == -1: + self.do_set_current_image() + + self.job_no += 1 + self.sampling_step = 0 + self.current_image_sampling_step = 0 + + def dict(self): + obj = { + "skipped": self.skipped, + "interrupted": self.skipped, + "job": self.job, + "job_count": self.job_count, + "job_no": self.job_no, + "sampling_step": self.sampling_step, + "sampling_steps": self.sampling_steps, + } + + return obj + + def begin(self): + self.sampling_step = 0 + self.job_count = -1 + self.job_no = 0 + self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + self.current_latent = None + self.current_image = None + self.current_image_sampling_step = 0 + self.skipped = False + self.interrupted = False + self.textinfo = None + self.time_start = time.time() + + devices.torch_gc() + + def end(self): + self.job = "" + self.job_count = 0 + + devices.torch_gc() + + """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" + def set_current_image(self): + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0: + self.do_set_current_image() + + def do_set_current_image(self): + if not parallel_processing_allowed: + return + if self.current_latent is None: + return + + import modules.sd_samplers + if opts.show_progress_grid: + self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent) + else: + self.current_image = modules.sd_samplers.sample_to_image(self.current_latent) + + self.current_image_sampling_step = self.sampling_step + +state = State() + +artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) + +styles_filename = cmd_opts.styles_file +prompt_styles = modules.styles.StyleDatabase(styles_filename) + +interrogator = modules.interrogate.InterrogateModels("interrogate") + +face_restorers = [] + + +def realesrgan_models_names(): + import modules.realesrgan_model + return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + + +class OptionInfo: + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): + self.default = default + self.label = label + self.component = component + self.component_args = component_args + self.onchange = onchange + self.section = section + self.refresh = refresh + + +def options_section(section_identifier, options_dict): + for k, v in options_dict.items(): + v.section = section_identifier + + return options_dict + + +def list_checkpoint_tiles(): + import modules.sd_models + return modules.sd_models.checkpoint_tiles() + + +def refresh_checkpoints(sagemaker_endpoint=None): + print('sagemaker_endpoint2:', sagemaker_endpoint) + + import modules.sd_models + return modules.sd_models.list_models(sagemaker_endpoint) + + +def list_samplers(): + import modules.sd_samplers + return modules.sd_samplers.all_samplers + + +hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} + +options_templates = {} + +sagemaker_endpoints = [] + +def list_sagemaker_endpoints(): + global sagemaker_endpoints + + return sagemaker_endpoints + +def refresh_sagemaker_endpoints(username=None): + global industrial_model, api_endpoint, sagemaker_endpoints + + print('username2:', username) + + sagemaker_endpoints = [] + + if not username: + return sagemaker_endpoints + + if industrial_model == '': + response = requests.get(url=f'{api_endpoint}/sd/industrialmodel') + if response.status_code == 200: + industrial_model = response.text + else: + model_name = 'stable-diffusion-webui' + model_description = model_name + inputs = { + 'model_algorithm': 'stable-diffusion-webui', + 'model_name': model_name, + 'model_description': model_description, + 'model_extra': '{"visible": "false"}', + 'model_samples': '', + 'file_content': { + 'data': [(lambda x: int(x))(x) for x in open(os.path.join(script_path, 'logo.ico'), 'rb').read()] + } + } + + response = requests.post(url=f'{api_endpoint}/industrialmodel', json = inputs) + if response.status_code == 200: + body = json.loads(response.text) + industrial_model = body['id'] + + if industrial_model != '': + params = { + 'industrial_model': industrial_model + } + response = requests.get(url=f'{api_endpoint}/endpoint', params=params) + if response.status_code == 200: + for endpoint_item in json.loads(response.text): + sagemaker_endpoints.append(endpoint_item['EndpointName']) + + return sagemaker_endpoints + +options_templates.update(options_section(('sd', "Stable Diffusion"), { + "sagemaker_endpoint": OptionInfo(None, "SaegMaker endpoint", gr.Dropdown, lambda: {"choices": list_sagemaker_endpoints()}, refresh=refresh_sagemaker_endpoints), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), + "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), + "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), + "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), + "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), + "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), + "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), + "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), + "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), + "filter_nsfw": OptionInfo(False, "Filter NSFW content"), + 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), +})) + +options_templates.update(options_section(('saving-images', "Saving images/grids"), { + "samples_save": OptionInfo(True, "Always save all generated images"), + "samples_format": OptionInfo('png', 'File format for images'), + "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs), + "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs), + + "grid_save": OptionInfo(True, "Always save all generated image grids"), + "grid_format": OptionInfo('png', 'File format for grids'), + "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), + "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), + "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), + "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), + + "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), + "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), + "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), + "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), + "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), + "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), + "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), + + "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), + "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), + "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), + + "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"), + "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), + +})) + +options_templates.update(options_section(('saving-paths', "Paths for saving"), { + "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), + "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), + "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), + "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs), + "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs), + "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs), + "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs), + "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs), +})) + +options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { + "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"), + "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"), + "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), + "directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs), + "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), +})) + +options_templates.update(options_section(('upscaling', "Upscaling"), { + "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), + "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), + "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), + "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), +})) + +options_templates.update(options_section(('face-restoration', "Face restoration"), { + "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), + "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), + "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), +})) + +options_templates.update(options_section(('system', "System"), { + "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), + "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), + "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), +})) + +options_templates.update(options_section(('training', "Training"), { + "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), + "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), + "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), + "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), + "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), + "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), + "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), + "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"), +})) + +options_templates.update(options_section(('interrogate', "Interrogate Options"), { + "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), + "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), + "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."), + "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), + "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), + "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), + "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), + "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), + "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), + "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), + "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"), +})) + +options_templates.update(options_section(('ui', "User interface"), { + "show_progressbar": OptionInfo(True, "Show progressbar"), + "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), + "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), + "return_grid": OptionInfo(True, "Show grid in results for web"), + "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), + "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), + "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), + "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), + "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), + "font": OptionInfo("", "Font for image grids that have text"), + "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), + "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + 'quicksettings': OptionInfo("", "Quicksettings list"), + 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), +})) + +options_templates.update(options_section(('sampler-params', "Sampler parameters"), { + "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}), + "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), + 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), +})) + +options_templates.update(options_section((None, "Hidden options"), { + "disabled_extensions": OptionInfo([], "Disable those extensions"), +})) + +options_templates.update() + + +class Options: + data = None + data_labels = options_templates + typemap = {int: float} + + def __init__(self): + self.data = {k: v.default for k, v in self.data_labels.items()} + + def __setattr__(self, key, value): + if self.data is not None: + if key in self.data or key in self.data_labels: + assert not cmd_opts.freeze_settings, "changing settings is disabled" + + info = opts.data_labels.get(key, None) + comp_args = info.component_args if info else None + if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: + raise RuntimeError(f"not possible to set {key} because it is restricted") + + if cmd_opts.hide_ui_dir_config and key in restricted_opts: + raise RuntimeError(f"not possible to set {key} because it is restricted") + + self.data[key] = value + return + + return super(Options, self).__setattr__(key, value) + + def __getattr__(self, item): + if self.data is not None: + if item in self.data: + return self.data[item] + + if item in self.data_labels: + return self.data_labels[item].default + + return super(Options, self).__getattribute__(item) + + def set(self, key, value): + """sets an option and calls its onchange callback, returning True if the option changed and False otherwise""" + + oldval = self.data.get(key, None) + if oldval == value: + return False + + try: + setattr(self, key, value) + except RuntimeError: + return False + + if self.data_labels[key].onchange is not None: + self.data_labels[key].onchange() + + return True + + def save(self, filename): + assert not cmd_opts.pureui and not cmd_opts.freeze_settings, "saving settings is disabled" + + with open(filename, "w", encoding="utf8") as file: + json.dump(self.data, file, indent=4) + + def same_type(self, x, y): + if x is None or y is None: + return True + + type_x = self.typemap.get(type(x), type(x)) + type_y = self.typemap.get(type(y), type(y)) + + return type_x == type_y + + def load(self, filename): + assert not cmd_opts.pureui + + with open(filename, "r", encoding="utf8") as file: + self.data = json.load(file) + + bad_settings = 0 + for k, v in self.data.items(): + info = self.data_labels.get(k, None) + if info is not None and not self.same_type(info.default, v): + print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr) + bad_settings += 1 + + if bad_settings > 0: + print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr) + + def onchange(self, key, func, call=True): + item = self.data_labels.get(key) + item.onchange = func + + if call: + func() + + def dumpjson(self): + d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} + return json.dumps(d) + + def add_option(self, key, info): + self.data_labels[key] = info + + def reorder(self): + """reorder settings so that all items related to section always go together""" + + section_ids = {} + settings_items = self.data_labels.items() + for k, item in settings_items: + if item.section not in section_ids: + section_ids[item.section] = len(section_ids) + + self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])} + + +opts = Options() +if os.path.exists(config_filename): + opts.load(config_filename) + +if cmd_opts.pureui and opts.localization == "None": + opts.localization = "zh_CN" + +sd_upscalers = [] + +sd_model = None + +clip_model = None + +progress_print_out = sys.stdout + + +class TotalTQDM: + def __init__(self): + self._tqdm = None + + def reset(self): + self._tqdm = tqdm.tqdm( + desc="Total progress", + total=state.job_count * state.sampling_steps, + position=1, + file=progress_print_out + ) + + def update(self): + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: + return + if self._tqdm is None: + self.reset() + self._tqdm.update() + + def updateTotal(self, new_total): + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: + return + if self._tqdm is None: + self.reset() + self._tqdm.total=new_total + + def clear(self): + if self._tqdm is not None: + self._tqdm.close() + self._tqdm = None + + +total_tqdm = TotalTQDM() + +mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts) +mem_mon.start() + + +def listfiles(dirname): + filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] + return [file for file in filenames if os.path.isfile(file)] diff --git a/modules/ui.py b/modules/ui.py index 2ce77ed157d..205a4a81e92 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -56,6 +56,7 @@ 'ml.g4dn.12xlarge', 'ml.g4dn.16xlarge' ] + component_dict = {} # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -501,16 +502,7 @@ def apply_setting(key, value): if oldval != value and opts.data_labels[key].onchange is not None: opts.data_labels[key].onchange() - if cmd_opts.pureui: - if shared.username != '': - inputs = { - 'action': 'edit', - 'username': shared.username, - 'options': json.dumps(opts.data) - } - - response = requests.post(url=f'{shared.api_endpoint}/sd/user', json = inputs) - else: + if not cmd_opts.pureui: opts.save(shared.config_filename) return value @@ -539,12 +531,45 @@ def refresh(): return gr.update(**(args or {})) + def refresh_sagemaker_endpoints(username): + print('username:', username) + refresh_method(username) + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + def refresh_checkpoints(sagemaker_endpoint): + print('sagemaker_endpoint:', sagemaker_endpoint) + refresh_method(sagemaker_endpoint) + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) + if elem_id == 'refresh_sagemaker_endpoint': + refresh_button.click( + fn=refresh_sagemaker_endpoints, + inputs=[shared.username_state], + outputs=[refresh_component] + ) + elif elem_id == 'refresh_sd_model_checkpoint': + refresh_button.click( + fn=refresh_checkpoints, + inputs=[shared.sagemaker_endpoint_component], + outputs=[refresh_component] + ) + else: + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) return refresh_button @@ -649,9 +674,216 @@ def create_ui(): modules.scripts.scripts_current = modules.scripts.scripts_txt2img modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {str(t)} for key {key}') + + elem_id = "setting_"+key + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + with gr.Row(variant="compact"): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + if key == 'sagemaker_endpoint': + shared.sagemaker_endpoint_component = res + + return res + + components = [] + global component_dict + + script_callbacks.ui_settings_callback() + opts.reorder() + + def run_settings(*args): + assert cmd_opts.pureui + + changed = [] + + username = args[len(args) - 1] + args = args[:-1] + print('username:', username) + + if not username or username == '': + return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + inputs = { + 'action': 'edit', + 'username': username, + 'options': opts.dumpjson() + } + + response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) + if response.status_code != 200: + raise RuntimeError("Settings saved failed") + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' + + def run_settings_single(value, key, username): + if username and username != '': + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + try: + if username and username != '': + inputs = { + 'action': 'edit', + 'username': username, + 'options': opts.dumpjson() + } + + response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) + if response.status_code != 200: + raise RuntimeError("Settings saved failed") + except RuntimeError: + return gr.update(visible=True), opts.dumpjson() + + return gr.update(value=value), opts.dumpjson() + else: + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + return gr.update(value=value), opts.dumpjson() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + shared.username_state = gr.Text(value='', visible=True) + dummy_component = gr.Label(visible=False) + + settings_submit = gr.Button(value="Apply settings", variant='primary') + result = gr.HTML() + + settings_cols = 3 + items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) + + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') + + quicksettings_list = [] + + cols_displayed = 0 + items_displayed = 0 + previous_section = None + column = None + with gr.Row(elem_id="settings").style(equal_height=False): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): + if column is not None: + column.__exit__() + + column = gr.Column(variant='panel') + column.__enter__() + + items_displayed = 0 + cols_displayed += 1 + + previous_section = item.section + + elem_id, text = item.section + gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) + + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + items_displayed += 1 + + with gr.Row(): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + + with gr.Row(): + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + def request_restart(): + shared.state.interrupt() + shared.state.need_restart = True + + restart_gradio.click( + fn=request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + if column is not None: + column.__exit__() + with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) - dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) txt2img_submit = submit @@ -723,7 +955,7 @@ def create_ui(): denoising_strength, firstphase_width, firstphase_height, - ] + custom_inputs, + ] + custom_inputs + [shared.username_state, shared.sagemaker_endpoint_component], outputs=[ txt2img_gallery, @@ -965,7 +1197,7 @@ def update_orig(image, state): inpainting_mask_invert, img2img_batch_input_dir if not cmd_opts.pureui else dummy_component, img2img_batch_output_dir if not cmd_opts.pureui else dummy_component, - ] + custom_inputs, + ] + custom_inputs + [shared.username_state, shared.sagemaker_endpoint_component], outputs=[ img2img_gallery, generation_info, @@ -1124,7 +1356,7 @@ def update_orig(image, state): extras_upscaler_2, extras_upscaler_2_visibility, upscale_before_face_fix, - ], + ] + [shared.username_state, shared.sagemaker_endpoint_component], outputs=[ result_images, html_info_x, @@ -1614,6 +1846,7 @@ def update_orig(image, state): create_train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', visible=False) def sagemaker_train_embedding( + username, new_embedding_name, initialization_text, nvpt, @@ -1697,7 +1930,7 @@ def sagemaker_train_embedding( 'train-args': json.dumps(json.dumps(train_args)), 'train-task': 'embedding', 'ckpt': '/opt/ml/input/data/models/{0}'.format(shared.sd_model.sd_model_name), - 'username': shared.username, + 'username': username, 'api-endpoint': shared.api_endpoint } @@ -1727,6 +1960,7 @@ def sagemaker_train_embedding( } def sagemaker_train_hypernetwork( + username, new_hypernetwork_name, new_hypernetwork_sizes, new_hypernetwork_layer_structure, @@ -1766,9 +2000,9 @@ def sagemaker_train_hypernetwork( hypernetwork_preview_from_txt2img, hypernetwork_training_instance_type, hypernetwork_training_instance_count, - *txt2img_preview_params + *txt2img_preview_params ): - + train_args = { 'hypernetwork_settings': { 'name': new_hypernetwork_name, @@ -1818,7 +2052,7 @@ def sagemaker_train_hypernetwork( 'train-args': json.dumps(json.dumps(train_args)), 'train-task': 'hypernetwork', 'ckpt': '/opt/ml/input/data/models/{0}'.format(shared.sd_model.sd_model_name), - 'username': shared.username, + 'username': username, 'api-endpoint': shared.api_endpoint } @@ -1850,6 +2084,7 @@ def sagemaker_train_hypernetwork( create_train_embedding.click( fn=sagemaker_train_embedding, inputs=[ + shared.username_state, new_embedding_name, initialization_text, nvpt, @@ -1893,6 +2128,7 @@ def sagemaker_train_hypernetwork( create_train_hypernetwork.click( fn=sagemaker_train_hypernetwork, inputs=[ + shared.username_state, new_hypernetwork_name, new_hypernetwork_sizes, new_hypernetwork_layer_structure, @@ -1937,182 +2173,6 @@ def sagemaker_train_hypernetwork( outputs=[hypernetwork_output] ) - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {str(t)} for key {key}') - - elem_id = "setting_"+key - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - with gr.Row(variant="compact"): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - - components = [] - component_dict = {} - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - if cmd_opts.pureui: - if shared.username != '': - inputs = { - 'action': 'edit', - 'username': shared.username, - 'options': json.dumps(opts.data) - } - - response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) - if response.status_code != 200: - raise RuntimeError("Settings saved failed") - else: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return gr.update(value=value), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - settings_submit = gr.Button(value="Apply settings", variant='primary') - result = gr.HTML() - - settings_cols = 3 - items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) - - quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') - - quicksettings_list = [] - - cols_displayed = 0 - items_displayed = 0 - previous_section = None - column = None - with gr.Row(elem_id="settings").style(equal_height=False): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): - if column is not None: - column.__exit__() - - column = gr.Column(variant='panel') - column.__enter__() - - items_displayed = 0 - cols_displayed += 1 - - previous_section = item.section - - elem_id, text = item.section - gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - items_displayed += 1 - - with gr.Row(): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - - with gr.Row(): - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - - restart_gradio.click( - fn=request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) - - if column is not None: - column.__exit__() - with gr.Blocks(analytics_enabled=False) as user_interface: def change_sign_options(choice): return { @@ -2120,7 +2180,7 @@ def change_sign_options(choice): signup_column: gr.update(visible=(choice=="Sign Up")) } - with gr.Row(visible=(shared.username!='')) as user_login_row: + with gr.Row(visible=False) as user_login_row: with gr.Column(): login_username = gr.Text(label="Username") login_password = gr.Text(label="Password", type="password") @@ -2132,7 +2192,7 @@ def change_sign_options(choice): userdelete = gr.Button("Delete") login_output = gr.Label(label="Output") - with gr.Row(visible=(shared.username=='')) as user_sign_row: + with gr.Row(visible=True) as user_sign_row: with gr.Column(): sign_options = gr.Radio(["Sign In", "Sign Up"], label="Sign Options", value="Sign In", interactive=True) @@ -2160,13 +2220,14 @@ def user_signin(signin_username, signin_password): response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) if response.status_code == 200: - shared.username = json.loads(response.text)['username'] + username = json.loads(response.text)['username'] password = json.loads(response.text)['password'] email = json.loads(response.text)['email'] options = json.loads(response.text)['options'] if 'options' in json.loads(response.text) else None response = { - user_login_row : gr.update(visible=True), + shared.username_state: gr.update(value=username), + user_login_row: gr.update(visible=True), user_sign_row: gr.update(visible=False), login_username: gr.update(value=signin_username), login_password: gr.update(value=password), @@ -2183,7 +2244,14 @@ def user_signin(signin_username, signin_password): opts.data = json.loads(options) for key in opts.data: if key in component_dict: - response[component_dict[key]] = gr.update(value=opts.data[key]) + if key == 'sagemaker_endpoint': + sagemaker_endpoint = opts.data[key] + response[component_dict[key]] = gr.update(value=opts.data[key], choices=shared.refresh_sagemaker_endpoints(username)) + elif key == 'sd_model_checkpoint': + shared.refresh_checkpoints(sagemaker_endpoint) + response[component_dict[key]] = gr.update(value=opts.data[key], choices=shared.list_checkpoint_tiles()) + else: + response[component_dict[key]] = gr.update(value=opts.data[key]) for key in sd_models.checkpoints_list: if sd_models.checkpoints_list[key].title == opts.data['sd_model_checkpoint']: if shared.sd_model: @@ -2216,10 +2284,11 @@ def user_signup(signup_username, signup_password, signup_email): } response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) - if response.status_code == 200: - shared.username = json.loads(response.text)['username'] + if response.status_code == 200: + username = json.loads(response.text)['username'] return { + shared.username_state: gr.update(value=username), user_login_row: gr.update(visible=True), user_sign_row: gr.update(visible=False), login_username: gr.update(value=signup_username), @@ -2238,8 +2307,7 @@ def user_signup(signup_username, signup_password, signup_email): } def user_signout(): - shared.username='' - opts.data = shared.default_options + username = '' if 'sd_model_checkpoint' in opts.data: for key in sd_models.checkpoints_list: if sd_models.checkpoints_list[key].title == opts.data['sd_model_checkpoint']: @@ -2247,7 +2315,8 @@ def user_signout(): break response = { - user_login_row : gr.update(visible=False), + shared.username_state: gr.update(value=username), + user_login_row: gr.update(visible=False), user_sign_row: gr.update(visible=True), txt2img_submit: gr.update(visible=False), img2img_submit: gr.update(visible=False), @@ -2258,7 +2327,12 @@ def user_signout(): for key in opts.data: if key in component_dict: - response[component_dict[key]] = gr.update(value=opts.data[key]) + if key == 'sagemaker_endpoint': + response[component_dict[key]] = gr.update(value=opts.data[key], choices=[]) + elif key == 'sd_model_checkpoint': + response[component_dict[key]] = gr.update(value=opts.data[key], choices=[]) + else: + response[component_dict[key]] = gr.update(value=opts.data[key]) return response @@ -2272,8 +2346,6 @@ def user_update(login_username, login_password, login_email): response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) if response.status_code == 200: - shared.username = json.loads(response.text)['username'] - return { login_output: gr.update(value='Update succeed') } @@ -2291,15 +2363,15 @@ def user_delete(login_username, login_password, login_email): response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) if response.status_code == 200: - shared.username = json.loads(response.text)['username'] - opts.data = shared.default_options + username = '' for key in sd_models.checkpoints_list: if sd_models.checkpoints_list[key].title == opts.data['sd_model_checkpoint']: shared.sd_model.sd_model_name = sd_models.checkpoints_list[key].model_name break response = { - user_login_row : gr.update(visible=False), + shared.username_state: gr.update(value=username), + user_login_row: gr.update(visible=False), user_sign_row: gr.update(visible=True), txt2img_submit: gr.update(visible=False), img2img_submit: gr.update(visible=False), @@ -2311,8 +2383,12 @@ def user_delete(login_username, login_password, login_email): for key in opts.data: if key in component_dict: - response[component_dict[key]] = gr.update(value=opts.data[key]) - + if key == 'sagemaker_endpoint': + response[component_dict[key]] = gr.update(value=opts.data[key], choices=[]) + elif key == 'sd_model_checkpoint': + response[component_dict[key]] = gr.update(value=opts.data[key], choices=[]) + else: + response[component_dict[key]] = gr.update(value=opts.data[key]) return response else: return { @@ -2322,19 +2398,19 @@ def user_delete(login_username, login_password, login_email): signin.click( fn=user_signin, inputs=[signin_username, signin_password], - outputs=[user_login_row, user_sign_row, login_username, login_password, login_email,txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, signin_output] + components + outputs=[shared.username_state, user_login_row, user_sign_row, login_username, login_password, login_email,txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, signin_output] + components ) signup.click( fn=user_signup, inputs=[signup_username, signup_password, signup_email], - outputs=[user_login_row, user_sign_row, login_username, login_password, login_email, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, signup_output] + outputs=[shared.username_state, user_login_row, user_sign_row, login_username, login_password, login_email, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, signup_output] ) signout.click( fn=user_signout, inputs=[], - outputs=[user_login_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork] + components + outputs=[shared.username_state, user_login_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork] + components ) userupdate.click( @@ -2346,7 +2422,7 @@ def user_delete(login_username, login_password, login_email): userdelete.click( fn=user_delete, inputs=[login_username, login_password, login_email], - outputs=[user_login_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, login_output] + components + outputs=[shared.username_state, user_login_row, user_sign_row, txt2img_submit, img2img_submit, extras_submit, create_train_embedding, create_train_hypernetwork, login_output] + components ) if cmd_opts.pureui: @@ -2410,7 +2486,7 @@ def user_delete(login_username, login_password, login_email): text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) settings_submit.click( fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, + inputs=components + [shared.username_state], outputs=[text_settings, result], ) @@ -2418,8 +2494,8 @@ def user_delete(login_username, login_password, login_email): component = component_dict[k] component.change( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], + fn=lambda value, k=k: run_settings_single(value, key=k, username=shared.username_state), + inputs=[component] + [shared.username_state], outputs=[component, text_settings], )