<a href="https://colab.research.google.com/github/AK391/fast-stable-diffusion/blob/main/fast-DreamBooth.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Colab From https://github.com/TheLastBen/fast-stable-diffusion, if you have any issues, feel free to discuss them.** 
Run this Notebook manually step by step, don't miss any, the colab is still in progress, trying to find the best settings for Dreambooth


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

# Setting up the environment

In [None]:
#@markdown # Install diffusers
%%capture
%cd /content/
!git clone https://github.com/TheLastBen/diffusers
!pip install git+https://github.com/TheLastBen/diffusers
%pip install transformers
%pip install ftfy
%pip install accelerate
%pip install bitsandbytes

#xformers install for T4, P100 and V100

In [None]:
#@markdown # Cloning repo
%%capture
%cd /content/
!git clone --branch gh/danthe3rd/35/orig https://github.com/facebookresearch/xformers.git

In [None]:
#@markdown # Patching setup.py
%%writefile /content/xformers/setup.py
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import distutils.command.clean
import glob
import os
import re
import shutil
import subprocess
import sys
from pathlib import Path

import setuptools
import torch
from torch.utils.cpp_extension import (
    CUDA_HOME,
    BuildExtension,
    CppExtension,
    CUDAExtension,
)

this_dir = os.path.dirname(os.path.abspath(__file__))


def fetch_requirements():
    with open("requirements.txt") as f:
        reqs = f.read().strip().split("\n")
    return reqs


# https://packaging.python.org/guides/single-sourcing-package-version/
def find_version(version_file_path):
    with open(version_file_path) as version_file:
        version_match = re.search(
            r"^__version__ = ['\"]([^'\"]*)['\"]", version_file.read(), re.M
        )
        # The following is used to build release packages.
        # Users should never use it.
        suffix = os.getenv("XFORMERS_VERSION_SUFFIX", "")
        if version_match:
            return version_match.group(1) + suffix
        raise RuntimeError("Unable to find version string.")


