<a href="https://colab.research.google.com/github/Linaqruf/sd-notebook-collection/blob/main/sd-colab-toolkit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# I. Installation

In [None]:
#@title ## 1.1 Install Colab Toolkit
#@markdown This will install required Python packages
import os
import zipfile
import shutil
import subprocess
import requests
import re
import time
from tqdm import tqdm
from urllib.parse import urlparse, unquote

root_dir    = "/content"
repo_dir    = os.path.join(root_dir, "sd-scripts")
models_dir  = os.path.join(root_dir, "models")
vaes_dir    = os.path.join(root_dir, "vae")
lora_dir    = os.path.join(root_dir, "network_weight")
deps_dir    = os.path.join(root_dir, "deps")
drive_dir   = os.path.join(root_dir, "drive/MyDrive")
tools_dir   = os.path.join(repo_dir, "tools")

repo_url    = "https://github.com/kohya-ss/sd-scripts"

def cprint(*args, color="default", reset=True, tqdm_desc=False):
    color_codes = {
        "default"     : "\033[0m",
        "green"       : "\033[0;32m",
        "red"         : "\033[0;31m",
        "bold_green"  : "\033[1;32m",
        "bold_red"    : "\033[1;31m",
    }
    
    if color in color_codes:
        color_start = color_codes[color]
        color_end = ""
        if reset:
            color_end = color_codes["default"]
        formatted_text = " ".join(str(arg) for arg in args)
        if tqdm_desc:
            color_end = ""
            return color_start + formatted_text + color_end
        else:
            print(color_start + formatted_text + color_end)
    else:
        if tqdm_desc:
            return " ".join(str(arg) for arg in args)
        else:
            print(*args)

def clone_repo(url, dir):
    if not os.path.exists(dir):
        subprocess.run(["git", "clone", url, dir], check=True)

def ubuntu_deps(url, dst, desc):
    os.makedirs(dst, exist_ok=True)
    filename  = get_filename(url)
    subprocess.run(["wget", url], stdout=subprocess.DEVNULL)

    with zipfile.ZipFile(filename, "r") as deps:
        deps.extractall(dst)

    for file in tqdm(os.listdir(dst), desc=desc):
        if file.endswith(".deb"):
            subprocess.run(["dpkg", "-i", os.path.join(dst, file)], stdout=subprocess.DEVNULL)
    
    os.remove(filename)
    shutil.rmtree(dst)

def get_filename(url):
    response = requests.get(url, stream=True)
    response.raise_for_status()

    if 'content-disposition' in response.headers:
        content_disposition = response.headers['content-disposition']
        filename = re.findall('filename="?([^"]+)"?', content_disposition)[0]
    else:
        url_path = urlparse(url).path
        filename = unquote(os.path.basename(url_path))
        
    return filename

def update_requirements(repo_dir, desired_module, desired_version, filepath):
    with open(filepath, "r") as f:
        lines = f.readlines()

    updated_lines = []
    for line in lines:
        if desired_module in line:
            line = f"{desired_module}=={desired_version}\n"
        updated_lines.append(line)

    with open(filepath, "w") as f:
      f.writelines(updated_lines)

def install_dependencies():
    requirements_file = os.path.join(repo_dir, "requirements.txt")

    ram_patch_url = "https://huggingface.co/Linaqruf/fast-repo/resolve/main/ram_patch.zip"

    cprint(f"Installing ubuntu dependencies...", color="green")
    subprocess.run(["apt", "install", "aria2", "lz4", "libunwind8-dev", "-y"], check=True)

    ubuntu_deps(ram_patch_url, deps_dir, cprint("Installing RAM allocation patch", color="green", tqdm_desc=True))
    
    cprint(f"Installing requirements...", color="green")
    update_requirements(repo_dir, "requests", "2.27.1", requirements_file)
    subprocess.run(['pip', 'install', '--upgrade', '--no-cache-dir', 'gdown'], check=True)
    subprocess.run(['pip', 'install', '--upgrade', '-r', requirements_file], cwd=repo_dir, check=True)

def calculate_elapsed_time(start_time):
    end_time = time.time()
    elapsed_time = int(end_time - start_time)

    if elapsed_time < 60:
        return f"{elapsed_time} sec"
    else:
        mins, secs = divmod(elapsed_time, 60)
        return f"{mins} mins {secs} sec"

