<a href="https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/dev/kohya-trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Kohya Trainer V11 - VRAM 12GB


This notebook has been adapted for use in Google Colab based on [kohya-ss/sd-scripts](https://github.com/kohya-ss/sd-scripts). </br>
This notebook was adapted by [Linaqruf](https://github.com/Linaqruf)</br>
You can find the latest update to the notebook [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer.ipynb).


# I. Install Kohya Trainer

In [None]:
#@title ## 1.1. Clone Kohya Trainer
#@markdown Clone Kohya Trainer from GitHub and check for updates. Use textbox below if you want to checkout other branch or old commit. Leave it empty to stay the HEAD on main.

import os
%store -r

# Check GPU Availability
!nvidia-smi

# Define path
root_dir = "/content"
%store root_dir
repo_dir = str(root_dir)+"/kohya-trainer"
%store repo_dir
tools_dir = str(root_dir)+"/kohya-trainer/tools"
%store tools_dir 
finetune_dir = str(root_dir)+"/kohya-trainer/finetune"
%store finetune_dir

# Define identifier
branch = "dev" #@param {type: "string"}
repo_url = "https://github.com/Linaqruf/kohya-trainer"

%cd {root_dir}

def clone_repo():
  # Check if the directory already exists
  if os.path.isdir(repo_dir):
    %cd {repo_dir}
    print("This folder already exists, will do a !git pull instead\n")
    if branch != "":
      !git pull origin {branch} 
    else:
      !git pull
  else:
    !git clone {repo_url}

def checkout_repo(branch_or_commit):
  %cd {repo_dir}
  !git checkout {branch_or_commit}

# Clone or update the Kohya Trainer repository
clone_repo()

# Checkout to the specified branch or commit
if branch != "":
  checkout_repo(branch)


In [None]:
#@title ## 1.2. Installing Dependencies

import os
%store -r

#@markdown This will install required Python packages

# Define variable
accelerate_config = str(repo_dir)+"/accelerate_config/config.yaml"
%store accelerate_config
colab_ram_patch = True #@param {'type':'boolean'}
install_xformers = True #@param {'type':'boolean'}

%cd {repo_dir}

def install_dependencies():

  if colab_ram_patch:
    !sed -i "s@cpu@cuda@" \
    {repo_dir}/train_db.py \
    {repo_dir}/train_network.py \
    {repo_dir}/fine_tune.py \
    {repo_dir}/library/model_util.py \
    {repo_dir}/library/train_util.py 

    !sed -i "s@cuda_count@cpu_count@" \
    {repo_dir}/train_db.py \
    {repo_dir}/train_network.py \
    {repo_dir}/fine_tune.py \
    {repo_dir}/library/model_util.py \
    {repo_dir}/library/train_util.py 
  
  !pip -qqqq install --upgrade -r requirements.txt
  !pip -qqqq install --upgrade gallery-dl
  !pip -qqqq install --upgrade --no-cache-dir gdown
  !apt -qqqq install liblz4-tool aria2

  if install_xformers:
    !pip -qqqq install -U -I --no-deps https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.15/xformers-0.0.15.dev0+189828c.d20221207-cp38-cp38-linux_x86_64.whl
  else:
    pass

  from accelerate.utils import write_basic_config
  if not os.path.exists(accelerate_config):
    write_basic_config(save_location = accelerate_config) # Write a config file
  else:
    pass

# Install dependencies
install_dependencies()

#@markdown After Accelerate updated its version to 0.15.0, you can't manually input the config using
#@markdown `!accelerate config` in Google Colab. Instead, a `config.yaml` file will be generated by
#@markdown the `write_basic_config()` function. You can find the file [here](/content/kohya-trainer/accelerate_config/config.yaml) after installation.
#@markdown if you want to modify it.



## 1.3. Sign-in to Cloud Service

In [None]:
#@title ### 1.3.1. Login to Huggingface hub
from huggingface_hub import login
%store -r

#@markdown 1. Of course, you need a Huggingface account first.
#@markdown 2. To create a huggingface token, go to [this link](https://huggingface.co/settings/tokens), then `create new token` or copy available token with the `Write` role.

write_token = "your-write-token-here" #@param {type:"string"}
login(write_token, add_to_git_credential=True)

%store write_token


In [None]:
#@title ### 1.3.2. Mount Drive
from google.colab import drive

mount_drive = True #@param {type: "boolean"}

if mount_drive:
  drive.mount('/content/drive')

# II. Pretrained Model Selection

In [None]:
#@title ## 2.1. Download Available Model 
import os
%store -r

%cd {root_dir}

installModels = []
installv2Models = []

#@markdown ### Available Model
#@markdown Select one of available model to download:

#@markdown ### SD1.x model
modelUrl = ["", \
            "https://huggingface.co/Linaqruf/personal_backup/resolve/main/animeckpt/model-pruned.ckpt", \
            "https://huggingface.co/Linaqruf/anything-V3.0/resolve/main/Anything-V3.0-pruned.ckpt", \
            "https://huggingface.co/Linaqruf/anything-v3-better-vae/resolve/main/any-v3-fp32-better-vae.ckpt", \
            "https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned.ckpt", \
            "https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.5-pruned.ckpt", \
            "https://huggingface.co/Rasgeath/self_made_sauce/resolve/main/Kani-anime-pruned.ckpt", \
            "https://huggingface.co/hesw23168/SD-Elysium-Model/resolve/main/Elysium_Anime_V2.ckpt", \
            "https://huggingface.co/prompthero/openjourney-v2/resolve/main/openjourney-v2.ckpt", \
            "https://huggingface.co/dreamlike-art/dreamlike-diffusion-1.0/resolve/main/dreamlike-diffusion-1.0.ckpt", \
            "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt"]
modelList = ["", \
             "Animefull-final-pruned", \
             "Anything-v3-pruned", \
             "Anything-v3-better-vae", \
             "Anything-v4-pruned", \
             "Anything-v4-5-pruned", \
             "Kani-anime-pruned", \
             "Elysium-anime-V2", \
             "OpenJourney-V2", \
             "Dreamlike-diffusion-V1-0", \
             "Stable-Diffusion-v1-5"]
modelName = "Anything-v3-better-vae" #@param ["", "Animefull-final-pruned", "Anything-v3-pruned", "Anything-v3-better-vae", "Anything-v4-pruned", "Anything-v4-5-pruned", "Kani-anime-pruned", "Elysium-anime-V2", "OpenJourney-V2", "Dreamlike-diffusion-V1-0", "Stable-Diffusion-v1-5"]

#@markdown ### SD2.x model
v2ModelUrl = ["", \
              "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt", \
              "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt", \
              "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt", \
              "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e2.ckpt"]
v2ModelList = ["", \
              "stable-diffusion-2-1-base", \
              "stable-diffusion-2-1-768v", \
              "waifu-diffusion-1-4-anime-e1", \
              "waifu-diffusion-1-4-anime-e2"]
v2ModelName = "" #@param ["", "stable-diffusion-2-1-base", "stable-diffusion-2-1-768v", "waifu-diffusion-1-4-anime-e1", "waifu-diffusion-1-4-anime-e2"]

# Check if user has selected a model
if modelName != "":
  # Map selected model to URL
  installModels.append((modelName, modelUrl[modelList.index(modelName)]))

# Check if user has selected a model
if v2ModelName != "":
  # Map selected model to URL
  installv2Models.append((v2ModelName, v2ModelUrl[v2ModelList.index(v2ModelName)]))

def install(checkpoint_name, url):
  hf_token = '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' 
  user_header = f"\"Authorization: Bearer {hf_token}\""
  !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {root_dir}/pre_trained_model -o {checkpoint_name}.ckpt "{url}"

def install_checkpoint():
  # Iterate through list of models to install
  for model in installModels:
    # Call install function for each model
    install(model[0], model[1])

  # Iterate through list of models to install
  for v2model in installv2Models:
    # Call install function for each v2model
    install(v2model[0], v2model[1])

install_checkpoint()

In [None]:
#@title ## 2.2. Download Custom Model

import os
%store -r

%cd {root_dir}

#@markdown ### Custom model
modelUrl = "" #@param {'type': 'string'}
dst = str(root_dir)+"/pre_trained_model"

if not os.path.exists(dst):
    os.makedirs(dst)

def install(url):

  if url.startswith("https://drive.google.com"):
    %cd {dst}
    !gdown --fuzzy {url}
  elif url.startswith("https://huggingface.co/"):
    if '/blob/' in url:
      url = url.replace('/blob/', '/resolve/')
    #@markdown Change this part with your own huggingface token if you need to download your private model
    hf_token = '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' #@param {type:"string"}
    user_header = f"\"Authorization: Bearer {hf_token}\""
    !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {dst} -Z {url}
  else:
    !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {dst} -Z {url}

install(modelUrl)


In [None]:
#@title ## 2.3. Download Available VAE
%store -r 

%cd {root_dir}

installVae = []
#@markdown ### Available VAE
#@markdown Select one of the VAEs to download, select `none` for not download VAE:
vaeUrl = ["", \
          "https://huggingface.co/Linaqruf/personal_backup/resolve/main/animevae/animevae.pt", \
          "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt", \
          "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt"]
vaeList = ["none", \
           "anime.vae.pt", \
           "waifudiffusion.vae.pt", \
           "stablediffusion.vae.pt"]
vaeName = "anime.vae.pt" #@param ["none", "anime.vae.pt", "waifudiffusion.vae.pt", "stablediffusion.vae.pt"]

installVae.append((vaeName, vaeUrl[vaeList.index(vaeName)]))

def install(vae_name, url):
  hf_token = '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'
  user_header = f"\"Authorization: Bearer {hf_token}\""
  !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -o vae/{vae_name} "{url}"

def install_vae():
  if vaeName != "none":
    for vae in installVae:
      install(vae[0], vae[1])
  else:
    pass

install_vae()

# III. Data Acquisition

You can either upload your dataset to this notebook or use the image scraper below to bulk download images from Danbooru.

If you want to use your own dataset, you can upload to colab `local files`.


In [None]:
#@title ## 3.1. Define Train Data Directory
#@markdown Define where your train data will be located. This cell will also create a folder based on your input. 
#@markdown This folder will be used as the target folder for scraping, tagging, bucketing, and training in the next cell.

import os
%store -r

train_data_dir = "/content/fine_tune/train_data" #@param {'type' : 'string'}
%store train_data_dir

if not os.path.exists(train_data_dir):
    os.makedirs(train_data_dir)
else:
    print(f"{train_data_dir} already exists\n")

print(f"Your train data directory : {train_data_dir}")

In [None]:
#@title ## 3.2. Clone Dataset Repository (Optional)
#@markdown *Optional but can be useful for resume training process, because you will need that `last-state` folder*
%store -r

%cd {root_dir}

#@markdown ### Define Parameters
repository_url = "https://huggingface.co/datasets/Linaqruf/hitokomoru-tag-v2"  #@param {'type': 'string'}

#@markdown ### Leave it empty if your datasets is on `main` branch
branch = "" #@param {'type': 'string'}

!git lfs install
if branch != "":
  !git clone --branch {branch} {repository_url}
else:
  !git clone {repository_url}

In [None]:
#@title ## 3.3. Download dataset (.zip)
import os
import shutil
from pathlib import Path
%store -r

#@markdown ### Define download parameter
zipfile_url = "https://huggingface.co/datasets/Linaqruf/anijourneydb/resolve/main/anijourneydbv1_512.zip" #@param {'type': 'string'}
zipfile_path = str(root_dir)+"/train_data.zip"

dirname = os.path.dirname(zipfile_path)
basename = os.path.basename(zipfile_path)

try:
  if zipfile_url.startswith("https://drive.google.com"):
    !gdown --fuzzy  {zipfile_url}
  elif zipfile_url.startswith("magnet:?"):
    !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 {zipfile_url}
  elif zipfile_url.startswith("https://huggingface.co/"):
    if '/blob/' in zipfile_url:
      zipfile_url = zipfile_url.replace('/blob/', '/resolve/')

    #@markdown Change this part with your own huggingface token if you need to download your private model
    hf_token = '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' #@param {type:"string"}
    user_header = f"\"Authorization: Bearer {hf_token}\""
    !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {dirname} -o {basename} {zipfile_url}
  else:
    !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {dirname} -o {basename} {zipfile_url}

except Exception as e:
  print("An error occurred while downloading the file:", e)

#@markdown Tick this if you want to extract all files directly to `train_data_dir`, and automatically delete the zip to save the memory
auto_unzip_and_delete = True #@param{'type':'boolean'}

if auto_unzip_and_delete:
  !unzip -j {zipfile_path} -d {train_data_dir}

  # directory to check for files
  # JSON files to move
  files_to_move = ("meta_cap.json", \
                   "meta_cap_dd.json", \
                   "meta_lat.json", \
                   "meta_clean.json")

  # check each file in the directory
  for filename in os.listdir(train_data_dir):
      # get the full file path
      file_path = os.path.join(train_data_dir, filename)
      if filename in files_to_move:
          # move the file to the parent directory
          shutil.move(file_path, os.path.dirname(train_data_dir))
  
  path_obj = Path(zipfile_path)
  zipfile_name = path_obj.parts[-1]
  
  if os.path.isdir(zipfile_path):
    print("\nThis zipfile doesn't exist or has been deleted \n")
  else:
    os.remove(zipfile_path)
    print(f"\n{zipfile_name} has been deleted")

In [None]:
#@title ### 3.3.1. Unzip dataset (.zip)
import shutil
import os
import zipfile
from pathlib import Path

#@markdown ### Define unzip parameter
zipfile_src = '/content/train_data.zip' #@param{'type':'string'}
zipfile_dst = '/content/fine_tune/train_data' #@param{'type':'string'}

#@markdown ### Delete zipfile after unzip process done
delete_zipfile = True #@param{'type':'boolean'}

try:   
  !unzip -j {zipfile_src} -d {zipfile_dst}

except Exception as e:
  print("An error occurred while unzipping the file:", e)

# directory to check for files
  # JSON files to move
files_to_move = ("meta_cap.json", \
                 "meta_cap_dd.json", \
                 "meta_lat.json", \
                 "meta_clean.json")

  # check each file in the directory
for filename in os.listdir(zipfile_dst):
  # get the full file path
  file_path = os.path.join(zipfile_dst, filename)
  if filename in files_to_move:
    # move the file to the parent director
    shutil.move(file_path, os.path.dirname(zipfile_dst))
  
if delete_zipfile:
  path_obj = Path(zipfile_src)
  zipfile_name = path_obj.parts[-1]
  
  if os.path.isdir(zipfile_src):
    print("\nThis zipfile doesn't exist or has been deleted \n")
  else:
    os.remove(zipfile_src)
    print(f"\n{zipfile_name} has been deleted")

In [None]:
#@title ## 3.4. Simple Booru Scraper
#@markdown Use gallery-dl to scrape images from a booru site using the specified tags
import os
import html
%store -r 

%cd {root_dir}

# Set configuration options
booru = "Gelbooru" #@param ["", "Danbooru", "Gelbooru"]
tag1 = "hito_komoru" #@param {type: "string"}
tag2 = "" #@param {type: "string"}
download_tags = False #@param {type: "boolean"}
# Construct the search query
if tag2 != "":
  tags = tag1 + "+" + tag2
else:
  tags = tag1

if download_tags == True:
  write_tags = "--write-tags"
else:
  write_tags = ""

# Scrape images from the specified booru site using the given tags
if booru.lower() == "danbooru":
  !gallery-dl "https://danbooru.donmai.us/posts?tags={tags}" {write_tags} -D {train_data_dir}
elif booru.lower() == "gelbooru":
  !gallery-dl "https://gelbooru.com/index.php?page=post&s=list&tags={tags}" {write_tags} -D {train_data_dir}
else:
  print(f"Unknown booru site: {booru}")

if download_tags == True: 
  # Get a list of all the .txt files in the folder
  files = [f for f in os.listdir(train_data_dir) if f.endswith(".txt")]

  # Loop through each file
  for file in files:
      file_path = os.path.join(train_data_dir, file)

      # Read the contents of the file
      with open(file_path, "r") as f:
          contents = f.read()

      # Decode HTML entities and replace _ with a space
      contents = html.unescape(contents)
      contents = contents.replace("_", " ")

      # Split the contents on newline characters and join with commas
      contents = ", ".join(contents.split("\n"))

      # Write the modified contents back to the file
      with open(file_path, "w") as f:
          f.write(contents)

# IV. Data Preprocessing

In [None]:
#@title ## 4.1. Data Cleaning
#@markdown This will delete unnecessary files and unsupported media like `.mp4`, `.webm`, and `.gif`
%store -r

import os

%cd {root_dir}

test = os.listdir(train_data_dir)

#@markdown I recommend to `keep_metadata` especially if you're doing resume training and you have metadata and bucket latents file from previous training like `.npz`, `.txt`, `.caption`, and `json`.
keep_metadata = True #@param {'type':'boolean'}

# List of supported file types
if keep_metadata == True:
  supported_types = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".caption", ".npz", ".txt", ".json"]
else:
  supported_types = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]