def get_cuda_version(cuda_dir) -> int:
    nvcc_bin = "nvcc" if cuda_dir is None else cuda_dir + "/bin/nvcc"
    raw_output = subprocess.check_output([nvcc_bin, "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
    release = output[release_idx].split(".")
    bare_metal_major = int(release[0])
    bare_metal_minor = int(release[1][0])

    assert bare_metal_minor < 100
    return bare_metal_major * 100 + bare_metal_minor


def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
    # Figure out default archs to target
    DEFAULT_ARCHS_LIST = ""
    if cuda_version > 1100:
        DEFAULT_ARCHS_LIST = "7.5;8.0;8.6"
    elif cuda_version == 1100:
        DEFAULT_ARCHS_LIST = "7.5;8.0"
    else:
        return []

    if os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") != "0":
        return []

    archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST", DEFAULT_ARCHS_LIST)
    nvcc_archs_flags = []
    for arch in archs_list.split(";"):
        assert len(arch) >= 3, f"Invalid sm version: {arch}"

        num = 10 * int(arch[0]) + int(arch[2])
        # Need at least 7.5
        if num < 75:
            continue
        nvcc_archs_flags.append(f"-gencode=arch=compute_{num},code=sm_{num}")
        if arch.endswith("+PTX"):
            nvcc_archs_flags.append(f"-gencode=arch=compute_{num},code=compute_{num}")
    if not nvcc_archs_flags:
        return []

    this_dir = os.path.dirname(os.path.abspath(__file__))
    flash_root = os.path.join(this_dir, "third_party", "flash-attention")
    if not os.path.exists(flash_root):
        raise RuntimeError(
            "flashattention submodule not found. Did you forget "
            "to run `git submodule update --init --recursive` ?"
        )

    return [
        CUDAExtension(
            name="xformers._C_flashattention",
            sources=[
                os.path.join(this_dir, "third_party", "flash-attention", path)
                for path in [
                    "csrc/flash_attn/fmha_api.cpp",
                    "csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu",
                    "csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu",
                    "csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu",
                    "csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
                ]
            ],
            extra_compile_args={
                **extra_compile_args,
                "nvcc": extra_compile_args.get("nvcc", [])
                + [
                    "-O3",
                    "-U__CUDA_NO_HALF_OPERATORS__",
                    "-U__CUDA_NO_HALF_CONVERSIONS__",
                    "--expt-relaxed-constexpr",
                    "--expt-extended-lambda",
                    "--use_fast_math",
                    "--ptxas-options=-v",
                    "-lineinfo",
                ]
                + nvcc_archs_flags,
            },
            include_dirs=[
                Path(flash_root) / "csrc" / "flash_attn",
                Path(flash_root) / "csrc" / "flash_attn" / "src",
                #            Path(flash_root) / 'csrc' / 'flash_attn' / 'cutlass' / 'include',
                Path(this_dir) / "third_party" / "cutlass" / "include",
            ],
        )
    ]


def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(
        this_dir, "xformers", "components", "attention", "csrc"
    )

    main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))

    source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + glob.glob(
        os.path.join(extensions_dir, "autograd", "*.cpp")
    )

    sources = main_file + source_cpu

    source_cuda = glob.glob(
        os.path.join(extensions_dir, "cuda", "**", "*.cu"), recursive=True
    )

    sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
    cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
    if not os.path.exists(cutlass_dir):
        raise RuntimeError(
            "CUTLASS submodule not found. Did you forget "
            "to run `git submodule update --init --recursive` ?"
        )

    extension = CppExtension

    define_macros = []

    extra_compile_args = {"cxx": ["-O3"]}
    if sys.platform == "win32":
        define_macros += [("xformers_EXPORTS", None)]
        extra_compile_args["cxx"].append("/MP")
    elif "OpenMP not found" not in torch.__config__.parallel_info():
        extra_compile_args["cxx"].append("-fopenmp")

    include_dirs = [extensions_dir]
    ext_modules = []

    if (torch.cuda.is_available() and ((CUDA_HOME is not None))) or os.getenv(
        "FORCE_CUDA", "0"
    ) == "1":
        extension = CUDAExtension
        sources += source_cuda
        include_dirs += [sputnik_dir, cutlass_dir]
        nvcc_flags = os.getenv("NVCC_FLAGS", "")
        if nvcc_flags == "":
            nvcc_flags = ["--use_fast_math", "-DNDEBUG"]
        else:
            nvcc_flags = nvcc_flags.split(" ")
        cuda_version = get_cuda_version(CUDA_HOME)
        if cuda_version >= 1102:
            nvcc_flags += [
                "--threads",
                "4",
                "--ptxas-options=-v",
            ]
        extra_compile_args["nvcc"] = nvcc_flags

        ext_modules += get_flash_attention_extensions(
            cuda_version=cuda_version, extra_compile_args=extra_compile_args
        )

    sources = [os.path.join(extensions_dir, s) for s in sources]

    ext_modules.append(
        extension(
            "xformers._C",
            sorted(sources),
            include_dirs=include_dirs,
            define_macros=define_macros,
            extra_compile_args=extra_compile_args,
        )
    )

    return ext_modules


class clean(distutils.command.clean.clean):  # type: ignore
    def run(self):
        if os.path.exists(".gitignore"):
            with open(".gitignore", "r") as f:
                ignores = f.read()
                for wildcard in filter(None, ignores.split("\n")):
                    for filename in glob.glob(wildcard):
                        try:
                            os.remove(filename)
                        except OSError:
                            shutil.rmtree(filename, ignore_errors=True)

        # It's an old-style class in Python 2.7...
        distutils.command.clean.clean.run(self)


