<a href="https://colab.research.google.com/github/AgileDevArt/stable-diffusion-plugins/blob/main/stable_diffusion_backend.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Stable Diffusion Backend

Original Notebook by [blueturtleai](https://github.com/blueturtleai/gimp-stable-diffusion)

By using this Notebook, you agree to the following Terms of Use, and license:

**Stablity.AI Model Terms of Use**

This model is open access and available to all, with a CreativeML OpenRAIL-M license further specifying rights and usage.

The CreativeML OpenRAIL License specifies:

You can't use the model to deliberately produce nor share illegal or harmful outputs or content.
CompVis claims no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in the license.
You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users (please read the license entirely and carefully)


Please read the full license here: https://huggingface.co/spaces/CompVis/stable-diffusion-license

In [None]:
#@title NVIDIA GPU
import subprocess
sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')
print(sub_p_res)

In [None]:
#@title Mount Google Drive
from google.colab import drive # type: ignore

try:
   drive_path = "/content/drive"
   drive.mount(drive_path,force_remount=False)
except:
   print("...error mounting drive or with drive path variables")



In [None]:
#@title Set Model Path
#@markdown **Hints**
#@markdown - It is recommended to use the default path. That way you don't have to adjust the path manually every time. 
#@markdown - If the model file doesn't exist at this location, it is automatically downloaded from Huggingface. Please make sure you have enough free space on your drive (about 4 GB).
#@markdown - **For an individual path:**
#@markdown  - Click on the folder symbol on the left. Open the "drive/MyDrive" folder and navigate to the model file which you uploaded before. Select the model file, click on the three dots and select "copy path". Close the file explorer via the cross. 
#@markdown  - Insert the copied path into the field. Remove the filename and the last "/" at the end. The path should now look for example like this ```/content/drive/MyDrive/AI/models```.

import os

# ask for the link
print("Local Path Variables:\n")

models_path = "/content/models"
output_path = "/content/output"

models_path_gdrive = "/content/drive/MyDrive/AI/models/" #@param {type:"string"}
output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion"
models_path = models_path_gdrive
output_path = output_path_gdrive

os.makedirs(models_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)

print(f"models_path: {models_path}")
print(f"output_path: {output_path}")

In [None]:
#@title Setup Environment

setup_environment = True
print_subprocess = True

if setup_environment:
    import subprocess, time
    print("Setting up environment...")
    start_time = time.time()
    all_process = [
        ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],
        ['pip', 'install', 'omegaconf==2.2.3', 'einops==0.4.1', 'pytorch-lightning==1.7.4', 'torchmetrics==0.9.3', 'torchtext==0.13.1', 'transformers==4.21.2', 'kornia==0.6.7', 'diffusers', 'xformers', 'triton==2.0.0.dev20221120'],
        ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq'],
        ['pip', 'install', 'pyngrok', 'flask-cloudflared'],
        ['git', 'lfs', 'install'],
    ]
    for process in all_process:
        running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')
        if print_subprocess:
            print(running)
    
    end_time = time.time()
    print(f"Environment set up in {end_time-start_time:.0f} seconds")

In [None]:
import torch
from torch import autocast
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionInpaintPipeline
from diffusers import StableDiffusionUpscalePipeline

#@title **Select and Load Model**
#@markdown **Hints**

#@markdown If the model file doesn't exist at the location you selected above, it will automatically be downloaded from Huggingface. 
#@markdown Please make sure you have enough free space on your drive (about 4 GB). It is necessary, that you created an account on Huggingface and accepted the terms of service. Otherwise it's not possible to download the file. For the download you need a Huggingface token:
#@markdown  - Login on Huggingface and select "Settings/Access Tokens" on the left.
#@markdown  - Click on "New Token", enter a name, select "Read" as role, click on create and copy the token.
model_checkpoint =  "stable-diffusion-2-base" #@param ["stable-diffusion-2-base", "stable-diffusion-2", "stable-diffusion-2-inpainting", "stable-diffusion-x4-upscaler", "stable-diffusion-inpainting", "stable-diffusion-v1-5"]

half_precision = False

model_map = {
    "stable-diffusion-2-base": {
        'url': 'https://huggingface.co/stabilityai/stable-diffusion-2-base',
        'requires_login': True,
        },
    "stable-diffusion-2": {
        'url': 'https://huggingface.co/stabilityai/stable-diffusion-2',
        'requires_login': True,
        },
    "stable-diffusion-2-inpainting": {
        'url': 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting',
        'requires_login': True,
        },
    "stable-diffusion-x4-upscaler": {
        'url': 'https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler',
        'requires_login': True,
        },
    "stable-diffusion-inpainting": {
        'url': 'https://huggingface.co/runwayml/stable-diffusion-inpainting',
        'requires_login': True,
        },
    "stable-diffusion-v1-5": {
        'url': 'https://huggingface.co/runwayml/stable-diffusion-v1-5',
        'requires_login': True,
        }
}

