Skip to content

Commit

Permalink
fix issues with username and training embedding/hypernetwork and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Apr 9, 2023
1 parent 53de970 commit 18cf813
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 56 deletions.
3 changes: 2 additions & 1 deletion modules/sd_models.py
Expand Up @@ -81,7 +81,8 @@ def modeltitle(path, shorthash):
if shared.cmd_opts.pureui:
if sagemaker_endpoint:
params = {
'module': 'Stable-diffusion', 'endpoint_name': sagemaker_endpoint
'module': 'Stable-diffusion',
'endpoint_name': sagemaker_endpoint
}
response = requests.get(url=f'{api_endpoint}/sd/models', params=params)
if response.status_code == 200:
Expand Down
4 changes: 1 addition & 3 deletions modules/shared.py
Expand Up @@ -144,7 +144,6 @@
sagemaker_endpoint_component = None
sd_model_checkpoint_component = None
create_train_dreambooth_component = None
username = ''
else:
api_endpoint = cmd_opts.api_endpoint

Expand Down Expand Up @@ -349,8 +348,7 @@ def refresh_sagemaker_endpoints(username):

if industrial_model != '':
params = {
'industrial_model': industrial_model,
'username': username
'industrial_model': industrial_model
}
response = requests.get(url=f'{api_endpoint}/endpoint', params=params)
if response.status_code == 200:
Expand Down
70 changes: 38 additions & 32 deletions modules/ui.py
Expand Up @@ -670,28 +670,6 @@ def open_folder(f):
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info

def update_sagemaker_endpoint():
return gr.update(value=shared.opts.sagemaker_endpoint, choices=shared.sagemaker_endpoints)

def update_sd_model_checkpoint():
return gr.update(value=shared.opts.sd_model_checkpoint, choices=modules.sd_models.checkpoint_tiles())

def update_username():
if shared.username == 'admin':
inputs = {
'action': 'load'
}
response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs)
if response.status_code == 200:
items = []
for item in json.loads(response.text):
items.append([item['username'], item['password'], item['options'] if 'options' in item else '', shared.get_available_sagemaker_endpoints(item)])
return gr.update(value=shared.username), gr.update(value=items if items != [] else None)
else:
return gr.update(value=shared.username), gr.update()
else:
return gr.update(value=shared.username), gr.update()

def create_ui():
import modules.img2img
import modules.txt2img
Expand Down Expand Up @@ -1455,8 +1433,6 @@ def update_orig(image, state):
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)

sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()

with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
Expand Down Expand Up @@ -2040,9 +2016,6 @@ def save_userdata(user_dataframe, request: gr.Request):
_js="var if alert('Only admin user can save user data')"
)

user_interface.load(update_sagemaker_endpoint, inputs=None, outputs=[shared.sagemaker_endpoint_component])
user_interface.load(update_sd_model_checkpoint, inputs=None, outputs=[shared.sd_model_checkpoint_component])

if cmd_opts.pureui:
interfaces += [
(txt2img_interface, "txt2img", "txt2img"),
Expand Down Expand Up @@ -2101,7 +2074,6 @@ def save_userdata(user_dataframe, request: gr.Request):
outputs=[username_state, user_dataframe],
_js="login"
)
user_interface.load(update_username, inputs=None, outputs=[username_state, user_dataframe])
with gr.Column(scale=1):
logout_button = gr.Button(value="Logout")

Expand Down Expand Up @@ -2152,13 +2124,47 @@ def user_logout(request: gr.Request):

component_keys = [k for k in opts.data_labels.keys() if k in component_dict]

def get_settings_values():
return [getattr(opts, key) for key in component_keys]
def demo_load(request: gr.Request):
tokens = shared.demo.server_app.tokens
cookies = request.headers['cookie'].split('; ')
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None

inputs = {
'action': 'load'
}
response = requests.post(url=f'{shared.api_endpoint}/sd/user', json=inputs)
if response.status_code == 200:
if username == 'admin':
items = []
for item in json.loads(response.text):
items.append([item['username'], item['password'], item['options'] if 'options' in item else '', shared.get_available_sagemaker_endpoints(item)])

additional_components = [gr.update(value=username), gr.update(value=items if items != [] else None), gr.update(), gr.update()]
else:
for item in json.loads(response.text):
if item['username'] == username:
try:
shared.opts.data = json.loads(item['options'])
break
except Exception as e:
print(e)
shared.refresh_sagemaker_endpoints(username)
shared.refresh_checkpoints(shared.opts.sagemaker_endpoint)
additional_components = [gr.update(value=username), gr.update(), gr.update(value=shared.opts.sagemaker_endpoint, choices=shared.sagemaker_endpoints), gr.update(value=shared.opts.sd_model_checkpoint, choices=modules.sd_models.checkpoint_tiles())]
else:
additional_components = [gr.update(value=username), gr.update(), gr.update(), gr.update()]

return [getattr(opts, key) for key in component_keys] + additional_components

demo.load(
fn=get_settings_values,
fn=demo_load,
inputs=[],
outputs=[component_dict[k] for k in component_keys],
outputs=[component_dict[k] for k in component_keys] + [username_state, user_dataframe, shared.sagemaker_endpoint_component, shared.sd_model_checkpoint_component]
)

if not cmd_opts.pureui:
Expand Down
20 changes: 0 additions & 20 deletions webui.py
Expand Up @@ -160,26 +160,6 @@ def user_auth(username, password):

response = requests.post(url=f'{api_endpoint}/sd/login', json=inputs)

if response.status_code == 200:
try:
body = json.loads(response.text)
options = json.loads(json.loads(body)['options'])
except Exception as e:
print(e)
options = None

if options != None:
shared.opts.data = options

shared.refresh_sagemaker_endpoints(username)
shared.refresh_checkpoints(shared.opts.sagemaker_endpoint)
shared.username = username
modules.ui.update_sagemaker_endpoint()
modules.ui.update_sd_model_checkpoint()
modules.ui.update_username()
else:
print(response.text)

return response.status_code == 200

def webui():
Expand Down

0 comments on commit 18cf813

Please sign in to comment.