Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#1 from xiehust/igg
Browse files Browse the repository at this point in the history
merger igg's ckptmerger fix
  • Loading branch information
xiehust committed Apr 17, 2023
2 parents bddb41f + 8832e51 commit a0a907d
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 46 deletions.
4 changes: 2 additions & 2 deletions javascript/ui.js
Expand Up @@ -213,8 +213,8 @@ function restart_reload(){
}

function login(username) {
var user=gradioApp().querySelector('#tabs').querySelectorAll('button')[6];
var setting=gradioApp().querySelector('#tabs').querySelectorAll('button')[7];
var user=gradioApp().querySelector('#tabs').querySelectorAll('button')[5];
var setting=gradioApp().querySelector('#tabs').querySelectorAll('button')[6];

if(username=='admin'){
user.style.display='block'
Expand Down
2 changes: 2 additions & 0 deletions localizations/zh_CN.json
Expand Up @@ -858,5 +858,7 @@
"Submit":"确定",
"columns width":"每行图片列数",
"Show current user's images only":"只显示当前用户图片集",
"Don't load other user's models":"不加载其他用户目录模型",
"Merge":"合并",
"--------": "--------"
}
44 changes: 30 additions & 14 deletions modules/model_merger.py
Expand Up @@ -109,11 +109,9 @@ def is_valid_s3_uri(s3_uri):
match = s3_uri_pattern.match(s3_uri)
return bool(match)

def load_checkpoints_from_s3_uri(s3_uri, primary_component,
secondary_component, tertiary_component):
def load_checkpoints_from_s3_uri(s3_uri,load_all_user,username):
global input_chkpt_s3uri
global s3_checkpoints

if not is_valid_s3_uri(s3_uri):
return

Expand All @@ -130,11 +128,22 @@ def load_checkpoints_from_s3_uri(s3_uri, primary_component,
return

text = json.loads(response.text)
for obj in text['payload']:
obj_key = obj['key']
ckpt = obj_key.split('/')[-1]
s3_checkpoints.append(ckpt)


if not load_all_user:
for obj in text['payload']:
title = obj['key'].replace('stable-diffusion-webui/models/Stable-diffusion/','')
# ckpt = title.split('/')[-1]
s3_checkpoints.append(title)
else:
for obj in text['payload']:
title = obj['key'].replace('stable-diffusion-webui/models/Stable-diffusion/','')
##filter by username . e.g title: river/jp-style-girl-3_200_lora.safetensors
dir = title.split('/')
if len(dir) > 1:
dir_user = dir[0]
if dir_user != username:
continue
s3_checkpoints.append(title)
return [gr.Dropdown.update(choices=s3_checkpoints) for _ in range(3)]

def get_checkpoints_to_merge():
Expand Down Expand Up @@ -217,31 +226,38 @@ def get_default_output_model_s3uri():
def run_modelmerger_remote(primary_model_name, secondary_model_name,
tertiary_model_name, interp_method, multiplier,
save_as_half, custom_name, checkpoint_format,
output_chkpt_s3uri, submit_result):
output_chkpt_s3uri, submit_result,request):
""" This is the same as run_modelmerger, but it calls a RESTful API to do the job """
if isinstance(primary_model_name, list) or \
isinstance(secondary_model_name, list):
ret_msg = "At least primary_model_name and secondary_model_name must be set."
set_last_processing_output_message(ret_msg)
return reg_msg
return ret_msg

if output_chkpt_s3uri != '' and not is_valid_s3_uri(output_chkpt_s3uri):
ret_msg = f"output_chkpt_s3uri is not valid: {output_chkpt_s3uri}"
set_last_processing_output_message(ret_msg)
return reg_msg
return ret_msg

input_srcs = f"{input_chkpt_s3uri}/{primary_model_name}," + \
f"{input_chkpt_s3uri}/{secondary_model_name}"
input_dsts = f"/opt/ml/processing/input/primary," + \
f"/opt/ml/processing/input/secondary"

username = shared.get_webui_username(request)
if is_valid_s3_uri(output_chkpt_s3uri):
output_dst = output_chkpt_s3uri
if output_chkpt_s3uri[-1] == '/':
output_dst = output_chkpt_s3uri+username
else:
output_dst = output_chkpt_s3uri+'/'+username
else:
output_dst = get_default_output_model_s3uri()
output_dst = get_default_output_model_s3uri()+'/'+username
output_name = get_merged_chkpt_name(primary_model_name, secondary_model_name,
tertiary_model_name, multiplier, interp_method,
checkpoint_format, custom_name)
##outputName' failed to satisfy constraint: Member must have length less than or equal to 64
if len(output_name) > 64:
output_name = output_name[len(output_name)-64:]

# Make an argument dict to be accessible in the process script
args = {
"primary_model": primary_model_name,
Expand Down
10 changes: 9 additions & 1 deletion modules/sd_models.py
Expand Up @@ -57,7 +57,7 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)


def list_models(sagemaker_endpoint=None):
def list_models(sagemaker_endpoint=None,username=''):
global checkpoints_list

checkpoints_list.clear()
Expand Down Expand Up @@ -90,12 +90,20 @@ def modeltitle(path, shorthash):
model_list = json.loads(response.text)

for model in model_list:

h = model['hash']
filename = model['filename']
title = model['title']
short_model_name = model['model_name']
config = model['config']

##filter by username . e.g title: river/jp-style-girl-3_200_lora.safetensors
dir = title.split('/')
if len(dir) > 1:
dir_user = dir[0]
if dir_user != username:
continue

if 'sd_model_checkpoint' not in shared.opts.data:
shared.opts.data['sd_model_checkpoint'] = title

Expand Down
7 changes: 3 additions & 4 deletions modules/shared.py
Expand Up @@ -387,9 +387,9 @@ def list_checkpoint_tiles():
return modules.sd_models.checkpoint_tiles()


def refresh_checkpoints(sagemaker_endpoint=None):
def refresh_checkpoints(sagemaker_endpoint=None,username=''):
import modules.sd_models
modules.sd_models.list_models(sagemaker_endpoint)
modules.sd_models.list_models(sagemaker_endpoint,username)
checkpoints = modules.sd_models.checkpoints_list
return checkpoints

Expand Down Expand Up @@ -486,7 +486,6 @@ def refresh_sd_models(username):
return sd_models

options_templates.update(options_section(('sd', "Stable Diffusion"), {
# "models_s3_bucket": OptionInfo(f'{get_default_sagemaker_bucket()}/stable-diffusion-webui/models/', "S3 path for downloading model files (E.g, s3://bucket-name/models/)", ),
"sagemaker_endpoint": OptionInfo(None, "SaegMaker endpoint", gr.Dropdown, lambda: {"choices": list_sagemaker_endpoints()}, refresh=refresh_sagemaker_endpoints),
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
Expand Down Expand Up @@ -538,7 +537,7 @@ def refresh_sd_models(username):
}))

options_templates.update(options_section(('saving-paths', "Paths for saving"), {
"train_files_s3bucket":OptionInfo(get_default_sagemaker_bucket(),"S3 bucket name for uploading/downloading images",component_args=hide_dirs),
# "train_files_s3bucket":OptionInfo(get_default_sagemaker_bucket(),"S3 bucket name for uploading/downloading images",component_args=hide_dirs),
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
Expand Down
76 changes: 51 additions & 25 deletions modules/ui.py
Expand Up @@ -25,7 +25,6 @@
from modules.paths import script_path

from modules.shared import opts, cmd_opts, restricted_opts,get_default_sagemaker_bucket

import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste
import modules.gfpgan_model
Expand Down Expand Up @@ -590,8 +589,9 @@ def refresh_sd_models(request: gr.Request):

return gr.update(**(args or {}))

def refresh_checkpoints(sagemaker_endpoint):
refresh_method(sagemaker_endpoint)
def refresh_checkpoints(sagemaker_endpoint,request:gr.Request):
username = shared.get_webui_username(request)
refresh_method(sagemaker_endpoint,username)
args = refreshed_args() if callable(refreshed_args) else refreshed_args

for k, v in args.items():
Expand Down Expand Up @@ -757,9 +757,12 @@ def image_viewer(path,cols_width,current_only,request:gr.Request):
with gr.Blocks(analytics_enabled=False) as imagesviewer_interface:
with gr.Row():
with gr.Column(scale=3):
images_s3_path = gr.Textbox(label="Input S3 path of images",value = get_default_sagemaker_bucket()+'/stable-diffusion-webui/generated')
images_s3_path = gr.Textbox(label="Input S3 path of images",visible=False, value = get_default_sagemaker_bucket()+'/stable-diffusion-webui/generated')
dummy_images_s3_path = gr.Textbox(label="Input S3 path of images",visible=True, interactive=False,
value = get_default_sagemaker_bucket()+'/stable-diffusion-webui/generated/{username}')

with gr.Column(scale=1):
show_user_only = gr.Checkbox(label="Show current user's images only", value=True)
show_user_only = gr.Checkbox(label="Show current user's images only", value=True,visible=True,interactive=False)
with gr.Column(scale=1):
cols_width = gr.Slider(minimum=4, maximum=20, step=1, label="columns width", value=8)
with gr.Column(scale=1):
Expand Down Expand Up @@ -919,9 +922,9 @@ def run_settings_single(value, key, request : gr.Request):
with gr.Row():
with gr.Column(scale=4):
models_s3bucket = gr.Textbox(label="S3 path for downloading model files (E.g, s3://bucket-name/models/)",
value=default_s3_path)
value=default_s3_path,visible=False)
with gr.Column(scale=1):
set_models_s3bucket_btn = gr.Button(value="Update model files path",elem_id='id_set_models_s3bucket')
set_models_s3bucket_btn = gr.Button(value="Update model files path",elem_id='id_set_models_s3bucket',visible=False)
with gr.Column(scale=1):
reload_models_btn = gr.Button(value='Reload all models', elem_id='id_reload_all_models')

Expand Down Expand Up @@ -1547,16 +1550,32 @@ def update_orig(image, state):
fn=modules.extras.clear_cache,
inputs=[], outputs=[]
)

def load_checkpoints_from_s3_uri(model_s3url,load_all_user,request:gr.Request):
username = shared.get_webui_username(request)
print(username)
return modules.model_merger.load_checkpoints_from_s3_uri(model_s3url,load_all_user,username)

with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Column():
gr.HTML(value="<p>Merged checkpoints will be put in the specified output S3 location</p>")

default_ckpt_s3 = get_default_sagemaker_bucket()+'/stable-diffusion-webui/models/Stable-diffusion/'
# default_merge_output_s3 = default_ckpt_s3
with gr.Row():
chkpt_s3uri = gr.Textbox(label="Checkpoint S3 URI", placeholder='s3://bucket/stable-diffusion-webui/models/')
chkpt_s3uri_button = gr.Button(value="Load Checkpoints", elem_id="checkpt_s3uri", variant='primary')
merge_output_s3uri = gr.Textbox(label="Merge Result S3 URI", placeholder="If not specified, will put into " + modules.model_merger.get_default_output_model_s3uri())

dummy_s3uri = gr.Textbox(label="Checkpoint S3 URI", elem_id="dummy_chkpt_s3uri",
value='模型存放位置:'+default_ckpt_s3+'{用户名}',
lines=2, visible=True,interactive=False)
chkpt_s3uri = gr.Textbox(label="Checkpoint S3 URI", elem_id="chkpt_s3uri", value= default_ckpt_s3,lines=2,visible=False)
merge_output_s3uri = gr.Textbox(label="Merge Result S3 URI",lines=2, visible=True,placeholder='(选填),默认输出位置:'+default_ckpt_s3+'{用户名}')
with gr.Row():
with gr.Column(scale=1):
load_all_user = gr.Checkbox(label="Don't load other user's models",interactive=False, value=True,visible=True)
with gr.Column(scale=2):
chkpt_s3uri_button = gr.Button(value="Load Checkpoints", variant='primary')



with gr.Row():
primary_model_name = gr.Dropdown(modules.model_merger.get_checkpoints_to_merge(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
secondary_model_name = gr.Dropdown(modules.model_merger.get_checkpoints_to_merge(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
Expand All @@ -1565,19 +1584,19 @@ def update_orig(image, state):
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")

chkpt_s3uri_button.click(
fn=modules.model_merger.load_checkpoints_from_s3_uri,
inputs=[chkpt_s3uri],
outputs=[primary_model_name, secondary_model_name, tertiary_model_name])

with gr.Row():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format")
save_as_half = gr.Checkbox(value=False, label="Save as float16")

modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')

with gr.Column(variant='panel'):
with gr.Column():
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
chkpt_s3uri_button.click(
fn=load_checkpoints_from_s3_uri,
inputs=[chkpt_s3uri,load_all_user],
outputs=[primary_model_name, secondary_model_name, tertiary_model_name])

# A periodic function to check the submit output
modelmerger_interface.load(modules.model_merger.get_processing_job_status,
Expand Down Expand Up @@ -2197,7 +2216,7 @@ def save_userdata(user_dataframe, request: gr.Request):
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
# (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
(user_interface, "User", "user")
]
Expand All @@ -2207,7 +2226,7 @@ def save_userdata(user_dataframe, request: gr.Request):
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
# (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
]

Expand All @@ -2229,6 +2248,7 @@ def save_userdata(user_dataframe, request: gr.Request):

# interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")]
interfaces += [(modelmerger_interface,"Checkpoint Merger", "modelmerger")]
interfaces += [(imagesviewer_interface,"Images Viewer","imagesviewer")]

extensions_interface = ui_extensions.create_ui()
Expand Down Expand Up @@ -2325,7 +2345,7 @@ def demo_load(request: gr.Request):
print(e)
shared.refresh_sagemaker_endpoints(username)
shared.refresh_sd_models(username)
shared.refresh_checkpoints(shared.opts.sagemaker_endpoint)
shared.refresh_checkpoints(shared.opts.sagemaker_endpoint,username)
additional_components = [gr.update(value=username), gr.update(), gr.update(value=shared.opts.sagemaker_endpoint, choices=shared.sagemaker_endpoints), gr.update(value=shared.opts.sd_model_checkpoint, choices=modules.sd_models.checkpoint_tiles())]
else:
additional_components = [gr.update(value=username), gr.update(), gr.update(), gr.update()]
Expand All @@ -2338,9 +2358,15 @@ def demo_load(request: gr.Request):
outputs=[component_dict[k] for k in component_keys] + [username_state, user_dataframe, shared.sagemaker_endpoint_component, shared.sd_model_checkpoint_component]
)

def modelmerger(*args):
def modelmerger(primary_model_name, secondary_model_name,
tertiary_model_name, interp_method, multiplier,
save_as_half, custom_name, checkpoint_format,
output_chkpt_s3uri, submit_result,request:gr.Request):
try:
results = modules.model_merger.run_modelmerger_remote(*args)
results = modules.model_merger.run_modelmerger_remote(primary_model_name, secondary_model_name,
tertiary_model_name, interp_method, multiplier,
save_as_half, custom_name, checkpoint_format,
output_chkpt_s3uri, submit_result,request)
except Exception as e:
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
Expand Down Expand Up @@ -2430,7 +2456,7 @@ def apply_field(obj, field, condition=None, init_field=None):
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
visit(extras_interface, loadsave, "extras")
visit(modelmerger_interface, loadsave, "modelmerger")
# visit(modelmerger_interface, loadsave, "modelmerger")

if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
with open(ui_config_file, "w", encoding="utf8") as file:
Expand Down

0 comments on commit a0a907d

Please sign in to comment.