if __name__ == "__main__":
    setuptools.setup(
        name="xformers",
        description="XFormers: A collection of composable Transformer building blocks.",
        version=find_version(os.path.join(this_dir, "xformers", "__init__.py")),
        setup_requires=[],
        install_requires=fetch_requirements(),
        packages=setuptools.find_packages(exclude=("tests", "tests.*")),

        url="https://facebookresearch.github.io/xformers/",
        python_requires=">=3.6",
        author="Facebook AI Research",
        author_email="lefaudeux@fb.com",
        long_description="XFormers: A collection of composable Transformer building blocks."
        + "XFormers aims at being able to reproduce most architectures in the Transformer-family SOTA,"
        + "defined as compatible and combined building blocks as opposed to monolithic models",
        long_description_content_type="text/markdown",
        classifiers=[
            "Programming Language :: Python :: 3.7",
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "License :: OSI Approved :: BSD License",
            "Topic :: Scientific/Engineering :: Artificial Intelligence",
            "Operating System :: OS Independent",
        ],
        zip_safe=False,
    )


In [None]:
#@markdown # Installing
%%capture
!pip install /content/xformers
!pip install triton==2.0.0.dev20220701

In [None]:
#@markdown # Precompiled files
%%capture
from subprocess import getoutput
from IPython.display import HTML

s = getoutput('nvidia-smi')
if 'T4' in s:
  gpu = 'T4'
elif 'P100' in s:
  gpu = 'P100'
elif 'V100' in s:
  gpu = 'V100'

if (gpu=='T4'):
  %cd /content/
  !git clone https://github.com/TheLastBen/fast-stable-diffusion
  %cd /content/fast-stable-diffusion/precompiled
  !mv /content/fast-stable-diffusion/precompiled/_C_flashattention.1 /content/fast-stable-diffusion/precompiled/_C_flashattention.7z.001
  !mv /content/fast-stable-diffusion/precompiled/_C_flashattention.2 /content/fast-stable-diffusion/precompiled/_C_flashattention.7z.002
  !7z x /content/fast-stable-diffusion/precompiled/_C_flashattention.7z.001
  !mv -f /content/fast-stable-diffusion/precompiled/_C_flashattention.so /usr/local/lib/python3.7/dist-packages/xformers
  !mv -f /content/fast-stable-diffusion/precompiled/_C.so /usr/local/lib/python3.7/dist-packages/xformers

elif (gpu=='P100'):
  %cd /content/
  !git clone https://github.com/TheLastBen/fast-stable-diffusion
  %cd /content/fast-stable-diffusion/precompiled
  !mv /content/fast-stable-diffusion/precompiled/_C_flashattention-p100.1 /content/fast-stable-diffusion/precompiled/_C_flashattention.7z.001
  !mv /content/fast-stable-diffusion/precompiled/_C_flashattention-p100.2 /content/fast-stable-diffusion/precompiled/_C_flashattention.7z.002
  !7z x /content/fast-stable-diffusion/precompiled/_C_flashattention.7z.001
  !mv -f /content/fast-stable-diffusion/precompiled/_C.flashattention.so /usr/local/lib/python3.7/dist-packages/xformers/_C_flashattention.so
  !mv -f /content/fast-stable-diffusion/precompiled/_C-p100.so /usr/local/lib/python3.7/dist-packages/xformers/_C.so
  
elif (gpu=='V100'):
  %cd /content/
  !git clone https://github.com/TheLastBen/fast-stable-diffusion
  %cd /content/fast-stable-diffusion/precompiled
  !mv /content/fast-stable-diffusion/precompiled/_C_flashattention-v100.1 /content/fast-stable-diffusion/precompiled/_C_flashattention.7z.001
  !mv /content/fast-stable-diffusion/precompiled/_C_flashattention-v100.2 /content/fast-stable-diffusion/precompiled/_C_flashattention.7z.002
  !7z x /content/fast-stable-diffusion/precompiled/_C_flashattention.7z.001
  !mv -f /content/fast-stable-diffusion/precompiled/_C_flashattention.so /usr/local/lib/python3.7/dist-packages/xformers/
  !mv -f /content/fast-stable-diffusion/precompiled/_C-v100.so /usr/local/lib/python3.7/dist-packages/xformers/_C.so
  

# Caching the model in GDrive

