<a href="https://colab.research.google.com/github/R3gm/SD_diffusers_interactive/blob/main/Stable_diffusion_interactive_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Stable Diffusion Interactive Notebook 📓 🤖

| Description | Link |
| ----------- | ---- |
| 🎉 Repository | [![GitHub Repository](https://img.shields.io/github/stars/R3gm/SD_diffusers_interactive?style=social)](https://github.com/R3gm/SD_diffusers_interactive) |


- Compel added for SD 1.5 (prompt weights)
- Controlnet 1.1 for SD
- SDXL models only support txt2img
- Lora usually doesn't work correctly alongside Controlnet
- More functions, more bugs; less than 10 words, more laughs




A widgets-based interactive notebook for Google Colab that lets users generate AI images from prompts (Text2Image) using [Stable Diffusion (by Stability AI, Runway & CompVis)](https://en.wikipedia.org/wiki/Stable_Diffusion).

This notebook aims to be an alternative to WebUIs while offering a simple and lightweight GUI for `anyone to get started with Stable Diffusion`.

Uses Stable Diffusion, [HuggingFace](https://huggingface.co/) Diffusers and [Jupyter widgets](https://github.com/jupyter-widgets/ipywidgets).

<br/>

Based on redromnon's repository

[Original GitHub](https://github.com/redromnon/stable-diffusion-interactive-notebook)

In [None]:
#@title 👇 Installing dependencies { display-mode: "form" }
#@markdown ---
#@markdown Make sure to select **GPU** as the runtime type:<br/>
#@markdown *Runtime->Change Runtime Type->Under Hardware accelerator, select GPU*
#@markdown
#@markdown ---

!pip install -q omegaconf==2.3.0 torch git+https://github.com/huggingface/diffusers.git git+https://github.com/damian0815/compel.git invisible_watermark  transformers accelerate scipy safetensors==0.3.3 xformers safetensors mediapy ipywidgets==7.7.1 controlnet_aux==0.0.6 mediapipe==0.10.1 pytorch-lightning asdff

!apt install git-lfs
!git lfs install
!apt -y install -qq aria2

`RESTART THE RUNTIME` before executing the next cell.

In [None]:
#@title 👇 Download Model: Please provide a link for the Civitai API, Google Drive, or Hugging Face. { form-width: "20%", display-mode: "form" }
import os
%cd /content

def download_things(directory, url, hf_token=""):
    url = url.strip()

    if "drive.google.com" in url:
        original_dir = os.getcwd()
        os.chdir(directory)
        !gdown --fuzzy {url}
        os.chdir(original_dir)
    elif "huggingface.co" in url:
        if "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
        user_header = f'"Authorization: Bearer {hf_token}"'
        !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory}  -o {url.split('/')[-1]}
    else:
        !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}

def get_model_list(directory_path):
    model_list = []
    valid_extensions = {'.ckpt' , '.pt', '.pth', '.safetensors', '.bin'}

    for filename in os.listdir(directory_path):
        if os.path.splitext(filename)[1] in valid_extensions:
            name_without_extension = os.path.splitext(filename)[0]
            file_path = os.path.join(directory_path, filename)
            model_list.append((name_without_extension, file_path))
            print('\033[34mFILE: ' + name_without_extension + file_path + '\033[0m')
    return model_list

def process_string(input_string):
    parts = input_string.split('/')

    if len(parts) == 2:
        first_element = parts[1]
        complete_string = input_string
        result = (first_element, complete_string)
        return result
    else:
        return None

directory_models = 'models'
os.makedirs(directory_models, exist_ok=True)
directory_loras = 'loras'
os.makedirs(directory_loras, exist_ok=True)
directory_vaes = 'vaes'
os.makedirs(directory_vaes, exist_ok=True)

#@markdown ---
#@markdown - **Download a Model**
download_model = "https://civitai.com/api/download/models/125771" # @param {type:"string"}
#@markdown - For SDXL models, only diffuser format models are supported, and you only need the repository name.
load_diffusers_format_model = 'Linaqruf/animagine-xl' # @param {type:"string"}
#@markdown - **Download a VAE**
download_vae = "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/VAEs/orangemix.vae.pt" # @param {type:"string"}
#@markdown - **Download a LoRA**
download_lora = "https://civitai.com/api/download/models/97655" # @param {type:"string"}
#@markdown ---
#@markdown **HF TOKEN** - If you need to download your private model from Hugging Face, input your token here.
hf_token = ""  # @param {type:"string"}
#@markdown
#@markdown ---

download_things(directory_models, download_model, hf_token)
download_things(directory_vaes, download_vae, hf_token)
download_things(directory_loras, download_lora, hf_token)


# TI more combatible in safetensor format; maybe convert to safetensor can help
directory_embeds = 'embedings'
os.makedirs(directory_embeds, exist_ok=True)
download_embeds = [
    'https://huggingface.co/datasets/Nerfgun3/bad_prompt/resolve/main/bad_prompt.pt',
    'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
    'https://huggingface.co/sayakpaul/EasyNegative-test/blob/main/EasyNegative.safetensors',
    ]
for url_embed in download_embeds:
    download_things(directory_embeds, url_embed, hf_token)
embed_list = get_model_list(directory_embeds)

model_list = get_model_list(directory_models)
if load_diffusers_format_model.strip() != "" and load_diffusers_format_model.count('/') == 1:
    model_list.append(process_string(load_diffusers_format_model))
lora_model_list = get_model_list(directory_loras)
lora_model_list.insert(0, ("None","None"))
vae_model_list = get_model_list(directory_vaes)
vae_model_list.insert(0, ("None","None"))



print('\033[33m🏁 Download finished.\033[0m')

### SECOND PART ###
from diffusers import StableDiffusionPipeline
import gc
import numpy as np
import PIL.Image
from diffusers import (ControlNetModel, DiffusionPipeline,
                       StableDiffusionControlNetPipeline,
                       StableDiffusionControlNetInpaintPipeline,
                       UniPCMultistepScheduler)
import gc
import torch
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
                            LineartAnimeDetector, LineartDetector,
                            MidasDetector, MLSDdetector, NormalBaeDetector,
                            OpenposeDetector, PidiNetDetector)
from transformers import pipeline
from controlnet_aux.util import HWC3, ade_palette
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
import cv2



# =====================================
# Utils preprocessor
# =====================================
def resize_image(input_image, resolution, interpolation=None):
    H, W, C = input_image.shape
    H = float(H)
    W = float(W)
    k = float(resolution) / max(H, W)
    H *= k
    W *= k
    H = int(np.round(H / 64.0)) * 64
    W = int(np.round(W / 64.0)) * 64
    if interpolation is None:
        interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
    img = cv2.resize(input_image, (W, H), interpolation=interpolation)
    return img

class DepthEstimator:
    def __init__(self):
        self.model = pipeline('depth-estimation')

    def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
        detect_resolution = kwargs.pop('detect_resolution', 512)
        image_resolution = kwargs.pop('image_resolution', 512)
        image = np.array(image)
        image = HWC3(image)
        image = resize_image(image, resolution=detect_resolution)
        image = PIL.Image.fromarray(image)
        image = self.model(image)
        image = image['depth']
        image = np.array(image)
        image = HWC3(image)
        image = resize_image(image, resolution=image_resolution)
        return PIL.Image.fromarray(image)

class ImageSegmentor:
    def __init__(self):
        self.image_processor = AutoImageProcessor.from_pretrained(
            'openmmlab/upernet-convnext-small')
        self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
            'openmmlab/upernet-convnext-small')

    @torch.inference_mode()
    def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
        detect_resolution = kwargs.pop('detect_resolution', 512)
        image_resolution = kwargs.pop('image_resolution', 512)
        image = HWC3(image)
        image = resize_image(image, resolution=detect_resolution)
        image = PIL.Image.fromarray(image)

        pixel_values = self.image_processor(image,
                                            return_tensors='pt').pixel_values
        outputs = self.image_segmentor(pixel_values)
        seg = self.image_processor.post_process_semantic_segmentation(
            outputs, target_sizes=[image.size[::-1]])[0]
        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
        for label, color in enumerate(ade_palette()):
            color_seg[seg == label, :] = color
        color_seg = color_seg.astype(np.uint8)

        color_seg = resize_image(color_seg,
                                 resolution=image_resolution,
                                 interpolation=cv2.INTER_NEAREST)
        return PIL.Image.fromarray(color_seg)

class Preprocessor:
    MODEL_ID = 'lllyasviel/Annotators'

    def __init__(self):
        self.model = None
        self.name = ''

    def load(self, name: str) -> None:
        if name == self.name:
            return
        if name == 'HED':
            self.model = HEDdetector.from_pretrained(self.MODEL_ID)
        elif name == 'Midas':
            self.model = MidasDetector.from_pretrained(self.MODEL_ID)
        elif name == 'MLSD':
            self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
        elif name == 'Openpose':
            self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
        elif name == 'PidiNet':
            self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
        elif name == 'NormalBae':
            self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
        elif name == 'Lineart':
            self.model = LineartDetector.from_pretrained(self.MODEL_ID)
        elif name == 'LineartAnime':
            self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
        elif name == 'Canny':
            self.model = CannyDetector()
        elif name == 'ContentShuffle':
            self.model = ContentShuffleDetector()
        elif name == 'DPT':
            self.model = DepthEstimator()
        elif name == 'UPerNet':
            self.model = ImageSegmentor()
        else:
            raise ValueError
        torch.cuda.empty_cache()
        gc.collect()
        self.name = name

    def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
        if self.name == 'Canny':
            if 'detect_resolution' in kwargs:
                detect_resolution = kwargs.pop('detect_resolution')
                image = np.array(image)
                image = HWC3(image)
                image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            return PIL.Image.fromarray(image)
        elif self.name == 'Midas':
            detect_resolution = kwargs.pop('detect_resolution', 512)
            image_resolution = kwargs.pop('image_resolution', 512)
            image = np.array(image)
            image = HWC3(image)
            image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            return PIL.Image.fromarray(image)
        else:
            return self.model(image, **kwargs)

# =====================================
# Base Model
# =====================================
MAX_IMAGE_RESOLUTION = 4096 ## ⭐
MAX_NUM_IMAGES = 16 ## ⭐

CONTROLNET_MODEL_IDS = {
    'Openpose': 'lllyasviel/control_v11p_sd15_openpose',
    'Canny': 'lllyasviel/control_v11p_sd15_canny',
    'MLSD': 'lllyasviel/control_v11p_sd15_mlsd',
    'scribble': 'lllyasviel/control_v11p_sd15_scribble',
    'softedge': 'lllyasviel/control_v11p_sd15_softedge',
    'segmentation': 'lllyasviel/control_v11p_sd15_seg',
    'depth': 'lllyasviel/control_v11f1p_sd15_depth',
    'NormalBae': 'lllyasviel/control_v11p_sd15_normalbae',
    'lineart': 'lllyasviel/control_v11p_sd15_lineart',
    'lineart_anime': 'lllyasviel/control_v11p_sd15s2_lineart_anime',
    'shuffle': 'lllyasviel/control_v11e_sd15_shuffle',
    'ip2p': 'lllyasviel/control_v11e_sd15_ip2p',
    'Inpaint': 'lllyasviel/control_v11p_sd15_inpaint',
    'txt2img': 'NotControlnet',
}



def download_all_controlnet_weights() -> None:
    for model_id in CONTROLNET_MODEL_IDS.values():
        ControlNetModel.from_pretrained(model_id)

from diffusers import AutoencoderKL
class Model:
    def __init__(self,
                 base_model_id: str = 'runwayml/stable-diffusion-v1-5',
                 task_name: str = 'Canny', vae_model=None):
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.base_model_id = ''
        self.task_name = ''
        self.vae_model = None
        self.pipe = self.load_pipe(base_model_id, task_name, vae_model)
        self.preprocessor = Preprocessor()


    def load_pipe(self, base_model_id: str, task_name, vae_model=None, reload=False) -> DiffusionPipeline:
        if base_model_id == self.base_model_id and task_name == self.task_name and hasattr(
                self, 'pipe') and self.vae_model==vae_model and self.pipe is not None and reload==False:
            print('previous loaded')
            return self.pipe

        self.pipe = None
        torch.cuda.empty_cache()
        gc.collect()

        model_id = CONTROLNET_MODEL_IDS[task_name]

        if task_name == 'txt2img':
            if os.path.exists(base_model_id):
                pipe = StableDiffusionPipeline.from_single_file(
                  base_model_id,
                  vae = None if vae_model == 'None' else AutoencoderKL.from_single_file(vae_model), # , torch_dtype=torch.float16
                  torch_dtype=torch.float16,
                ).to("cuda")
                pipe.safety_checker = None
            else:
                pipe = DiffusionPipeline.from_pretrained(
                    base_model_id,
                    vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16),
                    torch_dtype=torch.float16,
                    use_safetensors=True,
                    variant="fp16",
                    ).to("cuda")
                pipe.safety_checker = None
            print('Loaded txt2img pipeline')
        elif task_name == 'Inpaint':
            controlnet = ControlNetModel.from_pretrained(model_id,
                                                        torch_dtype=torch.float16)
            if os.path.exists(base_model_id):
                pipe = StableDiffusionControlNetInpaintPipeline.from_single_file(
                    base_model_id,
                    vae = None if vae_model == 'None' else AutoencoderKL.from_single_file(vae_model),
                    safety_checker=None,
                    controlnet=controlnet,
                    torch_dtype=torch.float16)
            print('Loaded ControlNet Inpaint pipeline')
        else:
            controlnet = ControlNetModel.from_pretrained(model_id,
                                                        torch_dtype=torch.float16)
            if os.path.exists(base_model_id):
                pipe = StableDiffusionControlNetPipeline.from_single_file(
                    base_model_id,
                    vae = None if vae_model == 'None' else AutoencoderKL.from_single_file(vae_model),
                    safety_checker=None,
                    controlnet=controlnet,
                    torch_dtype=torch.float16)
            else:
                pipe = StableDiffusionControlNetPipeline.from_pretrained(
                    base_model_id,
                    vae = AutoencoderKL.from_pretrained(base_model_id, subfolder='vae') if vae_model == 'None' else AutoencoderKL.from_single_file(vae_model),
                    safety_checker=None,
                    controlnet=controlnet,
                    torch_dtype=torch.float16)
            print('Loaded ControlNet pipeline')

            pipe.scheduler = UniPCMultistepScheduler.from_config(
                pipe.scheduler.config)

        if self.device.type == 'cuda':
            pipe.enable_xformers_memory_efficient_attention()

        pipe.to(self.device)
        torch.cuda.empty_cache()
        gc.collect()
        self.pipe = pipe
        self.base_model_id = base_model_id
        self.task_name = task_name
        self.vae_model = vae_model
        return pipe

    def set_base_model(self, base_model_id: str) -> str:
        if not base_model_id or base_model_id == self.base_model_id:
            return self.base_model_id
        del self.pipe
        torch.cuda.empty_cache()
        gc.collect()
        try:
            self.pipe = self.load_pipe(base_model_id, self.task_name, self.vae_model)
        except Exception:
            self.pipe = self.load_pipe(self.base_model_id, self.task_name, self.vae_model)
        return self.base_model_id

    def load_controlnet_weight(self, task_name: str) -> None:
        if task_name == self.task_name:
            return
        if self.pipe is not None and hasattr(self.pipe, 'controlnet'):
            del self.pipe.controlnet
        torch.cuda.empty_cache()
        gc.collect()
        model_id = CONTROLNET_MODEL_IDS[task_name]
        controlnet = ControlNetModel.from_pretrained(model_id,
                                                     torch_dtype=torch.float16)
        controlnet.to(self.device)
        torch.cuda.empty_cache()
        gc.collect()
        self.pipe.controlnet = controlnet
        self.task_name = task_name

    def get_prompt(self, prompt: str, additional_prompt: str) -> str:
        if not prompt:
            prompt = additional_prompt
        else:
            prompt = f'{prompt}, {additional_prompt}'
        return prompt

    @torch.autocast('cuda')
    def run_pipe(
        self,
        prompt: str,
        negative_prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        control_image: PIL.Image.Image,
        num_images: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
    ) -> list[PIL.Image.Image]:
        generator = torch.Generator().manual_seed(seed)
        return self.pipe(
            # prompt=prompt,
            # negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_images_per_prompt=num_images,
            num_inference_steps=num_steps,
            generator=generator,
            image=control_image).images

    @torch.inference_mode()
    def process_canny(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str = "best quality, extremely detailed",
        negative_prompt: str = "",
        num_images: int = 1,
        image_resolution: int = 512,
        num_steps: int = 30,
        guidance_scale: float = 7.5,
        seed: int = -1,
        low_threshold: int = 100,
        high_threshold: int = 200,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        self.preprocessor.load('Canny')
        control_image = self.preprocessor(image=image,
                                          low_threshold=low_threshold,
                                          high_threshold=high_threshold,
                                          detect_resolution=image_resolution)

        self.load_controlnet_weight('Canny')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_mlsd(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        preprocess_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
        value_threshold: float,
        distance_threshold: float,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        self.preprocessor.load('MLSD')
        control_image = self.preprocessor(
            image=image,
            image_resolution=image_resolution,
            detect_resolution=preprocess_resolution,
            thr_v=value_threshold,
            thr_d=distance_threshold,
        )
        self.load_controlnet_weight('MLSD')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_scribble(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        preprocess_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        elif preprocessor_name == 'HED':
            self.preprocessor.load(preprocessor_name)
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                scribble=False,
            )
        elif preprocessor_name == 'PidiNet':
            self.preprocessor.load(preprocessor_name)
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                safe=False,
            )
        self.load_controlnet_weight('scribble')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_scribble_interactive(
        self,
        image_and_mask: dict[str, np.ndarray],
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
    ) -> list[PIL.Image.Image]:
        if image_and_mask is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        image = image_and_mask['mask']
        image = HWC3(image)
        image = resize_image(image, resolution=image_resolution)
        control_image = PIL.Image.fromarray(image)

        self.load_controlnet_weight('scribble')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_softedge(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        preprocess_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        elif preprocessor_name in ['HED', 'HED safe']:
            safe = 'safe' in preprocessor_name
            self.preprocessor.load('HED')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                scribble=safe,
            )
        elif preprocessor_name in ['PidiNet', 'PidiNet safe']:
            safe = 'safe' in preprocessor_name
            self.preprocessor.load('PidiNet')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                safe=safe,
            )
        else:
            raise ValueError
        self.load_controlnet_weight('softedge')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_openpose(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        preprocess_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        else:
            self.preprocessor.load('Openpose')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                hand_and_face=True,
            )
        self.load_controlnet_weight('Openpose')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_segmentation(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        preprocess_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        else:
            self.preprocessor.load(preprocessor_name)
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
            )
        self.load_controlnet_weight('segmentation')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_depth(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        preprocess_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        else:
            self.preprocessor.load(preprocessor_name)
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
            )
        self.load_controlnet_weight('depth')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_normal(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        preprocess_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        else:
            self.preprocessor.load('NormalBae')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
            )
        self.load_controlnet_weight('NormalBae')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_lineart(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        preprocess_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        if preprocessor_name in ['None', 'None (anime)']:
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        elif preprocessor_name in ['Lineart', 'Lineart coarse']:
            coarse = 'coarse' in preprocessor_name
            self.preprocessor.load('Lineart')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                coarse=coarse,
            )
        elif preprocessor_name == 'Lineart (anime)':
            self.preprocessor.load('LineartAnime')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
            )
        if 'anime' in preprocessor_name:
            self.load_controlnet_weight('lineart_anime')
        else:
            self.load_controlnet_weight('lineart')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_shuffle(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        else:
            self.preprocessor.load(preprocessor_name)
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
            )
        self.load_controlnet_weight('shuffle')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_ip2p(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        image = HWC3(image)
        image = resize_image(image, resolution=image_resolution)
        control_image = PIL.Image.fromarray(image)
        self.load_controlnet_weight('ip2p')
        results = self.run_pipe(
            prompt=self.get_prompt(prompt, additional_prompt),
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            control_image=control_image,
            num_images=num_images,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )
        return [control_image] + results

    @torch.inference_mode()
    def process_inpaint(
        self,
        image: np.ndarray,
        prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        additional_prompt: str,
        negative_prompt: str,
        num_images: int,
        image_resolution: int,
        preprocess_resolution: int,
        num_steps: int,
        guidance_scale: float,
        seed: int,
        image_mask: str,###
        strength: float,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError
        if image_resolution > MAX_IMAGE_RESOLUTION:
            raise ValueError
        if num_images > MAX_NUM_IMAGES:
            raise ValueError

        image = HWC3(image)
        image = resize_image(image, resolution=preprocess_resolution)
        init_image = PIL.Image.fromarray(image)

        image_mask = HWC3(image_mask)
        image_mask = resize_image(image_mask, resolution=preprocess_resolution)
        control_mask = PIL.Image.fromarray(image_mask)

        control_image = make_inpaint_condition(init_image, control_mask)

        self.load_controlnet_weight('Inpaint')
        results = self.pipe(
            # prompt=self.get_prompt(prompt, additional_prompt),
            # negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            eta=1.0,
            strength = strength,
            image = init_image, # original image
            mask_image = control_mask, # mask, values of 0 to 255
            control_image = control_image, # tensor control image
            num_images_per_prompt  = num_images,
            num_inference_steps = num_steps,
            guidance_scale = guidance_scale,
            generator = torch.Generator().manual_seed(seed),
        ).images

        return [init_image] + results

# =====================================
# Prompt weights
# =====================================
from compel import Compel
from diffusers import StableDiffusionPipeline, DDIMScheduler
import re


def concat_tensor(t):
    t_list = torch.split(t, 1, dim=0)
    t = torch.cat(t_list, dim=1)
    return t

def merge_embeds(prompt_chanks):
    num_chanks = len(prompt_chanks)
    power_prompt = 1/(num_chanks*(num_chanks+1)//2)
    prompt_embs = compel(prompt_chanks)
    t_list = list(torch.split(prompt_embs, 1, dim=0))
    for i in range(num_chanks):
        t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt)
    prompt_emb = torch.stack(t_list, dim=0).sum(dim=0)
    return prompt_emb

def detokenize(chunk, actual_prompt):
    chunk[-1] = chunk[-1].replace('</w>', '')
    chanked_prompt = ''.join(chunk).strip()
    while '</w>' in chanked_prompt:
        if actual_prompt[chanked_prompt.find('</w>')] == ' ':
            chanked_prompt = chanked_prompt.replace('</w>', ' ', 1)
        else:
            chanked_prompt = chanked_prompt.replace('</w>', '', 1)
    actual_prompt = actual_prompt.replace(chanked_prompt,'')
    return chanked_prompt.strip(), actual_prompt.strip()

def tokenize_line(line, tokenizer): # split into chunks
    actual_prompt = line.lower().strip()
    if actual_prompt == "":
      actual_prompt = 'worst quality'
    actual_tokens = tokenizer.tokenize(actual_prompt)
    max_tokens = tokenizer.model_max_length - 2
    comma_token = tokenizer.tokenize(',')[0]

    chunks = []
    chunk = []
    for item in actual_tokens:
        chunk.append(item)
        if len(chunk) == max_tokens:
            if chunk[-1] != comma_token:
                for i in range(max_tokens-1, -1, -1):
                    if chunk[i] == comma_token:
                        actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt)
                        chunks.append(actual_chunk)
                        chunk = chunk[i+1:]
                        break
                else:
                    actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
                    chunks.append(actual_chunk)
                    chunk = []
            else:
                actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
                chunks.append(actual_chunk)
                chunk = []
    if chunk:
        actual_chunk, _ = detokenize(chunk, actual_prompt)
        chunks.append(actual_chunk)

    return chunks

def prompt_weight_conversor(input_string):
    # Convert prompt weights from a1... to comel

    # Find and replace instances of the colon format with the desired format
    converted_string = re.sub(r'\(([^:]+):([\d.]+)\)', r'(\1)\2', input_string)

    # Find and replace square brackets with round brackets and assign weight
    converted_string = re.sub(r'\[([^:\]]+)\]', r'(\1)0.909090909', converted_string)

    # Handle the general case of [x:number] and convert it to (x)0.9
    converted_string = re.sub(r'\[([^:]+):[\d.]+\]', r'(\1)0.9', converted_string)

    # Add a '+' sign after the closing parenthesis if no weight is specified
    converted_string = re.sub(r'\(([^)]+)\)(?![\d.])', r'(\1)+', converted_string)

    # double (())
    modified_string = re.sub(r'\(\(([^)]+)\)\+\)', r'(\1)++', converted_string)

    # triple ((()))
    #modified_string = re.sub(r'\(\(([^)]+)\)\+\+\)', r'(\1)+++', modified_string)

    #print(modified_string)
    return modified_string

# =====================================
# IMAGE: METADATA AND SAVE
# =====================================
import os
from PIL import Image
from PIL.PngImagePlugin import PngInfo

def save_pil_image_with_metadata(image, folder_path, metadata_list):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    existing_files = os.listdir(folder_path)

    # Determine the next available image name
    image_name = f"image{str(len(existing_files) + 1).zfill(3)}.png"
    image_path = os.path.join(folder_path, image_name)

    try:
        # metadata
        metadata = PngInfo()
        metadata.add_text("Prompt", str(metadata_list[0]))
        metadata.add_text("Negative prompt", str(metadata_list[1]))
        metadata.add_text("Model", str(metadata_list[2]))
        metadata.add_text("VAE", str(metadata_list[3]))
        metadata.add_text("Steps", str(metadata_list[4]))
        metadata.add_text("CFG", str(metadata_list[5]))
        metadata.add_text("Scheduler", str(metadata_list[6]))
        metadata.add_text("Seed", str(metadata_list[7]))

        image.save(image_path, pnginfo=metadata)
    except:
        print('Saving image without metadata')
        image.save(image_path)

    return image_path

# =====================================
# LoRA Loaders
# =====================================
import torch
from safetensors.torch import load_file
from collections import defaultdict
def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
    LORA_PREFIX_UNET = "lora_unet"
    LORA_PREFIX_TEXT_ENCODER = "lora_te"
    # load LoRA weight from .safetensors
    if isinstance(checkpoint_path, str):

        state_dict = load_file(checkpoint_path, device=device)

        updates = defaultdict(dict)
        for key, value in state_dict.items():
            # it is suggested to print out the key, it usually will be something like below
            # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

            layer, elem = key.split('.', 1)
            updates[layer][elem] = value

        # directly update weight in diffusers model
        for layer, elems in updates.items():

            if "text" in layer:
                layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
                curr_layer = pipeline.text_encoder
            else:
                layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
                curr_layer = pipeline.unet

            # find the target layer
            temp_name = layer_infos.pop(0)
            while len(layer_infos) > -1:
                try:
                    curr_layer = curr_layer.__getattr__(temp_name)
                    if len(layer_infos) > 0:
                        temp_name = layer_infos.pop(0)
                    elif len(layer_infos) == 0:
                        break
                except Exception:
                    if len(temp_name) > 0:
                        temp_name += "_" + layer_infos.pop(0)
                    else:
                        temp_name = layer_infos.pop(0)

            # get elements for this layer
            weight_up = elems['lora_up.weight'].to(dtype)
            weight_down = elems['lora_down.weight'].to(dtype)
            alpha = elems['alpha']
            if alpha:
                alpha = alpha.item() / weight_up.shape[1]
            else:
                alpha = 1.0

            # update weight
            if len(weight_up.shape) == 4:
                curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
            else:
                curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
    else:
        for ckptpath in checkpoint_path:
            state_dict = load_file(ckptpath, device=device)

            updates = defaultdict(dict)
            for key, value in state_dict.items():
                # it is suggested to print out the key, it usually will be something like below
                # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

                layer, elem = key.split('.', 1)
                updates[layer][elem] = value

            # directly update weight in diffusers model
            for layer, elems in updates.items():

                if "text" in layer:
                    layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
                    curr_layer = pipeline.text_encoder
                else:
                    layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
                    curr_layer = pipeline.unet

                # find the target layer
                temp_name = layer_infos.pop(0)
                while len(layer_infos) > -1:
                    try:
                        curr_layer = curr_layer.__getattr__(temp_name)
                        if len(layer_infos) > 0:
                            temp_name = layer_infos.pop(0)
                        elif len(layer_infos) == 0:
                            break
                    except Exception:
                        if len(temp_name) > 0:
                            temp_name += "_" + layer_infos.pop(0)
                        else:
                            temp_name = layer_infos.pop(0)

                # get elements for this layer
                weight_up = elems['lora_up.weight'].to(dtype)
                weight_down = elems['lora_down.weight'].to(dtype)
                alpha = elems['alpha']
                if alpha:
                    alpha = alpha.item() / weight_up.shape[1]
                else:
                    alpha = 1.0

                # update weight
                if len(weight_up.shape) == 4:
                    curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
                else:
                    curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
    return pipeline

def lora_mix_load(pipe, lora_path, alpha_scale=1.0, device='cuda'):
    try:
        pipe=load_lora_weights(pipe, [lora_path], alpha_scale, device, torch.float16)
    except:
        pipe.load_lora_weights(lora_path)
        pipe.fuse_lora(lora_scale=alpha_scale)

    return pipe


# =====================================
# Inpainting canvas
# =====================================
canvas_html = """
<style>
.button {
  background-color: #4CAF50;
  border: none;
  color: white;
  padding: 15px 32px;
  text-align: center;
  text-decoration: none;
  display: inline-block;
  font-size: 16px;
  margin: 4px 2px;
  cursor: pointer;
}
</style>
<canvas1 width=%d height=%d>
</canvas1>
<canvas width=%d height=%d>
</canvas>

<button>Finish</button>
<script>
var canvas = document.querySelector('canvas')
var ctx = canvas.getContext('2d')

var canvas1 = document.querySelector('canvas1')
var ctx1 = canvas.getContext('2d')


ctx.strokeStyle = 'red';

var img = new Image();
img.src = "data:image/%s;charset=utf-8;base64,%s";
console.log(img)
img.onload = function() {
  ctx1.drawImage(img, 0, 0);
};
img.crossOrigin = 'Anonymous';

ctx.clearRect(0, 0, canvas.width, canvas.height);

ctx.lineWidth = %d
var button = document.querySelector('button')
var mouse = {x: 0, y: 0}

canvas.addEventListener('mousemove', function(e) {
  mouse.x = e.pageX - this.offsetLeft
  mouse.y = e.pageY - this.offsetTop
})
canvas.onmousedown = ()=>{
  ctx.beginPath()
  ctx.moveTo(mouse.x, mouse.y)
  canvas.addEventListener('mousemove', onPaint)
}
canvas.onmouseup = ()=>{
  canvas.removeEventListener('mousemove', onPaint)
}
var onPaint = ()=>{
  ctx.lineTo(mouse.x, mouse.y)
  ctx.stroke()
}

var data = new Promise(resolve=>{
  button.onclick = ()=>{
    resolve(canvas.toDataURL('image/png'))
  }
})
</script>
"""

import base64, os
from google.colab.output import eval_js
from base64 import b64decode
import matplotlib.pyplot as plt
import numpy as np
from shutil import copyfile
import shutil
import matplotlib.pyplot as plt

def draw(imgm, filename='drawing.png', w=400, h=200, line_width=1):
  display(HTML(canvas_html % (w, h, w,h, filename.split('.')[-1], imgm, line_width)))
  data = eval_js("data")
  binary = b64decode(data.split(',')[1])
  with open(filename, 'wb') as f:
    f.write(binary)

# the control image of init_image and mask_image
def make_inpaint_condition(image, image_mask):
    image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
    image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0

    assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
    image[image_mask > 0.5] = -1.0  # set as masked pixel
    image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return image

# =====================================
# Adetailer
# =====================================
from functools import partial
from diffusers import DPMSolverMultistepScheduler
from asdff import AdCnPipeline, AdPipeline, yolo_detector
from huggingface_hub import hf_hub_download

def ad_model_process(
    face_detector_ad,
    person_detector_ad,
    hand_detector_ad,
    model_repo_id,
    common,
    inpaint,
    image_list_task,
    ):

    pipe = AdPipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.safety_checker = None
    pipe.to("cuda")

    image_list_ad = []

    for path in image_list_task:
        if os.path.exists(path):
            # Open the image using PIL and convert it to PIL.Image.Image
            with Image.open(path) as img:
                images_ad = img.convert("RGB")

        detectors = []
        if face_detector_ad:
            face_model_path = hf_hub_download("Bingsu/adetailer", "face_yolov8n.pt")
            face_detector = partial(yolo_detector, model_path=face_model_path)
            detectors.append(face_detector)
        if person_detector_ad:
            person_model_path = hf_hub_download("Bingsu/adetailer", "person_yolov8s-seg.pt")
            person_detector = partial(yolo_detector, model_path=person_model_path)
            detectors.append(person_detector)
        if hand_detector_ad:
            hand_model_path = hf_hub_download("Bingsu/adetailer", "hand_yolov8n.pt")
            hand_detector = partial(yolo_detector, model_path=hand_model_path)
            detectors.append(hand_detector)

        result_ad = pipe(images=[images_ad], common=common, inpaint_only=inpaint, detectors=detectors)

        try:
            mediapy.show_images([result_ad[0][0], result_ad[1][0]])
        except:
            del pipe
            torch.cuda.empty_cache()
            gc.collect()
            return image_list_task

        image_path = save_pil_image_with_metadata(result_ad[0][0], f'{os.getcwd()}/images', metadata_list=None)
        image_list_ad.append(image_path)

    del pipe
    torch.cuda.empty_cache()
    gc.collect()

    return image_list_ad

In [2]:
#@title 👇 Generating Images { form-width: "20%", display-mode: "form" }
#@markdown ---
#@markdown - **Prompt** - Description of the image
#@markdown - **Negative Prompt** - Things you don't want to see or ignore in the image
#@markdown - **Steps** - Number of denoising steps. Higher steps may lead to better results but takes longer time to generate the image. Default is `30`.
#@markdown - **CFG** - Guidance scale ranging from `0` to `20`. Lower values allow the AI to be more creative and less strict at following the prompt. Default is `7.5`.
#@markdown - **Sampler** - The scheduler is responsible for controlling the learning rate of the diffusion model. Different schedulers can produce different results, so it is important to choose a scheduler that is appropriate for the desired task.
#@markdown - **Seed** - A random value that controls image generation. The same seed and prompt produce the same images. Set `-1` for using random seed values.
#@markdown - **Prompt weights** -  Are a way to control the influence of different text prompts on the image generation process. Prompt weights can be used to emphasize or de-emphasize certain aspects of the image, such as the object, the scene, or the style. Currently, the [Compel syntax](https://github.com/damian0815/compel/blob/main/doc/syntax.md) is being used. You can also activate the `Convert Prompt weights` to automatically convert syntax from `(word:1.1)` to `(word)1.1` or `(word)` to `(word)+` to make them compatible with Compel weights. Compel scale more with its values so that fewer weights are needed for good results
#@markdown - **ControlNet** - Is a neural network model for controlling diffusion models. It allows users to input additional information, such as edge maps, segmentation maps, and key points, into diffusion models to guide the image generation process.
#@markdown ---
%cd /content
import ipywidgets as widgets, mediapy, random
from diffusers.models.attention_processor import AttnProcessor2_0
from PIL import Image
import IPython.display
from diffusers import (
    DPMSolverMultistepScheduler,
    DPMSolverSinglestepScheduler,
    KDPM2DiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    HeunDiscreteScheduler,
    LMSDiscreteScheduler,
    DDIMScheduler,
    DiffusionPipeline,
)
import time
from IPython.utils import capture
import logging
logging.getLogger("diffusers").setLevel(logging.ERROR)

#from IPython.display import display
from ipywidgets import interactive, Layout, VBox

#Get scheduler
def get_scheduler(name):

  match name:

    case "DPM++ 2M":
      return DPMSolverMultistepScheduler.from_config(model.pipe.scheduler.config)

    case "DPM++ 2M Karras":
      return DPMSolverMultistepScheduler.from_config(model.pipe.scheduler.config, use_karras_sigmas=True)

    case "DPM++ 2M SDE":
      return DPMSolverMultistepScheduler.from_config(model.pipe.scheduler.config, algorithm_type="sde-dpmsolver++")

    case "DPM++ 2M SDE Karras":
      return DPMSolverMultistepScheduler.from_config(model.pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")

    case "DPM++ SDE":
      return DPMSolverSinglestepScheduler.from_config(model.pipe.scheduler.config, )

    case "DPM++ SDE Karras":
      return DPMSolverSinglestepScheduler.from_config(model.pipe.scheduler.config, use_karras_sigmas=True)

    case "DPM2":
      return KDPM2DiscreteScheduler.from_config(model.pipe.scheduler.config, )

    case "DPM2 Karras":
      return KDPM2DiscreteScheduler.from_config(model.pipe.scheduler.config, use_karras_sigmas=True)

    case "Euler":
      return EulerDiscreteScheduler.from_config(model.pipe.scheduler.config, )

    case "Euler a":
      return EulerAncestralDiscreteScheduler.from_config(model.pipe.scheduler.config, )

    case "Heun":
      return HeunDiscreteScheduler.from_config(model.pipe.scheduler.config, )

    case "LMS":
      return LMSDiscreteScheduler.from_config(model.pipe.scheduler.config, )

    case "LMS Karras":
      return LMSDiscreteScheduler.from_config(model.pipe.scheduler.config, use_karras_sigmas=True)

    case "DDIMScheduler":
      return DDIMScheduler.from_config(model.pipe.scheduler.config)

#PARAMETER WIDGETS
width = "250px"

select_model = widgets.Dropdown(
    options=model_list,
    description="Model:"
)

vae_model_dropdown = widgets.Dropdown(
    options=vae_model_list,
    description="VAE:"
)

prompt = widgets.Textarea(
    value="",
    placeholder="Enter prompt",
    #description="Prompt:",
    rows=5,
    layout=widgets.Layout(width="550px")
)

neg_prompt = widgets.Textarea(
    value="",
    placeholder="Enter negative prompt",
    #description="Negative Prompt:",
    rows=5,
    layout=widgets.Layout(width="550px")
)

num_images = widgets.IntText(
    value=1,
    description="Images:",
    layout=widgets.Layout(width=width),
)

steps = widgets.IntText(
    value=30,
    description="Steps:",
    layout=widgets.Layout(width=width)
)

CFG = widgets.FloatText(
    value=7.5,
    step=0.5,
    description="CFG:",
    layout=widgets.Layout(width=width)
)

select_sampler = widgets.Dropdown(
    options=[
        "DPM++ 2M",
        "DPM++ 2M Karras",
        "DPM++ 2M SDE",
        "DPM++ 2M SDE Karras",
        "DPM++ SDE",
        "DPM++ SDE Karras",
        "DPM2",
        "DPM2 Karras",
        "Euler",
        "Euler a",
        "Heun",
        "LMS",
        "LMS Karras",
        "DDIMScheduler",
    ],
    description="Scheduler:",
    layout=widgets.Layout(width=width)
)
#select_sampler.style.description_width = "auto"

img_height = widgets.IntText(
    min=256,
    max=2048,
    value=512,
    description="Height:",
    layout=widgets.Layout(width=width)
)

img_width = widgets.IntText(
    min=256,
    max=2048,
    value=512,
    description="Width:",
    layout=widgets.Layout(width=width)
)

random_seed = widgets.IntText(
    value=-1,
    description="Seed:",
    layout=widgets.Layout(width=width),
    disabled=False
)

generate = widgets.Button(
    description="Generate",
    disabled=False,
    button_style="primary",
    layout=widgets.Layout(width=width)
)
# textual inversion
show_textual_inversion = widgets.Button(
    description="List available textual inversions",
    disabled=False,
    button_style="info",
    layout=widgets.Layout(width=width)
)
active_ti = widgets.Checkbox(
    value=False,
    description='Active Textual Inversion in prompt (Experimental)',
)
# alternative prompt weights
weights_prompt = widgets.Checkbox(
    value=False,
    description='Convert Prompt weights',
)

#lora1
select_lora1 = widgets.Dropdown(
    options=lora_model_list,
    description="Lora1:",
    layout={'width':'190px'}
)

lora_weights_scale1 = widgets.FloatText(
    min=-2.0,
    max=2.0,
    step=0.01,
    value=1,
    #description="Lora scale1:",
    layout={'width':'56px'}
)
#lora2
select_lora2 = widgets.Dropdown(
    options=lora_model_list,
    description="Lora2:",
    layout={'width':'190px'}
)

lora_weights_scale2 = widgets.FloatText(
    min=-2.0,
    max=2.0,
    step=0.01,
    value=1,
    #description="Lora scale2:",
    layout={'width':'56px'}
)
#lora3
select_lora3 = widgets.Dropdown(
    options=lora_model_list,
    description="Lora3:",
    layout={'width':'190px'}
)

lora_weights_scale3 = widgets.FloatText(
    min=-2.0,
    max=2.0,
    step=0.01,
    value=1,
    #description="Lora scale3:",
    layout={'width':'56px'}
)
#lora4
select_lora4 = widgets.Dropdown(
    options=lora_model_list,
    description="Lora4:",
    layout={'width':'190px'}
)

lora_weights_scale4 = widgets.FloatText(
    min=-2.0,
    max=2.0,
    step=0.01,
    value=1,
    #description="Lora scale4:",
    layout={'width':'56px'}
)
#lora5
select_lora5 = widgets.Dropdown(
    options=lora_model_list,
    description="Lora5:",
    layout={'width':'190px'}
)

lora_weights_scale5 = widgets.FloatText(
    min=-2.0,
    max=2.0,
    step=0.01,
    value=1,
    #description="Lora scale5:",
    layout={'width':'56px'}
)

display_imgs = widgets.Output()


### second part ####
preprocess_resolution_global = widgets.IntSlider(
    value=512,
    min=256,
    max=2048,
    description='Preprocess resolution ControlNet'
)

control_model_list = list(CONTROLNET_MODEL_IDS.keys())

# Create a Dropdown for selecting options
options_controlnet = widgets.Dropdown(
    options=[
        control_model_list[13],
        control_model_list[12],
        control_model_list[0],
        control_model_list[1],
        control_model_list[2],
        control_model_list[3],
        control_model_list[4],
        control_model_list[5],
        control_model_list[6],
        control_model_list[7],
        control_model_list[8],
        control_model_list[10],
        control_model_list[11],
    ],
    description='TASK:',
    layout=widgets.Layout(width="550px"),
)

neg_prompt = widgets.Textarea(
    value="",
    placeholder="Enter negative prompt",
    #description="Negative Prompt:",
    rows=5,
    layout=widgets.Layout(width="550px")
)
# Create a dictionary to map options to lists of IntText widgets
int_inputs = {

    control_model_list[13]: [
    ],
    control_model_list[12]: [
        widgets.FloatSlider(value=1.0, min=0.01, max=1.0, step=0.01, description='Inpaint strength:', layout=Layout(visibility='hidden')),
        widgets.Textarea(value="", placeholder="/content/my_mask.png", rows=1, description='Mask path:', layout=Layout(visibility='hidden'))
    ],
    control_model_list[0]: [
        widgets.Dropdown(value='Openpose', description='Preprocessor:', options=['None','Openpose'], layout=Layout(visibility='hidden')),
    ],
    control_model_list[1]: [
        widgets.IntText(value=100, min=1, max=255, description='Canny low threshold:', layout=Layout(visibility='hidden')),
        widgets.IntText(value=200, min=1, max=255, description='Canny high threshold:', layout=Layout(visibility='hidden'))
    ],
    control_model_list[2]: [
        widgets.FloatText(value=0.1, min=1, max=2.0, description='Hough value threshold (MLSD):', layout=Layout(visibility='hidden')),
        widgets.FloatText(value=0.1, min=1, max=20.0, description='Hough distance threshold (MLSD):', layout=Layout(visibility='hidden'))
    ],
    control_model_list[3]: [
        widgets.Dropdown(value='HED', description='Preprocessor:', options=['HED','PidiNet', 'None'], layout=Layout(visibility='hidden')),
    ],
    control_model_list[4]: [
        widgets.Dropdown(value='PidiNet', description='Preprocessor:', options=['HED','PidiNet', 'HED safe', 'PidiNet safe','None'], layout=Layout(visibility='hidden')),
    ],
    control_model_list[5]: [
        widgets.Dropdown(value='UPerNet', description='Preprocessor:', options=['UPerNet','None'], layout=Layout(visibility='hidden')),
    ],
    control_model_list[6]: [
        widgets.Dropdown(value='DPT', description='Preprocessor:', options=['Midas', 'DPT','None'], layout=Layout(visibility='hidden')),
    ],
    control_model_list[7]: [
        widgets.Dropdown(value='NormalBae', description='Preprocessor:', options=['NormalBae','None'], layout=Layout(visibility='hidden')),
    ],
    control_model_list[8]: [
        widgets.Dropdown(value='Lineart', description='Preprocessor:', options=['Lineart','Lineart coarse', 'None', 'Lineart (anime)', 'None (anime)'], layout=Layout(visibility='hidden')),
    ],
    control_model_list[10]: [
        widgets.Dropdown(value='ContentShuffle', description='Preprocessor:', options=['ContentShuffle','None'], layout=Layout(visibility='hidden')),
    ],
}

# Function to update visibility and enable/disable state of widgets
def update_widgets(option):
    for opt, int_inputs_list in int_inputs.items():
        if opt == option:
            for int_input in int_inputs_list:
                int_input.layout.visibility = 'visible'
        else:
            for int_input in int_inputs_list:
                int_input.layout.visibility = 'hidden'

interactive(update_widgets, option=options_controlnet)

### GENERATE ###

def generate_img(i):
  global model
  #Clear output
  display_imgs.clear_output()
  generate.disabled = True

  #Calculate seed
  seed = random.randint(0, 2147483647) if random_seed.value == -1 else random_seed.value

  with display_imgs:

    print("Running...")


    # First load
    try:
        model
    except:
        model = Model(base_model_id=select_model.value, task_name=options_controlnet.value, vae_model = vae_model_dropdown.value)

    model.load_pipe(select_model.value, task_name=options_controlnet.value, vae_model = vae_model_dropdown.value)

    display_imgs.clear_output()

    model.pipe.to("cuda")
    # model.pipe.unfuse_lora()
    # model.pipe.unload_lora_weights()

    if select_lora1.value != "None":
      print('lora1')
      try:
          model.pipe = lora_mix_load(model.pipe, select_lora1.value, lora_weights_scale1.value)
      except:
          print(f"ERROR: LoRA not compatible:  {select_lora1.value}")
    if select_lora2.value != "None":
      print('lora2')
      try:
          model.pipe = lora_mix_load(model.pipe, select_lora2.value, lora_weights_scale2.value)
      except:
          print(f"ERROR: LoRA not compatible:  {select_lora2.value}")

    if select_lora3.value != "None":
      print('lora3')
      try:
          model.pipe = lora_mix_load(model.pipe, select_lora3.value, lora_weights_scale3.value)
      except:
          print(f"ERROR: LoRA not compatible:  {select_lora3.value}")
    model.pipe.to("cuda")

    if select_lora4.value != "None":
      print('lora4')
      try:
          model.pipe = lora_mix_load(model.pipe, select_lora4.value, lora_weights_scale4.value)
      except:
          print(f"ERROR: LoRA not compatible:  {select_lora4.value}")

    if select_lora5.value != "None":
      print('lora5')
      try:
          model.pipe = lora_mix_load(model.pipe, select_lora5.value, lora_weights_scale5.value)
      except:
          print(f"ERROR: LoRA not compatible:  {select_lora5.value}")

    model.pipe.to("cuda")

    # Prompt Optimizations for 1.5
    if os.path.exists(select_model.value):
        if  active_ti.value:
          # Textual Inversion
          for name, directory_name in embed_list:

              try:
                      #model.pipe.text_encoder.resize_token_embeddings(len(model.pipe.tokenizer),pad_to_multiple_of=128)
                      #model.pipe.load_textual_inversion(directory_name, token=name)
                      #model.pipe.load_textual_inversion("./bad_prompt.pt", token="baddd")
                      model.pipe.load_textual_inversion(directory_name, token=name)
              except ValueError:
                  #print('previous loaded ti')
                  pass
              except:
                  print(f"Can't apply {name}")

        #Prompt weights
        global compel
        compel = Compel(tokenizer=model.pipe.tokenizer, text_encoder=model.pipe.text_encoder, truncate_long_prompts=False)

        prompt_ti = model.pipe.maybe_convert_prompt(prompt.value, model.pipe.tokenizer)
        negative_prompt_ti = model.pipe.maybe_convert_prompt(neg_prompt.value, model.pipe.tokenizer)

        if weights_prompt.value:
            prompt_ti = prompt_weight_conversor(prompt_ti)
            negative_prompt_ti = prompt_weight_conversor(negative_prompt_ti)

        prompt_emb = merge_embeds(tokenize_line(prompt_ti, model.pipe.tokenizer))
        negative_prompt_emb = merge_embeds(tokenize_line(negative_prompt_ti, model.pipe.tokenizer))

        # fix error shape
        if prompt_emb.shape != negative_prompt_emb.shape:
            print('___')
            #compel = Compel(tokenizer=model.pipe.tokenizer, text_encoder=model.pipe.text_encoder, truncate_long_prompts=False)
            prompt_emb, negative_prompt_emb = compel.pad_conditioning_tensors_to_same_length([prompt_emb, negative_prompt_emb])
            # prompt_emb = compel(prompt_ti)
            # negative_prompt_emb = compel(negative_prompt_ti)

        compel = None
        del compel

    model.pipe.enable_xformers_memory_efficient_attention()
    model.pipe.scheduler = get_scheduler(select_sampler.value)
    model.pipe.safety_checker = None

    if options_controlnet.value != 'txt2img':
        try:
            print(f'Control image: {destination_path_cn_img}')
            image_pil = Image.open(destination_path_cn_img)
            numpy_array = np.array(image_pil, dtype=np.uint8)
            array_rgb = numpy_array[:, :, :3]
        except:
            print("To use this function, you have to upload an image in the cell below first 👇")
            del model
            torch.cuda.empty_cache()
            gc.collect()
            generate.disabled = False
            return


    if not os.path.exists(select_model.value):
        # SDXL
        images = model.pipe(
            prompt = prompt.value,
            height = img_height.value,
            width = img_width.value,
            num_inference_steps = steps.value,
            guidance_scale = CFG.value,
            num_images_per_prompt = num_images.value,
            negative_prompt = neg_prompt.value,
            generator = torch.Generator("cuda").manual_seed(seed),
        ).images

    elif options_controlnet.value == 'txt2img':

        images = model.pipe(
            # prompt = '', # prompt.value,
            # negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            height = img_height.value,
            width = img_width.value,
            num_inference_steps = steps.value,
            guidance_scale = CFG.value,
            num_images_per_prompt = num_images.value,
            generator = torch.Generator("cuda").manual_seed(seed),
        ).images

    elif options_controlnet.value == 'Inpaint':
        global mask_control
        init_image = destination_path_cn_img
        name_without_extension = os.path.splitext(init_image.split('/')[-1])[0]

        image64 = base64.b64encode(open(init_image, 'rb').read())
        image64 = image64.decode('utf-8')

        img = np.array(plt.imread(f'{init_image}')[:,:,:3])

        if os.path.exists(int_inputs['Inpaint'][1].value):
            mask_control = int_inputs['Inpaint'][1].value

        mask_control_img = Image.open(mask_control)
        numpy_array_mask = np.array(mask_control_img, dtype=np.uint8)
        array_rgb_mask = numpy_array_mask[:, :, :3]

        # else:
        #     draw(image64, filename=f"./{name_without_extension}_draw.png", w=img.shape[1], h=img.shape[0], line_width=0.04*img.shape[1])

        #     with_mask = np.array(plt.imread(f"./{name_without_extension}_draw.png")[:,:,:3])
        #     mask = (with_mask[:,:,0]==1)*(with_mask[:,:,1]==0)*(with_mask[:,:,2]==0)
        #     plt.imsave(f"./{name_without_extension}_mask.png",mask, cmap='gray')
        #     mask_control = f"./{name_without_extension}_mask.png"
        #     print(f'Mask saved: {mask_control}')

        images = model.process_inpaint(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value), ### edit
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
            preprocess_resolution=preprocess_resolution_global.value, # edit size in Inpaint
            image_mask=array_rgb_mask,
            strength=int_inputs['Inpaint'][0].value,
        )

    elif options_controlnet.value == 'Openpose':
        print('BETA: Openpose, resolution min(W, H)')

        images = model.process_openpose(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value),
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
            preprocessor_name=int_inputs['Openpose'][0].value,
            preprocess_resolution=preprocess_resolution_global.value
        )

    elif options_controlnet.value == 'Canny':
        print('BETA: Canny')

        images = model.process_canny(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value),
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
            low_threshold=int_inputs['Canny'][0].value,
            high_threshold=int_inputs['Canny'][1].value
        )

    elif options_controlnet.value == 'MLSD':
        print('BETA: MLSD')

        images = model.process_mlsd(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value),
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
            value_threshold=int_inputs['MLSD'][0].value,
            distance_threshold=int_inputs['MLSD'][1].value,
            preprocess_resolution=preprocess_resolution_global.value
        )

    elif options_controlnet.value == 'scribble':
        print('BETA: scribble, resolution min(W, H)')

        images = model.process_scribble(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value),
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
            preprocessor_name=int_inputs['scribble'][0].value,
            preprocess_resolution=preprocess_resolution_global.value
        )

    elif options_controlnet.value == 'softedge':
        print('BETA: softedge, resolution min(W, H)')

        images = model.process_softedge(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value),
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
            preprocessor_name=int_inputs['softedge'][0].value,
            preprocess_resolution=preprocess_resolution_global.value
        )

    elif options_controlnet.value == 'segmentation':
        print('BETA: segmentation, resolution min(W, H)')

        images = model.process_segmentation(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value),
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
            preprocessor_name=int_inputs['segmentation'][0].value,
            preprocess_resolution=preprocess_resolution_global.value
        )

    elif options_controlnet.value == 'depth':
        print('BETA: depth, resolution min(W, H)')

        images = model.process_depth(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value),
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
            preprocessor_name=int_inputs['depth'][0].value,
            preprocess_resolution=preprocess_resolution_global.value
        )

    elif options_controlnet.value == 'NormalBae':
        print('BETA: NormalBae, resolution min(W, H)')

        images = model.process_mlsd(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value),
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
            preprocessor_name=int_inputs['NormalBae'][0].value,
            preprocess_resolution=preprocess_resolution_global.value
        )

    elif 'lineart' in options_controlnet.value:
        print('BETA: lineart, resolution min(W, H)')

        images = model.process_lineart(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value),
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
            preprocessor_name=int_inputs['lineart'][0].value,
            preprocess_resolution=preprocess_resolution_global.value
        )

    elif options_controlnet.value == 'shuffle':
        print('BETA: shuffle, resolution min(W, H)')

        images = model.process_shuffle(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value),
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
            preprocessor_name=int_inputs['shuffle'][0].value,
        )

    elif options_controlnet.value == 'ip2p':
        print('BETA: (inpaint pix2pix) ip2p, resolution min(W, H)')

        images = model.process_ip2p(
            image=array_rgb,
            prompt='', # prompt.value,
            negative_prompt = '', # negative_prompt = neg_prompt.value,
            prompt_embeds=prompt_emb,
            negative_prompt_embeds=negative_prompt_emb,
            additional_prompt="",
            num_images=num_images.value,
            image_resolution=min(img_height.value, img_width.value),
            num_steps=steps.value,
            guidance_scale=CFG.value,
            seed=seed,
        )

    else:
        images = None

    # if select_lora1.value != "None":
    #     model.pipe.unfuse_lora()
    #     model.pipe.unload_lora_weights()
    # if select_lora2.value != "None" or select_lora3.value != "None":
    #     print('BETA: reload weights for lora')
    #     model.load_pipe(select_model.value, task_name=options_controlnet.value, vae_model = vae_model_dropdown.value, reload=True)
    if select_lora1.value != "None":
      try:
          model.pipe = lora_mix_load(model.pipe, select_lora1.value, -lora_weights_scale1.value)
      except:
          pass
    if select_lora2.value != "None":
      try:
          model.pipe = lora_mix_load(model.pipe, select_lora2.value, -lora_weights_scale2.value)
      except:
          pass
    if select_lora3.value != "None":
      try:
          model.pipe = lora_mix_load(model.pipe, select_lora3.value, -lora_weights_scale3.value)
      except:
          pass
    if select_lora4.value != "None":
      try:
          model.pipe = lora_mix_load(model.pipe, select_lora4.value, -lora_weights_scale4.value)
      except:
          pass
    if select_lora5.value != "None":
      try:
          model.pipe = lora_mix_load(model.pipe, select_lora5.value, -lora_weights_scale5.value)
      except:
          pass
    # model.pipe.unfuse_lora()
    # model.pipe.unload_lora_weights()

    torch.cuda.empty_cache()
    gc.collect()

    mediapy.show_images(images)

    # Save img
    global image_list
    image_list = []

    metadata = [
            prompt.value,
            neg_prompt.value,
            select_model.value,
            vae_model_dropdown.value,
            steps.value,
            CFG.value,
            select_sampler.value,
            random_seed.value
    ]

    directory_images = './images'
    os.makedirs(directory_images, exist_ok=True)

    for image_ in images:
        image_path = save_pil_image_with_metadata(image_, directory_images, metadata)
        image_list.append(image_path)

    print(f"Seed:\n{seed}")

  generate.disabled = False



