<a href="https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/dev/kohya-LoRA-dreambooth.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

![visitors](https://visitor-badge.glitch.me/badge?page_id=linaqruf.lora-dreambooth)
#Kohya LoRA Dreambooth

Adapted to Google Colab based on [kohya-ss/sd-script](https://github.com/kohya-ss/sd-scripts)<br>
Adapted to Google Colab by [Linaqruf](https://github.com/Linaqruf)<br>
You can find latest notebook update [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-LoRA-dreambooth.ipynb)




# I. Install Kohya Trainer

In [None]:
#@title ## 1.1. Clone Kohya Trainer
#@markdown Clone Kohya Trainer from GitHub and check for updates. Use textbox below if you want to checkout other branch or old commit. Leave it empty to stay the HEAD on main.

import os
%store -r

# Check GPU Availability
!nvidia-smi

# Define path
root_dir = "/content"
%store root_dir
repo_dir = str(root_dir)+"/kohya-trainer"
%store repo_dir
tools_dir = str(root_dir)+"/kohya-trainer/tools"
%store tools_dir 
finetune_dir = str(root_dir)+"/kohya-trainer/finetune"
%store finetune_dir
training_dir = str(root_dir)+"/dreambooth"
%store training_dir

# Define identifier
branch = "" #@param {type: "string"}
repo_url = "https://github.com/Linaqruf/kohya-trainer"

%cd {root_dir}

def clone_repo():
  # Check if the directory already exists
  if os.path.isdir(repo_dir):
    %cd {repo_dir}
    print("This folder already exists, will do a !git pull instead\n")
    if branch != "":
      !git pull origin {branch} 
    else:
      !git pull
  else:
    !git clone {repo_url}

def checkout_repo(branch_or_commit):
  %cd {repo_dir}
  !git checkout {branch_or_commit}

# Clone or update the Kohya Trainer repository
clone_repo()

# Checkout to the specified branch or commit
if branch != "":
  checkout_repo(branch)


In [None]:
#@title ## 1.2. Installing Dependencies

import os
%store -r

#@markdown This will install required Python packages

# Define variable
accelerate_config = str(repo_dir)+"/accelerate_config/config.yaml"
%store accelerate_config
install_xformers = True #@param {'type':'boolean'}

%cd {repo_dir}

def install_dependencies():
  !pip -qqqq install --upgrade gallery-dl
  !pip -qqqq install --upgrade --no-cache-dir gdown
  !apt -qqqq install liblz4-tool aria2
  !pip -qqqq install imjoy-elfinder
  !pip -qqqq install --upgrade -r requirements.txt

  if install_xformers:
    !pip -qqqq install -U -I --no-deps https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.15/xformers-0.0.15.dev0+189828c.d20221207-cp38-cp38-linux_x86_64.whl
  else:
    pass

  from accelerate.utils import write_basic_config
  if not os.path.exists(accelerate_config):
    write_basic_config(save_location = accelerate_config) # Write a config file
  else:
    pass

# Install dependencies
install_dependencies()

#@markdown After Accelerate updated its version to 0.15.0, you can't manually input the config using
#@markdown `!accelerate config` in Google Colab. Instead, a `config.yaml` file will be generated by
#@markdown the `write_basic_config()` function. You can find the file [here](/content/kohya-trainer/accelerate_config/config.yaml) after installation.
#@markdown if you want to modify it.



## 1.3. Sign-in to Cloud Service

In [None]:
#@title ### 1.3.1. Login to Huggingface hub
from huggingface_hub import login
%store -r

#@markdown 1. Of course, you need a Huggingface account first.
#@markdown 2. To create a huggingface token, go to [this link](https://huggingface.co/settings/tokens), then `create new token` or copy available token with the `Write` role.

write_token = "your-write-token-here" #@param {type:"string"}
login(write_token, add_to_git_credential=True)

%store write_token


In [None]:
#@title ### 1.3.2. Mount Drive
from google.colab import drive

mount_drive = True #@param {type: "boolean"}

if mount_drive:
  drive.mount('/content/drive')

In [None]:
#@title 1.4. Open Special `File Explorer` for Colab
#@markdown This will work real-time even though you're running other cells
%store -r

import threading
from google.colab import output
from imjoy_elfinder.app import main

# start imjoy-elfinder server
thread = threading.Thread(target=main, args=[["--root-dir=/content", "--port=8765"]])
thread.start()

open_in_new_tab = True #@param {type:"boolean"}

if open_in_new_tab:
  # open imjoy-elfinder in a new tab
  output.serve_kernel_port_as_window(8765)
else:
  # view the 
  output.serve_kernel_port_as_iframe(8765, height='500')


# II. Pretrained Model Selection

In [None]:
#@title ## 2.1. Download Available Model 
import os
%store -r

%cd {root_dir}

installModels = []
installv2Models = []

#@markdown ### Available Model
#@markdown Select one of available model to download:

#@markdown ### SD1.x model
modelUrl = ["", \
            "https://huggingface.co/Linaqruf/personal-backup/resolve/main/models/animefull-final-pruned.ckpt", \
            "https://huggingface.co/cag/anything-v3-1/resolve/main/anything-v3-1.safetensors", \
            "https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.5-pruned.ckpt", \
            "https://huggingface.co/Rasgeath/self_made_sauce/resolve/main/Kani-anime-pruned.ckpt", \
            "https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix2/AbyssOrangeMix2_nsfw.safetensors", \
            "https://huggingface.co/gsdf/Counterfeit-V2.0/resolve/main/Counterfeit-V2.0fp16.safetensors", \
            "https://huggingface.co/closertodeath/dpepteahands3/resolve/main/dpepteahand3.ckpt", \
            "https://huggingface.co/prompthero/openjourney-v2/resolve/main/openjourney-v2.ckpt", \
            "https://huggingface.co/dreamlike-art/dreamlike-diffusion-1.0/resolve/main/dreamlike-diffusion-1.0.ckpt", \
            "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt"]
modelList = ["", \
             "Animefull-final-pruned", \
             "Anything-v3-1", \
             "Anything-v4-5-pruned", \
             "Kani-anime-pruned", \
             "AbyssOrangeMix2-nsfw", \
             "Counterfeit-v2", \
             "DpepTeaHands3", \
             "OpenJourney-v2", \
             "Dreamlike-diffusion-v1-0", \
             "Stable-Diffusion-v1-5"]
modelName = "Anything-v3-1"  #@param ["", "Animefull-final-pruned", "Anything-v3-1", "Anything-v4-5-pruned", "Kani-anime-pruned", "AbyssOrangeMix2-nsfw", "Counterfeit-v2", "DpepTeaHands3", "OpenJourney-v2", "Dreamlike-diffusion-v1-0", "Stable-Diffusion-v1-5"]

#@markdown ### SD2.x model
v2ModelUrl = ["", \
              "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt", \
              "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt", \
              "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e2.ckpt", \
              "https://huggingface.co/p1atdev/pd-archive/resolve/main/plat-v1-3-1.safetensors"]
v2ModelList = ["", \
              "stable-diffusion-2-1-base", \
              "stable-diffusion-2-1-768v", \
              "waifu-diffusion-1-4-anime-e2", \
              "plat-diffusion-v1-3-1"]
v2ModelName = "" #@param ["", "stable-diffusion-2-1-base", "stable-diffusion-2-1-768v", "waifu-diffusion-1-4-anime-e2", "plat-diffusion-v1-3-1"]

# Check if user has selected a model
if modelName != "":
  # Map selected model to URL
  installModels.append((modelName, modelUrl[modelList.index(modelName)]))

# Check if user has selected a model
if v2ModelName != "":
  # Map selected model to URL
  installv2Models.append((v2ModelName, v2ModelUrl[v2ModelList.index(v2ModelName)]))

def install(checkpoint_name, url):
  if url.endswith(".ckpt"):
    ext = "ckpt"
  else:
    ext = "safetensors"
     
  hf_token = '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' 
  user_header = f"\"Authorization: Bearer {hf_token}\""
  !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {root_dir}/pre_trained_model -o {checkpoint_name}.{ext} "{url}"

def install_checkpoint():
  # Iterate through list of models to install
  for model in installModels:
    # Call install function for each model
    install(model[0], model[1])

  # Iterate through list of models to install
  for v2model in installv2Models:
    # Call install function for each v2model
    install(v2model[0], v2model[1])

install_checkpoint()

In [None]:
#@title ## 2.2. Download Custom Model

import os
%store -r

%cd {root_dir}

#@markdown ### Custom model
modelUrl = "" #@param {'type': 'string'}
dst = str(root_dir)+"/pre_trained_model"

if not os.path.exists(dst):
    os.makedirs(dst)

def install(url):
  base_name = os.path.basename(url)

  if url.startswith("https://drive.google.com"):
    %cd {dst}
    !gdown --fuzzy {url}
  elif url.startswith("https://huggingface.co/"):
    if '/blob/' in url:
      url = url.replace('/blob/', '/resolve/')
    #@markdown Change this part with your own huggingface token if you need to download your private model
    hf_token = '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' #@param {type:"string"}
    user_header = f"\"Authorization: Bearer {hf_token}\""
    !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {dst} -o {base_name} {url}
  else:
    !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {dst} -o {base_name} {url}

install(modelUrl)


In [None]:
#@title ## 2.3. Download Available VAE
%store -r 

%cd {root_dir}

installVae = []
#@markdown ### Available VAE
#@markdown Select one of the VAEs to download, select `none` for not download VAE:
vaeUrl = ["", \
          "https://huggingface.co/Linaqruf/personal-backup/resolve/main/vae/animevae.pt", \
          "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt", \
          "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt"]
vaeList = ["none", \
           "anime.vae.pt", \
           "waifudiffusion.vae.pt", \
           "stablediffusion.vae.pt"]
vaeName = "anime.vae.pt" #@param ["none", "anime.vae.pt", "waifudiffusion.vae.pt", "stablediffusion.vae.pt"]

installVae.append((vaeName, vaeUrl[vaeList.index(vaeName)]))

def install(vae_name, url):
  hf_token = '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'
  user_header = f"\"Authorization: Bearer {hf_token}\""
  !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -o vae/{vae_name} "{url}"

def install_vae():
  if vaeName != "none":
    for vae in installVae:
      install(vae[0], vae[1])
  else:
    pass

install_vae()

# IV. Data Acquisition

You can either upload your dataset to this notebook or use the image scraper below to bulk download images from Danbooru.

If you want to use your own dataset, you can upload to colab `local files`.


In [37]:
#@title ## 3.1. Define Train Data Directory
#@markdown Define where your train data will be located. This cell will also create a folder based on your input. 
#@markdown This folder will be used as the target folder for scraping, tagging, bucketing, and training in the next cell.

%store -r

train_folder_directory = "/content/dreambooth/train_data" #@param {type: "string"}
%store train_folder_directory
train_parent_directory = os.path.dirname(train_folder_directory)
reg_folder_directory = f"{train_parent_directory}/reg_data"
%store reg_folder_directory

dataset_repeats = 1 #@param {type: "integer"}
train_concept = "mksks" #@param {type: "string"}
train_class = "" #@param {type: "string"}
#@markdown You can run this cell multiple time to add new concepts

if train_class:
  train_folder = str(dataset_repeats) + "_" + train_concept + " " + train_class
else:
  train_folder = str(dataset_repeats) + "_" + train_concept
  
train_data_dir = f"{train_folder_directory}/{train_folder}"

if not os.path.isdir(reg_folder_directory):
  os.mkdir(reg_folder_directory)

if not os.path.isdir(train_folder_directory):
  os.mkdir(train_folder_directory)

if not os.path.isdir(train_data_dir):
  os.mkdir(train_data_dir)

Stored 'train_folder_directory' (str)
Stored 'reg_folder_directory' (str)


In [38]:
#@title ## 4.2. Download and Extract Zip (.zip)
import os
import shutil
from pathlib import Path
%store -r

#@markdown ### Define Zipfile URL or Zipfile Path
zipfile_url_or_path = "https://huggingface.co/datasets/Linaqruf/your-dataset-name/resolve/main/hito-komoru_dataset.zip" #@param {'type': 'string'}
zipfile_dst = str(root_dir)+"/zip_file.zip"
extract_to = "" #@param {'type': 'string'}

if extract_to != "":
  if not os.path.exist(extract_to):
    os.makedirs(extract_to)
else:
  extract_to = train_data_dir

#@markdown This will ignore `extract_to` path and automatically extracting to `train_data_dir`
is_dataset = True #@param{'type':'boolean'}

#@markdown Tick this if you want to extract all files directly to `extract_to` folder, and automatically delete the zip to save the memory
auto_unzip_and_delete = True #@param{'type':'boolean'}

dirname = os.path.dirname(zipfile_dst)
basename = os.path.basename(zipfile_dst)

try:
  if zipfile_url_or_path.startswith("/content"):
    zipfile_dst = zipfile_url_or_path
    if auto_unzip_and_delete == False:
      if is_dataset:
        extract_to = train_data_dir
      !unzip -j {zipfile_dst} -d "{extract_to}"
  elif zipfile_url_or_path.startswith("https://drive.google.com"):
    !gdown --fuzzy  {zipfile_url_or_path}
  elif zipfile_url_or_path.startswith("magnet:?"):
    !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 {zipfile_url_or_path}
  elif zipfile_url_or_path.startswith("https://huggingface.co/"):
    if '/blob/' in zipfile_url_or_path:
      zipfile_url_or_path = zipfile_url_or_path.replace('/blob/', '/resolve/')

    hf_token = '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'
    user_header = f"\"Authorization: Bearer {hf_token}\""
    !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {dirname} -o {basename} {zipfile_url_or_path}
  else:
    !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {dirname} -o {basename} {zipfile_url_or_path}

except Exception as e:
  print("An error occurred while downloading the file:", e)

if is_dataset:
  extract_to = train_data_dir

if auto_unzip_and_delete:
  !unzip -j {zipfile_dst} -d "{extract_to}"

  # directory to check for files
  # JSON files to move
  files_to_move = ("meta_cap.json", \
                   "meta_cap_dd.json", \
                   "meta_lat.json", \
                   "meta_clean.json")

  # check each file in the directory
  for filename in os.listdir(extract_to):
      # get the full file path
      file_path = os.path.join(extract_to, filename)
      if filename in files_to_move:
          # move the file to the parent directory
          shutil.move(file_path, os.path.dirname(extract_to))
  
  path_obj = Path(zipfile_dst)
  zipfile_name = path_obj.parts[-1]
  
  if os.path.isdir(zipfile_dst):
    print("\nThis zipfile doesn't exist or has been deleted \n")
  else:
    os.remove(zipfile_dst)
    print(f"\n{zipfile_name} has been deleted")


01/30 17:19:48 [[1;31mERROR[0m] CUID#7 - Download aborted. URI=https://huggingface.co/datasets/Linaqruf/your-dataset-name/resolve/main/hito-komoru_dataset.zip
Exception: [AbstractCommand.cc:351] errorCode=3 URI=https://huggingface.co/datasets/Linaqruf/your-dataset-name/resolve/main/hito-komoru_dataset.zip
  -> [HttpSkipResponseCommand.cc:218] errorCode=3 Resource not found

Download Results:
gid   |stat|avg speed  |path/URI
50ddb5|[1;31mERR[0m |       0B/s|/content/zip_file.zip

Status Legend:
(ERR):error occurred.

aria2 will resume download if the transfer is restarted.
If there are any errors, then see the log file. See '-l' option in help/man page for details.
unzip:  cannot find or open /content/zip_file.zip, /content/zip_file.zip.zip or /content/zip_file.zip.ZIP.


FileNotFoundError: ignored

In [None]:
#@title ## 4.3. Simple Booru Scraper
#@markdown Use gallery-dl to scrape images from a booru site using the specified tags
import os
import html
%store -r

%cd {root_dir}

# Set configuration options
booru = "Gelbooru" #@param ["", "Danbooru", "Gelbooru"]
tag1 = "hito_komoru" #@param {type: "string"}
tag2 = "" #@param {type: "string"}
download_tags = False #@param {type: "boolean"}
# Construct the search query
if tag2 != "":
  tags = tag1 + "+" + tag2
else:
  tags = tag1

if download_tags == True:
  write_tags = "--write-tags"
else:
  write_tags = ""

# Scrape images from the specified booru site using the given tags
if booru.lower() == "danbooru":
  !gallery-dl "https://danbooru.donmai.us/posts?tags={tags}" {write_tags} -D "{train_data_dir}"
elif booru.lower() == "gelbooru":
  !gallery-dl "https://gelbooru.com/index.php?page=post&s=list&tags={tags}" {write_tags} -D "{train_data_dir}"
else:
  print(f"Unknown booru site: {booru}")

if download_tags == True: 
  # Get a list of all the .txt files in the folder
  files = [f for f in os.listdir(train_data_dir) if f.endswith(".txt")]

  # Loop through each file
  for file in files:
      file_path = os.path.join(train_data_dir, file)

      # Read the contents of the file
      with open(file_path, "r") as f:
          contents = f.read()

      # Decode HTML entities and replace _ with a space
      contents = html.unescape(contents)
      contents = contents.replace("_", " ")

      # Split the contents on newline characters and join with commas
      contents = ", ".join(contents.split("\n"))

      # Write the modified contents back to the file
      with open(file_path, "w") as f:
          f.write(contents)

# VI. Data Preprocessing

In [None]:
#@title ## 6.1. Data Cleaning
#@markdown This will delete unnecessary files and unsupported media like `.mp4`, `.webm`, and `.gif`
%store -r

import os

%cd {root_dir}

test = os.listdir(train_data_dir)

#@markdown I recommend to `keep_metadata` especially if you're doing resume training and you have metadata and bucket latents file from previous training like `.npz`, `.txt`, `.caption`, and `json`.
keep_metadata = True #@param {'type':'boolean'}

# List of supported file types
if keep_metadata == True:
  supported_types = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".caption", ".npz", ".txt", ".json"]
else:
  supported_types = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]

# Iterate over all files in the directory
for item in test:
    # Extract the file extension from the file name
    file_ext = os.path.splitext(item)[1]
    # If the file extension is not in the list of supported types, delete the file
    if file_ext not in supported_types:
        # Print a message indicating the name of the file being deleted
        print(f"Deleting file {item} from {train_data_dir}")
        # Delete the file
        os.remove(os.path.join(train_data_dir, item))

In [None]:
#@title ## 6.2. Data Annotation
%cd {finetune_dir}

#@markdown We're using [BLIP](https://huggingface.co/spaces/Salesforce/BLIP) for image captioning and [Waifu Diffusion 1.4 Tagger](https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags) for image tagging like danbooru.

start_labeling = "WD_1_4_Tagger" #@param ["BLIP_Captioning", "WD_1_4_Tagger"]

#@markdown BLIP Captioning example: <br>
#@markdown `a girl with long hair holding a cellphone`

#@markdown Waifu Diffusion 1.4 Tagger example : <br>
#@markdown `1girl, solo, looking_at_viewer, short_hair, bangs, simple_background, shirt, black_hair, white_background, closed_mouth, choker, hair_over_one_eye, head_tilt, grey_eyes, black_shirt, floating_hair, black_choker, eyes_visible_through_hair, portrait`

batch_size = 8

if start_labeling == "BLIP_Captioning":
  !python make_captions.py \
    "{train_data_dir}" \
    --batch_size {batch_size} \
    --caption_extension .caption
elif start_labeling == "WD_1_4_Tagger":
  !python tag_images_by_wd14_tagger.py \
    "{train_data_dir}" \
    --batch_size {batch_size} \
    --caption_extension .txt
else:
  pass
    

# VII. Training Model



In [None]:
#@title ## 7.1. Define Important folder
from google.colab import drive
%store -r

v2 = False #@param {type:"boolean"}
v_parameterization = False #@param {type:"boolean"}
project_name = "masabodo" #@param {type:"string"}
pretrained_model_name_or_path = "/content/pre_trained_model/Anything-v3-1.safetensors" #@param {type:"string"}
vae = ""  #@param {type:"string"}
#@markdown You need to register parent folder and not where `train_data_dir` located
train_folder_directory = "/content/dreambooth/train_data" #@param {'type':'string'}
%store train_folder_directory
reg_folder_directory = "/content/dreambooth/reg_data" #@param {'type':'string'}
%store reg_folder_directory
output_dir = "/content/dreambooth/output" #@param {'type':'string'}
resume_path =""

#@markdown This will ignore `output_dir` defined above, and changed to `/content/drive/MyDrive/fine_tune/output` by default
output_to_drive = False #@param {'type':'boolean'}

if output_to_drive:
  output_dir = "/content/drive/MyDrive/dreambooth/output"

  if not os.path.exists("/content/drive"):
    drive.mount('/content/drive')  

# Check if directory exists
if not os.path.exists(output_dir):
  # Create directory if it doesn't exist
  os.makedirs(output_dir)

#V2 Inference
inference_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/"

if v2 and not v_parameterization:
  inference_url += "v2-inference.yaml"
if v2 and v_parameterization:
  inference_url += "v2-inference-v.yaml"

try:
  if v2:
    !wget {inference_url} -O {output_dir}/{project_name}.yaml
    print("File successfully downloaded")
except:
  print("There was an error downloading the file. Please check the URL and try again.")

In [None]:
#@title ## 7.2. Define Specific LoRA Training parameter
%store -r

#@markdown ## LoRA - Low Rank Adaptation Dreambooth

#@markdown If you're following `https://rentry.org/lora_train` guide, they set `network_dim` to `128`, you can change it yourself or use default parameter
network_dim = 128 #@param {'type':'number'}
#@markdown For LoRA weight scaling. Not sure what this is, but if you want to get the same result before update, you need to set `network_alpha` the same as `network_dim`.
network_alpha = 128 #@param {'type':'number'}
network_module = "networks.lora"

#@markdown `Specify network_weights for resume training`
network_weights = "" #@param {'type':'string'}

#@markdown When neither `--network_train_unet_only` nor `--network_train_text_encoder_only` is specified (default), both Text Encoder and U-Net LoRA modules are enabled.
network_train_on = "both" #@param ['both','unet_only', 'text_encoder_only'] {'type':'string'}

#@markdown Some people recommend to set `text_encoder_lr` at lower learning rate such as `5e-5`
unet_lr = 1e-4 #@param {'type':'number'}
text_encoder_lr = 5e-5 #@param {'type':'number'}
lr_scheduler = "constant" #@param  ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] {allow-input: false}

#@markdown You dont need to change this part below if you're not using `cosine_with_restart`
lr_scheduler_num_cycles = 1 #@param {'type':'number'}
lr_scheduler_power = 1 #@param {'type':'number'}

#@markdown Tick if you dont want to save metadata in output model
no_metadata = False #@param {type:"boolean"}
training_comment = "this comment will stored to metadata" #@param {'type':'string'}

print("Load network module :", network_module)
print(f"{network_module} dim set to :", network_dim)
print(f"{network_module} alpha set to :", network_alpha)

if network_weights == "":
  print("No LoRA weight loaded")
else:
  if os.path.exists(network_weights):
    print("Load LoRA weight: ", network_weights)
  else:
    print(f"{network_weights} didn't exist")
    network_weights =""

if network_train_on == "unet_only":
  print("Enable LoRA for U-Net")
  print("Disable LoRA for Text Encoder")
  print("UNet learning rate: ", unet_lr)
elif network_train_on == "text_encoder_only":
  print("Disable LoRA for U-Net")
  print("Enable LoRA for Text Encoder")
  print("Text encoder learning rate: ", text_encoder_lr)
else:
  print("Enable LoRA for U-Net")
  print("Enable LoRA for Text Encoder")
  print("UNet learning rate: ", unet_lr)
  print("Text encoder learning rate: ", text_encoder_lr)

print("Learning rate Scheduler:", lr_scheduler)
if lr_scheduler == "cosine_with_restarts":
  print("- num cycles: ", lr_scheduler_num_cycles)
  print("- power: ", lr_scheduler_power)

if not no_metadata:
  if training_comment: 
    print("Training comment:", training_comment)
else:
  print("Metadata won't be saved")


In [None]:
from prettytable import PrettyTable
import textwrap
import yaml

%store -r

#@title ## 7.3. Start LoRA Dreambooth
#@markdown ### Define Parameter

train_batch_size = 4 #@param {type:"number"}
num_epochs = 1 #@param {type:"number"}
caption_extension = '.txt' #@param {'type':'string'}
mixed_precision = "fp16" #@param ["no","fp16","bf16"] {allow-input: false}
save_precision = "fp16" #@param ["float", "fp16", "bf16"] {allow-input: false}
save_n_epochs_type = "save_n_epoch_ratio" #@param ["save_every_n_epochs", "save_n_epoch_ratio"] {allow-input: false}
save_n_epochs_type_value = 5 #@param {type:"number"}
save_model_as = "safetensors" #@param ["ckpt", "pt", "safetensors"] {allow-input: false}
resolution = 512 #@param {type:"number"}
max_token_length = 225 #@param {type:"number"}
clip_skip = 2 #@param {type:"number"}
use_8bit_adam = True #@param {type:"boolean"}
gradient_checkpointing = False #@param {type:"boolean"}
gradient_accumulation_steps = 1 #@param {type:"number"}
seed = 0 #@param {type:"number"}
logging_dir = "/content/dreambooth/logs"
log_prefix = project_name
additional_argument = "--shuffle_caption --xformers --enable_bucket --cache_latents" #@param {type:"string"}
print_hyperparameter = True #@param {type:"boolean"}
prior_loss_weight =1.0
%cd {repo_dir}

train_command=f"""
accelerate launch --config_file={accelerate_config} --num_cpu_threads_per_process=8 train_network.py \
  {"--v2" if v2 else ""} \
  {"--v_parameterization" if v2 and v_parameterization else ""} \
  --network_dim={network_dim} \
  --network_alpha={network_alpha} \
  --network_module={network_module} \
  {"--network_weights=" + network_weights if network_weights else ""} \
  {"--network_train_unet_only" if network_train_on == "unet_only" else ""} \
  {"--network_train_text_encoder_only" if network_train_on == "text_encoder_only" else ""} \
  {"--unet_lr=" + format(unet_lr) if unet_lr else ""} \
  {"--text_encoder_lr=" + format(text_encoder_lr) if text_encoder_lr else ""} \
  {"--no_metadata" if no_metadata else ""} \
  {"--training_comment=" + "training_comment" if training_comment and not no_metadata else ""} \
  --lr_scheduler={lr_scheduler} \
  {"--lr_scheduler_num_cycles=" + format(lr_scheduler_num_cycles) if lr_scheduler == "cosine_with_restarts" else ""} \
  {"--lr_scheduler_power=" + format(lr_scheduler_power) if lr_scheduler == "cosine_with_restarts" else ""} \
  --pretrained_model_name_or_path={pretrained_model_name_or_path} \
  {"--vae=" + vae if vae else ""} \
  --caption_extension={caption_extension} \
  --train_data_dir={train_folder_directory} \
  --reg_data_dir={reg_folder_directory} \
  --output_dir={output_dir} \
  --prior_loss_weight={prior_loss_weight} \
  {"--resume=" + resume_path if resume_path else ""} \
  {"--output_name=" + project_name if project_name else ""} \
  --mixed_precision={mixed_precision} \
  --save_precision={save_precision} \
  {"--save_every_n_epochs=" + format(save_n_epochs_type_value) if save_n_epochs_type=="save_every_n_epochs" else ""} \
  {"--save_n_epoch_ratio=" + format(save_n_epochs_type_value) if save_n_epochs_type=="save_n_epoch_ratio" else ""} \
  --save_model_as={save_model_as} \
  --resolution={resolution} \
  --train_batch_size={train_batch_size} \
  --max_token_length={max_token_length} \
  {"--use_8bit_adam" if use_8bit_adam else ""} \
  --max_train_epochs={num_epochs} \
  {"--seed=" + format(seed) if seed > 0 else ""} \
  {"--gradient_checkpointing" if gradient_checkpointing else ""} \
  {"--gradient_accumulation_steps=" + format(gradient_accumulation_steps) } \
  {"--clip_skip=" + format(clip_skip) if v2 == False else ""} \
  --logging_dir={logging_dir} \
  --log_prefix={log_prefix} \
  {additional_argument}
  """

debug_params = ["v2", \
                "v_parameterization", \
                "network_dim", \
                "network_alpha", \
                "network_module", \
                "network_weights", \
                "network_train_on", \
                "unet_lr", \
                "text_encoder_lr", \
                "no_metadata", \
                "training_comment", \
                "lr_scheduler", \
                "lr_scheduler_num_cycles", \
                "lr_scheduler_power", \
                "pretrained_model_name_or_path", \
                "vae", \
                "caption_extension", \
                "train_folder_directory", \
                "reg_folder_directory", \
                "output_dir", \
                "prior_loss_weight", \
                "resume_path", \
                "project_name", \
                "mixed_precision", \
                "save_precision", \
                "save_n_epochs_type", \
                "save_n_epochs_type_value", \
                "save_model_as", \
                "resolution", \
                "train_batch_size", \
                "max_token_length", \
                "use_8bit_adam", \
                "num_epochs", \
                "seed", \
                "gradient_checkpointing", \
                "gradient_accumulation_steps", \
                "clip_skip", \
                "logging_dir", \
                "log_prefix", \
                "additional_argument"]

if print_hyperparameter:
    table = PrettyTable()
    table.field_names = ["Hyperparameter", "Value"]
    for params in debug_params:
        if params != "":
            if globals()[params] == "":
                value = "False"
            else:
                value = globals()[params]
            table.add_row([params, value])
    table.align = "l"
    print(table)

    arg_list = train_command.split()
    mod_train_command = {'command': arg_list}
    
    train_folder = os.path.dirname(output_dir)
    
    # save the YAML string to a file
    with open(str(train_folder)+'/dreamboothlora_cmd.yaml', 'w') as f:
        yaml.dump(mod_train_command, f)

f = open("./train.sh", "w")
f.write(train_command)
f.close()
!chmod +x ./train.sh
!./train.sh

# VIII. Testing

In [None]:
#@title ## 8.1 Validating LoRA Weights

#@markdown Now you can check if your LoRA trained properly.
network_weights = "/content/dreambooth/output/masabodo.safetensors" #@param {'type':'string'}

import os
import torch
from safetensors.torch import load_file

def main(file):
  print(f"loading: {file}")
  if os.path.splitext(file)[1] == '.safetensors':
    sd = load_file(file)
  else:
    sd = torch.load(file, map_location='cuda')

  values = []

  keys = list(sd.keys())
  for key in keys:
    if 'lora_up' in key or 'lora_down' in key:
      values.append((key, sd[key]))
  print(f"number of LoRA modules: {len(values)}")

  for key, value in values:
    value = value.to(torch.float32)
    print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")

if __name__ == '__main__':
  main(network_weights)

In [None]:
#@title ## 8.2 Inference
%store -r

#@markdown LoRA Config
network_weights = "/content/dreambooth/output/masabodo.safetensors" #@param {'type':'string'}
network_module = "networks.lora"
network_mul = 0.6 #@param {'type':'number'}

#@markdown Other Config
v2 = False #@param {type:"boolean"}
v_parameterization = False #@param {type:"boolean"}
prompt = "masterpiece, best quality, 1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt" #@param {type: "string"}
negative = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" #@param {type: "string"}
model = "/content/pre_trained_model/Anything-v3-1.safetensors" #@param {type: "string"}
vae = "" #@param {type: "string"}
outdir = "/content/tmp" #@param {type: "string"}
scale = 12 #@param {type: "slider", min: 1, max: 40}
sampler = "ddim" #@param ["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]
steps = 28 #@param {type: "slider", min: 1, max: 100}
precision = "fp16" #@param ["fp16", "bf16"] {allow-input: false}
width = 512 #@param {type: "integer"}
height = 768 #@param {type: "integer"}
images_per_prompt = 4 #@param {type: "integer"}
batch_size = 4 #@param {type: "integer"}
clip_skip = 2 #@param {type: "slider", min: 1, max: 40}
seed = -1 #@param {type: "integer"}

final_prompt = f"{prompt} --n {negative}"

%cd {repo_dir}

!python gen_img_diffusers.py \
  {"--v2" if v2 else ""} \
  {"--v_parameterization" if v2 and v_parameterization else ""} \
  --network_module={network_module} \
  --network_weight={network_weights} \
  --network_mul={network_mul} \
  --ckpt={model} \
  --outdir={outdir} \
  --xformers \
  {"--vae=" + vae if vae else ""} \
  --{precision} \
  --W={width} \
  --H={height} \
  {"--seed=" + format(seed) if seed > 0 else ""} \
  --scale={scale} \
  --sampler={sampler} \
  --steps={steps} \
  --max_embeddings_multiples=3 \
  --batch_size={batch_size} \
  --images_per_prompt={images_per_prompt} \
  {"--clip_skip=" + format(clip_skip) if v2 == False else ""} \
  --prompt="{final_prompt}"



In [None]:
#@title ## 8.3. Visualize loss graph (Optional)
training_logs_path = "/content/dreambooth/logs" #@param {type : "string"}

%cd /content/kohya-trainer
%load_ext tensorboard
%tensorboard --logdir {training_logs_path}

# IX. Extras

In [None]:
#@title ## 9.1. Compressing model or dataset
import os
import zipfile
import shutil

zip_module = "zipfile" #@param ["zipfile", "shutil", "pyminizip", "zip"]
directory_to_zip = '/content/dreambooth/output' #@param {type: "string"}
output_filename = '/content/output.zip' #@param {type: "string"}
password = "" #@param {type: "string"}

if zip_module == "zipfile":
    with zipfile.ZipFile(output_filename, 'w') as zip:
        for directory_to_zip, dirs, files in os.walk(directory_to_zip):
            for file in files:
                zip.write(os.path.join(directory_to_zip, file))
elif zip_module == "shutil":
    shutil.make_archive(output_filename, 'zip', directory_to_zip)
elif zip_module == "pyminizip":
    !pip install pyminizip
    import pyminizip
    for root, dirs, files in os.walk(directory_to_zip):
        for file in files:
            pyminizip.compress(os.path.join(root, file), "", os.path.join("*",output_filename), password, 5)
elif zip_module == "zip":
    !zip -rv -q -j {output_filename} {directory_to_zip}

# X. Deployment

In [None]:
#@title ## 10.1. Define your Huggingface Repo

from huggingface_hub import HfApi
from huggingface_hub.utils import validate_repo_id, HfHubHTTPError
%store -r

api = HfApi()
user = api.whoami(write_token)

#@markdown #### If your model/dataset repo didn't exist, it will automatically create your repo.
model_name = "your-model-name" #@param{type:"string"}
dataset_name = "your-dataset-name" #@param{type:"string"}
make_this_model_private = True #@param{type:"boolean"}

model_repo = user['name']+"/"+model_name.strip()
datasets_repo = user['name']+"/"+dataset_name.strip()

if model_name != "":
  try:
      validate_repo_id(model_repo)
      api.create_repo(repo_id=model_repo, 
                      private=make_this_model_private)
      print("Model Repo didn't exists, creating repo")
      print("Model Repo: ",model_repo,"created!\n")

  except HfHubHTTPError as e:
      print(f"Model Repo: {model_repo} exists, skipping create repo\n")

if dataset_name != "":
  try:
      validate_repo_id(datasets_repo)
      api.create_repo(repo_id=datasets_repo,
                      repo_type="dataset",
                      private=make_this_model_private)
      print("Dataset Repo didn't exists, creating repo")
      print("Dataset Repo",datasets_repo,"created!\n")

  except HfHubHTTPError as e:
      print(f"Dataset repo: {datasets_repo} exists, skipping create repo\n")


## 10.2. Upload with `hf_hub`

In [None]:
#@title ### 10.2.1. Upload LoRA
from huggingface_hub import HfApi
from pathlib import Path

api = HfApi()

#@markdown #### This will be uploaded to model repo

model_path = "/content/dreambooth/output/masabodo.safetensors" #@param {type :"string"}
path_in_repo = "masabodo.safetensors" #@param {type :"string"}

#@markdown #### Other Information
commit_message = "feat: upload lora model" #@param {type :"string"}

def upload_model(model_paths, is_folder :bool):
  path_obj = Path(model_paths)
  trained_model = path_obj.parts[-1]
  
  if is_folder == True:
    print(f"Uploading {trained_model} to https://huggingface.co/"+model_repo)
    print(f"Please wait...")
    
    api.upload_folder(
        folder_path=model_paths,
        path_in_repo=trained_model,
        repo_id=model_repo,
        commit_message=commit_message,
        ignore_patterns=".ipynb_checkpoints"
        )
    print(f"Upload success, located at https://huggingface.co/"+model_repo+"/tree/main\n")
  else: 
    print(f"Uploading {trained_model} to https://huggingface.co/"+model_repo)
    print(f"Please wait...")
            
    api.upload_file(
        path_or_fileobj=model_paths,
        path_in_repo=trained_model,
        repo_id=model_repo,
        commit_message=commit_message,
        )
        
    print(f"Upload success, located at https://huggingface.co/"+model_repo+"/blob/main/"+trained_model+"\n")
      
def upload():
    if model_path.endswith((".ckpt", ".safetensors", ".pt")):
      upload_model(model_path, False)
    else:
      upload_model(model_path, True)

upload()

In [None]:
#@title ### 10.2.2. Upload Dataset
from huggingface_hub import HfApi
from pathlib import Path
import shutil
import zipfile
import os

api = HfApi()

#@markdown #### This will be compressed to zip and  uploaded to datasets repo, leave it empty if not necessary
train_data_path = "/content/dreambooth/train_style/5_mksks style" #@param {type :"string"}
#@markdown ##### `Nerd stuff, only if you want to save training logs`
logs_path = "/content/dreambooth/logs" #@param {type :"string"}
#@markdown #### Delete zip after upload
delete_zip = True #@param {type :"boolean"}

if project_name !="":
  tmp_dataset = "/content/dreambooth/"+project_name+"_dataset"
else:
  tmp_dataset = "/content/dreambooth/tmp_dataset"

dataset_zip = tmp_dataset + ".zip"

#@markdown #### Other Information
commit_message = "feat: upload lora dataset" #@param {type :"string"}

os.makedirs(tmp_dataset)

def upload_dataset(dataset_paths, is_zip : bool):
  path_obj = Path(dataset_paths)
  dataset_name = path_obj.parts[-1]

  if is_zip:
    print(f"Uploading {dataset_name} to https://huggingface.co/datasets/"+datasets_repo)
    print(f"Please wait...")

    api.upload_file(
        path_or_fileobj=dataset_paths,
        path_in_repo=dataset_name,
        repo_id=datasets_repo,
        repo_type="dataset",
        commit_message=commit_message,
    )
    print(f"Upload success, located at https://huggingface.co/datasets/"+datasets_repo+"/blob/main/"+dataset_name+"\n")
  else:
    print(f"Uploading {dataset_name} to https://huggingface.co/datasets/"+datasets_repo)
    print(f"Please wait...")

    api.upload_folder(
        folder_path=dataset_paths,
        path_in_repo=dataset_name,
        repo_id=datasets_repo,
        repo_type="dataset",
        commit_message=commit_message,
        ignore_patterns=".ipynb_checkpoints",
    )
    print(f"Upload success, located at https://huggingface.co/datasets/"+datasets_repo+"/tree/main/"+dataset_name+"\n")
  
def zip_file(tmp):
    zipfiles = tmp + ".zip" 
    with zipfile.ZipFile(zipfiles, 'w') as zip:
      for tmp, dirs, files in os.walk(tmp):
          for file in files:
              zip.write(os.path.join(tmp, file))

def move(src_path, dst_path, is_metadata: bool):
  files_to_move = ["meta_cap.json", \
                   "meta_cap_dd.json", \
                   "meta_lat.json", \
                   "meta_clean.json", \
                   "meta_final.json"]

  if os.path.exists(src_path):
    shutil.move(src_path, dst_path)

  if is_metadata:
    parent_meta_path = os.path.dirname(src_path)

    for filename in os.listdir(parent_meta_path):
      file_path = os.path.join(parent_meta_path, filename)
      if filename in files_to_move:
        shutil.move(file_path, dst_path)

def upload():
  if train_data_path !="":
    move(train_data_path, tmp_dataset, False)
    zip_file(tmp_dataset)
    upload_dataset(dataset_zip, True)
    
  if logs_path !="":
    upload_dataset(logs_path, False)

upload()

if delete_zip:
  os.remove(dataset_zip)

## 10.3. Commit using Git (Alternative)

In [None]:
#@title ### 10.3.1. Clone Repository

clone_model = True #@param {'type': 'boolean'}
clone_dataset = True #@param {'type': 'boolean'}

!git lfs install --skip-smudge
!export GIT_LFS_SKIP_SMUDGE=1

if clone_model:
  !git clone https://huggingface.co/{model_repo} /content/{model_name}
clone_dataset
  !git clone https://huggingface.co/datasets/{datasets_repo} /content/{dataset_name}

In [None]:
#@title ### 10.3.2. Commit using Git 
%cd {root_dir}

#@markdown Tick which repo you want to commit
commit_model = True #@param {'type': 'boolean'}
commit_dataset = True #@param {'type': 'boolean'}

#@markdown Set **git commit identity**
email = "your-email-here" #@param {'type': 'string'}
name = "your-username-here" #@param {'type': 'string'}
#@markdown Set **commit message**
commit_m = "feat: upload model and dataset" #@param {'type': 'string'}

!git config --global user.email "{email}"
!git config --global user.name "{name}"

def commit(repo_folder, commit_message):
  %cd {root_dir}/{repo_folder}
  !git lfs install
  !huggingface-cli lfs-enable-largefiles .
  !git add .
  !git commit -m "{commit_message}"
  !git push

commit(model_name, commit_m)
commit(dataset_name, commit_m)