From dd88ce652d627b2a60f9bc393e1d868462bac438 Mon Sep 17 00:00:00 2001 From: xie river Date: Thu, 13 Apr 2023 08:22:36 +0000 Subject: [PATCH] fix s3 path error --- localizations/zh_CN.json | 1 + modules/ui.py | 11 ++++++++--- webui.py | 6 ++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/localizations/zh_CN.json b/localizations/zh_CN.json index 674b3505d36..83f1f64e052 100644 --- a/localizations/zh_CN.json +++ b/localizations/zh_CN.json @@ -852,5 +852,6 @@ "Input S3 path of images":"输入图片的S3路径", "Submit":"确定", "columns width":"每行图片列数", + "Show current user's images only":"只显示当前用户图片集", "--------": "--------" } diff --git a/modules/ui.py b/modules/ui.py index 1b8bf48587b..ef18c9fe776 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -735,7 +735,10 @@ def list_objects(bucket,prefix=''): objects = response['Contents'] if response.get('Contents') else [] return [obj['Key'] for obj in objects] - def image_viewer(path,cols_width,request:gr.Request): + def image_viewer(path,cols_width,current_only,request:gr.Request): + if current_only: + username = get_webui_username(request) + path = path+'/'+username dirs = path.replace('s3://','').split('/') prefix = '/'.join(dirs[1:]) bucket = dirs[0] @@ -755,14 +758,16 @@ def image_viewer(path,cols_width,request:gr.Request): with gr.Blocks(analytics_enabled=False) as imagesviewer_interface: with gr.Row(): with gr.Column(scale=3): - images_s3_path = gr.Textbox(label="Input S3 path of images",value = get_default_sagemaker_bucket()+'/output-images') + images_s3_path = gr.Textbox(label="Input S3 path of images",value = get_default_sagemaker_bucket()+'/stable-diffusion-webui/generated') + with gr.Column(scale=1): + show_user_only = gr.Checkbox(label="Show current user's images only", value=True) with gr.Column(scale=1): cols_width = gr.Slider(minimum=4, maximum=20, step=1, label="columns width", value=8) with gr.Column(scale=1): images_s3_path_btn = gr.Button(value="Submit",variant='primary') with gr.Row(): result = gr.HTML("
") - images_s3_path_btn.click(fn=image_viewer, inputs=[images_s3_path,cols_width], outputs=[result]) + images_s3_path_btn.click(fn=image_viewer, inputs=[images_s3_path,cols_width,show_user_only], outputs=[result]) ## end diff --git a/webui.py b/webui.py index e5f9fdc74fd..0bf375ea931 100644 --- a/webui.py +++ b/webui.py @@ -242,6 +242,7 @@ def check_space_s3_download(s3_client,bucket_name,s3_folder,local_folder,file,si # s3_client = boto3.client('s3') src = s3_folder + '/' + file dist = os.path.join(local_folder, file) + os.makedirs(os.path.dirname(dist), exist_ok=True) # Get disk usage statistics disk_usage = psutil.disk_usage('/tmp') freespace = disk_usage.free/(1024**3) @@ -508,11 +509,12 @@ def register_sd_models(sd_models_dir): for file in get_models(sd_models_dir, ['*.ckpt', '*.safetensors']): hash = modules.sd_models.model_hash(file) item = {} - item['model_name'] = os.path.basename(file) + ##remove the prefix, but remain the sub dir path in model_name. eg. 'river/jp-style-girl-2_200_lora.safetensors [4d3a456f]' + item['model_name'] = file.replace("/tmp/models/Stable-diffusion/",'') item['hash'] = hash item['filename'] = file item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml' - item['title'] = '{0} [{1}]'.format(os.path.basename(file), hash) + item['title'] = '{0} [{1}]'.format(file.replace("/tmp/models/Stable-diffusion/",''), hash) item['endpoint_name'] = endpoint_name items.append(item) inputs = {