Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Mar 20, 2023
1 parent 8e0e392 commit 14ad93d
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 108 deletions.
2 changes: 1 addition & 1 deletion localizations/zh_CN.json
Expand Up @@ -683,7 +683,7 @@
"Save Preview/Ckpt Every Epoch": "经过若干个 Epoch 保存预览/检查点",
"Save Checkpoint Frequency": "保存检查点频率",
"Save Preview(s) Frequency": "保存预览频率",
"Batch": "批处理",
"Batching": "批处理",
"Batch Size": "批量大小",
"Class Batch Size": "类批量大小",
"Learning Rate": "学习率",
Expand Down
2 changes: 1 addition & 1 deletion localizations/zh_TW.json
Expand Up @@ -677,7 +677,7 @@
"Save Preview/Ckpt Every Epoch": "經過若干個 Epoch 保存預覽/檢查點",
"Save Checkpoint Frequency": "保存檢查點頻率",
"Save Preview(s) Frequency": "保存預覽頻率",
"Batch": "批處理",
"Batching": "批處理",
"Batch Size": "批量大小",
"Class Batch Size": "類批量大小",
"Learning Rate": "學習率",
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Expand Up @@ -98,6 +98,7 @@
parser.add_argument('--hypernetwork-s3uri', default='', type=str, help='Hypernetwork S3Uri')
parser.add_argument('--sd-models-s3uri', default='', type=str, help='SD Models S3Uri')
parser.add_argument('--db-models-s3uri', default='', type=str, help='DB Models S3Uri')
parser.add_argument('--lora-models-s3uri', default='', type=str, help='Lora Models S3Uri')
parser.add_argument('--region-name', type=str, help='Region Name')
parser.add_argument('--username', default='', type=str, help='Username')
parser.add_argument('--api-endpoint', default='', type=str, help='API Endpoint')
Expand Down
183 changes: 77 additions & 106 deletions webui.py
Expand Up @@ -12,7 +12,7 @@
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse

from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.call_queue import wrap_queued_call, queue_lock
from modules.paths import script_path

from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
Expand Down Expand Up @@ -43,11 +43,9 @@
import json
import uuid
if not cmd_opts.api:
from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig, sanitize_name
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import start_training_from_config
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import performance_wizard, training_wizard
from extensions.sd_dreambooth_extension.dreambooth.db_config import from_file
from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.scripts.dreambooth import start_training_from_config, create_model
from extensions.sd_dreambooth_extension.scripts.dreambooth import performance_wizard, training_wizard
from modules import paths
import glob

Expand Down Expand Up @@ -348,6 +346,7 @@ def train():
hypernetwork_s3uri = cmd_opts.hypernetwork_s3uri
sd_models_s3uri = cmd_opts.sd_models_s3uri
db_models_s3uri = cmd_opts.db_models_s3uri
lora_models_s3uri = cmd_opts.lora_models_s3uri
api_endpoint = cmd_opts.api_endpoint
username = cmd_opts.username

Expand Down Expand Up @@ -564,16 +563,10 @@ def train():
opts.data = default_options
elif train_task == 'dreambooth':
db_create_new_db_model = train_args['train_dreambooth_settings']['db_create_new_db_model']

db_lora_model_name = train_args['train_dreambooth_settings']['db_lora_model_name']
db_lora_weight = train_args['train_dreambooth_settings']['db_lora_weight']
db_lora_txt_weight = train_args['train_dreambooth_settings']['db_lora_txt_weight']
db_train_imagic_only = train_args['train_dreambooth_settings']['db_train_imagic_only']
db_use_subdir = train_args['train_dreambooth_settings']['db_use_subdir']
db_custom_model_name = train_args['train_dreambooth_settings']['db_custom_model_name']
db_train_wizard_person = train_args['train_dreambooth_settings']['db_train_wizard_person']
db_train_wizard_object = train_args['train_dreambooth_settings']['db_train_wizard_object']
db_performance_wizard = train_args['train_dreambooth_settings']['db_performance_wizard']
db_use_txt2img = train_args['train_dreambooth_settings']['db_use_txt2img']

