<a href="https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/tools/code-snippet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Auto-move file 
#@markdown This code automatically moves the `last.ckpt` file, the `last-state` folder, the `train_data`
#@markdown directory, and the `meta_lat.json` directory from the output directory to the cloned model
#@markdown and datasets repositories. Before running this code, you need to clone the datasets and
#@markdown model repositories from huggingface. The code checks for the existence of these files and
#@markdown folders in the source and destination directories, and prints messages if they already
#@markdown exist or do not exist. It uses the os and shutil libraries to check for the existence of
#@markdown files and folders and to move them.
import shutil

# The path of the output directory
output_dir = '/content/kohya-trainer/fine_tuned/' #@param {'type':'string'}

# The name of the model
model_name = 'momoko30k' #@param {'type':'string'}

# The path of the cloned model repository
cloned_model_repo = '/content/momoko' #@param {'type':'string'}

# The name of the save state
save_state_name = 'momoko30k-state' #@param {'type':'string'}

# The path of the cloned datasets repository
cloned_datasets_repo = '/content/momoko-tag' #@param {'type':'string'}

# The path of the meta lat json directory
meta_lat_json_dir = "/content/kohya-trainer/meta_lat.json" #@param {'type':'string'}

# The path of the train data directory
train_data_dir = "/content/kohya-trainer/train_data" #@param {'type':'string'}

if opt_out == True :
    # Move file
  src_file = f'{output_dir}/last.ckpt'
  dst_file = f'{cloned_model_repo}/{model_name}.ckpt'
  if os.path.exists(src_file):
      if not os.path.exists(dst_file):
          shutil.move(src_file, dst_file)
          print(f'Moved {src_file} to {dst_file}\n', flush=True)
      else:
          print(f'{dst_file} already exists\n', flush=True)
  else:
      print(f'There is no {src_file} like that\n', flush=True)

  # Move folder
  src_folder = f'{output_dir}/last-state'
  dst_folder = f'{cloned_datasets_repo}/{save_state_name}'
  if os.path.exists(src_folder):
      if not os.path.exists(dst_folder):
          shutil.move(src_folder, dst_folder)
          print(f'Moved {src_folder} to {dst_folder}\n', flush=True)
      else:
          print(f'{dst_folder} already exists\n', flush=True)
  else:
      print(f'There is no {src_folder} like that\n', flush=True)

  # Define train data directory
  dst_train_data_dir = f'{cloned_datasets_repo}/train_data'

  # Check if train data directory already exists
  if not os.path.exists(dst_train_data_dir):
    # Move train data directory
    src_train_data_dir = f'{train_data_dir}'
    if os.path.exists(src_train_data_dir):
        shutil.move(src_train_data_dir, dst_train_data_dir)
        print(f'Moved {src_train_data_dir} to {dst_train_data_dir}\n', flush=True)
    else:
        print(f'There is no {src_train_data_dir} like that\n', flush=True)

  # Define meta lat json directory
  dst_meta_lat_json_dir = f'{cloned_datasets_repo}/meta_lat.json'

  # Check if meta lat json directory already exists
  if not os.path.exists(dst_meta_lat_json_dir):
    # Move meta lat json directory
    src_meta_lat_json_dir = f'{meta_lat_json_dir}'
    if os.path.exists(src_meta_lat_json_dir):
        shutil.move(src_meta_lat_json_dir, dst_meta_lat_json_dir)
        print(f'Moved {src_meta_lat_json_dir} to {dst_meta_lat_json_dir}\n', flush=True)
    else:
        print(f'There is no {src_meta_lat_json_dir} like that\n', flush=True)

  # Iterate over all files and folders in the cloned_datasets_repo directory
  for filename in os.listdir(cloned_datasets_repo):
    # Check if the file or folder is not the save_state_name folder, the train_data folder, or the meta_lat.json file
    if filename != save_state_name and filename != os.path.basename(dst_train_data_dir) and filename != os.path.basename(dst_meta_lat_json_dir):
      # Get the path of the file or folder
      file_path = os.path.join(cloned_datasets_repo, filename)

      # Check if the file or folder is a directory (i.e., a folder)
      if os.path.isdir(file_path):
        # Delete the folder
        shutil.rmtree(file_path)
        print(f'Deleted folder: {filename}')
      else:
        # Delete the file
        os.remove(file_path)
        print(f'Deleted file: {filename}')



In [None]:
#@title Using epochs instead of max training step
#@markdown ### Define Parameters

import glob
import math