def main():
    os.chdir(root_dir)
    start_time = time.time()

    for dir in [models_dir, vaes_dir]:
        os.makedirs(dir, exist_ok=True)

    cprint(f"Installing 'kohya-ss/sd-scripts'...", color="green")
    clone_repo(repo_url, repo_dir)
    install_dependencies()

    elapsed_time = calculate_elapsed_time(start_time)
    
    cprint(f"\nFinished installation. Took {elapsed_time}.", color="green")
    cprint(f"All is done! Go to the next step.", color="green")

main()


In [None]:
# @title ## **1.2. Download Model**
import os
import re
import json
import glob
import gdown
import time
import requests
import subprocess
from IPython.utils import capture
from urllib.parse import urlparse, unquote
from pathlib import Path

os.chdir(root_dir)

# @markdown ### **Download Stable Diffusion Model**

model_url = "https://civitai.com/api/download/models/82446" #@param ["Anime Model", "Anything V3.1", "AnyLoRA", "ChilloutMix Ni", "Stable Diffusion V1.5", "Replicant V3", "Illuminati Diffusion V1.1", "Waifu Diffusion V1.5 Beta 3", "Stable Diffusion V2.1"] {allow-input: true}
# @markdown ### **Download VAE Model**
vae_url = "" #@param ["", "Anime / Anything VAE", "Blessed VAE", "Waifu Diffusion VAE", "Stable Diffusion VAE"] {allow-input: true}
# @markdown ### **Download LoRA Model**
lora_url = "" #@param {type: "string"}

available_models = {
    # SDv1.x Pretrained Model
    "Anime Model"                 : "https://huggingface.co/Linaqruf/personal-backup/resolve/main/models/animefull-final-pruned.ckpt",
    "Anything V3.1"               : "https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors",
    "AnyLoRA"                     : "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt",
    "ChilloutMix Ni"              : "https://huggingface.co/naonovn/chilloutmix_NiPrunedFp32Fix/resolve/main/chilloutmix_NiPrunedFp32Fix.safetensors",
    "Stable Diffusion V1.5"       : "https://huggingface.co/Linaqruf/stolen/resolve/main/pruned-models/stable_diffusion_1_5-pruned.safetensors",
    # SDv2.x Pretrained Model
    "Replicant V3"                : "https://huggingface.co/gsdf/Replicant-V3.0/resolve/main/Replicant-V3.0_fp16.safetensors",
    "Illuminati Diffusion V1.1"   : "https://huggingface.co/Linaqruf/stolen/resolve/main/pruned-models/illuminatiDiffusionV1_v11.safetensors",
    "Waifu Diffusion V1.5 Beta 3" : "https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-beta3-base-fp16.safetensors",
    "Stable Diffusion V2.1"       : "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors",
}

available_vaes = {
    "Anime / Anything VAE"        : "https://huggingface.co/Linaqruf/personal-backup/resolve/main/vae/animevae.pt",
    "Blessed VAE"                 : "https://huggingface.co/NoCrypt/blessed_vae/resolve/main/blessed2.vae.pt",
    "Waifu Diffusion VAE"         : "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt",
    "Stable Diffusion VAE"        : "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt",
}

if model_url is not None:
    valid_model_url = model_url
    if model_url in available_models:
        valid_model_url = available_models[model_url]

if vae_url is not None:
    valid_vae_url = vae_url
    if vae_url in available_vaes:
        valid_vae_url = available_vaes[vae_url]

def get_supported_extensions():
    return tuple([".ckpt", ".safetensors", ".pt", ".pth"])

def get_filename(url, quiet=True):
    extensions = get_supported_extensions()

    if url.startswith("/content/drive/MyDrive/") or url.endswith(tuple(extensions)):
        filename = os.path.basename(url)
    else:
        response = requests.get(url, stream=True)
        response.raise_for_status()

        if 'content-disposition' in response.headers:
            content_disposition = response.headers['content-disposition']
            filename = re.findall('filename="?([^"]+)"?', content_disposition)[0]
        else:
            url_path = urlparse(url).path
            filename = unquote(os.path.basename(url_path))

    if filename is not None and filename.endswith(tuple(extensions)):
        return filename
    else:
        return None

