Skip to content

Commit

Permalink
add infer_type which will depends on the onvironment variable
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Apr 17, 2023
1 parent bddb41f commit 970f489
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions modules/call_queue.py
Expand Up @@ -17,6 +17,7 @@
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops

import os
from modules import shared
import gradio as gr

Expand Down Expand Up @@ -76,8 +77,7 @@ def handle_sagemaker_inference_async(response):

return processed

def sagemaker_inference(task, infer, username, sagemaker_endpoint, *args, **kwargs):
infer = 'async'
def sagemaker_inference(task, infer_type, username, sagemaker_endpoint, *args, **kwargs):
if task == 'text-to-image' or task == 'image-to-image':
if task == 'text-to-image':
script_args = []
Expand Down Expand Up @@ -298,7 +298,7 @@ def sagemaker_inference(task, infer, username, sagemaker_endpoint, *args, **kwar
'endpoint_name': sagemaker_endpoint
}
response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs)
if infer == 'async':
if infer_type == 'async':
processed = handle_sagemaker_inference_async(response)
else:
processed = json.loads(response.text)
Expand Down Expand Up @@ -395,7 +395,7 @@ def sagemaker_inference(task, infer, username, sagemaker_endpoint, *args, **kwar
'endpoint_name': sagemaker_endpoint
}
response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs)
if infer == 'async':
if infer_type == 'async':
processed = handle_sagemaker_inference_async(response)
else:
processed = json.loads(response.text)
Expand All @@ -410,18 +410,19 @@ def sagemaker_inference(task, infer, username, sagemaker_endpoint, *args, **kwar
return images, json.dumps(info), modules.ui.plaintext_to_html('\n'.join(info['infotexts']))

def f(username, *args, **kwargs):
infer_type = os.environ.get('infer_type', 'async')
if cmd_opts.pureui and func == modules.txt2img.txt2img:
sagemaker_endpoint = args[len(args) -1]
args = args[:-1]
res = sagemaker_inference('text-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs)
res = sagemaker_inference('text-to-image', infer_type, username, sagemaker_endpoint, *args, **kwargs)
elif cmd_opts.pureui and func == modules.img2img.img2img:
sagemaker_endpoint = args[len(args) -1]
args = args[:-1]
res = sagemaker_inference('image-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs)
res = sagemaker_inference('image-to-image', infer_type, username, sagemaker_endpoint, *args, **kwargs)
elif cmd_opts.pureui and func == modules.extras.run_extras:
sagemaker_endpoint = args[len(args) -1]
args = args[:-1]
res = sagemaker_inference('extras', 'sync', username, sagemaker_endpoint, *args, **kwargs)
res = sagemaker_inference('extras', infer_type, username, sagemaker_endpoint, *args, **kwargs)
else:
shared.state.begin()
with queue_lock:
Expand Down

0 comments on commit 970f489

Please sign in to comment.