def elemets_textual_inversion(value):
  with display_imgs:
    print('Clearing output in 7 seconds')
    print('The embeddings currently supported. Write in the prompt the word for use')
    for name, directory_name in embed_list:

        print(name)
    time.sleep(7)
    display_imgs.clear_output()


generate.on_click(generate_img)
show_textual_inversion.on_click(elemets_textual_inversion)

# TABS
tab = widgets.Tab()

#Display

# TAB 1
tab_sd = widgets.VBox(
    [
      widgets.AppLayout(
        header=None,
        left_sidebar=widgets.VBox(
            [
                num_images, steps,
                CFG, select_sampler,
                img_height, img_width,
                random_seed,
                widgets.HBox([select_lora1, lora_weights_scale1]),
                widgets.HBox([select_lora2, lora_weights_scale2]),
                widgets.HBox([select_lora3, lora_weights_scale3]),
                widgets.HBox([select_lora4, lora_weights_scale4]),
                widgets.HBox([select_lora5, lora_weights_scale5]),
            ]
        ),
        center=widgets.VBox(
            [
                widgets.HTML(
                    value="<h2>SD Interactive</h2>",
                    layout=widgets.Layout(display="flex", justify_content="center")
                ),
                options_controlnet,
                select_model,
                vae_model_dropdown,
                prompt,
                neg_prompt,
                show_textual_inversion,
                active_ti,
                weights_prompt,
                generate,
            ]
        ),
        right_sidebar=widgets.VBox(
            [preprocess_resolution_global] + [int_input for int_inputs_list in int_inputs.values() for int_input in int_inputs_list]
        ),
        footer=None
      ),
      display_imgs,
    ]
)