# Iterate over all files in the directory
for item in test:
    # Extract the file extension from the file name
    file_ext = os.path.splitext(item)[1]
    # If the file extension is not in the list of supported types, delete the file
    if file_ext not in supported_types:
        # Print a message indicating the name of the file being deleted
        print(f"Deleting file {item} from {train_data_dir}")
        # Delete the file
        os.remove(os.path.join(train_data_dir, item))

In [None]:
#@title ## 4.2. Data Annotation
%store -r
%cd {finetune_dir}

#@markdown We're using [BLIP](https://huggingface.co/spaces/Salesforce/BLIP) for image captioning and [Waifu Diffusion 1.4 Tagger](https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags) for image tagging like danbooru.

#@markdown Tick this if you want to label your dataset with natural language like this: <br>
#@markdown `a girl with long hair holding a cellphone`

Start_BLIP_Captioning = True #@param {type:"boolean"}

#@markdown or Tick this if you want to label your dataset with danbooru tag like this: <br>
#@markdown `1girl, solo, looking_at_viewer, short_hair, bangs, simple_background, shirt, black_hair, white_background, closed_mouth, choker, hair_over_one_eye, head_tilt, grey_eyes, black_shirt, floating_hair, black_choker, eyes_visible_through_hair, portrait`

Start_WD_1_4_Tagger = True #@param {type:"boolean"}

