<a href="https://colab.research.google.com/github/R3gm/SD_diffusers_interactive/blob/main/Stable_diffusion_interactive_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interactive Stable Diffusion

| Description | Link |
| ----------- | ---- |
| 🎉 Repository | [![GitHub Repository](https://img.shields.io/github/stars/R3gm/SD_diffusers_interactive?style=social)](https://github.com/R3gm/SD_diffusers_interactive) |


- Prompt weights: Depending on model and CFG you can weight up to around 1.5 or 1.6 before things start to get weird.
- Controlnet 1.1 for SD
- SDXL models only support txt2img
- More functions, more bugs; less than 10 words, more laughs


This Google Colab notebook offers a user-friendly interface for generating AI images from text prompts using Stable Diffusion. It uses HuggingFace Diffusers and Jupyter widgets, providing a simple and lightweight alternative to web-based tools, making it accessible for all to `get started with Stable Diffusion`.

Based on [redromnon's repository](https://github.com/redromnon/stable-diffusion-interactive-notebook)

In [None]:
#@title 1. 👇 Installing dependencies { display-mode: "form" }

!pip install -q omegaconf==2.3.0 torch git+https://github.com/huggingface/diffusers.git git+https://github.com/damian0815/compel.git invisible_watermark  transformers accelerate scipy safetensors==0.3.3 xformers safetensors mediapy ipywidgets==7.7.1 controlnet_aux==0.0.6 mediapipe==0.10.1 pytorch-lightning asdff

!apt install git-lfs
!git lfs install
!apt -y install -qq aria2

In [None]:
#@title 2. 👇 Download Model: Please provide a link for the Civitai API, Google Drive, or Hugging Face. { form-width: "20%", display-mode: "form" }
import os
%cd /content

def download_things(directory, url, hf_token=""):
    url = url.strip()

    if "drive.google.com" in url:
        original_dir = os.getcwd()
        os.chdir(directory)
        !gdown --fuzzy {url}
        os.chdir(original_dir)
    elif "huggingface.co" in url:
        if "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
        user_header = f'"Authorization: Bearer {hf_token}"'
        !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory}  -o {url.split('/')[-1]}
    else:
        !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}

def get_model_list(directory_path):
    model_list = []
    valid_extensions = {'.ckpt' , '.pt', '.pth', '.safetensors', '.bin'}

    for filename in os.listdir(directory_path):
        if os.path.splitext(filename)[1] in valid_extensions:
            name_without_extension = os.path.splitext(filename)[0]
            file_path = os.path.join(directory_path, filename)
            model_list.append((name_without_extension, file_path))
            print('\033[34mFILE: ' + name_without_extension + file_path + '\033[0m')
    return model_list

def process_string(input_string):
    parts = input_string.split('/')

    if len(parts) == 2:
        first_element = parts[1]
        complete_string = input_string
        result = (first_element, complete_string)
        return result
    else:
        return None

directory_models = 'models'
os.makedirs(directory_models, exist_ok=True)
directory_loras = 'loras'
os.makedirs(directory_loras, exist_ok=True)
directory_vaes = 'vaes'
os.makedirs(directory_vaes, exist_ok=True)

#@markdown ---
#@markdown - **Download a Model**
download_model = "https://civitai.com/api/download/models/125771" # @param {type:"string"}
#@markdown - For SDXL models, only diffuser format models are supported, and you only need the repository name.
load_diffusers_format_model = 'gsdf/CounterfeitXL' # @param {type:"string"}
#@markdown - **Download a VAE**
download_vae = "https://huggingface.co/fp16-guy/anything_kl-f8-anime2_vae-ft-mse-840000-ema-pruned_blessed_clearvae_fp16_cleaned/resolve/main/anything_fp16.safetensors" # @param {type:"string"}
#@markdown - **Download a LoRA**
download_lora = "https://civitai.com/api/download/models/97655" # @param {type:"string"}
#@markdown ---
#@markdown **HF TOKEN** - If you need to download your private model from Hugging Face, input your token here.
hf_token = ""  # @param {type:"string"}
#@markdown
#@markdown ---

download_things(directory_models, download_model, hf_token)
download_things(directory_vaes, download_vae, hf_token)
download_things(directory_loras, download_lora, hf_token)


# TI more combatible in safetensor format; maybe convert to safetensor can help
directory_embeds = 'embedings'
os.makedirs(directory_embeds, exist_ok=True)
download_embeds = [
    'https://huggingface.co/datasets/Nerfgun3/bad_prompt/resolve/main/bad_prompt.pt',
    'https://huggingface.co/datasets/Nerfgun3/bad_prompt/blob/main/bad_prompt_version2.pt',
    'https://huggingface.co/sayakpaul/EasyNegative-test/blob/main/EasyNegative.safetensors',
    ]
for url_embed in download_embeds:
    download_things(directory_embeds, url_embed, hf_token)
embed_list = get_model_list(directory_embeds)

model_list = get_model_list(directory_models)
if load_diffusers_format_model.strip() != "" and load_diffusers_format_model.count('/') == 1:
    model_list.append(process_string(load_diffusers_format_model))
lora_model_list = get_model_list(directory_loras)
lora_model_list.insert(0, ("None",None))
vae_model_list = get_model_list(directory_vaes)
vae_model_list.insert(0, ("None","None"))



print('\033[33m🏁 Download finished.\033[0m')

### SECOND PART ###
import gc
import time
import numpy as np
import PIL.Image
from diffusers import (
    ControlNetModel,
    DiffusionPipeline,
    StableDiffusionControlNetPipeline,
    StableDiffusionControlNetInpaintPipeline,
    StableDiffusionPipeline,
    AutoencoderKL
)
import torch
import random
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
                            LineartAnimeDetector, LineartDetector,
                            MidasDetector, MLSDdetector, NormalBaeDetector,
                            OpenposeDetector, PidiNetDetector)
from transformers import pipeline
from controlnet_aux.util import HWC3, ade_palette
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
import cv2
from diffusers import (
    DPMSolverMultistepScheduler,
    DPMSolverSinglestepScheduler,
    KDPM2DiscreteScheduler,
    KDPM2AncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    HeunDiscreteScheduler,
    LMSDiscreteScheduler,
    DDIMScheduler,
    DEISMultistepScheduler,
    UniPCMultistepScheduler,
)

# =====================================
# Utils preprocessor
# =====================================
def resize_image(input_image, resolution, interpolation=None):
    H, W, C = input_image.shape
    H = float(H)
    W = float(W)
    k = float(resolution) / max(H, W)
    H *= k
    W *= k
    H = int(np.round(H / 64.0)) * 64
    W = int(np.round(W / 64.0)) * 64
    if interpolation is None:
        interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
    img = cv2.resize(input_image, (W, H), interpolation=interpolation)
    return img

class DepthEstimator:
    def __init__(self):
        self.model = pipeline('depth-estimation')

    def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
        detect_resolution = kwargs.pop('detect_resolution', 512)
        image_resolution = kwargs.pop('image_resolution', 512)
        image = np.array(image)
        image = HWC3(image)
        image = resize_image(image, resolution=detect_resolution)
        image = PIL.Image.fromarray(image)
        image = self.model(image)
        image = image['depth']
        image = np.array(image)
        image = HWC3(image)
        image = resize_image(image, resolution=image_resolution)
        return PIL.Image.fromarray(image)

class ImageSegmentor:
    def __init__(self):
        self.image_processor = AutoImageProcessor.from_pretrained(
            'openmmlab/upernet-convnext-small')
        self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
            'openmmlab/upernet-convnext-small')

    @torch.inference_mode()
    def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
        detect_resolution = kwargs.pop('detect_resolution', 512)
        image_resolution = kwargs.pop('image_resolution', 512)
        image = HWC3(image)
        image = resize_image(image, resolution=detect_resolution)
        image = PIL.Image.fromarray(image)

        pixel_values = self.image_processor(image,
                                            return_tensors='pt').pixel_values
        outputs = self.image_segmentor(pixel_values)
        seg = self.image_processor.post_process_semantic_segmentation(
            outputs, target_sizes=[image.size[::-1]])[0]
        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
        for label, color in enumerate(ade_palette()):
            color_seg[seg == label, :] = color
        color_seg = color_seg.astype(np.uint8)

        color_seg = resize_image(color_seg,
                                 resolution=image_resolution,
                                 interpolation=cv2.INTER_NEAREST)
        return PIL.Image.fromarray(color_seg)

class Preprocessor:
    MODEL_ID = 'lllyasviel/Annotators'

    def __init__(self):
        self.model = None
        self.name = ''

    def load(self, name: str) -> None:
        if name == self.name:
            return
        if name == 'HED':
            self.model = HEDdetector.from_pretrained(self.MODEL_ID)
        elif name == 'Midas':
            self.model = MidasDetector.from_pretrained(self.MODEL_ID)
        elif name == 'MLSD':
            self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
        elif name == 'Openpose':
            self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
        elif name == 'PidiNet':
            self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
        elif name == 'NormalBae':
            self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
        elif name == 'Lineart':
            self.model = LineartDetector.from_pretrained(self.MODEL_ID)
        elif name == 'LineartAnime':
            self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
        elif name == 'Canny':
            self.model = CannyDetector()
        elif name == 'ContentShuffle':
            self.model = ContentShuffleDetector()
        elif name == 'DPT':
            self.model = DepthEstimator()
        elif name == 'UPerNet':
            self.model = ImageSegmentor()
        else:
            raise ValueError
        torch.cuda.empty_cache()
        gc.collect()
        self.name = name

    def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
        if self.name == 'Canny':
            if 'detect_resolution' in kwargs:
                detect_resolution = kwargs.pop('detect_resolution')
                image = np.array(image)
                image = HWC3(image)
                image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            return PIL.Image.fromarray(image)
        elif self.name == 'Midas':
            detect_resolution = kwargs.pop('detect_resolution', 512)
            image_resolution = kwargs.pop('image_resolution', 512)
            image = np.array(image)
            image = HWC3(image)
            image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            return PIL.Image.fromarray(image)
        else:
            return self.model(image, **kwargs)


# =====================================
# Base Model
# =====================================

CONTROLNET_MODEL_IDS = {
    'Openpose': 'lllyasviel/control_v11p_sd15_openpose',
    'Canny': 'lllyasviel/control_v11p_sd15_canny',
    'MLSD': 'lllyasviel/control_v11p_sd15_mlsd',
    'scribble': 'lllyasviel/control_v11p_sd15_scribble',
    'softedge': 'lllyasviel/control_v11p_sd15_softedge',
    'segmentation': 'lllyasviel/control_v11p_sd15_seg',
    'depth': 'lllyasviel/control_v11f1p_sd15_depth',
    'NormalBae': 'lllyasviel/control_v11p_sd15_normalbae',
    'lineart': 'lllyasviel/control_v11p_sd15_lineart',
    'lineart_anime': 'lllyasviel/control_v11p_sd15s2_lineart_anime',
    'shuffle': 'lllyasviel/control_v11e_sd15_shuffle',
    'ip2p': 'lllyasviel/control_v11e_sd15_ip2p',
    'Inpaint': 'lllyasviel/control_v11p_sd15_inpaint',
    'txt2img': 'NotControlnet',
}

def download_all_controlnet_weights() -> None:
    for model_id in CONTROLNET_MODEL_IDS.values():
        ControlNetModel.from_pretrained(model_id)

