Skip to content

Commit

Permalink
revise to get username from session token
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Mar 8, 2023
1 parent 231076a commit c25694c
Showing 1 changed file with 20 additions and 27 deletions.
47 changes: 20 additions & 27 deletions modules/call_queue.py
Expand Up @@ -32,16 +32,7 @@ def f(*args, **kwargs):

return f

def wrap_gradio_gpu_call(func, request:gr.Request, extra_outputs=None):
tokens = shared.demo.server_app.tokens
cookies = request.headers['cookie'].split('; ')
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None

def wrap_gradio_gpu_call(func, extra_outputs=None):
def encode_image_to_base64(image):
if isinstance(image, bytes):
encoded_string = base64.b64encode(image)
Expand Down Expand Up @@ -85,7 +76,7 @@ def handle_sagemaker_inference_async(response):

return processed

def sagemaker_inference(task, infer, sagemaker_endpoint, *args, **kwargs):
def sagemaker_inference(task, infer, username, sagemaker_endpoint, *args, **kwargs):
infer = 'async'
if task == 'text-to-image' or task == 'image-to-image':
if task == 'text-to-image':
Expand Down Expand Up @@ -418,19 +409,19 @@ def sagemaker_inference(task, infer, sagemaker_endpoint, *args, **kwargs):
info = processed['html_info']
return images, modules.ui.plaintext_to_html(info), ''

def f(*args, **kwargs):
def f(username, *args, **kwargs):
if cmd_opts.pureui and func == modules.txt2img.txt2img:
sagemaker_endpoint = args[len(args) -1]
args = args[:-2]
res = sagemaker_inference('text-to-image', 'sync', sagemaker_endpoint, *args, **kwargs)
res = sagemaker_inference('text-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs)
elif cmd_opts.pureui and func == modules.img2img.img2img:
sagemaker_endpoint = args[len(args) -1]
args = args[:-2]
res = sagemaker_inference('image-to-image', 'sync', sagemaker_endpoint, *args, **kwargs)
res = sagemaker_inference('image-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs)
elif cmd_opts.pureui and func == modules.extras.run_extras:
sagemaker_endpoint = args[len(args) -1]
args = args[:-2]
res = sagemaker_inference('extras', 'sync', sagemaker_endpoint, *args, **kwargs)
res = sagemaker_inference('extras', 'sync', username, sagemaker_endpoint, *args, **kwargs)
else:
shared.state.begin()
with queue_lock:
Expand All @@ -442,25 +433,27 @@ def f(*args, **kwargs):
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)


def wrap_gradio_call(func, request : gr.Request, extra_outputs=None, add_stats=False):
tokens = shared.demo.server_app.tokens
cookies = request.headers['cookie'].split('; ')
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs):
tokens = shared.demo.server_app.tokens
cookies = request.headers['cookie'].split('; ')
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token]

def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()

try:
args.append(username)
res = list(func(*args, **kwargs))
if func.__name__ == 'f':
res = list(func(username, *args, **kwargs))
else:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
Expand Down

0 comments on commit c25694c

Please sign in to comment.