Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Mar 22, 2023
1 parent eefad18 commit ad9b78f
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions webui.py
Expand Up @@ -39,7 +39,6 @@
import traceback
from botocore.exceptions import ClientError
import requests
import io
import json
import uuid
if not cmd_opts.api:
Expand Down Expand Up @@ -579,6 +578,7 @@ def train():
db_new_model_extract_ema = train_args['train_dreambooth_settings']['db_new_model_extract_ema']
db_train_unfrozen = train_args['train_dreambooth_settings']['db_train_unfrozen']
db_512_model = train_args['train_dreambooth_settings']['db_512_model']
db_save_safetensors = train_args['train_dreambooth_settings']['db_save_safetensors']

db_model_name, db_model_path, db_revision, db_epochs, db_scheduler, db_src, db_has_ema, db_v2, db_resolution = create_model(
db_new_model_name,
Expand Down Expand Up @@ -701,18 +701,29 @@ def train():
db_model_dir = os.path.dirname(cmd_dreambooth_models_path) if cmd_dreambooth_models_path else paths.models_path
db_model_dir = os.path.join(db_model_dir, "dreambooth")

lora_models_path = os.path.join(shared.models_path, "Lora")
lora_models_path = os.path.join(shared.models_path, "lora")

print('---models path---', sd_models_path, lora_models_path)
os.system(f'ls -l {sd_models_path}')
os.system('ls -l {0}'.format(os.path.join(sd_models_path, db_model_name)))
os.system(f'ls -l {lora_models_path}')

try:
print('Uploading SD Models...')
upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.yaml')
)
upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.ckpt')
os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.yaml')
)
if db_save_safetensors:
upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.safetensors')
)
else:
upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.ckpt')
)
print('Uploading DB Models...')
upload_s3folder(
f'{db_models_s3uri}{db_model_name}',
Expand All @@ -722,7 +733,7 @@ def train():
print('Uploading Lora Models...')
upload_s3files(
lora_models_s3uri,
os.path.join(lora_models_path, f'{lora_models_path}/{db_model_name}_*.pt')
os.path.join(lora_models_path, f'{db_model_name}_*.pt')
)
except Exception as e:
traceback.print_exc()
Expand Down

0 comments on commit ad9b78f

Please sign in to comment.