Skip to content

Commit

Permalink
dynamic loading models
Browse files Browse the repository at this point in the history
  • Loading branch information
xie river committed Apr 11, 2023
1 parent 6445c3c commit 972a04f
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 132 deletions.
2 changes: 1 addition & 1 deletion localizations/zh_CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@
"Output S3 folder":"S3文件夹目录",
"Upload Train Images to S3":"上传训练图片到S3",
"Error, please configure a S3 bucket at settings page first":"失败,请先到设置页面配置S3桶名",
"Upload":"上传",
"Upload Images":"上传图片",
"Reload all models":"重新加载模型文件",
"Update model files path":"更新模型加载路径",
"S3 path for downloading model files (E.g, s3://bucket-name/models/)":"加载模型的S3路径,例如:s3://bucket-name/models/",
Expand Down
2 changes: 1 addition & 1 deletion modules/call_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs):
t = time.perf_counter()

try:
if func.__name__ == 'f' or func.__name__ == 'run_settings':
if func.__name__ == 'f' or func.__name__ == 'run_settings' or func.__name__ == 'save_files':
res = list(func(username, *args, **kwargs))
else:
res = list(func(*args, **kwargs))
Expand Down
6 changes: 4 additions & 2 deletions modules/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import os
import sys
import modules.safe

script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models")
## Change by River
# models_path = os.path.join(script_path, "models")
models_path = '/tmp/models'
##
sys.path.insert(0, script_path)

# search for directory of stable diffusion in following places
Expand Down
9 changes: 7 additions & 2 deletions modules/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def load_scripts():
script_callbacks.clear_callbacks()

scripts_list = list_scripts("scripts", ".py")

print('scripts_list:',scripts_list)
syspath = sys.path

for scriptfile in sorted(scripts_list):
Expand All @@ -203,6 +203,7 @@ def load_scripts():
finally:
sys.path = syspath
current_basedir = paths.script_path
print('scripts_data',scripts_data)


def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
Expand All @@ -225,6 +226,9 @@ def __init__(self):
self.infotext_fields = []

def initialize_scripts(self, is_img2img):
print('----initialize_scripts()------')
print(f'--scripts_data--{scripts_data}')
traceback.print_stack()
self.scripts.clear()
self.alwayson_scripts.clear()
self.selectable_scripts.clear()
Expand Down Expand Up @@ -316,7 +320,8 @@ def run(self, p: StableDiffusionProcessing, *args):

if script_index == 0:
return None

print('self.selectable_scripts:',self.selectable_scripts)
print('script_index:',script_index)
script = self.selectable_scripts[script_index-1]

if script is None:
Expand Down
3 changes: 2 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
s3_folder_cn = None
syncLock = threading.Lock()
tmp_models_dir = '/tmp/models'
tmp_cache_dir = '/tmp/cache'
tmp_cache_dir = '/tmp/model_sync_cache'
#end

sd_model_file = os.path.join(script_path, 'model.ckpt')
Expand Down Expand Up @@ -514,6 +514,7 @@ def refresh_sagemaker_endpoints(username):


options_templates.update(options_section(('sd', "Stable Diffusion"), {
# "models_s3_bucket": OptionInfo(f'{get_default_sagemaker_bucket()}/stable-diffusion-webui/models/', "S3 path for downloading model files (E.g, s3://bucket-name/models/)", ),
"sagemaker_endpoint": OptionInfo(None, "SaegMaker endpoint", gr.Dropdown, lambda: {"choices": list_sagemaker_endpoints()}, refresh=refresh_sagemaker_endpoints),
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
Expand Down
79 changes: 50 additions & 29 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,22 @@ def gr_show(visible=True):
## Begin output images uploaded to s3 by River
s3_resource = boto3.resource('s3')

def save_images_to_s3(full_fillnames,timestamp):
username = shared.username
def get_webui_username(request):
tokens = shared.demo.server_app.tokens
cookies = request.headers['cookie'].split('; ')
access_token = None
for cookie in cookies:
if cookie.startswith('access-token'):
access_token = cookie[len('access-token=') : ]
break
username = tokens[access_token] if access_token else None
return username

def save_images_to_s3(full_fillnames,timestamp,username):
sagemaker_endpoint = shared.opts.sagemaker_endpoint
bucket_name = opts.train_files_s3bucket
bucket_name = opts.train_files_s3bucket.replace('s3://','')
if bucket_name.endswith('/'):
bucket_name= bucket_name[:-1]
if bucket_name == '':
return 'Error, please configure a S3 bucket at settings page first'
s3_bucket = s3_resource.Bucket(bucket_name)
Expand Down Expand Up @@ -134,6 +146,10 @@ def save_images_to_s3(full_fillnames,timestamp):
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋

def text_to_hyperlink_html(url):
text= f'<p><a target="_blank" href="{url}">{url}</a></p>'
return text

def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
return text
Expand All @@ -143,7 +159,7 @@ def send_gradio_gallery_to_image(x):
return None
return image_from_url_text(x[0])

def save_files(js_data, images, do_make_zip, index):
def save_files(username,js_data, images, do_make_zip, index):
import csv
filenames = []
fullfns = []
Expand Down Expand Up @@ -197,8 +213,7 @@ def __init__(self, d=None):

timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S')
logfile = os.path.join(opts.outdir_save, "log.csv")
s3folder = save_images_to_s3(fullfns,timestamp)
save_images_to_s3([logfile],timestamp)
s3folder = save_images_to_s3(fullfns+[logfile],timestamp,username)
# Make Zip
if do_make_zip:
zip_filepath = os.path.join(path, "images.zip")
Expand All @@ -210,7 +225,7 @@ def __init__(self, d=None):
zip_file.writestr(filenames[i], f.read())
fullfns.insert(0, zip_filepath)

return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}, \nS3 folder:\n{s3folder}")
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}"),text_to_hyperlink_html(s3folder)



Expand Down Expand Up @@ -683,13 +698,14 @@ def open_folder(f):
generation_info,
result_gallery,
do_make_zip,
html_info,
html_info
],
outputs=[
download_files,
html_info,
html_info,
html_info,
html_info
]
)
else:
Expand All @@ -711,20 +727,20 @@ def create_ui():