class Model:
    def __init__(self,
                 base_model_id: str = 'runwayml/stable-diffusion-v1-5',
                 task_name: str = 'Canny', vae_model=None, type_model_precision = torch.float16):
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.base_model_id = ''
        self.task_name = ''
        self.vae_model = None
        self.type_model_precision = type_model_precision if torch.cuda.is_available() else torch.float32 # For SD 1.5

        self.pipe = self.load_pipe(base_model_id, task_name, vae_model, type_model_precision)
        self.preprocessor = Preprocessor()

    def load_pipe(self, base_model_id: str, task_name, vae_model=None, type_model_precision = torch.float16, reload=False) -> DiffusionPipeline:
        if base_model_id == self.base_model_id and task_name == self.task_name and hasattr(
                self, 'pipe') and self.vae_model==vae_model and self.pipe is not None and reload==False and type_model_precision == self.type_model_precision:
            print('Previous loaded')
            return self.pipe
        if base_model_id == self.base_model_id and task_name == self.task_name and hasattr(
                self, 'pipe') and self.vae_model==vae_model and self.pipe is not None and reload==False and self.device == "cpu":
            print('Pipe in CPU')
            return self.pipe


        self.type_model_precision = type_model_precision if torch.cuda.is_available() else torch.float32
        self.pipe = None
        torch.cuda.empty_cache()
        gc.collect()

        model_id = CONTROLNET_MODEL_IDS[task_name]

        if task_name == 'txt2img':
            if os.path.exists(base_model_id):
                if self.type_model_precision == torch.float32:
                    print("Working with full precision")
                pipe = StableDiffusionPipeline.from_single_file(
                  base_model_id,
                  vae = None if vae_model == 'None' else AutoencoderKL.from_single_file(vae_model), # , torch_dtype=self.type_model_precision
                  torch_dtype=self.type_model_precision,
                )
                pipe.safety_checker = None
            else:
                print("Default VAE: madebyollin/sdxl-vae-fp16-fix")
                pipe = DiffusionPipeline.from_pretrained(
                    base_model_id,
                    vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16),
                    torch_dtype=torch.float16,
                    use_safetensors=True,
                    variant="fp16",
                    )
                pipe.safety_checker = None
            print('Loaded txt2img pipeline')
        elif task_name == 'Inpaint':
            if self.type_model_precision == torch.float32:
                print("Working with full precision")
            controlnet = ControlNetModel.from_pretrained(model_id,
                                                        torch_dtype=self.type_model_precision)
            if os.path.exists(base_model_id):
                pipe = StableDiffusionControlNetInpaintPipeline.from_single_file(
                    base_model_id,
                    vae = None if vae_model == 'None' else AutoencoderKL.from_single_file(vae_model),
                    safety_checker=None,
                    controlnet=controlnet,
                    torch_dtype=self.type_model_precision)
            print('Loaded ControlNet Inpaint pipeline')
        else:
            if self.type_model_precision == torch.float32:
                print("Working with full precision")
            controlnet = ControlNetModel.from_pretrained(model_id,
                                                        torch_dtype=self.type_model_precision) # for all
            if os.path.exists(base_model_id):
                pipe = StableDiffusionControlNetPipeline.from_single_file(
                    base_model_id,
                    vae = None if vae_model == 'None' else AutoencoderKL.from_single_file(vae_model),
                    safety_checker=None,
                    controlnet=controlnet,
                    torch_dtype=self.type_model_precision)
            else:
                raise ZeroDivisionError("Not implemented")
            #     pipe = StableDiffusionControlNetPipeline.from_pretrained(
            #         base_model_id,
            #         vae = AutoencoderKL.from_pretrained(base_model_id, subfolder='vae') if vae_model == 'None' else AutoencoderKL.from_single_file(vae_model),
            #         safety_checker=None,
            #         controlnet=controlnet,
            #         torch_dtype=torch.float16)
            # print('Loaded ControlNet pipeline')

            pipe.scheduler = UniPCMultistepScheduler.from_config(
                pipe.scheduler.config)

        if self.device.type == 'cuda':
            pipe.enable_xformers_memory_efficient_attention()

        pipe.to(self.device)
        torch.cuda.empty_cache()
        gc.collect()
        self.pipe = pipe
        self.base_model_id = base_model_id
        self.task_name = task_name
        self.vae_model = vae_model

        self.lora_memory = [None, None, None, None, None] # no need __init__
        self.lora_scale_memory = [1.0, 1.0, 1.0, 1.0, 1.0]
        self.FreeU = False
        return pipe

    def set_base_model(self, base_model_id: str) -> str:
        if not base_model_id or base_model_id == self.base_model_id:
            return self.base_model_id
        del self.pipe
        torch.cuda.empty_cache()
        gc.collect()
        try:
            self.pipe = self.load_pipe(base_model_id, self.task_name, self.vae_model)
        except Exception:
            self.pipe = self.load_pipe(self.base_model_id, self.task_name, self.vae_model)
        return self.base_model_id

    def load_controlnet_weight(self, task_name: str) -> None:
        if task_name == self.task_name:
            return
        if self.pipe is not None and hasattr(self.pipe, 'controlnet'):
            del self.pipe.controlnet
        torch.cuda.empty_cache()
        gc.collect()
        model_id = CONTROLNET_MODEL_IDS[task_name]
        controlnet = ControlNetModel.from_pretrained(model_id,
                                                     torch_dtype=self.type_model_precision)
        controlnet.to(self.device)
        torch.cuda.empty_cache()
        gc.collect()
        self.pipe.controlnet = controlnet
        self.task_name = task_name

    def get_prompt(self, prompt: str, additional_prompt: str) -> str:
        if not prompt:
            prompt = additional_prompt
        else:
            prompt = f'{prompt}, {additional_prompt}'
        return prompt

    @torch.autocast('cuda')
    def run_pipe(
        self,
        prompt: str,
        negative_prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        control_image: PIL.Image.Image,
        num_images: int,
        num_steps: int,
        guidance_scale: float,
        clip_skip: int,
        generator,
        controlnet_conditioning_scale,
        control_guidance_start,
        control_guidance_end,
    ) -> list[PIL.Image.Image]:
        # Return PIL images
        #generator = torch.Generator().manual_seed(seed)
        return self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            clip_skip = clip_skip,
            num_images_per_prompt=num_images,
            num_inference_steps=num_steps,
            generator=generator,
            controlnet_conditioning_scale = controlnet_conditioning_scale,
            control_guidance_start = control_guidance_start,
            control_guidance_end = control_guidance_end,
            image=control_image
            ).images

    @torch.autocast('cuda')
    def run_pipe_SD(
        self,
        prompt: str,
        negative_prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        num_images: int,
        num_steps: int,
        guidance_scale: float,
        clip_skip: int,
        height : int,
        width : int,
        generator,
    ) -> list[PIL.Image.Image]:
        # Return PIL images
        #generator = torch.Generator().manual_seed(seed)
        return self.pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            clip_skip = clip_skip,
            num_images_per_prompt=num_images,
            num_inference_steps=num_steps,
            generator=generator,
            height = height,
            width = width,
            ).images


    # @torch.autocast('cuda')
    # def run_pipe_SDXL(
    #     self,
    #     prompt: str,
    #     negative_prompt: str,
    #     prompt_embeds,
    #     negative_prompt_embeds,
    #     num_images: int,
    #     num_steps: int,
    #     guidance_scale: float,
    #     clip_skip: int,
    #     height : int,
    #     width : int,
    #     generator,
    #     seddd,
    #     conditioning,
    #     pooled,
    # ) -> list[PIL.Image.Image]:
    #     # Return PIL images
    #     #generator = torch.Generator("cuda").manual_seed(seddd) # generator = torch.Generator("cuda").manual_seed(seed),
    #     return self.pipe(
    #         prompt = None,
    #         negative_prompt = None,
    #         prompt_embeds=conditioning[0:1],
    #         pooled_prompt_embeds=pooled[0:1],
    #         negative_prompt_embeds=conditioning[1:2],
    #         negative_pooled_prompt_embeds=pooled[1:2],
    #         height = height,
    #         width = width,
    #         num_inference_steps = num_steps,
    #         guidance_scale = guidance_scale,
    #         clip_skip = clip_skip,
    #         num_images_per_prompt = num_images,
    #         generator = generator,
    #         ).images

    @torch.autocast('cuda')
    def run_pipe_inpaint(
        self,
        prompt: str,
        negative_prompt: str,
        prompt_embeds,
        negative_prompt_embeds,
        control_image: PIL.Image.Image,
        num_images: int,
        num_steps: int,
        guidance_scale: float,
        clip_skip: int,
        strength: float,
        init_image,
        control_mask,
        controlnet_conditioning_scale,
        control_guidance_start,
        control_guidance_end,
        generator,
    ) -> list[PIL.Image.Image]:
        # Return PIL images
        #generator = torch.Generator().manual_seed(seed)
        return self.pipe(
            prompt= None,
            negative_prompt= None,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            eta=1.0,
            strength = strength,
            image = init_image, # original image
            mask_image = control_mask, # mask, values of 0 to 255
            control_image = control_image, # tensor control image
            num_images_per_prompt  = num_images,
            num_inference_steps = num_steps,
            guidance_scale = guidance_scale,
            clip_skip = clip_skip,
            generator = generator,
            controlnet_conditioning_scale = controlnet_conditioning_scale,
            control_guidance_start = control_guidance_start,
            control_guidance_end = control_guidance_end,
            ).images

    ### self.x_process return image_preprocessor###
    @torch.inference_mode()
    def process_canny(
        self,
        image: np.ndarray,
        image_resolution: int,
        preprocess_resolution: int,
        low_threshold: int,
        high_threshold: int,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        self.preprocessor.load('Canny')
        control_image = self.preprocessor(
            image=image,
            low_threshold=low_threshold,
            high_threshold=high_threshold,
            image_resolution=image_resolution,
            detect_resolution=preprocess_resolution,
        )

        self.load_controlnet_weight('Canny')

        return control_image

    @torch.inference_mode()
    def process_mlsd(
        self,
        image: np.ndarray,
        image_resolution: int,
        preprocess_resolution: int,
        value_threshold: float,
        distance_threshold: float,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        self.preprocessor.load('MLSD')
        control_image = self.preprocessor(
            image=image,
            image_resolution=image_resolution,
            detect_resolution=preprocess_resolution,
            thr_v=value_threshold,
            thr_d=distance_threshold,
        )
        self.load_controlnet_weight('MLSD')

        return control_image

    @torch.inference_mode()
    def process_scribble(
        self,
        image: np.ndarray,
        image_resolution: int,
        preprocess_resolution: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        elif preprocessor_name == 'HED':
            self.preprocessor.load(preprocessor_name)
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                scribble=False,
            )
        elif preprocessor_name == 'PidiNet':
            self.preprocessor.load(preprocessor_name)
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                safe=False,
            )
        self.load_controlnet_weight('scribble')

        return control_image

    @torch.inference_mode()
    def process_scribble_interactive(
        self,
        image_and_mask: dict[str, np.ndarray],
        image_resolution: int,
    ) -> list[PIL.Image.Image]:
        if image_and_mask is None:
            raise ValueError

        image = image_and_mask['mask']
        image = HWC3(image)
        image = resize_image(image, resolution=image_resolution)
        control_image = PIL.Image.fromarray(image)

        self.load_controlnet_weight('scribble')

        return control_image

    @torch.inference_mode()
    def process_softedge(
        self,
        image: np.ndarray,
        image_resolution: int,
        preprocess_resolution: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        elif preprocessor_name in ['HED', 'HED safe']:
            safe = 'safe' in preprocessor_name
            self.preprocessor.load('HED')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                scribble=safe,
            )
        elif preprocessor_name in ['PidiNet', 'PidiNet safe']:
            safe = 'safe' in preprocessor_name
            self.preprocessor.load('PidiNet')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                safe=safe,
            )
        else:
            raise ValueError
        self.load_controlnet_weight('softedge')

        return control_image

    @torch.inference_mode()
    def process_openpose(
        self,
        image: np.ndarray,
        image_resolution: int,
        preprocess_resolution: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        else:
            self.preprocessor.load('Openpose')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                hand_and_face=True,
            )
        self.load_controlnet_weight('Openpose')

        return control_image

    @torch.inference_mode()
    def process_segmentation(
        self,
        image: np.ndarray,
        image_resolution: int,
        preprocess_resolution: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        else:
            self.preprocessor.load(preprocessor_name)
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
            )
        self.load_controlnet_weight('segmentation')

        return control_image

    @torch.inference_mode()
    def process_depth(
        self,
        image: np.ndarray,
        image_resolution: int,
        preprocess_resolution: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        else:
            self.preprocessor.load(preprocessor_name)
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
            )
        self.load_controlnet_weight('depth')

        return control_image

    @torch.inference_mode()
    def process_normal(
        self,
        image: np.ndarray,
        image_resolution: int,
        preprocess_resolution: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        else:
            self.preprocessor.load('NormalBae')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
            )
        self.load_controlnet_weight('NormalBae')

        return control_image

    @torch.inference_mode()
    def process_lineart(
        self,
        image: np.ndarray,
        image_resolution: int,
        preprocess_resolution: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        if preprocessor_name in ['None', 'None (anime)']:
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        elif preprocessor_name in ['Lineart', 'Lineart coarse']:
            coarse = 'coarse' in preprocessor_name
            self.preprocessor.load('Lineart')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
                coarse=coarse,
            )
        elif preprocessor_name == 'Lineart (anime)':
            self.preprocessor.load('LineartAnime')
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
                detect_resolution=preprocess_resolution,
            )
        if 'anime' in preprocessor_name:
            self.load_controlnet_weight('lineart_anime')
        else:
            self.load_controlnet_weight('lineart')

        return control_image

    @torch.inference_mode()
    def process_shuffle(
        self,
        image: np.ndarray,
        image_resolution: int,
        preprocessor_name: str,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        if preprocessor_name == 'None':
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            control_image = PIL.Image.fromarray(image)
        else:
            self.preprocessor.load(preprocessor_name)
            control_image = self.preprocessor(
                image=image,
                image_resolution=image_resolution,
            )
        self.load_controlnet_weight('shuffle')

        return control_image

    @torch.inference_mode()
    def process_ip2p(
        self,
        image: np.ndarray,
        image_resolution: int,
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        image = HWC3(image)
        image = resize_image(image, resolution=image_resolution)
        control_image = PIL.Image.fromarray(image)
        self.load_controlnet_weight('ip2p')

        return control_image

    @torch.inference_mode()
    def process_inpaint(
        self,
        image: np.ndarray,
        image_resolution: int,
        preprocess_resolution: int,
        image_mask: str,###
    ) -> list[PIL.Image.Image]:
        if image is None:
            raise ValueError

        image = HWC3(image)
        image = resize_image(image, resolution=image_resolution)
        init_image = PIL.Image.fromarray(image)

        image_mask = HWC3(image_mask)
        image_mask = resize_image(image_mask, resolution=image_resolution)
        control_mask = PIL.Image.fromarray(image_mask)

        control_image = make_inpaint_condition(init_image, control_mask)

        self.load_controlnet_weight('Inpaint')

        return init_image, control_mask, control_image

    def get_scheduler(self, name):
      #Get scheduler
      match name:

        case "DPM++ 2M":
          return DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)

        case "DPM++ 2M Karras":
          return DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config, use_karras_sigmas=True)

        case "DPM++ 2M SDE":
          return DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config, algorithm_type="sde-dpmsolver++")

        case "DPM++ 2M SDE Karras":
          return DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")

        case "DPM++ SDE":
          return DPMSolverSinglestepScheduler.from_config(self.pipe.scheduler.config, )

        case "DPM++ SDE Karras":
          return DPMSolverSinglestepScheduler.from_config(self.pipe.scheduler.config, use_karras_sigmas=True)

        case "DPM2":
          return KDPM2DiscreteScheduler.from_config(self.pipe.scheduler.config, )

        case "DPM2 Karras":
          return KDPM2DiscreteScheduler.from_config(self.pipe.scheduler.config, use_karras_sigmas=True)

        case "Euler":
          return EulerDiscreteScheduler.from_config(self.pipe.scheduler.config, )

        case "Euler a":
          return EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config, )

        case "Heun":
          return HeunDiscreteScheduler.from_config(self.pipe.scheduler.config, )

        case "LMS":
          return LMSDiscreteScheduler.from_config(self.pipe.scheduler.config, )

        case "LMS Karras":
          return LMSDiscreteScheduler.from_config(self.pipe.scheduler.config, use_karras_sigmas=True)

        case "DDIMScheduler":
          return DDIMScheduler.from_config(self.pipe.scheduler.config)

        case "DEISMultistepScheduler":
          return DEISMultistepScheduler.from_config(self.pipe.scheduler.config)

        case "UniPCMultistepScheduler":
          return UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)

    def process_lora(self, select_lora, lora_weights_scale, unload=False):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if not unload:
            if select_lora != None:
                  try:
                      self.pipe = lora_mix_load(self.pipe, select_lora, lora_weights_scale, device=device, dtype=self.type_model_precision)
                      print(select_lora)
                  except:
                      print(f"ERROR: LoRA not compatible: {select_lora}")
            return self.pipe
        else:
            # Unload numerically unstable but fast
            if select_lora != None:
                try:
                    self.pipe = lora_mix_load(self.pipe, select_lora, -lora_weights_scale, device=device, dtype=self.type_model_precision)
                    #print(select_lora, 'unload')
                except:
                    pass
            return self.pipe

    def __call__(
        self,
        prompt = None,
        negative_prompt = None,
        # prompt_embeds = None,
        # negative_prompt_embeds = None,
        img_height = None,
        img_width = None,
        num_images = None,
        num_steps = None,
        guidance_scale = None,
        clip_skip = True,
        seed = None,
        image = None, # path, np.array, or PIL image
        preprocessor_name = None,
        preprocess_resolution = None,
        image_resolution = None,
        additional_prompt = "",
        image_mask = None,
        strength = None,
        low_threshold=None, # Canny
        high_threshold=None, # Canny
        value_threshold=None, # MLSD
        distance_threshold=None, # MLSD
        lora_A = None,
        lora_scale_A = 1.0,
        lora_B = None,
        lora_scale_B = 1.0,
        lora_C = None,
        lora_scale_C = 1.0,
        lora_D = None,
        lora_scale_D = 1.0,
        lora_E = None,
        lora_scale_E = 1.0,
        active_textual_inversion = False,
        textual_inversion = [], # List of tuples [(activation_token, path_embedding),...]
        convert_weights_prompt = False,
        sampler = None,
        xformers_memory_efficient_attention = True,
        gui_active = False,
        loop_generation = 1,
        controlnet_conditioning_scale = 1.0,
        control_guidance_start = 0.0,
        control_guidance_end  = 0.0,
        generator_in_cpu = False, # Initial noise not in CPU
        FreeU = False,
    ):
        if self.task_name != "txt2img" and image == None:
            raise ValueError
        if img_height % 8 != 0 or img_width % 8 != 0:
            raise ValueError("Height and width must be divisible by 8")
        if control_guidance_start >= control_guidance_end:
            raise ValueError("Control guidance start (ControlNet Start Threshold) cannot be larger or equal to control guidance end (ControlNet Stop Threshold)")



        if self.pipe == None:
            self.load_pipe(self.base_model_id, task_name = self.task_name, vae_model = self.vae_model)

        #self.pipe.to(self.device)
        # self.pipe.unfuse_lora()
        # self.pipe.unload_lora_weights()
        # self.pipe = self.process_lora(lora_A, lora_scale_A)
        # self.pipe = self.process_lora(lora_B, lora_scale_B)
        # self.pipe = self.process_lora(lora_C, lora_scale_C)
        # self.pipe = self.process_lora(lora_D, lora_scale_D)
        # self.pipe = self.process_lora(lora_E, lora_scale_E)

        if xformers_memory_efficient_attention and torch.cuda.is_available():
            self.pipe.disable_xformers_memory_efficient_attention()
        self.pipe.to(self.device)

        # in call
        if self.lora_memory == [lora_A, lora_B, lora_C, lora_D, lora_E] and self.lora_scale_memory == [lora_scale_A, lora_scale_B, lora_scale_C, lora_scale_D, lora_scale_E]:
                for single_lora in self.lora_memory:
                    if single_lora != None:
                        print(f"LoRA in memory:{single_lora}")
                pass

        else:
            #print("_un, re and load_")
            self.pipe = self.process_lora(self.lora_memory[0], self.lora_scale_memory[0], unload=True)
            self.pipe = self.process_lora(self.lora_memory[1], self.lora_scale_memory[1], unload=True)
            self.pipe = self.process_lora(self.lora_memory[2], self.lora_scale_memory[2], unload=True)
            self.pipe = self.process_lora(self.lora_memory[3], self.lora_scale_memory[3], unload=True)
            self.pipe = self.process_lora(self.lora_memory[4], self.lora_scale_memory[4], unload=True)

            self.pipe = self.process_lora(lora_A, lora_scale_A)
            self.pipe = self.process_lora(lora_B, lora_scale_B)
            self.pipe = self.process_lora(lora_C, lora_scale_C)
            self.pipe = self.process_lora(lora_D, lora_scale_D)
            self.pipe = self.process_lora(lora_E, lora_scale_E)

        self.lora_memory = [lora_A, lora_B, lora_C, lora_D, lora_E]
        self.lora_scale_memory = [lora_scale_A, lora_scale_B, lora_scale_C, lora_scale_D, lora_scale_E]

        # FreeU for txt2img
        if self.task_name == 'txt2img':
            if FreeU:
                print("FreeU active")
                if os.path.exists(self.base_model_id):
                    #sd
                    self.pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
                else:
                    #sdxl
                    self.pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
                self.FreeU = True
            elif self.FreeU:
                self.pipe.disable_freeu()
                self.FreeU = False

        # Prompt Optimizations for 1.5
        if os.path.exists(self.base_model_id):
            if  active_textual_inversion:
              # Textual Inversion
              for name, directory_name in textual_inversion:
                  try:
                          #self.pipe.text_encoder.resize_token_embeddings(len(self.pipe.tokenizer),pad_to_multiple_of=128)
                          #self.pipe.load_textual_inversion("./bad_prompt.pt", token="baddd")
                          self.pipe.load_textual_inversion(directory_name, token=name)
                  except ValueError:
                      # previous loaded ti
                      pass
                  except:
                      print(f"Can't apply {name}")

            #Clip skip
            if clip_skip:
                #clip_skip_diffusers = None #clip_skip - 1 # future update
                compel = Compel(
                    tokenizer=self.pipe.tokenizer,
                    text_encoder=self.pipe.text_encoder,
                    truncate_long_prompts=False,
                    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
                )
            else:
                #clip_skip_diffusers = None # clip_skip = None # future update
                compel = Compel(
                    tokenizer=self.pipe.tokenizer,
                    text_encoder=self.pipe.text_encoder,
                    truncate_long_prompts=False
                )

            # Prompt weights for textual inversion
            prompt_ti = self.pipe.maybe_convert_prompt(prompt, self.pipe.tokenizer)
            negative_prompt_ti = self.pipe.maybe_convert_prompt(negative_prompt, self.pipe.tokenizer)

            # prompt syntax style a1...
            if convert_weights_prompt:
                prompt_ti = prompt_weight_conversor(prompt_ti)
                negative_prompt_ti = prompt_weight_conversor(negative_prompt_ti)

            # prompt embed chunks style a1...
            prompt_emb = merge_embeds(tokenize_line(prompt_ti, self.pipe.tokenizer), compel)
            negative_prompt_emb = merge_embeds(tokenize_line(negative_prompt_ti, self.pipe.tokenizer), compel)

            # fix error shape
            if prompt_emb.shape != negative_prompt_emb.shape:
                prompt_emb, negative_prompt_emb = compel.pad_conditioning_tensors_to_same_length([prompt_emb, negative_prompt_emb])

            compel = None
            del compel

        # Prompt Optimizations for SDXL
        else:
            if  active_textual_inversion:
              # Textual Inversion
              # for name, directory_name in textual_inversion:
              #     try:
              #             #self.pipe.text_encoder.resize_token_embeddings(len(self.pipe.tokenizer),pad_to_multiple_of=128)
              #             #self.pipe.load_textual_inversion("./bad_prompt.pt", token="baddd")
              #             self.pipe.load_textual_inversion(directory_name, token=name)
              #     except ValueError:
              #         # previous loaded ti
              #         pass
              #     except:
              #         print(f"Can't apply {name}")
              print("SDXL textual inversion not available")

            #Clip skip
            if clip_skip:
                #clip_skip_diffusers = None #clip_skip - 1 # future update
                compel = Compel(
                    tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] ,
                    text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
                    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
                    requires_pooled=[False, True],
                    truncate_long_prompts=False)
            else:
                #clip_skip_diffusers = None # clip_skip = None # future update
                compel = Compel(
                    tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] ,
                    text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
                    requires_pooled=[False, True],
                    truncate_long_prompts=False)

            # Prompt weights for textual inversion
            # prompt_ti = self.pipe.maybe_convert_prompt(prompt, self.pipe.tokenizer)
            # negative_prompt_ti = self.pipe.maybe_convert_prompt(negative_prompt, self.pipe.tokenizer)

            # prompt syntax style a1...
            if convert_weights_prompt:
                prompt_ti = prompt_weight_conversor(prompt)
                negative_prompt_ti = prompt_weight_conversor(negative_prompt)
            else:
                prompt_ti = prompt
                negative_prompt_ti = negative_prompt

            # prompt embed chunks style a1...
            # prompt_emb = merge_embeds(tokenize_line(prompt_ti, self.pipe.tokenizer), compel)
            # negative_prompt_emb = merge_embeds(tokenize_line(negative_prompt_ti, self.pipe.tokenizer), compel)

            # fix error shape
            # if prompt_emb.shape != negative_prompt_emb.shape:
            #     prompt_emb, negative_prompt_emb = compel.pad_conditioning_tensors_to_same_length([prompt_emb, negative_prompt_emb])

            conditioning, pooled = compel([prompt_ti, negative_prompt_ti])
            prompt_emb = None
            negative_prompt_emb = None

            compel = None
            del compel

        if torch.cuda.is_available():
            if xformers_memory_efficient_attention:
                self.pipe.enable_xformers_memory_efficient_attention()
            else:
                self.pipe.disable_xformers_memory_efficient_attention()

        try:
            self.pipe.scheduler = self.get_scheduler(sampler)
        except:
            print("Error in sampler, please try again")
            self.pipe = None
            self.base_model_id = None
            torch.cuda.empty_cache()
            gc.collect()
            return

        self.pipe.safety_checker = None

        # Get Control image
        if self.task_name != 'txt2img':
            if isinstance(image, str):
                # If the input is a string (file path), open it as an image
                image_pil = Image.open(image)
                numpy_array = np.array(image_pil, dtype=np.uint8)
            elif isinstance(image, Image.Image):
                # If the input is already a PIL Image, convert it to a NumPy array
                numpy_array = np.array(image, dtype=np.uint8)
            elif isinstance(image, np.ndarray):
                # If the input is a NumPy array, np.uint8
                numpy_array = image.astype(np.uint8)
            else:
                self.pipe = None
                self.base_model_id = None
                torch.cuda.empty_cache()
                gc.collect()
                if gui_active:
                    print("To use this function, you have to upload an image in the cell below first 👇")
                    return
                else:
                    raise ValueError("Unsupported image type or not control image found")

            # Extract the RGB channels
            try:
                array_rgb = numpy_array[:, :, :3]
            except:
                print("Unsupported image type")
                self.pipe = None
                self.base_model_id = None
                torch.cuda.empty_cache()
                gc.collect()
                raise ValueError("Unsupported image type") # return

        # Get params preprocess
        preprocess_params_config = {}
        if self.task_name != 'txt2img' and self.task_name != 'Inpaint':
            preprocess_params_config["image"] = array_rgb
            preprocess_params_config["image_resolution"] = image_resolution
            #preprocess_params_config["additional_prompt"] = additional_prompt # ""

            if self.task_name != "ip2p":
                if  self.task_name != "shuffle":
                    preprocess_params_config["preprocess_resolution"] = preprocess_resolution
                if self.task_name != "MLSD" and self.task_name != "Canny":
                    preprocess_params_config["preprocessor_name"] = preprocessor_name

        # RUN Preprocess
        if self.task_name == 'Inpaint':

            # Get mask for Inpaint
            if gui_active or os.path.exists(image_mask):
                # Read image mask from gui
                mask_control_img = Image.open(image_mask)
                numpy_array_mask = np.array(mask_control_img, dtype=np.uint8)
                array_rgb_mask = numpy_array_mask[:, :, :3]
            elif not gui_active:
                # Convert control image for draw
                name_without_extension = os.path.splitext(image.split('/')[-1])[0]
                image64 = base64.b64encode(open(image, 'rb').read())
                image64 = image64.decode('utf-8')
                img = np.array(plt.imread(f'{image}')[:,:,:3])

                # Create mask interactive
                draw(image64, filename=f"./{name_without_extension}_draw.png", w=img.shape[1], h=img.shape[0], line_width=0.04*img.shape[1])

                # Create mask and save
                with_mask = np.array(plt.imread(f"./{name_without_extension}_draw.png")[:,:,:3])
                mask = (with_mask[:,:,0]==1)*(with_mask[:,:,1]==0)*(with_mask[:,:,2]==0)
                plt.imsave(f"./{name_without_extension}_mask.png",mask, cmap='gray')
                mask_control = f"./{name_without_extension}_mask.png"
                print(f'Mask saved: {mask_control}')

                # Read image mask
                mask_control_img = Image.open(mask_control)
                numpy_array_mask = np.array(mask_control_img, dtype=np.uint8)
                array_rgb_mask = numpy_array_mask[:, :, :3]
            else:
                raise ValueError("No images found")

            init_image, control_mask, control_image = self.process_inpaint(
                image=array_rgb,
                image_resolution=image_resolution,
                preprocess_resolution=preprocess_resolution, # Not used
                image_mask=array_rgb_mask,
            )

        elif self.task_name == 'Openpose':
            print('Openpose')
            control_image = self.process_openpose(**preprocess_params_config)

        elif self.task_name == 'Canny':
            print('Canny')
            control_image = self.process_canny(
                **preprocess_params_config,
                low_threshold=low_threshold,
                high_threshold=high_threshold,
            )

        elif self.task_name == 'MLSD':
            print('MLSD')
            control_image = self.process_mlsd(
                **preprocess_params_config,
                value_threshold=value_threshold,
                distance_threshold=distance_threshold,
            )

        elif self.task_name == 'scribble':
            print('Scribble')
            control_image = self.process_scribble(**preprocess_params_config)

        elif self.task_name == 'softedge':
            print('Softedge')
            control_image = self.process_softedge(**preprocess_params_config)

        elif self.task_name == 'segmentation':
            print('Segmentation')
            control_image = self.process_segmentation(**preprocess_params_config)

        elif self.task_name == 'depth':
            print('Depth')
            control_image = self.process_depth(**preprocess_params_config)

        elif self.task_name == 'NormalBae':
            print('NormalBae')
            control_image = self.process_normal(**preprocess_params_config)

        elif 'lineart' in self.task_name:
            print('Lineart')
            control_image = self.process_lineart(**preprocess_params_config)

        elif self.task_name == 'shuffle':
            print('Shuffle')
            control_image = self.process_shuffle(**preprocess_params_config)

        elif self.task_name == 'ip2p':
            print('Ip2p')
            control_image = self.process_ip2p(**preprocess_params_config)


        # Get params for TASK
        pipe_params_config = {
            "prompt": None, #prompt, #self.get_prompt(prompt, additional_prompt),
            "negative_prompt": None, #negative_prompt,
            "prompt_embeds": prompt_emb,
            "negative_prompt_embeds": negative_prompt_emb,
            "num_images": num_images,
            "num_steps": num_steps,
            "guidance_scale": guidance_scale,
            "clip_skip": None, #clip_skip, because we use clip skip of compel
        }

        if not os.path.exists(self.base_model_id):
            # pipe_params_config["prompt_embeds"]                 = conditioning[0:1]
            # pipe_params_config["pooled_prompt_embeds"]          = pooled[0:1],
            # pipe_params_config["negative_prompt_embeds"]        = conditioning[1:2]
            # pipe_params_config["negative_pooled_prompt_embeds"] = pooled[1:2],
            # pipe_params_config["conditioning"]                  = conditioning,
            # pipe_params_config["pooled"]                        = pooled,
            # pipe_params_config["height"]                        = img_height
            # pipe_params_config["width"]                         = img_width
            pass
        elif self.task_name == 'txt2img':
            pipe_params_config["height"]                        = img_height
            pipe_params_config["width"]                         = img_width
        elif self.task_name == "Inpaint":
            pipe_params_config["strength"]                      = strength
            pipe_params_config["init_image"]                    = init_image
            pipe_params_config["control_mask"]                  = control_mask
            pipe_params_config["control_image"]                 = control_image
            pipe_params_config["controlnet_conditioning_scale"] = controlnet_conditioning_scale
            pipe_params_config["control_guidance_start"]        = control_guidance_start
            pipe_params_config["control_guidance_end"]          = control_guidance_end
            print(f"Image resolution: {str(init_image.size)}")
        elif self.task_name != 'txt2img' and self.task_name != 'Inpaint':
            pipe_params_config["control_image"]                 = control_image
            pipe_params_config["controlnet_conditioning_scale"] = controlnet_conditioning_scale
            pipe_params_config["control_guidance_start"]        = control_guidance_start
            pipe_params_config["control_guidance_end"]          = control_guidance_end
            print(f"Image resolution: {str(control_image.size)}")

        ### RUN PIPE ###
        for i in range (loop_generation):

            calculate_seed = random.randint(0, 2147483647) if seed == -1 else seed
            if generator_in_cpu:
                pipe_params_config["generator"] = torch.Generator().manual_seed(calculate_seed)
            else:
                try:
                    pipe_params_config["generator"] = torch.Generator("cuda").manual_seed(calculate_seed)
                except:
                    print("Generator in CPU")
                    pipe_params_config["generator"] = torch.Generator().manual_seed(calculate_seed)

            if not os.path.exists(self.base_model_id):
                #images = self.run_pipe_SDXL(**pipe_params_config)
                images = self.pipe(
                    prompt = None,
                    negative_prompt = None,
                    prompt_embeds=conditioning[0:1],
                    pooled_prompt_embeds=pooled[0:1],
                    negative_prompt_embeds=conditioning[1:2],
                    negative_pooled_prompt_embeds=pooled[1:2],
                    height = img_height,
                    width = img_width,
                    num_inference_steps = num_steps,
                    guidance_scale = guidance_scale,
                    clip_skip = None,
                    num_images_per_prompt = num_images,
                    generator = pipe_params_config["generator"],
                ).images
            elif self.task_name == 'txt2img':
                images = self.run_pipe_SD(**pipe_params_config)
            elif self.task_name == "Inpaint":
                images = self.run_pipe_inpaint(**pipe_params_config)
            elif self.task_name != 'txt2img' and self.task_name != 'Inpaint':
                results = self.run_pipe(**pipe_params_config) ## pipe ControlNet add condition_weights
                images =  [control_image] + results
                del results

            torch.cuda.empty_cache()
            gc.collect()

            if loop_generation > 1:

                mediapy.show_images(images)
                #print(image_list)
                #del images
                time.sleep(1)

            # save image
            image_list = []
            metadata = [
                    prompt,
                    negative_prompt,
                    self.base_model_id,
                    self.vae_model,
                    num_steps,
                    guidance_scale,
                    sampler,
                    calculate_seed
            ]

            for image_ in images:
                image_path = save_pil_image_with_metadata(image_, "./images", metadata)
                image_list.append(image_path)

            if loop_generation > 1:
                torch.cuda.empty_cache()
                gc.collect()
                print(image_list)
            print(f"Seed:\n{calculate_seed}")



        # if select_lora1.value != "None":
        #     model.pipe.unfuse_lora()
        #     model.pipe.unload_lora_weights()
        # if select_lora2.value != "None" or select_lora3.value != "None":
        #     print('BETA: reload weights for lora')
        #     model.load_pipe(select_model.value, task_name=options_controlnet.value, vae_model = vae_model_dropdown.value, reload=True)
        # self.pipe.to(self.device)
        # self.pipe = self.process_lora(lora_A, lora_scale_A, unload=True)
        # self.pipe = self.process_lora(lora_B, lora_scale_B, unload=True)
        # self.pipe = self.process_lora(lora_C, lora_scale_C, unload=True)
        # self.pipe = self.process_lora(lora_D, lora_scale_D, unload=True)
        # self.pipe = self.process_lora(lora_E, lora_scale_E, unload=True)
        # model.pipe.unfuse_lora()
        # model.pipe.unload_lora_weights()

        torch.cuda.empty_cache()
        gc.collect()

        return images, image_list