# checkpoint path or download
ckpt_path = os.path.join(models_path, model_checkpoint)
if os.path.exists(ckpt_path):
   print(f"{ckpt_path} exists...updating...")
   #running = subprocess.run(['git', '-C', ckpt_path, 'reset', '--hard'], stdout=subprocess.PIPE).stdout.decode('utf-8')
   print(running)
elif 'url' in model_map[model_checkpoint]:
    url = model_map[model_checkpoint]['url']

    # CLI dialogue to authenticate download
    if model_map[model_checkpoint]['requires_login']:
        print("This model requires an authentication token")
        print("Please ensure you have accepted its terms of service before continuing.\n")
        print("Press enter after you inserted your username")

        username = input("What is your huggingface username?:")
        print("\n")
        print("Press enter after you inserted your token")
        token = input("What is your huggingface token?:")
        print("\n")

        _, path = url.split("https://")

        url = f"https://{username}:{token}@{path}"

    # contact server for model
    print(f"Attempting to download {model_checkpoint}...this may take a while")
    running = None
    if half_precision:
       running = subprocess.run(['git', '-C', models_path, 'clone', '-b', 'fp16', url], stdout=subprocess.PIPE).stdout.decode('utf-8')
    else:
       running = subprocess.run(['git', '-C', models_path, 'clone', url], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(running)
else:
    raise Exception(f"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}")

print(f"Using model: {ckpt_path}")
device_type = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_type)
torch_dtype = torch.float16 if half_precision else torch.float32
if torch.cuda.is_available():
   torch.backends.cudnn.benchmark = True
   torch.backends.cuda.matmul.allow_tf32 = True
print(f"Using torch_dtype {torch_dtype} with {device}")

text2img = None
img2img = None
inpaint = None
upscale = None

if model_checkpoint.endswith('upscaler'):
   upscale = StableDiffusionUpscalePipeline.from_pretrained(ckpt_path, torch_dtype=torch_dtype)
elif model_checkpoint.endswith('inpainting'):
   inpaint = StableDiffusionInpaintPipeline.from_pretrained(ckpt_path, torch_dtype=torch_dtype)
elif model_checkpoint.startswith('stable-diffusion-2'):
   text2img = DiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch_dtype)
else: # sd_v1
   text2img = StableDiffusionPipeline.from_pretrained(ckpt_path, torch_dtype=torch_dtype)
   img2img = StableDiffusionImg2ImgPipeline(**text2img.components)

if text2img is not None:
   text2img = text2img.to(device_type)  
   if torch.cuda.is_available():
      text2img.enable_xformers_memory_efficient_attention()
      text2img.enable_attention_slicing()
if img2img is not None:
   img2img = img2img.to(device_type) 
   if torch.cuda.is_available():
      img2img.enable_xformers_memory_efficient_attention()
      img2img.enable_attention_slicing()
if inpaint is not None:
   inpaint = inpaint.to(device_type) 
   if torch.cuda.is_available():
      inpaint.enable_xformers_memory_efficient_attention()
      inpaint.enable_attention_slicing()
if upscale is not None:
   upscale = upscale.to(device_type) 
   if torch.cuda.is_available():
      upscale.enable_xformers_memory_efficient_attention()
      upscale.enable_attention_slicing()

def get_latents(seed, height, width):
    generator = torch.Generator(device=device)
    generator.manual_seed(seed)
    return torch.randn(
            (1, 4, height, width),
            generator = generator,
            device = device)

