In [1]:
# 1.0 DEFINE DIRECTORIES

import os
import ipywidgets as widgets

# root_dir
root_dir                = "/home/studio-lab-user/sagemaker-studiolab-notebooks"
deps_dir                = os.path.join(root_dir, "deps")
repo_dir                = os.path.join(root_dir, "kohya-trainer")
pretrained_dir          = os.path.join(root_dir, "pretrained_model")
vae_dir                 = os.path.join(root_dir, "vae")

dreambooth_training_dir = os.path.join(root_dir, "dreambooth")
dreambooth_config_dir   = os.path.join(dreambooth_training_dir, "config")
dreambooth_output_dir   = os.path.join(dreambooth_training_dir, "output") 
dreambooth_sample_dir   = os.path.join(dreambooth_output_dir, "sample")
dreambooth_logging_dir  = os.path.join(dreambooth_training_dir, "logs")

inference_dir           = os.path.join(root_dir, "txt2img")

train_data_dir          = os.path.join(root_dir, "train_data")  
reg_data_dir            = os.path.join(root_dir, "reg_data")  

# repo_dir
accelerate_config       = os.path.join(repo_dir, "accelerate_config/config.yaml")
tools_dir               = os.path.join(repo_dir, "tools")
finetune_dir            = os.path.join(repo_dir, "finetune")

os.chdir(root_dir)

for dir in [
    deps_dir,
    dreambooth_training_dir,
    dreambooth_config_dir,
    dreambooth_output_dir,
    pretrained_dir,
    dreambooth_sample_dir,
    inference_dir,
    vae_dir, 
    train_data_dir, 
    reg_data_dir
]:
    os.makedirs(dir, exist_ok=True)

# User settings
repo_url          = widgets.Text(
    value         = "https://github.com/TensorMouse/kohya-trainer",
    description   = "Repository URL:",
    style         = {"description_width": "initial"},
    layout        = widgets.Layout(width='50%')
)
branch            = widgets.Text(
    value         = "",
    description   = "Branch:",
    style         = {"description_width": "initial"},
    layout        = widgets.Layout(width='50%')
)
tooltip_branch    = widgets.HTML(
    value         = '<span style="color: blue;">Leave the box empty to use the default repository.</span>'
)
install_xformers  = widgets.Checkbox(
    value         = True,
    description   = "Install xformers"
)
verbose           = widgets.Checkbox(
    value         = False,
    description   = "Verbose"
)

box = widgets.VBox([
    repo_url,
    widgets.VBox([tooltip_branch, branch]),
    install_xformers,
    verbose
])
box