V2 = "none" #@param ["none", "V2_base", "V2_768_v"] {allow-input: false}
num_cpu_threads_per_process = 8 #@param {'type':'integer'}
save_state = True #@param {'type':'boolean'}
train_batch_size = 4  #@param {type: "slider", min: 1, max: 10}
learning_rate ="1e-4" #@param {'type':'string'}
num_epoch = 2 #@param {'type':'integer'}
dataset_repeats = 1 #@param {'type':'integer'}
train_text_encoder = False #@param {'type':'boolean'}
lr_scheduler = "constant" #@param  ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] {allow-input: false}
max_token_length = "225" #@param  ["150", "225"] {allow-input: false}
clip_skip = 2 #@param {type: "slider", min: 1, max: 10}
mixed_precision = "fp16" #@param ["no","fp16","bf16"] {allow-input: false}
save_model_as = "ckpt" #@param ["default", "ckpt", "safetensors", "diffusers", "diffusers_safetensors"] {allow-input: false}
save_precision = "None" #@param ["None","float", "fp16", "bf16"] {allow-input: false}
save_every_n_epochs = 50 #@param {'type':'integer'}
gradient_accumulation_steps = 1 #@param {type: "slider", min: 1, max: 10}
#@markdown ### Log And Debug
log_prefix = "fine-tune-style1" #@param {'type':'string'}
logs_dst = "/content/fine_tune/training_logs" #@param {'type':'string'}
debug_mode = False #@param {'type':'boolean'}

#V2 Inference

# Check if directory exists
if not os.path.exists(output_dir):
  # Create directory if it doesn't exist
  os.makedirs(output_dir)

inference_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/"

if V2 == "V2_base":
  v2_model = "--v2"
  v2_768v_model= ""
  inference_url += "v2-inference.yaml"
elif V2 == "V2_768_v":
  v2_model = "--v2"
  v2_768v_model = "--v2_parameterization"
  inference_url += "v2-inference-v.yaml"
else:
  v2_model = ""
  v2_768v_model = ""

try:
  if V2 != "none":
    !wget {inference_url} -O {output_dir}/last.yaml
    print("File successfully downloaded")
except:
  print("There was an error downloading the file. Please check the URL and try again.")

if V2 == "none":
  penultimate_layer = "--clip_skip" + "=" + "{}".format(clip_skip)
else:
  penultimate_layer = ""

if save_model_as == "default":
  sv_model = ""
else: 
  sv_model = "--save_model_as " + str(save_model_as)

if save_state == True:
  sv_state = "--save_state"
else:
  sv_state = ""

if resume_path == "":
  rs_state = ""
else:
  rs_state = "--resume " + str(resume_path)

if save_every_n_epochs == 0 :
  save_epoch = ""
else:
  save_epoch = "--save_every_n_epochs" + "=" + "{}".format(save_every_n_epochs)

if save_precision == "None":
  sv_precision = ""
else :
  sv_precision = "--save_precision=" + str(save_precision)

if debug_mode == True:
  debug_dataset = "--debug_dataset"
else:
  debug_dataset = ""

if train_text_encoder == True:
  text_encoder = "--train_text_encoder"
else:
  text_encoder = ""

# Get number of valid images
image_num = len(glob.glob(train_data_dir + "/*.npz"))

print("Total Train Data =", image_num)
print("Total Epoch=", num_epoch)
print("Dataset repeats =", dataset_repeats, "x")
repeats = image_num * dataset_repeats
print("Total Repeats =", image_num, "*", dataset_repeats, "=", repeats)

# calculate max_train_steps
max_train_steps = math.ceil(repeats / train_batch_size * num_epoch)
print("max_train_steps =", repeats, "/", train_batch_size, "*", num_epoch ,"=", max_train_steps, "\n")

%cd /content/kohya-trainer

!accelerate launch \
  --config_file {accelerate_config} \
  --num_cpu_threads_per_process {num_cpu_threads_per_process} \
  fine_tune.py \
  {v2_model} \
  {v2_768v_model} \
  --pretrained_model_name_or_path={pre_trained_model_path} \
  --in_json {meta_lat_json_dir} \
  --train_data_dir={train_data_dir} \
  --output_dir={output_dir} \
  --shuffle_caption \
  --keep_tokens 1 \
  --train_batch_size={train_batch_size} \
  --learning_rate={learning_rate} \
  --lr_scheduler={lr_scheduler} \
  --max_token_length={max_token_length} \
  {penultimate_layer} \
  --mixed_precision={mixed_precision} \
  --max_train_steps={max_train_steps} \
  --use_8bit_adam \
  --xformers \
  --gradient_checkpointing \
  --gradient_accumulation_steps {gradient_accumulation_steps} \
  {sv_model} \
  {text_encoder} \
  {sv_state} \
  {rs_state} \
  {save_epoch} \
  {sv_precision} \
  {debug_dataset} \
  --logging_dir={logs_dst} \
  --log_prefix {log_prefix}