# =====================================
# Prompt weights
# =====================================
from compel import Compel
from diffusers import StableDiffusionPipeline, DDIMScheduler
import re
from compel import EmbeddingsProvider, ReturnedEmbeddingsType

def concat_tensor(t):
    t_list = torch.split(t, 1, dim=0)
    t = torch.cat(t_list, dim=1)
    return t

def merge_embeds(prompt_chanks, compel):
    num_chanks = len(prompt_chanks)
    power_prompt = 1/(num_chanks*(num_chanks+1)//2)
    prompt_embs = compel(prompt_chanks)
    t_list = list(torch.split(prompt_embs, 1, dim=0))
    for i in range(num_chanks):
        t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt)
    prompt_emb = torch.stack(t_list, dim=0).sum(dim=0)
    return prompt_emb

def detokenize(chunk, actual_prompt):
    chunk[-1] = chunk[-1].replace('</w>', '')
    chanked_prompt = ''.join(chunk).strip()
    while '</w>' in chanked_prompt:
        if actual_prompt[chanked_prompt.find('</w>')] == ' ':
            chanked_prompt = chanked_prompt.replace('</w>', ' ', 1)
        else:
            chanked_prompt = chanked_prompt.replace('</w>', '', 1)
    actual_prompt = actual_prompt.replace(chanked_prompt,'')
    return chanked_prompt.strip(), actual_prompt.strip()