interfaces = []

##add images viewer
def translate(text):
return f'translated:{text}'
with gr.Blocks(analytics_enabled=False) as imagesviewer_interface:
with gr.Row().style(equal_height=False):
with gr.Column():
english = gr.Textbox(label="Placeholder")
translate_btn = gr.Button(value="Translate")
with gr.Column():
german = gr.Textbox(label="German Text")

translate_btn.click(translate, inputs=english, outputs=german, api_name="translate-to-german")
examples = gr.Examples(examples=["I went to the supermarket yesterday.", "Helen is a good swimmer."],
inputs=[english])
##add River
# def translate(text):
# return f'translated:{text}'
# with gr.Blocks(analytics_enabled=False) as imagesviewer_interface:
# with gr.Row().style(equal_height=False):
# with gr.Column():
# english = gr.Textbox(label="Placeholder")
# translate_btn = gr.Button(value="Translate")
# with gr.Column():
# german = gr.Textbox(label="German Text")

# translate_btn.click(translate, inputs=english, outputs=german, api_name="translate-to-german")
# examples = gr.Examples(examples=["I went to the supermarket yesterday.", "Helen is a good swimmer."],
# inputs=[english])

with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Row().style(equal_height=False):
Expand Down Expand Up @@ -880,13 +896,14 @@ def run_settings_single(value, key, request : gr.Request):
with gr.Row():
settings_submit = gr.Button(value="Apply settings", variant='primary')
with gr.Row():
with gr.Column(scale=2):
with gr.Column(scale=4):
models_s3bucket = gr.Textbox(label="S3 path for downloading model files (E.g, s3://bucket-name/models/)",
value=default_s3_path)
with gr.Column(scale=1):
set_models_s3bucket_btn = gr.Button(value="Update model files path",elem_id='id_set_models_s3bucket')
with gr.Column(scale=1):
reload_models_btn = gr.Button(value='Reload all models', elem_id='id_reload_all_models')



result = gr.HTML()
Expand Down Expand Up @@ -996,7 +1013,8 @@ def reload_all_models():
inputs=[],
outputs=[result]
)


# River
def set_models_s3bucket(bucket_name):
if bucket_name == '':
return 'Error, please configure a S3 bucket for downloading model files'
Expand Down Expand Up @@ -1539,10 +1557,13 @@ def update_orig(image, state):
with gr.Row().style(equal_height=False):
with gr.Tabs(elem_id="train_tabs"):
## Begin add s3 images upload interface by River
def upload_to_s3(imgs):
username = shared.username
def upload_to_s3(imgs,request : gr.Request):
username = get_webui_username(request)
print (f'--get_webui_username--:{username}')
timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S')
bucket_name = opts.train_files_s3bucket
bucket_name = opts.train_files_s3bucket.replace('s3://','')
if bucket_name.endswith('/'):
bucket_name= bucket_name[:-1]
if bucket_name == '':
return 'Error, please configure a S3 bucket at settings page first'
s3_bucket = s3_resource.Bucket(bucket_name.replace('s3://',''))
Expand All @@ -1561,7 +1582,7 @@ def upload_to_s3(imgs):
with gr.Tab(label="Upload Train Images to S3"):
upload_files = gr.Files(label="Files")
url_output = gr.Textbox(label="Output S3 folder")
sub_btn = gr.Button(label="Upload",variant='primary',elem_id='id_upload_train_files')
sub_btn = gr.Button(value="Upload Images",elem_id='id_upload_train_images',variant='primary')
sub_btn.click(fn=upload_to_s3, inputs=upload_files, outputs=url_output)
## End add s3 images upload interface by River
with gr.Tab(label="Train Embedding"):
Expand Down Expand Up @@ -2194,7 +2215,7 @@ def save_userdata(user_dataframe, request: gr.Request):

# interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")]
interfaces += [(imagesviewer_interface,"Images Viewer","imagesviewer")]
# interfaces += [(imagesviewer_interface,"Images Viewer","imagesviewer")]

extensions_interface = ui_extensions.create_ui()
interfaces += [(extensions_interface, "Extensions", "extensions")]
Expand Down
Loading

0 comments on commit 972a04f

Please sign in to comment.