From dc48a76af3df58b314491aafab6baa36f482a35c Mon Sep 17 00:00:00 2001 From: xie river Date: Wed, 29 Mar 2023 06:01:36 +0000 Subject: [PATCH 01/31] admin add user ui optimize, add wrap line --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 1bbc11b0e97..75ea817fa53 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1995,7 +1995,8 @@ def sagemaker_train_hypernetwork( interactive=True, visible=True, datatype=["str","str","str", "str"], - type="array" + type="array", + wrap=True, ) with gr.Row(): From 9e45f12642ce3d8aa7ad4bc9a71d67d49942a04c Mon Sep 17 00:00:00 2001 From: xie river Date: Wed, 29 Mar 2023 11:20:49 +0000 Subject: [PATCH 02/31] add training job detail link after submit train job --- localizations/zh_CN.json | 2 +- modules/ui.py | 23 +++++++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/localizations/zh_CN.json b/localizations/zh_CN.json index 9043cffb895..2f44c978aea 100644 --- a/localizations/zh_CN.json +++ b/localizations/zh_CN.json @@ -839,6 +839,6 @@ "Amount of time to pause between Epochs (s)": "Epochs 间隔等待时间", "Save Preview(s) Frequency (Epochs)": "保存预览频率 (Epochs)", "A generic prompt used to generate a sample image to verify model fidelity.": "用于生成样本图像以验证模型保真度的通用提示。", - + "Job detail":"训练任务详情", "--------": "--------" } diff --git a/modules/ui.py b/modules/ui.py index 75ea817fa53..9258d556279 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1542,6 +1542,9 @@ def update_orig(image, state): with gr.Row(): with gr.Column(scale=3): embedding_output = gr.Label(label='Output') + ##begin add train job info by River + embedding_training_job = gr.Markdown('Job detail') + ##end add train job info by River with gr.Column(): create_train_embedding = gr.Button(value="Train Embedding", variant='primary') @@ -1628,6 +1631,9 @@ def update_orig(image, state): with gr.Row(): with gr.Column(scale=3): hypernetwork_output = gr.Label(label='Output') + ##begin add train job info by River + hypernetwork_training_job = gr.Markdown('Job detail') + ##end add train job info by River with gr.Column(): create_train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') @@ -1753,8 +1759,12 @@ def sagemaker_train_embedding( response = requests.post(url=f'{shared.api_endpoint}/train', json=data) if response.status_code == 200: + ##begin add train job info by River + training_job_url = response.text.replace('\"','') return { - embedding_output: gr.update(value='Submit training job sucessful') + embedding_output: gr.update(value='Submit training job sucessful'), + embedding_training_job:gr.update(value=f'Job detail:[{training_job_url}]({training_job_url})') + ##end add train job info by River } else: return { @@ -1886,8 +1896,13 @@ def sagemaker_train_hypernetwork( response = requests.post(url=f'{shared.api_endpoint}/train', json=data) if response.status_code == 200: + ##begin add train job info by River + training_job_url = response.text.replace('\"','') return { - hypernetwork_output: gr.update(value='Submit training job sucessful') + ##begin add train job info by River + hypernetwork_output: gr.update(value='Submit training job sucessful'), + hypernetwork_training_job:gr.update(value=f'Job detail:[{training_job_url}]({training_job_url})') + ##end add train job info by River } else: return { @@ -1935,7 +1950,7 @@ def sagemaker_train_hypernetwork( embedding_training_instance_count, *txt2img_preview_params ], - outputs=[embedding_output] + outputs=[embedding_output,embedding_training_job] ) create_train_hypernetwork.click( @@ -1983,7 +1998,7 @@ def sagemaker_train_hypernetwork( hypernetwork_training_instance_count, *txt2img_preview_params ], - outputs=[hypernetwork_output] + outputs=[hypernetwork_output,hypernetwork_training_job] ) with gr.Blocks(analytics_enabled=False) as user_interface: From 5301172159d3c51d18f9fa827a1024ecea96de28 Mon Sep 17 00:00:00 2001 From: xie river Date: Thu, 30 Mar 2023 16:04:04 +0000 Subject: [PATCH 03/31] can upload train images to s3 via webui --- localizations/zh_CN.json | 5 +++++ modules/shared.py | 1 + modules/ui.py | 24 +++++++++++++++++++++++- 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/localizations/zh_CN.json b/localizations/zh_CN.json index 2f44c978aea..09bb57b644d 100644 --- a/localizations/zh_CN.json +++ b/localizations/zh_CN.json @@ -840,5 +840,10 @@ "Save Preview(s) Frequency (Epochs)": "保存预览频率 (Epochs)", "A generic prompt used to generate a sample image to verify model fidelity.": "用于生成样本图像以验证模型保真度的通用提示。", "Job detail":"训练任务详情", + "S3 bucket name for uploading train images":"上传训练图片集的S3桶名", + "Output S3 folder":"S3文件夹目录", + "Upload Train Images to S3":"上传训练图片到S3", + "Error, please configure a S3 bucket at settings page first":"失败,请先到设置页面配置S3桶名", + "Upload":"上传", "--------": "--------" } diff --git a/modules/shared.py b/modules/shared.py index a7f056db353..a50a276430d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -423,6 +423,7 @@ def refresh_sagemaker_endpoints(username): })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { + "train_files_s3bucket":OptionInfo("","S3 bucket name for uploading train images",component_args=hide_dirs), "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), diff --git a/modules/ui.py b/modules/ui.py index 9258d556279..8dc7f6b59d8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -11,7 +11,8 @@ import time import traceback from functools import partial, reduce - +import boto3 +import datetime import gradio as gr import gradio.routes import gradio.utils @@ -1463,6 +1464,27 @@ def update_orig(image, state): with gr.Row().style(equal_height=False): with gr.Tabs(elem_id="train_tabs"): + ## Begin add s3 images upload interface by River + s3 = boto3.client('s3') + def upload_to_s3(imgs): + username = shared.username + timestamp = datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S') + bucket_name = opts.train_files_s3bucket + if bucket_name == '': + return 'Error, please configure a S3 bucket at settings page first' + folder_name = f"train-images/{username}/{timestamp}" + for i, img in enumerate(imgs): + filename = img.name.split('/')[-1] + object_name = f"{folder_name}/{filename}" + s3.upload_file(img.name, bucket_name, object_name) + return f"{len(imgs)} images uploaded to S3 folder: s3://{bucket_name}/{folder_name}" + + with gr.Tab(label="Upload Train Images to S3"): + upload_files = gr.Files(label="Files") + url_output = gr.Textbox(label="Output S3 folder") + sub_btn = gr.Button("Upload") + sub_btn.click(fn=upload_to_s3, inputs=upload_files, outputs=url_output) + ## End add s3 images upload interface by River with gr.Tab(label="Train Embedding"): gr.HTML(value="

Train an embedding; you must specify a directory with a set of 1:1 ratio images [wiki]