#@markdown or you can use them both

batch_size = 8

if Start_BLIP_Captioning == True:
  !python make_captions.py \
    {train_data_dir} \
    --batch_size {batch_size} \
    --caption_extension .caption
else:
  pass

if Start_WD_1_4_Tagger == True:
  !python tag_images_by_wd14_tagger.py \
    {train_data_dir} \
    --batch_size {batch_size} \
    --caption_extension .txt
else:
  pass

In [None]:
#@title ### 4.2.1. Append Custom Tag (Optional)
import os
%store -r

%cd {root_dir}

def clone_random_repo():
  # Check if the directory already exists
  if os.path.isdir(str(root_dir)+"/cafe-aesthetic-scorer/"):
    %cd {root_dir}/cafe-aesthetic-scorer/
    print("This folder already exists, will do a !git pull instead\n")
    !git pull
  else:
    !git clone https://github.com/Linaqruf/cafe-aesthetic-scorer/

clone_random_repo()

%cd {root_dir}/cafe-aesthetic-scorer/
#@markdown If you want to append custom tag, you can do that here. This cell will add custom tag at the beginning of lines
custom_tag = "masterpiece" #@param {type:"string"}
caption_extension = "txt" #@param ["txt","caption"]
#@markdown Tick this if you want to append custom tag at the end of lines instead
append = False #@param {type: "boolean"}