if db_create_new_db_model:
db_new_model_name = train_args['train_dreambooth_settings']['db_new_model_name']
Expand All @@ -583,15 +576,20 @@ def train():
db_new_model_url = train_args['train_dreambooth_settings']['db_new_model_url']
db_new_model_token = train_args['train_dreambooth_settings']['db_new_model_token']
db_new_model_extract_ema = train_args['train_dreambooth_settings']['db_new_model_extract_ema']
db_model_name, _, db_revision, db_scheduler, db_src, db_has_ema, db_v2, db_resolution = extract_checkpoint(
db_new_model_name,
db_new_model_src,
db_new_model_scheduler,
db_create_from_hub,
db_new_model_url,
db_new_model_token,
db_new_model_extract_ema
)
db_train_unfrozen = train_args['train_dreambooth_settings']['db_train_unfrozen']
db_512_model = train_args['train_dreambooth_settings']['db_512_model']

db_model_name, db_model_path, db_revision, db_epochs, db_scheduler, db_src, _, _, _ = create_model(
db_new_model_name,
db_new_model_src,
db_new_model_scheduler,
db_create_from_hub,
db_new_model_url,
db_new_model_token,
db_new_model_extract_ema,
db_train_unfrozen,
db_512_model,
)
dreambooth_config_id = cmd_opts.dreambooth_config_id
try:
with open(f'/opt/ml/input/data/config/{dreambooth_config_id}.json', 'r') as f:
Expand All @@ -605,91 +603,61 @@ def train():
content = None

if content:
config_dict = json.loads(content)
print(db_model_name, db_revision, db_scheduler, db_src, db_has_ema, db_v2, db_resolution)

config_dict[0] = db_model_name
config_dict[31] = db_revision
config_dict[39] = db_scheduler
config_dict[40] = db_src
config_dict[14] = db_has_ema
config_dict[49] = db_v2
config_dict[30] = db_resolution

db_config = DreamboothConfig(*config_dict)

if db_train_wizard_person:
_, \
max_train_steps, \
num_train_epochs, \
c1_max_steps, \
c1_num_class_images, \
c2_max_steps, \
c2_num_class_images, \
c3_max_steps, \
c3_num_class_images = training_wizard(db_config, True)

config_dict[22] = int(max_train_steps)
config_dict[26] = int(num_train_epochs)
config_dict[59] = c1_max_steps
config_dict[61] = c1_num_class_images
config_dict[77] = c2_max_steps
config_dict[79] = c2_num_class_images
config_dict[95] = c3_max_steps
config_dict[97] = c3_num_class_images
if db_train_wizard_object:
_, \
max_train_steps, \
num_train_epochs, \
c1_max_steps, \
c1_num_class_images, \
c2_max_steps, \
c2_num_class_images, \
c3_max_steps, \
c3_num_class_images = training_wizard(db_config, False)

config_dict[22] = int(max_train_steps)
config_dict[26] = int(num_train_epochs)
config_dict[59] = c1_max_steps
config_dict[61] = c1_num_class_images
config_dict[77] = c2_max_steps
config_dict[79] = c2_num_class_images
config_dict[95] = c3_max_steps
config_dict[97] = c3_num_class_images
params_dict = json.loads(content)

params_dict['db_model_name'] = db_model_name
params_dict['db_model_path'] = db_model_path
params_dict['db_revision'] = db_revision
params_dict['db_epochs'] = db_epochs
params_dict['db_scheduler'] = db_scheduler
params_dict['db_src'] = db_src

if db_train_wizard_person or db_train_wizard_object:
db_num_train_epochs, \
c1_num_class_images_per, \
c2_num_class_images_per, \
c3_num_class_images_per, \
c4_num_class_images_per = training_wizard(db_config, db_train_wizard_person if db_train_wizard_person else db_train_wizard_object)

