Skip to content

Commit

Permalink
revise for controlnet
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Mar 5, 2023
1 parent a5abc21 commit 5078ae9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
2 changes: 1 addition & 1 deletion modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def modeltitle(path, shorthash):

if shared.cmd_opts.pureui:
params = {
'endpoint_name': sagemaker_endpoint
'module': 'Stable-diffusion', 'endpoint_name': sagemaker_endpoint
}
response = requests.get(url=f'{api_endpoint}/sd/models', params=params)
if response.status_code == 200:
Expand Down
33 changes: 32 additions & 1 deletion webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ def api_only():
if ckpt_dir is not None:
sd_models_path = ckpt_dir

controlnet_dir = cmd_opts.controlnet_dir
cn_models_path = os.path.join(shared.models_path, "ControlNet")
os.makedirs(controlnet_dir, exist_ok=True)
if controlnet_dir is not None:
cn_models_path = controlnet_dir

if 'endpoint_name' in os.environ:
items = []
api_endpoint = os.environ['api_endpoint']
Expand All @@ -165,8 +171,33 @@ def api_only():
inputs = {
'items': items
}
params = {
'module': 'Stable-diffusion'
}
if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'):
response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params)
print(response)

items = []
inputs = {
'items': items
}
params = {
'module': 'ControlNet'
}
for file in os.listdir(cn_models_path):
if os.path.isfile(os.path.join(cn_models_path, file)) and \
(file.endswith('pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')):
hash = modules.sd_models.model_hash(os.path.join(cn_models_path, file))
item = {}
item['model_name'] = file
item['title'] = '{0} [{1}]'.format(file, hash)
item['endpoint_name'] = endpoint_name
items.append(item)

if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'):
response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs)
response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params)
print(response)

modules.script_callbacks.app_started_callback(None, app)

Expand Down

0 comments on commit 5078ae9

Please sign in to comment.