In [7]:
import os, shutil, toml

In [8]:
# 存放输出数据的文件夹
root_dir: str = "database/working"
# 存放临时数据的文件夹
tmp_dir: str = "database/"
# lora的名称
lora_name: str = "juzhisam"
# 加密密钥
encrypt_key: str = None

config_dict = {
	"network_arguments": {
		"unet_lr": 3e-4*4, # UNet 学习率
		"text_encoder_lr": 6e-5*4, # TextEncoder 学习率
		"network_dim": 16, # LORA的大小
		"network_alpha": 8,

		"network_module": "networks.lora",
		"network_args": None,
		"network_train_unet_only": False,
	},
	"optimizer_arguments": {
		# lr_scheduler: ["constant", "cosine", "cosine_with_restarts", "constant_with_warmup", "linear", "polynomial"]
		"lr_scheduler": "cosine_with_restarts",
		"lr_scheduler_num_cycles": 3,
		"lr_scheduler_power": None,
		"lr_warmup_steps": 0,

		# optimizer: ["AdamW8bit", "Prodigy", "DAdaptation", "DadaptAdam", "DadaptLion", "AdamW", "Lion", "SGDNesterov", "SGDNesterov8bit", "AdaFactor"]
		"optimizer_type": "AdamW8bit",
		"optimizer_args": [
			"weight_decay=0.1", 
			"betas=[0.9,0.99]"
		],
	},
	"training_arguments": {
		# 模型文件
		"pretrained_model_name_or_path": "D:\\GITHUB\\stable-diffusion-webui\\models\\Stable-diffusion\\yiffymix_v61Noobxl.safetensors",

		# vae 文件
		# - "stabilityai/sdxl-vae": 使用 huggingface 模型
		# - None: 使用模型自带的vae
		"vae": None,

		"max_train_epochs": 15, # 训练多少epoch
		"train_batch_size": 2, # 训练的批大小

		# 尽可能直接加载到vram而不是ram中
		"lowram": True,

		# mixed precision: ["bf16", "fp16"]
		"mixed_precision": "fp16",

		# 二选一
		"xformers": False,
		"sdpa": True,

		"cache_latents": True,
		"cache_latents_to_disk": True,
		"cache_text_encoder_outputs": False,
		"seed": 42,
		"max_token_length": 225,
		"min_snr_gamma": 7.0,
		"no_half_vae": True,
		"gradient_checkpointing": True,
		"gradient_accumulation_steps": 1,
		"max_data_loader_n_workers": 8,
		"persistent_data_loader_workers": True,
		"min_timestep": 0,
		"max_timestep": 1000,
		"prior_loss_weight": 1.0,
	},
	"saving_arguments": {
		# 保存的精度
		"save_precision": "fp16",
		# 保存的格式
		"save_model_as": "safetensors",
		
		# 保存的频率
		"save_every_n_epochs": 1,
		"save_last_n_epochs": 15,
		
		"output_name": lora_name,
		"log_prefix": lora_name,
	}
}

dataset_dict = {
	"general": {
		# 是否需要打乱标签
		"shuffle_caption": True,
		# 保持前几个标签不打乱
		"keep_tokens": 1,

		# 训练的分辨率
		"resolution": 1024,
		"flip_aug": False,
		"caption_extension": ".txt",
		"enable_bucket": True,
		"bucket_no_upscale": True,
		"bucket_reso_steps": 64,
		"min_bucket_reso": 256,
		"max_bucket_reso": 4096,
	},
	# 数据集，可设置多个数据集
	"datasets": [{
		"subsets": [
			{
				# 数据重复几次
				"num_repeats": 5,
				# 数据集路径，最终会自动复制到tmpdir目录下使用
				"image_dir": "D:\\GITHUB\\sd-scripts\\samples\\juzhi",
				# 是否使用mep加密，设为None表示不使用，设为true表示使用
				"cache_info": None,
			}
		]
	}]
}

In [None]:
config_dict["optimizer_arguments"]["learning_rate"] = config_dict["network_arguments"]["unet_lr"]
config_dict["training_arguments"]["full_bf16"] = (config_dict["training_arguments"]["mixed_precision"] == "bf16")
config_dict["training_arguments"]["full_fp16"] = (config_dict["training_arguments"]["mixed_precision"] == "fp16")

workingdir = os.path.join(root_dir, lora_name)
configdir = os.path.join(workingdir, "configs")
logdir = os.path.join(workingdir, "logs")
outputdir = os.path.join(workingdir, "outputs")
datasetdir = os.path.join(tmp_dir, "dataset")

for x in [workingdir, configdir, logdir, outputdir, datasetdir]:
	os.makedirs(x, exist_ok = True)

config_dict["saving_arguments"]["output_dir"] = os.path.abspath(outputdir)
config_dict["saving_arguments"]["logging_dir"] = os.path.abspath(logdir)

resolution = dataset_dict["general"]["resolution"]
temp_resolution = round(resolution / 128) * 128
if (resolution != temp_resolution):
	resolution = temp_resolution
	print("⚠️ resolution is rouned to nearest step: ", resolution)
	dataset_dict["general"]["resolution"] = resolution

for i in range(len(dataset_dict["datasets"][0]["subsets"])):
	targetfolder = os.path.join(datasetdir, str(i))
	if os.path.exists(targetfolder):
		shutil.rmtree(targetfolder)
	shutil.copytree(dataset_dict["datasets"][0]["subsets"][i]["image_dir"], targetfolder)
	dataset_dict["datasets"][0]["subsets"][i]["image_dir"] = os.path.abspath(targetfolder)

def CleanConfigAndSave(name: str, filename: str, config_dict: dict):
	for key in config_dict:
		if isinstance(config_dict[key], dict):
			config_dict[key] = {k: v for k, v in config_dict[key].items() if v is not None}
	
	with open(filename, "w") as f:
		f.write(toml.dumps(config_dict))

	print(f"📄 {name} config saved to {filename}")

config_file = os.path.join(configdir, "training_config.toml")
CleanConfigAndSave("Train", config_file, config_dict)

dataset_file = os.path.join(configdir, "dataset_config.toml")
CleanConfigAndSave("Dataset", dataset_file, dataset_dict)

accelerate_file = os.path.join(configdir, "accelerate_config", "config.yaml")
from accelerate.utils import write_basic_config
write_basic_config(save_location=accelerate_file)
print(f"📄 Accelerate config saved to {accelerate_file}")

📄 Train config saved to database/working\juzhisam\configs\training_config.toml
📄 Dataset config saved to database/working\juzhisam\configs\dataset_config.toml
Configuration already exists at database/working\juzhisam\configs\accelerate_config\config.yaml, will not override. Run `accelerate config` manually or pass a different `save_location`.
📄 Accelerate config saved to database/working\juzhisam\configs\accelerate_config\config.yaml


In [10]:
print("⭐ Starting trainer...")

#!accelerate launch --config_file={accelerate_file} --num_cpu_threads_per_process=1 --mixed_precision={config_dict["training_arguments"]["mixed_precision"]} train_network_xl_wrapper.py --dataset_config={dataset_file} --config_file={config_file} --mep_key {encrypt_key}

⭐ Starting trainer...


In [11]:
!python train_network_xl_wrapper.py --dataset_config={dataset_file} --config_file={config_file} --mep_key {encrypt_key}

^C
