Skip to content

Commit

Permalink
fix bugs with concurrent usage
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Feb 1, 2023
1 parent 2fdec78 commit 2109a90
Show file tree
Hide file tree
Showing 5 changed files with 988 additions and 999 deletions.
116 changes: 0 additions & 116 deletions config.json

This file was deleted.

36 changes: 25 additions & 11 deletions modules/call_queue.py
Expand Up @@ -75,7 +75,7 @@ def handle_sagemaker_inference_async(response):

return processed

def sagemaker_inference(task, infer, *args, **kwargs):
def sagemaker_inference(task, infer, username, sagemaker_endpoint, *args, **kwargs):
infer = 'async'
if task == 'text-to-image' or task == 'image-to-image':
if task == 'text-to-image':
Expand Down Expand Up @@ -139,7 +139,7 @@ def sagemaker_inference(task, infer, *args, **kwargs):
inputs = {
'task': task,
'txt2img_payload': payload,
'username': shared.username
'username': username
}
else:
mode = args[0]
Expand Down Expand Up @@ -273,14 +273,13 @@ def sagemaker_inference(task, infer, *args, **kwargs):
inputs = {
'task': task,
'img2img_payload': payload,
'username': shared.username
'username': username
}
print(sd_samplers.samplers[sampler_index].name)

params = {
'endpoint_name': shared.opts.sagemaker_endpoint
'endpoint_name': sagemaker_endpoint
}

response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs)
if infer == 'async':
processed = handle_sagemaker_inference_async(response)
Expand Down Expand Up @@ -340,7 +339,7 @@ def sagemaker_inference(task, infer, *args, **kwargs):
inputs = {
'task': task,
'extras_single_payload': payload,
'username': shared.username
'username': username
}
else:
imageList = []
Expand Down Expand Up @@ -372,11 +371,11 @@ def sagemaker_inference(task, infer, *args, **kwargs):
inputs = {
'task': task,
'extras_batch_payload': payload,
'username': shared.username
'username': username
}

params = {
'endpoint_name': shared.opts.sagemaker_endpoint
'endpoint_name': sagemaker_endpoint
}
response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs)
if infer == 'async':
Expand All @@ -395,11 +394,26 @@ def sagemaker_inference(task, infer, *args, **kwargs):

def f(*args, **kwargs):
if cmd_opts.pureui and func == modules.txt2img.txt2img:
res = sagemaker_inference('text-to-image', 'sync', *args, **kwargs)
username = args[len(args) - 2]
sagemaker_endpoint = args[len(args) -1]
args = args[:-2]
print('username:', username)
print('sagemaker_endpoint:', sagemaker_endpoint)
res = sagemaker_inference('text-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs)
elif cmd_opts.pureui and func == modules.img2img.img2img:
res = sagemaker_inference('image-to-image', 'sync', *args, **kwargs)
username = args[len(args) - 2]
sagemaker_endpoint = args[len(args) -1]
args = args[:-2]
print('username:', username)
print('sagemaker_endpoint:', sagemaker_endpoint)
res = sagemaker_inference('image-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs)
elif cmd_opts.pureui and func == modules.extras.run_extras:
res = sagemaker_inference('extras', 'sync', *args, **kwargs)
username = args[len(args) - 2]
sagemaker_endpoint = args[len(args) -1]
args = args[:-2]
print('username:', username)
print('sagemaker_endpoint:', sagemaker_endpoint)
res = sagemaker_inference('extras', 'sync', username, sagemaker_endpoint, *args, **kwargs)
else:
shared.state.begin()
with queue_lock:
Expand Down
4 changes: 2 additions & 2 deletions modules/sd_models.py
Expand Up @@ -56,7 +56,7 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)


def list_models():
def list_models(sagemaker_endpoint=None):
global checkpoints_list

checkpoints_list.clear()
Expand Down Expand Up @@ -100,7 +100,7 @@ def modeltitle(path, shorthash):

if shared.cmd_opts.pureui:
params = {
'endpoint_name': shared.opts.sagemaker_endpoint
'endpoint_name': sagemaker_endpoint
}
response = requests.get(url=f'{api_endpoint}/sd/models', params=params)
if response.status_code == 200:
Expand Down

0 comments on commit 2109a90

Please sign in to comment.