Skip to content

Commit

Permalink
cleanup & fix issues with cookies handling
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Apr 16, 2023
1 parent ad72065 commit 003d7d7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 86 deletions.
11 changes: 11 additions & 0 deletions modules/shared.py
Expand Up @@ -834,3 +834,14 @@ def get_cookies(request):
# directly, so we need to retrieve its underlying dict first.
cookies = request.headers.__dict__['cookie'].split('; ')
return cookies

def get_webui_username(request):
tokens = demo.server_app.tokens
cookies = request.headers.__dict__['cookie'].split('; ')
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None
return username
95 changes: 9 additions & 86 deletions modules/ui.py
Expand Up @@ -90,17 +90,6 @@ def gr_show(visible=True):
## Begin output images uploaded to s3 by River
s3_resource = boto3.resource('s3')

def get_webui_username(request):
tokens = shared.demo.server_app.tokens
cookies = request.headers['cookie'].split('; ')
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None
return username

def save_images_to_s3(full_fillnames,timestamp,username):
sagemaker_endpoint = shared.opts.sagemaker_endpoint
bucket_name = opts.train_files_s3bucket.replace('s3://','')
Expand Down Expand Up @@ -580,14 +569,7 @@ def refresh():
return gr.update(**(args or {}))

def refresh_sagemaker_endpoints(request : gr.Request):
tokens = shared.demo.server_app.tokens
cookies = shared.get_cookies(request)
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None
username = shared.get_webui_username(request)

refresh_method(username)
args = refreshed_args() if callable(refreshed_args) else refreshed_args
Expand All @@ -598,32 +580,7 @@ def refresh_sagemaker_endpoints(request : gr.Request):
return gr.update(**(args or {}))

def refresh_sd_models(request: gr.Request):
tokens = shared.demo.server_app.tokens
cookies = request.headers['cookie'].split('; ')
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None

refresh_method(username)
args = refreshed_args() if callable(refreshed_args) else refreshed_args

for k, v in args.items():
setattr(refresh_component, k, v)

return gr.update(**(args or {}))

def refresh_sd_models(request: gr.Request):
tokens = shared.demo.server_app.tokens
cookies = request.headers['cookie'].split('; ')
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None
username = shared.get_webui_username(request)

refresh_method(username)
args = refreshed_args() if callable(refreshed_args) else refreshed_args
Expand Down Expand Up @@ -779,7 +736,7 @@ def list_objects(bucket,prefix=''):

def image_viewer(path,cols_width,current_only,request:gr.Request):
if current_only:
username = get_webui_username(request)
username = shared.get_webui_username(request)
path = path+'/'+username
dirs = path.replace('s3://','').split('/')
prefix = '/'.join(dirs[1:])
Expand Down Expand Up @@ -920,14 +877,7 @@ def run_settings(username, *args):
return opts.dumpjson(), 'Settings changed and saved'

def run_settings_single(value, key, request : gr.Request):
tokens = shared.demo.server_app.tokens
cookies = shared.get_cookies(request)
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None
username = shared.get_webui_username(request)

if username and username != '':
if not opts.same_type(value, opts.data_labels[key].default):
Expand Down Expand Up @@ -1889,14 +1839,7 @@ def sagemaker_train_embedding(
*txt2img_preview_params
):

tokens = shared.demo.server_app.tokens
cookies = shared.get_cookies(request)
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None
username = shared.get_webui_username(request)

train_args = {
'embedding_settings': {
Expand Down Expand Up @@ -2022,14 +1965,7 @@ def sagemaker_train_hypernetwork(
*txt2img_preview_params
):

tokens = shared.demo.server_app.tokens
cookies = shared.get_cookies(request)
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None
username = shared.get_webui_username(request)

train_args = {
'hypernetwork_settings': {
Expand Down Expand Up @@ -2224,14 +2160,8 @@ def sagemaker_train_hypernetwork(
save_userdata_btn = gr.Button(value="Save")

def save_userdata(user_dataframe, request: gr.Request):
tokens = shared.demo.server_app.tokens
cookies = shared.get_cookies(request)
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
if not access_token or tokens[access_token] != 'admin':
username = shared.get_webui_username(request)
if username == 'admin':
return gr.update()
items = []
for user_df in user_dataframe:
Expand Down Expand Up @@ -2373,14 +2303,7 @@ def user_logout(request: gr.Request):
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]

def demo_load(request: gr.Request):
tokens = shared.demo.server_app.tokens
cookies = shared.get_cookies(request)
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None
username = shared.get_webui_username(request)

inputs = {
'action': 'load'
Expand Down

0 comments on commit 003d7d7

Please sign in to comment.