if append:
  append_tag = "--append"
else:
  append_tag = ""

!python custom_tagger.py \
  {train_data_dir} \
  {caption_extension} \
  {custom_tag} \
  {append_tag}   

In [None]:
#@title ## 4.3. Create JSON file for Finetuning
import os
%store -r

# Change the working directory
%cd {finetune_dir}

#@markdown ### Define Parameter
meta_clean = "/content/fine_tune/meta_clean.json" #@param {type:"string"}
#@markdown This cell will merge all dataset label from captioning, tagging, and custom tagging into one JSON file, and later it will be used as input JSON for bucketing section.

parent_folder = os.path.dirname(meta_clean)
meta_cap_dd = f"{parent_folder}/meta_cap_dd.json"
meta_cap = f"{parent_folder}/meta_cap.json"

# Check if directory exists
if not os.path.exists(parent_folder):
  # Create directory if it doesn't exist
  os.makedirs(parent_folder)

# Check if the train_data_dir exists and is a directory
if os.path.isdir(train_data_dir):
  # Check if there are any .caption files in the train_data_dir
  if any(file.endswith('.caption') for file in os.listdir(train_data_dir)):
    # Create meta_cap.json from captions
    !python merge_captions_to_metadata.py \
      {train_data_dir} \
      {meta_cap}

  # Check if there are any .txtn files in the train_data_dir
  if any(file.endswith('.txt') for file in os.listdir(train_data_dir)):
    # Create meta_cap_dd.json from tags
    !python merge_dd_tags_to_metadata.py \
      {train_data_dir} \
      {meta_cap_dd}
else:
  print("train_data_dir does not exist or is not a directory.")

# Merge meta_cap.json to meta_cap_dd.json
if os.path.exists(meta_cap) and os.path.exists(meta_cap_dd):
  !python merge_dd_tags_to_metadata.py \
    {train_data_dir} \
    --in_json {meta_cap} \
    {meta_cap_dd}

# Clean meta_cap_dd.json and store it to meta_clean.json
if os.path.exists(meta_cap_dd):
  # Clean captions and tags in meta_cap_dd.json and store the result in meta_clean.json
  !python clean_captions_and_tags.py \
    {meta_cap_dd} \
    {meta_clean}
elif os.path.exists(meta_cap):
  # If meta_cap_dd.json does not exist, clean meta_cap.json and store the result in meta_clean.json
  !python clean_captions_and_tags.py \
    {meta_cap} \
    {meta_clean}

In [None]:
#@title ## 4.4. Aspect Ratio Bucketing and Cache Latents
%store -r

# Change the working directory
%cd {finetune_dir}