def tokenize_line(line, tokenizer): # split into chunks
    actual_prompt = line.lower().strip()
    if actual_prompt == "":
      actual_prompt = 'worst quality'
    actual_tokens = tokenizer.tokenize(actual_prompt)
    max_tokens = tokenizer.model_max_length - 2
    comma_token = tokenizer.tokenize(',')[0]

    chunks = []
    chunk = []
    for item in actual_tokens:
        chunk.append(item)
        if len(chunk) == max_tokens:
            if chunk[-1] != comma_token:
                for i in range(max_tokens-1, -1, -1):
                    if chunk[i] == comma_token:
                        actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt)
                        chunks.append(actual_chunk)
                        chunk = chunk[i+1:]
                        break
                else:
                    actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
                    chunks.append(actual_chunk)
                    chunk = []
            else:
                actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
                chunks.append(actual_chunk)
                chunk = []
    if chunk:
        actual_chunk, _ = detokenize(chunk, actual_prompt)
        chunks.append(actual_chunk)

    return chunks

def prompt_weight_conversor(input_string):
    # Convert prompt weights from a1... to comel

    # Find and replace instances of the colon format with the desired format
    converted_string = re.sub(r'\(([^:]+):([\d.]+)\)', r'(\1)\2', input_string)

    # Find and replace square brackets with round brackets and assign weight
    converted_string = re.sub(r'\[([^:\]]+)\]', r'(\1)0.909090909', converted_string)

    # Handle the general case of [x:number] and convert it to (x)0.9
    converted_string = re.sub(r'\[([^:]+):[\d.]+\]', r'(\1)0.9', converted_string)

    # Add a '+' sign after the closing parenthesis if no weight is specified
    modified_string = re.sub(r'\(([^)]+)\)(?![\d.])', r'(\1)+', converted_string)

    # double (())
    #modified_string = re.sub(r'\(\(([^)]+)\)\+\)', r'(\1)++', modified_string)

    # triple ((()))
    #modified_string = re.sub(r'\(\(([^)]+)\)\+\+\)', r'(\1)+++', modified_string)

    #print(modified_string)
    return modified_string

# =====================================
# IMAGE: METADATA AND SAVE
# =====================================
import os
from PIL import Image
from PIL.PngImagePlugin import PngInfo

def save_pil_image_with_metadata(image, folder_path, metadata_list):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    existing_files = os.listdir(folder_path)

    # Determine the next available image name
    image_name = f"image{str(len(existing_files) + 1).zfill(3)}.png"
    image_path = os.path.join(folder_path, image_name)

    try:
        # metadata
        metadata = PngInfo()
        metadata.add_text("Prompt", str(metadata_list[0]))
        metadata.add_text("Negative prompt", str(metadata_list[1]))
        metadata.add_text("Model", str(metadata_list[2]))
        metadata.add_text("VAE", str(metadata_list[3]))
        metadata.add_text("Steps", str(metadata_list[4]))
        metadata.add_text("CFG", str(metadata_list[5]))
        metadata.add_text("Scheduler", str(metadata_list[6]))
        metadata.add_text("Seed", str(metadata_list[7]))

        image.save(image_path, pnginfo=metadata)
    except:
        print('Saving image without metadata')
        image.save(image_path)

    return image_path

# =====================================
# LoRA Loaders
# =====================================
import torch
from safetensors.torch import load_file
from collections import defaultdict
def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
    LORA_PREFIX_UNET = "lora_unet"
    LORA_PREFIX_TEXT_ENCODER = "lora_te"
    # load LoRA weight from .safetensors
    if isinstance(checkpoint_path, str):

        state_dict = load_file(checkpoint_path, device=device)

        updates = defaultdict(dict)
        for key, value in state_dict.items():
            # it is suggested to print out the key, it usually will be something like below
            # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

            layer, elem = key.split('.', 1)
            updates[layer][elem] = value

        # directly update weight in diffusers model
        for layer, elems in updates.items():

            if "text" in layer:
                layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
                curr_layer = pipeline.text_encoder
            else:
                layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
                curr_layer = pipeline.unet

            # find the target layer
            temp_name = layer_infos.pop(0)
            while len(layer_infos) > -1:
                try:
                    curr_layer = curr_layer.__getattr__(temp_name)
                    if len(layer_infos) > 0:
                        temp_name = layer_infos.pop(0)
                    elif len(layer_infos) == 0:
                        break
                except Exception:
                    if len(temp_name) > 0:
                        temp_name += "_" + layer_infos.pop(0)
                    else:
                        temp_name = layer_infos.pop(0)

            # get elements for this layer
            weight_up = elems['lora_up.weight'].to(dtype)
            weight_down = elems['lora_down.weight'].to(dtype)
            alpha = elems['alpha']
            if alpha:
                alpha = alpha.item() / weight_up.shape[1]
            else:
                alpha = 1.0

            # update weight
            if len(weight_up.shape) == 4:
                curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
            else:
                curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
    else:
        for ckptpath in checkpoint_path:
            state_dict = load_file(ckptpath, device=device)

            updates = defaultdict(dict)
            for key, value in state_dict.items():
                # it is suggested to print out the key, it usually will be something like below
                # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

                layer, elem = key.split('.', 1)
                updates[layer][elem] = value

            # directly update weight in diffusers model
            for layer, elems in updates.items():

                if "text" in layer:
                    layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
                    curr_layer = pipeline.text_encoder
                else:
                    layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
                    curr_layer = pipeline.unet

                # find the target layer
                temp_name = layer_infos.pop(0)
                while len(layer_infos) > -1:
                    try:
                        curr_layer = curr_layer.__getattr__(temp_name)
                        if len(layer_infos) > 0:
                            temp_name = layer_infos.pop(0)
                        elif len(layer_infos) == 0:
                            break
                    except Exception:
                        if len(temp_name) > 0:
                            temp_name += "_" + layer_infos.pop(0)
                        else:
                            temp_name = layer_infos.pop(0)

                # get elements for this layer
                weight_up = elems['lora_up.weight'].to(dtype)
                weight_down = elems['lora_down.weight'].to(dtype)
                alpha = elems['alpha']
                if alpha:
                    alpha = alpha.item() / weight_up.shape[1]
                else:
                    alpha = 1.0

                # update weight
                if len(weight_up.shape) == 4:
                    curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
                else:
                    curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
    return pipeline

def lora_mix_load(pipe, lora_path, alpha_scale=1.0, device='cuda', dtype=torch.float16):
    try:
        pipe=load_lora_weights(pipe, [lora_path], alpha_scale, device=device, dtype=torch.float16)
    except:
        pipe.load_lora_weights(lora_path)
        pipe.fuse_lora(lora_scale=alpha_scale)

    return pipe


# =====================================
# Inpainting canvas
# =====================================
canvas_html = """
<style>
.button {
  background-color: #4CAF50;
  border: none;
  color: white;
  padding: 15px 32px;
  text-align: center;
  text-decoration: none;
  display: inline-block;
  font-size: 16px;
  margin: 4px 2px;
  cursor: pointer;
}
</style>
<canvas1 width=%d height=%d>
</canvas1>
<canvas width=%d height=%d>
</canvas>

<button>Finish</button>
<script>
var canvas = document.querySelector('canvas')
var ctx = canvas.getContext('2d')

var canvas1 = document.querySelector('canvas1')
var ctx1 = canvas.getContext('2d')


ctx.strokeStyle = 'red';

var img = new Image();
img.src = "data:image/%s;charset=utf-8;base64,%s";
console.log(img)
img.onload = function() {
  ctx1.drawImage(img, 0, 0);
};
img.crossOrigin = 'Anonymous';

ctx.clearRect(0, 0, canvas.width, canvas.height);

ctx.lineWidth = %d
var button = document.querySelector('button')
var mouse = {x: 0, y: 0}

canvas.addEventListener('mousemove', function(e) {
  mouse.x = e.pageX - this.offsetLeft
  mouse.y = e.pageY - this.offsetTop
})
canvas.onmousedown = ()=>{
  ctx.beginPath()
  ctx.moveTo(mouse.x, mouse.y)
  canvas.addEventListener('mousemove', onPaint)
}
canvas.onmouseup = ()=>{
  canvas.removeEventListener('mousemove', onPaint)
}
var onPaint = ()=>{
  ctx.lineTo(mouse.x, mouse.y)
  ctx.stroke()
}

var data = new Promise(resolve=>{
  button.onclick = ()=>{
    resolve(canvas.toDataURL('image/png'))
  }
})
</script>
"""