# TAB 2
tab_settings = widgets.VBox([
    widgets.HTML(
        value="<h2>SD Interactive</h2>",
        layout=widgets.Layout(display="flex", justify_content="center")
    ),
    widgets.HTML(
        value="<p>🔄</p>",
        layout=widgets.Layout(display="flex", justify_content="center")
    ),
])

#
tab.children = [
  widgets.VBox(
    #layout = {'height': '550px', 'max_height': '720px', 'margin':'8px'},
    children = [
        tab_sd,
    ]),
  widgets.VBox(
    children = [
        tab_settings,
    ]),
]

tab_titles = ["Stable Diffusion", "More Settings"]
tab.titles = tab_titles
tab.selected_index = 0

display(tab)

/content


Tab(children=(VBox(children=(VBox(children=(AppLayout(children=(VBox(children=(IntText(value=1, description='I…

In [3]:
#@title Upload an image here for use in Inpainting or ControlNet. 👈‍‍ 🖼️🖼️🖼️
#@markdown - To use Controlnet, you need to upload the control image with this cell
Create_mask_for_Inpaint = True # @param {type:"boolean"}
stroke_width = 24 # @param {type:"integer"}
from google.colab import files
from IPython.display import HTML
import os
import shutil
%cd /content
uploaded = files.upload()

filename = next(iter(uploaded))
print(f'Uploaded file: {filename}')

upload_folder = 'uploaded_controlnet_image/'
if not os.path.exists(upload_folder):
    os.makedirs(upload_folder)

source_path = filename
destination_path_cn_img = os.path.join(upload_folder, filename)
shutil.move(source_path, destination_path_cn_img)
print(f'Moved file to {destination_path_cn_img}')

if options_controlnet.value == 'Inpaint' or Create_mask_for_Inpaint:
    init_image = destination_path_cn_img
    name_without_extension = os.path.splitext(init_image.split('/')[-1])[0]

    image64 = base64.b64encode(open(init_image, 'rb').read())
    image64 = image64.decode('utf-8')

    print('\033[34m Draw the mask with the mouse \033[0m')
    img = np.array(plt.imread(f'{init_image}')[:,:,:3])

    draw(image64, filename=f"./{name_without_extension}_draw.png", w=img.shape[1], h=img.shape[0], line_width=stroke_width)

    with_mask = np.array(plt.imread(f"./{name_without_extension}_draw.png")[:,:,:3])
    mask = (with_mask[:,:,0]==1)*(with_mask[:,:,1]==0)*(with_mask[:,:,2]==0)
    plt.imsave(f"./{name_without_extension}_mask.png",mask, cmap='gray')
    mask_control = f"./{name_without_extension}_mask.png"
    print(f'\033[34m Mask saved: {mask_control} \033[0m')

/content


Saving image006.png to image006.png
Uploaded file: image006.png
Moved file to uploaded_controlnet_image/image006.png
[34m Draw the mask with the mouse [0m


[34m Mask saved: ./image006_mask.png [0m


In [None]:
#@title 👇 Upscale and face restoration { form-width: "20%", display-mode: "form" }
from IPython.utils import capture
import os
import shutil

%cd /content
directory_codeformer = '/content/CodeFormer/'
with capture.capture_output() as cap:
  if not os.path.exists(directory_codeformer):
      os.makedirs(directory_codeformer)

      # Setup
      # Clone CodeFormer and enter the CodeFormer folder
      %cd /content
      !git clone https://github.com/sczhou/CodeFormer.git
      %cd CodeFormer


      # Set up the environment
      # Install python dependencies
      !pip install -q -r requirements.txt
      !pip -q install ffmpeg
      # Install basicsr
      !python basicsr/setup.py develop

      # Download the pre-trained model
      !python scripts/download_pretrained_models.py facelib
      !python scripts/download_pretrained_models.py CodeFormer
  del cap
# Visualization function
import cv2
import matplotlib.pyplot as plt
def display(img1, img2):
  fig = plt.figure(figsize=(25, 10))
  ax1 = fig.add_subplot(1, 2, 1)
  plt.title('Input', fontsize=16)
  ax1.axis('off')
  ax2 = fig.add_subplot(1, 2, 2)
  plt.title('CodeFormer', fontsize=16)
  ax2.axis('off')
  ax1.imshow(img1)
  ax2.imshow(img2)
def imread(img_path):
  img = cv2.imread(img_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  return img

# Copy imgs
Select_an_image = "" # @param {type:"string"}

# PROCESS AD
if os.path.exists(Select_an_image.strip()):
    image_list = [Select_an_image.replace('/content/', '').strip()]

destination_directory = '/content/CodeFormer/inputs/user_upload'
!rm -rf /content/CodeFormer/inputs/user_upload/*
os.makedirs(destination_directory, exist_ok=True)
for image_path in image_list:
    image_filename = os.path.basename('/content/'+image_path)
    destination_path = os.path.join(destination_directory, image_filename)
    try:
        shutil.copyfile('/content/'+image_path, destination_path)
        print(f"Image '{image_filename}' has been copied to '{destination_path}'")
    except Exception as e:
        print(f"Failed to copy '{image_filename}' to '{destination_path}': {e}")

#@markdown `CODEFORMER_FIDELITY`: Balance the quality (lower number) and fidelity (higher number)<br>
# you can add '--bg_upsampler realesrgan' to enhance the background
CODEFORMER_FIDELITY = 0.7 #@param {type:"slider", min:0, max:1, step:0.01}
#@markdown `BACKGROUND_ENHANCE`: Enhance background image with Real-ESRGAN<br>
BACKGROUND_ENHANCE = True #@param {type:"boolean"}
#@markdown `FACE_UPSAMPLE`: Upsample restored faces for high-resolution AI-created images<br>
FACE_UPSAMPLE = False #@param {type:"boolean"}
#markdown `HAS_ALIGNED`: Input are cropped and aligned faces<br>
HAS_ALIGNED =  False
#@markdown `UPSCALE`: The final upsampling scale of the image. Default: 2<br>
UPSCALE = 2 #@param {type:"slider", min:2, max:8, step:1}
#markdown `DETECTION_MODEL`: Face detector. Default: retinaface_resnet50<br>
DETECTION_MODEL = "retinaface_resnet50"
#markdown `DRAW_BOX`: Draw the bounding box for the detected faces.
DRAW_BOX = False

BACKGROUND_ENHANCE = '--bg_upsampler realesrgan' if BACKGROUND_ENHANCE else ''
FACE_UPSAMPLE = '--face_upsample' if FACE_UPSAMPLE else ''
HAS_ALIGNED = '--has_aligned' if HAS_ALIGNED else ''
DRAW_BOX = '--draw_box' if DRAW_BOX else ''
%cd CodeFormer
!python inference_codeformer.py -w $CODEFORMER_FIDELITY --input_path {destination_directory} {BACKGROUND_ENHANCE} {FACE_UPSAMPLE} {HAS_ALIGNED} --upscale {UPSCALE} --detection_model {DETECTION_MODEL} {DRAW_BOX}


import os
import glob

input_folder = 'inputs/user_upload'
result_folder = f'results/user_upload_{CODEFORMER_FIDELITY}/final_results'
input_list = sorted(glob.glob(os.path.join(input_folder, '*')))
for input_path in input_list:
  img_input = imread(input_path)
  basename = os.path.splitext(os.path.basename(input_path))[0]
  output_path = os.path.join(result_folder, basename+'.png')
  img_output = imread(output_path)
  display(img_input, img_output)

%cd /content

In [None]:
#@title Download Images
import os
from google.colab import files
!rm /content/results.zip
!ls /content/images
print('Download results')
os.system(f'zip -r results.zip /content/images')
try:
  files.download("results.zip")
except:
  print("Error")

In [None]:
#@title Download Upscale results
import os
from google.colab import files
import shutil
%cd /content/CodeFormer
!ls results
print('Download results')
os.system(f'zip -r results.zip results/user_upload_{CODEFORMER_FIDELITY}/final_results')
try:
  files.download("results.zip")
except:
  files.download(f'/content/CodeFormer/results/{filename[:-4]}_{CODEFORMER_FIDELITY}/{filename}')
%cd /content

# Extras

In [None]:
# You can also use this cell to simply reload the model in case you need to.
del model

In [None]:
from PIL import Image
import os

try:
  del model
  torch.cuda.empty_cache()
  gc.collect()
except:
  torch.cuda.empty_cache()
  gc.collect()

# OPTIONS #
# @markdown # Adetailer
face_detector_ad = True # @param {type:"boolean"}
person_detector_ad = True # @param {type:"boolean"}
hand_detector_ad = False # @param {type:"boolean"}
model_for_inpaint = "frankjoshua/toonyou_beta6" # @param {type:"string"}
common = {
    "prompt": "masterpiece, best quality",
    "num_inference_steps": 50,
}
inpaint = {
    "prompt": "masterpiece, best quality",
    "num_inference_steps": 50,
}
Select_an_image = "" # @param {type:"string"}

# PROCESS AD
if os.path.exists(Select_an_image):
    image_list = [Select_an_image]

image_list = ad_model_process(
    face_detector_ad,
    person_detector_ad,
    hand_detector_ad,
    model_for_inpaint,
    common,
    inpaint,
    image_list,
)