params_dict['db_num_train_epochs'] = db_num_train_epochs
params_dict[59] = c1_num_class_images_per
params_dict[61] = c2_num_class_images_per
params_dict[77] = c3_num_class_images_per
params_dict[79] = c4_num_class_images_per
if db_performance_wizard:
_, \
attention, \
gradient_checkpointing, \
gradient_accumulation_steps, \
mixed_precision, \
not_cache_latents, \
cache_latents, \
sample_batch_size, \
train_batch_size, \
train_text_encoder, \
stop_text_encoder, \
use_8bit_adam, \
use_cpu, \
use_ema = performance_wizard()

config_dict[5] = attention
config_dict[12] = gradient_checkpointing
config_dict[23] = mixed_precision
config_dict[25] = not_cache_latents
config_dict[32] = sample_batch_size
config_dict[42] = train_batch_size
config_dict[43] = train_text_encoder
config_dict[44] = use_8bit_adam
config_dict[46] = use_cpu
config_dict[47] = use_ema
use_lora, \
use_ema, \
save_samples_every, \
save_weights_every = performance_wizard()

params_dict['attention'] = attention
params_dict['gradient_checkpointing'] = gradient_checkpointing
params_dict['gradient_accumulation_steps'] = gradient_accumulation_steps
params_dict['mixed_precision'] = mixed_precision
params_dict['cache_latents'] = cache_latents
params_dict['sample_batch_size'] = sample_batch_size
params_dict['train_batch_size'] = train_batch_size
params_dict['stop_text_encoder'] = stop_text_encoder
params_dict['use_8bit_adam'] = use_8bit_adam
params_dict['use_lora'] = use_lora
params_dict['use_ema'] = use_ema
params_dict['save_samples_every'] = save_samples_every
params_dict['params_dict'] = save_weights_every

db_config = DreamboothConfig(db_model_name)
db_config.load_params(params_dict)
else:
db_model_name = train_args['train_dreambooth_settings']['db_model_name']
db_model_name = sanitize_name(db_model_name)
db_models_path = cmd_opts.dreambooth_models_path
if db_models_path == "" or db_models_path is None:
db_models_path = os.path.join(shared.models_path, "dreambooth")
working_dir = os.path.join(db_models_path, db_model_name, "working")
config_dict = from_file(os.path.join(db_models_path, db_model_name))
config_dict["pretrained_model_name_or_path"] = working_dir

db_config = DreamboothConfig(*config_dict)
db_config = DreamboothConfig(db_model_name)

ckpt_dir = cmd_opts.ckpt_dir
sd_models_path = os.path.join(shared.models_path, "Stable-diffusion")
Expand All @@ -699,12 +667,7 @@ def train():
print(vars(db_config))
start_training_from_config(
db_config,
db_lora_model_name if db_lora_model_name != '' else None,
db_lora_weight,
db_lora_txt_weight,
db_train_imagic_only,
db_use_subdir,
db_custom_model_name
db_use_txt2img,
)

try:
Expand All @@ -715,12 +678,14 @@ def train():
db_model_dir = os.path.dirname(cmd_dreambooth_models_path) if cmd_dreambooth_models_path else paths.models_path
db_model_dir = os.path.join(db_model_dir, "dreambooth")

lora_models_path = os.path.join(shared.models_path, "Lora")

try:
print('Uploading SD Models...')
print('Uploading SD Models...')
upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.yaml')
)
)
upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.ckpt')
Expand All @@ -730,6 +695,12 @@ def train():
f'{db_models_s3uri}{db_model_name}',
os.path.join(db_model_dir, db_model_name)
)
if db_config.use_lora:
print('Uploading Lora Models...')
upload_s3files(
lora_models_s3uri,
os.path.join(lora_models_path, f'{lora_models_path}/{db_model_name}_*.pt')
)
except Exception as e:
traceback.print_exc()
print(e)
Expand Down

0 comments on commit 14ad93d

Please sign in to comment.