[![visitor][visitor-badge]][visitor-stats]
[![ko-fi][ko-fi-badge]][ko-fi-link]

# **Kohya LoRA Trainer XL**
A Colab Notebook For SDXL LoRA Training (Fine-tuning Method)

[visitor-badge]: https://api.visitorbadge.io/api/visitors?path=Kohya%20LoRA%20Trainer%20XL&label=Visitors&labelColor=%2334495E&countColor=%231ABC9C&style=flat&labelStyle=none
[visitor-stats]: https://visitorbadge.io/status?path=Kohya%20LoRA%20Trainer%20XL
[ko-fi-badge]: https://img.shields.io/badge/Support%20me%20on%20Ko--fi-F16061?logo=ko-fi&logoColor=white&style=flat
[ko-fi-link]: https://ko-fi.com/linaqruf

| Notebook Name | Description | Link |
| --- | --- | --- |
| [Kohya LoRA Trainer XL](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-LoRA-trainer-XL.ipynb) | LoRA Training | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=flat)](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-LoRA-trainer-XL.ipynb) |
| [Kohya Trainer XL](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer-XL.ipynb) | Native Training | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=flat)](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-trainer-XL.ipynb) |


<hr>
<h4><font color="#4a90e2"><b>NEWS:</b></font> <i>Colab's free-tier users can now train SDXL LoRA using the diffusers format instead of checkpoint as a pretrained model.</i></h4>
<hr>

In [1]:
# @title ## **1.1. Install Kohya Trainer**
import os
import zipfile
import shutil
import time
import requests
import torch
from subprocess import getoutput
from IPython.utils import capture
from google.colab import drive

%store -r

# root_dir
root_dir          = "/content"
drive_dir         = os.path.join(root_dir, "drive/MyDrive")
deps_dir          = os.path.join(root_dir, "deps")
repo_dir          = os.path.join(root_dir, "kohya-trainer")
training_dir      = os.path.join(root_dir, "LoRA")
pretrained_model  = os.path.join(root_dir, "pretrained_model")
vae_dir           = os.path.join(root_dir, "vae")
lora_dir          = os.path.join(root_dir, "network_weight")
repositories_dir  = os.path.join(root_dir, "repositories")
config_dir        = os.path.join(training_dir, "config")
tools_dir         = os.path.join(repo_dir, "tools")
finetune_dir      = os.path.join(repo_dir, "finetune")
accelerate_config = os.path.join(repo_dir, "accelerate_config/config.yaml")

for store in ["root_dir", "repo_dir", "training_dir", "pretrained_model", "vae_dir", "repositories_dir", "accelerate_config", "tools_dir", "finetune_dir", "config_dir"]:
    with capture.capture_output() as cap:
        %store {store}
        del cap

repo_dict = {
    "qaneel/kohya-trainer (forked repo, stable, optimized for colab use)" : "https://github.com/qaneel/kohya-trainer",
    "kohya-ss/sd-scripts (original repo, latest update)"                    : "https://github.com/kohya-ss/sd-scripts",
}

repository        = "qaneel/kohya-trainer (forked repo, stable, optimized for colab use)" #@param ["qaneel/kohya-trainer (forked repo, stable, optimized for colab use)", "kohya-ss/sd-scripts (original repo, latest update)"] {allow-input: true}
repo_url          = repo_dict[repository]
branch            = "main"  # @param {type: "string"}
output_to_drive   = True  # @param {type: "boolean"}

def clone_repo(url, dir, branch):
    if not os.path.exists(dir):
       !git clone -b {branch} {url} {dir}

def mount_drive(dir):
    output_dir      = os.path.join(training_dir, "output")

    if output_to_drive:
        if not os.path.exists(drive_dir):
            drive.mount(os.path.dirname(drive_dir))
        output_dir  = os.path.join(drive_dir, "kohya-trainer/output")

    return output_dir

def setup_directories():
    global output_dir

    output_dir      = mount_drive(drive_dir)

    for dir in [training_dir, config_dir, pretrained_model, vae_dir, repositories_dir, output_dir]:
        os.makedirs(dir, exist_ok=True)

def pastebin_reader(id):
    if "pastebin.com" in id:
        url = id
        if 'raw' not in url:
                url = url.replace('pastebin.com', 'pastebin.com/raw')
    else:
        url = "https://pastebin.com/raw/" + id
    response = requests.get(url)
    response.raise_for_status()
    lines = response.text.split('\n')
    return lines

def install_repository():
    global infinite_image_browser_dir, voldy, discordia_archivum_dir

    _, voldy = pastebin_reader("kq6ZmHFU")[:2]

    infinite_image_browser_url  = f"https://github.com/zanllp/{voldy}-infinite-image-browsing.git"
    infinite_image_browser_dir  = os.path.join(repositories_dir, f"infinite-image-browsing")
    infinite_image_browser_deps = os.path.join(infinite_image_browser_dir, "requirements.txt")

    discordia_archivum_url = "https://github.com/Linaqruf/discordia-archivum"
    discordia_archivum_dir = os.path.join(repositories_dir, "discordia-archivum")
    discordia_archivum_deps = os.path.join(discordia_archivum_dir, "requirements.txt")

    clone_repo(infinite_image_browser_url, infinite_image_browser_dir, "main")
    clone_repo(discordia_archivum_url, discordia_archivum_dir, "main")

    !pip install -q --upgrade -r {infinite_image_browser_deps}
    !pip install python-dotenv
    !pip install -q --upgrade -r {discordia_archivum_deps}