import base64, os
from google.colab.output import eval_js
from base64 import b64decode
import matplotlib.pyplot as plt
import numpy as np
from shutil import copyfile
import shutil
import matplotlib.pyplot as plt

def draw(imgm, filename='drawing.png', w=400, h=200, line_width=1):
  display(HTML(canvas_html % (w, h, w,h, filename.split('.')[-1], imgm, line_width)))
  data = eval_js("data")
  binary = b64decode(data.split(',')[1])
  with open(filename, 'wb') as f:
    f.write(binary)

# the control image of init_image and mask_image
def make_inpaint_condition(image, image_mask):
    image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
    image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0

    assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
    image[image_mask > 0.5] = -1.0  # set as masked pixel
    image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return image

# =====================================
# Adetailer
# =====================================
from functools import partial
from diffusers import DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler
from asdff import AdCnPipeline, AdPipeline, yolo_detector
from huggingface_hub import hf_hub_download

def ad_model_process(
    face_detector_ad,
    person_detector_ad,
    hand_detector_ad,
    model_repo_id,
    common,
    inpaint,
    image_list_task,
    ):

    pipe = AdPipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)

    try:
        pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
    except:
        pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

    pipe.safety_checker = None
    pipe.to("cuda")

    image_list_ad = []

    for path in image_list_task:
        if os.path.exists(path):
            # Open the image using PIL and convert it to PIL.Image.Image
            with Image.open(path) as img:
                images_ad = img.convert("RGB")

        detectors = []
        if face_detector_ad:
            face_model_path = hf_hub_download("Bingsu/adetailer", "face_yolov8n.pt")
            face_detector = partial(yolo_detector, model_path=face_model_path)
            detectors.append(face_detector)
        if person_detector_ad:
            person_model_path = hf_hub_download("Bingsu/adetailer", "person_yolov8s-seg.pt")
            person_detector = partial(yolo_detector, model_path=person_model_path)
            detectors.append(person_detector)
        if hand_detector_ad:
            hand_model_path = hf_hub_download("Bingsu/adetailer", "hand_yolov8n.pt")
            hand_detector = partial(yolo_detector, model_path=hand_model_path)
            detectors.append(hand_detector)

        result_ad = pipe(images=[images_ad], common=common, inpaint_only=inpaint, detectors=detectors)

        try:
            mediapy.show_images([result_ad[0][0], result_ad[1][0]])
        except:
            del pipe
            torch.cuda.empty_cache()
            gc.collect()
            return image_list_task

        image_path = save_pil_image_with_metadata(result_ad[0][0], f'{os.getcwd()}/images', metadata_list=None)
        image_list_ad.append(image_path)

    del pipe
    torch.cuda.empty_cache()
    gc.collect()

    return image_list_ad






# Experimental upscaler
# =====================================
# BASE UPSCALER
# =====================================
devicee = 'cuda' # run in CPU is very slow

from abc import abstractmethod
import PIL
from PIL import Image

LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)



class Upscaler:
    name = None
    model_path = None
    model_name = None
    model_url = None
    enable = True
    filter = None
    model = None
    user_path = None
    scalers: []
    tile = True

    def __init__(self, create_dirs=False):
        self.mod_pad_h = None
        self.tile_size = 100
        self.tile_pad = 10
        self.device = devicee # ⭐
        self.img = None
        self.output = None
        self.scale = 4.0
        self.half = True
        self.pre_pad = 0
        self.mod_scale = None
        self.model_download_path = None

        if self.model_path is None and self.name:
            self.model_path = os.path.join('/content/Real-ESRGAN/weights/RealESRGAN_x2plus.pth', self.name)
        if self.model_path and create_dirs:
            os.makedirs(self.model_path, exist_ok=True)

        try:
            import cv2  # noqa: F401
            self.can_tile = True
        except Exception:
            pass

    @abstractmethod
    def do_upscale(self, img: PIL.Image, selected_model: str):
        return img

    def upscale(self, img: PIL.Image, scale, selected_model: str = None):
        self.scale = scale
        dest_w = int((img.width * scale) // 8 * 8)
        dest_h = int((img.height * scale) // 8 * 8)

        for _ in range(3):
            shape = (img.width, img.height)

            img = self.do_upscale(img, selected_model)

            if shape == (img.width, img.height):
                break

            if img.width >= dest_w and img.height >= dest_h:
                break

        if img.width != dest_w or img.height != dest_h:
            img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)

        return img

    @abstractmethod
    def load_model(self, path: str):
        pass

    def find_models(self, ext_filter=None) -> list:
        return load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)

    def update_status(self, prompt):
        print(f"\nextras: {prompt}", file=shared.progress_print_out)


class UpscalerData:
    name = None
    data_path = None
    scale: int = 4
    scaler: Upscaler = None
    model: None

    def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
        self.name = name
        self.data_path = path
        self.local_data_path = path
        self.scaler = upscaler
        self.scale = scale
        self.model = model


class UpscalerNone(Upscaler):
    name = "None"
    scalers = []

    def load_model(self, path):
        pass

    def do_upscale(self, img, selected_model=None):
        return img

    def __init__(self, dirname=None):
        super().__init__(False)
        self.scalers = [UpscalerData("None", None, self)]


class UpscalerLanczos(Upscaler):
    scalers = []

    def do_upscale(self, img, selected_model=None):
        return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)

    def load_model(self, _):
        pass

    def __init__(self, dirname=None):
        super().__init__(False)
        self.name = "Lanczos"
        self.scalers = [UpscalerData("Lanczos", None, self)]


class UpscalerNearest(Upscaler):
    scalers = []

    def do_upscale(self, img, selected_model=None):
        return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)

    def load_model(self, _):
        pass

    def __init__(self, dirname=None):
        super().__init__(False)
        self.name = "Nearest"
        self.scalers = [UpscalerData("Nearest", None, self)]



# =====================================
# UpscalerESRGAN
# =====================================
import sys

import numpy as np
import torch
from PIL import Image

def mod2normal(state_dict):
    # this code is copied from https://github.com/victorca25/iNNfer
    if 'conv_first.weight' in state_dict:
        crt_net = {}
        items = list(state_dict)

        crt_net['model.0.weight'] = state_dict['conv_first.weight']
        crt_net['model.0.bias'] = state_dict['conv_first.bias']

        for k in items.copy():
            if 'RDB' in k:
                ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
                if '.weight' in k:
                    ori_k = ori_k.replace('.weight', '.0.weight')
                elif '.bias' in k:
                    ori_k = ori_k.replace('.bias', '.0.bias')
                crt_net[ori_k] = state_dict[k]
                items.remove(k)

        crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
        crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
        crt_net['model.3.weight'] = state_dict['upconv1.weight']
        crt_net['model.3.bias'] = state_dict['upconv1.bias']
        crt_net['model.6.weight'] = state_dict['upconv2.weight']
        crt_net['model.6.bias'] = state_dict['upconv2.bias']
        crt_net['model.8.weight'] = state_dict['HRconv.weight']
        crt_net['model.8.bias'] = state_dict['HRconv.bias']
        crt_net['model.10.weight'] = state_dict['conv_last.weight']
        crt_net['model.10.bias'] = state_dict['conv_last.bias']
        state_dict = crt_net
    return state_dict


def resrgan2normal(state_dict, nb=23):
    # this code is copied from https://github.com/victorca25/iNNfer
    if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
        re8x = 0
        crt_net = {}
        items = list(state_dict)

        crt_net['model.0.weight'] = state_dict['conv_first.weight']
        crt_net['model.0.bias'] = state_dict['conv_first.bias']

        for k in items.copy():
            if "rdb" in k:
                ori_k = k.replace('body.', 'model.1.sub.')
                ori_k = ori_k.replace('.rdb', '.RDB')
                if '.weight' in k:
                    ori_k = ori_k.replace('.weight', '.0.weight')
                elif '.bias' in k:
                    ori_k = ori_k.replace('.bias', '.0.bias')
                crt_net[ori_k] = state_dict[k]
                items.remove(k)

        crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
        crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
        crt_net['model.3.weight'] = state_dict['conv_up1.weight']
        crt_net['model.3.bias'] = state_dict['conv_up1.bias']
        crt_net['model.6.weight'] = state_dict['conv_up2.weight']
        crt_net['model.6.bias'] = state_dict['conv_up2.bias']

        if 'conv_up3.weight' in state_dict:
            # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
            re8x = 3
            crt_net['model.9.weight'] = state_dict['conv_up3.weight']
            crt_net['model.9.bias'] = state_dict['conv_up3.bias']

        crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
        crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
        crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
        crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']

        state_dict = crt_net
    return state_dict


def infer_params(state_dict):
    # this code is copied from https://github.com/victorca25/iNNfer
    scale2x = 0
    scalemin = 6
    n_uplayer = 0
    plus = False

    for block in list(state_dict):
        parts = block.split(".")
        n_parts = len(parts)
        if n_parts == 5 and parts[2] == "sub":
            nb = int(parts[3])
        elif n_parts == 3:
            part_num = int(parts[1])
            if (part_num > scalemin
                and parts[0] == "model"
                and parts[2] == "weight"):
                scale2x += 1
            if part_num > n_uplayer:
                n_uplayer = part_num
                out_nc = state_dict[block].shape[0]
        if not plus and "conv1x1" in block:
            plus = True

    nf = state_dict["model.0.weight"].shape[0]
    in_nc = state_dict["model.0.weight"].shape[1]
    out_nc = out_nc
    scale = 2 ** scale2x

    return in_nc, out_nc, nf, nb, plus, scale


class UpscalerESRGAN(Upscaler):
    def __init__(self, dirname):
        self.name = "ESRGAN"
        self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
        self.model_name = "ESRGAN_4x"
        self.scalers = []
        self.user_path = dirname
        super().__init__()
        model_paths = self.find_models(ext_filter=[".pt", ".pth"])
        scalers = []
        if len(model_paths) == 0:
            scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
            scalers.append(scaler_data)
        for file in model_paths:
            if file.startswith("http"):
                name = self.model_name
            else:
                name = modelloader.friendly_name(file)

            scaler_data = UpscalerData(name, file, self, 4)
            self.scalers.append(scaler_data)

    def do_upscale(self, img, selected_model):
        try:
            model = self.load_model(selected_model)
        except Exception as e:
            print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
            return img
        model.to(devicee) # ⭐
        img = esrgan_upscale(model, img)
        return img

    def load_model(self, path: str):
        if path.startswith("http"):
            # TODO: this doesn't use `path` at all?
            filename = modelloader.load_file_from_url(
                url=self.model_url,
                model_dir=self.model_download_path,
                file_name=f"{self.model_name}.pth",
            )
        else:
            filename = path

        state_dict = torch.load(filename, map_location='cpu' if 'p⭐s' == 'mps' else None)

        if "params_ema" in state_dict:
            state_dict = state_dict["params_ema"]
        elif "params" in state_dict:
            state_dict = state_dict["params"]
            num_conv = 16 if "realesr-animevideov3" in filename else 32
            model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
            model.load_state_dict(state_dict)
            model.eval()
            return model

        if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
            nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
            state_dict = resrgan2normal(state_dict, nb)
        elif "conv_first.weight" in state_dict:
            state_dict = mod2normal(state_dict)
        elif "model.0.weight" not in state_dict:
            raise Exception("The file is not a recognized ESRGAN model.")

        in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)

        model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
        model.load_state_dict(state_dict)
        model.eval()

        return model


def upscale_without_tiling(model, img):
    img = np.array(img)
    img = img[:, :, ::-1]
    img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
    img = torch.from_numpy(img).float()
    img = img.unsqueeze(0).to(devicee) # ⭐
    with torch.no_grad():
        output = model(img)
    output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
    output = 255. * np.moveaxis(output, 0, 2)
    output = output.astype(np.uint8)
    output = output[:, :, ::-1]
    return Image.fromarray(output, 'RGB')


def esrgan_upscale(model, img):
    if 'active default' == 0: # tilling ⭐
        return upscale_without_tiling(model, img)
    ESRGAN_tile = 100 #⭐
    ESRGAN_tile_overlap = 10 # ⭐

    grid = split_grid(img, ESRGAN_tile, ESRGAN_tile, ESRGAN_tile_overlap)
    newtiles = []
    scale_factor = 4

    for y, h, row in grid.tiles:
        newrow = []
        for tiledata in row:
            x, w, tile = tiledata

            output = upscale_without_tiling(model, tile)
            scale_factor = output.width // tile.width

            newrow.append([x * scale_factor, w * scale_factor, output])
        newtiles.append([y * scale_factor, h * scale_factor, newrow])

    newgrid = Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
    output = combine_grid(newgrid)
    return output
# =====================================
# Upscaler Utils
# =====================================

from collections import namedtuple

Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])


def split_grid(image, tile_w=512, tile_h=512, overlap=64):
    w = image.width
    h = image.height

    non_overlap_width = tile_w - overlap
    non_overlap_height = tile_h - overlap

    cols = math.ceil((w - overlap) / non_overlap_width)
    rows = math.ceil((h - overlap) / non_overlap_height)

    dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
    dy = (h - tile_h) / (rows - 1) if rows > 1 else 0

    grid = Grid([], tile_w, tile_h, w, h, overlap)
    for row in range(rows):
        row_images = []

        y = int(row * dy)

        if y + tile_h >= h:
            y = h - tile_h

        for col in range(cols):
            x = int(col * dx)

            if x + tile_w >= w:
                x = w - tile_w

            tile = image.crop((x, y, x + tile_w, y + tile_h))

            row_images.append([x, tile_w, tile])

        grid.tiles.append([y, tile_h, row_images])

    return grid


