Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Mar 25, 2023
1 parent 4642385 commit 3e4f422
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 38 deletions.
1 change: 0 additions & 1 deletion modules/shared.py
Expand Up @@ -99,7 +99,6 @@
parser.add_argument('--sd-models-s3uri', default='', type=str, help='SD Models S3Uri')
parser.add_argument('--db-models-s3uri', default='', type=str, help='DB Models S3Uri')
parser.add_argument('--lora-models-s3uri', default='', type=str, help='Lora Models S3Uri')
parser.add_argument('--region-name', type=str, help='Region Name')
parser.add_argument('--username', default='', type=str, help='Username')
parser.add_argument('--api-endpoint', default='', type=str, help='API Endpoint')
parser.add_argument('--dreambooth-config-id', default='', type=str, help='Dreambooth config ID')
Expand Down
72 changes: 35 additions & 37 deletions webui.py
@@ -1,6 +1,5 @@
import os
import shutil
import threading
import time
import importlib
import signal
Expand All @@ -16,7 +15,7 @@
from modules.call_queue import wrap_queued_call, queue_lock
from modules.paths import script_path

from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir
import modules.codeformer_model as codeformer
import modules.extras
import modules.face_restoration
Expand Down Expand Up @@ -220,24 +219,23 @@ def webui():
if launch_api:
create_api(app)

ckpt_dir = cmd_opts.ckpt_dir
sd_models_path = os.path.join(shared.models_path, "Stable-diffusion")
if ckpt_dir is not None:
sd_models_path = ckpt_dir
cmd_sd_models_path = cmd_opts.ckpt_dir
sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion")
if cmd_sd_models_path is not None:
sd_models_dir = cmd_sd_models_path

controlnet_dir = cmd_opts.controlnet_dir
cn_models_path = os.path.join(shared.models_path, "ControlNet")
os.makedirs(controlnet_dir, exist_ok=True)
if controlnet_dir is not None:
cn_models_path = controlnet_dir
cmd_controlnet_models_path = cmd_opts.controlnet_dir
cn_models_dir = os.path.join(shared.models_path, "ControlNet")
if cmd_controlnet_models_path is not None:
cn_models_dir = cmd_controlnet_models_path

if 'endpoint_name' in os.environ:
items = []
api_endpoint = os.environ['api_endpoint']
endpoint_name = os.environ['endpoint_name']
for file in os.listdir(sd_models_path):
if os.path.isfile(os.path.join(sd_models_path, file)) and (file.endswith('.ckpt') or file.endswith('.safesentors')):
hash = modules.sd_models.model_hash(os.path.join(sd_models_path, file))
for file in os.listdir(sd_models_dir):
if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safesentors')):
hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file))
item = {}
item['model_name'] = file
item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml'
Expand All @@ -263,10 +261,10 @@ def webui():
params = {
'module': 'ControlNet'
}
for file in os.listdir(cn_models_path):
if os.path.isfile(os.path.join(cn_models_path, file)) and \
for file in os.listdir(cn_models_dir):
if os.path.isfile(os.path.join(cn_models_dir, file)) and \
(file.endswith('pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')):
hash = modules.sd_models.model_hash(os.path.join(cn_models_path, file))
hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file))
item = {}
item['model_name'] = file
item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash)
Expand Down Expand Up @@ -565,10 +563,10 @@ def train():
opts.data = default_options
elif train_task == 'dreambooth':
db_create_new_db_model = train_args['train_dreambooth_settings']['db_create_new_db_model']
db_use_txt2img = train_args['train_dreambooth_settings']['db_use_txt2img']
db_train_wizard_person = train_args['train_dreambooth_settings']['db_train_wizard_person']
db_train_wizard_object = train_args['train_dreambooth_settings']['db_train_wizard_object']
db_performance_wizard = train_args['train_dreambooth_settings']['db_performance_wizard']
db_use_txt2img = train_args['train_dreambooth_settings']['db_use_txt2img']