def install_dependencies():
    requirements_file = os.path.join(repo_dir, "requirements.txt")
    model_util        = os.path.join(repo_dir, "library/model_util.py")
    gpu_info          = getoutput('nvidia-smi')
    t4_xformers_wheel = "https://github.com/Linaqruf/colab-xformers/releases/download/0.0.20/xformers-0.0.20+1d635e1.d20230519-cp310-cp310-linux_x86_64.whl"

    !apt install aria2 lz4
    !wget https://github.com/camenduru/gperftools/releases/download/v1.0/libtcmalloc_minimal.so.4 -O /content/libtcmalloc_minimal.so.4
    !pip install -q --upgrade -r {requirements_file}

    !pip install -q xformers==0.0.22.post7

    from accelerate.utils import write_basic_config

    if not os.path.exists(accelerate_config):
        write_basic_config(save_location=accelerate_config)

def prepare_environment():
    os.environ["LD_PRELOAD"] = "/content/libtcmalloc_minimal.so.4"
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["SAFETENSORS_FAST_GPU"] = "1"
    os.environ["PYTHONWARNINGS"] = "ignore"

def main():
    os.chdir(root_dir)
    clone_repo(repo_url, repo_dir, branch)
    os.chdir(repo_dir)
    setup_directories()
    install_repository()
    install_dependencies()
    prepare_environment()

main()



# @title ## **1.2. Download SDXL**
import os
import re
import json
import glob
import gdown
import requests
import subprocess
from IPython.utils import capture
from urllib.parse import urlparse, unquote
from pathlib import Path
from huggingface_hub import HfFileSystem
from huggingface_hub.utils import validate_repo_id, HfHubHTTPError

%store -r

os.chdir(root_dir)

# @markdown Place your Huggingface token [here](https://huggingface.co/settings/tokens) to download gated models.

HUGGINGFACE_TOKEN     = "" #@param {type: "string"}
LOAD_DIFFUSERS_MODEL  = True #@param {type: "boolean"}
SDXL_MODEL_URL        = "Linaqruf/animagine-xl" # @param ["gsdf/CounterfeitXL", "Linaqruf/animagine-xl", "stabilityai/stable-diffusion-xl-base-1.0", "PASTE MODEL URL OR GDRIVE PATH HERE"] {allow-input: true}
SDXL_VAE_URL          = "Original VAE" # @param ["None", "Original VAE", "FP16 VAE", "PASTE VAE URL OR GDRIVE PATH HERE"] {allow-input: true}

MODEL_URLS = {
    "gsdf/CounterfeitXL"        : "https://huggingface.co/gsdf/CounterfeitXL/resolve/main/CounterfeitXL_%CE%B2.safetensors",
    "Linaqruf/animagine-xl"   : "https://huggingface.co/Linaqruf/animagine-xl/resolve/main/animagine-xl.safetensors",
    "stabilityai/stable-diffusion-xl-base-1.0" : "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors",
}
VAE_URLS = {
    "None"                    : "",
    "Original VAE"           : "https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors",
    "FP16 VAE"           : "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl_vae.safetensors"
}

SDXL_MODEL_URL = MODEL_URLS.get(SDXL_MODEL_URL, SDXL_MODEL_URL)
SDXL_VAE_URL = VAE_URLS.get(SDXL_VAE_URL, SDXL_VAE_URL)

def get_filename(url):
    if any(url.endswith(ext) for ext in [".ckpt", ".safetensors", ".pt", ".pth"]):
        return os.path.basename(url)

    response = requests.get(url, stream=True)
    response.raise_for_status()

    if 'content-disposition' in response.headers:
        filename = re.findall('filename="?([^"]+)"?', response.headers['content-disposition'])[0]
    else:
        filename = unquote(os.path.basename(urlparse(url).path))

    return filename

def aria2_download(dir, filename, url):
    user_header = f"Authorization: Bearer {HUGGINGFACE_TOKEN}"
    aria2_args = [
        "aria2c",
        "--console-log-level=error",
        "--summary-interval=10",
        f"--header={user_header}" if "huggingface.co" in url else "",
        "--continue=true",
        "--max-connection-per-server=16",
        "--min-split-size=1M",
        "--split=16",
        f"--dir={dir}",
        f"--out={filename}",
        url
    ]
    subprocess.run(aria2_args)

def download(url, dst):
    print(f"Starting downloading from {url}")
    filename = get_filename(url)
    filepath = os.path.join(dst, filename)

    if "drive.google.com" in url:
        gdown.download(url, filepath, quiet=False)
    else:
        if "huggingface.co" in url and "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
        aria2_download(dst, filename, url)

    print(f"Download finished: {filepath}")
    return filepath

