Skip to content

Commit

Permalink
revise multi-user support
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Dec 5, 2022
1 parent 18815ab commit b1d9758
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 175 deletions.
9 changes: 5 additions & 4 deletions launch.py
Expand Up @@ -149,6 +149,7 @@ def prepare_enviroment():
sys.argv += shlex.split(commandline_args)
test_argv = [x for x in sys.argv if x != '--tests']

sys.argv, skip_torch_cuda = extract_arg(sys.argv, '--skip-torch-cuda')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
Expand All @@ -164,11 +165,11 @@ def prepare_enviroment():

print(f"Python {sys.version}")
print(f"Commit hash: {commit}")
if not is_installed("torch") or not is_installed("torchvision"):

if not skip_torch_cuda and (not is_installed("torch") or not is_installed("torchvision")):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")

if not skip_torch_cuda_test:
if not skip_torch_cuda and not skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")

if not is_installed("gfpgan"):
Expand Down Expand Up @@ -206,7 +207,7 @@ def prepare_enviroment():
if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")

run_pip(f"install -r {requirements_file}", "requirements for Web UI")
#run_pip(f"install -r {requirements_file}", "requirements for Web UI")

run_extensions_installers()

Expand Down
30 changes: 29 additions & 1 deletion localizations/zh_CN.json
Expand Up @@ -618,7 +618,35 @@
"favorites": "收藏夹(已保存)",
"others": "其他",
"Collect": "收藏(保存)",

"Create & Train Embedding": "创建并训练 Embedding",
"Train an embedding; you must specify a directory with a set of 1:1 ratio images": "训练 embedding; 必须指定一组具有 1:1 比例图像的目录",
"Embedding settings": "Embedding 设置",
"Image preprocess settings": "图像预处理设置",
"Train settings": "训练设置",
"Create & Train Hypernetwork": "创建并训练 Hypernetwork",
"Train an hypernetwork; you must specify a directory with a set of 1:1 ratio images": "训练 hypernetwork; 必须指定一组具有 1:1 比例图像的目录",
"Hypernetwork settings": "Hypernetwork 设置",
"Sign Options": "登陆选项",
"Sign In": "登入",
"Sign Up": "注册",
"Sign Out": "登出",
"Username": "用户名",
"Password": "密码",
"Email": "电子邮箱",
"Update": "更新",
"Delete": "删除",
"Mismatched username/password or not existed username": "用户名/密码不匹配或用户不存在",
"Signup failed, please check and retry again": "注册失败,请检查后并重试",
"Update failed, please check and retry again": "更新失败,请检查后并重试",
"Output": "输出",
"Images S3 URI": "图像 S3 位置",
"Models S3 URI": "模型 S3 位置",
"Instance type": "实例类型",
"Instance count": "实例数量",
"Submit training job sucessful": "训练任务提交成功",
"Settings saved failed": "设置保存错误",
"SageMaker endpoint": "SageMaker 端点",
"User": "用户",

"--------": "--------"
}
203 changes: 117 additions & 86 deletions localizations/zh_TW.json

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion modules/api/api.py
Expand Up @@ -18,7 +18,7 @@
import json
import os
import boto3
from modules import sd_hijack, hypernetworks
from modules import sd_hijack, hypernetworks, sd_models
from typing import Union
import traceback
import requests
Expand Down Expand Up @@ -359,6 +359,7 @@ def invocations(self, req: InvocationsRequest):
response = requests.post(url=f'{api_endpoint}/sd/user', json=inputs)
if response.status_code == 200 and response.text != '':
shared.opts.data = json.loads(response.text)
sd_models.reload_model_weights()

self.download_s3files(hypernetwork_s3uri, os.path.join(script_path, shared.cmd_opts.hypernetwork_dir))
hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)
Expand Down
2 changes: 2 additions & 0 deletions modules/sd_models.py
Expand Up @@ -311,6 +311,8 @@ def reload_model_weights(sd_model=None, info=None):
if not sd_model:
sd_model = shared.sd_model

print('Origin checkpoint: ', sd_model.sd_model_checkpoint)
print('Current checkpoint: ', checkpoint_info.filename)
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return

Expand Down
89 changes: 43 additions & 46 deletions modules/shared.py
Expand Up @@ -136,8 +136,6 @@
username = ''
api_endpoint = os.environ['api_endpoint']
industrial_model = ''
endpoint_name = ''
endpoint_names = []
default_options = {}

def reload_hypernetworks():
Expand Down Expand Up @@ -268,7 +266,49 @@ def options_section(section_identifier, options_dict):

options_templates = {}

def refresh_sagemaker_endpoints():
global industrial_model, api_endpoint, default_options

if industrial_model == '':
response = requests.get(url=f'{api_endpoint}/sd/industrialmodel')
if response.status_code == 200:
industrial_model = response.text
else:
model_name = 'stable-diffusion-webui'
model_description = model_name
inputs = {
'model_algorithm': 'stable-diffusion-webui',
'model_name': model_name,
'model_description': model_description,
'model_extra': '{"visible": "false"}',
'model_samples': '',
'file_content': {
'data': [(lambda x: int(x))(x) for x in open(os.path.join(script_path, 'logo.ico'), 'rb').read()]
}
}

response = requests.post(url=f'{api_endpoint}/industrialmodel', json = inputs)
if response.status_code == 200:
body = json.loads(response.text)
industrial_model = body['id']