def parse_args(config):
    args = []

    for k, v in config.items():
        if k.startswith("_"):
            args.append(f"{v}")
        elif isinstance(v, str) and v is not None:
            args.append(f'--{k}={v}')
        elif isinstance(v, bool) and v:
            args.append(f"--{k}")
        elif isinstance(v, float) and not isinstance(v, bool):
            args.append(f"--{k}={v}")
        elif isinstance(v, int) and not isinstance(v, bool):
            args.append(f"--{k}={v}")

    return args

def aria2_download(dir, filename, url):
    hf_token    = "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"
    user_header = f"Authorization: Bearer {hf_token}"

    aria2_config = {
        "console-log-level"         : "error",
        "summary-interval"          : 10,
        "header"                    : user_header if "huggingface.co" in url else None,
        "continue"                  : True,
        "max-connection-per-server" : 16,
        "min-split-size"            : "1M",
        "split"                     : 16,
        "dir"                       : dir,
        "out"                       : filename,
        "_url"                      : url,
    }
    aria2_args = parse_args(aria2_config)
    subprocess.run(["aria2c", *aria2_args])

def gdown_download(url, dst):
    if "/uc?id/" in url:
        return gdown.download(url, dst + "/", quiet=False)
    elif "/file/d/" in url:
        return gdown.download(url, dst + "/", quiet=False, fuzzy=True)
    elif "/drive/folders/" in url:
        os.chdir(dst)
        return gdown.download_folder(url, quiet=True, use_cookies=False)

def download(urls, dst, target):
    for url in tqdm(urls, desc=cprint(f"Downloading {target} from url", color="green", tqdm_desc=True)):
        with capture.capture_output() as cap:
            url = url.replace(" ", "")
            try:
                filename = get_filename(url, quiet=False)
            except Exception:
                filename = None
                continue

            if "drive.google.com" in url:
                try:
                    gdown = gdown_download(url, dst)
                except Exception as e:
                    print(f"Error occurred: {str(e)}")
            elif url.startswith("/content/drive/MyDrive/"):
                filepath = os.path.join(dst, filename)
                Path(filepath).write_bytes(Path(url).read_bytes())
            else:
                if "huggingface.co" in url:
                    if "/blob/" in url:
                        url = url.replace("/blob/", "/resolve/")
                aria2_download(dst, filename, url)
        
def main():
    global model_path, vae_path, lora_path

    start_time = time.time()
    model_path = vae_path = lora_path = None

    download_targets = {
        "model": (valid_model_url.split(','), models_dir),
        "vae": (valid_vae_url.split(','), vaes_dir),
        "lora": (lora_url.split(','), lora_dir),
    }

    for target, (urls, dst) in download_targets.items():
        if urls and urls != "PASTE {} URL OR GDRIVE PATH HERE".format(target.upper()):
            initial_files = glob.glob(os.path.join(dst, "*"))
            download(urls, dst, target)

            downloaded_files = []
            for filename in initial_files:
                filepath = os.path.join(dst, filename)
                if os.path.exists(filepath):
                    downloaded_files.append(filepath)

            if len(downloaded_files) == 0:
                downloaded_files = sorted(
                    glob.glob(os.path.join(dst, "*")), key=os.path.getmtime, reverse=True
                )

            cprint(f"Downloaded files for {target}:", color="green")
            for file in downloaded_files:
                cprint("  - ", os.path.basename(file), color="green")

    elapsed_time = calculate_elapsed_time(start_time)
    
    cprint(f"\nFinished installation. Took {elapsed_time}.", color="green")
    cprint(f"All is done! Go to the next step.", color="green")

main()


# 2. Model Conversion

In [None]:
import os
%store -r
#@title ## 7.2. Model Pruner

os.chdir(tools_dir)

if not os.path.exists('prune.py'):
    !wget https://raw.githubusercontent.com/lopho/stable-diffusion-prune/main/prune.py

#@markdown Convert to Float16
fp16 = True #@param {'type':'boolean'}
#@markdown Use EMA for weights
ema = False #@param {'type':'boolean'}
#@markdown Strip CLIP weights
no_clip = False #@param {'type':'boolean'}
#@markdown Strip VAE weights
no_vae = False #@param {'type':'boolean'}
#@markdown Strip depth model weights
no_depth = False #@param {'type':'boolean'}
#@markdown Strip UNet weights
no_unet = False #@param {'type':'boolean'}