#@markdown ### Define parameters
V2 = False #@param{type:"boolean"}
model_dir = "/content/pre_trained_model/Anything-v3-better-vae.ckpt" #@param {'type' : 'string'} 
input_json = "/content/fine_tune/meta_clean.json" #@param {'type' : 'string'} 
output_json = "/content/fine_tune/meta_lat.json"#@param {'type' : 'string'} 
batch_size = 4 #@param {'type':'integer'}
max_resolution = "512,512" #@param ["512,512", "640,640", "768,768"] {allow-input: false}
mixed_precision = "no" #@param ["no", "fp16", "bf16"] {allow-input: false}
flip_aug = True #@param{type:"boolean"}

if V2:
  SDV2 = "--v2"
else:
  SDV2 = ""

if flip_aug:
  flip_latents = "--v2"
else:
  flip_latents = "--flip_aug"

# Run script to prepare buckets and latents
!python prepare_buckets_latents.py \
  {train_data_dir} \
  {input_json} \
  {output_json} \
  {model_dir} \
  {SDV2} \
  {flip_latents} \
  --batch_size {batch_size} \
  --max_resolution {max_resolution} \
  --mixed_precision {mixed_precision}

# V. Training Model



In [None]:
#@title ## 5.1. Define Important folder
import os
%store -r

v2 = False #@param {type:"boolean"}
v_parameterization = False #@param {type:"boolean"}
output_name = "hito_komoru" #@param {type:"string"}
pretrained_model_name_or_path = "/content/pre_trained_model/Anything-v3-better-vae.ckpt" #@param {type:"string"}
vae = ""  #@param {type:"string"}
train_data_dir = "/content/fine_tune/train_data"  #@param {type:"string"}
%store train_data_dir
in_json = "/content/fine_tune/meta_lat.json" #@param {type:"string"}
output_dir = "/content/fine_tune/output" #@param {type:"string"}
resume_path = "" #@param {type:"string"}

#@markdown This will ignore `output_dir` defined above, and changed to `/content/drive/MyDrive/fine_tune/output` by default
output_to_drive = False #@param {'type':'boolean'}

if output_to_drive:
  output_dir = "/content/drive/MyDrive/fine_tune/output"

  if not os.path.exists("/content/drive"):
    drive.mount('/content/drive')  

# Check if directory exists
if not os.path.exists(output_dir):
  # Create directory if it doesn't exist
  os.makedirs(output_dir)

#V2 Inference
inference_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/"

if v2 and not v_parameterization:
  inference_url += "v2-inference.yaml"
if v2 and v_parameterization:
  inference_url += "v2-inference-v.yaml"

try:
  if v2:
    !wget {inference_url} -O {output_dir}/{output_name}.yaml
    print("File successfully downloaded")
except:
  print("There was an error downloading the file. Please check the URL and try again.")


In [None]:
from prettytable import PrettyTable
%store -r

#@title ## 5.2. Start Fine-Tuning
#@markdown ### Define Parameter
train_batch_size = 1 #@param {type:"number"}
train_text_encoder = False #@param {'type':'boolean'}
max_train_steps = 5000 #@param {type:"number"}
mixed_precision = "fp16" #@param ["no","fp16","bf16"] {allow-input: false}
save_precision = "float" #@param ["float", "fp16", "bf16"] {allow-input: false}
save_every_n_epochs = 0 
save_last_n_epochs = 0 
save_model_as = "ckpt" #@param ["ckpt", "safetensors", "diffusers", "diffusers_safetensors"] {allow-input: false}
resolution = 512 #@param {type:"number"}
max_token_length = 225 #@param {type:"number"}
clip_skip = 2 #@param {type:"number"}
learning_rate = 2e-6 #@param {type:"number"}
lr_scheduler = "constant" #@param  ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] {allow-input: false}
dataset_repeats = 1 #@param {type:"number"}
seed = 1234 #@param {type:"number"}
use_8bit_adam = True #@param {type:"boolean"}
gradient_checkpointing = False #@param {type:"boolean"}
gradient_accumulation_steps = 1 #@param {type:"number"}
logging_dir = repo_dir + "/logs"
additional_argument = "--save_state --shuffle_caption --xformers" #@param {type:"string"}
print_hyperparameter = True #@param {type:"boolean"}

%cd {repo_dir}

train_command=f"""
accelerate launch --config_file={accelerate_config} --num_cpu_threads_per_process 8 fine_tune.py \
  {"--v2" if v2 else ""} \
  {"--v_parameterization" if v2 and v_parameterization else ""} \
  --pretrained_model_name_or_path={pretrained_model_name_or_path} \
  {"--vae=" + vae if vae else ""} \
  --train_data_dir={train_data_dir} \
  --in_json={in_json} \
  --output_dir={output_dir} \
  {"--resume=" + resume_path if resume_path else ""} \
  {"--output_name=" + output_name if output_name else ""} \
  --mixed_precision={mixed_precision} \
  --save_precision={save_precision} \
  {"--save_every_n_epochs=" + save_every_n_epochs if save_every_n_epochs != 0 else ""} \
  {"--save_last_n_epochs=" + save_every_n_epochs if save_last_n_epochs != 0 else ""} \
  --save_model_as={save_model_as} \
  --resolution={resolution} \
  --train_batch_size={train_batch_size} \
  --max_token_length={max_token_length} \
  {"--train_text_encoder" if train_text_encoder else ""} \
  {"--use_8bit_adam" if use_8bit_adam else ""} \
  --learning_rate={learning_rate} \
  --lr_scheduler={lr_scheduler} \
  --lr_warmup_steps={lr_warmup_steps} \
  --dataset_repeats={dataset_repeats} \
  --max_train_steps={max_train_steps} \
  {"--seed=" + format(seed) if seed else ""} \
  {"--gradient_checkpointing" if gradient_checkpointing else ""} \
  {"--gradient_accumulation_steps=" + format(gradient_accumulation_steps) } \
  {"--clip_skip=" + clip_skip if v2=="" else ""} \
  --logging_dir={logging_dir} \
  --log_prefix={output_name} \
  {additional_argument}
  """
  
