<a href="https://colab.research.google.com/github/Linaqruf/erasing/blob/main/Erasing_Stable_Diffusion_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Install
import os
import shutil
import zipfile

#@markdown Clone ESD Repository
root_dir = "/content"
deps_dir = os.path.join(root_dir, "deps")
repo_dir = os.path.join(root_dir, "stable-diffusion")
models_dir =  os.path.join(repo_dir, "ldm")
train_dir = os.path.join(repo_dir, "train-scripts")

repo_url = "https://github.com/Linaqruf/erasing"
bitsandytes_main_py = "/usr/local/lib/python3.9/dist-packages/bitsandbytes/cuda_setup/main.py"

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents


def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)


def clone_repo(url):
    if not os.path.exists(repo_dir):
        os.chdir(root_dir)
        !git clone {url} {repo_dir}
    else:
        os.chdir(repo_dir)
        !git pull


def ubuntu_deps(url, name, dst):
    !wget -q --show-progress {url}
    with zipfile.ZipFile(name, "r") as deps:
        deps.extractall(dst)
    !dpkg -i {dst}/*
    os.remove(name)
    shutil.rmtree(dst)


def install_dependencies():
    !pip install omegaconf einops diffusers transformers "pytorch-lightning==1.6.5" kornia "bitsandbytes==0.35.0" safetensors
    !pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
    !pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
    !pip install -e .


def remove_bitsandbytes_message(filename):
    welcome_message = """
def evaluate_cuda_setup():
    print('')
    print('='*35 + 'BUG REPORT' + '='*35)
    print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
    print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
    print('='*80)"""

    new_welcome_message = """
def evaluate_cuda_setup():
    import os
    if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
        print('')
        print('=' * 35 + 'BUG REPORT' + '=' * 35)
        print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
        print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
        print('To hide this message, set the BITSANDBYTES_NOWELCOME variable like so: export BITSANDBYTES_NOWELCOME=1')
        print('=' * 80)"""

    contents = read_file(filename)
    new_contents = contents.replace(welcome_message, new_welcome_message)
    write_file(filename, new_contents)

def main():
    clone_repo(repo_url)

    os.chdir(repo_dir)
    
    !apt -y update -qq
    !apt install libunwind8-dev -qq

    ubuntu_deps(
        "https://huggingface.co/Linaqruf/fast-repo/resolve/main/deb-libs.zip",
        "deb-libs.zip",
        deps_dir,
    )
    install_dependencies()

    remove_bitsandbytes_message(bitsandytes_main_py)

    os.environ["LD_PRELOAD"] = "libtcmalloc.so"
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["BITSANDBYTES_NOWELCOME"] = "1"  
    os.environ["LD_LIBRARY_PATH"] = "/usr/local/cuda/lib64/:$LD_LIBRARY_PATH"
    os.environ["SAFETENSORS_FAST_GPU"] = "1"



main()

In [None]:
# @title Download Pretrained Model
import os

%store -r

os.chdir(root_dir)

model_url = "https://huggingface.co/cag/anything-v3-1/resolve/main/anything-v3-1.safetensors"  # @param {'type': 'string'}


def install(url):
    base_name = os.path.basename(url)

    if url.startswith("https://drive.google.com"):
        os.chdir(models_dir)
        !gdown --fuzzy {url}
    elif url.startswith("https://huggingface.co/"):
        if "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
        # @markdown Change this part with your own huggingface token if you need to download your private model
        hf_token = "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"  # @param {type:"string"}
        user_header = f'"Authorization: Bearer {hf_token}"'
        !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {models_dir} -o {base_name} {url}
    else:
        !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {models_dir} {url}


install(model_url)

In [None]:
#@title Train
#@markdown Erase a concept from diffusion model `(e.g. "Van Gogh")`.
prompt = "nsfw" #@param {type: 'string'}
#@markdown Choose a parameter to train for erasure from the following options: `('xattn','noxattn', 'selfattn', 'full')`.
train_method = "full" #@param ['xattn','noxattn', 'selfattn', 'full']
#@markdown Set the guidance value for generating images for training.
start_guidance = 3 #@param {type: 'number'}
#@markdown Set the guidance value for erasing the concept from the diffusion model.
negative_guidance = 1 #@param {type: 'number'}
#@markdown Set the number of iterations to train.
iterations = 1000 #@param {type: 'integer'}
#@markdown Set the learning rate for fine-tuning.
lr = 1e-5 #@param {type: 'number'}
#@markdown Set the config path for the CompVis diffusion format.
config_path = "/content/stable-diffusion/configs/stable-diffusion/v1-inference.yaml" #@param {type: 'string'}
#@markdown Set the checkpoint path for the pre-trained CompVis diffusion weights.
ckpt_path = "/content/stable-diffusion/ldm/anything-v3-1.safetensors" #@param {type: 'string'}
#@markdown Use this field to separate individual prompts for simultaneous erasures if the prompt contains commas.
seperator = "" #@param {type: 'string'}
#@markdown Set the image size for generated images.
image_size = 512 #@param {type: 'number'}
#@markdown Set the number of diffusion time steps.
ddim_steps = 50 #@param {type: 'number'}

config = {
    "prompt": prompt,
    "train_method": train_method,
    "start_guidance": float(start_guidance),
    "negative_guidance": float(negative_guidance),
    "iterations": iterations,
    "lr": float(lr),
    "lowram": True,
    "config_path": config_path,
    "ckpt_path": ckpt_path,
    "devices": "0,0",
    "accumulation_steps": 1,
    "use_8bit_adam": True,
    "image_size": image_size,
    "ddim_steps": ddim_steps,
}

args = ""
for k, v in config.items():
    if isinstance(v, str):
        args += f'--{k}="{v}" '
    if isinstance(v, bool) and v:
        args += f"--{k} "
    if isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    if isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

os.chdir(repo_dir)
final_args = f"python train-scripts/train-esd.py {args} {'--seperator ' + seperator if seperator else ''}"

!{final_args}