def all_folders_present(base_model_url, sub_folders):
    fs = HfFileSystem()
    existing_folders = set(fs.ls(base_model_url, detail=False))

    for folder in sub_folders:
        full_folder_path = f"{base_model_url}/{folder}"
        if full_folder_path not in existing_folders:
            return False
    return True

def get_total_ram_gb():
    with open('/proc/meminfo', 'r') as f:
        for line in f.readlines():
            if "MemTotal" in line:
                return int(line.split()[1]) / (1024**2)  # Convert to GB

def get_gpu_name():
    try:
        return subprocess.check_output("nvidia-smi --query-gpu=name --format=csv,noheader,nounits", shell=True).decode('ascii').strip()
    except:
        return None

def main():
    global model_path, vae_path, LOAD_DIFFUSERS_MODEL

    model_path, vae_path = None, None

    required_sub_folders = [
        'scheduler',
        'text_encoder',
        'text_encoder_2',
        'tokenizer',
        'tokenizer_2',
        'unet',
        'vae',
    ]

    download_targets = {
        "model": (SDXL_MODEL_URL, pretrained_model),
        "vae": (SDXL_VAE_URL, vae_dir),
    }

    total_ram = get_total_ram_gb()
    gpu_name = get_gpu_name()

    # Check hardware constraints
    if total_ram < 13 and gpu_name in ["Tesla T4", "Tesla V100"]:
        print("Attempt to load diffusers model instead due to hardware constraints.")
        if not LOAD_DIFFUSERS_MODEL:
            LOAD_DIFFUSERS_MODEL = True

    for target, (url, dst) in download_targets.items():
        if url and not url.startswith(f"PASTE {target.upper()} URL OR GDRIVE PATH HERE"):
            if target == "model" and LOAD_DIFFUSERS_MODEL:
                # Code for checking and handling diffusers model
                if 'huggingface.co' in url:
                    match = re.search(r'huggingface\.co/([^/]+)/([^/]+)', SDXL_MODEL_URL)
                    if match:
                        username = match.group(1)
                        model_name = match.group(2)
                        url = f"{username}/{model_name}"
                if all_folders_present(url, required_sub_folders):
                    print(f"Diffusers model is loaded : {url}")
                    model_path = url
                else:
                    print("Repository doesn't exist or no diffusers model detected.")
                    filepath = download(url, dst)  # Continue with the regular download
                    model_path = filepath
            else:
                filepath = download(url, dst)

                if target == "model":
                    model_path = filepath
                elif target == "vae":
                    vae_path = filepath

            print()

    if model_path:
        print(f"Selected model: {model_path}")

    if vae_path:
        print(f"Selected VAE: {vae_path}")

main()