if print_hyperparameter:
    table = PrettyTable()
    table.field_names = ["Hyperparameter", "Value"]
    for params in debug_params:
        if params != "":
            if globals()[params] == "":
                value = "False"
            else:
                value = globals()[params]
            table.add_row([params, value])
    table.add_row(["train_command", train_command])
    table.align = "l"
    print(table)

f = open("./train.sh", "w")
f.write(train_command)
f.close()
!chmod +x ./train.sh
!./train.sh

# VI. Testing

In [None]:
#@title ## 6.1. Inference
%store -r

v2 = False #@param {type:"boolean"}
v_parameterization = False #@param {type:"boolean"}
prompt = "masterpiece, best quality, high quality, 1girl, solo, sitting, confident expression, long blonde hair, blue eyes, formal dress, jewelry, make-up, luxury, close-up, face, upper body." #@param {type: "string"}
negative = "worst quality, low quality, medium quality, deleted, lowres, comic, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry" #@param {type: "string"}
model = "/content/fine_tune/output/last.ckpt" #@param {type: "string"}
vae = "" #@param {type: "string"}
outdir = "/content/tmp" #@param {type: "string"}
scale = 7 #@param {type: "slider", min: 1, max: 40}
sampler = "ddim" #@param ["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]
steps = 28 #@param {type: "slider", min: 1, max: 100}
precision = "fp16" #@param ["fp16", "bf16"] {allow-input: false}
width = 512 #@param {type: "integer"}
height = 768 #@param {type: "integer"}
images_per_prompt = 4 #@param {type: "integer"}
batch_size = 4 #@param {type: "integer"}
clip_skip = 2 #@param {type: "slider", min: 1, max: 40}
seed = -1 #@param {type: "integer"}

final_prompt = f"{prompt} --n {negative}"

%cd {repo_dir}

!python gen_img_diffusers.py \
  {"--v2" if v2 else ""} \
  {"--v_parameterization" if v2 and v_parameterization else ""} \
  --ckpt={model} \
  --outdir={outdir} \
  --xformers \
  {"--vae=" + vae if vae else ""} \
  --{precision} \
  --W={width} \
  --H={height} \
  {"--seed=" + format(seed) if seed > 0 else ""} \
  --scale={scale} \
  --sampler={sampler} \
  --steps={steps} \
  {"--max_embeddings_multiples=" + format(3)} \
  --batch_size={batch_size} \
  --images_per_prompt={images_per_prompt} \
  {"--clip_skip=" + clip_skip if v2=="" else ""} \
  --prompt="{final_prompt}"


In [None]:
#@title ## 6.2. Visualize loss graph (Optional)
training_logs_path = "/content/fine_tune/training_logs" #@param {type : "string"}

%cd {repo_dir}
%load_ext tensorboard
%tensorboard --logdir {training_logs_path}

# VII. Extras

In [None]:
#@title ## 7.1. Convert Diffusers to `.ckpt/.safetensors`
%store -r
%cd {tools_dir}


#@markdown ## Define model path
weight = "/content/fine_tune/output/last.ckpt" #@param {'type': 'string'}
weight_dir = os.path.dirname(weight)
base_name = os.path.splitext(os.path.basename(weight))[0]

convert = "ckpt_safetensors_to_diffusers" #@param ["diffusers_to_ckpt_safetensors", "ckpt_safetensors_to_diffusers"] {'allow-input': false}
#@markdown ___
#@markdown ## Conversion Config
#@markdown ___
#@markdown ### Diffusers to `.ckpt/.safetensors`
use_safetensors = False #@param {'type': 'boolean'}

save_precision = "--float" #@param ["--fp16","--bf16","--float"] {'allow-input': false}

#@markdown ### `.ckpt/.safetensors` to Diffusers
#@markdown is your model v1 or v2 based Stable Diffusion Model
version = "--v1" #@param ["--v1","--v2"] {'allow-input': false}
diffusers = os.path.join(weight_dir, base_name)

#@markdown Additional file for diffusers
feature_extractor = True #@param {'type': 'boolean'}
safety_checker = True #@param {'type': 'boolean'}

if use_safetensors:
    checkpoint = str(diffusers)+".safetensors"
else:
    checkpoint = str(diffusers)+".ckpt"

if version == "--v1":
  reference_model = "runwayml/stable-diffusion-v1-5"
elif version == "--v2":
  reference_model = "stabilityai/stable-diffusion-2-1"

if convert == "diffusers_to_ckpt_safetensors":
    if not weight.endswith(".ckpt") or weight.endswith(".safetensors"):
        !python convert_diffusers20_original_sd.py \
            "{weight}" \
            "{checkpoint}"" \
            {save_precision}

else:    
    !python convert_diffusers20_original_sd.py \
        "{weight}" \
        "{diffusers}" \
        {version} \
        --reference_model {reference_model} 

    url1 = "https://huggingface.co/CompVis/stable-diffusion-safety-checker/resolve/main/preprocessor_config.json"
    url2 = "https://huggingface.co/CompVis/stable-diffusion-safety-checker/resolve/main/config.json"
    url3 = "https://huggingface.co/CompVis/stable-diffusion-safety-checker/resolve/main/pytorch_model.bin"

    if feature_extractor == True:
      if not os.path.exists(str(diffusers)+'/feature_extractor'):
        os.makedirs(str(diffusers)+'/feature_extractor')
      
      !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d '{diffusers}/feature_extractor' -Z {url1}

    if safety_checker == True:
      if not os.path.exists(str(diffusers)+'/safety_checker'):
        os.makedirs(str(diffusers)+'/safety_checker')
      
      !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d '{diffusers}/safety_checker' -Z {url2}
      !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d '{diffusers}/safety_checker' -Z {url3}


