Skip to content

Commit

Permalink
fix ckptmerger and add user isolation
Browse files Browse the repository at this point in the history
  • Loading branch information
xie river committed Apr 17, 2023
1 parent 54e983b commit 8832e51
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 35 deletions.
4 changes: 2 additions & 2 deletions javascript/ui.js
Expand Up @@ -213,8 +213,8 @@ function restart_reload(){
}

function login(username) {
var user=gradioApp().querySelector('#tabs').querySelectorAll('button')[6];
var setting=gradioApp().querySelector('#tabs').querySelectorAll('button')[7];
var user=gradioApp().querySelector('#tabs').querySelectorAll('button')[5];
var setting=gradioApp().querySelector('#tabs').querySelectorAll('button')[6];

if(username=='admin'){
user.style.display='block'
Expand Down
2 changes: 2 additions & 0 deletions localizations/zh_CN.json
Expand Up @@ -858,5 +858,7 @@
"Submit":"确定",
"columns width":"每行图片列数",
"Show current user's images only":"只显示当前用户图片集",
"Don't load other user's models":"不加载其他用户目录模型",
"Merge":"合并",
"--------": "--------"
}
44 changes: 30 additions & 14 deletions modules/model_merger.py
Expand Up @@ -109,11 +109,9 @@ def is_valid_s3_uri(s3_uri):
match = s3_uri_pattern.match(s3_uri)
return bool(match)

def load_checkpoints_from_s3_uri(s3_uri, primary_component,
secondary_component, tertiary_component):
def load_checkpoints_from_s3_uri(s3_uri,load_all_user,username):
global input_chkpt_s3uri
global s3_checkpoints

if not is_valid_s3_uri(s3_uri):
return

Expand All @@ -130,11 +128,22 @@ def load_checkpoints_from_s3_uri(s3_uri, primary_component,
return

text = json.loads(response.text)
for obj in text['payload']:
obj_key = obj['key']
ckpt = obj_key.split('/')[-1]
s3_checkpoints.append(ckpt)


if not load_all_user:
for obj in text['payload']:
title = obj['key'].replace('stable-diffusion-webui/models/Stable-diffusion/','')
# ckpt = title.split('/')[-1]
s3_checkpoints.append(title)
else:
for obj in text['payload']:
title = obj['key'].replace('stable-diffusion-webui/models/Stable-diffusion/','')
##filter by username . e.g title: river/jp-style-girl-3_200_lora.safetensors
dir = title.split('/')
if len(dir) > 1:
dir_user = dir[0]
if dir_user != username:
continue
s3_checkpoints.append(title)
return [gr.Dropdown.update(choices=s3_checkpoints) for _ in range(3)]

def get_checkpoints_to_merge():
Expand Down Expand Up @@ -217,31 +226,38 @@ def get_default_output_model_s3uri():
def run_modelmerger_remote(primary_model_name, secondary_model_name,
tertiary_model_name, interp_method, multiplier,
save_as_half, custom_name, checkpoint_format,
output_chkpt_s3uri, submit_result):
output_chkpt_s3uri, submit_result,request):
""" This is the same as run_modelmerger, but it calls a RESTful API to do the job """
if isinstance(primary_model_name, list) or \
isinstance(secondary_model_name, list):
ret_msg = "At least primary_model_name and secondary_model_name must be set."
set_last_processing_output_message(ret_msg)
return reg_msg
return ret_msg

if output_chkpt_s3uri != '' and not is_valid_s3_uri(output_chkpt_s3uri):
ret_msg = f"output_chkpt_s3uri is not valid: {output_chkpt_s3uri}"
set_last_processing_output_message(ret_msg)
return reg_msg
return ret_msg

input_srcs = f"{input_chkpt_s3uri}/{primary_model_name}," + \
f"{input_chkpt_s3uri}/{secondary_model_name}"
input_dsts = f"/opt/ml/processing/input/primary," + \
f"/opt/ml/processing/input/secondary"

username = shared.get_webui_username(request)
if is_valid_s3_uri(output_chkpt_s3uri):
output_dst = output_chkpt_s3uri
if output_chkpt_s3uri[-1] == '/':
output_dst = output_chkpt_s3uri+username
else:
output_dst = output_chkpt_s3uri+'/'+username
else:
output_dst = get_default_output_model_s3uri()
output_dst = get_default_output_model_s3uri()+'/'+username
output_name = get_merged_chkpt_name(primary_model_name, secondary_model_name,
tertiary_model_name, multiplier, interp_method,
checkpoint_format, custom_name)
##outputName' failed to satisfy constraint: Member must have length less than or equal to 64
if len(output_name) > 64:
output_name = output_name[len(output_name)-64:]

# Make an argument dict to be accessible in the process script
args = {
"primary_model": primary_model_name,
Expand Down
63 changes: 44 additions & 19 deletions modules/ui.py
Expand Up @@ -25,7 +25,6 @@
from modules.paths import script_path

from modules.shared import opts, cmd_opts, restricted_opts,get_default_sagemaker_bucket

import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste
import modules.gfpgan_model
Expand Down Expand Up @@ -759,8 +758,11 @@ def image_viewer(path,cols_width,current_only,request:gr.Request):
with gr.Row():
with gr.Column(scale=3):
images_s3_path = gr.Textbox(label="Input S3 path of images",visible=False, value = get_default_sagemaker_bucket()+'/stable-diffusion-webui/generated')
dummy_images_s3_path = gr.Textbox(label="Input S3 path of images",visible=True, interactive=False,
value = get_default_sagemaker_bucket()+'/stable-diffusion-webui/generated/{username}')

with gr.Column(scale=1):
show_user_only = gr.Checkbox(label="Show current user's images only", value=True,visible=False)
show_user_only = gr.Checkbox(label="Show current user's images only", value=True,visible=True,interactive=False)
with gr.Column(scale=1):
cols_width = gr.Slider(minimum=4, maximum=20, step=1, label="columns width", value=8)
with gr.Column(scale=1):
Expand Down Expand Up @@ -1548,16 +1550,32 @@ def update_orig(image, state):
fn=modules.extras.clear_cache,
inputs=[], outputs=[]
)

def load_checkpoints_from_s3_uri(model_s3url,load_all_user,request:gr.Request):
username = shared.get_webui_username(request)
print(username)
return modules.model_merger.load_checkpoints_from_s3_uri(model_s3url,load_all_user,username)

with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Column():
gr.HTML(value="<p>Merged checkpoints will be put in the specified output S3 location</p>")

default_ckpt_s3 = get_default_sagemaker_bucket()+'/stable-diffusion-webui/models/Stable-diffusion/'
# default_merge_output_s3 = default_ckpt_s3
with gr.Row():
chkpt_s3uri = gr.Textbox(label="Checkpoint S3 URI", placeholder='s3://bucket/stable-diffusion-webui/models/')
chkpt_s3uri_button = gr.Button(value="Load Checkpoints", elem_id="checkpt_s3uri", variant='primary')
merge_output_s3uri = gr.Textbox(label="Merge Result S3 URI", placeholder="If not specified, will put into " + modules.model_merger.get_default_output_model_s3uri())

dummy_s3uri = gr.Textbox(label="Checkpoint S3 URI", elem_id="dummy_chkpt_s3uri",
value='模型存放位置:'+default_ckpt_s3+'{用户名}',
lines=2, visible=True,interactive=False)
chkpt_s3uri = gr.Textbox(label="Checkpoint S3 URI", elem_id="chkpt_s3uri", value= default_ckpt_s3,lines=2,visible=False)
merge_output_s3uri = gr.Textbox(label="Merge Result S3 URI",lines=2, visible=True,placeholder='(选填),默认输出位置:'+default_ckpt_s3+'{用户名}')
with gr.Row():
with gr.Column(scale=1):
load_all_user = gr.Checkbox(label="Don't load other user's models",interactive=False, value=True,visible=True)
with gr.Column(scale=2):
chkpt_s3uri_button = gr.Button(value="Load Checkpoints", variant='primary')



with gr.Row():
primary_model_name = gr.Dropdown(modules.model_merger.get_checkpoints_to_merge(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
secondary_model_name = gr.Dropdown(modules.model_merger.get_checkpoints_to_merge(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
Expand All @@ -1566,19 +1584,19 @@ def update_orig(image, state):
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")

chkpt_s3uri_button.click(
fn=modules.model_merger.load_checkpoints_from_s3_uri,
inputs=[chkpt_s3uri],
outputs=[primary_model_name, secondary_model_name, tertiary_model_name])

with gr.Row():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format")
save_as_half = gr.Checkbox(value=False, label="Save as float16")

modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')

with gr.Column(variant='panel'):
with gr.Column():
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
chkpt_s3uri_button.click(
fn=load_checkpoints_from_s3_uri,
inputs=[chkpt_s3uri,load_all_user],
outputs=[primary_model_name, secondary_model_name, tertiary_model_name])

# A periodic function to check the submit output
modelmerger_interface.load(modules.model_merger.get_processing_job_status,
Expand Down Expand Up @@ -2198,7 +2216,7 @@ def save_userdata(user_dataframe, request: gr.Request):
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
# (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
(user_interface, "User", "user")
]
Expand All @@ -2208,7 +2226,7 @@ def save_userdata(user_dataframe, request: gr.Request):
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
# (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
]

Expand All @@ -2230,6 +2248,7 @@ def save_userdata(user_dataframe, request: gr.Request):

# interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")]
interfaces += [(modelmerger_interface,"Checkpoint Merger", "modelmerger")]
interfaces += [(imagesviewer_interface,"Images Viewer","imagesviewer")]

extensions_interface = ui_extensions.create_ui()
Expand Down Expand Up @@ -2339,9 +2358,15 @@ def demo_load(request: gr.Request):
outputs=[component_dict[k] for k in component_keys] + [username_state, user_dataframe, shared.sagemaker_endpoint_component, shared.sd_model_checkpoint_component]
)

def modelmerger(*args):
def modelmerger(primary_model_name, secondary_model_name,
tertiary_model_name, interp_method, multiplier,
save_as_half, custom_name, checkpoint_format,
output_chkpt_s3uri, submit_result,request:gr.Request):
try:
results = modules.model_merger.run_modelmerger_remote(*args)
results = modules.model_merger.run_modelmerger_remote(primary_model_name, secondary_model_name,
tertiary_model_name, interp_method, multiplier,
save_as_half, custom_name, checkpoint_format,
output_chkpt_s3uri, submit_result,request)
except Exception as e:
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
Expand Down Expand Up @@ -2431,7 +2456,7 @@ def apply_field(obj, field, condition=None, init_field=None):
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
visit(extras_interface, loadsave, "extras")
visit(modelmerger_interface, loadsave, "modelmerger")
# visit(modelmerger_interface, loadsave, "modelmerger")

if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
with open(ui_config_file, "w", encoding="utf8") as file:
Expand Down

0 comments on commit 8832e51

Please sign in to comment.