def combine_grid(grid):
    def make_mask_image(r):
        r = r * 255 / grid.overlap
        r = r.astype(np.uint8)
        return Image.fromarray(r, 'L')

    mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
    mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))

    combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
    for y, h, row in grid.tiles:
        combined_row = Image.new("RGB", (grid.image_w, h))
        for x, w, tile in row:
            if x == 0:
                combined_row.paste(tile, (0, 0))
                continue

            combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
            combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))

        if y == 0:
            combined_image.paste(combined_row, (0, 0))
            continue

        combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
        combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))

    return combined_image
# =====================================
# Upscaler loaders
# =====================================
import os
import shutil
import importlib
from urllib.parse import urlparse



def load_file_from_url(
    url: str,
    *,
    model_dir: str,
    progress: bool = True,
    file_name: str | None = None,
) -> str:
    """Download a file from `url` into `model_dir`, using the file present if possible.

    Returns the path to the downloaded file.
    """
    os.makedirs(model_dir, exist_ok=True)
    if not file_name:
        parts = urlparse(url)
        file_name = os.path.basename(parts.path)
    cached_file = os.path.abspath(os.path.join(model_dir, file_name))
    if not os.path.exists(cached_file):
        print(f'Downloading: "{url}" to {cached_file}\n')
        from torch.hub import download_url_to_file
        download_url_to_file(url, cached_file, progress=progress)
    return cached_file


def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
    """
    """
    output = []

    try:
        places = []

        if command_path is not None and command_path != model_path:
            pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
            if os.path.exists(pretrained_path):
                print(f"Appending path: {pretrained_path}")
                places.append(pretrained_path)
            elif os.path.exists(command_path):
                places.append(command_path)

        places.append(model_path)

        for place in places:
            for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
                if os.path.islink(full_path) and not os.path.exists(full_path):
                    print(f"Skipping broken symlink: {full_path}")
                    continue
                if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
                    continue
                if full_path not in output:
                    output.append(full_path)

        if model_url is not None and len(output) == 0:
            if download_name is not None:
                output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
            else:
                output.append(model_url)

    except Exception:
        pass

    return output


def friendly_name(file: str):
    if file.startswith("http"):
        file = urlparse(file).path

    file = os.path.basename(file)
    model_name, extension = os.path.splitext(file)
    return model_name

# =====================================
# UpscalerESRGAN ARCH
# =====================================

# this file is adapted from https://github.com/victorca25/iNNfer

from collections import OrderedDict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


####################
# RRDBNet Generator
####################

class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
            finalact=None, gaussian_noise=False, plus=False):
        super(RRDBNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        self.resrgan_scale = 0
        if in_nc % 16 == 0:
            self.resrgan_scale = 1
        elif in_nc != 4 and in_nc % 4 == 0:
            self.resrgan_scale = 2

        fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
        rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
            norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
            gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
        LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)

        if upsample_mode == 'upconv':
            upsample_block = upconv_block
        elif upsample_mode == 'pixelshuffle':
            upsample_block = pixelshuffle_block
        else:
            raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
        HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
        HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)

        outact = act(finalact) if finalact else None

        self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
            *upsampler, HR_conv0, HR_conv1, outact)

    def forward(self, x, outm=None):
        if self.resrgan_scale == 1:
            feat = pixel_unshuffle(x, scale=4)
        elif self.resrgan_scale == 2:
            feat = pixel_unshuffle(x, scale=2)
        else:
            feat = x

        return self.model(feat)


class RRDB(nn.Module):
    """
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    """

    def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
            spectral_norm=False, gaussian_noise=False, plus=False):
        super(RRDB, self).__init__()
        # This is for backwards compatibility with existing models
        if nr == 3:
            self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
                    gaussian_noise=gaussian_noise, plus=plus)
            self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
                    gaussian_noise=gaussian_noise, plus=plus)
            self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
                    gaussian_noise=gaussian_noise, plus=plus)
        else:
            RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
                                              norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
                                              gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
            self.RDBs = nn.Sequential(*RDB_list)

    def forward(self, x):
        if hasattr(self, 'RDB1'):
            out = self.RDB1(x)
            out = self.RDB2(out)
            out = self.RDB3(out)
        else:
            out = self.RDBs(x)
        return out * 0.2 + x