model_path = "/content/models/Counterfeit-V3.0_fp16.safetensors" #@param {'type' : 'string'}

config = {
    "fp16": fp16,
    "ema": ema,
    "no_clip": no_clip,
    "no_vae": no_vae,
    "no_depth": no_depth,
    "no_unet": no_unet,
}

suffixes = {
    "fp16": "-fp16",
    "ema": "-ema",
    "no_clip": "-no-clip",
    "no_vae": "-no-vae",
    "no_depth": "-no-depth",
    "no_unet": "-no-unet",
}

print(f"Loading model from {model_path}")

dir_name = os.path.dirname(model_path)
base_name = os.path.basename(model_path)
output_name = base_name.split('.')[0]

for option, suffix in suffixes.items():
    if config[option]:
        print(f"Applying option {option}")
        output_name += suffix
        
output_name += '-pruned'
output_path = os.path.join(dir_name, output_name + ('.ckpt' if model_path.endswith(".ckpt") else ".safetensors"))

args = ""
for k, v in config.items():
    if k.startswith("_"):
        args += f'"{v}" '
    elif isinstance(v, str):
        args += f'--{k}="{v}" '
    elif isinstance(v, bool) and v:
        args += f"--{k} "
    elif isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    elif isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python3 prune.py {model_path} {output_path} {args}"
!{final_args}

print(f"Saving pruned model to {output_path}")

In [None]:
#@title ## 7.1. Convert Diffusers to Checkpoint
import os
%store -r

os.chdir(tools_dir)

#@markdown ### Conversion Config
model_to_load = "" #@param {'type': 'string'}
model_to_save = os.path.splitext(model_to_load)[0]
convert = "checkpoint_to_diffusers" #@param ["diffusers_to_checkpoint", "checkpoint_to_diffusers"] {'allow-input': false}
v2 = True #@param {type:'boolean'}
global_step = 0 #@param {'type': 'number'}
epoch = 0 #@param {'type': 'number'}
use_safetensors = True #@param {'type': 'boolean'}
save_precision_as = "--float" #@param ["--fp16","--bf16","--float"] {'allow-input': false}

#@markdown Additional option for diffusers
feature_extractor = True #@param {'type': 'boolean'}
safety_checker = True #@param {'type': 'boolean'}

reference_model = "stabilityai/stable-diffusion-2-1" if v2 else "runwayml/stable-diffusion-v1-5" 
model_output = f"{model_to_save}.safetensors" if use_safetensors else f"{model_to_save}.ckpt"

urls = [
    ("preprocessor_config.json", "https://huggingface.co/CompVis/stable-diffusion-safety-checker/resolve/main/preprocessor_config.json"),
    ("config.json", "https://huggingface.co/CompVis/stable-diffusion-safety-checker/resolve/main/config.json"),
    ("pytorch_model.bin", "https://huggingface.co/CompVis/stable-diffusion-safety-checker/resolve/main/pytorch_model.bin"),
]

diffusers_to_sd_dict = {
    "_model_to_load": model_to_load,
    "_model_to_save": model_output,
    "global_step": global_step,
    "epoch": epoch,
    "save_precision_as": save_precision_as,
}

sd_to_diffusers_dict = {
    "_model_to_load": model_to_load,
    "_model_to_save": model_to_save,
    "v2": True if v2 else False,
    "v1": True if not v2 else False,
    "global_step": global_step,
    "epoch": epoch,
    "fp16": True if save_precision_as == "fp16" else False,
    "use_safetensors": use_safetensors,
    "reference_model": reference_model
}