In [None]:
#@markdown # Downloading the model
%%capture
import os
Huggingface_Token = "" #@param {type:"string"}
#@markdown ---
#@markdown (Make sure you accepted the terms in https://huggingface.co/CompVis/stable-diffusion-v1-4)
token=Huggingface_Token
if token == "" and not os.path.exists('/content/gdrive/MyDrive/stable-diffusion-v1-4'):
  token=input("Insert your huggingface token :")
  %cd /content/
  !git init
  !git lfs install --system --skip-repo
  !git clone "https://USER:{token}@huggingface.co/CompVis/stable-diffusion-v1-4"

elif not os.path.exists('/content/gdrive/MyDrive/stable-diffusion-v1-4'):
  %cd /content/
  !git init
  !git lfs install --system --skip-repo
  !git clone "https://USER:{token}@huggingface.co/CompVis/stable-diffusion-v1-4"

else:
  print("Model already exists")

!rsync -av --progress /content/stable-diffusion-v1-4 /content/gdrive/MyDrive --exclude .git
!rm -r /content/stable-diffusion-v1-4


# Dreambooth

In [None]:
import os
import shutil
from google.colab import files
#@markdown #Setting up
#@markdown ---
MODEL_NAME="/content/gdrive/MyDrive/stable-diffusion-v1-4"
#@markdown ### Training subject (is it a person ? a dog ? a car ? pick the correct category):
SUBJECT_NAME= "" #@param{type: 'string'}
#@markdown ### Identifier (choose a unique identifier unknown by stable diffusion ):
INSTANCE_NAME= "" #@param{type: 'string'}

#@markdown ### This cell will ask you to upload your reference images, for best result, make sure they are square, eg: 1024x1024
INSTANCE_DIR="/content/data/"+INSTANCE_NAME
!mkdir -p $INSTANCE_DIR
CLASS_DIR="/content/data/"+ SUBJECT_NAME
OUTPUT_DIR="/content/models/"+ INSTANCE_NAME
# upload images
uploaded = files.upload()
for filename in uploaded.keys():
  shutil.move(filename, INSTANCE_DIR)

In [None]:
#@markdown ---
#@markdown #Start DreamBooth
#@markdown ---
Training_Steps=800 #@param{type: 'string'}
Seed="12345" #@param{type: 'string'}
#@markdown ####More steps, better results, but longer training time
!accelerate launch /content/diffusers/examples/dreambooth/train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="photo of {INSTANCE_NAME} {SUBJECT_NAME}"\
  --seed=$Seed \
  --resolution=512 \
  --mixed_precision="fp16" \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --use_8bit_adam \
  --learning_rate=5e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=$Training_Steps

In [None]:
#@markdown #Save the new model in your Gdrive (make sure there is enough space)
!cp -r "/content/models/" /content/gdrive/MyDrive

# Test the model

In [None]:
#@markdown #Load the new Model
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from IPython.display import display

pipe = StableDiffusionPipeline.from_pretrained('/content/gdrive/MyDrive/models/'+INSTANCE_NAME, torch_dtype=torch.float16).to("cuda")
def dummy(images, **kwargs):
    return images, False
pipe.safety_checker = dummy

In [None]:
#@markdown #Stable Diffusion

#@markdown #####Run the Stable Diffusion pipeline with interactive UI Demo on Gradio

#@markdown ---

import gradio as gr

def inference(prompt,Height,Width,Steps,Scale, num_samples):
    all_images = [] 
    with torch.autocast("cuda"):
            images = pipe([prompt] * num_samples, height=Height, width=Width, num_inference_steps=Steps, guidance_scale=Scale).images
            all_images.extend(images)
    return all_images

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="prompt")
            Height = gr.Slider(label="Height",value=512)
            Width = gr.Slider(label="Width",value=512)
            Steps = gr.Slider(label="Steps",value=50)
            samples = gr.Slider(label="Samples",value=1)
            Scale = gr.Slider(label="Scale",value=8)
            run = gr.Button(value="Run")
        with gr.Column():
            gallery = gr.Gallery(show_label=False)

    run.click(inference, inputs=[prompt,Height,Width,Steps,Scale, samples], outputs=gallery)

demo.launch()