def render_image(pipe, prompt, init_image, mask_image, strength, steps, scale, height, width, seed):
   with torch.inference_mode():
      # with autocast(device_type):
         if pipe == text2img:
            latents = get_latents(seed, height // 8, width // 8)
            return pipe(prompt, 
                      #  init_image,
                      #  mask_image=mask_image, 
                      #  strength=strength, 
                        guidance_scale=scale,
                        num_inference_steps=steps,
                        latents=latents).images  

         if pipe == img2img:
            latents = get_latents(seed, height // 8, width // 8)
            return pipe(prompt, 
                        init_image,
                      #  mask_image=mask_image, 
                        strength=strength, 
                        guidance_scale=scale,
                        num_inference_steps=steps,
                        latents=latents).images

         if pipe == inpaint:
            latents = get_latents(seed, height // 8, width // 8)
            return pipe(prompt, 
                        init_image,
                        mask_image=mask_image, 
                      #  strength=strength, 
                        guidance_scale=scale,
                        num_inference_steps=steps,
                        latents=latents).images

         if pipe == upscale:
            latents = get_latents(seed, height, width)
            return pipe(prompt, 
                        init_image,
                      #  mask_image=mask_image, 
                      #  strength=strength, 
                        guidance_scale=scale,
                        num_inference_steps=steps,
                        latents=latents).images

In [None]:
#@title Waiting for plugin requests

import gc, math, os, pathlib, sys, time, random
import json
import base64
import subprocess
import torch
import numpy as np
import torchvision.transforms as T
from torch import autocast
from einops import rearrange, repeat
from torchvision.utils import make_grid
from flask import Flask, Response, request, abort, make_response
from flask_cloudflared import run_with_cloudflared
from IPython import display
from io import BytesIO
from PIL import Image, ImageOps
from types import SimpleNamespace

prompts = []

def DeforumArgs():
    W = 0
    H = 0

    seed = -1
    sampler = 'klms'
    steps = 0
    scale = 0
    ddim_eta = 0.0
    dynamic_threshold = None
    static_threshold = None   
    pipe = None

    save_samples = True
    save_settings = False
    display_samples = False

    n_batch = 1
    batch_name = "PLUGIN"
    filename_format = "{timestring}_{index}_{prompt}.png"
    make_grid = False
    grid_rows = 2 
    outdir = get_output_folder(output_path, batch_name)

    use_init = False
    strength = 0
    init_image = None
    # Whiter areas of the mask are areas that change more
    use_mask = False
    invert_mask = False
    mask_image = None

    prompt = ""
    timestring = ""

    return locals()

def sanitize(prompt):
    whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ')
    tmp = ''.join(filter(whitelist.__contains__, prompt))
    return tmp.replace(' ', '_')

def get_output_folder(output_path, batch_folder):
    out_path = os.path.join(output_path,time.strftime('%Y-%m'))
    if batch_folder != "":
        out_path = os.path.join(out_path, batch_folder)
    os.makedirs(out_path, exist_ok=True)
    return out_path

def load_img(image, shape, use_alpha_as_mask=False, invert_mask=False):
    # use_alpha_as_mask: Read the alpha channel of the image as the mask image
    if isinstance(image, str):
        path = image
        if os.path.exists(path):
           image = Image.open(path)
        else:
          return None, None

    mask_image = None
    if use_alpha_as_mask:
        image = image.convert('RGBA')
        red, green, blue, alpha = Image.Image.split(image)
        mask_image = alpha.convert('L')
        if not invert_mask:
           mask_image = ImageOps.invert(mask_image)
    else:
        mask_image = Image.new("RGB", (image.height, image.width))

    image = image.convert('RGB')
    image = image.resize(shape, resample=Image.LANCZOS)

    return image, mask_image

def render_image_batch(args):
    args.prompts = {k: f"{v:05d}" for v, k in enumerate(prompts)}
    
    index = 0
    
    # function for init image batching
    init_array = []
    if args.use_init:
        if not isinstance(args.init_image, str):
            init_array = [args.init_image]
        elif args.init_image == "":
            raise FileNotFoundError("No path was given for init_image")
        elif args.init_image.startswith('http://') or args.init_image.startswith('https://'):
            init_array.append(args.init_image)
        elif not os.path.isfile(args.init_image):
            if args.init_image[-1] != "/": # avoids path error by adding / to end if not there
                args.init_image += "/" 
            for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array
                if image.split(".")[-1] in ("png", "jpg", "jpeg"):
                    init_array.append(args.init_image + image)
        else:
            init_array.append(args.init_image)
    else:
        init_array = [""]

    # when doing large batches don't flood browser with images
    clear_between_batches = args.n_batch >= 32

    for iprompt, prompt in enumerate(prompts):  
        args.prompt = prompt
        print(f"Prompt {iprompt+1} of {len(prompts)}")
        print(f"{args.prompt}")

        all_images = []

        for batch_index in range(args.n_batch):
            if clear_between_batches and batch_index % 32 == 0: 
                display.clear_output(wait=True)            
            print(f"Batch {batch_index+1} of {args.n_batch}")
            
            for image in init_array: # iterates the init images
                #The mask structure is white for inpainting and black for keeping as is
                args.init_image, args.mask_image = load_img(image, 
                                                            shape=(args.W, args.H),  
                                                            use_alpha_as_mask=args.use_mask,
                                                            invert_mask=args.invert_mask)
                
                if args.save_samples:
                   if args.init_image is not None:
                      args.init_image.save(os.path.join(output_path, "init.png"))
                   if args.mask_image is not None:
                      args.mask_image.save(os.path.join(output_path, "mask.png"))

                results = render_image(args.pipe, args.prompt, args.init_image, args.mask_image, args.strength, args.steps, args.scale, args.H, args.W, args.seed)       

                for image in results:
                    if args.make_grid:
                        all_images.append(T.functional.pil_to_tensor(image))
                    if args.save_samples:
                        if args.filename_format == "{timestring}_{index}_{prompt}.png":
                            filename = f"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png"
                        else:
                            filename = f"{args.timestring}_{index:05}_{args.seed}.png"
                        image.save(os.path.join(args.outdir, filename))
                    if args.display_samples:
                        display.display(image)
                    index += 1

        if args.make_grid:
            grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows))
            grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()
            filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png"
            grid_image = Image.fromarray(grid.astype(np.uint8))
            grid_image.save(os.path.join(args.outdir, filename))
            display.clear_output(wait=True)            
            display.display(grid_image)

    return results

args = SimpleNamespace(**DeforumArgs())
args.timestring = time.strftime('%Y%m%d%H%M%S')

API_VERSION = 5

app = Flask(__name__)

@app.route("/api/generate", methods=["POST"])
def generateImages():
    r = request
    data = r.data.decode("utf-8")
    data = json.loads(data)

    api_version = 0

    if "api_version" in data:
       api_version = int(data["api_version"])

    if api_version != API_VERSION:
       abort(405)
       
    print("\n")
    print("Parameters sent from Plugin")
    print("mode: " + data["mode"] + ", init_strength: " + str(data["init_strength"]) + ", prompt_strength: " + str(data["prompt_strength"]) + ", steps: " + str(data["steps"]) + ", width: " + str(data["width"]) + ", height: " + str(data["height"]) + ", prompt: " + data["prompt"] + ", seed: " + str(data["seed"]) + ", api_version: " + str(data["api_version"]))
    print("\n")

    args.W, args.H = map(lambda x: x - x % 64, (int(data["width"]), int(data["height"])))
    args.strength = max(0.0, min(1.0, float(data["init_strength"])))
    args.scale = float(data["prompt_strength"])
    args.steps = int(data["steps"])
    args.use_init = False
    args.use_mask = False
    args.init_image = None
    args.mask_image = None
    args.pipe = None

    init_img = ""
    if data["mode"] == "MODE_IMG2IMG" or data["mode"] == "MODE_INPAINTING" or data["mode"] == "MODE_UPSCALING":
       img_data = base64.b64decode(data["init_img"])
       img_stream = BytesIO(img_data)
       args.init_image = Image.open(img_stream)

       args.use_init = True   
       args.pipe = img2img

       if data["mode"] == "MODE_UPSCALING":
          args.strength = 0.0
          args.pipe = upscale

       if data["mode"] == "MODE_INPAINTING":
          args.strength = 0.0
          args.use_mask = True
          args.pipe = inpaint
    else:
       args.pipe = text2img

    if args.pipe is None:
      abort(406)
    
    global prompts
    prompts = [data["prompt"]]

    args.prompt = prompts
    imgs_return = []

    for counter in range(data["image_count"]):
       # clean up unused memory
       gc.collect()
       torch.cuda.empty_cache()
    
       args.seed = int(data["seed"]) if int(data["seed"]) != -1 else random.randint(0, 2**32)

       print("Parameters used for generating")
       print(args)

       img = render_image_batch(args)[0]

       img_data = BytesIO()
       img.save(img_data, format="PNG")
       img_data.seek(0)
       img_encoded = base64.b64encode(img_data.read())
       img_encoded = img_encoded.decode("utf-8")

       img_return = {"seed": args.seed, "image": img_encoded}
       imgs_return.append(img_return)

    data_return = {"images": imgs_return}
    data_return = json.dumps(data_return)

    if data["mode"] == "MODE_IMG2IMG" or data["mode"] == "MODE_INPAINTING":
       if os.path.exists(init_img):
          os.remove(init_img)

    response = make_response()
    response.headers["mimetype"] = "application/json"
    response.data = data_return
    return response

run_with_cloudflared(app)
app.run()