VBox(children=(Text(value='https://github.com/TensorMouse/kohya-trainer', description='Repository URL:', layou…

In [2]:
# @title ## 1.1.0 CLONE REPO
# @markdown Clone Kohya Trainer from GitHub and check for updates. Use textbox below if you want to checkout other branch or old commit. Leave it empty to stay the HEAD on main.  This will also install the required libraries.

import os
import shutil
from subprocess import getoutput

def clone_repo(url,repo_dir):
    if not os.path.exists(repo_dir):
        os.chdir(root_dir)
        !git clone {url} {repo_dir}
    else:
        print("Repo already exists")
        #os.chdir(repo_dir)
        #!git pull origin {branch.value} if branch.value else !git pull
        
def install_dependencies():
    print('Installation can take multiple minutes, enable "Verbose" to see progress')
    s = getoutput('nvidia-smi')

    if 'T4' in s:
        !sed -i "s@cpu@cuda@" library/model_util.py

    !pip install {'-q' if not verbose.value else ''} --upgrade -r requirements.txt
    !pip install {'-q' if not verbose.value else ''} torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1+cu118 torchtext==0.15.1 torchdata==0.6.0 --extra-index-url https://download.pytorch.org/whl/cu118 -U
    !conda install -c conda-forge glib --yes
    
    if install_xformers.value:
        !pip install {'-q' if not verbose.value else ''} xformers==0.0.19 triton==2.0.0 -U
        
    from accelerate.utils import write_basic_config

    if not os.path.exists(accelerate_config):
        write_basic_config(save_location=accelerate_config)


def main():
    
    clone_repo(repo_url.value, repo_dir)
    os.chdir(repo_dir)
    install_dependencies()

main()
!pip cache purge

Cloning into '/home/studio-lab-user/sagemaker-studiolab-notebooks/kohya-trainer'...
remote: Enumerating objects: 1988, done.[K
remote: Counting objects: 100% (656/656), done.[K
remote: Compressing objects: 100% (188/188), done.[K
remote: Total 1988 (delta 528), reused 514 (delta 468), pack-reused 1332[K
Receiving objects: 100% (1988/1988), 3.66 MiB | 27.79 MiB/s, done.
Resolving deltas: 100% (1283/1283), done.
Installation can take multiple minutes, enable "Verbose" to see progress
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Collecting package metadata (current_repodata.json): done
Solving environment: done


  current version:



Files removed: 32


In [3]:
# @title ## 2.1. Download Available Model
import os
import ipywidgets as widgets

os.chdir(root_dir)

tooltip_model = widgets.HTML(value='<span style="color: blue;">Leave boxes empty to not install a model. Run this box again, to lock in choices.</span>') 

models = {
    "Animefull-final-pruned": "https://huggingface.co/Linaqruf/personal-backup/resolve/main/models/animefull-final-pruned.ckpt",
    "Anything-v3-1": "https://huggingface.co/cag/anything-v3-1/resolve/main/anything-v3-1.safetensors",
    "AnyLoRA": "https://huggingface.co/Linaqruf/stolen/resolve/main/pruned-models/AnyLoRA_noVae_fp16-pruned.safetensors",
    "AnimePastelDream": "https://huggingface.co/Lykon/AnimePastelDream/resolve/main/AnimePastelDream_Soft_noVae_fp16.safetensors",
    "Chillout-mix": "https://huggingface.co/Linaqruf/stolen/resolve/main/pruned-models/chillout_mix-pruned.safetensors",
    "OpenJourney-v4": "https://huggingface.co/prompthero/openjourney-v4/resolve/main/openjourney-v4.ckpt",
    "Stable-Diffusion-v1-5": "https://huggingface.co/Linaqruf/stolen/resolve/main/pruned-models/stable_diffusion_1_5-pruned.safetensors",
}

v2_models = {
    "stable-diffusion-2-1-base": "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors",
    "stable-diffusion-2-1-768v": "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors",
    "plat-diffusion-v1-3-1": "https://huggingface.co/p1atdev/pd-archive/resolve/main/plat-v1-3-1.safetensors",
    "replicant-v1": "https://huggingface.co/gsdf/Replicant-V1.0/resolve/main/Replicant-V1.0.safetensors",
    "illuminati-diffusion-v1-0": "https://huggingface.co/IlluminatiAI/Illuminati_Diffusion_v1.0/resolve/main/illuminati_diffusion_v1.0.safetensors",
    "illuminati-diffusion-v1-1": "https://huggingface.co/4eJIoBek/Illuminati-Diffusion-v1-1/resolve/main/illuminatiDiffusionV1_v11.safetensors",
    "waifu-diffusion-1-4-anime-e2": "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e2.ckpt",
    "waifu-diffusion-1-5-e2": "https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-fp32.safetensors",
    "waifu-diffusion-1-5-e2-aesthetic": "https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-aesthetic-fp32.safetensors",
}

vaes = {
    "none": "",
    "anime.vae.pt": "https://huggingface.co/Linaqruf/personal-backup/resolve/main/vae/animevae.pt",
    "waifudiffusion.vae.pt": "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt",
    "stablediffusion.vae.pt": "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt",
}

# Create the dropdown selector for model_name
model_name = widgets.Dropdown(
    options=[""] + list(models.keys()),
    value="Stable-Diffusion-v1-5",
    description="SD1.x model:"
)

# Create the dropdown selector for v2_model_name
v2_model_name = widgets.Dropdown(
    options=[""] + list(v2_models.keys()),
    value="",
    description="SD2.x model:"
)

# Create the text input for custom model URL
custom_model_url_text = widgets.Text(
    placeholder="Paste custom model URL",
    description="Custom Model:",
)

vaes_name = widgets.Dropdown(
    options=[""] + list(vaes.keys()),
    value="stablediffusion.vae.pt",
    description="Vae model:"
)

vaes_custom = widgets.Text(
    placeholder="Paste custom vaes URL",
    description="Custom Vae Model:",
)

# Combine the dropdown selectors
dropdown_box = widgets.VBox([tooltip_model, model_name, v2_model_name, custom_model_url_text, vaes_name, vaes_custom])
dropdown_box

VBox(children=(HTML(value='<span style="color: blue;">Leave boxes empty to not install a model. Run this box a…

In [4]:
os.chdir(pretrained_dir)

if model_name.value:
    model_url = models.get(model_name.value)
    model_file = os.path.basename(model_url)
    if not os.path.exists(os.path.join(pretrained_dir, model_file)):
        !wget {model_url}
    else:
        print(f"Model '{model_file}' already exists in the directory. Skipping download.")

if v2_model_name.value:
    v2_model_url = v2_models.get(v2_model_name.value)
    v2_model_file = os.path.basename(v2_model_url)
    if not os.path.exists(os.path.join(pretrained_dir, v2_model_file)):
        !wget {v2_model_url}
    else:
        print(f"Model '{v2_model_file}' already exists in the directory. Skipping download.")

if custom_model_url_text.value:
    !wget -nc --content-disposition {custom_model_url_text.value}
    
os.chdir(vae_dir)
    
if vaes_name.value:
    vaes_url = vaes.get(vaes_name.value)
    vaes_file = os.path.basename(vaes_url)
    if not os.path.exists(os.path.join(vae_dir, vaes_file)):
        !wget {vaes_url}
    else:
        print(f"Vae Model '{vaes_file}' already exists in the directory. Skipping download.")

if vaes_custom.value:
    !wget -nc --content-disposition {vaes_custom.value}


--2023-06-02 12:39:28--  https://civitai.com/api/download/models/29460
Resolving civitai.com (civitai.com)... 104.18.23.206, 104.18.22.206, 2606:4700::6812:17ce, ...
Connecting to civitai.com (civitai.com)|104.18.23.206|:443... connected.
HTTP request sent, awaiting response... 307 Temporary Redirect
Location: https://civitai-delivery-worker-prod-2023-06-01.5ac0637cfd0766c97916cefa3764fbdf.r2.cloudflarestorage.com/26957/training-images/realisticVisionV20Fp16.Or1n.safetensors?X-Amz-Expires=86400&response-content-disposition=attachment%3B%20filename%3D%22realisticVisionV20_v20NoVAE.safetensors%22&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=2fea663d76bd24a496545da373d610fc/20230602/us-east-1/s3/aws4_request&X-Amz-Date=20230602T123928Z&X-Amz-SignedHeaders=host&X-Amz-Signature=8a277a7026b1feb9089c343b18cbce3bf61144fccb9bb61856eb6bb4dd6600f5 [following]
--2023-06-02 12:39:28--  https://civitai-delivery-worker-prod-2023-06-01.5ac0637cfd0766c97916cefa3764fbdf.r2.cloudflarestorage.com/269

In [5]:
import ipywidgets as widgets

convert = widgets.Checkbox(value=False, description="Convert to RGB with white background", layout=widgets.Layout(width="auto"))
random_color = widgets.Checkbox(value=False, description="Use random color background", layout=widgets.Layout(width="auto"))
recursive = widgets.Checkbox(value=False, description="Process subfolders as well", layout=widgets.Layout(width="auto"))

def update_random_color_checkbox(change):
    random_color.disabled = not change["new"]

convert.observe(update_random_color_checkbox, "value")

widget_box = widgets.VBox([convert, random_color, recursive])
widget_box

VBox(children=(Checkbox(value=False, description='Convert to RGB with white background', layout=Layout(width='…

In [6]:
# @title ## 4.1. Data Cleaning
import os
import random
import concurrent.futures
from tqdm import tqdm
from PIL import Image

os.chdir(root_dir)

test = os.listdir(train_data_dir)

batch_size = 32
supported_types = [
    ".png",
    ".jpg",
    ".jpeg",
    ".webp",
    ".bmp",
    ".caption",
    ".npz",
    ".txt",
    ".json",
]

background_colors = [
    (255, 255, 255),
    (0, 0, 0),
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (255, 255, 0),
    (255, 0, 255),
    (0, 255, 255),
]

def clean_directory(directory):
    for item in os.listdir(directory):
        file_path = os.path.join(directory, item)
        if os.path.isfile(file_path):
            file_ext = os.path.splitext(item)[1]
            if file_ext not in supported_types:
                print(f"Deleting file {item} from {directory}")
                os.remove(file_path)
        elif os.path.isdir(file_path) and recursive:
            clean_directory(file_path)

def process_image(image_path):
    img = Image.open(image_path)
    img_dir, image_name = os.path.split(image_path)

    if img.mode in ("RGBA", "LA"):
        if random_color.value:
            background_color = random.choice(background_colors)
        else:
            background_color = (255, 255, 255)
        bg = Image.new("RGB", img.size, background_color)
        bg.paste(img, mask=img.split()[-1])

        if image_name.endswith(".webp"):
            bg = bg.convert("RGB")
            new_image_path = os.path.join(img_dir, image_name.replace(".webp", ".jpg"))
            bg.save(new_image_path, "JPEG")
            os.remove(image_path)
            print(f" Converted image: {image_name} to {os.path.basename(new_image_path)}")
        else:
            bg.save(image_path, "PNG")
            print(f" Converted image: {image_name}")
    else:
        if image_name.endswith(".webp"):
            new_image_path = os.path.join(img_dir, image_name.replace(".webp", ".jpg"))
            img.save(new_image_path, "JPEG")
            os.remove(image_path)
            print(f" Converted image: {image_name} to {os.path.basename(new_image_path)}")
        else:
            img.save(image_path, "PNG")

def find_images(directory):
    images = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".png") or file.endswith(".webp"):
                images.append(os.path.join(root, file))
    return images

clean_directory(train_data_dir)
images = find_images(train_data_dir)
num_batches = len(images) // batch_size + 1

if convert.value:
    with concurrent.futures.ThreadPoolExecutor() as executor:
        for i in tqdm(range(num_batches)):
            start = i * batch_size
            end = start + batch_size
            batch = images[start:end]
            executor.map(process_image, batch)

    print("All images have been converted")
print('all good')

100%|██████████| 1/1 [00:00<00:00, 11715.93it/s]

All images have been converted
all good





In [7]:
import ipywidgets as widgets

captioning = widgets.Dropdown(
    options=["BLIP", "Waifu", "No captions"],
    value="BLIP",
    description="Captioning:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

batch_size = widgets.IntSlider(
    value=8,
    min=1,
    max=16,
    step=1,
    description="Batch Size:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

max_data_loader_n_workers = widgets.IntSlider(
    value=2,
    min=1,
    max=8,
    step=1,
    description="Data Loader Workers:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

beam_search = widgets.Checkbox(
    value=True,
    description="Beam Search",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

min_length = widgets.IntSlider(
    value=5,
    min=0,
    max=100,
    step=5,
    description="Minimum Length:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

max_length = widgets.IntSlider(
    value=75,
    min=0,
    max=100,
    step=5,
    description="Maximum Length:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

recursive = widgets.Checkbox(
    value=False,
    description="Recursive",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

verbose_logging = widgets.Checkbox(
    value=True,
    description="Verbose Logging",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

model = widgets.Dropdown(
    options=[
        "SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
        "SmilingWolf/wd-v1-4-swinv2-tagger-v2",
        "SmilingWolf/wd-v1-4-convnext-tagger-v2",
        "SmilingWolf/wd-v1-4-vit-tagger-v2"
    ],
    value="SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
    description="Model:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

undesired_tags = widgets.Text(
    value="",
    description="Undesired Tags:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

general_threshold = widgets.FloatSlider(
    value=0.35,
    min=0.0,
    max=1.0,
    step=0.05,
    description="General Threshold:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

character_threshold = widgets.FloatSlider(
    value=0.35,
    min=0.0,
    max=1.0,
    step=0.05,
    description="Character Threshold:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}  
)

widget_box = widgets.VBox([
    captioning,
    batch_size,
    max_data_loader_n_workers,
    beam_search,
    min_length,
    max_length,
    recursive,
    verbose_logging,
    model,
    undesired_tags,
    general_threshold,
    character_threshold
])

def caption_default():
    model.disabled = True
    undesired_tags.disabled = True
    general_threshold.disabled = True
    character_threshold.disabled = True

def on_captioning_change(change):
    
    if change.new == "Waifu":
        model.disabled = False
        undesired_tags.disabled = False
        general_threshold.disabled = False
        character_threshold.disabled = False
    else:
        model.disabled = True
        undesired_tags.disabled = True
        general_threshold.disabled = True
        character_threshold.disabled = True

caption_default()
captioning.observe(on_captioning_change, names='value')

widget_box

VBox(children=(Dropdown(description='Captioning:', layout=Layout(width='auto'), options=('BLIP', 'Waifu', 'No …

In [8]:
#@title ### 4.2.1. BLIP Captioning
# Use BLIP for general images
# Use Waifu for anime/manga images
import os

os.chdir(finetune_dir)

if captioning.value == "BLIP":
    config = {
        "_train_data_dir" : train_data_dir,
        "batch_size" : batch_size.value,
        "beam_search" : beam_search.value,
        "min_length" : min_length.value,
        "max_length" : max_length.value,
        "debug" : verbose_logging.value,
        "caption_extension" : ".caption",
        "max_data_loader_n_workers" : max_data_loader_n_workers.value,
        "recursive" : recursive.value
    }

elif captioning.value == "Waifu":
    config = {
        "_train_data_dir": train_data_dir,
        "batch_size": batch_size.value,
        "repo_id": model.value,
        "recursive": recursive.value,
        "remove_underscore": True,
        "general_threshold": general_threshold.value,
        "character_threshold": character_threshold.value,
        "caption_extension": ".txt",
        "max_data_loader_n_workers": max_data_loader_n_workers.value,
        "debug": verbose_logging.value,
        "undesired_tags": undesired_tags.value
    }

else:
    print("No captioning option selected. Skipping captioning process.")

if captioning.value == "BLIP" or captioning == "Waifu":
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "

    final_args = f"python make_captions.py {args}" if captioning.value == "BLIP" else f"python tag_images_by_wd14_tagger.py {args}"

    os.chdir(finetune_dir)
    !{final_args}

load images from /home/studio-lab-user/sagemaker-studiolab-notebooks/train_data
found 24 images.
loading BLIP caption: https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth
Downloading (…)solve/main/vocab.txt: 100%|███| 232k/232k [00:00<00:00, 29.8MB/s]
Downloading (…)okenizer_config.json: 100%|███| 28.0/28.0 [00:00<00:00, 4.70kB/s]
Downloading (…)lve/main/config.json: 100%|██████| 570/570 [00:00<00:00, 347kB/s]
100%|██████████████████████████████████████| 1.66G/1.66G [00:29<00:00, 60.9MB/s]
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth
BLIP loaded
  0%|                                                     | 0/3 [00:00<?, ?it/s]/home/studio-lab-user/sagemaker-studiolab-notebooks/train_data/0530dfe0-c3b8-11eb-aff9-7670003c98ba.jpg a woman in a red dress posing for a picture
/home/studio-lab-user/sagemaker-studiolab-notebooks/train_data/08f5d98b9e9e9f91fddfb1d4a9ad8ae6e3-13-chr

In [9]:
# @title ### 4.2.3. Custom Caption/Tag config
import ipywidgets as widgets

extension = widgets.Dropdown(
    options=[".txt", ".caption"],
    value=".txt",
    description="Extension:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

custom_tag = widgets.Text(
    value="",
    description="Custom Tag:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

keyword = widgets.Text(
    value="",
    description="Keyword replaced by custom tag:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

sub_folder = widgets.Text(
    value="",
    description="Subfolder:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

append = widgets.Checkbox(
    value=False,
    description="Append Custom Tags",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

prefix_tag = widgets.Checkbox(
    value=False,
    description="Prefix Custom Tags",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

remove_tag = widgets.Checkbox(
    value=False,
    description="Remove Captions/Tags (Only this option will be executed if checked)",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

recursive = widgets.Checkbox(
    value=False,
    description="Recursive",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

widget_box = widgets.VBox([
    extension,
    custom_tag,
    keyword,
    sub_folder,
    append,
    prefix_tag,
    remove_tag,
    recursive
])

widget_box

VBox(children=(Dropdown(description='Extension:', layout=Layout(width='auto'), options=('.txt', '.caption'), s…

In [21]:
# @title ### 4.2.3. Custom Caption/Tag
import os

os.chdir(root_dir)

if sub_folder.value == "":
    image_dir = train_data_dir
elif sub_folder.value == "--all":
    image_dir = train_data_dir
    recursive = True
else:
    image_dir = os.path.join(train_data_dir, sub_folder.value)
    os.makedirs(image_dir, exist_ok=True)

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

def process_tags(filename, custom_tag, append, prefix_tag, remove_tag, keyword):
    contents = read_file(filename)
    tags = [tag.strip() for tag in contents.split(',')]
    custom_tags = [tag.strip() for tag in custom_tag.split(',')]

    for custom_tag in custom_tags:
        custom_tag = custom_tag.replace("_", " ")
        if remove_tag:
            while custom_tag in tags:
                tags.remove(custom_tag)
        else:
            for i in range(len(tags)):
                if keyword in tags[i]:
                    tags[i] = tags[i].replace(keyword, custom_tag)
            if append:
                tags.append(custom_tag)
            if prefix_tag:
                tags.insert(0, custom_tag)

    contents = ', '.join(tags)
    write_file(filename, contents)

def process_directory(image_dir, tag, append, prefix_tag, remove_tag, recursive, keyword):
    for filename in os.listdir(image_dir):
        file_path = os.path.join(image_dir, filename)
        
        if os.path.isdir(file_path) and recursive:
            process_directory(file_path, tag, append, prefix_tag, remove_tag, recursive, keyword)
        elif filename.endswith(extension.value):
            process_tags(file_path, tag, append, prefix_tag, remove_tag, keyword)

tag = custom_tag.value

if not any(
    [filename.endswith(extension.value) for filename in os.listdir(image_dir)]
):
    for filename in os.listdir(image_dir):
        if filename.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp")):
            open(
                os.path.join(image_dir, filename.split(".")[0] + extension.value),
                "w",
            ).close()

if custom_tag.value:
    process_directory(image_dir, tag, append.value, prefix_tag.value, remove_tag.value, recursive.value, keyword.value)


In [22]:
# @title ## 5.1. Model Config widget

import ipywidgets as widgets

pretrained_model_options = [
    os.path.join(pretrained_dir, filename)
    for filename in os.listdir(pretrained_dir)
    if os.path.isfile(os.path.join(pretrained_dir, filename))
]

vae_options = [
    os.path.join(vae_dir, filename)
    for filename in os.listdir(vae_dir)
    if os.path.isfile(os.path.join(vae_dir, filename))
]

resume_options = [
    os.path.join(dreambooth_output_dir, filename)
    for filename in os.listdir(dreambooth_output_dir)
    if os.path.isfile(os.path.join(dreambooth_output_dir, filename))
]

v2 = widgets.Checkbox(
    value=False,
    description="v2",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

v_parameterization = widgets.Checkbox(
    value=False,
    description="v_parameterization",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

project_name = widgets.Text(
    value="Last",
    description="Project Name:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

pretrained_model_name_or_path = widgets.Dropdown(
    options=pretrained_model_options+resume_options,
    description="Pretrained Model:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

vae = widgets.Dropdown(
    options=vae_options,
    description="VAE:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

resume_path = widgets.Dropdown(
    options=[""]+resume_options,
    description="Resume Path:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

widget_box = widgets.VBox([
    v2,
    v_parameterization,
    project_name,
    pretrained_model_name_or_path,
    vae,
    resume_path
])

widget_box

VBox(children=(Checkbox(value=False, description='v2', layout=Layout(width='auto'), style=DescriptionStyle(des…

In [23]:
# @title ## 5.2. Dataset Config widget
# @markdown default values are designed for `one concept` training. Refer to this [guide](https://rentry.org/kohyaminiguide#b-multi-concept-training) for multi-concept training.
import ipywidgets as widgets

dataset_repeats = widgets.IntSlider(
    value=10,
    min=1,
    max=100,
    step=1,
    description="Dataset Repeats:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

activation_word = widgets.Text(
    value="mksks style",
    description="Activation Word:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

caption_extension = widgets.Dropdown(
    options=["none", ".txt", ".caption"],
    value=".caption",
    description="Caption Extension:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

token_to_captions = widgets.Checkbox(
    value=False,
    description="Token to Captions",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

resolution = widgets.IntSlider(
    value=512,
    min=512,
    max=1024,
    step=128,
    description="Resolution:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

flip_aug = widgets.Checkbox(
    value=False,
    description="Flip Augmentation",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

keep_tokens = widgets.IntSlider(
    value=0,
    min=0,
    max=100,
    step=1,
    description="Keep Tokens:",
    layout=widgets.Layout(width="auto"),
    style={"description_width": "initial"}
)

widget_box = widgets.VBox([
    dataset_repeats,
    activation_word,
    caption_extension,
    token_to_captions,
    resolution,
    flip_aug,
    keep_tokens
])

widget_box

VBox(children=(IntSlider(value=10, description='Dataset Repeats:', layout=Layout(width='auto'), min=1, style=S…

In [24]:
# @title ## 5.2. Dataset Config
import toml
import glob

if ',' in activation_word.value or ' ' in activation_word.value:
    words = activation_word.value.replace(',', ' ').split()
    class_token = words[-1]


def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents


def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)


def get_supported_images(folder):
    supported_extensions = (".png", ".jpg", ".jpeg", ".webp", ".bmp")
    return [file for ext in supported_extensions for file in glob.glob(f"{folder}/*{ext}")]


def get_subfolders_with_supported_images(folder):
    subfolders = [os.path.join(folder, subfolder) for subfolder in os.listdir(folder) if os.path.isdir(os.path.join(folder, subfolder))]
    return [subfolder for subfolder in subfolders if len(get_supported_images(subfolder)) > 0]


def process_tags(filename, custom_tag, remove_tag):
    contents = read_file(filename)
    tags = [tag.strip() for tag in contents.split(',')]
    custom_tags = [tag.strip() for tag in custom_tag.split(',')]

    for custom_tag in custom_tags:
        custom_tag = custom_tag.replace("_", " ")
        if remove_tag:
            while custom_tag in tags:
                tags.remove(custom_tag)
        else:
            if custom_tag not in tags:
                tags.insert(0, custom_tag)

    contents = ', '.join(tags)
    write_file(filename, contents)


def process_folder_recursively(folder):
    for root, _, files in os.walk(folder):
        for file in files:
            if file.endswith(caption_extension.value):
                file_path = os.path.join(root, file)
                extracted_class_token = get_class_token_from_folder_name(root, folder)
                train_supported_images = get_supported_images(train_data_dir)
                tag = extracted_class_token if extracted_class_token else activation_word.value if train_supported_images else ""
                if not tag == "":
                    process_tags(file_path, tag, remove_tag=(not token_to_captions.value))


def get_num_repeats(folder):
    folder_name = os.path.basename(folder)
    try:
        repeats, _ = folder_name.split('_', 1)
        num_repeats = int(repeats)
    except ValueError:
        num_repeats = 1

    return num_repeats


def get_class_token_from_folder_name(folder, parent_folder):
    if folder == parent_folder:
        return class_token

    folder_name = os.path.basename(folder)
    try:
        _, concept = folder_name.split('_', 1)
        return concept
    except ValueError:
        return ""
        
train_supported_images = get_supported_images(train_data_dir)
train_subfolders = get_subfolders_with_supported_images(train_data_dir)
reg_supported_images = get_supported_images(reg_data_dir)
reg_subfolders = get_subfolders_with_supported_images(reg_data_dir)

subsets = []

config = {
    "general": {
        "enable_bucket": True,
        "caption_extension": caption_extension.value,
        "shuffle_caption": True,
        "keep_tokens": keep_tokens.value,
        "bucket_reso_steps": 64,
        "bucket_no_upscale": False,
    },
    "datasets": [
        {
            "resolution": resolution.value,
            "min_bucket_reso": 320 if resolution.value > 640 else 256,
            "max_bucket_reso": 1280 if resolution.value > 640 else 1024,
            "caption_dropout_rate": 0,
            "caption_tag_dropout_rate": 0,
            "caption_dropout_every_n_epochs": 0,
            "flip_aug": flip_aug.value,
            "color_aug": False,
            "face_crop_aug_range": None,
            "subsets": subsets,
        }
    ],
}

if token_to_captions.value and keep_tokens.value < 2:
    keep_tokens.value = 1

if caption_extension.value != "none":
    process_folder_recursively(train_data_dir)

if train_supported_images:
    subsets.append({
        "image_dir": train_data_dir,
        "class_tokens": activation_word.value,
        "num_repeats": dataset_repeats.value,
    })

for subfolder in train_subfolders:
    num_repeats = get_num_repeats(subfolder)
    extracted_class_token = get_class_token_from_folder_name(subfolder, train_data_dir.value)
    subsets.append({
        "image_dir": subfolder,
        "class_tokens": extracted_class_token if extracted_class_token else None,
        "num_repeats": num_repeats,
    })

if reg_supported_images:
    subsets.append({
        "is_reg": True,
        "image_dir": reg_data_dir,
        "class_tokens": class_token if 'class_token' in globals() else None,
        "num_repeats": 1,
    })

for subfolder in reg_subfolders:
    extracted_class_token = get_class_token_from_folder_name(subfolder, reg_data_dir.value)
    subsets.append({
        "is_reg": True,
        "image_dir": subfolder,
        "class_tokens": extracted_class_token if extracted_class_token else None,
        "num_repeats": num_repeats,
    })

for subset in subsets:
    if not glob.glob(f"{subset['image_dir']}/*.txt"):
        subset["class_tokens"] = activation_word.value

dataset_config = os.path.join(dreambooth_config_dir, "dataset_config.toml")

for key in config:
    if isinstance(config[key], dict):
        for sub_key in config[key]:
            if config[key][sub_key] == "":
                config[key][sub_key] = None
    elif config[key] == "":
        config[key] = None

config_str = toml.dumps(config)

with open(dataset_config, "w") as f:
    f.write(config_str)

print(config_str)

[[datasets]]
resolution = 512
min_bucket_reso = 256
max_bucket_reso = 1024
caption_dropout_rate = 0
caption_tag_dropout_rate = 0
caption_dropout_every_n_epochs = 0
flip_aug = false
color_aug = false
[[datasets.subsets]]
image_dir = "/home/studio-lab-user/sagemaker-studiolab-notebooks/train_data"
class_tokens = "mksks style"
num_repeats = 10


[general]
enable_bucket = true
caption_extension = ".caption"
shuffle_caption = true
keep_tokens = 0
bucket_reso_steps = 64
bucket_no_upscale = false



In [25]:
# @title ## 5.3. Optimizer Config

import ipywidgets as widgets

min_snr_gamma = widgets.FloatText(
    value=-1,
    description='min_snr_gamma:',
    step=0.01,
    style={"description_width": "initial"}
)

optimizer_type = widgets.Dropdown(
    options=["AdamW", "AdamW8bit", "Lion", "SGDNesterov", "SGDNesterov8bit", "DAdaptation", "AdaFactor"],
    value="AdamW8bit",
    description='optimizer_type:',
    style={"description_width": "initial"}
)

optimizer_args = widgets.Text(
    value="",
    description='optimizer_args:',
    style={"description_width": "initial"}
)

learning_rate = widgets.FloatText(
    value=2e-6,
    description='learning_rate:',
    step=1e-7,
    style={"description_width": "initial"}
)

stop_train_text_encoder = widgets.IntText(
    value=-1,
    description='stop_train_text_encoder:',
    style={"description_width": "initial"}
)

lr_scheduler = widgets.Dropdown(
    options=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "adafactor"],
    value="constant",
    description='lr_scheduler:',
    style={"description_width": "initial"}
)

lr_warmup_steps = widgets.IntText(
    value=0,
    description='lr_warmup_steps:',
    style={"description_width": "initial"}
)

lr_scheduler_num_cycles = widgets.IntText(
    value=0,
    description='lr_scheduler_num_cycles:',
    style={"description_width": "initial"}
)

lr_scheduler_power = widgets.IntText(
    value=0,
    description='lr_scheduler_power:',
    style={"description_width": "initial"}
)

widget_list = [
    min_snr_gamma,
    optimizer_type,
    optimizer_args,
    learning_rate,
    stop_train_text_encoder,
    lr_scheduler,
    lr_warmup_steps,
    lr_scheduler_num_cycles,
    lr_scheduler_power,
]

widget_box = widgets.VBox(widget_list)
widget_box

VBox(children=(FloatText(value=-1.0, description='min_snr_gamma:', step=0.01, style=DescriptionStyle(descripti…

In [27]:
# @title ## 5.4. Training Config widget
import ipywidgets as widgets

enable_sample_prompt = widgets.Checkbox(
    value=True,
    description='enable_sample_prompt:',
)

samples_per_prompt = widgets.IntText(
    value=1,
    description='samples per prompt:',
    style={"description_width": "initial"}
)


sampler = widgets.Dropdown(
    options=["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver", "dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"],
    value="ddim",
    description='sampler:',
    style={"description_width": "initial"}
)

noise_offset = widgets.FloatText(
    value=0.0,
    description='noise_offset:',
    step=0.01,
    style={"description_width": "initial"}
)

max_train_steps = widgets.IntText(
    value=2500,
    description='max_train_steps:',
    style={"description_width": "initial"}
)

vae_batch_size = widgets.IntText(
    value=1,
    description='vae_batch_size:',
    style={"description_width": "initial"}
)

train_batch_size = widgets.IntText(
    value=4,
    description='train_batch_size:',
    style={"description_width": "initial"}
)

mixed_precision = widgets.Dropdown(
    options=["no", "fp16", "bf16"],
    value="fp16",
    description='mixed_precision:',
    disabled=False,
    style={"description_width": "initial"}
)

save_state = widgets.Checkbox(
    value=False,
    description='save_state:',
    style={"description_width": "initial"}
)

save_precision = widgets.Dropdown(
    options=["float", "fp16", "bf16"],
    value="fp16",
    description='save_precision:',
    style={"description_width": "initial"},
    disabled=False,
)

save_n_epoch_ratio = widgets.FloatText(
    value=1,
    description='save_n_epoch_ratio:',
    step=0.1,
    style={"description_width": "initial"}
)

save_model_as = widgets.Dropdown(
    options=["ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
    value="ckpt",
    description='save_model_as:',
    style={"description_width": "initial"},
    disabled=False,
)

max_token_length = widgets.IntText(
    value=225,
    description='max_token_length:',
    style={"description_width": "initial"}
)

clip_skip = widgets.IntText(
    value=2,
    description='clip_skip:',
    style={"description_width": "initial"}
)

gradient_checkpointing = widgets.Checkbox(
    value=False,
    description='gradient_checkpointing:',
    style={"description_width": "initial"}
)

gradient_accumulation_steps = widgets.IntText(
    value=1,
    description='gradient_accumulation_steps:',
    style={"description_width": "initial"}
)

seed = widgets.IntText(
    value=-1,
    description='seed:',
    style={"description_width": "initial"}
)

huggingface_repo_id = widgets.Text(
    value="xxthekingxx/myHendricks2",
    description='huggingface repo id:',
    style={"description_width": "initial"}
)

huggingface_path_in_repo = widgets.Text(
    value="mymodel",
    description='path in repo:',
    style={"description_width": "initial"}
)

huggingface_token = widgets.Text(
    value="",
    description='huggingface write token:',
    style={"description_width": "initial"}
)

save_checkpoint_local  = widgets.Checkbox(
    value=False,
    description='save epochs local?:',
    style={"description_width": "initial"}
)

async_upload = widgets.Checkbox(
    value=True,
    description='upload asynchronously (requires save local):',
    style={"description_width": "initial"}
)

prior_loss_weight = widgets.FloatText(
    value=1.0,
    description='prior_loss_weight:',
    step=0.1,
    style={"description_width": "initial"}
)

widget_list = [
    enable_sample_prompt,
    samples_per_prompt,
    sampler,
    noise_offset,
    max_train_steps,
    vae_batch_size,
    train_batch_size,
    mixed_precision,
    save_checkpoint_local,
    huggingface_repo_id,
    huggingface_path_in_repo,
    huggingface_token,
    async_upload,
    save_state,
    save_precision,
    save_n_epoch_ratio,
    save_model_as,
    max_token_length,
    clip_skip,
    gradient_checkpointing,
    gradient_accumulation_steps,
    seed,
    prior_loss_weight,
]

widget_box = widgets.VBox(widget_list)
widget_box


VBox(children=(Checkbox(value=True, description='enable_sample_prompt:'), IntText(value=1, description='sample…

In [28]:
import toml
import os

os.chdir(repo_dir)

sample_str = f"""
  xhendx in a swimsuit \
  --n lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry \
  --w 512 \
  --h 768 \
  --l 7 \
  --s 35    
"""

config = {
    "model_arguments": {
        "v2": v2.value,
        "v_parameterization": v_parameterization.value if v2.value and v_parameterization.value else False,
        "pretrained_model_name_or_path": pretrained_model_name_or_path.value,
        "vae": vae.value,
    },
    "optimizer_arguments": {
        "min_snr_gamma": min_snr_gamma.value if not min_snr_gamma.value == -1 else None,
        "optimizer_type": optimizer_type.value,
        "learning_rate": learning_rate.value,
        "max_grad_norm": 1.0,
        "stop_train_text_encoder": stop_train_text_encoder.value if stop_train_text_encoder.value > 0 else None,
        "optimizer_args": eval(optimizer_args.value) if optimizer_args.value else None,
        "lr_scheduler": lr_scheduler.value,
        "lr_warmup_steps": lr_warmup_steps.value,
        "lr_scheduler_num_cycles": lr_scheduler_num_cycles.value if lr_scheduler.value == "cosine_with_restarts" else None,
        "lr_scheduler_power": lr_scheduler_power.value if lr_scheduler.value == "polynomial" else None,
    },
    "dataset_arguments": {
        "cache_latents": True,
        "debug_dataset": False,
        "vae_batch_size": vae_batch_size.value,
    },
        "huggingface_arguments": {
        "repo_type" : None,
        "huggingface_path_in_repo": huggingface_path_in_repo.value,
        "huggingface_repo_visibility" : None, #private if not "public"
        "huggingface_token": huggingface_token.value,
        "async_upload" : async_upload.value,
        "huggingface_repo_id" : huggingface_repo_id.value,
    },
    "training_arguments": {
        "output_dir": dreambooth_output_dir,
        "output_name": project_name.value,
        "save_precision": save_precision.value,
        "save_every_n_epochs": None,
        "save_n_epoch_ratio": save_n_epoch_ratio.value,
        "save_checkpoint_local": save_checkpoint_local.value,
        "save_last_n_epochs": None,
        "save_state": save_state.value,
        "save_last_n_epochs_state": None,
        "resume": resume_path.value,
        "train_batch_size": train_batch_size.value,
        "max_token_length": 225,
        "mem_eff_attn": False,
        "xformers": True,
        "max_train_steps": max_train_steps.value,
        "max_data_loader_n_workers": 8,
        "persistent_data_loader_workers": True,
        "seed": seed.value if seed.value > 0 else None,
        "gradient_checkpointing": gradient_checkpointing.value,
        "gradient_accumulation_steps": gradient_accumulation_steps.value,
        "mixed_precision": mixed_precision.value,
        "clip_skip": clip_skip.value if not v2.value else None,
        "logging_dir": dreambooth_logging_dir,
        "log_prefix": project_name.value,
        "noise_offset": noise_offset.value if noise_offset.value > 0 else None,
    },
    "sample_prompt_arguments": {
        "samples_per_prompt" : samples_per_prompt.value,
        "sample_every_n_steps": 100 if enable_sample_prompt.value else 999999,
        "sample_every_n_epochs": None,
        "sample_sampler": sampler.value,
    },
    "dreambooth_arguments": {
        "prior_loss_weight": 1.0,
    },
    "saving_arguments": {
        "save_model_as": save_model_as.value
    },
}


config_path = os.path.join(dreambooth_config_dir, "config_file.toml")
prompt_path = os.path.join(dreambooth_config_dir, "sample_prompt.txt")

for key in config:
    if isinstance(config[key], dict):
        for sub_key in config[key]:
            if config[key][sub_key] == "":
                config[key][sub_key] = None
    elif config[key] == "":
        config[key] = None

config_str = toml.dumps(config)

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

write_file(config_path, config_str)
write_file(prompt_path, sample_str)
    
print(config_str)


[model_arguments]
v2 = false
v_parameterization = false
pretrained_model_name_or_path = "/home/studio-lab-user/sagemaker-studiolab-notebooks/pretrained_model/realisticVisionV20_v20NoVAE.safetensors"
vae = "/home/studio-lab-user/sagemaker-studiolab-notebooks/vae/vae-ft-mse-840000-ema-pruned.ckpt"

[optimizer_arguments]
optimizer_type = "AdamW8bit"
learning_rate = 4e-6
max_grad_norm = 1.0
lr_scheduler = "cosine_with_restarts"
lr_warmup_steps = 0
lr_scheduler_num_cycles = 0

[dataset_arguments]
cache_latents = true
debug_dataset = false
vae_batch_size = 1

[huggingface_arguments]
huggingface_path_in_repo = "mymodel"
async_upload = false
huggingface_repo_id = "xxthekingxx/myHendricks4"

[training_arguments]
output_dir = "/home/studio-lab-user/sagemaker-studiolab-notebooks/dreambooth/output"
output_name = "Hendricks4"
save_precision = "fp16"
save_n_epoch_ratio = 1.0
save_checkpoint_local = false
save_state = false
train_batch_size = 4
max_token_length = 225
mem_eff_attn = false
xformers = t

In [29]:
#@title ## 5.5. Start Training

#@markdown Check your config here if you want to edit something: 
#@markdown - `sample_prompt` : /content/dreambooth/config/sample_prompt.txt
#@markdown - `config_file` : /content/dreambooth/config/config_file.toml
#@markdown - `dataset_config` : /content/dreambooth/config/dataset_config.toml

#@markdown Generated sample can be seen here: /content/dreambooth/output/sample

#@markdown You can import config from another session if you want.
sample_prompt = os.path.join(dreambooth_training_dir,"config/sample_prompt.txt") #@param {type:'string'}
config_file = os.path.join(dreambooth_training_dir,"config/config_file.toml") #@param {type:'string'}
dataset_config = os.path.join(dreambooth_training_dir,"config/dataset_config.toml") #@param {type:'string'}

accelerate_conf = {
    "config_file" : accelerate_config,
    "num_cpu_threads_per_process" : 1,
}

train_conf = {
    "sample_prompts" : sample_prompt,
    "dataset_config" : dataset_config,
    "config_file" : config_file
}

def train(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "

    return args

accelerate_args = train(accelerate_conf)
train_args = train(train_conf)
final_args = f"accelerate launch {accelerate_args} train_db.py {train_args}"
os.chdir(repo_dir)
!{final_args}

Loading settings from /home/studio-lab-user/sagemaker-studiolab-notebooks/dreambooth/config/config_file.toml...
/home/studio-lab-user/sagemaker-studiolab-notebooks/dreambooth/config/config_file
prepare tokenizer
update token length: 225
Load dataset config from /home/studio-lab-user/sagemaker-studiolab-notebooks/dreambooth/config/dataset_config.toml
prepare images.
found directory /home/studio-lab-user/sagemaker-studiolab-notebooks/train_data contains 24 image files
240 train images with repeating.
0 reg images.
no regularization images / 正則化画像が見つかりませんでした
[Dataset 0]
  batch_size: 4
  resolution: (512, 512)
  enable_bucket: True
  min_bucket_reso: 256
  max_bucket_reso: 1024
  bucket_reso_steps: 64
  bucket_no_upscale: False

  [Subset 0 of Dataset 0]
    image_dir: "/home/studio-lab-user/sagemaker-studiolab-notebooks/train_data"
    image_count: 24
    num_repeats: 10
    shuffle_caption: True
    keep_tokens: 0
    caption_dropout_rate: 0
    caption_dropout_every_n_epoches: 0
    ca

In [6]:
# @title ## 6.2. Inference
v2 = False  # @param {type:"boolean"}
v_parameterization = False  # @param {type:"boolean"}
prompt = "RAW photo, xhendx in a bikini, high detailed skin, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"  # @param {type: "string"}
negative = "(weird eyes, disfigured eyes, looking different direction:1.3), cgi, 3d, render, mutated hands, mutated fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, bad quality, worst quality"  # @param {type: "string"}
model = os.path.join(dreambooth_output_dir,'Hen1-fp16-pruned.ckpt')  # @param {type: "string"}
vae = os.path.join(vae_dir,'vae-ft-mse-840000-ema-pruned.ckpt')  # @param {type: "string"}
outdir = inference_dir  # @param {type: "string"}
scale = 7  # @param {type: "slider", min: 1, max: 40}
sampler = "euler_a"  # @param ["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]
steps = 35  # @param {type: "slider", min: 1, max: 100}
precision = "fp16"  # @param ["fp16", "bf16"] {allow-input: false}
width = 512  # @param {type: "integer"}
height = 768  # @param {type: "integer"}
images_per_prompt = 12  # @param {type: "integer"}
batch_size = 1  # @param {type: "integer"}
clip_skip = 1  # @param {type: "slider", min: 1, max: 40}
seed = -1  # @param {type: "integer"}

final_prompt = f"{prompt} --n {negative}"

config = {
    "v2": v2,
    "v_parameterization": v_parameterization,
    "ckpt": model,
    "outdir": outdir,
    "xformers": True,
    "vae": vae if vae else None,
    "fp16": True,
    "W": width,
    "H": height,
    "seed": seed if seed > 0 else None,
    "scale": scale,
    "sampler": sampler,
    "steps": steps,
    "max_embeddings_multiples": 3,
    "batch_size": batch_size,
    "images_per_prompt": images_per_prompt,
    "clip_skip": clip_skip if not v2 else None,
    "prompt": final_prompt,
}

args = ""
for k, v in config.items():
    if isinstance(v, str):
        args += f'--{k}="{v}" '
    if isinstance(v, bool) and v:
        args += f"--{k} "
    if isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    if isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python gen_img_diffusers.py {args}"

os.chdir(repo_dir)
!{final_args}

2023-06-01 16:26:03.325452: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
load StableDiffusion checkpoint
Traceback (most recent call last):
  File "/home/studio-lab-user/sagemaker-studiolab-notebooks/kohya-trainer/gen_img_diffusers.py", line 3262, in <module>
    main(args)
  File "/home/studio-lab-user/sagemaker-studiolab-notebooks/kohya-trainer/gen_img_diffusers.py", line 2110, in main
    text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
  File "/home/studio-lab-user/sagemaker-studiolab-notebooks/kohya-trainer/library/model_util.py", line 850, in load_models_from_stable_diffusion_checkpoint
    _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
  File "/home/st

In [3]:
import os
#@title ## 7.2. Model Pruner

os.chdir(tools_dir)

if not os.path.exists('prune.py'):
    !wget https://raw.githubusercontent.com/lopho/stable-diffusion-prune/main/prune.py

#@markdown Convert to Float16
fp16 = True #@param {'type':'boolean'}
#@markdown Use EMA for weights
ema = False #@param {'type':'boolean'}
#@markdown Strip CLIP weights
no_clip = False #@param {'type':'boolean'}
#@markdown Strip VAE weights
no_vae = False #@param {'type':'boolean'}
#@markdown Strip depth model weights
no_depth = False #@param {'type':'boolean'}
#@markdown Strip UNet weights
no_unet = False #@param {'type':'boolean'}

model_path = "/home/studio-lab-user/sagemaker-studiolab-notebooks/dreambooth/output/Hen1.ckpt" #@param {'type' : 'string'}

config = {
    "fp16": fp16,
    "ema": ema,
    "no_clip": no_clip,
    "no_vae": no_vae,
    "no_depth": no_depth,
    "no_unet": no_unet,
}

suffixes = {
    "fp16": "-fp16",
    "ema": "-ema",
    "no_clip": "-no-clip",
    "no_vae": "-no-vae",
    "no_depth": "-no-depth",
    "no_unet": "-no-unet",
}

print(f"Loading model from {model_path}")

dir_name = os.path.dirname(model_path)
base_name = os.path.basename(model_path)
output_name = base_name.split('.')[0]

for option, suffix in suffixes.items():
    if config[option]:
        print(f"Applying option {option}")
        output_name += suffix
        
output_name += '-pruned'
output_path = os.path.join(dir_name, output_name + ('.ckpt' if model_path.endswith(".ckpt") else ".safetensors"))

args = ""
for k, v in config.items():
    if k.startswith("_"):
        args += f'"{v}" '
    elif isinstance(v, str):
        args += f'--{k}="{v}" '
    elif isinstance(v, bool) and v:
        args += f"--{k} "
    elif isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    elif isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python3 prune.py {model_path} {output_path} {args}"
!{final_args}

print(f"Saving pruned model to {output_path}")

--2023-06-01 15:52:46--  https://raw.githubusercontent.com/lopho/stable-diffusion-prune/main/prune.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4374 (4.3K) [text/plain]
Saving to: 'prune.py'


2023-06-01 15:52:46 (43.5 MB/s) - 'prune.py' saved [4374/4374]

Loading model from /home/studio-lab-user/sagemaker-studiolab-notebooks/dreambooth/output/Hen1.ckpt
Applying option fp16
Saving pruned model to /home/studio-lab-user/sagemaker-studiolab-notebooks/dreambooth/output/Hen1-fp16-pruned.ckpt


In [4]:
# @title ## 7.1. Upload Config
from huggingface_hub import login
from huggingface_hub import HfApi
from huggingface_hub.utils import validate_repo_id, HfHubHTTPError


# @markdown Login to Huggingface Hub
# @markdown > Get **your** huggingface `WRITE` token [here](https://huggingface.co/settings/tokens)
write_token = "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"  # @param {type:"string"}
# @markdown Fill this if you want to upload to your organization, or just leave it empty.
orgs_name = ""  # @param{type:"string"}
# @markdown If your model/dataset repo does not exist, it will automatically create it.
model_name = "myHendricks2"  # @param{type:"string"}
dataset_name = ""  # @param{type:"string"}
make_private = True  # @param{type:"boolean"}

def authenticate(write_token):
    login(write_token, add_to_git_credential=True)
    api = HfApi()
    return api.whoami(write_token), api


def create_repo(api, user, orgs_name, repo_name, repo_type, make_private=False):
    global model_repo
    global datasets_repo
    
    if orgs_name == "":
        repo_id = user["name"] + "/" + repo_name.strip()
    else:
        repo_id = orgs_name + "/" + repo_name.strip()

    try:
        validate_repo_id(repo_id)
        api.create_repo(repo_id=repo_id, repo_type=repo_type, private=make_private)
        print(f"{repo_type.capitalize()} repo '{repo_id}' didn't exist, creating repo")
    except HfHubHTTPError as e:
        print(f"{repo_type.capitalize()} repo '{repo_id}' exists, skipping create repo")
    
    if repo_type == "model":
        model_repo = repo_id
        print(f"{repo_type.capitalize()} repo '{repo_id}' link: https://huggingface.co/{repo_id}\n")
    else:
        datasets_repo = repo_id
        print(f"{repo_type.capitalize()} repo '{repo_id}' link: https://huggingface.co/datasets/{repo_id}\n")

user, api = authenticate(write_token)

if model_name:
    create_repo(api, user, orgs_name, model_name, "model", make_private)
if dataset_name:
    create_repo(api, user, orgs_name, dataset_name, "dataset", make_private)


Token is valid.
[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub.
Run the following command in your terminal in case you want to set the 'store' credential helper as default.

git config --global credential.helper store

Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.[0m
Token has not been saved to git credential helper.
Your token has been saved to /home/studio-lab-user/.cache/huggingface/token
Login successful
Model repo 'xxthekingxx/myHendricks2' exists, skipping create repo
Model repo 'xxthekingxx/myHendricks2' link: https://huggingface.co/xxthekingxx/myHendricks2



In [5]:
# @title ### 8.2.1. Upload Model
from huggingface_hub import HfApi
from pathlib import Path

api = HfApi()

# @markdown This will be uploaded to model repo
model_path = os.path.join(dreambooth_output_dir,"Hen1.ckpt")  # @param {type :"string"}
path_in_repo = ""  # @param {type :"string"}
# @markdown Now you can save your config file for future use
config_path = dreambooth_config_dir  # @param {type :"string"}
# @markdown Other Information
commit_message = "uploading model"  # @param {type :"string"}

if not commit_message:
    commit_message = "feat: upload " + project_name.value + " checkpoint"

if os.path.exists(model_path):
    vae_exists = os.path.exists(os.path.join(model_path, "vae"))
    unet_exists = os.path.exists(os.path.join(model_path, "unet"))
    text_encoder_exists = os.path.exists(os.path.join(model_path, "text_encoder"))


def upload_model(model_paths, is_folder: bool, is_config: bool):
    path_obj = Path(model_paths)
    trained_model = path_obj.parts[-1]

    if path_in_repo:
        trained_model = path_in_repo

    if is_config:
        if path_in_repo:
            trained_model = f"{path_in_repo}_config"
        else:
            trained_model = f"{project_name.value}_config"

    if is_folder == True:
        print(f"Uploading {trained_model} to https://huggingface.co/" + model_repo)
        print(f"Please wait...")

        if vae_exists and unet_exists and text_encoder_exists:
            api.upload_folder(
                folder_path=model_paths,
                repo_id=model_repo,
                commit_message=commit_message,
                ignore_patterns=".ipynb_checkpoints",
            )
        else:
            api.upload_folder(
                folder_path=model_paths,
                path_in_repo=trained_model,
                repo_id=model_repo,
                commit_message=commit_message,
                ignore_patterns=".ipynb_checkpoints",
            )
        print(
            f"Upload success, located at https://huggingface.co/"
            + model_repo
            + "/tree/main\n"
        )
    else:
        print(f"Uploading {trained_model} to https://huggingface.co/" + model_repo)
        print(f"Please wait...")

        api.upload_file(
            path_or_fileobj=model_paths,
            path_in_repo=trained_model,
            repo_id=model_repo,
            commit_message=commit_message,
        )

        print(
            f"Upload success, located at https://huggingface.co/"
            + model_repo
            + "/blob/main/"
            + trained_model
            + "\n"
        )


def upload():
    if model_path.endswith((".ckpt", ".safetensors", ".pt")):
        upload_model(model_path, False, False)
    else:
        upload_model(model_path, True, False)

    if config_path:
        upload_model(config_path, True, True)


upload()

Uploading Hen1.ckpt to https://huggingface.co/xxthekingxx/myHendricks2
Please wait...


Hen1.ckpt:   0%|          | 0.00/2.13G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

Upload success, located at https://huggingface.co/xxthekingxx/myHendricks2/blob/main/Hen1.ckpt

Uploading Hen1_config to https://huggingface.co/xxthekingxx/myHendricks2
Please wait...
Upload success, located at https://huggingface.co/xxthekingxx/myHendricks2/tree/main

