Skip to content

Commit

Permalink
cleanup with sd_model
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Feb 6, 2023
1 parent 326969a commit 4e7db20
Showing 1 changed file with 6 additions and 39 deletions.
45 changes: 6 additions & 39 deletions modules/sd_models.py
Expand Up @@ -60,7 +60,6 @@ def list_models(sagemaker_endpoint=None):
global checkpoints_list

checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])

def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
Expand All @@ -79,25 +78,6 @@ def modeltitle(path, shorthash):

return f'{name} [{shorthash}]', shortname

cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)

basename, _ = os.path.splitext(filename)
config = basename + ".yaml"
if not os.path.exists(config):
config = shared.cmd_opts.config

checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)

if shared.cmd_opts.pureui:
params = {
'endpoint_name': sagemaker_endpoint
Expand Down Expand Up @@ -132,26 +112,9 @@ def modeltitle(path, shorthash):
else:
shared.sd_model = None
else:
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])

def modeltitle(path, shorthash):
abspath = os.path.abspath(path)

if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
name = os.path.basename(path)

if name.startswith("\\") or name.startswith("/"):
name = name[1:]

shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]

return f'{name} [{shorthash}]', shortname

cmd_ckpt = shared.cmd_opts.ckpt
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])

if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
Expand Down Expand Up @@ -369,6 +332,10 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model:
sd_model = shared.sd_model

while not sd_model:
load_model()
sd_model = shared.sd_model

if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return

Expand Down

0 comments on commit 4e7db20

Please sign in to comment.