if db_create_new_db_model:
db_new_model_name = train_args['train_dreambooth_settings']['db_new_model_name']
Expand Down Expand Up @@ -623,13 +621,13 @@ def train():
c1_num_class_images_per, \
c2_num_class_images_per, \
c3_num_class_images_per, \
c4_num_class_images_per = training_wizard(db_config, db_train_wizard_person if db_train_wizard_person else db_train_wizard_object)
c4_num_class_images_per = training_wizard(db_train_wizard_person if db_train_wizard_person else db_train_wizard_object)

params_dict['db_num_train_epochs'] = db_num_train_epochs
params_dict[59] = c1_num_class_images_per
params_dict[61] = c2_num_class_images_per
params_dict[77] = c3_num_class_images_per
params_dict[79] = c4_num_class_images_per
params_dict['c1_num_class_images_per'] = c1_num_class_images_per
params_dict['c1_num_class_images_per'] = c2_num_class_images_per
params_dict['c1_num_class_images_per'] = c3_num_class_images_per
params_dict['c1_num_class_images_per'] = c4_num_class_images_per
if db_performance_wizard:
attention, \
gradient_checkpointing, \
Expand Down Expand Up @@ -684,17 +682,17 @@ def train():
db_model_name = train_args['train_dreambooth_settings']['db_model_name']
db_config = DreamboothConfig(db_model_name)

ckpt_dir = cmd_opts.ckpt_dir
sd_models_path = os.path.join(shared.models_path, "Stable-diffusion")
if ckpt_dir is not None:
sd_models_path = ckpt_dir

print(vars(db_config))
start_training_from_config(
db_config,
db_use_txt2img,
)

cmd_sd_models_path = cmd_opts.ckpt_dir
sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion")
if cmd_sd_models_path is not None:
sd_models_dir = cmd_sd_models_path

try:
cmd_dreambooth_models_path = cmd_opts.dreambooth_models_path
except:
Expand All @@ -711,26 +709,26 @@ def train():
lora_model_dir = os.path.dirname(cmd_lora_models_path) if cmd_lora_models_path else paths.models_path
lora_model_dir = os.path.join(lora_model_dir, "lora")

print('---models path---', sd_models_path, lora_model_dir)
os.system(f'ls -l {sd_models_path}')
os.system('ls -l {0}'.format(os.path.join(sd_models_path, db_model_name)))
print('---models path---', sd_models_dir, lora_model_dir)
os.system(f'ls -l {sd_models_dir}')
os.system('ls -l {0}'.format(os.path.join(sd_models_dir, db_model_name)))
os.system(f'ls -l {lora_model_dir}')

try:
print('Uploading SD Models...')
upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.yaml')
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.yaml')
)
if db_save_safetensors:
upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.safetensors')
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.safetensors')
)
else:
upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.ckpt')
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.ckpt')
)
print('Uploading DB Models...')
upload_s3folder(
Expand All @@ -747,15 +745,15 @@ def train():
os.makedirs(os.path.dirname("/opt/ml/model/"), exist_ok=True)
train_steps=int(db_config.revision)
model_file_basename = f'{db_model_name}_{train_steps}_lora' if db_config.use_lora else f'{db_model_name}_{train_steps}'
f1=os.path.join(sd_models_path, db_model_name, f'{model_file_basename}.yaml')
f1=os.path.join(sd_models_dir, db_model_name, f'{model_file_basename}.yaml')
if os.path.exists(f1):
shutil.copy(f1,"/opt/ml/model/")
if db_save_safetensors:
f2=os.path.join(sd_models_path, db_model_name, f'{model_file_basename}.safetensors')
f2=os.path.join(sd_models_dir, db_model_name, f'{model_file_basename}.safetensors')
if os.path.exists(f2):
shutil.copy(f2,"/opt/ml/model/")
else:
f2=os.path.join(sd_models_path, db_model_name, f'{model_file_basename}.ckpt')
f2=os.path.join(sd_models_dir, db_model_name, f'{model_file_basename}.ckpt')
if os.path.exists(f2):
shutil.copy(f2,"/opt/ml/model/")
except Exception as e:
Expand Down

0 comments on commit 3e4f422

Please sign in to comment.