class ResidualDenseBlock_5C(nn.Module):
    """
    Residual Dense Block
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    Modified options that can be used:
        - "Partial Convolution based Padding" arXiv:1811.11718
        - "Spectral normalization" arXiv:1802.05957
        - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
            {Rakotonirina} and A. {Rasoanaivo}
    """

    def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
            spectral_norm=False, gaussian_noise=False, plus=False):
        super(ResidualDenseBlock_5C, self).__init__()

        self.noise = GaussianNoise() if gaussian_noise else None
        self.conv1x1 = conv1x1(nf, gc) if plus else None

        self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
            spectral_norm=spectral_norm)
        self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
            spectral_norm=spectral_norm)
        self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
            spectral_norm=spectral_norm)
        self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
            spectral_norm=spectral_norm)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
            norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
            spectral_norm=spectral_norm)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), 1))
        if self.conv1x1:
            x2 = x2 + self.conv1x1(x)
        x3 = self.conv3(torch.cat((x, x1, x2), 1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        if self.conv1x1:
            x4 = x4 + x2
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        if self.noise:
            return self.noise(x5.mul(0.2) + x)
        else:
            return x5 * 0.2 + x


####################
# ESRGANplus
####################

class GaussianNoise(nn.Module):
    def __init__(self, sigma=0.1, is_relative_detach=False):
        super().__init__()
        self.sigma = sigma
        self.is_relative_detach = is_relative_detach
        self.noise = torch.tensor(0, dtype=torch.float)

    def forward(self, x):
        if self.training and self.sigma != 0:
            self.noise = self.noise.to(x.device)
            scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
            sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
            x = x + sampled_noise
        return x

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


####################
# SRVGGNetCompact
####################

class SRVGGNetCompact(nn.Module):
    """A compact VGG-style network structure for super-resolution.
    This class is copied from https://github.com/xinntao/Real-ESRGAN
    """

    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
        super(SRVGGNetCompact, self).__init__()
        self.num_in_ch = num_in_ch
        self.num_out_ch = num_out_ch
        self.num_feat = num_feat
        self.num_conv = num_conv
        self.upscale = upscale
        self.act_type = act_type

        self.body = nn.ModuleList()
        # the first conv
        self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
        # the first activation
        if act_type == 'relu':
            activation = nn.ReLU(inplace=True)
        elif act_type == 'prelu':
            activation = nn.PReLU(num_parameters=num_feat)
        elif act_type == 'leakyrelu':
            activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.body.append(activation)

        # the body structure
        for _ in range(num_conv):
            self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
            # activation
            if act_type == 'relu':
                activation = nn.ReLU(inplace=True)
            elif act_type == 'prelu':
                activation = nn.PReLU(num_parameters=num_feat)
            elif act_type == 'leakyrelu':
                activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
            self.body.append(activation)

        # the last conv
        self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
        # upsample
        self.upsampler = nn.PixelShuffle(upscale)

    def forward(self, x):
        out = x
        for i in range(0, len(self.body)):
            out = self.body[i](out)

        out = self.upsampler(out)
        # add the nearest upsampled image, so that the network learns the residual
        base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
        out += base
        return out


####################
# Upsampler
####################

class Upsample(nn.Module):
    r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
    The input data is assumed to be of the form
    `minibatch x channels x [optional depth] x [optional height] x width`.
    """

    def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
        super(Upsample, self).__init__()
        if isinstance(scale_factor, tuple):
            self.scale_factor = tuple(float(factor) for factor in scale_factor)
        else:
            self.scale_factor = float(scale_factor) if scale_factor else None
        self.mode = mode
        self.size = size
        self.align_corners = align_corners

    def forward(self, x):
        return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)

    def extra_repr(self):
        if self.scale_factor is not None:
            info = f'scale_factor={self.scale_factor}'
        else:
            info = f'size={self.size}'
        info += f', mode={self.mode}'
        return info


def pixel_unshuffle(x, scale):
    """ Pixel unshuffle.
    Args:
        x (Tensor): Input feature with shape (b, c, hh, hw).
        scale (int): Downsample ratio.
    Returns:
        Tensor: the pixel unshuffled feature.
    """
    b, c, hh, hw = x.size()
    out_channel = c * (scale**2)
    assert hh % scale == 0 and hw % scale == 0
    h = hh // scale
    w = hw // scale
    x_view = x.view(b, c, h, scale, w, scale)
    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)


def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
                        pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
    """
    Pixel shuffle layer
    (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
    Neural Network, CVPR17)
    """
    conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
                        pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
    pixel_shuffle = nn.PixelShuffle(upscale_factor)

    n = norm(norm_type, out_nc) if norm_type else None
    a = act(act_type) if act_type else None
    return sequential(conv, pixel_shuffle, n, a)


def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
                pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
    """ Upconv layer """
    upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
    upsample = Upsample(scale_factor=upscale_factor, mode=mode)
    conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
                        pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
    return sequential(upsample, conv)








####################
# Basic blocks
####################


def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.
    Args:
        basic_block (nn.module): nn.module class for basic block. (block)
        num_basic_block (int): number of blocks. (n_layers)
    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)


def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
    """ activation helper """
    act_type = act_type.lower()
    if act_type == 'relu':
        layer = nn.ReLU(inplace)
    elif act_type in ('leakyrelu', 'lrelu'):
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act_type == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    elif act_type == 'tanh':  # [-1, 1] range output
        layer = nn.Tanh()
    elif act_type == 'sigmoid':  # [0, 1] range output
        layer = nn.Sigmoid()
    else:
        raise NotImplementedError(f'activation layer [{act_type}] is not found')
    return layer


class Identity(nn.Module):
    def __init__(self, *kwargs):
        super(Identity, self).__init__()

    def forward(self, x, *kwargs):
        return x


def norm(norm_type, nc):
    """ Return a normalization layer """
    norm_type = norm_type.lower()
    if norm_type == 'batch':
        layer = nn.BatchNorm2d(nc, affine=True)
    elif norm_type == 'instance':
        layer = nn.InstanceNorm2d(nc, affine=False)
    elif norm_type == 'none':
        def norm_layer(x): return Identity()
    else:
        raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
    return layer


def pad(pad_type, padding):
    """ padding layer helper """
    pad_type = pad_type.lower()
    if padding == 0:
        return None
    if pad_type == 'reflect':
        layer = nn.ReflectionPad2d(padding)
    elif pad_type == 'replicate':
        layer = nn.ReplicationPad2d(padding)
    elif pad_type == 'zero':
        layer = nn.ZeroPad2d(padding)
    else:
        raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
    return layer


def get_valid_padding(kernel_size, dilation):
    kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
    padding = (kernel_size - 1) // 2
    return padding


class ShortcutBlock(nn.Module):
    """ Elementwise sum the output of a submodule to its input """
    def __init__(self, submodule):
        super(ShortcutBlock, self).__init__()
        self.sub = submodule

    def forward(self, x):
        output = x + self.sub(x)
        return output

    def __repr__(self):
        return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')


def sequential(*args):
    """ Flatten Sequential. It unwraps nn.Sequential. """
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError('sequential does not support OrderedDict input.')
        return args[0]  # No sequential is needed.
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Module):
            modules.append(module)
    return nn.Sequential(*modules)


def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
               pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
               spectral_norm=False):
    """ Conv layer with padding, normalization, activation """
    assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
    padding = get_valid_padding(kernel_size, dilation)
    p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
    padding = padding if pad_type == 'zero' else 0

    if convtype=='PartialConv2D':
        from torchvision.ops import PartialConv2d  # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
        c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
               dilation=dilation, bias=bias, groups=groups)
    elif convtype=='DeformConv2D':
        from torchvision.ops import DeformConv2d  # not tested
        c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
               dilation=dilation, bias=bias, groups=groups)
    elif convtype=='Conv3D':
        c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
                dilation=dilation, bias=bias, groups=groups)
    else:
        c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
                dilation=dilation, bias=bias, groups=groups)

    if spectral_norm:
        c = nn.utils.spectral_norm(c)

    a = act(act_type) if act_type else None
    if 'CNA' in mode:
        n = norm(norm_type, out_nc) if norm_type else None
        return sequential(p, c, n, a)
    elif mode == 'NAC':
        if norm_type is None and act_type is not None:
            a = act(act_type, inplace=False)
        n = norm(norm_type, in_nc) if norm_type else None
        return sequential(n, a, p, c)

In [None]:
#@title 3. 👇 Generating Images { form-width: "20%", display-mode: "form" }
#@markdown ---
#@markdown - **Prompt** - Tell the model what you want to see.
#@markdown - **Negative Prompt** - Tell the model what you don't want to see.
#@markdown - **Steps** - How long the model should work on the image.
#@markdown - **CFG** - Controls how much the image generation process follows the text prompt.  Guidance scale ranging from 0 to 20. Lower values allow the AI to be more creative and less strict at following the prompt. Default is 7.5
#@markdown - **Sampler** - Progressively reduce image noise through denoising steps.
#@markdown - **Seed** -  A number that helps the model start generating the image. Set `-1` for using random seed values.
#@markdown - **Clip Skip** - It allows to control the level of detail and accuracy in the generated images by skipping certain layers of the CLIP model during the image generation process.
#@markdown - **Prompt weights** -  Helps the model focus on different parts of the prompt. Prompt weights can be used to emphasize or de-emphasize certain aspects of the image, such as the object, the scene, or the style. Currently, the [Compel syntax](https://github.com/damian0815/compel/blob/main/doc/syntax.md) is being used. You can also activate the `Convert Prompt weights` to automatically convert syntax from `(word:1.1)` to `(word)1.1` or `(word)` to `(word)+` to make them compatible with Compel weights. Compel scale more with its values so that fewer weights are needed for good results
#@markdown - **FreeU** - Is a method that substantially improves diffusion model sample quality at no costs.
#@markdown - **ControlNet** - Enhances text-to-image diffusion models by allowing various spatial contexts to serve as additional conditioning, enabling the generation of more controlled and context-aware images.
#@markdown ---
%cd /content
import ipywidgets as widgets, mediapy, random
from diffusers.models.attention_processor import AttnProcessor2_0
from PIL import Image
import IPython.display
import time
from IPython.utils import capture
import logging
from ipywidgets import Button, Layout, jslink, IntText, IntSlider, BoundedFloatText
logging.getLogger("diffusers").setLevel(logging.ERROR)

#from IPython.display import display
from ipywidgets import TwoByTwoLayout, Layout, Button, Box, FloatText, Textarea, Dropdown, Label, IntSlider, interactive, HBox, VBox, BoundedIntText, BoundedFloatText

#PARAMETER WIDGETS
width = "225px"
auto_layout = widgets.Layout(height='auto', width='auto')
lora_layout = {'width':'165px'}
lora_scale_layout = {'width':'50px'}
style = {'description_width': 'initial'} # show full text

# =====================================
# Left
# =====================================

num_images = widgets.BoundedIntText(
    value = 1,
    min = 1,
    description="Images:",
    layout=widgets.Layout(width=width)
)

steps = widgets.BoundedIntText(
    value = 30,
    min = 1,
    max = 100,
    description="Steps:",
    layout=widgets.Layout(width=width)
)

CFG = widgets.BoundedFloatText(
    value = 7.5,
    min = 0,
    step=0.5,
    description="CFG:",
    layout=widgets.Layout(width=width)
)

select_sampler = widgets.Dropdown(
    options=[
        "DPM++ 2M",
        "DPM++ 2M Karras",
        "DPM++ 2M SDE",
        "DPM++ 2M SDE Karras",
        "DPM++ SDE",
        "DPM++ SDE Karras",
        "DPM2",
        "DPM2 Karras",
        "Euler",
        "Euler a",
        "Heun",
        "LMS",
        "LMS Karras",
        "DDIMScheduler",
        "DEISMultistepScheduler",
        "UniPCMultistepScheduler",
    ],
    description="Sampler:",
    layout=widgets.Layout(width=width),
)

img_height = widgets.BoundedIntText(
    min=64,
    max=4096,
    step=8,
    value=512,
    description="Height:",
    layout=widgets.Layout(width=width),
)

img_width = widgets.BoundedIntText(
    min=64,
    max=4096,
    step=8,
    value=512,
    description="Width:",
    layout=widgets.Layout(width=width),
)

random_seed = widgets.IntText(
    value=-1,
    description="Seed:",
    layout=widgets.Layout(width=width),
    disabled=False
)

#lora1
select_lora1 = widgets.Dropdown(
    options=lora_model_list,
    description="Lora1:",
    layout=lora_layout
)

lora_weights_scale1 = widgets.FloatText(
    min=-2.0,
    max=2.0,
    step=0.01,
    value=1,
    #description="Lora scale1:",
    layout=lora_scale_layout
)
#lora2
select_lora2 = widgets.Dropdown(
    options=lora_model_list,
    description="Lora2:",
    layout=lora_layout
)

lora_weights_scale2 = widgets.FloatText(
    min=-2.0,
    max=2.0,
    step=0.01,
    value=1,
    #description="Lora scale2:",
    layout=lora_scale_layout
)
#lora3
select_lora3 = widgets.Dropdown(
    options=lora_model_list,
    description="Lora3:",
    layout=lora_layout
)

lora_weights_scale3 = widgets.FloatText(
    min=-2.0,
    max=2.0,
    step=0.01,
    value=1,
    #description="Lora scale3:",
    layout=lora_scale_layout
)
#lora4
select_lora4 = widgets.Dropdown(
    options=lora_model_list,
    description="Lora4:",
    layout=lora_layout
)

lora_weights_scale4 = widgets.FloatText(
    min=-2.0,
    max=2.0,
    step=0.01,
    value=1,
    #description="Lora scale4:",
    layout=lora_scale_layout
)
#lora5
select_lora5 = widgets.Dropdown(
    options=lora_model_list,
    description="Lora5:",
    layout=lora_layout
)

lora_weights_scale5 = widgets.FloatText(
    min=-2.0,
    max=2.0,
    step=0.01,
    value=1,
    #description="Lora scale5:",
    layout=lora_scale_layout
)

select_clip_skip = widgets.Checkbox(
    value=True,
    description='Layer 2 Clip Skip',
    layout=widgets.Layout(width=width),
)

freeu_check = widgets.Checkbox(
    value=False,
    description='FreeU for txt2img',
    layout=widgets.Layout(width=width),
)

# =====================================
# Center
# =====================================

display_imgs = widgets.Output() # layout={'border': '1px solid black'}

select_model = widgets.Dropdown(
    options=model_list,
    description="Model:",
    layout=auto_layout,
)

vae_model_dropdown = widgets.Dropdown(
    options=vae_model_list,
    description="VAE:",
    layout=auto_layout,
)

prompt = widgets.Textarea(
    value="",
    placeholder="Enter prompt",
    rows=5,
    layout=auto_layout,
)

neg_prompt = widgets.Textarea(
    value="",
    placeholder="Enter negative prompt",
    rows=5,
    layout=auto_layout,
)

active_ti = widgets.Checkbox(
    value=False,
    description='Active Textual Inversion in prompt (Experimental)',
    #style = style,
    layout=auto_layout,
)
# alternative prompt weights
weights_prompt = widgets.Checkbox(
    value=False,
    description='Convert Prompt weights',
    #style = style,
    layout=auto_layout,
)

generate = widgets.Button(
    description="Generate",
    disabled=False,
    button_style="primary",
    layout=Layout(height='auto', width='auto'),
)

### GENERATE ###

def generate_img(i):
  global model
  #Clear output
  display_imgs.clear_output()
  generate.disabled = True

  with display_imgs:

    print("Running...")


    # First load
    try:
        model
    except:
        model = Model(base_model_id=select_model.value, task_name=select_task.value, vae_model = vae_model_dropdown.value, type_model_precision = model_precision.value)

    model.load_pipe(select_model.value, task_name=select_task.value, vae_model = vae_model_dropdown.value, type_model_precision = model_precision.value)

    display_imgs.clear_output()

    try:
      preprocessor_name_found = int_inputs[select_task.value][0].value
    except:
      preprocessor_name_found = None

    global destination_path_cn_img, mask_control, image_list

    image_control_base = None
    if select_task.value != "txt2img":
        try:
            image_control_base = destination_path_cn_img
        except:
            print("No control image found")
            generate.disabled = False
            return

    mask_control_base = None
    if select_task.value == "Inpaint":
        if os.path.exists(int_inputs['Inpaint'][1].value):
            mask_control_base = int_inputs['Inpaint'][1].value
        else:
            try:
                mask_control_base = mask_control
            except:
                print("No mask image found")
                generate.disabled = False
                return

    pipe_params = {
    "prompt": prompt.value,
    "negative_prompt": neg_prompt.value,
    "img_height": img_height.value,
    "img_width": img_width.value,
    "num_images": num_images.value,
    "num_steps": steps.value,
    "guidance_scale": CFG.value,
    "clip_skip": select_clip_skip.value,
    "seed": random_seed.value,
    "image": image_control_base,
    "preprocessor_name": preprocessor_name_found,
    "preprocess_resolution": preprocess_resolution_global.value,
    "image_resolution": image_resolution_global.value,
    "additional_prompt": "",
    "image_mask": mask_control_base, # only for Inpaint
    "strength": int_inputs['Inpaint'][0].value, # only for Inpaint
    "low_threshold": int_inputs['Canny'][0].value,
    "high_threshold": int_inputs['Canny'][1].value,
    "value_threshold": int_inputs['MLSD'][0].value,
    "distance_threshold": int_inputs['MLSD'][1].value,
    "lora_A": select_lora1.value,
    "lora_scale_A": lora_weights_scale1.value,
    "lora_B": select_lora2.value,
    "lora_scale_B": lora_weights_scale2.value,
    "lora_C": select_lora3.value,
    "lora_scale_C": lora_weights_scale3.value,
    "lora_D": select_lora4.value,
    "lora_scale_D": lora_weights_scale4.value,
    "lora_E": select_lora5.value,
    "lora_scale_E": lora_weights_scale5.value,
    "active_textual_inversion": active_ti.value,
    "textual_inversion": embed_list,
    "convert_weights_prompt": weights_prompt.value,
    "sampler": select_sampler.value,
    "xformers_memory_efficient_attention": xformers_memory_efficient_attention.value,
    "gui_active": True,
    "loop_generation": loop_generator.value,
    "controlnet_conditioning_scale" : controlnet_output_scaling_in_unet.value,
    "control_guidance_start" : controlnet_start_threshold.value,
    "control_guidance_end" : controlnet_stop_threshold.value,
    "generator_in_cpu" : init_generator_in_cpu.value,
    "FreeU" : freeu_check.value,
    }

    images, image_list = model(**pipe_params)

    if loop_generator.value == 1:
        mediapy.show_images(images)

    torch.cuda.empty_cache()
    gc.collect()

  generate.disabled = False
  return

generate.on_click(generate_img)

show_textual_inversion = widgets.Button(
    description="List available textual inversions",
    disabled=False,
    button_style="info",
    layout=widgets.Layout(height='auto', width='auto'),
)

def elemets_textual_inversion(value):
  with display_imgs:
    print('The embeddings currently supported. Write in the prompt the word for use')
    for name, directory_name in embed_list:
        print(name)
    return

show_textual_inversion.on_click(elemets_textual_inversion)

clear_outputs = widgets.Button(
    description="Clear outputs",
    disabled=False,
    button_style="info",
    layout=widgets.Layout(height='auto', width='auto'),

)

def clear_outputs_run(value):
  display_imgs.clear_output()
  return

clear_outputs.on_click(clear_outputs_run)

# =====================================
# Right ControlNet
# =====================================

control_model_list = list(CONTROLNET_MODEL_IDS.keys())

# Create a Dropdown for selecting options
select_task = widgets.Dropdown( #
    options=[
        control_model_list[13],
        control_model_list[12],
        control_model_list[0],
        control_model_list[1],
        control_model_list[2],
        control_model_list[3],
        control_model_list[4],
        control_model_list[5],
        control_model_list[6],
        control_model_list[7],
        control_model_list[8],
        control_model_list[10],
        control_model_list[11],
    ],
    description='TASK:',
    layout=auto_layout,
)

controlnet_output_scaling_in_unet = widgets.FloatText(value=1.0, min=0.0, max=5.0, step=0.1, description='ControlNet Output Scaling in UNet:', style=style)
controlnet_start_threshold = widgets.FloatSlider(value=0.0, min=0.00, max=1.0, step=0.01, description='ControlNet Start Threshold (%):', style=style)
controlnet_stop_threshold = widgets.FloatSlider(value=1.0, min=0.00, max=1.0, step=0.01, description='ControlNet Stop Threshold (%):', style=style)

preprocess_resolution_global = widgets.IntSlider(
    value=512,
    min=64,
    max=2048,
    description='Preprocessor resolution',
    style=style,
)

image_resolution_global = widgets.IntSlider(
    value=512,
    min=64,
    step=64,
    max=2048,
    description='Image resolution',
    style=style,
)

# Create a dictionary to map options to lists of IntText widgets
int_inputs = {

    control_model_list[13]: [
    ],
    control_model_list[12]: [
        widgets.FloatSlider(value=1.0, min=0.01, max=1.0, step=0.01, description='Inpaint strength:', layout=Layout(visibility='hidden'), style=style),
        widgets.Text(value="", placeholder="/content/my_mask.png", rows=1, description='Mask path:', layout=Layout(visibility='hidden'), style=style)
    ],
    control_model_list[0]: [
        widgets.Dropdown(value='Openpose', description='Preprocessor:', options=['None','Openpose'], layout=Layout(visibility='hidden'), style=style),
    ],
    control_model_list[1]: [
        widgets.BoundedIntText(value=100, min=1, max=255, description='Canny low threshold:', layout=Layout(visibility='hidden'), style=style),
        widgets.BoundedIntText(value=200, min=1, max=255, description='Canny high threshold:', layout=Layout(visibility='hidden'), style=style)
    ],
    control_model_list[2]: [
        widgets.BoundedFloatText(value=0.1, min=1, max=2.0, step=0.01, description='Hough value threshold (MLSD):', layout=Layout(visibility='hidden'), style=style),
        widgets.BoundedFloatText(value=0.1, min=1, max=20.0, step=0.01, description='Hough distance threshold (MLSD):', layout=Layout(visibility='hidden'), style=style)
    ],
    control_model_list[3]: [
        widgets.Dropdown(value='HED', description='Preprocessor:', options=['HED','PidiNet', 'None'], layout=Layout(visibility='hidden'), style=style),
    ],
    control_model_list[4]: [
        widgets.Dropdown(value='PidiNet', description='Preprocessor:', options=['HED','PidiNet', 'HED safe', 'PidiNet safe','None'], layout=Layout(visibility='hidden'), style=style),
    ],
    control_model_list[5]: [
        widgets.Dropdown(value='UPerNet', description='Preprocessor:', options=['UPerNet','None'], layout=Layout(visibility='hidden'), style=style),
    ],
    control_model_list[6]: [
        widgets.Dropdown(value='DPT', description='Preprocessor:', options=['Midas', 'DPT','None'], layout=Layout(visibility='hidden'), style=style),
    ],
    control_model_list[7]: [
        widgets.Dropdown(value='NormalBae', description='Preprocessor:', options=['NormalBae','None'], layout=Layout(visibility='hidden'), style=style),
    ],
    control_model_list[8]: [
        widgets.Dropdown(value='Lineart', description='Preprocessor:', options=['Lineart','Lineart coarse', 'None', 'Lineart (anime)', 'None (anime)'], layout=Layout(visibility='hidden'), style=style),
    ],
    control_model_list[10]: [
        widgets.Dropdown(value='ContentShuffle', description='Preprocessor:', options=['ContentShuffle','None'], layout=Layout(visibility='hidden'), style=style),
    ],
}

# Function to update visibility and enable/disable state of widgets
def update_widgets(option):
    for opt, int_inputs_list in int_inputs.items():
        if opt == option:
            for int_input in int_inputs_list:
                int_input.layout.visibility = 'visible'
        else:
            for int_input in int_inputs_list:
                int_input.layout.visibility = 'hidden'

interactive(update_widgets, option=select_task)

# =====================================
# Settings
# =====================================
loop_generator = widgets.IntSlider(
    value=1,
    min=1,
    max=10,
    description='Loops 🔁',
    #layout=auto_layout,
    style = style,
)

xformers_memory_efficient_attention = widgets.Checkbox(
    value=True,
    description='Xformers memory efficient attention',
    layout=auto_layout,
    style = style,
)

model_precision = widgets.RadioButtons(
    options=[("float16", torch.float16),("float32", torch.float32)],
    description="Model precision (float32 need more memory):",
    layout=auto_layout,
    style = style,
)

init_generator_in_cpu = widgets.Checkbox(
    value=False,
    description='Generate initial random noise with the seed on the CPU',
    layout=auto_layout,
    style = style,
)
# =====================================
# App
# =====================================

title_tab_one = widgets.HTML(
    value="<h2>SD Interactive</h2>",
    layout=widgets.Layout(display="flex", justify_content="center")
)
title_tab_two = widgets.HTML(
    value="<h2>Settings</h2>",
    layout=widgets.Layout(display="flex", justify_content="center")
)

buttons_ = TwoByTwoLayout(top_left=generate,
               top_right=show_textual_inversion,
               #bottom_left=None,
               bottom_right=clear_outputs, merge=True)

show_textual_inversion.style.button_color = '#97BC62'
clear_outputs.style.button_color = '#97BC62'
generate.style.button_color = '#2C5F2D'

# TAB 1
tab_sd = widgets.VBox(
    [
      widgets.AppLayout(
        header=None,
        left_sidebar=widgets.VBox(
            [
                title_tab_one,
                num_images,
                steps,
                CFG,
                select_sampler,
                img_width,
                img_height,
                random_seed,
                widgets.HBox([select_lora1, lora_weights_scale1],layout=widgets.Layout(width=width)),
                widgets.HBox([select_lora2, lora_weights_scale2],layout=widgets.Layout(width=width)),
                widgets.HBox([select_lora3, lora_weights_scale3],layout=widgets.Layout(width=width)),
                widgets.HBox([select_lora4, lora_weights_scale4],layout=widgets.Layout(width=width)),
                widgets.HBox([select_lora5, lora_weights_scale5],layout=widgets.Layout(width=width)),
                select_clip_skip,
                freeu_check,
            ]
        ),
        center=widgets.VBox(
            [
                select_task,
                select_model,
                vae_model_dropdown,
                prompt,
                neg_prompt,
                widgets.HBox([weights_prompt ,active_ti]),
                buttons_,
            ]
        ),
        right_sidebar=widgets.VBox(
            [
                controlnet_output_scaling_in_unet,
                controlnet_start_threshold,
                controlnet_stop_threshold,
            ] +
            [preprocess_resolution_global] + [image_resolution_global] + [int_input for int_inputs_list in int_inputs.values() for int_input in int_inputs_list]
        ),
        footer=None,
        pane_widths=[1.1, 1.4, 1.2],
        # pane_heights=["0px", 1, '0px'],
      ),
      display_imgs,
    ]
)

# TAB 2
tab_settings = widgets.VBox([
    title_tab_two,
    loop_generator,
    xformers_memory_efficient_attention,
    init_generator_in_cpu,
    model_precision,
    widgets.HTML(
        value="<p>⏲️</p>",
        layout=widgets.Layout(display="flex", justify_content="center")
    ),
])

# APP
tab = widgets.Tab()
tab.children = [
  widgets.VBox(
    children = [
        tab_sd,
    ]),
  widgets.VBox(
    children = [
        tab_settings,
    ]),
]

tab.set_title(0, "Stable Diffusion")
tab.set_title(1,  "Settings")
tab.selected_index = 0

display(tab)

In [None]:
#@title Upload an image here for use in Inpainting or ControlNet. 👈‍‍ 🖼️🖼️🖼️
#@markdown - To use Controlnet, you need to upload the control image with this cell
Create_mask_for_Inpaint = False # @param {type:"boolean"}
stroke_width = 24 # @param {type:"integer"}
from google.colab import files
from IPython.display import HTML
import os
import shutil
%cd /content
uploaded = files.upload()

filename = next(iter(uploaded))
print(f'Uploaded file: {filename}')

upload_folder = 'uploaded_controlnet_image/'
if not os.path.exists(upload_folder):
    os.makedirs(upload_folder)

source_path = filename
destination_path_cn_img = os.path.join(upload_folder, filename)
shutil.move(source_path, destination_path_cn_img)
print(f'Moved file to {destination_path_cn_img}')

if select_task.value == 'Inpaint' or Create_mask_for_Inpaint:
    init_image = destination_path_cn_img
    name_without_extension = os.path.splitext(init_image.split('/')[-1])[0]

    image64 = base64.b64encode(open(init_image, 'rb').read())
    image64 = image64.decode('utf-8')

    print('\033[34m Draw the mask with the mouse \033[0m')
    img = np.array(plt.imread(f'{init_image}')[:,:,:3])

    draw(image64, filename=f"./{name_without_extension}_draw.png", w=img.shape[1], h=img.shape[0], line_width=stroke_width)

    with_mask = np.array(plt.imread(f"./{name_without_extension}_draw.png")[:,:,:3])
    mask = (with_mask[:,:,0]==1)*(with_mask[:,:,1]==0)*(with_mask[:,:,2]==0)
    plt.imsave(f"./{name_without_extension}_mask.png",mask, cmap='gray')
    mask_control = f"./{name_without_extension}_mask.png"
    print(f'\033[34m Mask saved: {mask_control} \033[0m')

In [None]:
#@title Upscale and face restoration { form-width: "20%", display-mode: "form" }
from IPython.utils import capture
import os
import shutil

%cd /content
directory_codeformer = '/content/CodeFormer/'
with capture.capture_output() as cap:
  if not os.path.exists(directory_codeformer):
      os.makedirs(directory_codeformer)

      # Setup
      # Clone CodeFormer and enter the CodeFormer folder
      %cd /content
      !git clone https://github.com/sczhou/CodeFormer.git
      %cd CodeFormer


      # Set up the environment
      # Install python dependencies
      !pip install -q -r requirements.txt
      !pip -q install ffmpeg
      # Install basicsr
      !python basicsr/setup.py develop

      # Download the pre-trained model
      !python scripts/download_pretrained_models.py facelib
      !python scripts/download_pretrained_models.py CodeFormer
  del cap
# Visualization function
import cv2
import matplotlib.pyplot as plt
def display_codeformer(img1, img2):
  fig = plt.figure(figsize=(25, 10))
  ax1 = fig.add_subplot(1, 2, 1)
  plt.title('Input', fontsize=16)
  ax1.axis('off')
  ax2 = fig.add_subplot(1, 2, 2)
  plt.title('CodeFormer', fontsize=16)
  ax2.axis('off')
  ax1.imshow(img1)
  ax2.imshow(img2)
def imread(img_path):
  img = cv2.imread(img_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  return img

# Copy imgs
Select_an_image = "" # @param {type:"string"}

# PROCESS AD
if os.path.exists(Select_an_image.strip()):
    image_list = [Select_an_image.replace('/content/', '').strip()]

destination_directory = '/content/CodeFormer/inputs/user_upload'
!rm -rf /content/CodeFormer/inputs/user_upload/*
os.makedirs(destination_directory, exist_ok=True)
for image_path in image_list:
    image_filename = os.path.basename('/content/'+image_path)
    destination_path = os.path.join(destination_directory, image_filename)
    try:
        shutil.copyfile('/content/'+image_path, destination_path)
        print(f"Image '{image_filename}' has been copied to '{destination_path}'")
    except Exception as e:
        print(f"Failed to copy '{image_filename}' to '{destination_path}': {e}")

#@markdown `CODEFORMER_FIDELITY`: Balance the quality (lower number) and fidelity (higher number)<br>
# you can add '--bg_upsampler realesrgan' to enhance the background
CODEFORMER_FIDELITY = 0.7 #@param {type:"slider", min:0, max:1, step:0.01}
#@markdown `BACKGROUND_ENHANCE`: Enhance background image with Real-ESRGAN<br>
BACKGROUND_ENHANCE = True #@param {type:"boolean"}
#@markdown `FACE_UPSAMPLE`: Upsample restored faces for high-resolution AI-created images<br>
FACE_UPSAMPLE = False #@param {type:"boolean"}
#markdown `HAS_ALIGNED`: Input are cropped and aligned faces<br>
HAS_ALIGNED =  False
#@markdown `UPSCALE`: The final upsampling scale of the image. Default: 2<br>
UPSCALE = 3 #@param {type:"slider", min:2, max:8, step:1}
#markdown `DETECTION_MODEL`: Face detector. Default: retinaface_resnet50<br>
DETECTION_MODEL = "retinaface_resnet50"
#markdown `DRAW_BOX`: Draw the bounding box for the detected faces.
DRAW_BOX = False

BACKGROUND_ENHANCE = '--bg_upsampler realesrgan' if BACKGROUND_ENHANCE else ''
FACE_UPSAMPLE = '--face_upsample' if FACE_UPSAMPLE else ''
HAS_ALIGNED = '--has_aligned' if HAS_ALIGNED else ''
DRAW_BOX = '--draw_box' if DRAW_BOX else ''
%cd CodeFormer
!python inference_codeformer.py -w $CODEFORMER_FIDELITY --input_path {destination_directory} {BACKGROUND_ENHANCE} {FACE_UPSAMPLE} {HAS_ALIGNED} --upscale {UPSCALE} --detection_model {DETECTION_MODEL} {DRAW_BOX}


import os
import glob

input_folder = 'inputs/user_upload'
result_folder = f'results/user_upload_{CODEFORMER_FIDELITY}/final_results'
input_list = sorted(glob.glob(os.path.join(input_folder, '*')))
for input_path in input_list:
  img_input = imread(input_path)
  basename = os.path.splitext(os.path.basename(input_path))[0]
  output_path = os.path.join(result_folder, basename+'.png')
  img_output = imread(output_path)
  display_codeformer(img_input, img_output)

%cd /content

Results in /content/CodeFormer/results


In [None]:
#@title Download Images
import os
from google.colab import files
!rm /content/results.zip
!ls /content/images
print('Download results')
os.system(f'zip -r results.zip /content/images')
try:
  files.download("results.zip")
except:
  print("Error")

In [None]:
#@title Download Upscale results
import os
from google.colab import files
import shutil
%cd /content/CodeFormer
!ls results
print('Download results')
os.system(f'zip -r results.zip results/user_upload_{CODEFORMER_FIDELITY}/final_results')
try:
  files.download("results.zip")
except:
  files.download(f'/content/CodeFormer/results/{filename[:-4]}_{CODEFORMER_FIDELITY}/{filename}')
%cd /content

# Extras

In [None]:
# You can also use this cell to simply reload the model in case you need to.
del model

In [None]:
import os
from PIL import Image

upscaler_dict = {
    "RealESRGAN_x4plus" : "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
    "RealESRNet_x4plus" : "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
    "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
    "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
    "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
    "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
    "realesr-general-wdn-x4v3" : "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
    "4x-UltraSharp" : "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
    "4x_foolhardy_Remacri" : "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
    "Remacri4xExtraSmoother" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
    "AnimeSharp4x" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
    "lollypop" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
    "RealisticRescaler4x" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
    "NickelbackFS4x" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
}

#@markdown # Alternative Upscaler Tool
#@markdown You can leave `Select_an_image` blank to process the last generated images.
Select_an_image = "" # @param {type:"string"}
MODEL_UPSCALER = "RealESRGAN_x4plus_anime_6B" #@param ["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x2plus", "RealESRGAN_x4plus_anime_6B", "realesr-animevideov3", "realesr-general-x4v3", "realesr-general-wdn-x4v3", "4x-UltraSharp", "4x_foolhardy_Remacri", "Remacri4xExtraSmoother", "AnimeSharp4x", "lollypop", "RealisticRescaler4x", "NickelbackFS4x"]
Scale_of_the_image_x = 1.5 #@param {type:"slider", min:1, max:4, step:0.5}
show_result = True #@param {type: "boolean"}

directory_upscalers = 'upscalers'
os.makedirs(directory_upscalers, exist_ok=True)

url_upscaler = upscaler_dict[MODEL_UPSCALER]

if not os.path.exists(f"./upscalers/{url_upscaler.split('/')[-1]}"):
    download_things(directory_upscalers, url_upscaler, hf_token)

scaler_beta = UpscalerESRGAN("")

if os.path.exists(Select_an_image.strip()):
    image_list = [Select_an_image.replace('/content/', '').strip()]

for img_base in image_list:
    img_pil = Image.open(img_base)
    img_up = scaler_beta.upscale(img_pil, Scale_of_the_image_x , f"./upscalers/{url_upscaler.split('/')[-1]}")

    if show_result:
        display(img_up)

    image_path = save_pil_image_with_metadata(img_up, f'{os.getcwd()}/up_images', metadata_list=None)
    print(image_path)

In [None]:
# @markdown CONVERT SAFETENSORS TO DIFFUSERS for SD 1.5
path_safetensor_model = "" # @param {type:"string"}
path_diffusers_model = "./converted_model/" # @param {type:"string"}

from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_single_file(input_safetensor_model).to("cuda")
pipe.save_pretrained(output_diffusers_model) # model path inpaint is ./adetailer_model/
del pipe
torch.cuda.empty_cache()
gc.collect()

Utility https://github.com/Linaqruf/sdxl-model-converter

In [None]:
from PIL import Image
import os

try:
  del model
  torch.cuda.empty_cache()
  gc.collect()
except:
  torch.cuda.empty_cache()
  gc.collect()



# OPTIONS #
# @markdown # Adetailer 🚫 beta
# @markdown work in progress for implement safetensors format
face_detector_ad = True # @param {type:"boolean"}
person_detector_ad = True # @param {type:"boolean"}
hand_detector_ad = False # @param {type:"boolean"}
path_diffusers_model = "./converted_model/" # @param {type:"string"}
common = {
    "prompt": "masterpiece, best quality",
    "num_inference_steps": 40,
    "height": img_height.value,
    "width": img_width.value,
}
inpaint = {
    "prompt": "masterpiece, best quality",
    "num_inference_steps": 40,
    "height": img_height.value,
    "width": img_width.value,
}
Select_an_image = "" # @param {type:"string"}

# PROCESS AD
if os.path.exists(Select_an_image):
    image_list = [Select_an_image]

image_list = ad_model_process(
    face_detector_ad,
    person_detector_ad,
    hand_detector_ad,
    model_for_inpaint,
    common,
    inpaint,
    image_list,
)