From ad9b78f47a155cb1d4436b5362a521b68cb40923 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Wed, 22 Mar 2023 23:37:54 +0800 Subject: [PATCH] cleanup --- webui.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/webui.py b/webui.py index 0488c3f65f5..9a5c4e2a767 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,6 @@ import traceback from botocore.exceptions import ClientError import requests -import io import json import uuid if not cmd_opts.api: @@ -579,6 +578,7 @@ def train(): db_new_model_extract_ema = train_args['train_dreambooth_settings']['db_new_model_extract_ema'] db_train_unfrozen = train_args['train_dreambooth_settings']['db_train_unfrozen'] db_512_model = train_args['train_dreambooth_settings']['db_512_model'] + db_save_safetensors = train_args['train_dreambooth_settings']['db_save_safetensors'] db_model_name, db_model_path, db_revision, db_epochs, db_scheduler, db_src, db_has_ema, db_v2, db_resolution = create_model( db_new_model_name, @@ -701,18 +701,29 @@ def train(): db_model_dir = os.path.dirname(cmd_dreambooth_models_path) if cmd_dreambooth_models_path else paths.models_path db_model_dir = os.path.join(db_model_dir, "dreambooth") - lora_models_path = os.path.join(shared.models_path, "Lora") + lora_models_path = os.path.join(shared.models_path, "lora") + + print('---models path---', sd_models_path, lora_models_path) + os.system(f'ls -l {sd_models_path}') + os.system('ls -l {0}'.format(os.path.join(sd_models_path, db_model_name))) + os.system(f'ls -l {lora_models_path}') try: print('Uploading SD Models...') upload_s3files( sd_models_s3uri, - os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.yaml') - ) - upload_s3files( - sd_models_s3uri, - os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.ckpt') + os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.yaml') ) + if db_save_safetensors: + upload_s3files( + sd_models_s3uri, + os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.safetensors') + ) + else: + upload_s3files( + sd_models_s3uri, + os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.ckpt') + ) print('Uploading DB Models...') upload_s3folder( f'{db_models_s3uri}{db_model_name}', @@ -722,7 +733,7 @@ def train(): print('Uploading Lora Models...') upload_s3files( lora_models_s3uri, - os.path.join(lora_models_path, f'{lora_models_path}/{db_model_name}_*.pt') + os.path.join(lora_models_path, f'{db_model_name}_*.pt') ) except Exception as e: traceback.print_exc()