In [None]:
import os
%store -r
#@title ## 7.2. Model Pruner

%cd {toolsdir}

if os.path.exists('prune.py'):
  pass
else:
  # Add a comment to explain what the code is doing
  # Download the pruning script if it doesn't already exist
  !wget https://raw.githubusercontent.com/lopho/stable-diffusion-prune/main/prune.py

#@markdown Convert to Float16
fp16 = False #@param {'type':'boolean'}
#@markdown Use EMA for weights
ema = False #@param {'type':'boolean'}
#@markdown Strip CLIP weights
no_clip = False #@param {'type':'boolean'}
#@markdown Strip VAE weights
no_vae = False #@param {'type':'boolean'}
#@markdown Strip depth model weights
no_depth = False #@param {'type':'boolean'}
#@markdown Strip UNet weights
no_unet = False #@param {'type':'boolean'}
#@markdown You need to input model ends with `.ckpt`, because `.safetensors` model won't work.

input = "/content/fine_tune/output/last.ckpt" #@param {'type' : 'string'}


# Notify the user that the model is being loaded
print(f"Loading model from {input}")

input_path = os.path.dirname(input)
base_name = os.path.basename(input)
output_name = base_name.split('.')[0]
# Notify the user of the arguments being used
if fp16:
    print("Converting to float16")
    output_name += '-fp16'
if ema:
    print("Using EMA for weights")
    output_name += '-ema'
if no_clip:
    print("Stripping CLIP weights")
    output_name += '-no-clip'
if no_vae:
    print("Stripping VAE weights")
    output_name += '-no-vae'
if no_depth:
    print("Stripping depth model weights")
    output_name += '-no-depth'
if no_unet:
    print("Stripping UNet weights")
    output_name += '-no-unet'
output_name += '-pruned'
output_path = os.path.join(input_path, output_name + '.ckpt')

!python3 prune.py "{input}" \
  "{output_path}" \
  {'--fp16' if fp16 else ''} \
  {'--ema' if ema else ''} \
  {'--no-clip' if no_clip else ''} \
  {'--no-vae' if no_vae else ''} \
  {'--no-depth' if no_depth else ''} \
  {'--no-unet' if no_unet else ''}

# Notify the user of the output file location
print(f"Saving pruned model to {output_path}")

In [None]:
#@title ## 7.3. Compressing model or dataset
import os
import zipfile
import shutil

zip_module = "zipfile" #@param ["zipfile", "shutil", "pyminizip", "zip"]
directory_to_zip = '/content/fine_tune/train_data' #@param {type: "string"}
output_filename = '/content/train_data.zip' #@param {type: "string"}
password = "" #@param {type: "string"}

if zip_module == "zipfile":
    with zipfile.ZipFile(output_filename, 'w') as zip:
        for directory_to_zip, dirs, files in os.walk(directory_to_zip):
            for file in files:
                zip.write(os.path.join(directory_to_zip, file))
elif zip_module == "shutil":
    shutil.make_archive(output_filename, 'zip', directory_to_zip)
elif zip_module == "pyminizip":
    !pip install pyminizip
    import pyminizip
    for root, dirs, files in os.walk(directory_to_zip):
        for file in files:
            pyminizip.compress(os.path.join(root, file), "", os.path.join("*",output_filename), password, 5)
elif zip_module == "zip":
    !zip -rv -q -j {output_filename} {directory_to_zip}

# VIII. Deployment

In [None]:
#@title ## 8.1. Define your Huggingface Repo

from huggingface_hub import HfApi
from huggingface_hub.utils import validate_repo_id, HfHubHTTPError
%store -r

api = HfApi()
user = api.whoami(write_token)

#@markdown #### If your model/dataset repo didn't exist, it will automatically create your repo.
model_name = "your-model-name" #@param{type:"string"}
dataset_name = "your-dataset-name" #@param{type:"string"}
make_this_model_private = True #@param{type:"boolean"}
clone_with_git = True #@param{type:"boolean"}

model_repo = user['name']+"/"+model_name.strip()
datasets_repo = user['name']+"/"+dataset_name.strip()

validate_repo_id(model_repo)
validate_repo_id(datasets_repo)

if make_this_model_private:
  private_repo = True
else:
  private_repo = False

if model_name != "":
  try:
      api.create_repo(repo_id=model_repo, 
                      private=private_repo)
      print("Model Repo didn't exists, creating repo")
      print("Model Repo: ",model_repo,"created!\n")

  except HfHubHTTPError as e:
      print(f"Model Repo: {model_repo} exists, skipping create repo\n")

if dataset_name != "":
  try:
      api.create_repo(repo_id=datasets_repo,
                      repo_type="dataset",
                      private=private_repo)
      print("Dataset Repo didn't exists, creating repo")
      print("Dataset Repo",datasets_repo,"created!\n")

  except HfHubHTTPError as e:
      print(f"Dataset repo: {datasets_repo} exists, skipping create repo\n")

if clone_with_git:
  !sudo apt-get remove git-lfs

  if model_name != "":
    !git clone https://huggingface.co/{model_repo} /content/{model_name}
  
  if dataset_name != "":
    !git clone https://huggingface.co/datasets/{datasets_repo} /content/{dataset_name}

## 8.2. Upload to Huggingface

In [None]:
#@title ### 8.2.1. Commit using Git 
%cd {root_dir}

#@markdown Tick which repo you want to commit
commit_model = True #@param {'type': 'boolean'}
commit_dataset = True #@param {'type': 'boolean'}

