From c25694cc4f1e8926f03185f7f5582eee5d290e2a Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Wed, 8 Mar 2023 09:41:05 +0800 Subject: [PATCH] revise to get username from session token --- modules/call_queue.py | 47 ++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/modules/call_queue.py b/modules/call_queue.py index 69e45c0bae8..19d17f436cf 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -32,16 +32,7 @@ def f(*args, **kwargs): return f -def wrap_gradio_gpu_call(func, request:gr.Request, extra_outputs=None): - tokens = shared.demo.server_app.tokens - cookies = request.headers['cookie'].split('; ') - access_token = None - for cookie in cookies: - if cookie.startswith('access-token'): - access_token = cookie[len('access-token=') : ] - break - username = tokens[access_token] if access_token else None - +def wrap_gradio_gpu_call(func, extra_outputs=None): def encode_image_to_base64(image): if isinstance(image, bytes): encoded_string = base64.b64encode(image) @@ -85,7 +76,7 @@ def handle_sagemaker_inference_async(response): return processed - def sagemaker_inference(task, infer, sagemaker_endpoint, *args, **kwargs): + def sagemaker_inference(task, infer, username, sagemaker_endpoint, *args, **kwargs): infer = 'async' if task == 'text-to-image' or task == 'image-to-image': if task == 'text-to-image': @@ -418,19 +409,19 @@ def sagemaker_inference(task, infer, sagemaker_endpoint, *args, **kwargs): info = processed['html_info'] return images, modules.ui.plaintext_to_html(info), '' - def f(*args, **kwargs): + def f(username, *args, **kwargs): if cmd_opts.pureui and func == modules.txt2img.txt2img: sagemaker_endpoint = args[len(args) -1] args = args[:-2] - res = sagemaker_inference('text-to-image', 'sync', sagemaker_endpoint, *args, **kwargs) + res = sagemaker_inference('text-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs) elif cmd_opts.pureui and func == modules.img2img.img2img: sagemaker_endpoint = args[len(args) -1] args = args[:-2] - res = sagemaker_inference('image-to-image', 'sync', sagemaker_endpoint, *args, **kwargs) + res = sagemaker_inference('image-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs) elif cmd_opts.pureui and func == modules.extras.run_extras: sagemaker_endpoint = args[len(args) -1] args = args[:-2] - res = sagemaker_inference('extras', 'sync', sagemaker_endpoint, *args, **kwargs) + res = sagemaker_inference('extras', 'sync', username, sagemaker_endpoint, *args, **kwargs) else: shared.state.begin() with queue_lock: @@ -442,25 +433,27 @@ def f(*args, **kwargs): return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) -def wrap_gradio_call(func, request : gr.Request, extra_outputs=None, add_stats=False): - tokens = shared.demo.server_app.tokens - cookies = request.headers['cookie'].split('; ') - access_token = None - for cookie in cookies: - if cookie.startswith('access-token'): - access_token = cookie[len('access-token=') : ] - break - username = tokens[access_token] if access_token else None +def wrap_gradio_call(func, extra_outputs=None, add_stats=False): + def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs): + tokens = shared.demo.server_app.tokens + cookies = request.headers['cookie'].split('; ') + access_token = None + for cookie in cookies: + if cookie.startswith('access-token'): + access_token = cookie[len('access-token=') : ] + break + username = tokens[access_token] - def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats if run_memmon: shared.mem_mon.monitor() t = time.perf_counter() try: - args.append(username) - res = list(func(*args, **kwargs)) + if func.__name__ == 'f': + res = list(func(username, *args, **kwargs)) + else: + res = list(func(*args, **kwargs)) except Exception as e: # When printing out our debug argument list, do not print out more than a MB of text max_debug_str_len = 131072 # (1024*1024)/8