default_options = self.data

sagemaker_endpoints = []

if industrial_model != '':
params = {
'industrial_model': industrial_model
}
response = requests.get(url=f'{api_endpoint}/endpoint', params=params)
if response.status_code == 200:
for endpoint_item in json.loads(response.text):
sagemaker_endpoints.append(endpoint_item['EndpointName'])

return sagemaker_endpoints

options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sagemaker_endpoint": OptionInfo(None, "SaegMaker endpoint", gr.Dropdown, lambda: {"choices": refresh_sagemaker_endpoints()}, refresh=refresh_sagemaker_endpoints),
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list),
Expand Down Expand Up @@ -393,7 +433,7 @@ def options_section(section_identifier, options_dict):
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'quicksettings': OptionInfo("sagemaker_endpoint", "Quicksettings list"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))

Expand Down Expand Up @@ -480,34 +520,6 @@ def load(self, filename):
if bad_settings > 0:
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)

if cmd_opts.pureui:
global api_endpoint, industrial_model, default_options

#opts.show_progressbar = False
response = requests.get(url=f'{api_endpoint}/sd/industrialmodel')
if response.status_code == 200:
industrial_model = response.text
else:
model_name = 'stable-diffusion-webui'
model_description = model_name
inputs = {
'model_algorithm': 'stable-diffusion-webui',
'model_name': model_name,
'model_description': model_description,
'model_extra': '{"visible": "false"}',
'model_samples': '',
'file_content': {
'data': [(lambda x: int(x))(x) for x in open(os.path.join(script_path, 'logo.ico'), 'rb').read()]
}
}

response = requests.post(url=f'{api_endpoint}/industrialmodel', json = inputs)
if response.status_code == 200:
body = json.loads(response.text)
industrial_model = body['id']

default_options = self.data

def onchange(self, key, func, call=True):
item = self.data_labels.get(key)
item.onchange = func
Expand Down Expand Up @@ -587,18 +599,3 @@ def clear(self):
def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]

if cmd_opts.pureui:
def init_endpoints():
global endpoint_name, endpoint_names, industrial_model, api_endpoint

endpoints = []
params = {
'industrial_model': industrial_model
}
response = requests.get(url=f'{api_endpoint}/endpoint', params=params)
if response.status_code == 200:
for endpoint_item in json.loads(response.text):
endpoints.append(endpoint_item['EndpointName'])
endpoint_name = endpoints[0] if len(endpoints) > 0 else ''
endpoint_names = endpoints
30 changes: 0 additions & 30 deletions modules/ui.py
Expand Up @@ -2331,36 +2331,6 @@ def user_delete(login_username, login_password, login_email):
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component

if cmd_opts.pureui:
shared.init_endpoints()

with gr.Row():
with gr.Column(scale=9):
endpoint_names = gr.Dropdown(label='SageMaker endpoint', value=shared.endpoint_name, choices=shared.endpoint_names)
with gr.Column(scale=1):
endpoint_refresh = gr.Button(refresh_symbol)

def refresh_endpoint():
shared.init_endpoints()
return {
endpoint_names: gr.update(value=shared.endpoint_name, choices=shared.endpoint_names)
}

def change_endpoint(endpoint_names):
shared.endpoint_name = endpoint_names

endpoint_names.change(
fn=change_endpoint,
inputs=[endpoint_names],
outputs=[]
)

endpoint_refresh.click(
fn=refresh_endpoint,
inputs=[],
outputs=[endpoint_names]
)

parameters_copypaste.integrate_settings_paste_fields(component_dict)
parameters_copypaste.run_bind()

Expand Down
2 changes: 1 addition & 1 deletion style.css
Expand Up @@ -501,7 +501,7 @@ input[type="range"]{
padding: 0;
}

#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
#refresh_sagemaker_endpoint, #refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
max-width: 2.5em;
min-width: 2.5em;
height: 2.4em;
Expand Down
10 changes: 4 additions & 6 deletions webui.py
Expand Up @@ -286,7 +286,7 @@ def sagemaker_inference(task, infer, *args, **kwargs):
}

params = {
'endpoint_name': shared.endpoint_name
'endpoint_name': shared.opts.sagemaker_endpoint
}

response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs)
Expand Down Expand Up @@ -384,7 +384,7 @@ def sagemaker_inference(task, infer, *args, **kwargs):
}

params = {
'endpoint_name': shared.endpoint_name
'endpoint_name': shared.opts.sagemaker_endpoint
}
response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs)
if infer == 'async':
Expand Down Expand Up @@ -440,6 +440,7 @@ def initialize():
modules.scripts.load_scripts()

modules.sd_vae.refresh_vae_list()

shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)

if not cmd_opts.pureui:
Expand Down Expand Up @@ -593,10 +594,7 @@ def train():
response = requests.post(url=f'{api_endpoint}/sd/user', json=inputs)
if response.status_code == 200 and response.text != '':
opts.data = json.loads(response.text)
for key in modules.sd_models.checkpoints_list:
if modules.sd_models.checkpoints_list[key].title == opts.data['sd_model_checkpoint']:
shared.sd_model.sd_model_name = modules.sd_models.checkpoints_list[key].model_name
break
modules.sd_models.load_model()

if train_task == 'embedding':
name = train_args['embedding_settings']['name']
Expand Down

0 comments on commit b1d9758

Please sign in to comment.