Cloning into '/content/kohya-trainer'...
remote: Enumerating objects: 2441, done.[K
remote: Counting objects: 100% (1045/1045), done.[K
remote: Compressing objects: 100% (273/273), done.[K
remote: Total 2441 (delta 907), reused 772 (delta 772), pack-reused 1396[K
Receiving objects: 100% (2441/2441), 4.13 MiB | 9.88 MiB/s, done.
Resolving deltas: 100% (1632/1632), done.
Mounted at /content/drive
Cloning into '/content/repositories/infinite-image-browsing'...
remote: Enumerating objects: 8417, done.[K
remote: Counting objects: 100% (2813/2813), done.[K
remote: Compressing objects: 100% (806/806), done.[K
remote: Total 8417 (delta 2165), reused 2517 (delta 1951), pack-reused 5604[K
Receiving objects: 100% (8417/8417), 18.84 MiB | 20.64 MiB/s, done.
Resolving deltas: 100% (6050/6050), done.
Cloning into '/content/repositories/discordia-archivum'...
remote: Enumerating objects: 54, done.[K
remote: Counting objects: 100% (54/54), done.[K
remote: Compressing objects: 100% (45/45), d

In [10]:
# @title ## **3.4. Bucketing and Latents Caching**
%store -r


train_folder_directory = "/content/drive/MyDrive/dreambooth/dataset/train" #@param {'type':'string'}
%store train_folder_directory

for training_dir in os.listdir(train_folder_directory):
    train_data_dir = os.path.join(train_folder_directory, training_dir)
    # @markdown This code will create buckets based on the `bucket_resolution` provided for multi-aspect ratio training, and then convert all images within the `train_data_dir` to latents.
    bucketing_json    = os.path.join(train_folder_directory, training_dir, "meta_lat.json")
    metadata_json     = os.path.join(train_folder_directory, training_dir, "meta_clean.json")
    project_name      = training_dir
    bucket_resolution = 1024  # @param {type:"slider", min:512, max:1024, step:128}
    mixed_precision   = "no"  # @param ["no", "fp16", "bf16"] {allow-input: false}
    skip_existing     = False  # @param{type:"boolean"}
    flip_aug          = False  # @param{type:"boolean"}
    # @markdown Use `clean_caption` option to clean such as duplicate tags, `women` to `girl`, etc
    clean_caption     = False #@param {type:"boolean"}
    #@markdown Use the `recursive` option to process subfolders as well
    recursive         = True #@param {type:"boolean"}

    metadata_config = {
        "_train_data_dir": train_data_dir,
        "_out_json": metadata_json,
        "recursive": recursive,
        "full_path": recursive,
        "clean_caption": clean_caption
    }

    bucketing_config = {
        "_train_data_dir": train_data_dir,
        "_in_json": metadata_json,
        "_out_json": bucketing_json,
        "_model_name_or_path": vae_path if vae_path else model_path,
        "recursive": recursive,
        "full_path": recursive,
        "flip_aug": flip_aug,
        "skip_existing": skip_existing,
        "batch_size": 4,
        "max_data_loader_n_workers": 2,
        "max_resolution": f"{bucket_resolution}, {bucket_resolution}",
        "mixed_precision": mixed_precision,
    }

    def generate_args(config):
        args = ""
        for k, v in config.items():
            if k.startswith("_"):
                args += f'"{v}" '
            elif isinstance(v, str):
                args += f'--{k}="{v}" '
            elif isinstance(v, bool) and v:
                args += f"--{k} "
            elif isinstance(v, float) and not isinstance(v, bool):
                args += f"--{k}={v} "
            elif isinstance(v, int) and not isinstance(v, bool):
                args += f"--{k}={v} "
        return args.strip()

    merge_metadata_args = generate_args(metadata_config)
    prepare_buckets_args = generate_args(bucketing_config)

    merge_metadata_command = f"python merge_all_to_metadata.py {merge_metadata_args}"
    prepare_buckets_command = f"python prepare_buckets_latents.py {prepare_buckets_args}"

    os.chdir(finetune_dir)
    !{merge_metadata_command}
    time.sleep(1)
    !{prepare_buckets_command}

    print("ahat1")









    import toml

    print("ahat2")

    # @title ## **4.1. LoRa: Low-Rank Adaptation Config**
    # @markdown Kohya's `LoRA` renamed to `LoRA-LierLa` and Kohya's `LoCon` renamed to `LoRA-C3Lier`, read [official announcement](https://github.com/kohya-ss/sd-scripts/blob/849bc24d205a35fbe1b2a4063edd7172533c1c01/README.md#naming-of-lora).
    network_category = "LoRA_LierLa"  # @param ["LoRA_LierLa", "LoRA_C3Lier", "DyLoRA_LierLa", "DyLoRA_C3Lier", "LoCon", "LoHa", "IA3", "LoKR", "DyLoRA_Lycoris"]

    # @markdown | network_category | network_dim | network_alpha | conv_dim | conv_alpha | unit |
    # @markdown | :---: | :---: | :---: | :---: | :---: | :---: |
    # @markdown | LoRA-LierLa | 32 | 1 | - | - | - |
    # @markdown | LoCon/LoRA-C3Lier | 16 | 8 | 8 | 1 | - |
    # @markdown | LoHa | 8 | 4 | 4 | 1 | - |
    # @markdown | Other Category | ? | ? | ? | ? | - |

    # @markdown Specify `network_args` to add `optional` training args, like for specifying each 25 block weight, read [this](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E9%9A%8E%E5%B1%A4%E5%88%A5%E5%AD%A6%E7%BF%92%E7%8E%87)
    network_args    = ""  # @param {'type':'string'}

    # @markdown ### **Linear Layer Config**
    # @markdown Used by all `network_category`. When in doubt, set `network_dim = network_alpha`
    network_dim     = 32  # @param {'type':'number'}
    network_alpha   = 16  # @param {'type':'number'}

    # @markdown ### **Convolutional Layer Config**
    # @markdown Only required if `network_category` is not `LoRA_LierLa`, as it involves training convolutional layers in addition to linear layers.
    conv_dim        = 32  # @param {'type':'number'}
    conv_alpha      = 16  # @param {'type':'number'}

    # @markdown ### **DyLoRA Config**
    # @markdown Only required if `network_category` is `DyLoRA_LierLa` and `DyLoRA_C3Lier`
    unit = 4  # @param {'type':'number'}

    if isinstance(network_args, str):
        network_args = network_args.strip()
        if network_args.startswith('[') and network_args.endswith(']'):
            try:
                network_args = ast.literal_eval(network_args)
            except (SyntaxError, ValueError) as e:
                print(f"Error parsing network_args: {e}\n")
                network_args = []
        elif len(network_args) > 0:
            print(f"WARNING! '{network_args}' is not a valid list! Put args like this: [\"args=1\", \"args=2\"]\n")
            network_args = []
        else:
            network_args = []
    else:
        network_args = []

    network_config = {
        "LoRA_LierLa": {
            "module": "networks.lora",
            "args"  : []
        },
        "LoRA_C3Lier": {
            "module": "networks.lora",
            "args"  : [
                f"conv_dim={conv_dim}",
                f"conv_alpha={conv_alpha}"
            ]
        },
        "DyLoRA_LierLa": {
            "module": "networks.dylora",
            "args"  : [
                f"unit={unit}"
            ]
        },
        "DyLoRA_C3Lier": {
            "module": "networks.dylora",
            "args"  : [
                f"conv_dim={conv_dim}",
                f"conv_alpha={conv_alpha}",
                f"unit={unit}"
            ]
        },
        "LoCon": {
            "module": "lycoris.kohya",
            "args"  : [
                f"algo=locon",
                f"conv_dim={conv_dim}",
                f"conv_alpha={conv_alpha}"
            ]
        },
        "LoHa": {
            "module": "lycoris.kohya",
            "args"  : [
                f"algo=loha",
                f"conv_dim={conv_dim}",
                f"conv_alpha={conv_alpha}"
            ]
        },
        "IA3": {
            "module": "lycoris.kohya",
            "args"  : [
                f"algo=ia3",
                f"conv_dim={conv_dim}",
                f"conv_alpha={conv_alpha}"
            ]
        },
        "LoKR": {
            "module": "lycoris.kohya",
            "args"  : [
                f"algo=lokr",
                f"conv_dim={conv_dim}",
                f"conv_alpha={conv_alpha}"
            ]
        },
        "DyLoRA_Lycoris": {
            "module": "lycoris.kohya",
            "args"  : [
                f"algo=dylora",
                f"conv_dim={conv_dim}",
                f"conv_alpha={conv_alpha}"
            ]
        }
    }
    print("ahat3")

    network_module = network_config[network_category]["module"]
    network_args.extend(network_config[network_category]["args"])

    lora_config = {
        "additional_network_arguments": {
            "no_metadata"                     : False,
            "network_module"                  : network_module,
            "network_dim"                     : network_dim,
            "network_alpha"                   : network_alpha,
            "network_args"                    : network_args,
            "network_train_unet_only"         : True,
            "training_comment"                : None,
        },
    }
    print("ahat4")

    print(toml.dumps(lora_config))










    import toml
    import ast
    print("ahat5")

    # @title ## **4.2. Optimizer Config**
    # @markdown Use `Adafactor` optimizer. `RMSprop 8bit` or `Adagrad 8bit` may work. `AdamW 8bit` doesn't seem to work.
    optimizer_type = "AdaFactor"  # @param ["AdamW", "AdamW8bit", "Lion8bit", "Lion", "SGDNesterov", "SGDNesterov8bit", "DAdaptation(DAdaptAdamPreprint)", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptAdanIP", "DAdaptLion", "DAdaptSGD", "AdaFactor"]
    # @markdown Specify `optimizer_args` to add `additional` args for optimizer, e.g: `["weight_decay=0.6"]`
    optimizer_args = "[ \"scale_parameter=False\", \"relative_step=False\", \"warmup_init=False\" ]"  # @param {'type':'string'}
    # @markdown ### **Learning Rate Config**
    # @markdown Different `optimizer_type` and `network_category` for some condition requires different learning rate. It's recommended to set `text_encoder_lr = 1/2 * unet_lr`
    learning_rate = 1e-4  # @param {'type':'number'}
    # @markdown ### **LR Scheduler Config**
    # @markdown `lr_scheduler` provides several methods to adjust the learning rate based on the number of epochs.
    lr_scheduler = "constant_with_warmup"  # @param ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "adafactor"] {allow-input: false}
    lr_warmup_steps = 100  # @param {'type':'number'}
    # @markdown Specify `lr_scheduler_num` with `num_cycles` value for `cosine_with_restarts` or `power` value for `polynomial`
    lr_scheduler_num = 0  # @param {'type':'number'}

    if isinstance(optimizer_args, str):
        optimizer_args = optimizer_args.strip()
        if optimizer_args.startswith('[') and optimizer_args.endswith(']'):
            try:
                optimizer_args = ast.literal_eval(optimizer_args)
            except (SyntaxError, ValueError) as e:
                print(f"Error parsing optimizer_args: {e}\n")
                optimizer_args = []
        elif len(optimizer_args) > 0:
            print(f"WARNING! '{optimizer_args}' is not a valid list! Put args like this: [\"args=1\", \"args=2\"]\n")
            optimizer_args = []
        else:
            optimizer_args = []
    else:
        optimizer_args = []

    optimizer_config = {
        "optimizer_arguments": {
            "optimizer_type"          : optimizer_type,
            "learning_rate"           : learning_rate,
            "max_grad_norm"           : 0,
            "optimizer_args"          : optimizer_args,
            "lr_scheduler"            : lr_scheduler,
            "lr_warmup_steps"         : lr_warmup_steps,
            "lr_scheduler_num_cycles" : lr_scheduler_num if lr_scheduler == "cosine_with_restarts" else None,
            "lr_scheduler_power"      : lr_scheduler_num if lr_scheduler == "polynomial" else None,
            "lr_scheduler_type"       : None,
            "lr_scheduler_args"       : None,
        },
    }

    print("ahat6")
    print(toml.dumps(optimizer_config))










    # @title ## **4.3. Advanced Training Config** (Optional)
    import toml

    print("ahat7")

    # @markdown ### **Optimizer State Config**
    save_optimizer_state      = False #@param {type:"boolean"}
    load_optimizer_state      = "" #@param {type:"string"}
    # @markdown ### **Noise Control**
    noise_control_type        = "none" #@param ["none", "noise_offset", "multires_noise"]
    # @markdown #### **a. Noise Offset**
    # @markdown Control and easily generating darker or light images by offset the noise when fine-tuning the model. Recommended value: `0.1`. Read [Diffusion With Offset Noise](https://www.crosslabs.org//blog/diffusion-with-offset-noise)
    noise_offset_num          = 0.0357  # @param {type:"number"}
    # @markdown **[Experimental]**
    # @markdown Automatically adjusts the noise offset based on the absolute mean values of each channel in the latents when used with `--noise_offset`. Specify a value around 1/10 to the same magnitude as the `--noise_offset` for best results. Set `0` to disable.
    adaptive_noise_scale      = 0.00357 # @param {type:"number"}
    # @markdown #### **b. Multires Noise**
    # @markdown enable multires noise with this number of iterations (if enabled, around 6-10 is recommended)
    multires_noise_iterations = 6 #@param {type:"slider", min:1, max:10, step:1}
    multires_noise_discount = 0.3 #@param {type:"slider", min:0.1, max:1, step:0.1}
    # @markdown ### **Caption Dropout**
    caption_dropout_rate = 0  # @param {type:"number"}
    caption_tag_dropout_rate = 0.5  # @param {type:"number"}
    caption_dropout_every_n_epochs = 0  # @param {type:"number"}
    # @markdown ### **Custom Train Function**
    # @markdown Gamma for reducing the weight of high-loss timesteps. Lower numbers have a stronger effect. The paper recommends `5`. Read the paper [here](https://arxiv.org/abs/2303.09556).
    min_snr_gamma             = 5 #@param {type:"number"}

    advanced_training_config = {
        "advanced_training_config": {
            "resume"                        : load_optimizer_state,
            "save_state"                    : save_optimizer_state,
            "save_last_n_epochs_state"      : save_optimizer_state,
            "noise_offset"                  : noise_offset_num if noise_control_type == "noise_offset" else None,
            "adaptive_noise_scale"          : adaptive_noise_scale if adaptive_noise_scale and noise_control_type == "noise_offset" else None,
            "multires_noise_iterations"     : multires_noise_iterations if noise_control_type =="multires_noise" else None,
            "multires_noise_discount"       : multires_noise_discount if noise_control_type =="multires_noise" else None,
            "caption_dropout_rate"          : caption_dropout_rate,
            "caption_tag_dropout_rate"      : caption_tag_dropout_rate,
            "caption_dropout_every_n_epochs": caption_dropout_every_n_epochs,
            "min_snr_gamma"                 : min_snr_gamma if not min_snr_gamma == -1 else None,
        }
    }
    print("ahat8")

    print(toml.dumps(advanced_training_config))








    # @title ## **4.4. Training Config**
    import toml
    import os
    from subprocess import getoutput

    %store -r
    print("ahat9")
    print("project_name", project_name)

    # @markdown ### **Project Config**
    output_dir                  = "/content/drive/MyDrive/dreambooth/output" # @param {type:"string"}
    # @markdown Get your `wandb_api_key` [here](https://wandb.ai/settings) to logs with wandb.
    wandb_api_key               = "" # @param {type:"string"}
    in_json                     = bucketing_json
    # @markdown ### **SDXL Config**
    gradient_checkpointing      = True  # @param {type:"boolean"}
    no_half_vae                 = True  # @param {type:"boolean"}
    #@markdown Recommended parameter for SDXL training but if you enable it, `shuffle_caption` won't work
    cache_text_encoder_outputs  = False  # @param {type:"boolean"}
    #@markdown These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
    min_timestep                = 0 # @param {type:"number"}
    max_timestep                = 1000 # @param {type:"number"}
    # @markdown ### **Dataset Config**
    num_repeats                 = 1  # @param {type:"number"}
    resolution                  = 1024  # @param {type:"slider", min:512, max:1024, step:128}
    keep_tokens                 = 0  # @param {type:"number"}
    # @markdown ### **General Config**
    num_epochs                  = 1  # @param {type:"number"}
    train_batch_size            = 4  # @param {type:"number"}
    mixed_precision             = "fp16"  # @param ["no","fp16","bf16"] {allow-input: false}
    seed                        = -1  # @param {type:"number"}
    optimization                = "scaled dot-product attention" # @param ["xformers", "scaled dot-product attention"]
    # @markdown ### **Save Output Config**
    save_precision              = "fp16"  # @param ["float", "fp16", "bf16"] {allow-input: false}
    save_every_n_epochs         = 1  # @param {type:"number"}
    # @markdown ### **Sample Prompt Config**
    enable_sample               = False  # @param {type:"boolean"}
    sampler                     = "euler_a"  # @param ["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]
    positive_prompt             = ""
    negative_prompt             = ""
    quality_prompt              = "NovelAI"  # @param ["None", "Waifu Diffusion 1.5", "NovelAI", "AbyssOrangeMix", "Stable Diffusion XL"] {allow-input: false}
    if quality_prompt          == "NovelAI":
        positive_prompt         = "masterpiece, best quality, "
        negative_prompt         = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, "
    if quality_prompt          == "AbyssOrangeMix":
        positive_prompt         = "masterpiece, best quality, "
        negative_prompt         = "(worst quality, low quality:1.4), "
    if quality_prompt          == "Stable Diffusion XL":
        negative_prompt         = "3d render, smooth, plastic, blurry, grainy, low-resolution, deep-fried, oversaturated"
    custom_prompt               = "face focus, cute, 1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, night, turtleneck" # @param {type:"string"}
    # @markdown Specify `prompt_from_caption` if you want to use caption as prompt instead. Will be chosen randomly.
    prompt_from_caption         = "none"  # @param ["none", ".txt", ".caption"]
    if prompt_from_caption     != "none":
        custom_prompt           = ""
    num_prompt                  = 2  # @param {type:"number"}
    logging_dir                 = os.path.join(training_dir, "logs")
    lowram                      = int(next(line.split()[1] for line in open('/proc/meminfo') if "MemTotal" in line)) / (1024**2) < 15

    os.chdir(repo_dir)

    prompt_config = {
        "prompt": {
            "negative_prompt" : negative_prompt,
            "width"           : resolution,
            "height"          : resolution,
            "scale"           : 12,
            "sample_steps"    : 28,
            "subset"          : [],
        }
    }

    train_config = {
        "sdxl_arguments": {
            "cache_text_encoder_outputs" : cache_text_encoder_outputs,
            "no_half_vae"                : True,
            "min_timestep"               : min_timestep,
            "max_timestep"               : max_timestep,
            "shuffle_caption"            : True if not cache_text_encoder_outputs else False,
            "lowram"                     : lowram
        },
        "model_arguments": {
            "pretrained_model_name_or_path" : model_path,
            "vae"                           : vae_path,
        },
        "dataset_arguments": {
            "debug_dataset"                 : False,
            "in_json"                       : in_json,
            "train_data_dir"                : train_data_dir,
            "dataset_repeats"               : num_repeats,
            "keep_tokens"                   : keep_tokens,
            "resolution"                    : str(resolution) + ',' + str(resolution),
            "color_aug"                     : False,
            "face_crop_aug_range"           : None,
            "token_warmup_min"              : 1,
            "token_warmup_step"             : 0,
        },
        "training_arguments": {
            "output_dir"                    : output_dir,
            "output_name"                   : project_name if project_name else "last",
            "save_precision"                : save_precision,
            "save_every_n_epochs"           : save_every_n_epochs,
            "save_n_epoch_ratio"            : None,
            "save_last_n_epochs"            : None,
            "resume"                        : None,
            "train_batch_size"              : train_batch_size,
            "max_token_length"              : 225,
            "mem_eff_attn"                  : False,
            "sdpa"                          : True if optimization == "scaled dot-product attention" else False,
            "xformers"                      : True if optimization == "xformers" else False,
            "max_train_epochs"              : num_epochs,
            "max_data_loader_n_workers"     : 8,
            "persistent_data_loader_workers": True,
            "seed"                          : seed if seed > 0 else None,
            "gradient_checkpointing"        : gradient_checkpointing,
            "gradient_accumulation_steps"   : 1,
            "mixed_precision"               : mixed_precision,
        },
        "logging_arguments": {
            "log_with"          : "wandb" if wandb_api_key else "tensorboard",
            "log_tracker_name"  : project_name if wandb_api_key and not project_name == "last" else None,
            "logging_dir"       : logging_dir,
            "log_prefix"        : project_name if not wandb_api_key else None,
        },
        "sample_prompt_arguments": {
            "sample_every_n_steps"    : None,
            "sample_every_n_epochs"   : save_every_n_epochs if enable_sample else None,
            "sample_sampler"          : sampler,
        },
        "saving_arguments": {
            "save_model_as": "safetensors"
        },
    }

    def write_file(filename, contents):
        with open(filename, "w") as f:
            f.write(contents)

    def prompt_convert(enable_sample, num_prompt, train_data_dir, prompt_config, custom_prompt):
        if enable_sample:
            search_pattern = os.path.join(train_data_dir, '**/*' + prompt_from_caption)
            caption_files = glob.glob(search_pattern, recursive=True)

            if not caption_files:
                if not custom_prompt:
                    custom_prompt = "masterpiece, best quality, 1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt"
                new_prompt_config = prompt_config.copy()
                new_prompt_config['prompt']['subset'] = [
                    {"prompt": positive_prompt + custom_prompt if positive_prompt else custom_prompt}
                ]
            else:
                selected_files = random.sample(caption_files, min(num_prompt, len(caption_files)))

                prompts = []
                for file in selected_files:
                    with open(file, 'r') as f:
                        prompts.append(f.read().strip())

                new_prompt_config = prompt_config.copy()
                new_prompt_config['prompt']['subset'] = []

                for prompt in prompts:
                    new_prompt = {
                        "prompt": positive_prompt + prompt if positive_prompt else prompt,
                    }
                    new_prompt_config['prompt']['subset'].append(new_prompt)

            return new_prompt_config
        else:
            return prompt_config

    def eliminate_none_variable(config):
        for key in config:
            if isinstance(config[key], dict):
                for sub_key in config[key]:
                    if config[key][sub_key] == "":
                        config[key][sub_key] = None
            elif config[key] == "":
                config[key] = None

        return config

    try:
        train_config.update(optimizer_config)
    except NameError:
        raise NameError("'optimizer_config' dictionary is missing. Please run  '4.1. Optimizer Config' cell.")

    try:
        train_config.update(lora_config)
    except NameError:
        raise NameError("'lora_config' dictionary is missing. Please run  '4.1. LoRa: Low-Rank Adaptation Config' cell.")

    advanced_training_warning = False
    try:
        train_config.update(advanced_training_config)
    except NameError:
        advanced_training_warning = True
        pass

    prompt_config = prompt_convert(enable_sample, num_prompt, train_data_dir, prompt_config, custom_prompt)

    config_path         = os.path.join(config_dir, "config_file.toml")
    prompt_path         = os.path.join(config_dir, "sample_prompt.toml")

    config_str          = toml.dumps(eliminate_none_variable(train_config))
    prompt_str          = toml.dumps(eliminate_none_variable(prompt_config))

    write_file(config_path, config_str)
    write_file(prompt_path, prompt_str)

    print(config_str)

    if advanced_training_warning:
        import textwrap
        error_message = "WARNING: This is not an error message, but the [advanced_training_config] dictionary is missing. Please run the '4.2. Advanced Training Config' cell if you intend to use it, or continue to the next step."
        wrapped_message = textwrap.fill(error_message, width=80)
        print('\033[38;2;204;102;102m' + wrapped_message + '\033[0m\n')
        pass
    print("ahat10")

    print(prompt_str)


    #@title ## **4.5. Start Training**
    import os
    import toml

    #@markdown Check your config here if you want to edit something:
    #@markdown - `sample_prompt` : /content/LoRA/config/sample_prompt.toml
    #@markdown - `config_file` : /content/LoRA/config/config_file.toml


    #@markdown You can import config from another session if you want.

    sample_prompt   = "/content/LoRA/config/sample_prompt.toml" #@param {type:'string'}
    config_file     = "/content/LoRA/config/config_file.toml" #@param {type:'string'}

    def read_file(filename):
        with open(filename, "r") as f:
            contents = f.read()
        return contents

    def train(config):
        args = ""
        for k, v in config.items():
            if k.startswith("_"):
                args += f'"{v}" '
            elif isinstance(v, str):
                args += f'--{k}="{v}" '
            elif isinstance(v, bool) and v:
                args += f"--{k} "
            elif isinstance(v, float) and not isinstance(v, bool):
                args += f"--{k}={v} "
            elif isinstance(v, int) and not isinstance(v, bool):
                args += f"--{k}={v} "

        return args

    accelerate_conf = {
        "config_file" : "/content/kohya-trainer/accelerate_config/config.yaml",
        "num_cpu_threads_per_process" : 1,
    }

    train_conf = {
        "sample_prompts"  : sample_prompt if os.path.exists(sample_prompt) else None,
        "config_file"     : config_file,
        "wandb_api_key"   : wandb_api_key if wandb_api_key else None
    }

    accelerate_args = train(accelerate_conf)
    train_args = train(train_conf)

    final_args = f"accelerate launch {accelerate_args} sdxl_train_network.py {train_args}"

    os.chdir(repo_dir)
    !{final_args}

Stored 'train_folder_directory' (str)
Found 84 images.
Creating a new metadata file
Merging tags and captions into metadata json.
100% 84/84 [00:00<00:00, 101.79it/s]
No captions found for any of the 84 images
All 84 images have tags
Writing metadata: /content/drive/MyDrive/dreambooth/dataset/train/쌍둥이공주_파인1_3/meta_clean.json
Done!
found 84 images.
loading existing metadata: /content/drive/MyDrive/dreambooth/dataset/train/쌍둥이공주_파인1_3/meta_clean.json
load VAE: /content/vae/sdxl_vae.safetensors
100% 84/84 [00:56<00:00,  1.49it/s]
bucket 0 (384, 1024): 1
bucket 1 (448, 1024): 1
bucket 2 (512, 1024): 5
bucket 3 (640, 1024): 8
bucket 4 (704, 1024): 9
bucket 5 (768, 1024): 8
bucket 6 (832, 1024): 11
bucket 7 (896, 1024): 3
bucket 8 (960, 1024): 1
bucket 9 (1024, 704): 1
bucket 10 (1024, 768): 29
bucket 11 (1024, 832): 1
bucket 12 (1024, 896): 1
bucket 13 (1024, 960): 2
bucket 14 (1024, 1024): 3
mean ar error: 0.014482077909482834
writing metadata: /content/drive/MyDrive