#@markdown Set **git commit identity**
email = "your-email" #@param {'type': 'string'}
name = "your-username" #@param {'type': 'string'}
#@markdown Set **commit message**
commit_m = "feat: upload prototype model" #@param {'type': 'string'}

!sudo apt-get install git-lfs

!git config --global user.email "{email}"
!git config --global user.name "{name}"

if commit_model:
  %cd {root_dir}/{model_name}
  !huggingface-cli lfs-enable-largefiles .
  !git add .
  !git lfs help smudge
  !git commit -m "{commit_m}"
  !git push

if commit_dataset:
  %cd {root_dir}/{dataset_name}
  !huggingface-cli lfs-enable-largefiles .
  !git add .
  !git lfs help smudge
  !git commit -m "{commit_m}"
  !git push

In [None]:
#@title ### 8.2.2. Quick Upload to Huggingface
from huggingface_hub import HfApi
from pathlib import Path

api = HfApi()

#@markdown #### This will be uploaded to model repo
model_path = "/content/fine_tune/output/last.ckpt" #@param {type :"string"}

is_diffusers_model = False #@param {type:"boolean"}

#@markdown #### This will be uploaded to datasets repo, leave it empty if not necessary
last_state_path = "/content/fine_tune/output/last-state" #@param {type :"string"}
train_data_path = "/content/fine_tune/output/train_data" #@param {type :"string"}
meta_lat_path = "/content/fine_tune/output/meta_lat_json" #@param {type :"string"}

#@markdown ##### `Nerd stuff, only if you want to save training logs`
logs_path = "/content/fine_tune/logs" #@param {type :"string"}

#@markdown #### Other Information
commit_message = "feat: upload a model and dataset" #@param {type :"string"}

if model_path != "":
  path_obj = Path(model_path)
  trained_model = path_obj.parts[-1]

  if model_path.endswith(".ckpt") or model_path.endswith(".safetensors") or model_path.endswith(".pt"):
    print(f"Uploading {trained_model} to https://huggingface.co/"+model_repo)
    print(f"Please wait...")

    api.upload_file(
        path_or_fileobj=model_path,
        path_in_repo=trained_model,
        repo_id=model_repo,
        commit_message=commit_message,
    )
    
    print(f"Upload success, located at https://huggingface.co/"+model_repo+"/blob/main/"+trained_model+"\n")
  
  elif is_diffusers_model == True:
    print(f"Uploading {trained_model} to https://huggingface.co/"+model_repo)
    print(f"Please wait...")

    api.upload_folder(
        folder_path=model_path,
        repo_id=model_repo,
        commit_message=commit_message,
        ignore_patterns=".ipynb_checkpoints"
    )
    print(f"Upload success, located at https://huggingface.co/"+model_repo+"/tree/main\n")
  
  else:
    print(f"Uploading {trained_model} to https://huggingface.co/"+model_repo)
    print(f"Please wait...")

    api.upload_folder(
        folder_path=model_path,
        path_in_repo=trained_model,
        repo_id=model_repo,
        commit_message=commit_message,
        ignore_patterns=".ipynb_checkpoints"
    )
    print(f"Upload success, located at https://huggingface.co/"+model_repo+"/tree/main/"+trained_model+"\n")
if last_state_path != "":
  path_obj = Path(last_state_path)
  last_state_folder = path_obj.parts[-1]

  print(f"Uploading {last_state_folder} to https://huggingface.co/datasets/"+datasets_repo)
  print(f"Please wait...")

  api.upload_folder(
      folder_path=last_state_path,
      path_in_repo=last_state_folder,
      repo_id=datasets_repo,
      repo_type="dataset",
      commit_message=commit_message,
      ignore_patterns=".ipynb_checkpoints",
  )
  print(f"Upload success, located at https://huggingface.co/datasets/"+datasets_repo+"/tree/main/"+last_state_folder+"\n")


if train_data_path != "":
  path_obj = Path(train_data_path)
  train_data_folder = path_obj.parts[-1]

  print(f"Uploading {train_data_folder} to https://huggingface.co/datasets/"+datasets_repo)
  print(f"Please wait...")

  api.upload_folder(
      folder_path=train_data_path,
      path_in_repo=train_data_folder,
      repo_id=datasets_repo,
      repo_type="dataset",
      commit_message=commit_message,
      ignore_patterns=".ipynb_checkpoints",
  )
  print(f"Upload success, located at https://huggingface.co/datasets/"+datasets_repo+"/tree/main/"+train_data_folder+"\n")

if meta_lat_path != "":
  path_obj = Path(meta_lat_path)
  meta_lat_file = path_obj.parts[-1]

  print(f"Uploading {meta_lat_file} to https://huggingface.co/datasets/"+datasets_repo)
  print(f"Please wait...")

  api.upload_file(
      path_or_fileobj=meta_lat_path,
      path_in_repo=meta_lat_file,
      repo_id=datasets_repo,
      repo_type="dataset",
      commit_message=commit_message,
  )
  print(f"Upload success, located at https://huggingface.co/datasets/"+datasets_repo+"/blob/main/"+meta_lat_file+"\n")

if logs_path != "":
  path_obj = Path(logs_path)
  logs_folder = path_obj.parts[-1]

  print(f"Uploading {logs_folder} to https://huggingface.co/datasets/"+datasets_repo)
  print(f"Please wait...")

  api.upload_folder(
      folder_path=logs_path,
      path_in_repo=logs_folder,
      repo_id=datasets_repo,
      repo_type="dataset",
      commit_message=commit_message,
      ignore_patterns=".ipynb_checkpoints",
  )
  print(f"Upload success, located at https://huggingface.co/datasets/"+datasets_repo+"/tree/main/"+logs_folder+"\n")