def convert_dict(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "

    return args

def run_script(script_name, script_args):
    !python {script_name} {script_args}

def download(output, url, save_dir):
    !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d '{save_dir}' -o '{output}' {url}

diffusers_to_sd_args = convert_dict(diffusers_to_sd_dict)
sd_to_diffusers_args = convert_dict(sd_to_diffusers_dict)

if convert == "diffusers_to_checkpoint":
    if model_to_load.endswith(("ckpt","safetensors")):
        print(f"{os.path.basename(model_to_load)} is not in diffusers format")
    else:
        run_script("convert_diffusers20_original_sd.py", diffusers_to_sd_args)
else:
    if not model_to_load.endswith(("ckpt","safetensors")):
        print(f"{os.path.basename(model_to_load)} is not in ckpt/safetensors format")
    else:     
        run_script("convert_diffusers20_original_sd.py", sd_to_diffusers_args)

        if feature_extractor:
            save_dir = os.path.join(model_to_save, "feature_extractor")
            os.makedirs(save_dir, exist_ok=True)
            output, url = urls[0]
            download(output, url, save_dir)
            
        if safety_checker:
            save_dir = os.path.join(model_to_save, "safety_checker")
            os.makedirs(save_dir, exist_ok=True)
            for output, url in urls[1:]:
                download(output, url, save_dir)

In [None]:
#@title 2.3. Replace VAE of Existing Model 

os.chdir(tools_dir)
if not os.path.exists('merge_vae.py'):
  !wget https://raw.githubusercontent.com/Linaqruf/kohya-trainer/main/tools/merge_vae.py

#@markdown You need to input model ends with `.ckpt`, because `.safetensors` model won't work.

target_model = "" #@param {'type': 'string'}
target_vae = "/content/vae/anime.vae.pt" #@param {'type': 'string'}
use_safetensors = False #@param {type:'boolean'}
# get the base file name and directory
base_name = os.path.basename(target_model)
base_dir = os.path.dirname(target_model)

# get the file name without extension
file_name = os.path.splitext(base_name)[0]

# create the new file name
new_file_name = file_name + "-vae-swapped"

# get the file extension
file_ext = os.path.splitext(base_name)[1]

# create the output file path
output_model = os.path.join(base_dir, new_file_name + file_ext)

!python merge_vae.py \
    {target_model} \
    {target_vae} \
    {output_model}



In [None]:
#@title 2.4. Convert CKPT-2-Safetensors

import os
import torch
from safetensors.torch import load_file, save_file
from torch import load, save

model_path = "" #@param {type: 'string'}

def is_safetensors(path):
  return os.path.splitext(path)[1].lower() == '.safetensors'

def convert(model_path):
  print("Loading model:", os.path.basename(model_path))
  
  try:
      with torch.no_grad():
          print("Conversion in progress, please wait...")
          if is_safetensors(model_path):
            model = load_file(model_path, device="cpu")
          else:
            model = load(model_path, map_location="cpu")
          
          if 'state_dict' in model:
            sd = model['state_dict']
          else:
            sd = model

          save_to = ".ckpt" if is_safetensors(model_path) else ".safetensors"
          output = os.path.splitext(model_path)[0] + save_to

          if is_safetensors(model_path):
            save(sd, output)
          else:
            save_file(sd, output)

      print(f'Successfully converted {os.path.basename(model_path)} to {os.path.basename(output)}')
      print(f'located in this path : {output}')
  except Exception as ex:
      print(f'ERROR converting {os.path.basename(model_path)}: {ex}')

  print('Done!')

def main():
  convert(model_path)
main()


# VIII. Deployment

In [None]:
# @title ## 7.1. Upload Config
from huggingface_hub import login
from huggingface_hub import HfApi
from huggingface_hub.utils import validate_repo_id, HfHubHTTPError

# @markdown Login to Huggingface Hub
# @markdown > Get **your** huggingface `WRITE` token [here](https://huggingface.co/settings/tokens)
write_token = ""  # @param {type:"string"}
# @markdown Fill this if you want to upload to your organization, or just leave it empty.
orgs_name = ""  # @param{type:"string"}
# @markdown If your model/dataset repo does not exist, it will automatically create it.
repo_name = "stolen"  # @param{type:"string"}
make_private = False  # @param{type:"boolean"}

def authenticate(write_token):
    login(write_token, add_to_git_credential=True)
    api = HfApi()
    return api.whoami(write_token), api

def create_repo(api, user, orgs_name, repo_name, make_private=False):
    if orgs_name == "":
        repo_id = user["name"] + "/" + repo_name.strip()
    else:
        repo_id = orgs_name + "/" + repo_name.strip()

    try:
        validate_repo_id(repo_id)
        api.create_repo(repo_id=repo_id, private=make_private)
        print(f"Model repo '{repo_id}' didn't exist, creating repo")
    except HfHubHTTPError as e:
        print(f"Model repo '{repo_id}' exists, skipping create repo")
    
    print(f"Model repo '{repo_id}' link: https://huggingface.co/{repo_id}\n")
    return repo_id

user, api = authenticate(write_token)

model_repo = create_repo(api, user, orgs_name, repo_name, make_private)


## 8.2. Upload with Huggingface Hub

In [None]:
#@title ### 8.2.1. Upload Model
from huggingface_hub import HfApi
from huggingface_hub.utils import validate_repo_id, HfHubHTTPError
from pathlib import Path

api = HfApi()

#@markdown This will be uploaded to model repo

model_path = "/content/models" #@param {type :"string"}
path_in_repo = "fp16" #@param {type :"string"}
revision = "" #@param {type :"string"}
if revision:
  api.create_branch(repo_id=model_repo, 
                branch=revision, 
                exist_ok=True)
else:
  revision = "main"
project_name = os.path.basename(model_path)
if project_name in [".safetensors", "ckpt", "pt"]:
  project_name = os.path.split(model_path)[0]
# @markdown Other Information
commit_message = ""  # @param {type :"string"}

if not commit_message:
    commit_message = "feat: upload " + project_name + " checkpoint"

if os.path.exists(model_path):
  vae_exists = os.path.exists(os.path.join(model_path, 'vae'))
  unet_exists = os.path.exists(os.path.join(model_path, 'unet'))
  text_encoder_exists = os.path.exists(os.path.join(model_path, 'text_encoder'))
    
def upload_model(model_paths, is_folder :bool, commit_message):
  path_obj = Path(model_paths)
  trained_model = path_obj.parts[-1]
  
  if path_in_repo:
    trained_model = path_in_repo
    
  if is_folder == True:
    print(f"Uploading {trained_model} to https://huggingface.co/"+model_repo)
    print(f"Please wait...")

    if vae_exists and unet_exists and text_encoder_exists:
      if not commit_message:
        commit_message = f"feat: upload diffusers version of {trained_model}"

      api.upload_folder(
          folder_path=model_paths,
          repo_id=model_repo,
          revision=revision,
          commit_message=commit_message,
          ignore_patterns=".ipynb_checkpoints"
          )
    
    else:
      if not commit_message:
        commit_message = f"feat: upload {trained_model} checkpoint folder"

      api.upload_folder(
          folder_path=model_paths,
          path_in_repo=trained_model,
          repo_id=model_repo,
          revision=revision,
          commit_message=commit_message,
          ignore_patterns=".ipynb_checkpoints"
          )
    print(f"Upload success, located at https://huggingface.co/"+model_repo+"/tree/main\n")
  else: 
    print(f"Uploading {trained_model} to https://huggingface.co/"+model_repo)
    print(f"Please wait...")
    if not commit_message:
      if model_paths.endswith(".safetensors"):
        commit_message = f"feat: upload safetensors version of {trained_model} "
      else:
        commit_message = f"feat: upload {trained_model} checkpoint"
            
    api.upload_file(
        path_or_fileobj=model_paths,
        path_in_repo=trained_model,
        repo_id=model_repo,
        revision=revision,
        commit_message=commit_message,
        )
        
    print(f"Upload success, located at https://huggingface.co/"+model_repo+"/blob/main/"+trained_model+"\n")
      
def upload():
    if model_path.endswith((".ckpt", ".safetensors", ".pt")):
      upload_model(model_path, False, commit_message)
    else:
      upload_model(model_path, True, commit_message)

upload()

## 8.3. Upload with GIT (Alternative)

In [None]:
#@title ### 8.3.1. Clone Repository
%cd /content/
clone_model = True #@param {'type': 'boolean'}

!git lfs install --skip-smudge
!export GIT_LFS_SKIP_SMUDGE=1

if clone_model:
  !git clone https://huggingface.co/{model_repo} /content/{repo_name}

In [None]:
#@title ### 8.3.2. Commit using Git 
import os

os.chdir(root_dir)

#@markdown Choose which repo you want to commit
commit_model = True #@param {'type': 'boolean'}
#@markdown #### Other Information
commit_message = "" #@param {type :"string"}

if not commit_message:
  commit_message = f"feat: upload {repo_name}"

!git config --global user.email "example@mail.com"
!git config --global user.name "example"

def commit(repo_folder, commit_message):
  os.chdir(os.path.join(root_dir, repo_folder))
  !git lfs install
  !huggingface-cli lfs-enable-largefiles .
  !git add .
  !git commit -m "{commit_message}"
  !git push

commit(repo_name, commit_message)