") From f222d10c8c015863b63a1b3579d31d83fe3ac0cc Mon Sep 17 00:00:00 2001 From: xie river Date: Thu, 30 Mar 2023 16:48:16 +0000 Subject: [PATCH 04/31] can upload train images to s3 via webui-timezone fix --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 8dc7f6b59d8..0e125cc0042 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,7 +12,7 @@ import traceback from functools import partial, reduce import boto3 -import datetime +from datetime import datetime, timedelta, timezone import gradio as gr import gradio.routes import gradio.utils @@ -1468,7 +1468,7 @@ def update_orig(image, state): s3 = boto3.client('s3') def upload_to_s3(imgs): username = shared.username - timestamp = datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S') + timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S') bucket_name = opts.train_files_s3bucket if bucket_name == '': return 'Error, please configure a S3 bucket at settings page first' From 2eaf31c0ac8df10f0b00e58559018e6c935734ef Mon Sep 17 00:00:00 2001 From: xie river Date: Fri, 31 Mar 2023 02:33:15 +0000 Subject: [PATCH 05/31] add s3 upload --- modules/ui.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 0e125cc0042..885e908571b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,6 +12,7 @@ import traceback from functools import partial, reduce import boto3 +from botocore.exceptions import ClientError from datetime import datetime, timedelta, timezone import gradio as gr import gradio.routes @@ -1465,18 +1466,24 @@ def update_orig(image, state): with gr.Row().style(equal_height=False): with gr.Tabs(elem_id="train_tabs"): ## Begin add s3 images upload interface by River - s3 = boto3.client('s3') + s3_resource = boto3.resource('s3') def upload_to_s3(imgs): username = shared.username timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S') bucket_name = opts.train_files_s3bucket if bucket_name == '': return 'Error, please configure a S3 bucket at settings page first' + s3_bucket = s3_resource.Bucket(bucket_name) folder_name = f"train-images/{username}/{timestamp}" - for i, img in enumerate(imgs): - filename = img.name.split('/')[-1] - object_name = f"{folder_name}/{filename}" - s3.upload_file(img.name, bucket_name, object_name) + try: + for i, img in enumerate(imgs): + filename = img.name.split('/')[-1] + object_name = f"{folder_name}/{filename}" + s3_bucket.upload_file(img.name,object_name) + except ClientError as e: + print(e) + return e + return f"{len(imgs)} images uploaded to S3 folder: s3://{bucket_name}/{folder_name}" with gr.Tab(label="Upload Train Images to S3"): From 37bb2b13918ff02050956123a7a315562111f992 Mon Sep 17 00:00:00 2001 From: xie river Date: Sun, 2 Apr 2023 14:14:09 +0000 Subject: [PATCH 06/31] model file loaded dynamically from s3 --- localizations/zh_CN.json | 2 +- modules/api/api.py | 2 +- modules/shared.py | 2 +- modules/ui.py | 37 +++++++++-- webui.py | 129 ++++++++++++++++++++++++++++++++++++++- 5 files changed, 161 insertions(+), 11 deletions(-) diff --git a/localizations/zh_CN.json b/localizations/zh_CN.json index 09bb57b644d..5dfde0f8fb8 100644 --- a/localizations/zh_CN.json +++ b/localizations/zh_CN.json @@ -840,7 +840,7 @@ "Save Preview(s) Frequency (Epochs)": "保存预览频率 (Epochs)", "A generic prompt used to generate a sample image to verify model fidelity.": "用于生成样本图像以验证模型保真度的通用提示。", "Job detail":"训练任务详情", - "S3 bucket name for uploading train images":"上传训练图片集的S3桶名", + "S3 bucket name for uploading/downloading images":"上传训练图片集或者下载生成图片的S3桶名", "Output S3 folder":"S3文件夹目录", "Upload Train Images to S3":"上传训练图片到S3", "Error, please configure a S3 bucket at settings page first":"失败,请先到设置页面配置S3桶名", diff --git a/modules/api/api.py b/modules/api/api.py index b442f8bce4a..a065d314d6e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -457,7 +457,7 @@ def invocations(self, req: InvocationsRequest): traceback.print_exc() def ping(self): - print('-------ping------') + # print('-------ping------') return {'status': 'Healthy'} def launch(self, server_name, port): diff --git a/modules/shared.py b/modules/shared.py index a50a276430d..f9b9897a04e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -423,7 +423,7 @@ def refresh_sagemaker_endpoints(username): })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { - "train_files_s3bucket":OptionInfo("","S3 bucket name for uploading train images",component_args=hide_dirs), + "train_files_s3bucket":OptionInfo("","S3 bucket name for uploading/downloading images",component_args=hide_dirs), "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), diff --git a/modules/ui.py b/modules/ui.py index 885e908571b..9bf067fd907 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -86,6 +86,29 @@ def gr_show(visible=True): return {"visible": visible, "__type__": "update"} +## Begin output images uploaded to s3 by River +s3_resource = boto3.resource('s3') + +def save_images_to_s3(full_fillnames,timestamp): + username = shared.username + sagemaker_endpoint = shared.opts.sagemaker_endpoint + bucket_name = opts.train_files_s3bucket + if bucket_name == '': + return 'Error, please configure a S3 bucket at settings page first' + s3_bucket = s3_resource.Bucket(bucket_name) + folder_name = f"output-images/{username}/{sagemaker_endpoint}/{timestamp}" + try: + for i, fname in enumerate(full_fillnames): + filename = fname.split('/')[-1] + object_name = f"{folder_name}/{filename}" + s3_bucket.upload_file(fname,object_name) + print (f'upload file [{i}]:{filename} to s3://{bucket_name}/{object_name}') + except ClientError as e: + print(e) + return e + return f"s3://{bucket_name}/{folder_name}" +## End output images uploaded to s3 by River + sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None @@ -147,7 +170,7 @@ def __init__(self, d=None): os.makedirs(opts.outdir_save, exist_ok=True) - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + with open(os.path.join(opts.outdir_save, "log.csv"), "w", encoding="utf8", newline='') as file: at_start = file.tell() == 0 writer = csv.writer(file) if at_start: @@ -163,16 +186,19 @@ def __init__(self, d=None): break fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - filename = os.path.relpath(fullfn, path) + print(f'fullfn:{fullfn},\n txt_fullfn:{txt_fullfn} \nfilename:{filename}') filenames.append(filename) fullfns.append(fullfn) if txt_fullfn: filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - + + timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S') + logfile = os.path.join(opts.outdir_save, "log.csv") + s3folder = save_images_to_s3(fullfns,timestamp) + save_images_to_s3([logfile],timestamp) # Make Zip if do_make_zip: zip_filepath = os.path.join(path, "images.zip") @@ -184,7 +210,7 @@ def __init__(self, d=None): zip_file.writestr(filenames[i], f.read()) fullfns.insert(0, zip_filepath) - return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") + return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}, \nS3 folder:\n{s3folder}") @@ -1466,7 +1492,6 @@ def update_orig(image, state): with gr.Row().style(equal_height=False): with gr.Tabs(elem_id="train_tabs"): ## Begin add s3 images upload interface by River - s3_resource = boto3.resource('s3') def upload_to_s3(imgs): username = shared.username timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S') diff --git a/webui.py b/webui.py index b42bcfb197a..3fd617579f2 100644 --- a/webui.py +++ b/webui.py @@ -36,6 +36,9 @@ from modules.shared import cmd_opts, opts import modules.hypernetworks.hypernetwork import boto3 +import threading +import time + import traceback from botocore.exceptions import ClientError import requests @@ -64,6 +67,21 @@ def initialize(): modules.scripts.load_scripts() return + ## auto reload new models from s3 add by River + sd_models_tmp_dir = "/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/" + cn_models_tmp_dir = "/opt/ml/code/stable-diffusion-webui/models/ControlNet/" + session = boto3.Session() + region_name = session.region_name + sts_client = session.client('sts') + account_id = sts_client.get_caller_identity()['Account'] + sg_defaul_bucket_name = f"sagemaker-{region_name}-{account_id}" + s3_folder_sd = "stable-diffusion-webui/models/Stable-diffusion" + s3_folder_cn = "stable-diffusion-webui/models/ControlNet" + + sync_s3_folder(sg_defaul_bucket_name,s3_folder_sd,sd_models_tmp_dir,'sd') + sync_s3_folder(sg_defaul_bucket_name,s3_folder_cn,cn_models_tmp_dir,'cn') + ## end + modelloader.cleanup_models() modules.sd_models.setup_model() codeformer.setup_model(cmd_opts.codeformer_models_path) @@ -182,6 +200,114 @@ def user_auth(username, password): return response.status_code == 200 + +def register_sd_models(sd_models_dir): + print ('---register_sd_models()----') + if 'endpoint_name' in os.environ: + items = [] + api_endpoint = os.environ['api_endpoint'] + endpoint_name = os.environ['endpoint_name'] + print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') + for file in os.listdir(sd_models_dir): + if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')): + hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file)) + item = {} + item['model_name'] = file + item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' + item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file) + item['hash'] = hash + item['title'] = '{0} [{1}]'.format(file, hash) + item['endpoint_name'] = endpoint_name + items.append(item) + inputs = { + 'items': items + } + params = { + 'module': 'Stable-diffusion' + } + if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): + response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) + print(response) + +def register_cn_models(cn_models_dir): + print ('---register_cn_models()----') + if 'endpoint_name' in os.environ: + items = [] + api_endpoint = os.environ['api_endpoint'] + endpoint_name = os.environ['endpoint_name'] + print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') + + inputs = { + 'items': items + } + params = { + 'module': 'ControlNet' + } + for file in os.listdir(cn_models_dir): + if os.path.isfile(os.path.join(cn_models_dir, file)) and \ + (file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')): + hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file)) + item = {} + item['model_name'] = file + item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) + item['endpoint_name'] = endpoint_name + items.append(item) + + if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): + response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) + print(response) + + +def sync_s3_folder(bucket_name, s3_folder, local_folder,mode): + print(f"sync S3 bucket '{bucket_name}', folder '{s3_folder}' for new files...") + # Create tmp folders + os.makedirs(os.path.dirname(local_folder), exist_ok=True) + print(f'create dir: {os.path.dirname(local_folder)}') + # Create an S3 client + s3 = boto3.client('s3') + def sync(): + # List all objects in the S3 folder + response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder) + # Check if there are any new or deleted files + s3_files = set() + for obj in response.get('Contents', []): + s3_files.add(obj['Key'].replace(s3_folder, '').lstrip('/')) + + local_files = set(os.listdir(local_folder)) + + new_files = s3_files - local_files + del_files = local_files - s3_files + + # Copy new files to local folder + for file in new_files: + s3.download_file(bucket_name, s3_folder + '/' + file, os.path.join(local_folder, file)) + print(f'download_file:from {bucket_name}/{s3_folder}/{file} to {os.path.join(local_folder, file)}') + + # Delete vanished files from local folder + for file in del_files: + os.remove(os.path.join(local_folder, file)) + print(f'remove file {os.path.join(local_folder, file)}') + # If there are changes + if len(new_files) | len(del_files): + if mode == 'sd': + register_sd_models(local_folder) + elif mode == 'cn': + register_cn_models(local_folder) + else: + print(f'unsupported mode:{mode}') + # Create a thread function to keep syncing with the S3 folder + def sync_thread(): + while True: + sync() + time.sleep(60) + # Initialize at launch + sync() + # Start the thread + thread = threading.Thread(target=sync_thread) + thread.start() + return thread + + def webui(): launch_api = cmd_opts.api initialize() @@ -218,7 +344,7 @@ def webui(): if launch_api: create_api(app) - + cmd_sd_models_path = cmd_opts.ckpt_dir sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion") if cmd_sd_models_path is not None: @@ -274,7 +400,6 @@ def webui(): if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) print(response) - modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo) From 4ec1dd88c5d73a3e3d0f2ae32f850c903ea9ce66 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Tue, 4 Apr 2023 22:33:50 +0800 Subject: [PATCH 07/31] add Lora support --- .../.ipynb_checkpoints/lora-checkpoint.py | 207 ++++++++++ .../Lora/extra_networks_lora.py | 26 ++ extensions-builtin/Lora/lora.py | 362 ++++++++++++++++++ extensions-builtin/Lora/preload.py | 6 + .../lora_script-checkpoint.py | 56 +++ .../Lora/scripts/lora_script.py | 56 +++ .../Lora/ui_extra_networks_lora.py | 31 ++ modules/extra_networks.py | 147 +++++++ modules/extra_networks_hypernet.py | 27 ++ modules/hypernetworks/hypernetwork.py | 19 + modules/processing.py | 12 +- modules/sd_models.py | 22 ++ modules/shared.py | 24 ++ modules/ui_extra_networks.py | 320 ++++++++++++++++ modules/ui_extra_networks_checkpoints.py | 31 ++ modules/ui_extra_networks_hypernets.py | 30 ++ .../ui_extra_networks_textual_inversion.py | 29 ++ webui.py | 22 +- 18 files changed, 1425 insertions(+), 2 deletions(-) create mode 100644 extensions-builtin/Lora/.ipynb_checkpoints/lora-checkpoint.py create mode 100644 extensions-builtin/Lora/extra_networks_lora.py create mode 100644 extensions-builtin/Lora/lora.py create mode 100644 extensions-builtin/Lora/preload.py create mode 100644 extensions-builtin/Lora/scripts/.ipynb_checkpoints/lora_script-checkpoint.py create mode 100644 extensions-builtin/Lora/scripts/lora_script.py create mode 100644 extensions-builtin/Lora/ui_extra_networks_lora.py create mode 100644 modules/extra_networks.py create mode 100644 modules/extra_networks_hypernet.py create mode 100644 modules/ui_extra_networks.py create mode 100644 modules/ui_extra_networks_checkpoints.py create mode 100644 modules/ui_extra_networks_hypernets.py create mode 100644 modules/ui_extra_networks_textual_inversion.py diff --git a/extensions-builtin/Lora/.ipynb_checkpoints/lora-checkpoint.py b/extensions-builtin/Lora/.ipynb_checkpoints/lora-checkpoint.py new file mode 100644 index 00000000000..42b9eb56aaf --- /dev/null +++ b/extensions-builtin/Lora/.ipynb_checkpoints/lora-checkpoint.py @@ -0,0 +1,207 @@ +import glob +import os +import re +import torch + +from modules import shared, devices, sd_models + +re_digits = re.compile(r"\d+") +re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") +re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") +re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") +re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") + + +def convert_diffusers_name_to_compvis(key): + def match(match_list, regex): + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, re_unet_down_blocks): + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_unet_mid_blocks): + return f"diffusion_model_middle_block_1_{m[1]}" + + if match(m, re_unet_up_blocks): + return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_text_block): + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + return key + + +class LoraOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + + +class LoraModule: + def __init__(self, name): + self.name = name + self.multiplier = 1.0 + self.modules = {} + self.mtime = None + + +class LoraUpDownModule: + def __init__(self): + self.up = None + self.down = None + self.alpha = None + + +def assign_lora_names_to_compvis_modules(sd_model): + lora_layer_mapping = {} + + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + for name, module in shared.sd_model.model.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + sd_model.lora_layer_mapping = lora_layer_mapping + + +def load_lora(name, filename): + lora = LoraModule(name) + lora.mtime = os.path.getmtime(filename) + + sd = sd_models.read_state_dict(filename) + + keys_failed_to_match = [] + + for key_diffusers, weight in sd.items(): + fullkey = convert_diffusers_name_to_compvis(key_diffusers) + key, lora_key = fullkey.split(".", 1) + + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: + keys_failed_to_match.append(key_diffusers) + continue + + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "alpha": + lora_module.alpha = weight.item() + continue + + if type(sd_module) == torch.nn.Linear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.Conv2d: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' + + with torch.no_grad(): + module.weight.copy_(weight) + + module.to(device=devices.device, dtype=devices.dtype) + + if lora_key == "lora_up.weight": + lora_module.up = module + elif lora_key == "lora_down.weight": + lora_module.down = module + else: + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' + + if len(keys_failed_to_match) > 0: + print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") + + return lora + + +def load_loras(names, multipliers=None): + already_loaded = {} + + for lora in loaded_loras: + if lora.name in names: + already_loaded[lora.name] = lora + + loaded_loras.clear() + + loras_on_disk = [available_loras.get(name, None) for name in names] + if any([x is None for x in loras_on_disk]): + list_available_loras() + + loras_on_disk = [available_loras.get(name, None) for name in names] + + for i, name in enumerate(names): + lora = already_loaded.get(name, None) + + lora_on_disk = loras_on_disk[i] + if lora_on_disk is not None: + if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: + lora = load_lora(name, lora_on_disk.filename) + + if lora is None: + print(f"Couldn't find Lora with name {name}") + continue + + lora.multiplier = multipliers[i] if multipliers else 1.0 + loaded_loras.append(lora) + + +def lora_forward(module, input, res): + if len(loaded_loras) == 0: + return res + + lora_layer_name = getattr(module, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is not None: + if shared.opts.lora_apply_to_outputs and res.shape == input.shape: + res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + else: + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + return res + + +def lora_Linear_forward(self, input): + return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + + +def lora_Conv2d_forward(self, input): + return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + + +def list_available_loras(): + available_loras.clear() + + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) + + candidates = \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) + + for filename in sorted(candidates): + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + + available_loras[name] = LoraOnDisk(name, filename) + + +available_loras = {} +loaded_loras = [] + +list_available_loras() diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py new file mode 100644 index 00000000000..db63a4bb819 --- /dev/null +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -0,0 +1,26 @@ +from modules import extra_networks, shared +import lora + +class ExtraNetworkLora(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('lora') + + def activate(self, p, params_list): + additional = shared.opts.sd_lora + + if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: + p.all_prompts = [x + f"" for x in p.all_prompts] + params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) + + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + lora.load_loras(names, multipliers) + + def deactivate(self, p): + pass diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py new file mode 100644 index 00000000000..2c4cb80be65 --- /dev/null +++ b/extensions-builtin/Lora/lora.py @@ -0,0 +1,362 @@ +import glob +import os +import re +import torch +from typing import Union + +from modules import shared, devices, sd_models, errors + +metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} + +re_digits = re.compile(r"\d+") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} + + +def convert_diffusers_name_to_compvis(key, is_sd2): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" + + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + return key + + +class LoraOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + self.metadata = {} + + _, ext = os.path.splitext(filename) + if ext.lower() == ".safetensors": + try: + self.metadata = sd_models.read_metadata_from_safetensors(filename) + except Exception as e: + errors.display(e, f"reading lora {filename}") + + if self.metadata: + m = {} + for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): + m[k] = v + + self.metadata = m + + self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text + + +class LoraModule: + def __init__(self, name): + self.name = name + self.multiplier = 1.0 + self.modules = {} + self.mtime = None + + +class LoraUpDownModule: + def __init__(self): + self.up = None + self.down = None + self.alpha = None + + +def assign_lora_names_to_compvis_modules(sd_model): + lora_layer_mapping = {} + + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + for name, module in shared.sd_model.model.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + sd_model.lora_layer_mapping = lora_layer_mapping + + +def load_lora(name, filename): + lora = LoraModule(name) + lora.mtime = os.path.getmtime(filename) + + sd = sd_models.read_state_dict(filename) + + keys_failed_to_match = {} + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping + + for key_diffusers, weight in sd.items(): + key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) + key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2) + + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + + if sd_module is None: + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) + + if sd_module is None: + keys_failed_to_match[key_diffusers] = key + continue + + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "alpha": + lora_module.alpha = weight.item() + continue + + if type(sd_module) == torch.nn.Linear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.MultiheadAttention: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.Conv2d: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') + continue + assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' + + with torch.no_grad(): + module.weight.copy_(weight) + + module.to(device=devices.cpu, dtype=devices.dtype) + + if lora_key == "lora_up.weight": + lora_module.up = module + elif lora_key == "lora_down.weight": + lora_module.down = module + else: + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' + + if len(keys_failed_to_match) > 0: + print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") + + return lora + + +def load_loras(names, multipliers=None): + already_loaded = {} + + for lora in loaded_loras: + if lora.name in names: + already_loaded[lora.name] = lora + + loaded_loras.clear() + + loras_on_disk = [available_loras.get(name, None) for name in names] + if any([x is None for x in loras_on_disk]): + list_available_loras() + + loras_on_disk = [available_loras.get(name, None) for name in names] + + for i, name in enumerate(names): + lora = already_loaded.get(name, None) + + lora_on_disk = loras_on_disk[i] + if lora_on_disk is not None: + if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: + lora = load_lora(name, lora_on_disk.filename) + + if lora is None: + print(f"Couldn't find Lora with name {name}") + continue + + lora.multiplier = multipliers[i] if multipliers else 1.0 + loaded_loras.append(lora) + + +def lora_calc_updown(lora, module, target): + with torch.no_grad(): + up = module.up.weight.to(target.device, dtype=target.dtype) + down = module.down.weight.to(target.device, dtype=target.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + updown = up @ down + + updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + return updown + + +def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + """ + Applies the currently selected set of Loras to the weights of torch layer self. + If weights already have this particular set of loras applied, does nothing. + If not, restores orginal weights from backup and alters weights according to loras. + """ + + lora_layer_name = getattr(self, 'lora_layer_name', None) + if lora_layer_name is None: + return + + current_names = getattr(self, "lora_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) + + weights_backup = getattr(self, "lora_weights_backup", None) + if weights_backup is None: + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + else: + weights_backup = self.weight.to(devices.cpu, copy=True) + + self.lora_weights_backup = weights_backup + + if current_names != wanted_names: + if weights_backup is not None: + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is not None and hasattr(self, 'weight'): + self.weight += lora_calc_updown(lora, module, self.weight) + continue + + module_q = lora.modules.get(lora_layer_name + "_q_proj", None) + module_k = lora.modules.get(lora_layer_name + "_k_proj", None) + module_v = lora.modules.get(lora_layer_name + "_v_proj", None) + module_out = lora.modules.get(lora_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight) + updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight) + updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + + self.in_proj_weight += updown_qkv + self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight) + continue + + if module is None: + continue + + print(f'failed to calculate lora weights for layer {lora_layer_name}') + + setattr(self, "lora_current_names", wanted_names) + + +def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + +def lora_Linear_forward(self, input): + lora_apply_weights(self) + + return torch.nn.Linear_forward_before_lora(self, input) + + +def lora_Linear_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) + + +def lora_Conv2d_forward(self, input): + lora_apply_weights(self) + + return torch.nn.Conv2d_forward_before_lora(self, input) + + +def lora_Conv2d_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) + + +def lora_MultiheadAttention_forward(self, *args, **kwargs): + lora_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs) + + +def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs) + + +def list_available_loras(): + available_loras.clear() + + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) + + candidates = \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) + + for filename in sorted(candidates, key=str.lower): + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + + available_loras[name] = LoraOnDisk(name, filename) + + +available_loras = {} +loaded_loras = [] + +list_available_loras() diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py new file mode 100644 index 00000000000..c47d7ef4e24 --- /dev/null +++ b/extensions-builtin/Lora/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) diff --git a/extensions-builtin/Lora/scripts/.ipynb_checkpoints/lora_script-checkpoint.py b/extensions-builtin/Lora/scripts/.ipynb_checkpoints/lora_script-checkpoint.py new file mode 100644 index 00000000000..302888387c2 --- /dev/null +++ b/extensions-builtin/Lora/scripts/.ipynb_checkpoints/lora_script-checkpoint.py @@ -0,0 +1,56 @@ +import torch +import gradio as gr + +import lora +import extra_networks_lora +import ui_extra_networks_lora +from modules import script_callbacks, ui_extra_networks, extra_networks, shared + + +def unload(): + torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora + torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora + torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora + torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora + + +def before_ui(): + ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) + extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) + + +if not hasattr(torch.nn, 'Linear_forward_before_lora'): + torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward + +if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): + torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict + +if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): + torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward + +if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): + torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict + +if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): + torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward + +if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): + torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict + +torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict +torch.nn.Conv2d.forward = lora.lora_Conv2d_forward +torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict +torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward +torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict + +script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) +script_callbacks.on_script_unloaded(unload) +script_callbacks.on_before_ui(before_ui) + + +shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { + "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), +})) diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py new file mode 100644 index 00000000000..302888387c2 --- /dev/null +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -0,0 +1,56 @@ +import torch +import gradio as gr + +import lora +import extra_networks_lora +import ui_extra_networks_lora +from modules import script_callbacks, ui_extra_networks, extra_networks, shared + + +def unload(): + torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora + torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora + torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora + torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora + + +def before_ui(): + ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) + extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) + + +if not hasattr(torch.nn, 'Linear_forward_before_lora'): + torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward + +if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): + torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict + +if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): + torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward + +if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): + torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict + +if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): + torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward + +if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): + torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict + +torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict +torch.nn.Conv2d.forward = lora.lora_Conv2d_forward +torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict +torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward +torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict + +script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) +script_callbacks.on_script_unloaded(unload) +script_callbacks.on_before_ui(before_ui) + + +shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { + "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), +})) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py new file mode 100644 index 00000000000..eb0fc4634d7 --- /dev/null +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -0,0 +1,31 @@ +import json +import os +import lora + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Lora') + + def refresh(self): + lora.list_available_loras() + + def list_items(self): + for name, lora_on_disk in lora.available_loras.items(): + path, ext = os.path.splitext(lora_on_disk.filename) + yield { + "name": name, + "filename": path, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(lora_on_disk.filename), + "prompt": json.dumps(f""), + "local_preview": f"{path}.{shared.opts.samples_format}", + "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.lora_dir] + diff --git a/modules/extra_networks.py b/modules/extra_networks.py new file mode 100644 index 00000000000..cf6aa790715 --- /dev/null +++ b/modules/extra_networks.py @@ -0,0 +1,147 @@ +import re +from collections import defaultdict + +from modules import errors + +extra_network_registry = {} + + +def initialize(): + extra_network_registry.clear() + + +def register_extra_network(extra_network): + extra_network_registry[extra_network.name] = extra_network + + +class ExtraNetworkParams: + def __init__(self, items=None): + self.items = items or [] + + +class ExtraNetwork: + def __init__(self, name): + self.name = name + + def activate(self, p, params_list): + """ + Called by processing on every run. Whatever the extra network is meant to do should be activated here. + Passes arguments related to this extra network in params_list. + User passes arguments by specifying this in his prompt: + + + + Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments + separated by colon. + + Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list - + in this case, all effects of this extra networks should be disabled. + + Can be called multiple times before deactivate() - each new call should override the previous call completely. + + For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is: + + > "1girl, " + + params_list will be: + + [ + ExtraNetworkParams(items=["agm", "1.1"]), + ExtraNetworkParams(items=["ray"]) + ] + + """ + raise NotImplementedError + + def deactivate(self, p): + """ + Called at the end of processing for housekeeping. No need to do anything here. + """ + + raise NotImplementedError + + +def activate(p, extra_network_data): + """call activate for extra networks in extra_network_data in specified order, then call + activate for all remaining registered networks with an empty argument list""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + print(f"Skipping unknown extra network: {extra_network_name}") + continue + + try: + extra_network.activate(p, extra_network_args) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.activate(p, []) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name}") + + +def deactivate(p, extra_network_data): + """call deactivate for extra networks in extra_network_data in specified order, then call + deactivate for all remaining registered networks""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating extra network {extra_network_name}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating unmentioned extra network {extra_network_name}") + + +re_extra_net = re.compile(r"<(\w+):([^>]+)>") + + +def parse_prompt(prompt): + res = defaultdict(list) + + def found(m): + name = m.group(1) + args = m.group(2) + + res[name].append(ExtraNetworkParams(items=args.split(":"))) + + return "" + + prompt = re.sub(re_extra_net, found, prompt) + + return prompt, res + + +def parse_prompts(prompts): + res = [] + extra_data = None + + for prompt in prompts: + updated_prompt, parsed_extra_data = parse_prompt(prompt) + + if extra_data is None: + extra_data = parsed_extra_data + + res.append(updated_prompt) + + return res, extra_data + diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py new file mode 100644 index 00000000000..207343daa67 --- /dev/null +++ b/modules/extra_networks_hypernet.py @@ -0,0 +1,27 @@ +from modules import extra_networks, shared, extra_networks +from modules.hypernetworks import hypernetwork + + +class ExtraNetworkHypernet(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('hypernet') + + def activate(self, p, params_list): + additional = shared.opts.sd_hypernetwork + + if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: + p.all_prompts = [x + f"" for x in p.all_prompts] + params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) + + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + hypernetwork.load_hypernetworks(names, multipliers) + + def deactivate(self, p): + pass diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 9eb27ce5b64..520b53c62e5 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -292,6 +292,25 @@ def load_hypernetwork(filename): shared.loaded_hypernetwork = None +def load_hypernetworks(names, multipliers=None): + already_loaded = {} + + for hypernetwork in shared.loaded_hypernetworks: + if hypernetwork.name in names: + already_loaded[hypernetwork.name] = hypernetwork + + shared.loaded_hypernetworks.clear() + + for i, name in enumerate(names): + hypernetwork = already_loaded.get(name, None) + if hypernetwork is None: + hypernetwork = load_hypernetwork(name) + + if hypernetwork is None: + continue + + hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0) + shared.loaded_hypernetworks.append(hypernetwork) def find_closest_hypernetwork_name(search: str): if not search: diff --git a/modules/processing.py b/modules/processing.py index b8d5bc3b75e..14f47d78dfa 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -120,6 +120,7 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom self.s_noise = s_noise or opts.s_noise self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.is_using_inpainting_conditioning = False + self.disable_extra_networks = False self.script_args = json.loads(script_args) if script_args != None else None @@ -530,6 +531,12 @@ def infotext(iteration=0, position_in_batch=0): if len(prompts) == 0: break + prompts, extra_network_data = extra_networks.parse_prompts(prompts) + + if not p.disable_extra_networks: + with devices.autocast(): + extra_networks.activate(p, extra_network_data) + if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) @@ -618,6 +625,9 @@ def infotext(iteration=0, position_in_batch=0): if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + if not p.disable_extra_networks and extra_network_data: + extra_networks.deactivate(p, extra_network_data) + devices.torch_gc() res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) diff --git a/modules/sd_models.py b/modules/sd_models.py index 5c2d7e13d2a..0fae36298a0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -214,6 +214,28 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +def read_metadata_from_safetensors(filename): + import json + + with open(filename, mode="rb") as file: + metadata_len = file.read(8) + metadata_len = int.from_bytes(metadata_len, "little") + json_start = file.read(2) + + assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file" + json_data = json_start + file.read(metadata_len-2) + json_obj = json.loads(json_data) + + res = {} + for k, v in json_obj.get("__metadata__", {}).items(): + res[k] = v + if isinstance(v, str) and v[0:1] == '{': + try: + res[k] = json.loads(v) + except Exception as e: + pass + + return res def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) diff --git a/modules/shared.py b/modules/shared.py index a7f056db353..5dc3e912883 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -136,6 +136,7 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = {} loaded_hypernetwork = None +loaded_hypernetworks = [] if not cmd_opts.train: api_endpoint = os.environ['api_endpoint'] @@ -486,6 +487,15 @@ def refresh_sagemaker_endpoints(username): "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"), })) +options_templates.update(options_section(('extra_networks', "Extra Networks"), { + "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}), + "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"), + "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"), + "extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"), + "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), +})) + options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), @@ -502,6 +512,7 @@ def refresh_sagemaker_endpoints(username): "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), 'quicksettings': OptionInfo("", "Quicksettings list"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), + "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { @@ -689,3 +700,16 @@ def clear(self): def listfiles(dirname): filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] return [file for file in filenames if os.path.isfile(file)] + +def html_path(filename): + return os.path.join(script_path, "html", filename) + + +def html(filename): + path = html_path(filename) + + if os.path.exists(path): + with open(path, encoding="utf8") as file: + return file.read() + + return "" \ No newline at end of file diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py new file mode 100644 index 00000000000..85eee60f1ed --- /dev/null +++ b/modules/ui_extra_networks.py @@ -0,0 +1,320 @@ +import glob +import os.path +import urllib.parse +from pathlib import Path +from PIL import PngImagePlugin + +from modules import shared +from modules.images import read_info_from_image +import gradio as gr +import json +import html + +from modules.generation_parameters_copypaste import image_from_url_text + +extra_pages = [] +allowed_dirs = set() + + +def register_page(page): + """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" + + extra_pages.append(page) + allowed_dirs.clear() + allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], []))) + + +def fetch_file(filename: str = ""): + from starlette.responses import FileResponse + + if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]): + raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + + ext = os.path.splitext(filename)[1].lower() + if ext not in (".png", ".jpg", ".webp"): + raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.") + + # would profit from returning 304 + return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) + + +def get_metadata(page: str = "", item: str = ""): + from starlette.responses import JSONResponse + + page = next(iter([x for x in extra_pages if x.name == page]), None) + if page is None: + return JSONResponse({}) + + metadata = page.metadata.get(item) + if metadata is None: + return JSONResponse({}) + + return JSONResponse({"metadata": metadata}) + + +def add_pages_to_demo(app): + app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) + app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"]) + + +class ExtraNetworksPage: + def __init__(self, title): + self.title = title + self.name = title.lower() + self.card_page = shared.html("extra-networks-card.html") + self.allow_negative_prompt = False + self.metadata = {} + + def refresh(self): + pass + + def link_preview(self, filename): + return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename)) + + def search_terms_from_path(self, filename, possible_directories=None): + abspath = os.path.abspath(filename) + + for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): + parentdir = os.path.abspath(parentdir) + if abspath.startswith(parentdir): + return abspath[len(parentdir):].replace('\\', '/') + + return "" + + def create_html(self, tabname): + view = shared.opts.extra_networks_default_view + items_html = '' + + self.metadata = {} + + subdirs = {} + for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: + for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True): + if not os.path.isdir(x): + continue + + subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") + while subdir.startswith("/"): + subdir = subdir[1:] + + is_empty = len(os.listdir(x)) == 0 + if not is_empty and not subdir.endswith("/"): + subdir = subdir + "/" + + subdirs[subdir] = 1 + + if subdirs: + subdirs = {"": 1, **subdirs} + + subdirs_html = "".join([f""" + +""" for subdir in subdirs]) + + for item in self.list_items(): + metadata = item.get("metadata") + if metadata: + self.metadata[item["name"]] = metadata + + items_html += self.create_html_for_item(item, tabname) + + if items_html == '': + dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) + items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) + + self_name_id = self.name.replace(" ", "_") + + res = f""" +
    +{subdirs_html} +
    +
    +{items_html} +
    +""" + + return res + + def list_items(self): + raise NotImplementedError() + + def allowed_directories_for_previews(self): + return [] + + def create_html_for_item(self, item, tabname): + preview = item.get("preview", None) + + onclick = item.get("onclick", None) + if onclick is None: + onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + + height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' + width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' + background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else '' + metadata_button = "" + metadata = item.get("metadata") + if metadata: + metadata_button = f"" + + args = { + "style": f"'{height}{width}{background_image}'", + "prompt": item.get("prompt", None), + "tabname": json.dumps(tabname), + "local_preview": json.dumps(item["local_preview"]), + "name": item["name"], + "description": (item.get("description") or ""), + "card_clicked": onclick, + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', + "search_term": item.get("search_term", ""), + "metadata_button": metadata_button, + } + + return self.card_page.format(**args) + + def find_preview(self, path): + """ + Find a preview PNG for a given path (without extension) and call link_preview on it. + """ + + preview_extensions = ["png", "jpg", "webp"] + if shared.opts.samples_format not in preview_extensions: + preview_extensions.append(shared.opts.samples_format) + + potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], []) + + for file in potential_files: + if os.path.isfile(file): + return self.link_preview(file) + + return None + + def find_description(self, path): + """ + Find and read a description file for a given path (without extension). + """ + for file in [f"{path}.txt", f"{path}.description.txt"]: + try: + with open(file, "r", encoding="utf-8", errors="replace") as f: + return f.read() + except OSError: + pass + return None + + +def intialize(): + extra_pages.clear() + + +class ExtraNetworksUi: + def __init__(self): + self.pages = None + self.stored_extra_pages = None + + self.button_save_preview = None + self.preview_target_filename = None + + self.tabname = None + + +def pages_in_preferred_order(pages): + tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")] + + def tab_name_score(name): + name = name.lower() + for i, possible_match in enumerate(tab_order): + if possible_match in name: + return i + + return len(pages) + + tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)} + + return sorted(pages, key=lambda x: tab_scores[x.name]) + + +def create_ui(container, button, tabname): + ui = ExtraNetworksUi() + ui.pages = [] + ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) + ui.tabname = tabname + + with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: + for page in ui.stored_extra_pages: + with gr.Tab(page.title): + + page_elem = gr.HTML(page.create_html(ui.tabname)) + ui.pages.append(page_elem) + + filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") + + ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + + def toggle_visibility(is_visible): + is_visible = not is_visible + return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")) + + state_visible = gr.State(value=False) + button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button]) + + def refresh(): + res = [] + + for pg in ui.stored_extra_pages: + pg.refresh() + res.append(pg.create_html(ui.tabname)) + + return res + + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) + + return ui + + +def path_is_parent(parent_path, child_path): + parent_path = os.path.abspath(parent_path) + child_path = os.path.abspath(child_path) + + return child_path.startswith(parent_path) + + +def setup_ui(ui, gallery): + def save_preview(index, images, filename): + if len(images) == 0: + print("There is no image in gallery to save as a preview.") + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + index = int(index) + index = 0 if index < 0 else index + index = len(images) - 1 if index >= len(images) else index + + img_info = images[index if index >= 0 else 0] + image = image_from_url_text(img_info) + geninfo, items = read_info_from_image(image) + + is_allowed = False + for extra_page in ui.stored_extra_pages: + if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]): + is_allowed = True + break + + assert is_allowed, f'writing to {filename} is not allowed' + + if geninfo: + pnginfo_data = PngImagePlugin.PngInfo() + pnginfo_data.add_text('parameters', geninfo) + image.save(filename, pnginfo=pnginfo_data) + else: + image.save(filename) + + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + ui.button_save_preview.click( + fn=save_preview, + _js="function(x, y, z){return [selected_gallery_index(), y, z]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], + outputs=[*ui.pages] + ) + diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py new file mode 100644 index 00000000000..e6d19d60d6b --- /dev/null +++ b/modules/ui_extra_networks_checkpoints.py @@ -0,0 +1,31 @@ +import html +import json +import os + +from modules import shared, ui_extra_networks, sd_models + + +class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Checkpoints') + + def refresh(self): + shared.refresh_checkpoints() + + def list_items(self): + checkpoint: sd_models.CheckpointInfo + for name, checkpoint in sd_models.checkpoints_list.items(): + path, ext = os.path.splitext(checkpoint.filename) + yield { + "name": checkpoint.name_for_extra, + "filename": path, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), + "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', + "local_preview": f"{path}.{shared.opts.samples_format}", + } + + def allowed_directories_for_previews(self): + return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] + diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py new file mode 100644 index 00000000000..d9ad12d8bf0 --- /dev/null +++ b/modules/ui_extra_networks_hypernets.py @@ -0,0 +1,30 @@ +import json +import os + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Hypernetworks') + + def refresh(self): + shared.reload_hypernetworks() + + def list_items(self): + for name, path in shared.hypernetworks.items(): + path, ext = os.path.splitext(path) + + yield { + "name": name, + "filename": path, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(path), + "prompt": json.dumps(f""), + "local_preview": f"{path}.preview.{shared.opts.samples_format}", + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.hypernetwork_dir] + diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py new file mode 100644 index 00000000000..cf09fbcb9a2 --- /dev/null +++ b/modules/ui_extra_networks_textual_inversion.py @@ -0,0 +1,29 @@ +import json +import os + +from modules import ui_extra_networks, sd_hijack, shared + + +class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Textual Inversion') + self.allow_negative_prompt = True + + def refresh(self): + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + + def list_items(self): + for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): + path, ext = os.path.splitext(embedding.filename) + yield { + "name": embedding.name, + "filename": embedding.filename, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(embedding.filename), + "prompt": json.dumps(embedding.name), + "local_preview": f"{path}.preview.{shared.opts.samples_format}", + } + + def allowed_directories_for_previews(self): + return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/webui.py b/webui.py index b42bcfb197a..f9164a5045c 100644 --- a/webui.py +++ b/webui.py @@ -11,11 +11,13 @@ from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse +from modules import extra_networks, ui_extra_networks_checkpoints +from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock from modules.paths import script_path -from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir +from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -85,6 +87,14 @@ def initialize(): shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: try: @@ -190,6 +200,7 @@ def webui(): if shared.opts.clean_temp_dir_at_start: ui_tempdir.cleanup_tmpdr() + modules.script_callbacks.before_ui_callback() shared.demo = modules.ui.create_ui() app, local_url, share_url = shared.demo.launch( @@ -274,6 +285,7 @@ def webui(): if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) print(response) + ui_extra_networks.add_pages_to_demo(app) modules.script_callbacks.app_started_callback(shared.demo, app) @@ -296,6 +308,14 @@ def webui(): modules.sd_models.list_models() print('Restarting Gradio') + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + def upload_s3files(s3uri, file_path_with_pattern): pos = s3uri.find('/', 5) bucket = s3uri[5 : pos] From 553d5539532d677bab48ab6ae691244dc354482b Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Thu, 6 Apr 2023 16:53:47 +0800 Subject: [PATCH 08/31] fix for the issues with request --- modules/ui.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 1bbc11b0e97..200462e4b4d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1637,6 +1637,7 @@ def update_orig(image, state): dreambooth_tab.render() def sagemaker_train_embedding( + request: gr.Request, sd_model_checkpoint, new_embedding_name, initialization_text, @@ -1673,8 +1674,7 @@ def sagemaker_train_embedding( embedding_preview_from_txt2img, embedding_training_instance_type, embedding_training_instance_count, - *txt2img_preview_params, - request: gr.Request + *txt2img_preview_params ): tokens = shared.demo.server_app.tokens @@ -1762,6 +1762,7 @@ def sagemaker_train_embedding( } def sagemaker_train_hypernetwork( + request: gr.Request, sd_model_checkpoint, new_hypernetwork_name, new_hypernetwork_sizes, @@ -1802,8 +1803,7 @@ def sagemaker_train_hypernetwork( hypernetwork_preview_from_txt2img, hypernetwork_training_instance_type, hypernetwork_training_instance_count, - *txt2img_preview_params, - request: gr.Request + *txt2img_preview_params ): tokens = shared.demo.server_app.tokens From 91afed17fa06c5bfc4f2114f1eaf8d0d1bdae607 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Fri, 7 Apr 2023 19:56:24 +0800 Subject: [PATCH 09/31] update for lora --- modules/script_callbacks.py | 67 ++++++- modules/sd_hijack.py | 5 +- modules/sd_models.py | 2 + .../textual_inversion/textual_inversion.py | 173 ++++++++++++------ modules/ui.py | 2 - 5 files changed, 187 insertions(+), 62 deletions(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 8e22f875564..c5d41d0e5c8 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,7 +2,7 @@ import traceback from collections import namedtuple import inspect -from typing import Optional +from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks @@ -50,6 +50,11 @@ class UiTrainTabParams: def __init__(self, txt2img_preview_params): self.txt2img_preview_params = txt2img_preview_params +class ImageGridLoopParams: + def __init__(self, imgs, cols, rows): + self.imgs = imgs + self.cols = cols + self.rows = rows ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callback_map = dict( @@ -63,6 +68,10 @@ def __init__(self, txt2img_preview_params): callbacks_cfg_denoiser=[], callbacks_before_component=[], callbacks_after_component=[], + callbacks_image_grid=[], + callbacks_infotext_pasted=[], + callbacks_script_unloaded=[], + callbacks_before_ui=[], ) @@ -154,6 +163,34 @@ def after_component_callback(component, **kwargs): except Exception: report_exception(c, 'after_component_callback') +def image_grid_callback(params: ImageGridLoopParams): + for c in callback_map['callbacks_image_grid']: + try: + c.callback(params) + except Exception: + report_exception(c, 'image_grid') + + +def infotext_pasted_callback(infotext: str, params: Dict[str, Any]): + for c in callback_map['callbacks_infotext_pasted']: + try: + c.callback(infotext, params) + except Exception: + report_exception(c, 'infotext_pasted') + +def script_unloaded_callback(): + for c in reversed(callback_map['callbacks_script_unloaded']): + try: + c.callback() + except Exception: + report_exception(c, 'script_unloaded') + +def before_ui_callback(): + for c in reversed(callback_map['callbacks_before_ui']): + try: + c.callback() + except Exception: + report_exception(c, 'before_ui') def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] @@ -255,3 +292,31 @@ def on_before_component(callback): def on_after_component(callback): """register a function to be called after a component is created. See on_before_component for more.""" add_callback(callback_map['callbacks_after_component'], callback) + +def on_image_grid(callback): + """register a function to be called before making an image grid. + The callback is called with one argument: + - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. + """ + add_callback(callback_map['callbacks_image_grid'], callback) + + +def on_infotext_pasted(callback): + """register a function to be called before applying an infotext. + The callback is called with two arguments: + - infotext: str - raw infotext. + - result: Dict[str, any] - parsed infotext parameters. + """ + add_callback(callback_map['callbacks_infotext_pasted'], callback) + + +def on_script_unloaded(callback): + """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that + the script did should be reverted here""" + + add_callback(callback_map['callbacks_script_unloaded'], callback) + +def on_before_ui(callback): + """register a function to be called before the UI is created.""" + + add_callback(callback_map['callbacks_before_ui'], callback) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 95a17093da1..2d48748c4da 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -79,7 +79,10 @@ class StableDiffusionModelHijack: circular_enabled = False clip = None - embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) + embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() + + def __init__(self): + self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) def hijack(self, m): if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: diff --git a/modules/sd_models.py b/modules/sd_models.py index 0fae36298a0..389120ca6dd 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -347,6 +347,8 @@ def load_model(checkpoint_info=None): sd_model.eval() shared.sd_model = sd_model + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + script_callbacks.model_loaded_callback(sd_model) print(f"Model loaded.") diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e28c357ab47..8b95dbfa8e2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -52,19 +52,42 @@ def const_hash(a): self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' return self.cached_checksum +class DirWithTextualInversionEmbeddings: + def __init__(self, path): + self.path = path + self.mtime = None + + def has_changed(self): + if not os.path.isdir(self.path): + return False + + mt = os.path.getmtime(self.path) + if self.mtime is None or mt > self.mtime: + return True + + def update(self): + if not os.path.isdir(self.path): + return + + self.mtime = os.path.getmtime(self.path) class EmbeddingDatabase: - def __init__(self, embeddings_dir): + def __init__(self): self.ids_lookup = {} self.word_embeddings = {} - self.dir_mtime = None - self.embeddings_dir = embeddings_dir + self.skipped_embeddings = {} + self.expected_shape = -1 + self.embedding_dirs = {} - def register_embedding(self, embedding, model): + def add_embedding_dir(self, path): + self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) + def clear_embedding_dirs(self): + self.embedding_dirs.clear() + + def register_embedding(self, embedding, model): self.word_embeddings[embedding.name] = embedding - # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working ids = model.cond_stage_model.tokenize([embedding.name])[0] first_id = ids[0] @@ -75,70 +98,104 @@ def register_embedding(self, embedding, model): return embedding - def load_textual_inversion_embeddings(self): - mt = os.path.getmtime(self.embeddings_dir) - if self.dir_mtime is not None and mt <= self.dir_mtime: - return - - self.dir_mtime = mt - self.ids_lookup.clear() - self.word_embeddings.clear() + def get_expected_shape(self): + vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) + return vec.shape[1] - def process_file(path, filename): - name = os.path.splitext(filename)[0] + def load_from_file(self, path, filename): + name, ext = os.path.splitext(filename) + ext = ext.upper() - data = [] + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': + return - if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']: - embed_image = Image.open(path) - if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: - data = embedding_from_b64(embed_image.text['sd-ti-embedding']) - name = data.get('name', name) - else: - data = extract_image_data_embed(embed_image) - name = data.get('name', name) + embed_image = Image.open(path) + if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) + name = data.get('name', name) else: - data = torch.load(path, map_location="cpu") - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - else: - raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + data = extract_image_data_embed(embed_image) + name = data.get('name', name) + elif ext in ['.BIN', '.PT']: + data = torch.load(path, map_location="cpu") + elif ext in ['.SAFETENSORS']: + data = safetensors.torch.load_file(path, device="cpu") + else: + return - vec = emb.detach().to(devices.device, dtype=torch.float32) - embedding = Embedding(vec, name) - embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('sd_checkpoint', None) - embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vec.shape[0] + embedding.shape = vec.shape[-1] + + if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) + else: + self.skipped_embeddings[name] = embedding + + def load_from_dir(self, embdir): + if not os.path.isdir(embdir.path): + return + + for root, dirs, fns in os.walk(embdir.path): + for fn in fns: + try: + fullfn = os.path.join(root, fn) - for fn in os.listdir(self.embeddings_dir): - try: - fullfn = os.path.join(self.embeddings_dir, fn) + if os.stat(fullfn).st_size == 0: + continue - if os.stat(fullfn).st_size == 0: + self.load_from_file(fullfn, fn) + except Exception: + print(f"Error loading embedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) continue - process_file(fullfn, fn) - except Exception: - print(f"Error loading emedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - continue + def load_textual_inversion_embeddings(self, force_reload=False): + if not force_reload: + need_reload = False + for path, embdir in self.embedding_dirs.items(): + if embdir.has_changed(): + need_reload = True + break + + if not need_reload: + return + + self.ids_lookup.clear() + self.word_embeddings.clear() + self.skipped_embeddings.clear() + self.expected_shape = self.get_expected_shape() + + for path, embdir in self.embedding_dirs.items(): + self.load_from_dir(embdir) + embdir.update() - print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") - print("Embeddings:", ', '.join(self.word_embeddings.keys())) + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") + if len(self.skipped_embeddings) > 0: + print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): token = tokens[offset] diff --git a/modules/ui.py b/modules/ui.py index 200462e4b4d..b4c57d40280 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1455,8 +1455,6 @@ def update_orig(image, state): with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() - with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Row().style(equal_height=False): gr.HTML(value="

    See wiki for detailed explanation.

    ") From 0cdb81556978d7a394bd543897302cf0d2005036 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Fri, 7 Apr 2023 20:29:20 +0800 Subject: [PATCH 10/31] update webui.py --- webui.py | 93 +++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/webui.py b/webui.py index f9164a5045c..1fe18ca8318 100644 --- a/webui.py +++ b/webui.py @@ -43,13 +43,56 @@ import requests import json import uuid + +from huggingface_hub import hf_hub_download +import shutil +import glob + if not cmd_opts.api: from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig from extensions.sd_dreambooth_extension.scripts.dreambooth import start_training_from_config, create_model from extensions.sd_dreambooth_extension.scripts.dreambooth import performance_wizard, training_wizard from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept from modules import paths -import glob +elif not cmd_opts.pureui + import requests + cache = dict() + s3_client = boto3.client('s3') + s3_resource= boto3.resource('s3') + + def s3_download(s3uri, path): + pos = s3uri.find('/', 5) + bucket = s3uri[5 : pos] + key = s3uri[pos + 1 : ] + + s3_bucket = s3_resource.Bucket(bucket) + objs = list(s3_bucket.objects.filter(Prefix=key)) + + if os.path.isfile('cache'): + cache = json.load(open('cache', 'r')) + + for obj in objs: + if obj.key == key: + continue + response = s3_client.head_object( + Bucket = bucket, + Key = obj.key + ) + obj_key = 's3://{0}/{1}'.format(bucket, obj.key) + if obj_key not in cache or cache[obj_key] != response['ETag']: + filename = obj.key[obj.key.rfind('/') + 1 : ] + + s3_client.download_file(bucket, obj.key, os.path.join(path, filename)) + cache[obj_key] = response['ETag'] + + json.dump(cache, open('cache', 'w')) + + def http_download(httpuri, path): + with requests.get(httpuri, stream=True) as r: + r.raise_for_status() + with open(path, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) if cmd_opts.server_name: server_name = cmd_opts.server_name @@ -194,6 +237,54 @@ def user_auth(username, password): def webui(): launch_api = cmd_opts.api + + if launch_api: + models_config_s3uri = os.environ.get('models_config_s3uri', None) + if models_config_s3uri: + bucket, key = get_bucket_and_key(models_config_s3uri) + s3_object = s3_client.get_object(Bucket=bucket, Key=key) + bytes = s3_object["Body"].read() + payload = bytes.decode('utf8') + huggingface_models = json.loads(payload).get('huggingface_models', None) + s3_models = json.loads(payload).get('s3_models', None) + http_models = json.loads(payload).get('http_models', None) + else: + huggingface_models = os.environ.get('huggingface_models', None) + s3_models = os.environ.get('s3_models', None) + http_models = os.environ.get('http_models', None) + + if huggingface_models: + huggingface_models = json.loads(huggingface_models) + huggingface_token = huggingface_models['token'] + os.system(f'huggingface-cli login --token {huggingface_token}') + hf_hub_models = huggingface_models['models'] + for huggingface_model in hf_hub_models: + repo_id = huggingface_model['repo_id'] + filename = huggingface_model['filename'] + name = huggingface_model['name'] + + hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=f'/tmp/models/{name}', + cache_dir='/tmp/cache/huggingface' + ) + + if s3_models: + s3_models = json.loads(s3_models) + for s3_model in s3_models: + uri = s3_model['uri'] + name = s3_model['name'] + s3_download(uri, f'/tmp/models/{name}') + + if http_models: + http_models = json.loads(http_models) + for http_model in http_models: + uri = http_model['uri'] + filename = http_model['filename'] + name = http_model['name'] + http_download(uri, f'/tmp/models/{name}/{filename}') + initialize() while 1: From 98577258ede60912aea7664a41c8aaf1f5199cf2 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Fri, 7 Apr 2023 20:36:24 +0800 Subject: [PATCH 11/31] cleanup --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 1fe18ca8318..f2ead1ed2e1 100644 --- a/webui.py +++ b/webui.py @@ -54,7 +54,7 @@ from extensions.sd_dreambooth_extension.scripts.dreambooth import performance_wizard, training_wizard from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept from modules import paths -elif not cmd_opts.pureui +elif not cmd_opts.pureui: import requests cache = dict() s3_client = boto3.client('s3') From 5b62b1b94c2fac84aa094079de653db748b904ca Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Fri, 7 Apr 2023 23:04:18 +0800 Subject: [PATCH 12/31] fix issues with users --- modules/shared.py | 1 - modules/ui.py | 68 ++++++++++++++++++++++++++--------------------- webui.py | 20 -------------- 3 files changed, 38 insertions(+), 51 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 5dc3e912883..1d39070214a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -145,7 +145,6 @@ sagemaker_endpoint_component = None sd_model_checkpoint_component = None create_train_dreambooth_component = None - username = '' else: api_endpoint = cmd_opts.api_endpoint diff --git a/modules/ui.py b/modules/ui.py index b4c57d40280..0c4d5c81516 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -670,28 +670,6 @@ def open_folder(f): parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info -def update_sagemaker_endpoint(): - return gr.update(value=shared.opts.sagemaker_endpoint, choices=shared.sagemaker_endpoints) - -def update_sd_model_checkpoint(): - return gr.update(value=shared.opts.sd_model_checkpoint, choices=modules.sd_models.checkpoint_tiles()) - -def update_username(): - if shared.username == 'admin': - inputs = { - 'action': 'load' - } - response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) - if response.status_code == 200: - items = [] - for item in json.loads(response.text): - items.append([item['username'], item['password'], item['options'] if 'options' in item else '', shared.get_available_sagemaker_endpoints(item)]) - return gr.update(value=shared.username), gr.update(value=items if items != [] else None) - else: - return gr.update(value=shared.username), gr.update() - else: - return gr.update(value=shared.username), gr.update() - def create_ui(): import modules.img2img import modules.txt2img @@ -2038,9 +2016,6 @@ def save_userdata(user_dataframe, request: gr.Request): _js="var if alert('Only admin user can save user data')" ) - user_interface.load(update_sagemaker_endpoint, inputs=None, outputs=[shared.sagemaker_endpoint_component]) - user_interface.load(update_sd_model_checkpoint, inputs=None, outputs=[shared.sd_model_checkpoint_component]) - if cmd_opts.pureui: interfaces += [ (txt2img_interface, "txt2img", "txt2img"), @@ -2099,7 +2074,6 @@ def save_userdata(user_dataframe, request: gr.Request): outputs=[username_state, user_dataframe], _js="login" ) - user_interface.load(update_username, inputs=None, outputs=[username_state, user_dataframe]) with gr.Column(scale=1): logout_button = gr.Button(value="Logout") @@ -2150,13 +2124,47 @@ def user_logout(request: gr.Request): component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - def get_settings_values(): - return [getattr(opts, key) for key in component_keys] + def demo_load(request: gr.Request): + tokens = shared.demo.server_app.tokens + cookies = request.headers['cookie'].split('; ') + access_token = None + for cookie in cookies: + if cookie.startswith('access-token'): + access_token = cookie[len('access-token=') : ] + break + username = tokens[access_token] if access_token else None + + inputs = { + 'action': 'load' + } + response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs) + if response.status_code == 200: + if username == 'admin': + items = [] + for item in json.loads(response.text): + items.append([item['username'], item['password'], item['options'] if 'options' in item else '', shared.get_available_sagemaker_endpoints(item)]) + + additional_components = [gr.update(value=username), gr.update(value=items if items != [] else None), gr.update(), gr.update()] + else: + for item in json.loads(response.text): + if item['username'] == username: + try: + shared.opts.data = json.loads(item['options']) + break + except Exception as e: + print(e) + shared.refresh_sagemaker_endpoints(username) + shared.refresh_checkpoints(shared.opts.sagemaker_endpoint) + additional_components = [gr.update(value=username), gr.update(), gr.update(value=shared.opts.sagemaker_endpoint, choices=shared.sagemaker_endpoints), gr.update(value=shared.opts.sd_model_checkpoint, choices=modules.sd_models.checkpoint_tiles())] + else: + additional_components = [gr.update(value=username), gr.update(), gr.update(), gr.update()] + + return [getattr(opts, key) for key in component_keys] + additional_components demo.load( - fn=get_settings_values, + fn=demo_load, inputs=[], - outputs=[component_dict[k] for k in component_keys], + outputs=[component_dict[k] for k in component_keys] + [username_state, user_dataframe, shared.sagemaker_endpoint_component, shared.sd_model_checkpoint_component] ) if not cmd_opts.pureui: diff --git a/webui.py b/webui.py index f2ead1ed2e1..b9b56f85e3a 100644 --- a/webui.py +++ b/webui.py @@ -213,26 +213,6 @@ def user_auth(username, password): response = requests.post(url=f'{api_endpoint}/sd/login', json=inputs) - if response.status_code == 200: - try: - body = json.loads(response.text) - options = json.loads(json.loads(body)['options']) - except Exception as e: - print(e) - options = None - - if options != None: - shared.opts.data = options - - shared.refresh_sagemaker_endpoints(username) - shared.refresh_checkpoints(shared.opts.sagemaker_endpoint) - shared.username = username - modules.ui.update_sagemaker_endpoint() - modules.ui.update_sd_model_checkpoint() - modules.ui.update_username() - else: - print(response.text) - return response.status_code == 200 def webui(): From 78b3a69141d83155c2124ad4e868fc3a38be5d8c Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Sun, 9 Apr 2023 17:28:02 +0800 Subject: [PATCH 13/31] update webui.py --- .../.ipynb_checkpoints/lora-checkpoint.py | 207 ------------------ .../lora_script-checkpoint.py | 56 ----- webui.py | 40 +++- 3 files changed, 33 insertions(+), 270 deletions(-) delete mode 100644 extensions-builtin/Lora/.ipynb_checkpoints/lora-checkpoint.py delete mode 100644 extensions-builtin/Lora/scripts/.ipynb_checkpoints/lora_script-checkpoint.py diff --git a/extensions-builtin/Lora/.ipynb_checkpoints/lora-checkpoint.py b/extensions-builtin/Lora/.ipynb_checkpoints/lora-checkpoint.py deleted file mode 100644 index 42b9eb56aaf..00000000000 --- a/extensions-builtin/Lora/.ipynb_checkpoints/lora-checkpoint.py +++ /dev/null @@ -1,207 +0,0 @@ -import glob -import os -import re -import torch - -from modules import shared, devices, sd_models - -re_digits = re.compile(r"\d+") -re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") -re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") -re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") -re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") - - -def convert_diffusers_name_to_compvis(key): - def match(match_list, regex): - r = re.match(regex, key) - if not r: - return False - - match_list.clear() - match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) - return True - - m = [] - - if match(m, re_unet_down_blocks): - return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" - - if match(m, re_unet_mid_blocks): - return f"diffusion_model_middle_block_1_{m[1]}" - - if match(m, re_unet_up_blocks): - return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" - - if match(m, re_text_block): - return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" - - return key - - -class LoraOnDisk: - def __init__(self, name, filename): - self.name = name - self.filename = filename - - -class LoraModule: - def __init__(self, name): - self.name = name - self.multiplier = 1.0 - self.modules = {} - self.mtime = None - - -class LoraUpDownModule: - def __init__(self): - self.up = None - self.down = None - self.alpha = None - - -def assign_lora_names_to_compvis_modules(sd_model): - lora_layer_mapping = {} - - for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): - lora_name = name.replace(".", "_") - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name - - for name, module in shared.sd_model.model.named_modules(): - lora_name = name.replace(".", "_") - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name - - sd_model.lora_layer_mapping = lora_layer_mapping - - -def load_lora(name, filename): - lora = LoraModule(name) - lora.mtime = os.path.getmtime(filename) - - sd = sd_models.read_state_dict(filename) - - keys_failed_to_match = [] - - for key_diffusers, weight in sd.items(): - fullkey = convert_diffusers_name_to_compvis(key_diffusers) - key, lora_key = fullkey.split(".", 1) - - sd_module = shared.sd_model.lora_layer_mapping.get(key, None) - if sd_module is None: - keys_failed_to_match.append(key_diffusers) - continue - - lora_module = lora.modules.get(key, None) - if lora_module is None: - lora_module = LoraUpDownModule() - lora.modules[key] = lora_module - - if lora_key == "alpha": - lora_module.alpha = weight.item() - continue - - if type(sd_module) == torch.nn.Linear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.Conv2d: - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - else: - assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' - - with torch.no_grad(): - module.weight.copy_(weight) - - module.to(device=devices.device, dtype=devices.dtype) - - if lora_key == "lora_up.weight": - lora_module.up = module - elif lora_key == "lora_down.weight": - lora_module.down = module - else: - assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' - - if len(keys_failed_to_match) > 0: - print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") - - return lora - - -def load_loras(names, multipliers=None): - already_loaded = {} - - for lora in loaded_loras: - if lora.name in names: - already_loaded[lora.name] = lora - - loaded_loras.clear() - - loras_on_disk = [available_loras.get(name, None) for name in names] - if any([x is None for x in loras_on_disk]): - list_available_loras() - - loras_on_disk = [available_loras.get(name, None) for name in names] - - for i, name in enumerate(names): - lora = already_loaded.get(name, None) - - lora_on_disk = loras_on_disk[i] - if lora_on_disk is not None: - if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: - lora = load_lora(name, lora_on_disk.filename) - - if lora is None: - print(f"Couldn't find Lora with name {name}") - continue - - lora.multiplier = multipliers[i] if multipliers else 1.0 - loaded_loras.append(lora) - - -def lora_forward(module, input, res): - if len(loaded_loras) == 0: - return res - - lora_layer_name = getattr(module, 'lora_layer_name', None) - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is not None: - if shared.opts.lora_apply_to_outputs and res.shape == input.shape: - res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - else: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - - return res - - -def lora_Linear_forward(self, input): - return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) - - -def lora_Conv2d_forward(self, input): - return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) - - -def list_available_loras(): - available_loras.clear() - - os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) - - candidates = \ - glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \ - glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ - glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) - - for filename in sorted(candidates): - if os.path.isdir(filename): - continue - - name = os.path.splitext(os.path.basename(filename))[0] - - available_loras[name] = LoraOnDisk(name, filename) - - -available_loras = {} -loaded_loras = [] - -list_available_loras() diff --git a/extensions-builtin/Lora/scripts/.ipynb_checkpoints/lora_script-checkpoint.py b/extensions-builtin/Lora/scripts/.ipynb_checkpoints/lora_script-checkpoint.py deleted file mode 100644 index 302888387c2..00000000000 --- a/extensions-builtin/Lora/scripts/.ipynb_checkpoints/lora_script-checkpoint.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -import gradio as gr - -import lora -import extra_networks_lora -import ui_extra_networks_lora -from modules import script_callbacks, ui_extra_networks, extra_networks, shared - - -def unload(): - torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora - torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora - torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora - torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora - torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora - torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora - - -def before_ui(): - ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) - extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) - - -if not hasattr(torch.nn, 'Linear_forward_before_lora'): - torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward - -if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): - torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict - -if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): - torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward - -if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): - torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict - -if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): - torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward - -if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): - torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict - -torch.nn.Linear.forward = lora.lora_Linear_forward -torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict -torch.nn.Conv2d.forward = lora.lora_Conv2d_forward -torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict -torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward -torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict - -script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) -script_callbacks.on_script_unloaded(unload) -script_callbacks.on_before_ui(before_ui) - - -shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { - "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), -})) diff --git a/webui.py b/webui.py index b9b56f85e3a..86740e93c3a 100644 --- a/webui.py +++ b/webui.py @@ -311,10 +311,19 @@ def webui(): if cmd_controlnet_models_path is not None: cn_models_dir = cmd_controlnet_models_path + cmd_lora_models_path = cmd_opts.lora_dir + lora_models_dir = os.path.join(shared.models_path, "Lora") + if cmd_lora_models_path is not None: + lora_models_dir = cmd_lora_models_path + if 'endpoint_name' in os.environ: - items = [] api_endpoint = os.environ['api_endpoint'] endpoint_name = os.environ['endpoint_name'] + + items = [] + params = { + 'module': 'Stable-diffusion' + } for file in os.listdir(sd_models_dir): if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')): hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file)) @@ -329,17 +338,11 @@ def webui(): inputs = { 'items': items } - params = { - 'module': 'Stable-diffusion' - } if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) print(response) items = [] - inputs = { - 'items': items - } params = { 'module': 'ControlNet' } @@ -352,10 +355,33 @@ def webui(): item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) item['endpoint_name'] = endpoint_name items.append(item) + inputs = { + 'items': items + } + if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): + response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) + print(response) + items = [] + params = { + 'module': 'Lora' + } + for file in os.listdir(lora_models_dir): + if os.path.isfile(os.path.join(lora_models_dir, file)) and \ + (file.endswith('.pt') or file.endswith('.ckpt') or file.endswith('.safetensors')): + hash = modules.sd_models.model_hash(os.path.join(lora_models_dir, file)) + item = {} + item['model_name'] = file + item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) + item['endpoint_name'] = endpoint_name + items.append(item) + inputs = { + 'items': items + } if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) print(response) + ui_extra_networks.add_pages_to_demo(app) modules.script_callbacks.app_started_callback(shared.demo, app) From 5be73c960274cd12f42d6a32e2ad74a95acc0e32 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Sun, 9 Apr 2023 21:25:27 +0800 Subject: [PATCH 14/31] update webui.py --- modules/api/api.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index b442f8bce4a..edda5c941d9 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -137,6 +137,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.cache = dict() self.s3_client = boto3.client('s3') self.s3_resource= boto3.resource('s3') + self.generated_images_s3uri = os.environ.get('generated_images_s3uri', None) def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -399,6 +400,25 @@ def download_s3files(self, s3uri, path): json.dump(self.cache, open('cache', 'w')) + def post_invocations(self, b64images): + if self.generated_images_s3uri: + bucket, key = self.get_bucket_and_key(self.generated_images_s3uri) + images = [] + for b64image in b64images: + image = decode_base64_to_image(b64image).convert('RGB') + output = io.BytesIO() + image.save(output, format='JPEG') + image_id = str(uuid.uuid4()) + self.s3_client.put_object( + Body=output.getvalue(), + Bucket=bucket, + Key=f'{key}/{image_id}.jpg' + ) + images.append(f's3://{bucket}/{key}/{image_id}.jpg') + return images + else: + return b64images + def invocations(self, req: InvocationsRequest): print('-------invocation------') print(req) @@ -433,24 +453,26 @@ def invocations(self, req: InvocationsRequest): self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() response = self.text2imgapi(req.txt2img_payload) + response.images = self.post_invocations(response.images) shared.opts.data = default_options return response elif req.task == 'image-to-image': self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() response = self.img2imgapi(req.img2img_payload) + response.images = self.post_invocations(response.images) shared.opts.data = default_options return response elif req.task == 'extras-single-image': response = self.extras_single_image_api(req.extras_single_payload) + response.image = self.post_invocations([response.image])[0] shared.opts.data = default_options return response elif req.task == 'extras-batch-images': response = self.extras_batch_images_api(req.extras_batch_payload) + response.images = self.post_invocations(response.images) shared.opts.data = default_options return response - elif req.task == 'sd-models': - return self.get_sd_models() else: raise NotImplementedError except Exception as e: @@ -463,3 +485,9 @@ def ping(self): def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port) + + def get_bucket_and_key(self, s3uri): + pos = s3uri.find('/', 5) + bucket = s3uri[5 : pos] + key = s3uri[pos + 1 : ] + return bucket, key From 43314625e3814fc9191ec2dd2925a9e0f417992e Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Sun, 9 Apr 2023 22:06:13 +0800 Subject: [PATCH 15/31] save generated imagess and models with user specific path --- modules/api/api.py | 18 ++++++++++-------- webui.py | 10 +++++----- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index edda5c941d9..566b2618521 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -137,7 +137,6 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.cache = dict() self.s3_client = boto3.client('s3') self.s3_resource= boto3.resource('s3') - self.generated_images_s3uri = os.environ.get('generated_images_s3uri', None) def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -400,9 +399,12 @@ def download_s3files(self, s3uri, path): json.dump(self.cache, open('cache', 'w')) - def post_invocations(self, b64images): - if self.generated_images_s3uri: - bucket, key = self.get_bucket_and_key(self.generated_images_s3uri) + def post_invocations(self, username, b64images): + generated_images_s3uri = os.environ.get('generated_images_s3uri', None) + + if generated_images_s3uri: + generated_images_s3uri = f'{generated_images_s3uri}{username}/' + bucket, key = self.get_bucket_and_key(generated_images_s3uri) images = [] for b64image in b64images: image = decode_base64_to_image(b64image).convert('RGB') @@ -453,24 +455,24 @@ def invocations(self, req: InvocationsRequest): self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() response = self.text2imgapi(req.txt2img_payload) - response.images = self.post_invocations(response.images) + response.images = self.post_invocations(username, response.images) shared.opts.data = default_options return response elif req.task == 'image-to-image': self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() response = self.img2imgapi(req.img2img_payload) - response.images = self.post_invocations(response.images) + response.images = self.post_invocations(username, response.images) shared.opts.data = default_options return response elif req.task == 'extras-single-image': response = self.extras_single_image_api(req.extras_single_payload) - response.image = self.post_invocations([response.image])[0] + response.image = self.post_invocations(username, [response.image])[0] shared.opts.data = default_options return response elif req.task == 'extras-batch-images': response = self.extras_batch_images_api(req.extras_batch_payload) - response.images = self.post_invocations(response.images) + response.images = self.post_invocations(username, response.images) shared.opts.data = default_options return response else: diff --git a/webui.py b/webui.py index 86740e93c3a..0b4acb324d1 100644 --- a/webui.py +++ b/webui.py @@ -838,28 +838,28 @@ def train(): print('Uploading SD Models...') if db_config.v2: upload_s3files( - sd_models_s3uri, + f'{sd_models_s3uri}/{username}/', os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.yaml') ) if db_config.save_safetensors: upload_s3files( - sd_models_s3uri, + f'{sd_models_s3uri}/{username}/', os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.safetensors') ) else: upload_s3files( - sd_models_s3uri, + f'{sd_models_s3uri}/{username}/', os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.ckpt') ) print('Uploading DB Models...') upload_s3folder( - f'{db_models_s3uri}{db_model_name}', + f'{db_models_s3uri}{username}/{db_model_name}', os.path.join(db_model_dir, db_model_name) ) if db_config.use_lora: print('Uploading Lora Models...') upload_s3files( - lora_models_s3uri, + f'{lora_models_s3uri}/{username}/', os.path.join(lora_model_dir, f'{db_model_name}_*.pt') ) #automatic tar latest checkpoint and upload to s3 by zheng on 2023.03.22 From ac15d0324007ce37f749be3f5f2ee3f47e13adeb Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Sun, 9 Apr 2023 23:29:57 +0800 Subject: [PATCH 16/31] update stable-diffusion-webui --- webui.py | 74 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 45 insertions(+), 29 deletions(-) diff --git a/webui.py b/webui.py index 0b4acb324d1..d9342241a6b 100644 --- a/webui.py +++ b/webui.py @@ -215,6 +215,27 @@ def user_auth(username, password): return response.status_code == 200 +def get_bucket_and_key(s3uri): + pos = s3uri.find('/', 5) + bucket = s3uri[5 : pos] + key = s3uri[pos + 1 : ] + return bucket, key + +def get_models(path, extensions): + candidates = [] + models = [] + + for extension in extensions: + candidates = candidates + glob.glob(os.path.join(path, f'**/{extension}'), recursive=True) + + for filename in sorted(candidates, key=str.lower): + if os.path.isdir(filename): + continue + + models.append(filename) + + return models + def webui(): launch_api = cmd_opts.api @@ -301,6 +322,8 @@ def webui(): if launch_api: create_api(app) + os.path.splitext(os.path.basename(filename))[0] + cmd_sd_models_path = cmd_opts.ckpt_dir sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion") if cmd_sd_models_path is not None: @@ -324,17 +347,14 @@ def webui(): params = { 'module': 'Stable-diffusion' } - for file in os.listdir(sd_models_dir): - if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')): - hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file)) - item = {} - item['model_name'] = file - item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' - item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file) - item['hash'] = hash - item['title'] = '{0} [{1}]'.format(file, hash) - item['endpoint_name'] = endpoint_name - items.append(item) + for file in get_models(sd_models_dir, ['*.ckpt', '*.safetensors']): + hash = modules.sd_models.model_hash(file) + item = {} + item['model_name'] = os.path.basename(file) + item['hash'] = hash + item['title'] = '{0} [{1}]'.format(os.path.basename(file), hash) + item['endpoint_name'] = endpoint_name + items.append(item) inputs = { 'items': items } @@ -346,15 +366,13 @@ def webui(): params = { 'module': 'ControlNet' } - for file in os.listdir(cn_models_dir): - if os.path.isfile(os.path.join(cn_models_dir, file)) and \ - (file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')): - hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file)) - item = {} - item['model_name'] = file - item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) - item['endpoint_name'] = endpoint_name - items.append(item) + for file in get_models(cn_models_dir, ['*.pt', '*.pth', '*.ckpt', '*.safetensors']): + hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file)) + item = {} + item['model_name'] = os.path.basename(file) + item['title'] = '{0} [{1}]'.format(os.path.splitext(os.path.basename(file))[0], hash) + item['endpoint_name'] = endpoint_name + items.append(item) inputs = { 'items': items } @@ -366,15 +384,13 @@ def webui(): params = { 'module': 'Lora' } - for file in os.listdir(lora_models_dir): - if os.path.isfile(os.path.join(lora_models_dir, file)) and \ - (file.endswith('.pt') or file.endswith('.ckpt') or file.endswith('.safetensors')): - hash = modules.sd_models.model_hash(os.path.join(lora_models_dir, file)) - item = {} - item['model_name'] = file - item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) - item['endpoint_name'] = endpoint_name - items.append(item) + for file in get_models(lora_models_dir, ['*.pt', '*.ckpt', '*.safetensors']): + hash = modules.sd_models.model_hash(os.path.join(lora_models_dir, file)) + item = {} + item['model_name'] = os.path.basename(file) + item['title'] = '{0} [{1}]'.format(os.path.splitext(os.path.basename(file))[0], hash) + item['endpoint_name'] = endpoint_name + items.append(item) inputs = { 'items': items } From b155b27d95fa0c574a771ffe3f2059b87f5bc096 Mon Sep 17 00:00:00 2001 From: xie river Date: Mon, 10 Apr 2023 01:15:18 +0000 Subject: [PATCH 17/31] dynamic model laoding --- localizations/zh_CN.json | 4 + modules/api/api.py | 55 +++- modules/api/models.py | 5 +- modules/call_queue.py | 2 +- modules/sd_models.py | 3 + modules/shared.py | 149 +++++++++- modules/ui.py | 86 +++++- requirements_versions.txt | 1 + requirements_versions.txt.cn | 1 + webui.py | 521 ++++++++++++++++++++++++----------- 10 files changed, 649 insertions(+), 178 deletions(-) diff --git a/localizations/zh_CN.json b/localizations/zh_CN.json index 5dfde0f8fb8..95dd39412f6 100644 --- a/localizations/zh_CN.json +++ b/localizations/zh_CN.json @@ -845,5 +845,9 @@ "Upload Train Images to S3":"上传训练图片到S3", "Error, please configure a S3 bucket at settings page first":"失败,请先到设置页面配置S3桶名", "Upload":"上传", + "Reload all models":"重新加载模型文件", + "Update model files path":"更新模型加载路径", + "S3 path for downloading model files (E.g, s3://bucket-name/models/)":"加载模型的S3路径,例如:s3://bucket-name/models/", + "Images Viewer":"图片浏览器", "--------": "--------" } diff --git a/modules/api/api.py b/modules/api/api.py index a065d314d6e..5a5651c155d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -8,7 +8,7 @@ from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest - +from modules.shared import de_register_model import modules.shared as shared from modules import sd_samplers, deepbooru from modules.api.models import * @@ -409,6 +409,7 @@ def invocations(self, req: InvocationsRequest): try: username = req.username default_options = shared.opts.data + if username != '': inputs = { 'action': 'get', @@ -428,7 +429,10 @@ def invocations(self, req: InvocationsRequest): self.download_s3files(hypernetwork_s3uri, os.path.join(script_path, shared.cmd_opts.hypernetwork_dir)) hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) hypernetworks.hypernetwork.apply_strength() - + ##add sd model usage stats by River + print(f'default_options:{shared.opts.data}') + shared.sd_models_Ref.add_models_ref(shared.opts.data['sd_model_checkpoint']) + ##end if req.task == 'text-to-image': self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() @@ -451,6 +455,11 @@ def invocations(self, req: InvocationsRequest): return response elif req.task == 'sd-models': return self.get_sd_models() + elif req.task == 'reload-all-models': + return self.reload_all_models() + elif req.task == 'set-models-bucket': + bucket = req.models_bucket + return self.set_models_bucket(bucket) else: raise NotImplementedError except Exception as e: @@ -460,6 +469,48 @@ def ping(self): # print('-------ping------') return {'status': 'Healthy'} + def reload_all_models(self): + print('-------reload_all_models------') + def remove_files(path): + for file_name in os.listdir(path): + file_path = os.path.join(path, file_name) + if os.path.isfile(file_path): + os.remove(file_path) + print(f'{file_path} has been removed') + if file_path.find('Stable-diffusion'): + de_register_model(file_name,'sd') + elif file_path.find('ControlNet'): + de_register_model(file_name,'cn') + elif os.path.isdir(file_path): + remove_files(file_path) + os.rmdir(file_path) + shared.syncLock.acquire() + #remove all files in /tmp/models/ and /tmp/cache/ + remove_files(shared.tmp_models_dir) + remove_files(shared.tmp_cache_dir) + shared.syncLock.release() + return {'simple_result':'success'} + + def set_models_bucket(self,bucket): + shared.syncLock.acquire() + if bucket.endswith('/'): + bucket = bucket[:-1] + url_parts = bucket.replace('s3://','').split('/') + shared.models_s3_bucket = url_parts[0] + lastfolder = url_parts[-1] + if lastfolder == 'Stable-diffusion': + shared.s3_folder_sd = '/'.join(url_parts[1:]) + elif lastfolder == 'ControlNet': + shared.s3_folder_cn = '/'.join(url_parts[1:]) + else: + shared.s3_folder_sd = '/'.join(url_parts[1:]+['Stable-diffusion']) + shared.s3_folder_cn = '/'.join(url_parts[1:]+['ControlNet']) + print(f'set_models_bucket to {shared.models_s3_bucket}') + print(f'set_s3_folder_sd to {shared.s3_folder_sd}') + print(f'set_s3_folder_cn to {shared.s3_folder_cn}') + shared.syncLock.release() + return {'simple_result':'success'} + def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port) diff --git a/modules/api/models.py b/modules/api/models.py index bc61ae90534..d655d8f7b27 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -244,11 +244,12 @@ class ArtistItem(BaseModel): class InvocationsRequest(BaseModel): task: str username: Optional[str] + models_bucket:Optional[str] + simple_result:Optional[str] txt2img_payload: Optional[StableDiffusionTxt2ImgProcessingAPI] img2img_payload: Optional[StableDiffusionImg2ImgProcessingAPI] extras_single_payload: Optional[ExtrasSingleImageRequest] extras_batch_payload: Optional[ExtrasBatchImagesRequest] class PingResponse(BaseModel): - status: str - + status: str \ No newline at end of file diff --git a/modules/call_queue.py b/modules/call_queue.py index 3515a88b67a..7578d963188 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -505,4 +505,4 @@ def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs): return tuple(res) - return f + return f \ No newline at end of file diff --git a/modules/sd_models.py b/modules/sd_models.py index 5c2d7e13d2a..a4f361c7778 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -19,6 +19,7 @@ model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) + CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) checkpoints_list = {} checkpoints_loaded = collections.OrderedDict() @@ -161,6 +162,8 @@ def model_hash(filename): def select_checkpoint(): + ##add log by Rive + print('checkpoints_list:',checkpoints_list) model_checkpoint = shared.opts.sd_model_checkpoint checkpoint_info = checkpoints_list.get(model_checkpoint, None) if checkpoint_info is not None: diff --git a/modules/shared.py b/modules/shared.py index f9b9897a04e..c2d7b152e80 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -4,7 +4,7 @@ import os import sys import time - +import threading import gradio as gr import tqdm @@ -16,8 +16,17 @@ from modules import localization, sd_vae, extensions, script_loading from modules.paths import models_path, script_path, sd_path import requests +import boto3 demo = None +#Add by River +models_s3_bucket = None +s3_folder_sd = None +s3_folder_cn = None +syncLock = threading.Lock() +tmp_models_dir = '/tmp/models' +tmp_cache_dir = '/tmp/cache' +#end sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file @@ -273,11 +282,144 @@ def do_set_current_image(self): face_restorers = [] +def get_default_sagemaker_bucket(default_region = 'us-west-2'): + session = boto3.Session() + region_name = session.region_name if session.region_name else default_region + sts_client = session.client('sts') + account_id = sts_client.get_caller_identity()['Account'] + return f"s3://sagemaker-{region_name}-{account_id}" def realesrgan_models_names(): import modules.realesrgan_model return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] +#add by River +class ModelsRef: + def __init__(self): + self.models_ref = {} + + def get_models_ref_dict(self): + return self.models_ref + + def add_models_ref(self, model_name): + if model_name in self.models_ref: + self.models_ref[model_name] += 1 + else: + self.models_ref[model_name] = 0 + + def remove_model_ref(self,model_name): + if self.models_ref.get(model_name): + del self.models_ref[model_name] + + def get_models_ref(self, model_name): + return self.models_ref.get(model_name) + + def get_least_ref_model(self): + sorted_models = sorted(self.models_ref.items(), key=lambda item: item[1]) + if sorted_models: + least_ref_model, least_counter = sorted_models[0] + return least_ref_model,least_counter + else: + return None,None + + def pop_least_ref_model(self): + sorted_models = sorted(self.models_ref.items(), key=lambda item: item[1]) + if sorted_models: + least_ref_model, least_counter = sorted_models[0] + del self.models_ref[least_ref_model] + return least_ref_model,least_counter + else: + return None,None + +sd_models_Ref = ModelsRef() +cn_models_Ref = ModelsRef() + +def register_models(models_dir,mode): + if mode == 'sd': + register_sd_models(models_dir) + elif mode == 'cn': + register_cn_models(models_dir) + +def register_sd_models(sd_models_dir): + print ('---register_sd_models()----') + if 'endpoint_name' in os.environ: + items = [] + api_endpoint = os.environ['api_endpoint'] + endpoint_name = os.environ['endpoint_name'] + print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') + for file in os.listdir(sd_models_dir): + if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')): + hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file)) + item = {} + item['model_name'] = file + item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' + item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file) + item['hash'] = hash + item['title'] = '{0} [{1}]'.format(file, hash) + item['endpoint_name'] = endpoint_name + items.append(item) + inputs = { + 'items': items + } + params = { + 'module': 'Stable-diffusion' + } + if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): + response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) + print(response) + +def register_cn_models(cn_models_dir): + print ('---register_cn_models()----') + if 'endpoint_name' in os.environ: + items = [] + api_endpoint = os.environ['api_endpoint'] + endpoint_name = os.environ['endpoint_name'] + print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') + + inputs = { + 'items': items + } + params = { + 'module': 'ControlNet' + } + for file in os.listdir(cn_models_dir): + if os.path.isfile(os.path.join(cn_models_dir, file)) and \ + (file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')): + hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file)) + item = {} + item['model_name'] = file + item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) + item['endpoint_name'] = endpoint_name + items.append(item) + + if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): + response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) + print(response) + +def de_register_model(model_name,mode): + models_Ref = sd_models_Ref + if mode == 'sd' : + models_Ref = sd_models_Ref + elif mode == 'cn': + models_Ref = cn_models_Ref + models_Ref.remove_model_ref(model_name) + print (f'---de_register_{mode}_model({model_name})---models_Ref({models_Ref.get_models_ref_dict()})----') + if 'endpoint_name' in os.environ: + api_endpoint = os.environ['api_endpoint'] + endpoint_name = os.environ['endpoint_name'] + data = { + "module":mode, + "model_name": model_name, + "endpoint_name": endpoint_name + } + response = requests.delete(url=f'{api_endpoint}/sd/models', json=data) + # Check if the request was successful + if response.status_code == requests.codes.ok: + print(f"{model_name} deleted successfully!") + else: + print(f"Error deleting {model_name}: ", response.text) + +#end by River class OptionInfo: def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): @@ -371,6 +513,8 @@ def refresh_sagemaker_endpoints(username): return sagemaker_endpoints + + options_templates.update(options_section(('sd', "Stable Diffusion"), { "sagemaker_endpoint": OptionInfo(None, "SaegMaker endpoint", gr.Dropdown, lambda: {"choices": list_sagemaker_endpoints()}, refresh=refresh_sagemaker_endpoints), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), @@ -423,7 +567,7 @@ def refresh_sagemaker_endpoints(username): })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { - "train_files_s3bucket":OptionInfo("","S3 bucket name for uploading/downloading images",component_args=hide_dirs), + "train_files_s3bucket":OptionInfo(get_default_sagemaker_bucket(),"S3 bucket name for uploading/downloading images",component_args=hide_dirs), "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), @@ -640,6 +784,7 @@ def reorder(self): if cmd_opts.pureui and opts.localization == "None": opts.localization = "zh_CN" + sd_upscalers = [] sd_model = None diff --git a/modules/ui.py b/modules/ui.py index 9bf067fd907..b3c724063ed 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -24,7 +24,7 @@ from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru from modules.paths import script_path -from modules.shared import opts, cmd_opts, restricted_opts +from modules.shared import opts, cmd_opts, restricted_opts,get_default_sagemaker_bucket import modules.codeformer_model import modules.generation_parameters_copypaste as parameters_copypaste @@ -733,6 +733,21 @@ def create_ui(): interfaces = [] + ##add images viewer + def translate(text): + return f'translated:{text}' + with gr.Blocks(analytics_enabled=False) as imagesviewer_interface: + with gr.Row().style(equal_height=False): + with gr.Column(): + english = gr.Textbox(label="Placeholder") + translate_btn = gr.Button(value="Translate") + with gr.Column(): + german = gr.Textbox(label="German Text") + + translate_btn.click(translate, inputs=english, outputs=german, api_name="translate-to-german") + examples = gr.Examples(examples=["I went to the supermarket yesterday.", "Helen is a good swimmer."], + inputs=[english]) + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): @@ -761,6 +776,7 @@ def create_ui(): else: dreambooth_tab = ui_tab[0] + def create_setting_component(key, is_quicksettings=False): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -879,11 +895,22 @@ def run_settings_single(value, key, request : gr.Request): return gr.update(value=value), opts.dumpjson() + default_sagemaker_s3 = get_default_sagemaker_bucket() + default_s3_path = f"{default_sagemaker_s3}/stable-diffusion-webui/models/" with gr.Blocks(analytics_enabled=False) as settings_interface: dummy_component = gr.Label(visible=False) - - settings_submit = gr.Button(value="Apply settings", variant='primary') - + with gr.Row(): + settings_submit = gr.Button(value="Apply settings", variant='primary') + with gr.Row(): + with gr.Column(scale=2): + models_s3bucket = gr.Textbox(label="S3 path for downloading model files (E.g, s3://bucket-name/models/)", + value=default_s3_path) + with gr.Column(scale=1): + set_models_s3bucket_btn = gr.Button(value="Update model files path",elem_id='id_set_models_s3bucket') + with gr.Column(scale=1): + reload_models_btn = gr.Button(value='Reload all models', elem_id='id_reload_all_models') + + result = gr.HTML() settings_cols = 3 @@ -898,6 +925,7 @@ def run_settings_single(value, key, request : gr.Request): items_displayed = 0 previous_section = None column = None + with gr.Row(elem_id="settings").style(equal_height=False): for i, (k, item) in enumerate(opts.data_labels.items()): section_must_be_skipped = item.section[0] is None @@ -916,6 +944,7 @@ def run_settings_single(value, key, request : gr.Request): previous_section = item.section elem_id, text = item.section + print(text) gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

    {}

    '.format(text)) if k in quicksettings_names and not shared.cmd_opts.freeze_settings: @@ -972,6 +1001,48 @@ def request_restart(): outputs=[], ) + def reload_all_models(): + sagemaker_endpoint=shared.opts.sagemaker_endpoint + print(f'reload_all_models from:{sagemaker_endpoint}') + inputs = {'task': 'reload-all-models'} + params = {'endpoint_name': sagemaker_endpoint} + response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs) + if response.status_code == 200: + return f'[{sagemaker_endpoint}] reload_all_models success' + else: + print(response.status_code ) + return f'[{sagemaker_endpoint}] reload_all_models failed' + + reload_models_btn.click( + fn=reload_all_models, + inputs=[], + outputs=[result] + ) + + def set_models_s3bucket(bucket_name): + if bucket_name == '': + return 'Error, please configure a S3 bucket for downloading model files' + sagemaker_endpoint=shared.opts.sagemaker_endpoint + print(f'set_models_s3bucket to:{sagemaker_endpoint}') + inputs = {'task': 'set-models-bucket', + 'models_bucket':bucket_name} + params = {'endpoint_name': + sagemaker_endpoint} + response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs) + if response.status_code == 200: + return f'[{sagemaker_endpoint}] set bucket succeess' + else: + print(response.status_code ) + return f'[{sagemaker_endpoint}] set bucket failed' + + + set_models_s3bucket_btn.click( + fn=set_models_s3bucket, + inputs=[models_s3bucket], + outputs=[result] + + ) + if column is not None: column.__exit__() @@ -1498,7 +1569,7 @@ def upload_to_s3(imgs): bucket_name = opts.train_files_s3bucket if bucket_name == '': return 'Error, please configure a S3 bucket at settings page first' - s3_bucket = s3_resource.Bucket(bucket_name) + s3_bucket = s3_resource.Bucket(bucket_name.replace('s3://','')) folder_name = f"train-images/{username}/{timestamp}" try: for i, img in enumerate(imgs): @@ -1509,12 +1580,12 @@ def upload_to_s3(imgs): print(e) return e - return f"{len(imgs)} images uploaded to S3 folder: s3://{bucket_name}/{folder_name}" + return f"{len(imgs)} images uploaded to S3 folder:{bucket_name}/{folder_name}" with gr.Tab(label="Upload Train Images to S3"): upload_files = gr.Files(label="Files") url_output = gr.Textbox(label="Output S3 folder") - sub_btn = gr.Button("Upload") + sub_btn = gr.Button(label="Upload",variant='primary',elem_id='id_upload_train_files') sub_btn.click(fn=upload_to_s3, inputs=upload_files, outputs=url_output) ## End add s3 images upload interface by River with gr.Tab(label="Train Embedding"): @@ -2150,6 +2221,7 @@ def save_userdata(user_dataframe, request: gr.Request): # interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings_interface, "Settings", "settings")] + interfaces += [(imagesviewer_interface,"Images Viewer","imagesviewer")] extensions_interface = ui_extensions.create_ui() interfaces += [(extensions_interface, "Extensions", "extensions")] diff --git a/requirements_versions.txt b/requirements_versions.txt index b5d15750386..f424578202b 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -29,3 +29,4 @@ torchsde==0.2.5 safetensors==0.2.5 fastapi==0.90.1 boto3 +psutil diff --git a/requirements_versions.txt.cn b/requirements_versions.txt.cn index 985a27a33fc..4edd5cba387 100644 --- a/requirements_versions.txt.cn +++ b/requirements_versions.txt.cn @@ -29,3 +29,4 @@ GitPython==3.1.27 torchsde==0.2.5 safetensors==0.2.5 boto3 +psutil \ No newline at end of file diff --git a/webui.py b/webui.py index 3fd617579f2..8fd5cf34fb2 100644 --- a/webui.py +++ b/webui.py @@ -11,9 +11,10 @@ from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse - +import psutil from modules.call_queue import wrap_queued_call, queue_lock from modules.paths import script_path +from collections import OrderedDict from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer @@ -33,7 +34,7 @@ import modules.ui from modules import modelloader -from modules.shared import cmd_opts, opts +from modules.shared import cmd_opts, opts, sd_model,syncLock,de_register_model,register_models import modules.hypernetworks.hypernetwork import boto3 import threading @@ -57,6 +58,7 @@ else: server_name = "0.0.0.0" if cmd_opts.listen else None +FREESPACE = 20 def initialize(): extensions.list_extensions() @@ -66,20 +68,26 @@ def initialize(): shared.sd_upscalers = upscaler.UpscalerLanczos().scalers modules.scripts.load_scripts() return - + ## auto reload new models from s3 add by River - sd_models_tmp_dir = "/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/" - cn_models_tmp_dir = "/opt/ml/code/stable-diffusion-webui/models/ControlNet/" - session = boto3.Session() - region_name = session.region_name - sts_client = session.client('sts') - account_id = sts_client.get_caller_identity()['Account'] - sg_defaul_bucket_name = f"sagemaker-{region_name}-{account_id}" - s3_folder_sd = "stable-diffusion-webui/models/Stable-diffusion" - s3_folder_cn = "stable-diffusion-webui/models/ControlNet" - - sync_s3_folder(sg_defaul_bucket_name,s3_folder_sd,sd_models_tmp_dir,'sd') - sync_s3_folder(sg_defaul_bucket_name,s3_folder_cn,cn_models_tmp_dir,'cn') + if not cmd_opts.pureui and not cmd_opts.train: + print(os.system('df -h')) + sd_models_tmp_dir = f"{shared.tmp_models_dir}/Stable-diffusion/" + cn_models_tmp_dir = f"{shared.tmp_models_dir}/ControlNet/" + cache_dir = f"{shared.tmp_cache_dir}/" + session = boto3.Session() + region_name = session.region_name + sts_client = session.client('sts') + account_id = sts_client.get_caller_identity()['Account'] + if not shared.models_s3_bucket: + shared.models_s3_bucket = f"sagemaker-{region_name}-{account_id}" + shared.s3_folder_sd = "stable-diffusion-webui/models/Stable-diffusion" + shared.s3_folder_cn = "stable-diffusion-webui/models/ControlNet" + + #only download the first file from defaul biucket, to accerlate the startup time + initial_s3_download(shared.s3_folder_sd,sd_models_tmp_dir,cache_dir,'sd') + sync_s3_folder(sd_models_tmp_dir,cache_dir,'sd') + sync_s3_folder(cn_models_tmp_dir,cache_dir,'cn') ## end modelloader.cleanup_models() @@ -152,7 +160,6 @@ def wait_on_server(demo=None): def api_only(): initialize() - app = FastAPI() setup_cors(app) app.add_middleware(GZipMiddleware, minimum_size=1000) @@ -200,114 +207,300 @@ def user_auth(username, password): return response.status_code == 200 +# def register_models(models_dir,mode): +# if mode == 'sd': +# register_sd_models(models_dir) +# elif mode == 'cn': +# register_cn_models(models_dir) + +# def register_sd_models(sd_models_dir): +# print ('---register_sd_models()----') +# if 'endpoint_name' in os.environ: +# items = [] +# api_endpoint = os.environ['api_endpoint'] +# endpoint_name = os.environ['endpoint_name'] +# print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') +# for file in os.listdir(sd_models_dir): +# if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')): +# hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file)) +# item = {} +# item['model_name'] = file +# item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' +# item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file) +# item['hash'] = hash +# item['title'] = '{0} [{1}]'.format(file, hash) +# item['endpoint_name'] = endpoint_name +# items.append(item) +# inputs = { +# 'items': items +# } +# params = { +# 'module': 'Stable-diffusion' +# } +# if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): +# response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) +# print(response) + +# def register_cn_models(cn_models_dir): +# print ('---register_cn_models()----') +# if 'endpoint_name' in os.environ: +# items = [] +# api_endpoint = os.environ['api_endpoint'] +# endpoint_name = os.environ['endpoint_name'] +# print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') + +# inputs = { +# 'items': items +# } +# params = { +# 'module': 'ControlNet' +# } +# for file in os.listdir(cn_models_dir): +# if os.path.isfile(os.path.join(cn_models_dir, file)) and \ +# (file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')): +# hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file)) +# item = {} +# item['model_name'] = file +# item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) +# item['endpoint_name'] = endpoint_name +# items.append(item) + +# if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): +# response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) +# print(response) + +# def de_register_model(model_name,mode): +# models_Ref = shared.sd_models_Ref +# if mode == 'sd' : +# models_Ref = shared.sd_models_Ref +# elif mode == 'cn': +# models_Ref = shared.cn_models_Ref +# models_Ref.remove_model_ref(model_name) +# print (f'---de_register_{mode}_model({model_name})---models_Ref({models_Ref.get_models_ref_dict()})----') +# if 'endpoint_name' in os.environ: +# api_endpoint = os.environ['api_endpoint'] +# endpoint_name = os.environ['endpoint_name'] +# data = { +# "module":mode, +# "model_name": model_name, +# "endpoint_name": endpoint_name +# } +# response = requests.delete(url=f'{api_endpoint}/sd/models', json=data) +# # Check if the request was successful +# if response.status_code == requests.codes.ok: +# print(f"{model_name} deleted successfully!") +# else: +# print(f"Error deleting {model_name}: ", response.text) + + + +def check_space_s3_download(s3,bucket_name,s3_folder,local_folder,file,size,mode): + src = s3_folder + '/' + file + dist = os.path.join(local_folder, file) + # Get disk usage statistics + disk_usage = psutil.disk_usage('/tmp') + freespace = disk_usage.free/(1024**3) + print(f"Total space: {disk_usage.total/(1024**3)}, Used space: {disk_usage.used/(1024**3)}, Free space: {freespace}") + if freespace - size >= FREESPACE: + try: + s3.download_file(bucket_name, src, dist) + #init ref cnt to 0, when the model file first time download + hash = modules.sd_models.model_hash(dist) + if mode == 'sd' : + shared.sd_models_Ref.add_models_ref('{0} [{1}]'.format(file, hash)) + elif mode == 'cn': + shared.cn_models_Ref.add_models_ref('{0} [{1}]'.format(os.path.splitext(file)[0], hash)) + print(f'download_file success:from {bucket_name}/{src} to {dist}') + except Exception as e: + print(f'download_file error: from {bucket_name}/{src} to {dist}') + print(f"An error occurred: {e}") + return False + return True + else: + return False -def register_sd_models(sd_models_dir): - print ('---register_sd_models()----') - if 'endpoint_name' in os.environ: - items = [] - api_endpoint = os.environ['api_endpoint'] - endpoint_name = os.environ['endpoint_name'] - print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') - for file in os.listdir(sd_models_dir): - if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')): - hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file)) - item = {} - item['model_name'] = file - item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' - item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file) - item['hash'] = hash - item['title'] = '{0} [{1}]'.format(file, hash) - item['endpoint_name'] = endpoint_name - items.append(item) - inputs = { - 'items': items - } - params = { - 'module': 'Stable-diffusion' - } - if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): - response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) - print(response) - -def register_cn_models(cn_models_dir): - print ('---register_cn_models()----') - if 'endpoint_name' in os.environ: - items = [] - api_endpoint = os.environ['api_endpoint'] - endpoint_name = os.environ['endpoint_name'] - print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') - - inputs = { - 'items': items - } - params = { - 'module': 'ControlNet' - } - for file in os.listdir(cn_models_dir): - if os.path.isfile(os.path.join(cn_models_dir, file)) and \ - (file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')): - hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file)) - item = {} - item['model_name'] = file - item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) - item['endpoint_name'] = endpoint_name - items.append(item) - - if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): - response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) - print(response) - - -def sync_s3_folder(bucket_name, s3_folder, local_folder,mode): - print(f"sync S3 bucket '{bucket_name}', folder '{s3_folder}' for new files...") +def free_local_disk(local_folder,mode): + models_Ref = None + if mode == 'sd' : + models_Ref = shared.sd_models_Ref + elif mode == 'cn': + models_Ref = shared.cn_models_Ref + # Get disk usage statistics + # disk_usage = psutil.disk_usage('/tmp') + # freespace = disk_usage.free/(1024**3) + # while freespace < FREESPACE: + model_name,ref_cnt = models_Ref.get_least_ref_model() + print (f'shared.{mode}_models_Ref:{models_Ref.get_models_ref_dict()} -- model_name:{model_name}') + if model_name and ref_cnt: + filename = model_name[:model_name.rfind("[")] + os.remove(os.path.join(local_folder, filename)) + disk_usage = psutil.disk_usage('/tmp') + freespace = disk_usage.free/(1024**3) + print(f"Remove file: {os.path.join(local_folder, filename)} now left space:{freespace}") + de_register_model(filename,mode) + else: + ## if ref_cnt == 0, then delete the oldest zero_ref one + zero_ref_models = set([model[:model.rfind(" [")] for model, count in models_Ref.get_models_ref_dict().items() if count == 0]) + local_files = set(os.listdir(local_folder)) + # join with local + files = [(os.path.join(local_folder, file), os.path.getctime(os.path.join(local_folder, file))) for file in zero_ref_models.intersection(local_files)] + if len(files) == 0: + print(f"No files to remove in folder: {local_folder}, please remove some files in S3 bucket") + return + files.sort(key=lambda x: x[1]) + oldest_file = files[0][0] + os.remove(oldest_file) + disk_usage = psutil.disk_usage('/tmp') + freespace = disk_usage.free/(1024**3) + print(f"Remove file: {oldest_file} now left space:{freespace}") + filename = os.path.basename(oldest_file) + de_register_model(filename,mode) + + +def initial_s3_download(s3_folder, local_folder,cache_dir,mode): # Create tmp folders os.makedirs(os.path.dirname(local_folder), exist_ok=True) + os.makedirs(os.path.dirname(cache_dir), exist_ok=True) print(f'create dir: {os.path.dirname(local_folder)}') - # Create an S3 client + print(f'create dir: {os.path.dirname(cache_dir)}') + s3_file_name = os.path.join(cache_dir,f's3_files_{mode}.json') + # Create an empty file if not exist + if os.path.isfile(s3_file_name) == False: + s3_files = {} + with open(s3_file_name, "w") as f: + json.dump(s3_files, f) + # Create an S3 clientb s3 = boto3.client('s3') - def sync(): - # List all objects in the S3 folder - response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder) - # Check if there are any new or deleted files - s3_files = set() - for obj in response.get('Contents', []): - s3_files.add(obj['Key'].replace(s3_folder, '').lstrip('/')) - - local_files = set(os.listdir(local_folder)) + # List all objects in the S3 folder + response = s3.list_objects_v2(Bucket=shared.models_s3_bucket, Prefix=s3_folder) + # only download on model at initialization + s3_objects = response.get('Contents', []) + fnames_dict = {} + # if there v2 models, one root should have two files (.ckpt,.yaml) + for obj in s3_objects: + filename = obj['Key'].replace(s3_folder, '').lstrip('/') + root, ext = os.path.splitext(filename) + model = fnames_dict.get(root) + if model: + model.append(filename) + else: + fnames_dict[root] = [filename] + print(f'-----fnames_dict---{fnames_dict}') + + tmp_s3_files = {} + for i, obj in enumerate (s3_objects): + etag = obj['ETag'].strip('"').strip("'") + size = obj['Size']/(1024**3) + filename = obj['Key'].replace(s3_folder, '').lstrip('/') + tmp_s3_files[filename] = [etag,size] + + #only fetch the first model to download. + s3_files = {} + _, file_names = next(iter(fnames_dict.items())) + for fname in file_names: + s3_files[fname] = tmp_s3_files.get(fname) + check_space_s3_download(s3, shared.models_s3_bucket, s3_folder,local_folder, fname, tmp_s3_files.get(fname)[1], mode) + register_models(local_folder,mode) + print(f'-----s3_files---{s3_files}') + # save the lastest one + with open(s3_file_name, "w") as f: + json.dump(s3_files, f) + - new_files = s3_files - local_files - del_files = local_files - s3_files - # Copy new files to local folder - for file in new_files: - s3.download_file(bucket_name, s3_folder + '/' + file, os.path.join(local_folder, file)) - print(f'download_file:from {bucket_name}/{s3_folder}/{file} to {os.path.join(local_folder, file)}') +def sync_s3_folder(local_folder,cache_dir,mode): + # Create an S3 clientb + s3 = boto3.client('s3') + def sync(mode): + if mode == 'sd': + s3_folder = shared.s3_folder_sd + elif mode == 'cn': + s3_folder = shared.s3_folder_cn + else: + s3_folder = '' + # print(f"sync S3 bucket '{shared.models_s3_bucket}', folder '{s3_folder}' for new files...") + # Check and Create tmp folders + os.makedirs(os.path.dirname(local_folder), exist_ok=True) + os.makedirs(os.path.dirname(cache_dir), exist_ok=True) + # print(f'create dir: {os.path.dirname(local_folder)}') + # print(f'create dir: {os.path.dirname(cache_dir)}') + s3_file_name = os.path.join(cache_dir,f's3_files_{mode}.json') + # Create an empty file if not exist + if os.path.isfile(s3_file_name) == False: + s3_files = {} + with open(s3_file_name, "w") as f: + json.dump(s3_files, f) + # List all objects in the S3 folder + response = s3.list_objects_v2(Bucket=shared.models_s3_bucket, Prefix=s3_folder) + # Check if there are any new or deleted files + s3_files = {} + for obj in response.get('Contents', []): + etag = obj['ETag'].strip('"').strip("'") + size = obj['Size']/(1024**3) + key = obj['Key'].replace(s3_folder, '').lstrip('/') + s3_files[key] = [etag,size] + + # to compared the latest s3 list with last time saved in local json, + # read it first + s3_files_local = {} + with open(s3_file_name, "r") as f: + s3_files_local = json.load(f) + # print(f's3_files:{s3_files}') + # print(f's3_files_local:{s3_files_local}') + # save the lastest one + with open(s3_file_name, "w") as f: + json.dump(s3_files, f) + mod_files = set() + new_files = set([key for key in s3_files if key not in s3_files_local]) + del_files = set([key for key in s3_files_local if key not in s3_files]) + registerflag = False + #compare etag changes + for key in set(s3_files_local.keys()).intersection(s3_files.keys()): + local_etag = s3_files_local.get(key)[0] + if local_etag and local_etag != s3_files[key][0]: + mod_files.add(key) # Delete vanished files from local folder for file in del_files: - os.remove(os.path.join(local_folder, file)) - print(f'remove file {os.path.join(local_folder, file)}') - # If there are changes - if len(new_files) | len(del_files): + if os.path.isfile(os.path.join(local_folder, file)): + os.remove(os.path.join(local_folder, file)) + print(f'remove file {os.path.join(local_folder, file)}') + de_register_model(file,mode) + # Add new files + for file in new_files.union(mod_files): + registerflag = True + retry = 3 ##retry limit times to prevent dead loop in case other folders is empty + while retry: + ret = check_space_s3_download(s3, shared.models_s3_bucket, s3_folder,local_folder, file, s3_files[file][1], mode) + #if the space is not enough free + if ret: + retry = 0 + else: + free_local_disk(local_folder,mode) + retry = retry - 1 + if registerflag: + register_models(local_folder,mode) if mode == 'sd': - register_sd_models(local_folder) + #Refreshing Model List + modules.sd_models.list_models() elif mode == 'cn': - register_cn_models(local_folder) - else: - print(f'unsupported mode:{mode}') + #Reload extension models, such as ControlNet + modules.scripts.reload_scripts() + + # Create a thread function to keep syncing with the S3 folder - def sync_thread(): + def sync_thread(mode): while True: - sync() - time.sleep(60) - # Initialize at launch - sync() - # Start the thread - thread = threading.Thread(target=sync_thread) + syncLock.acquire() + sync(mode) + syncLock.release() + time.sleep(30) + thread = threading.Thread(target=sync_thread,args=(mode,)) thread.start() return thread + def webui(): launch_api = cmd_opts.api initialize() @@ -345,61 +538,61 @@ def webui(): if launch_api: create_api(app) - cmd_sd_models_path = cmd_opts.ckpt_dir - sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion") - if cmd_sd_models_path is not None: - sd_models_dir = cmd_sd_models_path - - cmd_controlnet_models_path = cmd_opts.controlnet_dir - cn_models_dir = os.path.join(shared.models_path, "ControlNet") - if cmd_controlnet_models_path is not None: - cn_models_dir = cmd_controlnet_models_path - - if 'endpoint_name' in os.environ: - items = [] - api_endpoint = os.environ['api_endpoint'] - endpoint_name = os.environ['endpoint_name'] - for file in os.listdir(sd_models_dir): - if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')): - hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file)) - item = {} - item['model_name'] = file - item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' - item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file) - item['hash'] = hash - item['title'] = '{0} [{1}]'.format(file, hash) - item['endpoint_name'] = endpoint_name - items.append(item) - inputs = { - 'items': items - } - params = { - 'module': 'Stable-diffusion' - } - if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): - response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) - print(response) - - items = [] - inputs = { - 'items': items - } - params = { - 'module': 'ControlNet' - } - for file in os.listdir(cn_models_dir): - if os.path.isfile(os.path.join(cn_models_dir, file)) and \ - (file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')): - hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file)) - item = {} - item['model_name'] = file - item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) - item['endpoint_name'] = endpoint_name - items.append(item) - - if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): - response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) - print(response) + # cmd_sd_models_path = cmd_opts.ckpt_dir + # sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion") + # if cmd_sd_models_path is not None: + # sd_models_dir = cmd_sd_models_path + + # cmd_controlnet_models_path = cmd_opts.controlnet_dir + # cn_models_dir = os.path.join(shared.models_path, "ControlNet") + # if cmd_controlnet_models_path is not None: + # cn_models_dir = cmd_controlnet_models_path + + # if 'endpoint_name' in os.environ: + # items = [] + # api_endpoint = os.environ['api_endpoint'] + # endpoint_name = os.environ['endpoint_name'] + # for file in os.listdir(sd_models_dir): + # if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')): + # hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file)) + # item = {} + # item['model_name'] = file + # item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' + # item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file) + # item['hash'] = hash + # item['title'] = '{0} [{1}]'.format(file, hash) + # item['endpoint_name'] = endpoint_name + # items.append(item) + # inputs = { + # 'items': items + # } + # params = { + # 'module': 'Stable-diffusion' + # } + # if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): + # response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) + # print(response) + + # items = [] + # inputs = { + # 'items': items + # } + # params = { + # 'module': 'ControlNet' + # } + # for file in os.listdir(cn_models_dir): + # if os.path.isfile(os.path.join(cn_models_dir, file)) and \ + # (file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')): + # hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file)) + # item = {} + # item['model_name'] = file + # item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) + # item['endpoint_name'] = endpoint_name + # items.append(item) + + # if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): + # response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) + # print(response) modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo) @@ -838,9 +1031,9 @@ def train(): lora_model_dir = os.path.join(lora_model_dir, "lora") print('---models path---', sd_models_dir, lora_model_dir) - os.system(f'ls -l {sd_models_dir}') - os.system('ls -l {0}'.format(os.path.join(sd_models_dir, db_model_name))) - os.system(f'ls -l {lora_model_dir}') + print(os.system(f'ls -l {sd_models_dir}')) + print(os.system('ls -l {0}'.format(os.path.join(sd_models_dir, db_model_name)))) + print(os.system(f'ls -l {lora_model_dir}')) try: print('Uploading SD Models...') From 1807ea88836819e14746c03ec41fe52afae0eaf3 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Mon, 10 Apr 2023 16:16:22 +0800 Subject: [PATCH 18/31] update webui.py and api.py --- modules/api/api.py | 1 + webui.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index 566b2618521..63a517c54ee 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -29,6 +29,7 @@ import piexif import piexif.helper import numpy as np +import uuid def upscaler_to_index(name: str): try: diff --git a/webui.py b/webui.py index d9342241a6b..db16873a599 100644 --- a/webui.py +++ b/webui.py @@ -352,6 +352,8 @@ def webui(): item = {} item['model_name'] = os.path.basename(file) item['hash'] = hash + item['filename'] = file + item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' item['title'] = '{0} [{1}]'.format(os.path.basename(file), hash) item['endpoint_name'] = endpoint_name items.append(item) From f352ab2e395e7e280c9592eef63dc86d53b3b26f Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Mon, 10 Apr 2023 16:59:40 +0800 Subject: [PATCH 19/31] update webui.py --- modules/api/api.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 63a517c54ee..70dd2d7e42e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -406,7 +406,6 @@ def post_invocations(self, username, b64images): if generated_images_s3uri: generated_images_s3uri = f'{generated_images_s3uri}{username}/' bucket, key = self.get_bucket_and_key(generated_images_s3uri) - images = [] for b64image in b64images: image = decode_base64_to_image(b64image).convert('RGB') output = io.BytesIO() @@ -417,10 +416,6 @@ def post_invocations(self, username, b64images): Bucket=bucket, Key=f'{key}/{image_id}.jpg' ) - images.append(f's3://{bucket}/{key}/{image_id}.jpg') - return images - else: - return b64images def invocations(self, req: InvocationsRequest): print('-------invocation------') @@ -456,24 +451,24 @@ def invocations(self, req: InvocationsRequest): self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() response = self.text2imgapi(req.txt2img_payload) - response.images = self.post_invocations(username, response.images) + self.post_invocations(username, response.images) shared.opts.data = default_options return response elif req.task == 'image-to-image': self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() response = self.img2imgapi(req.img2img_payload) - response.images = self.post_invocations(username, response.images) + self.post_invocations(username, response.images) shared.opts.data = default_options return response elif req.task == 'extras-single-image': response = self.extras_single_image_api(req.extras_single_payload) - response.image = self.post_invocations(username, [response.image])[0] + self.post_invocations(username, [response.image]) shared.opts.data = default_options return response elif req.task == 'extras-batch-images': response = self.extras_batch_images_api(req.extras_batch_payload) - response.images = self.post_invocations(username, response.images) + self.post_invocations(username, response.images) shared.opts.data = default_options return response else: From 972a04f6b780f2131c970f4b2194470d6c742dc5 Mon Sep 17 00:00:00 2001 From: xie river Date: Tue, 11 Apr 2023 02:33:17 +0000 Subject: [PATCH 20/31] dynamic loading models --- localizations/zh_CN.json | 2 +- modules/call_queue.py | 2 +- modules/paths.py | 6 +- modules/scripts.py | 9 ++- modules/shared.py | 3 +- modules/ui.py | 79 ++++++++++++++++---------- webui.py | 119 ++++++++------------------------------- 7 files changed, 88 insertions(+), 132 deletions(-) diff --git a/localizations/zh_CN.json b/localizations/zh_CN.json index 95dd39412f6..2eb9c4d0c3f 100644 --- a/localizations/zh_CN.json +++ b/localizations/zh_CN.json @@ -844,7 +844,7 @@ "Output S3 folder":"S3文件夹目录", "Upload Train Images to S3":"上传训练图片到S3", "Error, please configure a S3 bucket at settings page first":"失败,请先到设置页面配置S3桶名", - "Upload":"上传", + "Upload Images":"上传图片", "Reload all models":"重新加载模型文件", "Update model files path":"更新模型加载路径", "S3 path for downloading model files (E.g, s3://bucket-name/models/)":"加载模型的S3路径,例如:s3://bucket-name/models/", diff --git a/modules/call_queue.py b/modules/call_queue.py index 7578d963188..fce0938cedb 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -450,7 +450,7 @@ def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs): t = time.perf_counter() try: - if func.__name__ == 'f' or func.__name__ == 'run_settings': + if func.__name__ == 'f' or func.__name__ == 'run_settings' or func.__name__ == 'save_files': res = list(func(username, *args, **kwargs)) else: res = list(func(*args, **kwargs)) diff --git a/modules/paths.py b/modules/paths.py index 4dd03a3594c..c80b317bdc0 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -2,9 +2,11 @@ import os import sys import modules.safe - script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -models_path = os.path.join(script_path, "models") +## Change by River +# models_path = os.path.join(script_path, "models") +models_path = '/tmp/models' +## sys.path.insert(0, script_path) # search for directory of stable diffusion in following places diff --git a/modules/scripts.py b/modules/scripts.py index bba8707d0b4..6320d3c9641 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -181,7 +181,7 @@ def load_scripts(): script_callbacks.clear_callbacks() scripts_list = list_scripts("scripts", ".py") - + print('scripts_list:',scripts_list) syspath = sys.path for scriptfile in sorted(scripts_list): @@ -203,6 +203,7 @@ def load_scripts(): finally: sys.path = syspath current_basedir = paths.script_path + print('scripts_data',scripts_data) def wrap_call(func, filename, funcname, *args, default=None, **kwargs): @@ -225,6 +226,9 @@ def __init__(self): self.infotext_fields = [] def initialize_scripts(self, is_img2img): + print('----initialize_scripts()------') + print(f'--scripts_data--{scripts_data}') + traceback.print_stack() self.scripts.clear() self.alwayson_scripts.clear() self.selectable_scripts.clear() @@ -316,7 +320,8 @@ def run(self, p: StableDiffusionProcessing, *args): if script_index == 0: return None - + print('self.selectable_scripts:',self.selectable_scripts) + print('script_index:',script_index) script = self.selectable_scripts[script_index-1] if script is None: diff --git a/modules/shared.py b/modules/shared.py index 2681327dba4..d90009be911 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -25,7 +25,7 @@ s3_folder_cn = None syncLock = threading.Lock() tmp_models_dir = '/tmp/models' -tmp_cache_dir = '/tmp/cache' +tmp_cache_dir = '/tmp/model_sync_cache' #end sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -514,6 +514,7 @@ def refresh_sagemaker_endpoints(username): options_templates.update(options_section(('sd', "Stable Diffusion"), { + # "models_s3_bucket": OptionInfo(f'{get_default_sagemaker_bucket()}/stable-diffusion-webui/models/', "S3 path for downloading model files (E.g, s3://bucket-name/models/)", ), "sagemaker_endpoint": OptionInfo(None, "SaegMaker endpoint", gr.Dropdown, lambda: {"choices": list_sagemaker_endpoints()}, refresh=refresh_sagemaker_endpoints), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), diff --git a/modules/ui.py b/modules/ui.py index 3406400d422..12d2f9d284c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -89,10 +89,22 @@ def gr_show(visible=True): ## Begin output images uploaded to s3 by River s3_resource = boto3.resource('s3') -def save_images_to_s3(full_fillnames,timestamp): - username = shared.username +def get_webui_username(request): + tokens = shared.demo.server_app.tokens + cookies = request.headers['cookie'].split('; ') + access_token = None + for cookie in cookies: + if cookie.startswith('access-token'): + access_token = cookie[len('access-token=') : ] + break + username = tokens[access_token] if access_token else None + return username + +def save_images_to_s3(full_fillnames,timestamp,username): sagemaker_endpoint = shared.opts.sagemaker_endpoint - bucket_name = opts.train_files_s3bucket + bucket_name = opts.train_files_s3bucket.replace('s3://','') + if bucket_name.endswith('/'): + bucket_name= bucket_name[:-1] if bucket_name == '': return 'Error, please configure a S3 bucket at settings page first' s3_bucket = s3_resource.Bucket(bucket_name) @@ -134,6 +146,10 @@ def save_images_to_s3(full_fillnames,timestamp): save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 +def text_to_hyperlink_html(url): + text= f'

    {url}

    ' + return text + def plaintext_to_html(text): text = "

    " + "
    \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

    " return text @@ -143,7 +159,7 @@ def send_gradio_gallery_to_image(x): return None return image_from_url_text(x[0]) -def save_files(js_data, images, do_make_zip, index): +def save_files(username,js_data, images, do_make_zip, index): import csv filenames = [] fullfns = [] @@ -197,8 +213,7 @@ def __init__(self, d=None): timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S') logfile = os.path.join(opts.outdir_save, "log.csv") - s3folder = save_images_to_s3(fullfns,timestamp) - save_images_to_s3([logfile],timestamp) + s3folder = save_images_to_s3(fullfns+[logfile],timestamp,username) # Make Zip if do_make_zip: zip_filepath = os.path.join(path, "images.zip") @@ -210,7 +225,7 @@ def __init__(self, d=None): zip_file.writestr(filenames[i], f.read()) fullfns.insert(0, zip_filepath) - return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}, \nS3 folder:\n{s3folder}") + return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}"),text_to_hyperlink_html(s3folder) @@ -683,13 +698,14 @@ def open_folder(f): generation_info, result_gallery, do_make_zip, - html_info, + html_info ], outputs=[ download_files, html_info, html_info, html_info, + html_info ] ) else: @@ -711,20 +727,20 @@ def create_ui(): interfaces = [] - ##add images viewer - def translate(text): - return f'translated:{text}' - with gr.Blocks(analytics_enabled=False) as imagesviewer_interface: - with gr.Row().style(equal_height=False): - with gr.Column(): - english = gr.Textbox(label="Placeholder") - translate_btn = gr.Button(value="Translate") - with gr.Column(): - german = gr.Textbox(label="German Text") - - translate_btn.click(translate, inputs=english, outputs=german, api_name="translate-to-german") - examples = gr.Examples(examples=["I went to the supermarket yesterday.", "Helen is a good swimmer."], - inputs=[english]) + ##add River + # def translate(text): + # return f'translated:{text}' + # with gr.Blocks(analytics_enabled=False) as imagesviewer_interface: + # with gr.Row().style(equal_height=False): + # with gr.Column(): + # english = gr.Textbox(label="Placeholder") + # translate_btn = gr.Button(value="Translate") + # with gr.Column(): + # german = gr.Textbox(label="German Text") + + # translate_btn.click(translate, inputs=english, outputs=german, api_name="translate-to-german") + # examples = gr.Examples(examples=["I went to the supermarket yesterday.", "Helen is a good swimmer."], + # inputs=[english]) with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Row().style(equal_height=False): @@ -880,13 +896,14 @@ def run_settings_single(value, key, request : gr.Request): with gr.Row(): settings_submit = gr.Button(value="Apply settings", variant='primary') with gr.Row(): - with gr.Column(scale=2): + with gr.Column(scale=4): models_s3bucket = gr.Textbox(label="S3 path for downloading model files (E.g, s3://bucket-name/models/)", value=default_s3_path) with gr.Column(scale=1): set_models_s3bucket_btn = gr.Button(value="Update model files path",elem_id='id_set_models_s3bucket') with gr.Column(scale=1): reload_models_btn = gr.Button(value='Reload all models', elem_id='id_reload_all_models') + result = gr.HTML() @@ -996,7 +1013,8 @@ def reload_all_models(): inputs=[], outputs=[result] ) - + + # River def set_models_s3bucket(bucket_name): if bucket_name == '': return 'Error, please configure a S3 bucket for downloading model files' @@ -1539,10 +1557,13 @@ def update_orig(image, state): with gr.Row().style(equal_height=False): with gr.Tabs(elem_id="train_tabs"): ## Begin add s3 images upload interface by River - def upload_to_s3(imgs): - username = shared.username + def upload_to_s3(imgs,request : gr.Request): + username = get_webui_username(request) + print (f'--get_webui_username--:{username}') timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S') - bucket_name = opts.train_files_s3bucket + bucket_name = opts.train_files_s3bucket.replace('s3://','') + if bucket_name.endswith('/'): + bucket_name= bucket_name[:-1] if bucket_name == '': return 'Error, please configure a S3 bucket at settings page first' s3_bucket = s3_resource.Bucket(bucket_name.replace('s3://','')) @@ -1561,7 +1582,7 @@ def upload_to_s3(imgs): with gr.Tab(label="Upload Train Images to S3"): upload_files = gr.Files(label="Files") url_output = gr.Textbox(label="Output S3 folder") - sub_btn = gr.Button(label="Upload",variant='primary',elem_id='id_upload_train_files') + sub_btn = gr.Button(value="Upload Images",elem_id='id_upload_train_images',variant='primary') sub_btn.click(fn=upload_to_s3, inputs=upload_files, outputs=url_output) ## End add s3 images upload interface by River with gr.Tab(label="Train Embedding"): @@ -2194,7 +2215,7 @@ def save_userdata(user_dataframe, request: gr.Request): # interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings_interface, "Settings", "settings")] - interfaces += [(imagesviewer_interface,"Images Viewer","imagesviewer")] + # interfaces += [(imagesviewer_interface,"Images Viewer","imagesviewer")] extensions_interface = ui_extensions.create_ui() interfaces += [(extensions_interface, "Extensions", "extensions")] diff --git a/webui.py b/webui.py index 7d761e6633d..ff51d2b83d1 100644 --- a/webui.py +++ b/webui.py @@ -79,8 +79,11 @@ def initialize(): region_name = session.region_name sts_client = session.client('sts') account_id = sts_client.get_caller_identity()['Account'] + sg_s3_bucket = f"sagemaker-{region_name}-{account_id}" + #print ('environ:',os.environ) if not shared.models_s3_bucket: - shared.models_s3_bucket = f"sagemaker-{region_name}-{account_id}" + + shared.models_s3_bucket = os.environ['sg_default_bucket'] if os.environ.get('sg_default_bucket') else sg_s3_bucket shared.s3_folder_sd = "stable-diffusion-webui/models/Stable-diffusion" shared.s3_folder_cn = "stable-diffusion-webui/models/ControlNet" @@ -99,7 +102,7 @@ def initialize(): modules.scripts.load_scripts() modelloader.load_upscalers() - + modules.sd_vae.refresh_vae_list() if not cmd_opts.pureui: modules.sd_models.load_model() @@ -187,92 +190,6 @@ def user_auth(username, password): return response.status_code == 200 -# def register_models(models_dir,mode): -# if mode == 'sd': -# register_sd_models(models_dir) -# elif mode == 'cn': -# register_cn_models(models_dir) - -# def register_sd_models(sd_models_dir): -# print ('---register_sd_models()----') -# if 'endpoint_name' in os.environ: -# items = [] -# api_endpoint = os.environ['api_endpoint'] -# endpoint_name = os.environ['endpoint_name'] -# print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') -# for file in os.listdir(sd_models_dir): -# if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')): -# hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file)) -# item = {} -# item['model_name'] = file -# item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' -# item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file) -# item['hash'] = hash -# item['title'] = '{0} [{1}]'.format(file, hash) -# item['endpoint_name'] = endpoint_name -# items.append(item) -# inputs = { -# 'items': items -# } -# params = { -# 'module': 'Stable-diffusion' -# } -# if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): -# response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) -# print(response) - -# def register_cn_models(cn_models_dir): -# print ('---register_cn_models()----') -# if 'endpoint_name' in os.environ: -# items = [] -# api_endpoint = os.environ['api_endpoint'] -# endpoint_name = os.environ['endpoint_name'] -# print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}') - -# inputs = { -# 'items': items -# } -# params = { -# 'module': 'ControlNet' -# } -# for file in os.listdir(cn_models_dir): -# if os.path.isfile(os.path.join(cn_models_dir, file)) and \ -# (file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')): -# hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file)) -# item = {} -# item['model_name'] = file -# item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash) -# item['endpoint_name'] = endpoint_name -# items.append(item) - -# if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'): -# response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) -# print(response) - -# def de_register_model(model_name,mode): -# models_Ref = shared.sd_models_Ref -# if mode == 'sd' : -# models_Ref = shared.sd_models_Ref -# elif mode == 'cn': -# models_Ref = shared.cn_models_Ref -# models_Ref.remove_model_ref(model_name) -# print (f'---de_register_{mode}_model({model_name})---models_Ref({models_Ref.get_models_ref_dict()})----') -# if 'endpoint_name' in os.environ: -# api_endpoint = os.environ['api_endpoint'] -# endpoint_name = os.environ['endpoint_name'] -# data = { -# "module":mode, -# "model_name": model_name, -# "endpoint_name": endpoint_name -# } -# response = requests.delete(url=f'{api_endpoint}/sd/models', json=data) -# # Check if the request was successful -# if response.status_code == requests.codes.ok: -# print(f"{model_name} deleted successfully!") -# else: -# print(f"Error deleting {model_name}: ", response.text) - - def check_space_s3_download(s3,bucket_name,s3_folder,local_folder,file,size,mode): src = s3_folder + '/' + file @@ -365,7 +282,6 @@ def initial_s3_download(s3_folder, local_folder,cache_dir,mode): model.append(filename) else: fnames_dict[root] = [filename] - print(f'-----fnames_dict---{fnames_dict}') tmp_s3_files = {} for i, obj in enumerate (s3_objects): @@ -460,12 +376,13 @@ def sync(mode): retry = retry - 1 if registerflag: register_models(local_folder,mode) - if mode == 'sd': - #Refreshing Model List - modules.sd_models.list_models() - elif mode == 'cn': - #Reload extension models, such as ControlNet - modules.scripts.reload_scripts() + reload_webui_infer(mode) + # if mode == 'sd': + # #Refreshing Model List + # modules.sd_models.list_models() + # elif mode == 'cn': + # #Reload extension models, such as ControlNet + # modules.scripts.reload_scripts() # Create a thread function to keep syncing with the S3 folder @@ -479,7 +396,17 @@ def sync_thread(mode): thread.start() return thread - +def reload_webui_infer(mode): + extensions.list_extensions() + print('Reloading custom scripts') + modules.scripts.load_scripts() + modelloader.load_upscalers() + # print('Reloading modules: modules.ui') + # importlib.reload(modules.ui) + if mode == 'sd': + print('Refreshing SD Model List') + modules.sd_models.list_models() + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) def webui(): launch_api = cmd_opts.api From 02e3d02d1f16552c87003bb2fc36ae7c861b0ff5 Mon Sep 17 00:00:00 2001 From: xie river Date: Wed, 12 Apr 2023 05:28:21 +0000 Subject: [PATCH 21/31] merge upstream --- modules/shared.py | 24 ++++++++++++++++++++++++ webui.py | 27 +-------------------------- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 83066dfd3ba..2edf3116998 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -337,6 +337,30 @@ def pop_least_ref_model(self): cn_models_Ref = ModelsRef() lora_models_Ref = ModelsRef() +def de_register_model(model_name,mode): + models_Ref = sd_models_Ref + if mode == 'sd' : + models_Ref = sd_models_Ref + elif mode == 'cn': + models_Ref = cn_models_Ref + elif mode == 'lora': + models_Ref = lora_models_Ref + models_Ref.remove_model_ref(model_name) + print (f'---de_register_{mode}_model({model_name})---models_Ref({models_Ref.get_models_ref_dict()})----') + if 'endpoint_name' in os.environ: + api_endpoint = os.environ['api_endpoint'] + endpoint_name = os.environ['endpoint_name'] + data = { + "module":mode, + "model_name": model_name, + "endpoint_name": endpoint_name + } + response = requests.delete(url=f'{api_endpoint}/sd/models', json=data) + # Check if the request was successful + if response.status_code == requests.codes.ok: + print(f"{model_name} deleted successfully!") + else: + print(f"Error deleting {model_name}: ", response.text) #end by River class OptionInfo: diff --git a/webui.py b/webui.py index 78a5ebb1683..bbf4b8d2353 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,7 @@ import modules.ui from modules import modelloader -from modules.shared import cmd_opts, opts, sd_model,syncLock +from modules.shared import cmd_opts, opts, sd_model,syncLock,de_register_model import modules.hypernetworks.hypernetwork import boto3 import threading @@ -548,31 +548,6 @@ def register_cn_models(cn_models_dir): response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params) print(response) -def de_register_model(model_name,mode): - models_Ref = sd_models_Ref - if mode == 'sd' : - models_Ref = sd_models_Ref - elif mode == 'cn': - models_Ref = cn_models_Ref - elif mode == 'lora': - models_Ref = lora_models_Ref - models_Ref.remove_model_ref(model_name) - print (f'---de_register_{mode}_model({model_name})---models_Ref({models_Ref.get_models_ref_dict()})----') - if 'endpoint_name' in os.environ: - api_endpoint = os.environ['api_endpoint'] - endpoint_name = os.environ['endpoint_name'] - data = { - "module":mode, - "model_name": model_name, - "endpoint_name": endpoint_name - } - response = requests.delete(url=f'{api_endpoint}/sd/models', json=data) - # Check if the request was successful - if response.status_code == requests.codes.ok: - print(f"{model_name} deleted successfully!") - else: - print(f"Error deleting {model_name}: ", response.text) - def webui(): launch_api = cmd_opts.api From d8251f2c59cb850310575f62c9ac3f2a8dfd7e4d Mon Sep 17 00:00:00 2001 From: xie river Date: Wed, 12 Apr 2023 10:13:00 +0000 Subject: [PATCH 22/31] image viewer --- localizations/zh_CN.json | 3 +++ modules/ui.py | 55 ++++++++++++++++++++++++++++------------ 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/localizations/zh_CN.json b/localizations/zh_CN.json index 2eb9c4d0c3f..674b3505d36 100644 --- a/localizations/zh_CN.json +++ b/localizations/zh_CN.json @@ -849,5 +849,8 @@ "Update model files path":"更新模型加载路径", "S3 path for downloading model files (E.g, s3://bucket-name/models/)":"加载模型的S3路径,例如:s3://bucket-name/models/", "Images Viewer":"图片浏览器", + "Input S3 path of images":"输入图片的S3路径", + "Submit":"确定", + "columns width":"每行图片列数", "--------": "--------" } diff --git a/modules/ui.py b/modules/ui.py index 12d2f9d284c..1b8bf48587b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -728,20 +728,44 @@ def create_ui(): interfaces = [] ##add River - # def translate(text): - # return f'translated:{text}' - # with gr.Blocks(analytics_enabled=False) as imagesviewer_interface: - # with gr.Row().style(equal_height=False): - # with gr.Column(): - # english = gr.Textbox(label="Placeholder") - # translate_btn = gr.Button(value="Translate") - # with gr.Column(): - # german = gr.Textbox(label="German Text") - - # translate_btn.click(translate, inputs=english, outputs=german, api_name="translate-to-german") - # examples = gr.Examples(examples=["I went to the supermarket yesterday.", "Helen is a good swimmer."], - # inputs=[english]) + session = boto3.Session() + s3 = session.client('s3') + def list_objects(bucket,prefix=''): + response = s3.list_objects(Bucket=bucket, Prefix=prefix) + objects = response['Contents'] if response.get('Contents') else [] + return [obj['Key'] for obj in objects] + + def image_viewer(path,cols_width,request:gr.Request): + dirs = path.replace('s3://','').split('/') + prefix = '/'.join(dirs[1:]) + bucket = dirs[0] + objects = list_objects(bucket,prefix) + image_url = [] + for object_key in objects: + if object_key.endswith('.jpg') or object_key.endswith('.jpeg') or object_key.endswith('.png'): + image_url.append([s3.generate_presigned_url('get_object', Params={ + 'Bucket': bucket, 'Key': object_key}, ExpiresIn=3600),object_key.split('/')[-1]]) + image_tags = "" + for image,key in image_url: + image_tags += f"
    {key}
    " + div = f"
    {image_tags}
    " + return div + + with gr.Blocks(analytics_enabled=False) as imagesviewer_interface: + with gr.Row(): + with gr.Column(scale=3): + images_s3_path = gr.Textbox(label="Input S3 path of images",value = get_default_sagemaker_bucket()+'/output-images') + with gr.Column(scale=1): + cols_width = gr.Slider(minimum=4, maximum=20, step=1, label="columns width", value=8) + with gr.Column(scale=1): + images_s3_path_btn = gr.Button(value="Submit",variant='primary') + with gr.Row(): + result = gr.HTML("
    ") + images_s3_path_btn.click(fn=image_viewer, inputs=[images_s3_path,cols_width], outputs=[result]) + + + ## end with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): @@ -1526,7 +1550,6 @@ def update_orig(image, state): fn=modules.extras.clear_cache, inputs=[], outputs=[] ) - if not cmd_opts.pureui: with gr.Blocks(analytics_enabled=False) as modelmerger_interface: with gr.Row().style(equal_height=False): @@ -1577,7 +1600,7 @@ def upload_to_s3(imgs,request : gr.Request): print(e) return e - return f"{len(imgs)} images uploaded to S3 folder:{bucket_name}/{folder_name}" + return f"{len(imgs)} images uploaded to S3 folder:s3://{bucket_name}/{folder_name}" with gr.Tab(label="Upload Train Images to S3"): upload_files = gr.Files(label="Files") @@ -2215,7 +2238,7 @@ def save_userdata(user_dataframe, request: gr.Request): # interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings_interface, "Settings", "settings")] - # interfaces += [(imagesviewer_interface,"Images Viewer","imagesviewer")] + interfaces += [(imagesviewer_interface,"Images Viewer","imagesviewer")] extensions_interface = ui_extensions.create_ui() interfaces += [(extensions_interface, "Extensions", "extensions")] From 78afa458a6d140dcc106db84ea96c311d715cc2d Mon Sep 17 00:00:00 2001 From: xie river Date: Wed, 12 Apr 2023 14:05:58 +0000 Subject: [PATCH 23/31] fix model uploading s3 path --- webui.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/webui.py b/webui.py index bbf4b8d2353..e5f9fdc74fd 100644 --- a/webui.py +++ b/webui.py @@ -1137,17 +1137,17 @@ def train(): print('Uploading SD Models...') if db_config.v2: upload_s3files( - f'{sd_models_s3uri}/{username}/', + f'{sd_models_s3uri}{username}/', os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.yaml') ) if db_config.save_safetensors: upload_s3files( - f'{sd_models_s3uri}/{username}/', + f'{sd_models_s3uri}{username}/', os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.safetensors') ) else: upload_s3files( - f'{sd_models_s3uri}/{username}/', + f'{sd_models_s3uri}{username}/', os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.ckpt') ) print('Uploading DB Models...') @@ -1158,7 +1158,7 @@ def train(): if db_config.use_lora: print('Uploading Lora Models...') upload_s3files( - f'{lora_models_s3uri}/{username}/', + f'{lora_models_s3uri}{username}/', os.path.join(lora_model_dir, f'{db_model_name}_*.pt') ) #automatic tar latest checkpoint and upload to s3 by zheng on 2023.03.22 From dd88ce652d627b2a60f9bc393e1d868462bac438 Mon Sep 17 00:00:00 2001 From: xie river Date: Thu, 13 Apr 2023 08:22:36 +0000 Subject: [PATCH 24/31] fix s3 path error --- localizations/zh_CN.json | 1 + modules/ui.py | 11 ++++++++--- webui.py | 6 ++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/localizations/zh_CN.json b/localizations/zh_CN.json index 674b3505d36..83f1f64e052 100644 --- a/localizations/zh_CN.json +++ b/localizations/zh_CN.json @@ -852,5 +852,6 @@ "Input S3 path of images":"输入图片的S3路径", "Submit":"确定", "columns width":"每行图片列数", + "Show current user's images only":"只显示当前用户图片集", "--------": "--------" } diff --git a/modules/ui.py b/modules/ui.py index 1b8bf48587b..ef18c9fe776 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -735,7 +735,10 @@ def list_objects(bucket,prefix=''): objects = response['Contents'] if response.get('Contents') else [] return [obj['Key'] for obj in objects] - def image_viewer(path,cols_width,request:gr.Request): + def image_viewer(path,cols_width,current_only,request:gr.Request): + if current_only: + username = get_webui_username(request) + path = path+'/'+username dirs = path.replace('s3://','').split('/') prefix = '/'.join(dirs[1:]) bucket = dirs[0] @@ -755,14 +758,16 @@ def image_viewer(path,cols_width,request:gr.Request): with gr.Blocks(analytics_enabled=False) as imagesviewer_interface: with gr.Row(): with gr.Column(scale=3): - images_s3_path = gr.Textbox(label="Input S3 path of images",value = get_default_sagemaker_bucket()+'/output-images') + images_s3_path = gr.Textbox(label="Input S3 path of images",value = get_default_sagemaker_bucket()+'/stable-diffusion-webui/generated') + with gr.Column(scale=1): + show_user_only = gr.Checkbox(label="Show current user's images only", value=True) with gr.Column(scale=1): cols_width = gr.Slider(minimum=4, maximum=20, step=1, label="columns width", value=8) with gr.Column(scale=1): images_s3_path_btn = gr.Button(value="Submit",variant='primary') with gr.Row(): result = gr.HTML("
    ") - images_s3_path_btn.click(fn=image_viewer, inputs=[images_s3_path,cols_width], outputs=[result]) + images_s3_path_btn.click(fn=image_viewer, inputs=[images_s3_path,cols_width,show_user_only], outputs=[result]) ## end diff --git a/webui.py b/webui.py index e5f9fdc74fd..0bf375ea931 100644 --- a/webui.py +++ b/webui.py @@ -242,6 +242,7 @@ def check_space_s3_download(s3_client,bucket_name,s3_folder,local_folder,file,si # s3_client = boto3.client('s3') src = s3_folder + '/' + file dist = os.path.join(local_folder, file) + os.makedirs(os.path.dirname(dist), exist_ok=True) # Get disk usage statistics disk_usage = psutil.disk_usage('/tmp') freespace = disk_usage.free/(1024**3) @@ -508,11 +509,12 @@ def register_sd_models(sd_models_dir): for file in get_models(sd_models_dir, ['*.ckpt', '*.safetensors']): hash = modules.sd_models.model_hash(file) item = {} - item['model_name'] = os.path.basename(file) + ##remove the prefix, but remain the sub dir path in model_name. eg. 'river/jp-style-girl-2_200_lora.safetensors [4d3a456f]' + item['model_name'] = file.replace("/tmp/models/Stable-diffusion/",'') item['hash'] = hash item['filename'] = file item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' - item['title'] = '{0} [{1}]'.format(os.path.basename(file), hash) + item['title'] = '{0} [{1}]'.format(file.replace("/tmp/models/Stable-diffusion/",''), hash) item['endpoint_name'] = endpoint_name items.append(item) inputs = { From b6105fd7805cc077ef9acd747ba399e7b392cc73 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Thu, 13 Apr 2023 19:55:37 +0800 Subject: [PATCH 25/31] Update webui.py --- webui.py | 1 - 1 file changed, 1 deletion(-) diff --git a/webui.py b/webui.py index 0bf375ea931..404bed535aa 100644 --- a/webui.py +++ b/webui.py @@ -509,7 +509,6 @@ def register_sd_models(sd_models_dir): for file in get_models(sd_models_dir, ['*.ckpt', '*.safetensors']): hash = modules.sd_models.model_hash(file) item = {} - ##remove the prefix, but remain the sub dir path in model_name. eg. 'river/jp-style-girl-2_200_lora.safetensors [4d3a456f]' item['model_name'] = file.replace("/tmp/models/Stable-diffusion/",'') item['hash'] = hash item['filename'] = file From 823e510f9928de76da2bcc84631a6987b342dafb Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Thu, 13 Apr 2023 20:08:24 +0800 Subject: [PATCH 26/31] update webui.py --- webui.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/webui.py b/webui.py index 404bed535aa..c547bd48286 100644 --- a/webui.py +++ b/webui.py @@ -68,6 +68,8 @@ s3_resource= boto3.resource('s3') def s3_download(s3uri, path): + global cache + pos = s3uri.find('/', 5) bucket = s3uri[5 : pos] key = s3uri[pos + 1 : ] From 19c8ccb5c4df3cc4ee2f682034bcec86e130eea5 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Thu, 13 Apr 2023 20:58:09 +0800 Subject: [PATCH 27/31] update webui.py --- webui.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/webui.py b/webui.py index c547bd48286..299e2edc492 100644 --- a/webui.py +++ b/webui.py @@ -81,8 +81,6 @@ def s3_download(s3uri, path): cache = json.load(open('cache', 'r')) for obj in objs: - if obj.key == key: - continue response = s3_client.head_object( Bucket = bucket, Key = obj.key From 59abb69d218b6e69c305ffef2664c45b9fdb00b7 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Thu, 13 Apr 2023 21:58:33 +0800 Subject: [PATCH 28/31] update shared.py and ui.py --- modules/shared.py | 28 ++++++++++++++++++++++++++++ modules/ui.py | 25 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/modules/shared.py b/modules/shared.py index 2edf3116998..be090bab00f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -404,11 +404,18 @@ def list_samplers(): sagemaker_endpoints = [] +sd_models = [] + def list_sagemaker_endpoints(): global sagemaker_endpoints return sagemaker_endpoints +def list_sd_models(): + global sd_models + + return sd_models + def intersection(lst1, lst2): set1 = set(lst1) set2 = set(lst2) @@ -454,7 +461,28 @@ def refresh_sagemaker_endpoints(username): return sagemaker_endpoints +def refresh_sd_models(username): + global api_endpoint, sd_models + + names = set() + + if not username: + return sd_models + + params = { + 'module': 'sd_models' + } + params['username'] = username + + response = requests.get(url=f'{api_endpoint}/sd/models', params=params) + if response.status_code == 200: + model_list = json.loads(response.text) + for model in model_list: + names.add(model) + + sd_models = list(names) + return sd_models options_templates.update(options_section(('sd', "Stable Diffusion"), { # "models_s3_bucket": OptionInfo(f'{get_default_sagemaker_bucket()}/stable-diffusion-webui/models/', "S3 path for downloading model files (E.g, s3://bucket-name/models/)", ), diff --git a/modules/ui.py b/modules/ui.py index ef18c9fe776..398e6ea050d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -597,6 +597,24 @@ def refresh_sagemaker_endpoints(request : gr.Request): return gr.update(**(args or {})) + def refresh_sd_models(request: gr.Request): + tokens = shared.demo.server_app.tokens + cookies = request.headers['cookie'].split('; ') + access_token = None + for cookie in cookies: + if cookie.startswith('access-token'): + access_token = cookie[len('access-token=') : ] + break + username = tokens[access_token] if access_token else None + + refresh_method(username) + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + def refresh_checkpoints(sagemaker_endpoint): refresh_method(sagemaker_endpoint) args = refreshed_args() if callable(refreshed_args) else refreshed_args @@ -613,6 +631,12 @@ def refresh_checkpoints(sagemaker_endpoint): inputs=[], outputs=[refresh_component] ) + elif elem_id == 'refresh_sd_models': + refresh_button.click( + fn=refresh_sd_models, + inputs=[], + outputs=[refresh_component] + ) elif elem_id == 'refresh_sd_model_checkpoint': refresh_button.click( fn=refresh_checkpoints, @@ -2345,6 +2369,7 @@ def demo_load(request: gr.Request): except Exception as e: print(e) shared.refresh_sagemaker_endpoints(username) + shared.refresh_sd_models(username) shared.refresh_checkpoints(shared.opts.sagemaker_endpoint) additional_components = [gr.update(value=username), gr.update(), gr.update(value=shared.opts.sagemaker_endpoint, choices=shared.sagemaker_endpoints), gr.update(value=shared.opts.sd_model_checkpoint, choices=modules.sd_models.checkpoint_tiles())] else: From 4470feb346afe26c45f375f92b9edebc9a6a16f3 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Fri, 14 Apr 2023 14:38:46 +0800 Subject: [PATCH 29/31] add support for load VAE models when endpoint is created --- modules/api/api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index 971424ac8a0..2aa1f20cbb9 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -30,6 +30,7 @@ import piexif.helper import numpy as np import uuid +import modules.sd_vae def upscaler_to_index(name: str): try: @@ -438,9 +439,13 @@ def invocations(self, req: InvocationsRequest): if response.status_code == 200 and response.text != '': try: data = json.loads(response.text) + sd_model_checkpoint = shared.opts.sd_model_checkpoint shared.opts.data = json.loads(data['options']) + modules.sd_vae.refresh_vae_list() with self.queue_lock: sd_models.reload_model_weights() + if sd_model_checkpoint == shared.opts.sd_model_checkpoint: + modules.sd_vae.reload_vae_weights() except Exception as e: print(e) From e19ab3cd97b2fa926928c8f92f1488f5b7bab531 Mon Sep 17 00:00:00 2001 From: Jianyu Zhan Date: Thu, 6 Apr 2023 17:30:37 +0800 Subject: [PATCH 30/31] Added support for checkpoint merge from S3 --- localizations/zh_CN.json | 5 + modules/call_queue.py | 2 +- modules/model_merger.py | 293 +++++++++++++++++++++++++++++++++++++++ modules/shared.py | 8 +- modules/ui.py | 129 +++++++++-------- webui.py | 2 +- 6 files changed, 377 insertions(+), 62 deletions(-) create mode 100644 modules/model_merger.py diff --git a/localizations/zh_CN.json b/localizations/zh_CN.json index 83f1f64e052..9f5b24c15b6 100644 --- a/localizations/zh_CN.json +++ b/localizations/zh_CN.json @@ -210,6 +210,11 @@ "A merger of the two checkpoints will be generated in your": "合并后的模型(ckpt)会生成在你的", "checkpoint": "模型(ckpt)", "directory.": "目录", + "Merged checkpoints will be put in the specified output S3 location": "合并后的模型(ckpt)会生成在指定的S3目录", + "Checkpoint S3 URI": "用于合并的模型所在S3目录", + "Load Checkpoints": "加载模型(ckpt)", + "Merge Result S3 URI": "合并后的模型存放的S3目录", + "If not specified, will put into ": "如果没有指定,将会放到", "Primary model (A)": "主要模型 (A)", "Secondary model (B)": "第二模型 (B)", "Tertiary model (C)": "第三模型 (C)", diff --git a/modules/call_queue.py b/modules/call_queue.py index fce0938cedb..50c7f2e1a07 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -436,7 +436,7 @@ def f(username, *args, **kwargs): def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs): tokens = shared.demo.server_app.tokens - cookies = request.headers['cookie'].split('; ') + cookies = shared.get_cookies(request) access_token = None for cookie in cookies: if cookie.startswith('access-token'): diff --git a/modules/model_merger.py b/modules/model_merger.py new file mode 100644 index 00000000000..422fb9c7cac --- /dev/null +++ b/modules/model_merger.py @@ -0,0 +1,293 @@ +from __future__ import annotations +from datetime import datetime, timedelta +import pytz +import json +import gradio as gr +import os +import re +import requests +import sys +import threading + +from modules import shared + +input_chkpt_s3uri = '' +s3_checkpoints = [] +s3_uri_pattern = re.compile(r"^s3://[\w\-\.]+/[\w\-\.\/]+$") + +job_rwlock = threading.RLock() +processing_jobs = {} +last_processing_output_msg = '' + +def get_processing_jobs(): + global job_rwlock + global processing_jobs + + copy = {} + with job_rwlock: + copy = processing_jobs.copy() + return copy + +def add_processing_job(job_name, output_loc): + global job_rwlock + global processing_jobs + + with job_rwlock: + processing_jobs[job_name] = output_loc + +def delete_processing_job(job_name): + global job_rwlock + global processing_jobs + + with job_rwlock: + if job_name in processing_jobs: + del processing_jobs[job_name] + +def get_last_processing_output_message(): + global job_rwlock + global last_processing_output_msg + + last_msg = '' + with job_rwlock: + last_msg = last_processing_output_msg + return last_msg + +def set_last_processing_output_message(msg): + global job_rwlock + global last_processing_output_msg + + with job_rwlock: + last_processing_output_msg = msg + +time_fmt = '%Y-%m-%d-%H-%M-%S-UTC' +job_fmt = f'model-merge-{time_fmt}' + +def uniq_job_name(): + # Valid job name must start with a letter or number ([a-zA-Z0-9]) and can contain up to 63 characters, including hyphens (-). + global time_fmt + global job_fmt + import pytz + + now_utc = datetime.now(pytz.utc) + current_time_str = now_utc.strftime(time_fmt) + job_name = f'model-merge-{current_time_str}' + return job_name + +def get_job_elapsed_time(job_name): + global job_fmt + + timestamp_utc = None + try: + timestamp_utc = datetime.strptime(job_name, job_fmt).replace(tzinfo=pytz.utc) + except ValueError: + print(f"Error: input string {job_name} does not match format: {job_fmt}.") + + if timestamp_utc is None: + return None + + now_utc = datetime.now(pytz.utc) + time_diff = now_utc - timestamp_utc + return time_diff + +def readable_time_diff(time_diff): + total_seconds = int(time_diff.total_seconds()) + + hours, remainder = divmod(total_seconds, 3600) + minutes, seconds = divmod(remainder, 60) + + if hours > 0: + time_str = f"{hours} hours, {minutes} minutes, {seconds} seconds" + elif minutes > 0: + time_str = f"{minutes} minutes, {seconds} seconds" + else: + time_str = f"{seconds} seconds" + + return time_str + +def is_valid_s3_uri(s3_uri): + global s3_uri_pattern + match = s3_uri_pattern.match(s3_uri) + return bool(match) + +def load_checkpoints_from_s3_uri(s3_uri, primary_component, + secondary_component, tertiary_component): + global input_chkpt_s3uri + global s3_checkpoints + + if not is_valid_s3_uri(s3_uri): + return + + input_chkpt_s3uri = s3_uri.rstrip('/') + + s3_checkpoints.clear() + + params = { + 's3uri': input_chkpt_s3uri, + 'exclude_filters': 'yaml', + } + response = requests.get(url=f'{shared.api_endpoint}/s3', params = params) + if response.status_code != 200: + return + + text = json.loads(response.text) + for obj in text['payload']: + obj_key = obj['key'] + ckpt = obj_key.split('/')[-1] + s3_checkpoints.append(ckpt) + + return [gr.Dropdown.update(choices=s3_checkpoints) for _ in range(3)] + +def get_checkpoints_to_merge(): + global s3_checkpoints + return s3_checkpoints + +def get_chkpt_name(checkpoint_file): + name = os.path.basename(checkpoint_file) + if name.startswith("\\") or name.startswith("/"): + name = name[1:] + + chkpt_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + return chkpt_name + +def get_merged_chkpt_name(primary_model_name, secondary_model_name, + tertiary_model_name, multiplier, interp_method, + checkpoint_format, custom_name): + filename = get_chkpt_name(primary_model_name) + '_' + \ + str(round(1-multiplier, 2)) + '-' + \ + get_chkpt_name(secondary_model_name) + '_' + \ + str(round(multiplier, 2)) + '-' + + if isinstance(tertiary_model_name, str) and tertiary_model_name != '': + filename += get_chkpt_name(tertiary_model_name) + '-' + + filename += interp_method.replace(" ", "_") + '-merged.' + checkpoint_format + filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format) + return filename + +def get_processing_job_status(): + job_dict = get_processing_jobs() + if len(job_dict) == 0: + print("No jobs running yet.") + return get_last_processing_output_message() + + ret_message = '' + for job_name, job_output_loc in job_dict.items(): + inputs = {'job_name': job_name} + response = requests.get(url=f'{shared.api_endpoint}/process', json=inputs) + + if response.status_code != 200: + ret_message += f"Processing job {job_name}:\tjob status unknown\n" + continue + + job_elapsed_time = get_job_elapsed_time(job_name) + job_elapsed_timestr = f"Time elapsed: {readable_time_diff(job_elapsed_time)}" \ + if job_elapsed_time is not None else '' + + text = json.loads(response.text) + job_status = text['job_status'] + shall_delete = False + if job_status == 'Completed': + msg = f"finished successfully. Output: {job_output_loc}. {job_elapsed_timestr}" + shall_delete = True + elif job_status == 'Failed': + msg = f"failed: {text['failure_reason']}. {job_elapsed_timestr}" + shall_delete = True + else: + msg = f"still in progress. {job_elapsed_timestr}" + + ret_message += f"Processing job {job_name}:\t{msg}\n" + print(f"Processing job {job_name}: {msg}") + + if shall_delete or (job_elapsed_time and job_elapsed_time > timedelta(hours=1)): + print(f"Romving processing job '{job_name}', job_staus: {job_status}. {job_elapsed_timestr}") + delete_processing_job(job_name) + + if ret_message == '': + ret_message = get_last_processing_output_message() + else: + set_last_processing_output_message(ret_message) + + return ret_message + +def get_default_output_model_s3uri(): + s3uri = shared.get_default_sagemaker_bucket() + \ + '/stable-diffusion-webui/models/Stable-diffusion' + return s3uri + +def run_modelmerger_remote(primary_model_name, secondary_model_name, + tertiary_model_name, interp_method, multiplier, + save_as_half, custom_name, checkpoint_format, + output_chkpt_s3uri, submit_result): + """ This is the same as run_modelmerger, but it calls a RESTful API to do the job """ + if isinstance(primary_model_name, list) or \ + isinstance(secondary_model_name, list): + ret_msg = "At least primary_model_name and secondary_model_name must be set." + set_last_processing_output_message(ret_msg) + return reg_msg + + if output_chkpt_s3uri != '' and not is_valid_s3_uri(output_chkpt_s3uri): + ret_msg = f"output_chkpt_s3uri is not valid: {output_chkpt_s3uri}" + set_last_processing_output_message(ret_msg) + return reg_msg + + input_srcs = f"{input_chkpt_s3uri}/{primary_model_name}," + \ + f"{input_chkpt_s3uri}/{secondary_model_name}" + input_dsts = f"/opt/ml/processing/input/primary," + \ + f"/opt/ml/processing/input/secondary" + + if is_valid_s3_uri(output_chkpt_s3uri): + output_dst = output_chkpt_s3uri + else: + output_dst = get_default_output_model_s3uri() + output_name = get_merged_chkpt_name(primary_model_name, secondary_model_name, + tertiary_model_name, multiplier, interp_method, + checkpoint_format, custom_name) + # Make an argument dict to be accessible in the process script + args = { + "primary_model": primary_model_name, + "secondary_model": secondary_model_name, + "interp_method": interp_method, + "multiplier": multiplier, + "save_as_half": save_as_half, + "checkpoint_format": checkpoint_format, + 'output_destination': output_dst, + 'output_name': output_name, + } + + if custom_name != '': + args["custom_name"] = custom_name + + if isinstance(tertiary_model_name, str) and tertiary_model_name != '': + input_srcs += f",{input_chkpt_s3uri}/{tertiary_model_name}" + input_dsts += f",/opt/ml/processing/input/tertiary" + args["tertiary_model"] = tertiary_model_name + + inputs = { + 'instance_type': 'ml.m5.4xlarge', # Memory intensive + 'instance_count': 1, + 'process_script': 'process_checkpoint_merge.py', + 'input_sources': input_srcs, + 'input_destination': input_dsts, + 'output_sources': '/opt/ml/processing/output', + 'output_destination': output_dst, + 'output_name': output_name, + 'job_name': uniq_job_name(), + 'arguments': args + } + + response = requests.post(url=f'{shared.api_endpoint}/process', json=inputs) + if response.status_code != 200: + ret_msg = f"Failed to run model merge process job: {response.text}" + set_last_processing_output_message(ret_msg) + return ret_msg + + text = json.loads(response.text) + job_name = text['job_name'] + + # Add the job to the list for later status poll + add_processing_job(job_name, f"{output_dst}/{output_name}") + + ret_msg = f"Merging models in Sagemaker Processing Job...\nJob Name: {job_name}" + set_last_processing_output_message(ret_msg) + + return ret_msg diff --git a/modules/shared.py b/modules/shared.py index be090bab00f..2283cfced12 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -827,4 +827,10 @@ def html(filename): with open(path, encoding="utf8") as file: return file.read() - return "" \ No newline at end of file + return "" + +def get_cookies(request): + # request.headers is of type Gradio.queue.Obj, can't be subscripted + # directly, so we need to retrieve its underlying dict first. + cookies = request.headers.__dict__['cookie'].split('; ') + return cookies diff --git a/modules/ui.py b/modules/ui.py index 398e6ea050d..4a54610d92c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -34,6 +34,7 @@ import modules.shared as shared import modules.styles import modules.textual_inversion.ui +import modules.model_merger from modules import prompt_parser from modules.images import save_image from modules.sd_hijack import model_hijack @@ -568,7 +569,6 @@ def update_generation_info(args): # if the json parse or anything else fails, just return the old html_info return html_info - def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def refresh(): refresh_method() @@ -581,7 +581,7 @@ def refresh(): def refresh_sagemaker_endpoints(request : gr.Request): tokens = shared.demo.server_app.tokens - cookies = request.headers['cookie'].split('; ') + cookies = shared.get_cookies(request) access_token = None for cookie in cookies: if cookie.startswith('access-token'): @@ -903,7 +903,7 @@ def run_settings(username, *args): def run_settings_single(value, key, request : gr.Request): tokens = shared.demo.server_app.tokens - cookies = request.headers['cookie'].split('; ') + cookies = shared.get_cookies(request) access_token = None for cookie in cookies: if cookie.startswith('access-token'): @@ -1579,28 +1579,43 @@ def update_orig(image, state): fn=modules.extras.clear_cache, inputs=[], outputs=[] ) - if not cmd_opts.pureui: - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

    A merger of the two checkpoints will be generated in your checkpoint directory.

    ") - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + gr.HTML(value="

    Merged checkpoints will be put in the specified output S3 location

    ") - with gr.Row(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format") - save_as_half = gr.Checkbox(value=False, label="Save as float16") + with gr.Row(): + chkpt_s3uri = gr.Textbox(label="Checkpoint S3 URI", placeholder='s3://bucket/stable-diffusion-webui/models/') + chkpt_s3uri_button = gr.Button(value="Load Checkpoints", elem_id="checkpt_s3uri", variant='primary') + merge_output_s3uri = gr.Textbox(label="Merge Result S3 URI", placeholder="If not specified, will put into " + modules.model_merger.get_default_output_model_s3uri()) + + with gr.Row(): + primary_model_name = gr.Dropdown(modules.model_merger.get_checkpoints_to_merge(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + secondary_model_name = gr.Dropdown(modules.model_merger.get_checkpoints_to_merge(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + tertiary_model_name = gr.Dropdown(modules.model_merger.get_checkpoints_to_merge(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + custom_name = gr.Textbox(label="Custom Name (Optional)") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") + + chkpt_s3uri_button.click( + fn=modules.model_merger.load_checkpoints_from_s3_uri, + inputs=[chkpt_s3uri], + outputs=[primary_model_name, secondary_model_name, tertiary_model_name]) - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + with gr.Row(): + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format") + save_as_half = gr.Checkbox(value=False, label="Save as float16") + + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + with gr.Column(variant='panel'): + submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + + # A periodic function to check the submit output + modelmerger_interface.load(modules.model_merger.get_processing_job_status, + inputs=None, outputs=submit_result, + every=10, queue=True) with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Row().style(equal_height=False): @@ -1858,7 +1873,7 @@ def sagemaker_train_embedding( ): tokens = shared.demo.server_app.tokens - cookies = request.headers['cookie'].split('; ') + cookies = shared.get_cookies(request) access_token = None for cookie in cookies: if cookie.startswith('access-token'): @@ -1991,7 +2006,7 @@ def sagemaker_train_hypernetwork( ): tokens = shared.demo.server_app.tokens - cookies = request.headers['cookie'].split('; ') + cookies = shared.get_cookies(request) access_token = None for cookie in cookies: if cookie.startswith('access-token'): @@ -2193,7 +2208,7 @@ def sagemaker_train_hypernetwork( def save_userdata(user_dataframe, request: gr.Request): tokens = shared.demo.server_app.tokens - cookies = request.headers['cookie'].split('; ') + cookies = shared.get_cookies(request) access_token = None for cookie in cookies: if cookie.startswith('access-token'): @@ -2236,6 +2251,7 @@ def save_userdata(user_dataframe, request: gr.Request): (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), (user_interface, "User", "user") ] @@ -2294,7 +2310,7 @@ def save_userdata(user_dataframe, request: gr.Request): def user_logout(request: gr.Request): tokens = shared.demo.server_app.tokens - cookies = request.headers['cookie'].split('; ') + cookies = shared.get_cookies(request) access_token = None for cookie in cookies: if cookie.startswith('access-token'): @@ -2341,7 +2357,7 @@ def user_logout(request: gr.Request): def demo_load(request: gr.Request): tokens = shared.demo.server_app.tokens - cookies = request.headers['cookie'].split('; ') + cookies = shared.get_cookies(request) access_token = None for cookie in cookies: if cookie.startswith('access-token'): @@ -2383,37 +2399,33 @@ def demo_load(request: gr.Request): outputs=[component_dict[k] for k in component_keys] + [username_state, user_dataframe, shared.sagemaker_endpoint_component, shared.sd_model_checkpoint_component] ) - if not cmd_opts.pureui: - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - custom_name, - checkpoint_format, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) + def modelmerger(*args): + try: + results = modules.model_merger.run_modelmerger_remote(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return "Error running model merge" + return results + + modelmerger_merge.click( + fn=modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + custom_name, + checkpoint_format, + merge_output_s3uri, + submit_result, + ], + outputs=[ + submit_result, + ] + ) ui_config_file = cmd_opts.ui_config_file ui_settings = {} @@ -2479,8 +2491,7 @@ def apply_field(obj, field, condition=None, init_field=None): visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") visit(extras_interface, loadsave, "extras") - if not cmd_opts.pureui: - visit(modelmerger_interface, loadsave, "modelmerger") + visit(modelmerger_interface, loadsave, "modelmerger") if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): with open(ui_config_file, "w", encoding="utf8") as file: diff --git a/webui.py b/webui.py index 299e2edc492..76cbfefe295 100644 --- a/webui.py +++ b/webui.py @@ -634,7 +634,7 @@ def webui(): modules.script_callbacks.before_ui_callback() shared.demo = modules.ui.create_ui() - app, local_url, share_url = shared.demo.launch( + app, local_url, share_url = shared.demo.queue(concurrency_count=5, max_size=20).launch( share=cmd_opts.share, server_name=server_name, server_port=cmd_opts.port, From 63233fd6136cd72a0fa46876999100f5657844a5 Mon Sep 17 00:00:00 2001 From: Jianyu Zhan Date: Fri, 14 Apr 2023 11:06:48 +0800 Subject: [PATCH 31/31] Fix UI user tab missing problem --- javascript/ui.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index 24ff9e8c335..6765f22d9b1 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -213,8 +213,8 @@ function restart_reload(){ } function login(username) { - var user=gradioApp().querySelector('#tabs').querySelectorAll('button')[5]; - var setting=gradioApp().querySelector('#tabs').querySelectorAll('button')[6]; + var user=gradioApp().querySelector('#tabs').querySelectorAll('button')[6]; + var setting=gradioApp().querySelector('#tabs').querySelectorAll('button')[7]; if(username=='admin'){ user.style.display='block'